Coverage for packages/registry/src/langgate/registry/models.py: 69%

119 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-11 05:31 +0000

1"""Model information and registry implementation.""" 

2 

3import os 

4from typing import Any, Literal, overload 

5 

6from langgate.core.logging import get_logger 

7from langgate.core.models import ( 

8 ContextWindow, 

9 ImageGenerationCost, 

10 ImageModelCost, 

11 ImageModelInfo, 

12 LLMInfo, 

13 Modality, 

14 ModelCapabilities, 

15 ModelProvider, 

16 ModelProviderId, 

17 TokenCosts, 

18) 

19from langgate.registry.config import RegistryConfig 

20 

21logger = get_logger(__name__) 

22 

23 

24# Type alias for any concrete model info - extend this union when adding new modalities 

25type _ModelInfo = LLMInfo | ImageModelInfo 

26 

27 

28class ModelRegistry: 

29 """Registry for managing model information.""" 

30 

31 _instance = None 

32 

33 def __new__(cls, *args, **kwargs): 

34 """Create or return the singleton instance.""" 

35 if cls._instance is None: 

36 logger.debug("creating_model_registry_singleton") 

37 cls._instance = super().__new__(cls) 

38 else: 

39 logger.debug( 

40 "reusing_registry_singleton", 

41 initialized=hasattr(cls._instance, "_initialized"), 

42 ) 

43 return cls._instance 

44 

45 def __init__(self, config: RegistryConfig | None = None): 

46 """Initialize the registry. 

47 

48 Args: 

49 config: Optional configuration object, will create a new one if not provided 

50 """ 

51 if not hasattr(self, "_initialized"): 

52 self.config = config or RegistryConfig() 

53 

54 # Separate caches by modality 

55 self._model_caches: dict[Modality, dict[str, _ModelInfo]] = { 

56 modality: {} for modality in Modality 

57 } 

58 

59 # Build the model cache 

60 try: 

61 self._build_model_cache() 

62 except Exception: 

63 logger.exception( 

64 "failed_to_build_model_cache", 

65 models_data_path=str(self.config.models_data_path), 

66 config_path=str(self.config.config_path), 

67 env_file_path=str(self.config.env_file_path), 

68 env_file_exists=self.config.env_file_path.exists(), 

69 ) 

70 if not any( 

71 self._model_caches[modality] for modality in self._model_caches 

72 ): # Only raise if we have no data 

73 raise 

74 logger.debug( 

75 "initialized_model_registry_singleton", 

76 models_data_path=str(self.config.models_data_path), 

77 config_path=str(self.config.config_path), 

78 env_file_path=str(self.config.env_file_path), 

79 env_file_exists=self.config.env_file_path.exists(), 

80 ) 

81 self._initialized = True 

82 

83 def _build_model_cache(self) -> None: 

84 """Build cached model information for all modalities.""" 

85 # Clear all caches 

86 for modality in Modality: 

87 self._model_caches[modality] = {} 

88 

89 for model_id, mapping in self.config.model_mappings.items(): 

90 service_provider: str = mapping["service_provider"] 

91 service_model_id: str = mapping["service_model_id"] 

92 

93 # Check if we have data for this service model ID 

94 full_service_model_id = f"{service_provider}/{service_model_id}" 

95 

96 # Try to find model data 

97 model_data = {} 

98 if full_service_model_id in self.config.models_data: 

99 model_data = self.config.models_data[full_service_model_id].copy() 

100 else: 

101 logger.warning( 

102 "no_model_data_found", 

103 msg="No model data found for service provider model ID", 

104 help="""Check that the model data file contains the correct service provider model ID. 

105 To add new models for a service provider, add the model data to your langgate_models.json file.""", 

106 full_service_model_id=full_service_model_id, 

107 service_provider=service_provider, 

108 service_model_id=service_model_id, 

109 exposed_model_id=model_id, 

110 ) 

111 

112 # Map "mode" field from JSON to Modality enum 

113 mode = model_data.get("mode") 

114 if not mode: 

115 raise ValueError( 

116 f"Model {model_id} does not have a valid 'mode' field in model data." 

117 ) 

118 modality = Modality.IMAGE if mode == "image" else Modality.TEXT 

119 

120 if modality == Modality.TEXT: 

121 llm_info = self._build_llm_info(model_id, mapping, model_data) 

