Coverage for packages/transform/src/langgate/transform/local.py: 16%

128 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-08-29 16:00 +0000

1"""Local transformer client implementation.""" 

2 

3import importlib.resources 

4import os 

5from pathlib import Path 

6from typing import Any 

7 

8from dotenv import load_dotenv 

9from pydantic import SecretStr 

10 

11from langgate.core.logging import get_logger 

12from langgate.core.models import Modality 

13from langgate.core.schemas.config import ConfigSchema, ModelConfig 

14from langgate.core.utils.config_utils import load_yaml_config, resolve_path 

15from langgate.transform.protocol import TransformerClientProtocol 

16from langgate.transform.transformer import ParamTransformer 

17 

18logger = get_logger(__name__) 

19 

20 

21class LocalTransformerClient(TransformerClientProtocol): 

22 """ 

23 Local transformer client for parameter transformations. 

24 

25 This client handles parameter transformations based on local configuration. 

26 """ 

27 

28 _instance = None 

29 _initialized: bool = False 

30 

31 def __new__(cls, *args, **kwargs): 

32 """Create or return the singleton instance.""" 

33 if cls._instance is None: 

34 cls._instance = super().__new__(cls) 

35 return cls._instance 

36 

37 def __init__( 

38 self, config_path: Path | None = None, env_file_path: Path | None = None 

39 ): 

40 """Initialize the client. 

41 

42 Args: 

43 config_path: Path to the configuration file 

44 env_file_path: Path to the environment file 

45 """ 

46 if hasattr(self, "_initialized") and self._initialized: 

47 return 

48 

49 # Set up default paths 

50 cwd = Path.cwd() 

51 

52 core_resources = importlib.resources.files("langgate.core") 

53 default_config_path = Path( 

54 str(core_resources.joinpath("data", "default_config.yaml")) 

55 ) 

56 cwd_config_path = cwd / "langgate_config.yaml" 

57 cwd_env_path = cwd / ".env" 

58 

59 self.config_path = resolve_path( 

60 "LANGGATE_CONFIG", 

61 config_path, 

62 cwd_config_path if cwd_config_path.exists() else default_config_path, 

63 "config_path", 

64 logger, 

65 ) 

66 self.env_file_path = resolve_path( 

67 "LANGGATE_ENV_FILE", 

68 env_file_path, 

69 cwd_env_path, 

70 "env_file_path", 

71 logger, 

72 ) 

73 

74 # Cache for configs 

75 self._global_config: dict[str, Any] = {} 

76 self._service_config: dict[str, dict[str, Any]] = {} 

77 self._model_mappings: dict[str, dict[str, Any]] = {} 

78 

79 # Load configuration 

80 self._load_config() 

81 self._initialized = True 

82 logger.debug("initialized_local_transformer_client") 

83 

84 def _load_config(self) -> None: 

85 """Load configuration from file.""" 

86 config = load_yaml_config(self.config_path, ConfigSchema, logger) 

87 

88 # Extract validated data 

89 self._global_config = {"default_params": config.default_params} 

90 self._service_config = { 

91 k: v.model_dump(exclude_none=True) for k, v in config.services.items() 

92 } 

93 self._process_model_mappings(config.models) 

94 

95 # Load environment variables from .env file if it exists 

96 if self.env_file_path.exists(): 

97 load_dotenv(self.env_file_path) 

98 logger.debug("loaded_transformer_env_file", path=str(self.env_file_path)) 

99 

100 def _process_model_mappings( 

101 self, models_config: dict[str, list[ModelConfig]] 

102 ) -> None: 

103 """Process model mappings from validated configuration. 

104 

105 Args: 

106 models_config: Dict of modality to list of model configurations 

107 """ 

108 self._model_mappings = {} 

109 

110 for modality, model_list in models_config.items(): 

111 for model_config in model_list: 

112 model_data = model_config.model_dump(exclude_none=True) 

113 model_id = model_data["id"] 

114 service = model_data["service"] 

115 

116 # Store mapping info including modality 

117 self._model_mappings[model_id] = { 

118 "modality": modality, 

119 "service_provider": service["provider"], 

120 "service_model_id": service["model_id"], 

121 "api_format": model_data.get("api_format"), 

122 "default_params": model_data["default_params"], 

123 "override_params": model_data["override_params"], 

124 "remove_params": model_data["remove_params"], 

125 "rename_params": model_data["rename_params"], 

126 } 

127 

128 async def get_params( 

129 self, model_id: str, input_params: dict[str, Any] 

130 ) -> tuple[str, dict[str, Any]]: 

131 """ 

132 Get transformed parameters for the specified model. 

133 

134 Parameter precedence and transformation order: 

135 

136 Defaults (applied only if key doesn't exist yet): 

137 1. Model-specific defaults (highest precedence for defaults) 

138 2. Pattern defaults (matching patterns applied in config order) 

139 3. Service provider defaults 

140 4. Global defaults (lowest precedence for defaults) 

141 

142 Overrides/Removals/Renames (applied in order, later steps overwrite/modify earlier ones): 

143 1. Input parameters (initial state) 

144 2. Service-level API keys and base URLs 

145 3. Service-level overrides, removals, renames 

146 4. Pattern-level overrides, removals, renames (matching patterns applied in config order) 

147 5. Model-specific overrides, removals, renames (highest precedence) 

148 6. Model ID (always overwritten with service_model_id) 

149 7. Environment variable substitution (applied last to all string values) 

150 

151 Args: 

152 model_id: The ID of the model to get transformed parameters for 

153 input_params: The parameters to transform 

154 

155 Returns: 

156 A tuple containing (api_format, transformed_parameters) 

157 

158 Raises: 

159 ValueError: If the model is not found in the configuration. 

160 """ 

