Coverage for packages/registry/src/langgate/registry/config.py: 70%
89 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-11 05:31 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-11 05:31 +0000
1"""Configuration handling for the registry."""
3import importlib.resources
4import json
5from pathlib import Path
6from typing import Any
8from dotenv import load_dotenv
10from langgate.core.logging import get_logger
11from langgate.core.schemas.config import ConfigSchema
12from langgate.core.utils.config_utils import load_yaml_config, resolve_path
14logger = get_logger(__name__)
17class RegistryConfig:
18 """Configuration handler for the registry."""
20 def __init__(
21 self,
22 models_data_path: Path | None = None,
23 config_path: Path | None = None,
24 env_file_path: Path | None = None,
25 ):
26 """
27 Args:
28 models_data_path: Path to the models data JSON file
29 config_path: Path to the main configuration YAML file
30 env_file_path: Path to a `.env` file for environment variables
31 """
32 # Set up default paths
33 cwd = Path.cwd()
34 # Get package resource paths
35 registry_resources = importlib.resources.files("langgate.registry")
36 core_resources = importlib.resources.files("langgate.core")
37 self.default_models_path = Path(
38 str(registry_resources.joinpath("data", "default_models.json"))
39 )
40 default_config_path = Path(
41 str(core_resources.joinpath("data", "default_config.yaml"))
42 )
44 # Define default paths with priorities
45 # Models data: args > env > cwd > package_dir
46 cwd_models_path = cwd / "langgate_models.json"
48 # Config: args > env > cwd > package_dir
49 cwd_config_path = cwd / "langgate_config.yaml"
51 # Env file: args > env > cwd
52 cwd_env_path = cwd / ".env"
54 # Resolve paths using priority order
55 self.models_data_path = resolve_path(
56 "LANGGATE_MODELS",
57 models_data_path,
58 cwd_models_path if cwd_models_path.exists() else self.default_models_path,
59 "models_data_path",
60 )
62 self.config_path = resolve_path(
63 "LANGGATE_CONFIG",
64 config_path,
65 cwd_config_path if cwd_config_path.exists() else default_config_path,
66 "config_path",
67 )
69 self.env_file_path = resolve_path(
70 "LANGGATE_ENV_FILE", env_file_path, cwd_env_path, "env_file_path"
71 )
73 # Load environment variables from .env file if it exists
74 if self.env_file_path.exists():
75 load_dotenv(self.env_file_path)
76 logger.debug("loaded_env_file", path=str(self.env_file_path))
78 # Initialize data structures
79 self.models_data: dict[str, dict[str, Any]] = {}
80 self.global_config: dict[str, Any] = {}
81 self.service_config: dict[str, dict[str, Any]] = {}
82 self.model_mappings: dict[str, dict[str, Any]] = {}
83 self.models_merge_mode: str = "merge" # Default value
85 # Load configuration
86 self._load_config()
88 def _load_config(self) -> None:
89 """Load configuration from files."""
90 try:
91 # Load main configuration first to get merge mode
92 self._load_main_config()
94 # Load model data with merge mode available
95 self._load_model_data()
97 except Exception:
98 logger.exception(
99 "failed_to_load_config",
100 models_data_path=str(self.models_data_path),
101 config_path=str(self.config_path),
102 )
103 raise
105 def _has_user_models(self) -> bool:
106 """Check if user has specified custom models."""
107 return (
108 self.models_data_path != self.default_models_path
109 and self.models_data_path.exists()
110 )
112 def _load_json_file(self, path: Path) -> dict[str, dict[str, Any]]:
113 """Load and return JSON data from file."""
114 try:
115 with open(path) as f:
116 return json.load(f)
117 except FileNotFoundError:
118 logger.warning(
119 "model_data_file_not_found",
120 models_data_path=str(path),
121 )
122 return {}
124 def _merge_models(
125 self,
126 default_models: dict[str, dict[str, Any]],
127 user_models: dict[str, dict[str, Any]],
128 merge_mode: str,
129 ) -> dict[str, dict[str, Any]]:
130 """Merge user models with default models based on merge mode."""
131 if merge_mode == "merge":
132 # User models override defaults, new models are added
133 merged = default_models.copy()
134 for model_id, model_config in user_models.items():
135 if model_id in merged:
136 logger.debug(
137 "overriding_default_model",
138 model_id=model_id,
139 original_name=merged[model_id].get("name"),
140 new_name=model_config.get("name"),
141 )
142 merged[model_id] = model_config
143 return merged
145 if merge_mode == "extend":
146 # User models are added, conflicts with defaults cause errors
147 conflicts = set(default_models.keys()) & set(user_models.keys())
148 if conflicts:
149 raise ValueError(
150 f"Model ID conflicts found in extend mode: {', '.join(conflicts)}. "
151 f"Use 'merge' mode to allow overrides or rename conflicting models."
152 )
153 return {**default_models, **user_models}
155 raise ValueError(f"Unknown merge mode: {merge_mode}")
157 def _load_model_data(self) -> None:
158 """Load model data from JSON file(s) based on merge mode."""
159 try:
160 # Always load default models first
161 default_models = self._load_json_file(self.default_models_path)
163 has_user_models = self._has_user_models()
165 if self.models_merge_mode == "replace" or not has_user_models:
166 # Use only user models or default if no user models
167 if has_user_models:
168 self.models_data = self._load_json_file(self.models_data_path)
169 else:
170 self.models_data = default_models
171 else:
172 # merge user models with defaults
173 user_models = (
174 self._load_json_file(self.models_data_path)
175 if has_user_models
176 else {}
177 )
178 self.models_data = self._merge_models(
179 default_models, user_models, self.models_merge_mode
180 )
182 logger.info(
183 "loaded_model_data",
184 models_data_path=str(self.models_data_path),
185 merge_mode=self.models_merge_mode,
186 model_count=len(self.models_data),
187 )
188 except Exception:
189 logger.exception("failed_to_load_model_data")
190 raise
192 def _load_main_config(self) -> None:
193 """Load main configuration from YAML file."""
194 config = load_yaml_config(self.config_path, ConfigSchema, logger)
196 # Extract merge mode
197 self.models_merge_mode = config.models_merge_mode
199 # Extract validated data
200 self.global_config = {
201 "default_params": config.default_params,
202 "models_merge_mode": config.models_merge_mode,
203 }
205 # Extract service provider config
206 self.service_config = {
207 k: v.model_dump(exclude_none=True) for k, v in config.services.items()
208 }
210 # Process model mappings
211 self._process_model_mappings(config.models)
213 def _process_model_mappings(self, models_config: dict[str, list]) -> None:
214 """Process model mappings from validated configuration.
216 Args:
217 models_config: Dict of modality to list of model configurations
218 """
219 self.model_mappings = {}
221 for _, model_list in models_config.items():
222 for model_config in model_list:
223 model_data = model_config.model_dump(exclude_none=True)
224 model_id = model_data["id"]
225 service = model_data["service"]
227 # Store mapping info with proper type handling
228 self.model_mappings[model_id] = {
229 "service_provider": service["provider"],
230 "service_model_id": service["model_id"],
231 "override_params": model_data.get("override_params", {}),
232 "remove_params": model_data.get("remove_params", []),
233 "rename_params": model_data.get("rename_params", {}),
234 "name": model_data.get("name"),
235 "model_provider": model_data.get("model_provider"),
236 "model_provider_name": model_data.get("model_provider_name"),
237 "description": model_data.get("description"),
238 }