122 self._model_caches[Modality.TEXT][model_id] = llm_info 

123 elif modality == Modality.IMAGE: 

124 image_info = self._build_image_model_info(model_id, mapping, model_data) 

125 self._model_caches[Modality.IMAGE][model_id] = image_info 

126 

127 # Check if registry is empty after building all caches 

128 if not any(self._model_caches[modality] for modality in self._model_caches): 

129 logger.warning( 

130 "empty_model_registry", 

131 help="No models were loaded into the registry cache. Check configuration files.", 

132 models_data_path=str(self.config.models_data_path), 

133 config_path=str(self.config.config_path), 

134 ) 

135 

136 def _get_model_metadata( 

137 self, model_id: str, mapping: dict[str, Any], model_data: dict[str, Any] 

138 ) -> dict[str, Any]: 

139 """Get generic model metadata from configuration and model data.""" 

140 # Determine the model provider (creator of the model) 

141 # Model provider might differ from the inference service provider 

142 # The service provider is not intended to be exposed to external consumers of the registry 

143 # The service provider is used by the proxy for routing requests to the correct service 

144 # Priority order: 

145 # 1. Explicitly set in mapping (takes precedence) 

146 # 2. Value coming from the service-level model data 

147 model_provider_id: str = mapping.get("model_provider") or model_data.get( 

148 "model_provider", "" 

149 ) 

150 if not model_provider_id: 

151 raise ValueError( 

152 f"Model {model_id} does not have a valid provider ID, Set `model_provider` in model data." 

153 ) 

154 

155 # Get the provider display name, either from data or fallback to ID 

156 provider_display_name = mapping.get("model_provider_name") or model_data.get( 

157 "model_provider_name", model_provider_id.title() 

158 ) 

159 

160 # Name can come from multiple sources in decreasing priority 

161 # Use the model name from the config if available, otherwise use the model data name, 

162 # If no name is provided, default to the model portion of the fully specified model ID 

163 # (if it contains any slashes), or else the entire model ID. 

164 model_name_setting = mapping.get("name") or model_data.get("name") 

165 if not model_name_setting: 

166 logger.warning( 

167 "model_name_not_set", 

168 msg="Model name not found in config or model data files", 

169 model_id=model_id, 

170 ) 

171 name = ( 

172 model_name_setting or model_specifier 

173 if (model_specifier := model_id.split("/")[-1]) 

174 else model_id 

175 ) 

176 

177 # Description can come from config mapping or model data 

178 description = mapping.get("description") or model_data.get("description") 

179 

180 return { 

181 "name": name, 

182 "description": description, 

183 "model_provider_id": model_provider_id, 

184 "provider_display_name": provider_display_name, 

185 } 

186 

187 def _build_llm_info( 

188 self, model_id: str, mapping: dict[str, Any], model_data: dict[str, Any] 

189 ) -> LLMInfo: 

190 """Build LLMInfo from configuration data.""" 

191 

192 generic_metadata = self._get_model_metadata(model_id, mapping, model_data) 

193 

194 # Extract context window if available 

195 context_window = ContextWindow.model_validate(model_data.get("context", {})) 

196 

197 # Extract capabilities if available 

198 capabilities = ModelCapabilities.model_validate( 

199 model_data.get("capabilities", {}) 

200 ) 

201 

202 # Extract costs if available 

203 costs = TokenCosts.model_validate(model_data.get("costs", {})) 

204 

205 # Create complete model info 

206 return LLMInfo( 

207 id=model_id, 

208 name=generic_metadata["name"], 

209 description=generic_metadata["description"], 

210 provider_id=ModelProviderId(generic_metadata["model_provider_id"]), 

211 provider=ModelProvider( 

212 id=ModelProviderId(generic_metadata["model_provider_id"]), 

213 name=generic_metadata["provider_display_name"], 

214 description=None, 

215 ), 

216 costs=costs, 

217 capabilities=capabilities, 

218 context_window=context_window, 

219 ) 

220 

221 def _build_image_model_info( 

222 self, model_id: str, mapping: dict[str, Any], model_data: dict[str, Any] 

223 ) -> ImageModelInfo: 

224 """Build ImageModelInfo from configuration data.""" 

225 

226 generic_metadata = self._get_model_metadata(model_id, mapping, model_data) 