161 if model_id not in self._model_mappings: 

162 logger.error("model_not_found", model_id=model_id) 

163 raise ValueError(f"Model '{model_id}' not found in configuration") 

164 

165 mapping = self._model_mappings[model_id] 

166 service_provider = mapping["service_provider"] 

167 service_model_id = mapping["service_model_id"] 

168 service_config = self._service_config.get(service_provider, {}) 

169 

170 transformer = ParamTransformer() 

171 

172 # Collect pattern configurations 

173 matching_pattern_defaults = [] 

174 matching_pattern_overrides = [] 

175 matching_pattern_removals = [] 

176 matching_pattern_renames = [] 

177 

178 model_patterns = service_config.get("model_patterns", {}) 

179 for pattern, pattern_config in model_patterns.items(): 

180 if pattern in model_id: 

181 if "default_params" in pattern_config: 

182 matching_pattern_defaults.append(pattern_config["default_params"]) 

183 if "override_params" in pattern_config: 

184 matching_pattern_overrides.append(pattern_config["override_params"]) 

185 if "remove_params" in pattern_config: 

186 matching_pattern_removals.extend(pattern_config["remove_params"]) 

187 if "rename_params" in pattern_config: 

188 matching_pattern_renames.append(pattern_config["rename_params"]) 

189 

190 # Apply defaults (highest to lowest precedence) 

191 transformer.with_defaults(mapping.get("default_params", {})) 

192 for defaults in reversed(matching_pattern_defaults): 

193 transformer.with_defaults(defaults) 

194 

195 # Apply service-level defaults 

196 service_defaults = service_config.get("default_params", {}) 

197 model_modality = mapping.get("modality") 

198 if service_defaults: 

199 # Check if service_defaults is structured by modality 

200 is_modality_structured = any( 

201 key in [m.value for m in Modality] for key in service_defaults 

202 ) 

203 

204 if is_modality_structured and model_modality: 

205 # Apply modality-specific defaults if available 

206 modality_defaults = service_defaults.get(model_modality, {}) 

207 if modality_defaults: 

208 transformer.with_defaults(modality_defaults) 

209 elif not is_modality_structured: 

210 # Apply flat defaults directly 

211 transformer.with_defaults(service_defaults) 

212 

213 # Apply modality-specific global defaults 

214 global_defaults = self._global_config.get("default_params", {}) 

215 modality_defaults = ( 

216 global_defaults.get(model_modality, {}) if model_modality else {} 

217 ) 

218 transformer.with_defaults(modality_defaults) 

219 

220 # Apply overrides, removals, renames (service -> pattern -> model precedence) 

221 

222 # Service level transformations 

223 api_key_val = service_config.get("api_key") 

224 if ( 

225 api_key_val 

226 and isinstance(api_key_val, str) 

227 and api_key_val.startswith("${") 

228 and api_key_val.endswith("}") 

229 ): 

230 env_var = api_key_val[2:-1] 

231 env_val = os.environ.get(env_var) 

232 if env_val is not None: 

233 transformer.with_overrides({"api_key": SecretStr(env_val)}) 

234 else: 

235 logger.warning( 

236 "api_key_env_var_not_found", 

237 variable=env_var, 

238 service=service_provider, 

239 ) 

240 elif api_key_val: 

241 # Handle non-env var API keys if needed (though config should use env vars) 

242 await logger.awarning( 

243 "api_key_set_directly_in_config", 

244 service=service_provider, 

245 msg="API key set directly in config, not recommended", 

246 ) 

247 transformer.with_overrides({"api_key": api_key_val}) 

248 

249 base_url_val = service_config.get("base_url") 

250 if base_url_val and isinstance(base_url_val, str): 

251 transformer.with_overrides({"base_url": base_url_val}) 

252 

253 # Service overrides, removals, renames 

254 service_override_params = service_config.get("override_params", {}) 

255 if isinstance(service_override_params, dict): 

256 transformer.with_overrides(service_override_params) 

257 service_remove_params = service_config.get("remove_params", []) 

258 if isinstance(service_remove_params, list): 

259 transformer.removing(service_remove_params) 

260 service_rename_params = service_config.get("rename_params", {}) 

261 if isinstance(service_rename_params, dict): 

262 transformer.renaming(service_rename_params) 

263 

264 # Pattern level transformations (apply collected configs) 

265 for overrides in matching_pattern_overrides: 

266 transformer.with_overrides(overrides) 

267 if matching_pattern_removals: 

268 transformer.removing(matching_pattern_removals) 

269 for renames in matching_pattern_renames: 

270 transformer.renaming(renames) 

271 

272 # Model level transformations (highest precedence for overrides/removals/renames) 

273 transformer.with_overrides(mapping.get("override_params", {})) 

274 transformer.removing(mapping.get("remove_params", [])) 

275 transformer.renaming(mapping.get("rename_params", {})) 

276 

277 # Final transformations 

278 # Set the final model ID to be sent to the service provider 

279 transformer.with_model_id(service_model_id) 

280 

281 # Substitute environment variables in string values 

282 transformer.with_env_vars() 

283 

284 # Resolve API format based on fallback hierarchy: model.api_format → service.api_format → service.provider 

285 api_format = ( 

286 mapping.get("api_format") 

287 or service_config.get("api_format") 

288 or service_provider 

289 ) 

290 

291 # Execute Transformation 

292 # Applies all chained transformations in order 

293 transformed_params = transformer.transform(input_params) 

294 

295 return api_format, transformed_params