Coverage for packages/transform/src/langgate/transform/local.py: 18%

118 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-04-09 21:23 +0000

1"""Local transformer client implementation.""" 

2 

3import importlib.resources 

4import os 

5from pathlib import Path 

6from typing import Any 

7 

8from dotenv import load_dotenv 

9from pydantic import SecretStr 

10 

11from langgate.core.logging import get_logger 

12from langgate.core.schemas.config import ConfigSchema, ModelConfig 

13from langgate.core.utils.config_utils import load_yaml_config, resolve_path 

14from langgate.transform.protocol import TransformerClientProtocol 

15from langgate.transform.transformer import ParamTransformer 

16 

17logger = get_logger(__name__) 

18 

19 

20class LocalTransformerClient(TransformerClientProtocol): 

21 """ 

22 Local transformer client for parameter transformations. 

23 

24 This client handles parameter transformations based on local configuration. 

25 """ 

26 

27 _instance = None 

28 _initialized: bool = False 

29 

30 def __new__(cls, *args, **kwargs): 

31 """Create or return the singleton instance.""" 

32 if cls._instance is None: 

33 cls._instance = super().__new__(cls) 

34 return cls._instance 

35 

36 def __init__( 

37 self, config_path: Path | None = None, env_file_path: Path | None = None 

38 ): 

39 """Initialize the client. 

40 

41 Args: 

42 config_path: Path to the configuration file 

43 env_file_path: Path to the environment file 

44 """ 

45 if hasattr(self, "_initialized") and self._initialized: 

46 return 

47 

48 # Set up default paths 

49 cwd = Path.cwd() 

50 

51 core_resources = importlib.resources.files("langgate.core") 

52 default_config_path = Path( 

53 str(core_resources.joinpath("data", "default_config.yaml")) 

54 ) 

55 cwd_config_path = cwd / "langgate_config.yaml" 

56 cwd_env_path = cwd / ".env" 

57 

58 self.config_path = resolve_path( 

59 "LANGGATE_CONFIG", 

60 config_path, 

61 cwd_config_path if cwd_config_path.exists() else default_config_path, 

62 "config_path", 

63 logger, 

64 ) 

65 self.env_file_path = resolve_path( 

66 "LANGGATE_ENV_FILE", 

67 env_file_path, 

68 cwd_env_path, 

69 "env_file_path", 

70 logger, 

71 ) 

72 

73 # Cache for configs 

74 self._global_config: dict[str, Any] = {} 

75 self._service_config: dict[str, dict[str, Any]] = {} 

76 self._model_mappings: dict[str, dict[str, Any]] = {} 

77 

78 # Load configuration 

79 self._load_config() 

80 self._initialized = True 

81 logger.debug("initialized_local_transformer_client") 

82 

83 def _load_config(self) -> None: 

84 """Load configuration from file.""" 

85 config = load_yaml_config(self.config_path, ConfigSchema, logger) 

86 

87 if config: 

88 # Extract validated data 

89 self._global_config = {"default_params": config.default_params} 

90 self._service_config = { 

91 k: v.model_dump(exclude_none=True) for k, v in config.services.items() 

92 } 

93 self._process_model_mappings(config.models) 

94 else: 

95 self._set_empty_config() 

96 

97 # Load environment variables from .env file if it exists 

98 if self.env_file_path.exists(): 

99 load_dotenv(self.env_file_path) 

100 logger.debug("loaded_transformer_env_file", path=str(self.env_file_path)) 

101 

102 def _set_empty_config(self) -> None: 

103 """Set empty/default config state, typically used in error scenarios.""" 

104 self._global_config = {"default_params": {}} 

105 self._service_config = {} 

106 self._model_mappings = {} 

107 

108 def _process_model_mappings(self, models_config: list[ModelConfig]) -> None: 

109 """Process model mappings from validated configuration. 

110 

111 Args: 

112 models_config: List of validated model configurations 

113 """ 

114 self._model_mappings = {} 

115 

116 for model_config in models_config: 

117 model_data = model_config.model_dump(exclude_none=True) 

118 model_id = model_data["id"] 

119 service = model_data["service"] 

120 

121 # Store mapping info 

122 self._model_mappings[model_id] = { 

123 "service_provider": service["provider"], 

124 "service_model_id": service["model_id"], 

125 "default_params": model_data["default_params"], 

126 "override_params": model_data["override_params"], 

127 "remove_params": model_data["remove_params"], 

128 "rename_params": model_data["rename_params"], 

129 } 

130 

131 async def get_params( 

132 self, model_id: str, input_params: dict[str, Any] 

133 ) -> dict[str, Any]: 

