Coverage for packages/transform/src/langgate/transform/local.py: 18%
118 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-04-09 21:23 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-04-09 21:23 +0000
1"""Local transformer client implementation."""
3import importlib.resources
4import os
5from pathlib import Path
6from typing import Any
8from dotenv import load_dotenv
9from pydantic import SecretStr
11from langgate.core.logging import get_logger
12from langgate.core.schemas.config import ConfigSchema, ModelConfig
13from langgate.core.utils.config_utils import load_yaml_config, resolve_path
14from langgate.transform.protocol import TransformerClientProtocol
15from langgate.transform.transformer import ParamTransformer
17logger = get_logger(__name__)
20class LocalTransformerClient(TransformerClientProtocol):
21 """
22 Local transformer client for parameter transformations.
24 This client handles parameter transformations based on local configuration.
25 """
27 _instance = None
28 _initialized: bool = False
30 def __new__(cls, *args, **kwargs):
31 """Create or return the singleton instance."""
32 if cls._instance is None:
33 cls._instance = super().__new__(cls)
34 return cls._instance
36 def __init__(
37 self, config_path: Path | None = None, env_file_path: Path | None = None
38 ):
39 """Initialize the client.
41 Args:
42 config_path: Path to the configuration file
43 env_file_path: Path to the environment file
44 """
45 if hasattr(self, "_initialized") and self._initialized:
46 return
48 # Set up default paths
49 cwd = Path.cwd()
51 core_resources = importlib.resources.files("langgate.core")
52 default_config_path = Path(
53 str(core_resources.joinpath("data", "default_config.yaml"))
54 )
55 cwd_config_path = cwd / "langgate_config.yaml"
56 cwd_env_path = cwd / ".env"
58 self.config_path = resolve_path(
59 "LANGGATE_CONFIG",
60 config_path,
61 cwd_config_path if cwd_config_path.exists() else default_config_path,
62 "config_path",
63 logger,
64 )
65 self.env_file_path = resolve_path(
66 "LANGGATE_ENV_FILE",
67 env_file_path,
68 cwd_env_path,
69 "env_file_path",
70 logger,
71 )
73 # Cache for configs
74 self._global_config: dict[str, Any] = {}
75 self._service_config: dict[str, dict[str, Any]] = {}
76 self._model_mappings: dict[str, dict[str, Any]] = {}
78 # Load configuration
79 self._load_config()
80 self._initialized = True
81 logger.debug("initialized_local_transformer_client")
83 def _load_config(self) -> None:
84 """Load configuration from file."""
85 config = load_yaml_config(self.config_path, ConfigSchema, logger)
87 if config:
88 # Extract validated data
89 self._global_config = {"default_params": config.default_params}
90 self._service_config = {
91 k: v.model_dump(exclude_none=True) for k, v in config.services.items()
92 }
93 self._process_model_mappings(config.models)
94 else:
95 self._set_empty_config()
97 # Load environment variables from .env file if it exists
98 if self.env_file_path.exists():
99 load_dotenv(self.env_file_path)
100 logger.debug("loaded_transformer_env_file", path=str(self.env_file_path))
102 def _set_empty_config(self) -> None:
103 """Set empty/default config state, typically used in error scenarios."""
104 self._global_config = {"default_params": {}}
105 self._service_config = {}
106 self._model_mappings = {}
108 def _process_model_mappings(self, models_config: list[ModelConfig]) -> None:
109 """Process model mappings from validated configuration.
111 Args:
112 models_config: List of validated model configurations
113 """
114 self._model_mappings = {}
116 for model_config in models_config:
117 model_data = model_config.model_dump(exclude_none=True)
118 model_id = model_data["id"]
119 service = model_data["service"]
121 # Store mapping info
122 self._model_mappings[model_id] = {
123 "service_provider": service["provider"],
124 "service_model_id": service["model_id"],
125 "default_params": model_data["default_params"],
126 "override_params": model_data["override_params"],
127 "remove_params": model_data["remove_params"],
128 "rename_params": model_data["rename_params"],
129 }
131 async def get_params(
132 self, model_id: str, input_params: dict[str, Any]
133 ) -> dict[str, Any]:
134 """
135 Get transformed parameters for the specified model.
137 Parameter precedence and transformation order:
139 Defaults (applied only if key doesn't exist yet):
140 1. Model-specific defaults (highest precedence for defaults)
141 2. Pattern defaults (matching patterns applied in config order)
142 3. Service provider defaults
143 4. Global defaults (lowest precedence for defaults)
145 Overrides/Removals/Renames (applied in order, later steps overwrite/modify earlier ones):
146 1. Input parameters (initial state)
147 2. Service-level API keys and base URLs
148 3. Service-level overrides, removals, renames
149 4. Pattern-level overrides, removals, renames (matching patterns applied in config order)
150 5. Model-specific overrides, removals, renames (highest precedence)
151 6. Model ID (always overwritten with service_model_id)
152 7. Environment variable substitution (applied last to all string values)
154 Args:
155 model_id: The ID of the model to get transformed parameters for
156 input_params: The parameters to transform
158 Returns:
159 The transformed parameters
161 Raises:
162 ValueError: If the model is not found in the configuration.
163 """
164 if model_id not in self._model_mappings:
165 logger.error("model_not_found", model_id=model_id)
166 raise ValueError(f"Model '{model_id}' not found in configuration")
168 mapping = self._model_mappings[model_id]
169 service_provider = mapping["service_provider"]
170 service_model_id = mapping["service_model_id"]
171 service_config = self._service_config.get(service_provider, {})
173 transformer = ParamTransformer()
175 # Collect pattern configurations
176 matching_pattern_defaults = []
177 matching_pattern_overrides = []
178 matching_pattern_removals = []
179 matching_pattern_renames = []
181 for pattern, pattern_config in service_config["model_patterns"].items():
182 if pattern in model_id:
183 if "default_params" in pattern_config:
184 matching_pattern_defaults.append(pattern_config["default_params"])
185 if "override_params" in pattern_config:
186 matching_pattern_overrides.append(pattern_config["override_params"])
187 if "remove_params" in pattern_config:
188 matching_pattern_removals.extend(pattern_config["remove_params"])
189 if "rename_params" in pattern_config:
190 matching_pattern_renames.append(pattern_config["rename_params"])
192 # Apply defaults (highest to lowest precedence)
193 transformer.with_defaults(mapping.get("default_params", {}))
194 for defaults in reversed(matching_pattern_defaults):
195 transformer.with_defaults(defaults)
196 transformer.with_defaults(service_config.get("default_params", {}))
197 transformer.with_defaults(self._global_config.get("default_params", {}))
199 # Apply overrides, removals, renames (service -> pattern -> model precedence)
201 # Service level transformations
202 api_key_val = service_config.get("api_key")
203 if (
204 api_key_val
205 and isinstance(api_key_val, str)
206 and api_key_val.startswith("${")
207 and api_key_val.endswith("}")
208 ):
209 env_var = api_key_val[2:-1]
210 env_val = os.environ.get(env_var)
211 if env_val is not None:
212 transformer.with_overrides({"api_key": SecretStr(env_val)})
213 else:
214 logger.warning(
215 "api_key_env_var_not_found",
216 variable=env_var,
217 service=service_provider,
218 )
219 elif api_key_val:
220 # Handle non-env var API keys if needed (though config should use env vars)
221 await logger.awarning(
222 "api_key_set_directly_in_config",
223 service=service_provider,
224 msg="API key set directly in config, not recommended",
225 )
226 transformer.with_overrides({"api_key": api_key_val})
228 base_url_val = service_config.get("base_url")
229 if base_url_val and isinstance(base_url_val, str):
230 transformer.with_overrides({"base_url": base_url_val})
232 # Service overrides, removals, renames
233 service_override_params = service_config.get("override_params", {})
234 if isinstance(service_override_params, dict):
235 transformer.with_overrides(service_override_params)
236 service_remove_params = service_config.get("remove_params", [])
237 if isinstance(service_remove_params, list):
238 transformer.removing(service_remove_params)
239 service_rename_params = service_config.get("rename_params", {})
240 if isinstance(service_rename_params, dict):
241 transformer.renaming(service_rename_params)
243 # Pattern level transformations (apply collected configs)
244 for overrides in matching_pattern_overrides:
245 transformer.with_overrides(overrides)
246 if matching_pattern_removals:
247 transformer.removing(matching_pattern_removals)
248 for renames in matching_pattern_renames:
249 transformer.renaming(renames)
251 # Model level transformations (highest precedence for overrides/removals/renames)
252 transformer.with_overrides(mapping.get("override_params", {}))
253 transformer.removing(mapping.get("remove_params", []))
254 transformer.renaming(mapping.get("rename_params", {}))
256 # Final transformations
257 # Set the final model ID to be sent to the service provider
258 transformer.with_model_id(service_model_id)
260 # Substitute environment variables in string values
261 transformer.with_env_vars()
263 # Execute Transformation
264 # Applies all chained transformations in order
265 return transformer.transform(input_params)