Coverage for packages/transform/src/langgate/transform/transformer.py: 21%
66 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-04-09 21:23 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-04-09 21:23 +0000
1"""Parameter transformation for LLM API calls."""
3import os
4from collections.abc import Callable
5from typing import Any
7from langgate.core.logging import get_logger
9logger = get_logger(__name__)
12class ParamTransformer:
13 """
14 Handle LLM API call parameter transformations with a fluent interface.
16 This transformer allows for a series of transformations to be applied to
17 a set of parameters in a specific order, including:
18 - Adding default values
19 - Applying overrides
20 - Removing parameters
21 - Renaming parameters
22 - Setting the model ID
23 - Substituting environment variables
24 """
26 def __init__(self, initial_params: dict[str, Any] | None = None):
27 self.params = initial_params or {}
28 self.transforms: list[Callable[[dict[str, Any]], dict[str, Any]]] = []
30 def with_defaults(self, defaults: dict[str, Any]) -> "ParamTransformer":
31 """Add default parameters that will be used if not present in input.
33 Args:
34 defaults: Dictionary of default parameters
36 Returns:
37 Self for method chaining
38 """
40 def apply_defaults(params: dict[str, Any]) -> dict[str, Any]:
41 result = params.copy()
42 for key, value in defaults.items():
43 if key not in result:
44 result[key] = value
45 return result
47 self.transforms.append(apply_defaults)
48 return self
50 def with_overrides(self, overrides: dict[str, Any]) -> "ParamTransformer":
51 """Add overrides that will replace any existing parameters."""
53 def apply_overrides(params: dict[str, Any]) -> dict[str, Any]:
54 result = params.copy()
55 for key, value in overrides.items():
56 if isinstance(value, dict) and isinstance(result.get(key), dict):
57 # Deep merge for nested dicts
58 result[key] = {**result[key], **value}
59 else:
60 result[key] = value
61 return result
63 self.transforms.append(apply_overrides)
64 return self
66 def removing(self, keys: list[str]) -> "ParamTransformer":
67 """Remove specified parameters."""
69 def apply_removals(params: dict[str, Any]) -> dict[str, Any]:
70 return {k: v for k, v in params.items() if k not in keys}
72 self.transforms.append(apply_removals)
73 return self
75 def renaming(self, key_map: dict[str, str]) -> "ParamTransformer":
76 """Rename parameters according to the key_map."""
78 def apply_renames(params: dict[str, Any]) -> dict[str, Any]:
79 result = params.copy()
80 for old_key, new_key in key_map.items():
81 if old_key in result:
82 result[new_key] = result.pop(old_key)
83 return result
85 self.transforms.append(apply_renames)
86 return self
88 def with_model_id(self, model_id: str) -> "ParamTransformer":
89 """Set the model parameter."""
90 return self.with_overrides({"model": model_id})
92 def with_env_vars(self) -> "ParamTransformer":
93 """Substitute environment variables in string values."""
95 def apply_env_vars(params: dict[str, Any]) -> dict[str, Any]:
96 def substitute(value):
97 if (
98 isinstance(value, str)
99 and value.startswith("${")
100 and value.endswith("}")
101 ):
102 env_var = value[2:-1]
103 if env_var in os.environ:
104 return os.environ[env_var]
105 logger.warning("env_variable_not_found", variable=env_var)
106 return value
107 if isinstance(value, dict):
108 return {k: substitute(v) for k, v in value.items()}
109 if isinstance(value, list):
110 return [substitute(item) for item in value]
111 return value
113 return {k: substitute(v) for k, v in params.items()}
115 self.transforms.append(apply_env_vars)
116 return self
118 def transform(self, input_params: dict[str, Any] | None = None) -> dict[str, Any]:
119 """Apply all transformations to input parameters."""
120 result = (input_params or {}).copy()
122 for transform in self.transforms:
123 result = transform(result)
125 return result