Coverage for packages/transform/src/langgate/transform/local.py: 17%
115 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-06-18 14:26 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-06-18 14:26 +0000
1"""Local transformer client implementation."""
3import importlib.resources
4import os
5from pathlib import Path
6from typing import Any
8from dotenv import load_dotenv
9from pydantic import SecretStr
11from langgate.core.logging import get_logger
12from langgate.core.schemas.config import ConfigSchema, ModelConfig
13from langgate.core.utils.config_utils import load_yaml_config, resolve_path
14from langgate.transform.protocol import TransformerClientProtocol
15from langgate.transform.transformer import ParamTransformer
17logger = get_logger(__name__)
20class LocalTransformerClient(TransformerClientProtocol):
21 """
22 Local transformer client for parameter transformations.
24 This client handles parameter transformations based on local configuration.
25 """
27 _instance = None
28 _initialized: bool = False
30 def __new__(cls, *args, **kwargs):
31 """Create or return the singleton instance."""
32 if cls._instance is None:
33 cls._instance = super().__new__(cls)
34 return cls._instance
36 def __init__(
37 self, config_path: Path | None = None, env_file_path: Path | None = None
38 ):
39 """Initialize the client.
41 Args:
42 config_path: Path to the configuration file
43 env_file_path: Path to the environment file
44 """
45 if hasattr(self, "_initialized") and self._initialized:
46 return
48 # Set up default paths
49 cwd = Path.cwd()
51 core_resources = importlib.resources.files("langgate.core")
52 default_config_path = Path(
53 str(core_resources.joinpath("data", "default_config.yaml"))
54 )
55 cwd_config_path = cwd / "langgate_config.yaml"
56 cwd_env_path = cwd / ".env"
58 self.config_path = resolve_path(
59 "LANGGATE_CONFIG",
60 config_path,
61 cwd_config_path if cwd_config_path.exists() else default_config_path,
62 "config_path",
63 logger,
64 )
65 self.env_file_path = resolve_path(
66 "LANGGATE_ENV_FILE",
67 env_file_path,
68 cwd_env_path,
69 "env_file_path",
70 logger,
71 )
73 # Cache for configs
74 self._global_config: dict[str, Any] = {}
75 self._service_config: dict[str, dict[str, Any]] = {}
76 self._model_mappings: dict[str, dict[str, Any]] = {}
78 # Load configuration
79 self._load_config()
80 self._initialized = True
81 logger.debug("initialized_local_transformer_client")
83 def _load_config(self) -> None:
84 """Load configuration from file."""
85 config = load_yaml_config(self.config_path, ConfigSchema, logger)
87 # Extract validated data
88 self._global_config = {"default_params": config.default_params}
89 self._service_config = {
90 k: v.model_dump(exclude_none=True) for k, v in config.services.items()
91 }
92 self._process_model_mappings(config.models)
94 # Load environment variables from .env file if it exists
95 if self.env_file_path.exists():
96 load_dotenv(self.env_file_path)
97 logger.debug("loaded_transformer_env_file", path=str(self.env_file_path))
99 def _process_model_mappings(self, models_config: list[ModelConfig]) -> None:
100 """Process model mappings from validated configuration.
102 Args:
103 models_config: List of validated model configurations
104 """
105 self._model_mappings = {}
107 for model_config in models_config:
108 model_data = model_config.model_dump(exclude_none=True)
109 model_id = model_data["id"]
110 service = model_data["service"]
112 # Store mapping info
113 self._model_mappings[model_id] = {
114 "service_provider": service["provider"],
115 "service_model_id": service["model_id"],
116 "api_format": model_data.get("api_format"),
117 "default_params": model_data["default_params"],
118 "override_params": model_data["override_params"],
119 "remove_params": model_data["remove_params"],
120 "rename_params": model_data["rename_params"],
121 }
123 async def get_params(
124 self, model_id: str, input_params: dict[str, Any]
125 ) -> tuple[str, dict[str, Any]]:
126 """
127 Get transformed parameters for the specified model.
129 Parameter precedence and transformation order:
131 Defaults (applied only if key doesn't exist yet):
132 1. Model-specific defaults (highest precedence for defaults)
133 2. Pattern defaults (matching patterns applied in config order)
134 3. Service provider defaults
135 4. Global defaults (lowest precedence for defaults)
137 Overrides/Removals/Renames (applied in order, later steps overwrite/modify earlier ones):
138 1. Input parameters (initial state)
139 2. Service-level API keys and base URLs
140 3. Service-level overrides, removals, renames
141 4. Pattern-level overrides, removals, renames (matching patterns applied in config order)
142 5. Model-specific overrides, removals, renames (highest precedence)
143 6. Model ID (always overwritten with service_model_id)
144 7. Environment variable substitution (applied last to all string values)
146 Args:
147 model_id: The ID of the model to get transformed parameters for
148 input_params: The parameters to transform
150 Returns:
151 A tuple containing (api_format, transformed_parameters)
153 Raises:
154 ValueError: If the model is not found in the configuration.
155 """
156 if model_id not in self._model_mappings:
157 logger.error("model_not_found", model_id=model_id)
158 raise ValueError(f"Model '{model_id}' not found in configuration")
160 mapping = self._model_mappings[model_id]
161 service_provider = mapping["service_provider"]
162 service_model_id = mapping["service_model_id"]
163 service_config = self._service_config.get(service_provider, {})
165 transformer = ParamTransformer()
167 # Collect pattern configurations
168 matching_pattern_defaults = []
169 matching_pattern_overrides = []
170 matching_pattern_removals = []
171 matching_pattern_renames = []
173 model_patterns = service_config.get("model_patterns", {})
174 for pattern, pattern_config in model_patterns.items():
175 if pattern in model_id:
176 if "default_params" in pattern_config:
177 matching_pattern_defaults.append(pattern_config["default_params"])
178 if "override_params" in pattern_config:
179 matching_pattern_overrides.append(pattern_config["override_params"])
180 if "remove_params" in pattern_config:
181 matching_pattern_removals.extend(pattern_config["remove_params"])
182 if "rename_params" in pattern_config:
183 matching_pattern_renames.append(pattern_config["rename_params"])
185 # Apply defaults (highest to lowest precedence)
186 transformer.with_defaults(mapping.get("default_params", {}))
187 for defaults in reversed(matching_pattern_defaults):
188 transformer.with_defaults(defaults)
189 transformer.with_defaults(service_config.get("default_params", {}))
190 transformer.with_defaults(self._global_config.get("default_params", {}))
192 # Apply overrides, removals, renames (service -> pattern -> model precedence)
194 # Service level transformations
195 api_key_val = service_config.get("api_key")
196 if (
197 api_key_val
198 and isinstance(api_key_val, str)
199 and api_key_val.startswith("${")
200 and api_key_val.endswith("}")
201 ):
202 env_var = api_key_val[2:-1]
203 env_val = os.environ.get(env_var)
204 if env_val is not None:
205 transformer.with_overrides({"api_key": SecretStr(env_val)})
206 else:
207 logger.warning(
208 "api_key_env_var_not_found",
209 variable=env_var,
210 service=service_provider,
211 )
212 elif api_key_val:
213 # Handle non-env var API keys if needed (though config should use env vars)
214 await logger.awarning(
215 "api_key_set_directly_in_config",
216 service=service_provider,
217 msg="API key set directly in config, not recommended",
218 )
219 transformer.with_overrides({"api_key": api_key_val})
221 base_url_val = service_config.get("base_url")
222 if base_url_val and isinstance(base_url_val, str):
223 transformer.with_overrides({"base_url": base_url_val})
225 # Service overrides, removals, renames
226 service_override_params = service_config.get("override_params", {})
227 if isinstance(service_override_params, dict):
228 transformer.with_overrides(service_override_params)
229 service_remove_params = service_config.get("remove_params", [])
230 if isinstance(service_remove_params, list):
231 transformer.removing(service_remove_params)
232 service_rename_params = service_config.get("rename_params", {})
233 if isinstance(service_rename_params, dict):
234 transformer.renaming(service_rename_params)
236 # Pattern level transformations (apply collected configs)
237 for overrides in matching_pattern_overrides:
238 transformer.with_overrides(overrides)
239 if matching_pattern_removals:
240 transformer.removing(matching_pattern_removals)
241 for renames in matching_pattern_renames:
242 transformer.renaming(renames)
244 # Model level transformations (highest precedence for overrides/removals/renames)
245 transformer.with_overrides(mapping.get("override_params", {}))
246 transformer.removing(mapping.get("remove_params", []))
247 transformer.renaming(mapping.get("rename_params", {}))
249 # Final transformations
250 # Set the final model ID to be sent to the service provider
251 transformer.with_model_id(service_model_id)
253 # Substitute environment variables in string values
254 transformer.with_env_vars()
256 # Resolve API format based on fallback hierarchy: model.api_format → service.api_format → service.provider
257 api_format = (
258 mapping.get("api_format")
259 or service_config.get("api_format")
260 or service_provider
261 )
263 # Execute Transformation
264 # Applies all chained transformations in order
265 transformed_params = transformer.transform(input_params)
267 return api_format, transformed_params