227 

228 # Extract and validate image model costs 

229 costs_data = model_data.get("costs", {}) 

230 

231 # Handle token costs if present (for hybrid models like gpt-image-1) 

232 token_costs = None 

233 if "token_costs" in costs_data: 

234 token_costs = TokenCosts.model_validate(costs_data["token_costs"]) 

235 

236 # Handle image generation costs (required) 

237 image_generation_data = costs_data.get("image_generation", {}) 

238 image_generation = ImageGenerationCost.model_validate(image_generation_data) 

239 

240 costs = ImageModelCost( 

241 token_costs=token_costs, image_generation=image_generation 

242 ) 

243 

244 # Create complete model info 

245 return ImageModelInfo( 

246 id=model_id, 

247 name=generic_metadata["name"], 

248 description=generic_metadata["description"], 

249 provider_id=ModelProviderId(generic_metadata["model_provider_id"]), 

250 provider=ModelProvider( 

251 id=ModelProviderId(generic_metadata["model_provider_id"]), 

252 name=generic_metadata["provider_display_name"], 

253 description=None, 

254 ), 

255 costs=costs, 

256 ) 

257 

258 # Generic accessor methods with overloads for type safety 

259 @overload 

260 def get_model_info(self, model_id: str, modality: Literal["text"]) -> LLMInfo: ... 

261 

262 @overload 

263 def get_model_info( 

264 self, model_id: str, modality: Literal["image"] 

265 ) -> ImageModelInfo: ... 

266 

267 def get_model_info(self, model_id: str, modality: str) -> LLMInfo | ImageModelInfo: 

268 """Get model information by ID and modality with proper typing.""" 

269 modality_enum = Modality(modality) 

270 if model_id not in self._model_caches[modality_enum]: 

271 raise ValueError(f"{modality.capitalize()} model {model_id} not found") 

272 return self._model_caches[modality_enum][model_id] 

273 

274 @overload 

275 def list_models(self, modality: Literal["text"]) -> list[LLMInfo]: ... 

276 

277 @overload 

278 def list_models(self, modality: Literal["image"]) -> list[ImageModelInfo]: ... 

279 

280 def list_models(self, modality: str) -> list[LLMInfo] | list[ImageModelInfo]: 

281 """List all available models for a given modality with proper typing.""" 

282 modality_enum = Modality(modality) 

283 return list(self._model_caches[modality_enum].values()) # type: ignore[return-value] 

284 

285 # Backward compatibility methods for existing LLM functionality 

286 def get_llm_info(self, model_id: str) -> LLMInfo: 

287 """Get LLM information by model ID (backward compatibility).""" 

288 return self.get_model_info(model_id, "text") 

289 

290 def get_image_model_info(self, model_id: str) -> ImageModelInfo: 

291 """Get image model information by model ID.""" 

292 return self.get_model_info(model_id, "image") 

293 

294 def list_llms(self) -> list[LLMInfo]: 

295 """List all available LLMs.""" 

296 return self.list_models("text") 

297 

298 def list_image_models(self) -> list[ImageModelInfo]: 

299 """List all available image models.""" 

300 return self.list_models("image") 

301 

302 def get_provider_info(self, model_id: str) -> dict[str, Any]: 

303 """Get provider information for a model to use in the proxy.""" 

304 if not (mapping := self.config.model_mappings.get(model_id)): 

305 raise ValueError(f"Model {model_id} not found") 

306 

307 service_provider = mapping["service_provider"] 

308 

309 if not (service_config := self.config.service_config.get(service_provider)): 

310 raise ValueError(f"Service provider {service_provider} not found") 

311 

312 provider_info = {"provider": service_provider} 

313 

314 # Add base URL 

315 if "base_url" in service_config: 

316 provider_info["base_url"] = service_config["base_url"] 

317 

318 # Add API key (resolve from environment) 

319 if "api_key" in service_config: 

320 api_key: str = service_config["api_key"] 

321 if api_key.startswith("${") and api_key.endswith("}"): 

322 env_var = api_key[2:-1] 

323 if env_var in os.environ: 

324 provider_info["api_key"] = os.environ[env_var] 

325 else: 

326 logger.warning("env_variable_not_found", variable=env_var) 

327 else: 

328 provider_info["api_key"] = api_key 

329 

330 return provider_info