Coverage for packages/registry/src/langgate/registry/models.py: 69%
119 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-11 05:31 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-11 05:31 +0000
1"""Model information and registry implementation."""
3import os
4from typing import Any, Literal, overload
6from langgate.core.logging import get_logger
7from langgate.core.models import (
8 ContextWindow,
9 ImageGenerationCost,
10 ImageModelCost,
11 ImageModelInfo,
12 LLMInfo,
13 Modality,
14 ModelCapabilities,
15 ModelProvider,
16 ModelProviderId,
17 TokenCosts,
18)
19from langgate.registry.config import RegistryConfig
21logger = get_logger(__name__)
24# Type alias for any concrete model info - extend this union when adding new modalities
25type _ModelInfo = LLMInfo | ImageModelInfo
28class ModelRegistry:
29 """Registry for managing model information."""
31 _instance = None
33 def __new__(cls, *args, **kwargs):
34 """Create or return the singleton instance."""
35 if cls._instance is None:
36 logger.debug("creating_model_registry_singleton")
37 cls._instance = super().__new__(cls)
38 else:
39 logger.debug(
40 "reusing_registry_singleton",
41 initialized=hasattr(cls._instance, "_initialized"),
42 )
43 return cls._instance
45 def __init__(self, config: RegistryConfig | None = None):
46 """Initialize the registry.
48 Args:
49 config: Optional configuration object, will create a new one if not provided
50 """
51 if not hasattr(self, "_initialized"):
52 self.config = config or RegistryConfig()
54 # Separate caches by modality
55 self._model_caches: dict[Modality, dict[str, _ModelInfo]] = {
56 modality: {} for modality in Modality
57 }
59 # Build the model cache
60 try:
61 self._build_model_cache()
62 except Exception:
63 logger.exception(
64 "failed_to_build_model_cache",
65 models_data_path=str(self.config.models_data_path),
66 config_path=str(self.config.config_path),
67 env_file_path=str(self.config.env_file_path),
68 env_file_exists=self.config.env_file_path.exists(),
69 )
70 if not any(
71 self._model_caches[modality] for modality in self._model_caches
72 ): # Only raise if we have no data
73 raise
74 logger.debug(
75 "initialized_model_registry_singleton",
76 models_data_path=str(self.config.models_data_path),
77 config_path=str(self.config.config_path),
78 env_file_path=str(self.config.env_file_path),
79 env_file_exists=self.config.env_file_path.exists(),
80 )
81 self._initialized = True
83 def _build_model_cache(self) -> None:
84 """Build cached model information for all modalities."""
85 # Clear all caches
86 for modality in Modality:
87 self._model_caches[modality] = {}
89 for model_id, mapping in self.config.model_mappings.items():
90 service_provider: str = mapping["service_provider"]
91 service_model_id: str = mapping["service_model_id"]
93 # Check if we have data for this service model ID
94 full_service_model_id = f"{service_provider}/{service_model_id}"
96 # Try to find model data
97 model_data = {}
98 if full_service_model_id in self.config.models_data:
99 model_data = self.config.models_data[full_service_model_id].copy()
100 else:
101 logger.warning(
102 "no_model_data_found",
103 msg="No model data found for service provider model ID",
104 help="""Check that the model data file contains the correct service provider model ID.
105 To add new models for a service provider, add the model data to your langgate_models.json file.""",
106 full_service_model_id=full_service_model_id,
107 service_provider=service_provider,
108 service_model_id=service_model_id,
109 exposed_model_id=model_id,
110 )
112 # Map "mode" field from JSON to Modality enum
113 mode = model_data.get("mode")
114 if not mode:
115 raise ValueError(
116 f"Model {model_id} does not have a valid 'mode' field in model data."
117 )
118 modality = Modality.IMAGE if mode == "image" else Modality.TEXT
120 if modality == Modality.TEXT:
121 llm_info = self._build_llm_info(model_id, mapping, model_data)
122 self._model_caches[Modality.TEXT][model_id] = llm_info
123 elif modality == Modality.IMAGE:
124 image_info = self._build_image_model_info(model_id, mapping, model_data)
125 self._model_caches[Modality.IMAGE][model_id] = image_info
127 # Check if registry is empty after building all caches
128 if not any(self._model_caches[modality] for modality in self._model_caches):
129 logger.warning(
130 "empty_model_registry",
131 help="No models were loaded into the registry cache. Check configuration files.",
132 models_data_path=str(self.config.models_data_path),
133 config_path=str(self.config.config_path),
134 )
136 def _get_model_metadata(
137 self, model_id: str, mapping: dict[str, Any], model_data: dict[str, Any]
138 ) -> dict[str, Any]:
139 """Get generic model metadata from configuration and model data."""
140 # Determine the model provider (creator of the model)
141 # Model provider might differ from the inference service provider
142 # The service provider is not intended to be exposed to external consumers of the registry
143 # The service provider is used by the proxy for routing requests to the correct service
144 # Priority order:
145 # 1. Explicitly set in mapping (takes precedence)
146 # 2. Value coming from the service-level model data
147 model_provider_id: str = mapping.get("model_provider") or model_data.get(
148 "model_provider", ""
149 )
150 if not model_provider_id:
151 raise ValueError(
152 f"Model {model_id} does not have a valid provider ID, Set `model_provider` in model data."
153 )
155 # Get the provider display name, either from data or fallback to ID
156 provider_display_name = mapping.get("model_provider_name") or model_data.get(
157 "model_provider_name", model_provider_id.title()
158 )
160 # Name can come from multiple sources in decreasing priority
161 # Use the model name from the config if available, otherwise use the model data name,
162 # If no name is provided, default to the model portion of the fully specified model ID
163 # (if it contains any slashes), or else the entire model ID.
164 model_name_setting = mapping.get("name") or model_data.get("name")
165 if not model_name_setting:
166 logger.warning(
167 "model_name_not_set",
168 msg="Model name not found in config or model data files",
169 model_id=model_id,
170 )
171 name = (
172 model_name_setting or model_specifier
173 if (model_specifier := model_id.split("/")[-1])
174 else model_id
175 )
177 # Description can come from config mapping or model data
178 description = mapping.get("description") or model_data.get("description")
180 return {
181 "name": name,
182 "description": description,
183 "model_provider_id": model_provider_id,
184 "provider_display_name": provider_display_name,
185 }
187 def _build_llm_info(
188 self, model_id: str, mapping: dict[str, Any], model_data: dict[str, Any]
189 ) -> LLMInfo:
190 """Build LLMInfo from configuration data."""
192 generic_metadata = self._get_model_metadata(model_id, mapping, model_data)
194 # Extract context window if available
195 context_window = ContextWindow.model_validate(model_data.get("context", {}))
197 # Extract capabilities if available
198 capabilities = ModelCapabilities.model_validate(
199 model_data.get("capabilities", {})
200 )
202 # Extract costs if available
203 costs = TokenCosts.model_validate(model_data.get("costs", {}))
205 # Create complete model info
206 return LLMInfo(
207 id=model_id,
208 name=generic_metadata["name"],
209 description=generic_metadata["description"],
210 provider_id=ModelProviderId(generic_metadata["model_provider_id"]),
211 provider=ModelProvider(
212 id=ModelProviderId(generic_metadata["model_provider_id"]),
213 name=generic_metadata["provider_display_name"],
214 description=None,
215 ),
216 costs=costs,
217 capabilities=capabilities,
218 context_window=context_window,
219 )
221 def _build_image_model_info(
222 self, model_id: str, mapping: dict[str, Any], model_data: dict[str, Any]
223 ) -> ImageModelInfo:
224 """Build ImageModelInfo from configuration data."""
226 generic_metadata = self._get_model_metadata(model_id, mapping, model_data)
228 # Extract and validate image model costs
229 costs_data = model_data.get("costs", {})
231 # Handle token costs if present (for hybrid models like gpt-image-1)
232 token_costs = None
233 if "token_costs" in costs_data:
234 token_costs = TokenCosts.model_validate(costs_data["token_costs"])
236 # Handle image generation costs (required)
237 image_generation_data = costs_data.get("image_generation", {})
238 image_generation = ImageGenerationCost.model_validate(image_generation_data)
240 costs = ImageModelCost(
241 token_costs=token_costs, image_generation=image_generation
242 )
244 # Create complete model info
245 return ImageModelInfo(
246 id=model_id,
247 name=generic_metadata["name"],
248 description=generic_metadata["description"],
249 provider_id=ModelProviderId(generic_metadata["model_provider_id"]),
250 provider=ModelProvider(
251 id=ModelProviderId(generic_metadata["model_provider_id"]),
252 name=generic_metadata["provider_display_name"],
253 description=None,
254 ),
255 costs=costs,
256 )
258 # Generic accessor methods with overloads for type safety
259 @overload
260 def get_model_info(self, model_id: str, modality: Literal["text"]) -> LLMInfo: ...
262 @overload
263 def get_model_info(
264 self, model_id: str, modality: Literal["image"]
265 ) -> ImageModelInfo: ...
267 def get_model_info(self, model_id: str, modality: str) -> LLMInfo | ImageModelInfo:
268 """Get model information by ID and modality with proper typing."""
269 modality_enum = Modality(modality)
270 if model_id not in self._model_caches[modality_enum]:
271 raise ValueError(f"{modality.capitalize()} model {model_id} not found")
272 return self._model_caches[modality_enum][model_id]
274 @overload
275 def list_models(self, modality: Literal["text"]) -> list[LLMInfo]: ...
277 @overload
278 def list_models(self, modality: Literal["image"]) -> list[ImageModelInfo]: ...
280 def list_models(self, modality: str) -> list[LLMInfo] | list[ImageModelInfo]:
281 """List all available models for a given modality with proper typing."""
282 modality_enum = Modality(modality)
283 return list(self._model_caches[modality_enum].values()) # type: ignore[return-value]
285 # Backward compatibility methods for existing LLM functionality
286 def get_llm_info(self, model_id: str) -> LLMInfo:
287 """Get LLM information by model ID (backward compatibility)."""
288 return self.get_model_info(model_id, "text")
290 def get_image_model_info(self, model_id: str) -> ImageModelInfo:
291 """Get image model information by model ID."""
292 return self.get_model_info(model_id, "image")
294 def list_llms(self) -> list[LLMInfo]:
295 """List all available LLMs."""
296 return self.list_models("text")
298 def list_image_models(self) -> list[ImageModelInfo]:
299 """List all available image models."""
300 return self.list_models("image")
302 def get_provider_info(self, model_id: str) -> dict[str, Any]:
303 """Get provider information for a model to use in the proxy."""
304 if not (mapping := self.config.model_mappings.get(model_id)):
305 raise ValueError(f"Model {model_id} not found")
307 service_provider = mapping["service_provider"]
309 if not (service_config := self.config.service_config.get(service_provider)):
310 raise ValueError(f"Service provider {service_provider} not found")
312 provider_info = {"provider": service_provider}
314 # Add base URL
315 if "base_url" in service_config:
316 provider_info["base_url"] = service_config["base_url"]
318 # Add API key (resolve from environment)
319 if "api_key" in service_config:
320 api_key: str = service_config["api_key"]
321 if api_key.startswith("${") and api_key.endswith("}"):
322 env_var = api_key[2:-1]
323 if env_var in os.environ:
324 provider_info["api_key"] = os.environ[env_var]
325 else:
326 logger.warning("env_variable_not_found", variable=env_var)
327 else:
328 provider_info["api_key"] = api_key
330 return provider_info