Coverage for packages/registry/src/langgate/registry/models.py: 60%

75 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-04-09 21:23 +0000

1"""Model information and registry implementation.""" 

2 

3import os 

4from typing import Any 

5 

6from langgate.core.logging import get_logger 

7from langgate.core.models import ( 

8 ContextWindow, 

9 LLMInfo, 

10 ModelCapabilities, 

11 ModelCost, 

12 ModelProvider, 

13 ModelProviderId, 

14) 

15from langgate.registry.config import RegistryConfig 

16 

17logger = get_logger(__name__) 

18 

19 

20class ModelRegistry: 

21 """Registry for managing model information.""" 

22 

23 _instance = None 

24 

25 def __new__(cls, *args, **kwargs): 

26 """Create or return the singleton instance.""" 

27 if cls._instance is None: 

28 logger.debug("creating_model_registry_singleton") 

29 cls._instance = super().__new__(cls) 

30 else: 

31 logger.debug( 

32 "reusing_registry_singleton", 

33 initialized=hasattr(cls._instance, "_initialized"), 

34 ) 

35 return cls._instance 

36 

37 def __init__(self, config: RegistryConfig | None = None): 

38 """Initialize the registry. 

39 

40 Args: 

41 config: Optional configuration object, will create a new one if not provided 

42 """ 

43 if not hasattr(self, "_initialized"): 

44 self.config = config or RegistryConfig() 

45 

46 # Cached model info objects 

47 self._models_cache: dict[str, LLMInfo] = {} 

48 

49 # Build the model cache 

50 try: 

51 self._build_model_cache() 

52 except Exception: 

53 logger.exception( 

54 "failed_to_build_model_cache", 

55 models_data_path=str(self.config.models_data_path), 

56 config_path=str(self.config.config_path), 

57 env_file_path=str(self.config.env_file_path), 

58 env_file_exists=self.config.env_file_path.exists(), 

59 ) 

60 if not self._models_cache: # Only raise if we have no data 

61 raise 

62 logger.debug( 

63 "initialized_model_registry_singleton", 

64 models_data_path=str(self.config.models_data_path), 

65 config_path=str(self.config.config_path), 

66 env_file_path=str(self.config.env_file_path), 

67 env_file_exists=self.config.env_file_path.exists(), 

68 ) 

69 self._initialized = True 

70 

71 def _build_model_cache(self) -> None: 

72 """Build the cached model information from configuration.""" 

73 self._models_cache = {} 

74 

75 for model_id, mapping in self.config.model_mappings.items(): 

76 service_provider: str = mapping["service_provider"] 

77 service_model_id: str = mapping["service_model_id"] 

78 

79 # Check if we have data for this service model ID 

80 full_service_model_id = f"{service_provider}/{service_model_id}" 

81 

82 # Try to find model data 

83 model_data = {} 

84 if full_service_model_id in self.config.models_data: 

85 model_data = self.config.models_data[full_service_model_id].copy() 

86 else: 

87 logger.warning( 

88 "no_model_data_found", 

89 msg="No model data found for service provider model ID", 

90 help="""Check that the model data file contains the correct service provider model ID. 

91 To add new models for a service provider, add the model data to your langgate_models.json file.""", 

92 full_service_model_id=full_service_model_id, 

93 service_provider=service_provider, 

94 service_model_id=service_model_id, 

95 exposed_model_id=model_id, 

96 available_models=list(self._models_cache.keys()), 

97 ) 

98 

99 # Extract context window if available 

100 context_window = ContextWindow.model_validate(model_data.get("context", {})) 

101 

102 # Extract capabilities if available 

103 capabilities = ModelCapabilities.model_validate( 

104 model_data.get("capabilities", {}) 

105 ) 

106 

107 # Extract costs if available 

108 costs = ModelCost.model_validate(model_data.get("costs", {})) 

109 

110 # Determine the model provider (creator of the model) 

111 # Model provider might differ from the inference service provider 

112 # The service provider is not intended to be exposed to external consumers of the registry 

113 # The service provider is used by the proxy for routing requests to the correct service 

114 if "model_provider" in model_data: 

115 model_provider_id: str = model_data["model_provider"] 

116 else: 

117 raise ValueError( 

118 f"Model {model_id} does not have a valid provider ID, Set `model_provider` in model data." 

119 ) 

120 

121 # Get the provider display name, either from data or fallback to ID 

122 provider_display_name = model_data.get( 

123 "model_provider_name", model_provider_id.title() 

124 ) 

125 

126 # Name can come from multiple sources in decreasing priority 

127 # Use the model name from the config if available, otherwise use the model data name, 

128 # If no name is provided, default to the model portion of the fully specified model ID 

129 # (if it contains any slashes), or else the entire model ID. 

130 model_name_setting = mapping.get("name") or model_data.get("name") 

131 if not model_name_setting: 

132 logger.warning( 

133 "model_name_not_set", 

134 msg="Model name not found in config or model data files", 

135 model_id=model_id, 

136 ) 

137 name = ( 

138 model_name_setting or model_specifier 

139 if (model_specifier := model_id.split("/")[-1]) 

140 else model_id 

141 ) 

142 

143 # Description can come from config mapping or model data 

144 description = mapping.get("description") or model_data.get("description") 

145 

146 # Create complete model info 

147 self._models_cache[model_id] = LLMInfo( 

148 id=model_id, 

149 name=name, 

150 description=description, 

151 provider=ModelProvider( 

152 id=ModelProviderId(model_provider_id), 

153 name=provider_display_name, 

154 description=None, 

155 ), 

156 costs=costs, 

157 capabilities=capabilities, 

158 context_window=context_window, 

159 ) 

160 

161 if not self._models_cache: 

162 logger.warning( 

163 "empty_model_registry", 

164 help="No models were loaded into the registry cache. Check configuration files.", 

165 models_data_path=str(self.config.models_data_path), 

166 config_path=str(self.config.config_path), 

167 ) 

168 

169 def get_model_info(self, model_id: str) -> LLMInfo: 

170 """Get model information by model ID.""" 

171 if model_id not in self._models_cache: 

172 raise ValueError(f"Model {model_id} not found") 

173 

174 return self._models_cache[model_id] 

175 

176 def list_models(self) -> list[LLMInfo]: 

177 """List all available models.""" 

178 return list(self._models_cache.values()) 

179 

180 def get_provider_info(self, model_id: str) -> dict[str, Any]: 

181 """Get provider information for a model to use in the proxy.""" 

182 if not (mapping := self.config.model_mappings.get(model_id)): 

183 raise ValueError(f"Model {model_id} not found") 

184 

185 service_provider = mapping["service_provider"] 

186 

187 if not (service_config := self.config.service_config.get(service_provider)): 

188 raise ValueError(f"Service provider {service_provider} not found") 

189 

190 provider_info = {"provider": service_provider} 

191 

192 # Add base URL 

193 if "base_url" in service_config: 

194 provider_info["base_url"] = service_config["base_url"] 

195 

196 # Add API key (resolve from environment) 

197 if "api_key" in service_config: 

198 api_key: str = service_config["api_key"] 

199 if api_key.startswith("${") and api_key.endswith("}"): 

200 env_var = api_key[2:-1] 

201 if env_var in os.environ: 

202 provider_info["api_key"] = os.environ[env_var] 

203 else: 

204 logger.warning("env_variable_not_found", variable=env_var) 

205 else: 

206 provider_info["api_key"] = api_key 

207 

208 return provider_info