134 """ 

135 Get transformed parameters for the specified model. 

136 

137 Parameter precedence and transformation order: 

138 

139 Defaults (applied only if key doesn't exist yet): 

140 1. Model-specific defaults (highest precedence for defaults) 

141 2. Pattern defaults (matching patterns applied in config order) 

142 3. Service provider defaults 

143 4. Global defaults (lowest precedence for defaults) 

144 

145 Overrides/Removals/Renames (applied in order, later steps overwrite/modify earlier ones): 

146 1. Input parameters (initial state) 

147 2. Service-level API keys and base URLs 

148 3. Service-level overrides, removals, renames 

149 4. Pattern-level overrides, removals, renames (matching patterns applied in config order) 

150 5. Model-specific overrides, removals, renames (highest precedence) 

151 6. Model ID (always overwritten with service_model_id) 

152 7. Environment variable substitution (applied last to all string values) 

153 

154 Args: 

155 model_id: The ID of the model to get transformed parameters for 

156 input_params: The parameters to transform 

157 

158 Returns: 

159 The transformed parameters 

160 

161 Raises: 

162 ValueError: If the model is not found in the configuration. 

163 """ 

164 if model_id not in self._model_mappings: 

165 logger.error("model_not_found", model_id=model_id) 

166 raise ValueError(f"Model '{model_id}' not found in configuration") 

167 

168 mapping = self._model_mappings[model_id] 

169 service_provider = mapping["service_provider"] 

170 service_model_id = mapping["service_model_id"] 

171 service_config = self._service_config.get(service_provider, {}) 

172 

173 transformer = ParamTransformer() 

174 

175 # Collect pattern configurations 

176 matching_pattern_defaults = [] 

177 matching_pattern_overrides = [] 

178 matching_pattern_removals = [] 

179 matching_pattern_renames = [] 

180 

181 for pattern, pattern_config in service_config["model_patterns"].items(): 

182 if pattern in model_id: 

183 if "default_params" in pattern_config: 

184 matching_pattern_defaults.append(pattern_config["default_params"]) 

185 if "override_params" in pattern_config: 

186 matching_pattern_overrides.append(pattern_config["override_params"]) 

187 if "remove_params" in pattern_config: 

188 matching_pattern_removals.extend(pattern_config["remove_params"]) 

189 if "rename_params" in pattern_config: 

190 matching_pattern_renames.append(pattern_config["rename_params"]) 

191 

192 # Apply defaults (highest to lowest precedence) 

193 transformer.with_defaults(mapping.get("default_params", {})) 

194 for defaults in reversed(matching_pattern_defaults): 

195 transformer.with_defaults(defaults) 

196 transformer.with_defaults(service_config.get("default_params", {})) 

197 transformer.with_defaults(self._global_config.get("default_params", {})) 

198 

199 # Apply overrides, removals, renames (service -> pattern -> model precedence) 

200 

201 # Service level transformations 

202 api_key_val = service_config.get("api_key") 

203 if ( 

204 api_key_val 

205 and isinstance(api_key_val, str) 

206 and api_key_val.startswith("${") 

207 and api_key_val.endswith("}") 

208 ): 

209 env_var = api_key_val[2:-1] 

210 env_val = os.environ.get(env_var) 

211 if env_val is not None: 

212 transformer.with_overrides({"api_key": SecretStr(env_val)}) 

213 else: 

214 logger.warning( 

215 "api_key_env_var_not_found", 

216 variable=env_var, 

217 service=service_provider, 

218 ) 

219 elif api_key_val: 

220 # Handle non-env var API keys if needed (though config should use env vars) 

221 await logger.awarning( 

222 "api_key_set_directly_in_config", 

223 service=service_provider, 

224 msg="API key set directly in config, not recommended", 

225 ) 

226 transformer.with_overrides({"api_key": api_key_val}) 

227 

228 base_url_val = service_config.get("base_url") 

229 if base_url_val and isinstance(base_url_val, str): 

230 transformer.with_overrides({"base_url": base_url_val}) 

231 

232 # Service overrides, removals, renames 

233 service_override_params = service_config.get("override_params", {}) 

234 if isinstance(service_override_params, dict): 

235 transformer.with_overrides(service_override_params) 

236 service_remove_params = service_config.get("remove_params", []) 

237 if isinstance(service_remove_params, list): 

238 transformer.removing(service_remove_params) 

239 service_rename_params = service_config.get("rename_params", {}) 

240 if isinstance(service_rename_params, dict): 

241 transformer.renaming(service_rename_params) 

242 

243 # Pattern level transformations (apply collected configs) 

244 for overrides in matching_pattern_overrides: 

245 transformer.with_overrides(overrides) 

246 if matching_pattern_removals: 

247 transformer.removing(matching_pattern_removals) 

248 for renames in matching_pattern_renames: 

249 transformer.renaming(renames) 

250 

251 # Model level transformations (highest precedence for overrides/removals/renames) 

252 transformer.with_overrides(mapping.get("override_params", {})) 

253 transformer.removing(mapping.get("remove_params", [])) 

254 transformer.renaming(mapping.get("rename_params", {})) 

255 

256 # Final transformations 

257 # Set the final model ID to be sent to the service provider 

258 transformer.with_model_id(service_model_id) 

259 

260 # Substitute environment variables in string values 

261 transformer.with_env_vars() 

262 

263 # Execute Transformation 

264 # Applies all chained transformations in order 

265 return transformer.transform(input_params)