Coverage for packages/transform/src/langgate/transform/local.py: 17%

115 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-06-18 14:26 +0000

1"""Local transformer client implementation.""" 

2 

3import importlib.resources 

4import os 

5from pathlib import Path 

6from typing import Any 

7 

8from dotenv import load_dotenv 

9from pydantic import SecretStr 

10 

11from langgate.core.logging import get_logger 

12from langgate.core.schemas.config import ConfigSchema, ModelConfig 

13from langgate.core.utils.config_utils import load_yaml_config, resolve_path 

14from langgate.transform.protocol import TransformerClientProtocol 

15from langgate.transform.transformer import ParamTransformer 

16 

17logger = get_logger(__name__) 

18 

19 

20class LocalTransformerClient(TransformerClientProtocol): 

21 """ 

22 Local transformer client for parameter transformations. 

23 

24 This client handles parameter transformations based on local configuration. 

25 """ 

26 

27 _instance = None 

28 _initialized: bool = False 

29 

30 def __new__(cls, *args, **kwargs): 

31 """Create or return the singleton instance.""" 

32 if cls._instance is None: 

33 cls._instance = super().__new__(cls) 

34 return cls._instance 

35 

36 def __init__( 

37 self, config_path: Path | None = None, env_file_path: Path | None = None 

38 ): 

39 """Initialize the client. 

40 

41 Args: 

42 config_path: Path to the configuration file 

43 env_file_path: Path to the environment file 

44 """ 

45 if hasattr(self, "_initialized") and self._initialized: 

46 return 

47 

48 # Set up default paths 

49 cwd = Path.cwd() 

50 

51 core_resources = importlib.resources.files("langgate.core") 

52 default_config_path = Path( 

53 str(core_resources.joinpath("data", "default_config.yaml")) 

54 ) 

55 cwd_config_path = cwd / "langgate_config.yaml" 

56 cwd_env_path = cwd / ".env" 

57 

58 self.config_path = resolve_path( 

59 "LANGGATE_CONFIG", 

60 config_path, 

61 cwd_config_path if cwd_config_path.exists() else default_config_path, 

62 "config_path", 

63 logger, 

64 ) 

65 self.env_file_path = resolve_path( 

66 "LANGGATE_ENV_FILE", 

67 env_file_path, 

68 cwd_env_path, 

69 "env_file_path", 

70 logger, 

71 ) 

72 

73 # Cache for configs 

74 self._global_config: dict[str, Any] = {} 

75 self._service_config: dict[str, dict[str, Any]] = {} 

76 self._model_mappings: dict[str, dict[str, Any]] = {} 

77 

78 # Load configuration 

79 self._load_config() 

80 self._initialized = True 

81 logger.debug("initialized_local_transformer_client") 

82 

83 def _load_config(self) -> None: 

84 """Load configuration from file.""" 

85 config = load_yaml_config(self.config_path, ConfigSchema, logger) 

86 

87 # Extract validated data 

88 self._global_config = {"default_params": config.default_params} 

89 self._service_config = { 

90 k: v.model_dump(exclude_none=True) for k, v in config.services.items() 

91 } 

92 self._process_model_mappings(config.models) 

93 

94 # Load environment variables from .env file if it exists 

95 if self.env_file_path.exists(): 

96 load_dotenv(self.env_file_path) 

97 logger.debug("loaded_transformer_env_file", path=str(self.env_file_path)) 

98 

99 def _process_model_mappings(self, models_config: list[ModelConfig]) -> None: 

100 """Process model mappings from validated configuration. 

101 

102 Args: 

103 models_config: List of validated model configurations 

104 """ 

105 self._model_mappings = {} 

106 

107 for model_config in models_config: 

108 model_data = model_config.model_dump(exclude_none=True) 

109 model_id = model_data["id"] 

110 service = model_data["service"] 

111 

112 # Store mapping info 

113 self._model_mappings[model_id] = { 

114 "service_provider": service["provider"], 

115 "service_model_id": service["model_id"], 

116 "api_format": model_data.get("api_format"), 

117 "default_params": model_data["default_params"], 

118 "override_params": model_data["override_params"], 

119 "remove_params": model_data["remove_params"], 

120 "rename_params": model_data["rename_params"], 

121 } 

122 

123 async def get_params( 

124 self, model_id: str, input_params: dict[str, Any] 

125 ) -> tuple[str, dict[str, Any]]: 

126 """ 

127 Get transformed parameters for the specified model. 

128 

129 Parameter precedence and transformation order: 

130 

131 Defaults (applied only if key doesn't exist yet): 

132 1. Model-specific defaults (highest precedence for defaults) 

133 2. Pattern defaults (matching patterns applied in config order) 

134 3. Service provider defaults 

135 4. Global defaults (lowest precedence for defaults) 

136 

137 Overrides/Removals/Renames (applied in order, later steps overwrite/modify earlier ones): 

138 1. Input parameters (initial state) 

139 2. Service-level API keys and base URLs 

140 3. Service-level overrides, removals, renames 

141 4. Pattern-level overrides, removals, renames (matching patterns applied in config order) 

142 5. Model-specific overrides, removals, renames (highest precedence) 

143 6. Model ID (always overwritten with service_model_id) 

144 7. Environment variable substitution (applied last to all string values) 

145 

146 Args: 

147 model_id: The ID of the model to get transformed parameters for 

148 input_params: The parameters to transform 

149 

150 Returns: 

151 A tuple containing (api_format, transformed_parameters) 

152 

153 Raises: 

154 ValueError: If the model is not found in the configuration. 

155 """ 

