Coverage for packages/transform/src/langgate/transform/local.py: 16%
128 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-08-29 16:00 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-08-29 16:00 +0000
1"""Local transformer client implementation."""
3import importlib.resources
4import os
5from pathlib import Path
6from typing import Any
8from dotenv import load_dotenv
9from pydantic import SecretStr
11from langgate.core.logging import get_logger
12from langgate.core.models import Modality
13from langgate.core.schemas.config import ConfigSchema, ModelConfig
14from langgate.core.utils.config_utils import load_yaml_config, resolve_path
15from langgate.transform.protocol import TransformerClientProtocol
16from langgate.transform.transformer import ParamTransformer
18logger = get_logger(__name__)
21class LocalTransformerClient(TransformerClientProtocol):
22 """
23 Local transformer client for parameter transformations.
25 This client handles parameter transformations based on local configuration.
26 """
28 _instance = None
29 _initialized: bool = False
31 def __new__(cls, *args, **kwargs):
32 """Create or return the singleton instance."""
33 if cls._instance is None:
34 cls._instance = super().__new__(cls)
35 return cls._instance
37 def __init__(
38 self, config_path: Path | None = None, env_file_path: Path | None = None
39 ):
40 """Initialize the client.
42 Args:
43 config_path: Path to the configuration file
44 env_file_path: Path to the environment file
45 """
46 if hasattr(self, "_initialized") and self._initialized:
47 return
49 # Set up default paths
50 cwd = Path.cwd()
52 core_resources = importlib.resources.files("langgate.core")
53 default_config_path = Path(
54 str(core_resources.joinpath("data", "default_config.yaml"))
55 )
56 cwd_config_path = cwd / "langgate_config.yaml"
57 cwd_env_path = cwd / ".env"
59 self.config_path = resolve_path(
60 "LANGGATE_CONFIG",
61 config_path,
62 cwd_config_path if cwd_config_path.exists() else default_config_path,
63 "config_path",
64 logger,
65 )
66 self.env_file_path = resolve_path(
67 "LANGGATE_ENV_FILE",
68 env_file_path,
69 cwd_env_path,
70 "env_file_path",
71 logger,
72 )
74 # Cache for configs
75 self._global_config: dict[str, Any] = {}
76 self._service_config: dict[str, dict[str, Any]] = {}
77 self._model_mappings: dict[str, dict[str, Any]] = {}
79 # Load configuration
80 self._load_config()
81 self._initialized = True
82 logger.debug("initialized_local_transformer_client")
84 def _load_config(self) -> None:
85 """Load configuration from file."""
86 config = load_yaml_config(self.config_path, ConfigSchema, logger)
88 # Extract validated data
89 self._global_config = {"default_params": config.default_params}
90 self._service_config = {
91 k: v.model_dump(exclude_none=True) for k, v in config.services.items()
92 }
93 self._process_model_mappings(config.models)
95 # Load environment variables from .env file if it exists
96 if self.env_file_path.exists():
97 load_dotenv(self.env_file_path)
98 logger.debug("loaded_transformer_env_file", path=str(self.env_file_path))
100 def _process_model_mappings(
101 self, models_config: dict[str, list[ModelConfig]]
102 ) -> None:
103 """Process model mappings from validated configuration.
105 Args:
106 models_config: Dict of modality to list of model configurations
107 """
108 self._model_mappings = {}
110 for modality, model_list in models_config.items():
111 for model_config in model_list:
112 model_data = model_config.model_dump(exclude_none=True)
113 model_id = model_data["id"]
114 service = model_data["service"]
116 # Store mapping info including modality
117 self._model_mappings[model_id] = {
118 "modality": modality,
119 "service_provider": service["provider"],
120 "service_model_id": service["model_id"],
121 "api_format": model_data.get("api_format"),
122 "default_params": model_data["default_params"],
123 "override_params": model_data["override_params"],
124 "remove_params": model_data["remove_params"],
125 "rename_params": model_data["rename_params"],
126 }
128 async def get_params(
129 self, model_id: str, input_params: dict[str, Any]
130 ) -> tuple[str, dict[str, Any]]:
131 """
132 Get transformed parameters for the specified model.
134 Parameter precedence and transformation order:
136 Defaults (applied only if key doesn't exist yet):
137 1. Model-specific defaults (highest precedence for defaults)
138 2. Pattern defaults (matching patterns applied in config order)
139 3. Service provider defaults
140 4. Global defaults (lowest precedence for defaults)
142 Overrides/Removals/Renames (applied in order, later steps overwrite/modify earlier ones):
143 1. Input parameters (initial state)
144 2. Service-level API keys and base URLs
145 3. Service-level overrides, removals, renames
146 4. Pattern-level overrides, removals, renames (matching patterns applied in config order)
147 5. Model-specific overrides, removals, renames (highest precedence)
148 6. Model ID (always overwritten with service_model_id)
149 7. Environment variable substitution (applied last to all string values)
151 Args:
152 model_id: The ID of the model to get transformed parameters for
153 input_params: The parameters to transform
155 Returns:
156 A tuple containing (api_format, transformed_parameters)
158 Raises:
159 ValueError: If the model is not found in the configuration.
160 """
161 if model_id not in self._model_mappings:
162 logger.error("model_not_found", model_id=model_id)
163 raise ValueError(f"Model '{model_id}' not found in configuration")
165 mapping = self._model_mappings[model_id]
166 service_provider = mapping["service_provider"]
167 service_model_id = mapping["service_model_id"]
168 service_config = self._service_config.get(service_provider, {})
170 transformer = ParamTransformer()
172 # Collect pattern configurations
173 matching_pattern_defaults = []
174 matching_pattern_overrides = []
175 matching_pattern_removals = []
176 matching_pattern_renames = []
178 model_patterns = service_config.get("model_patterns", {})
179 for pattern, pattern_config in model_patterns.items():
180 if pattern in model_id:
181 if "default_params" in pattern_config:
182 matching_pattern_defaults.append(pattern_config["default_params"])
183 if "override_params" in pattern_config:
184 matching_pattern_overrides.append(pattern_config["override_params"])
185 if "remove_params" in pattern_config:
186 matching_pattern_removals.extend(pattern_config["remove_params"])
187 if "rename_params" in pattern_config:
188 matching_pattern_renames.append(pattern_config["rename_params"])
190 # Apply defaults (highest to lowest precedence)
191 transformer.with_defaults(mapping.get("default_params", {}))
192 for defaults in reversed(matching_pattern_defaults):
193 transformer.with_defaults(defaults)
195 # Apply service-level defaults
196 service_defaults = service_config.get("default_params", {})
197 model_modality = mapping.get("modality")
198 if service_defaults:
199 # Check if service_defaults is structured by modality
200 is_modality_structured = any(
201 key in [m.value for m in Modality] for key in service_defaults
202 )
204 if is_modality_structured and model_modality:
205 # Apply modality-specific defaults if available
206 modality_defaults = service_defaults.get(model_modality, {})
207 if modality_defaults:
208 transformer.with_defaults(modality_defaults)
209 elif not is_modality_structured:
210 # Apply flat defaults directly
211 transformer.with_defaults(service_defaults)
213 # Apply modality-specific global defaults
214 global_defaults = self._global_config.get("default_params", {})
215 modality_defaults = (
216 global_defaults.get(model_modality, {}) if model_modality else {}
217 )
218 transformer.with_defaults(modality_defaults)
220 # Apply overrides, removals, renames (service -> pattern -> model precedence)
222 # Service level transformations
223 api_key_val = service_config.get("api_key")
224 if (
225 api_key_val
226 and isinstance(api_key_val, str)
227 and api_key_val.startswith("${")
228 and api_key_val.endswith("}")
229 ):
230 env_var = api_key_val[2:-1]
231 env_val = os.environ.get(env_var)
232 if env_val is not None:
233 transformer.with_overrides({"api_key": SecretStr(env_val)})
234 else:
235 logger.warning(
236 "api_key_env_var_not_found",
237 variable=env_var,
238 service=service_provider,
239 )
240 elif api_key_val:
241 # Handle non-env var API keys if needed (though config should use env vars)
242 await logger.awarning(
243 "api_key_set_directly_in_config",
244 service=service_provider,
245 msg="API key set directly in config, not recommended",
246 )
247 transformer.with_overrides({"api_key": api_key_val})
249 base_url_val = service_config.get("base_url")
250 if base_url_val and isinstance(base_url_val, str):
251 transformer.with_overrides({"base_url": base_url_val})
253 # Service overrides, removals, renames
254 service_override_params = service_config.get("override_params", {})
255 if isinstance(service_override_params, dict):
256 transformer.with_overrides(service_override_params)
257 service_remove_params = service_config.get("remove_params", [])
258 if isinstance(service_remove_params, list):
259 transformer.removing(service_remove_params)
260 service_rename_params = service_config.get("rename_params", {})
261 if isinstance(service_rename_params, dict):
262 transformer.renaming(service_rename_params)
264 # Pattern level transformations (apply collected configs)
265 for overrides in matching_pattern_overrides:
266 transformer.with_overrides(overrides)
267 if matching_pattern_removals:
268 transformer.removing(matching_pattern_removals)
269 for renames in matching_pattern_renames:
270 transformer.renaming(renames)
272 # Model level transformations (highest precedence for overrides/removals/renames)
273 transformer.with_overrides(mapping.get("override_params", {}))
274 transformer.removing(mapping.get("remove_params", []))
275 transformer.renaming(mapping.get("rename_params", {}))
277 # Final transformations
278 # Set the final model ID to be sent to the service provider
279 transformer.with_model_id(service_model_id)
281 # Substitute environment variables in string values
282 transformer.with_env_vars()
284 # Resolve API format based on fallback hierarchy: model.api_format → service.api_format → service.provider
285 api_format = (
286 mapping.get("api_format")
287 or service_config.get("api_format")
288 or service_provider
289 )
291 # Execute Transformation
292 # Applies all chained transformations in order
293 transformed_params = transformer.transform(input_params)
295 return api_format, transformed_params