156 if model_id not in self._model_mappings: 

157 logger.error("model_not_found", model_id=model_id) 

158 raise ValueError(f"Model '{model_id}' not found in configuration") 

159 

160 mapping = self._model_mappings[model_id] 

161 service_provider = mapping["service_provider"] 

162 service_model_id = mapping["service_model_id"] 

163 service_config = self._service_config.get(service_provider, {}) 

164 

165 transformer = ParamTransformer() 

166 

167 # Collect pattern configurations 

168 matching_pattern_defaults = [] 

169 matching_pattern_overrides = [] 

170 matching_pattern_removals = [] 

171 matching_pattern_renames = [] 

172 

173 model_patterns = service_config.get("model_patterns", {}) 

174 for pattern, pattern_config in model_patterns.items(): 

175 if pattern in model_id: 

176 if "default_params" in pattern_config: 

177 matching_pattern_defaults.append(pattern_config["default_params"]) 

178 if "override_params" in pattern_config: 

179 matching_pattern_overrides.append(pattern_config["override_params"]) 

180 if "remove_params" in pattern_config: 

181 matching_pattern_removals.extend(pattern_config["remove_params"]) 

182 if "rename_params" in pattern_config: 

183 matching_pattern_renames.append(pattern_config["rename_params"]) 

184 

185 # Apply defaults (highest to lowest precedence) 

186 transformer.with_defaults(mapping.get("default_params", {})) 

187 for defaults in reversed(matching_pattern_defaults): 

188 transformer.with_defaults(defaults) 

189 transformer.with_defaults(service_config.get("default_params", {})) 

190 transformer.with_defaults(self._global_config.get("default_params", {})) 

191 

192 # Apply overrides, removals, renames (service -> pattern -> model precedence) 

193 

194 # Service level transformations 

195 api_key_val = service_config.get("api_key") 

196 if ( 

197 api_key_val 

198 and isinstance(api_key_val, str) 

199 and api_key_val.startswith("${") 

200 and api_key_val.endswith("}") 

201 ): 

202 env_var = api_key_val[2:-1] 

203 env_val = os.environ.get(env_var) 

204 if env_val is not None: 

205 transformer.with_overrides({"api_key": SecretStr(env_val)}) 

206 else: 

207 logger.warning( 

208 "api_key_env_var_not_found", 

209 variable=env_var, 

210 service=service_provider, 

211 ) 

212 elif api_key_val: 

213 # Handle non-env var API keys if needed (though config should use env vars) 

214 await logger.awarning( 

215 "api_key_set_directly_in_config", 

216 service=service_provider, 

217 msg="API key set directly in config, not recommended", 

218 ) 

219 transformer.with_overrides({"api_key": api_key_val}) 

220 

221 base_url_val = service_config.get("base_url") 

222 if base_url_val and isinstance(base_url_val, str): 

223 transformer.with_overrides({"base_url": base_url_val}) 

224 

225 # Service overrides, removals, renames 

226 service_override_params = service_config.get("override_params", {}) 

227 if isinstance(service_override_params, dict): 

228 transformer.with_overrides(service_override_params) 

229 service_remove_params = service_config.get("remove_params", []) 

230 if isinstance(service_remove_params, list): 

231 transformer.removing(service_remove_params) 

232 service_rename_params = service_config.get("rename_params", {}) 

233 if isinstance(service_rename_params, dict): 

234 transformer.renaming(service_rename_params) 

235 

236 # Pattern level transformations (apply collected configs) 

237 for overrides in matching_pattern_overrides: 

238 transformer.with_overrides(overrides) 

239 if matching_pattern_removals: 

240 transformer.removing(matching_pattern_removals) 

241 for renames in matching_pattern_renames: 

242 transformer.renaming(renames) 

243 

244 # Model level transformations (highest precedence for overrides/removals/renames) 

245 transformer.with_overrides(mapping.get("override_params", {})) 

246 transformer.removing(mapping.get("remove_params", [])) 

247 transformer.renaming(mapping.get("rename_params", {})) 

248 

249 # Final transformations 

250 # Set the final model ID to be sent to the service provider 

251 transformer.with_model_id(service_model_id) 

252 

253 # Substitute environment variables in string values 

254 transformer.with_env_vars() 

255 

256 # Resolve API format based on fallback hierarchy: model.api_format → service.api_format → service.provider 

257 api_format = ( 

258 mapping.get("api_format") 

259 or service_config.get("api_format") 

260 or service_provider 

261 ) 

262 

263 # Execute Transformation 

264 # Applies all chained transformations in order 

265 transformed_params = transformer.transform(input_params) 

266 

267 return api_format, transformed_params