Coverage for packages/transform/src/langgate/transform/transformer.py: 21%

66 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-04-09 21:23 +0000

1"""Parameter transformation for LLM API calls.""" 

2 

3import os 

4from collections.abc import Callable 

5from typing import Any 

6 

7from langgate.core.logging import get_logger 

8 

9logger = get_logger(__name__) 

10 

11 

12class ParamTransformer: 

13 """ 

14 Handle LLM API call parameter transformations with a fluent interface. 

15 

16 This transformer allows for a series of transformations to be applied to 

17 a set of parameters in a specific order, including: 

18 - Adding default values 

19 - Applying overrides 

20 - Removing parameters 

21 - Renaming parameters 

22 - Setting the model ID 

23 - Substituting environment variables 

24 """ 

25 

26 def __init__(self, initial_params: dict[str, Any] | None = None): 

27 self.params = initial_params or {} 

28 self.transforms: list[Callable[[dict[str, Any]], dict[str, Any]]] = [] 

29 

30 def with_defaults(self, defaults: dict[str, Any]) -> "ParamTransformer": 

31 """Add default parameters that will be used if not present in input. 

32 

33 Args: 

34 defaults: Dictionary of default parameters 

35 

36 Returns: 

37 Self for method chaining 

38 """ 

39 

40 def apply_defaults(params: dict[str, Any]) -> dict[str, Any]: 

41 result = params.copy() 

42 for key, value in defaults.items(): 

43 if key not in result: 

44 result[key] = value 

45 return result 

46 

47 self.transforms.append(apply_defaults) 

48 return self 

49 

50 def with_overrides(self, overrides: dict[str, Any]) -> "ParamTransformer": 

51 """Add overrides that will replace any existing parameters.""" 

52 

53 def apply_overrides(params: dict[str, Any]) -> dict[str, Any]: 

54 result = params.copy() 

55 for key, value in overrides.items(): 

56 if isinstance(value, dict) and isinstance(result.get(key), dict): 

57 # Deep merge for nested dicts 

58 result[key] = {**result[key], **value} 

59 else: 

60 result[key] = value 

61 return result 

62 

63 self.transforms.append(apply_overrides) 

64 return self 

65 

66 def removing(self, keys: list[str]) -> "ParamTransformer": 

67 """Remove specified parameters.""" 

68 

69 def apply_removals(params: dict[str, Any]) -> dict[str, Any]: 

70 return {k: v for k, v in params.items() if k not in keys} 

71 

72 self.transforms.append(apply_removals) 

73 return self 

74 

75 def renaming(self, key_map: dict[str, str]) -> "ParamTransformer": 

76 """Rename parameters according to the key_map.""" 

77 

78 def apply_renames(params: dict[str, Any]) -> dict[str, Any]: 

79 result = params.copy() 

80 for old_key, new_key in key_map.items(): 

81 if old_key in result: 

82 result[new_key] = result.pop(old_key) 

83 return result 

84 

85 self.transforms.append(apply_renames) 

86 return self 

87 

88 def with_model_id(self, model_id: str) -> "ParamTransformer": 

89 """Set the model parameter.""" 

90 return self.with_overrides({"model": model_id}) 

91 

92 def with_env_vars(self) -> "ParamTransformer": 

93 """Substitute environment variables in string values.""" 

94 

95 def apply_env_vars(params: dict[str, Any]) -> dict[str, Any]: 

96 def substitute(value): 

97 if ( 

98 isinstance(value, str) 

99 and value.startswith("${") 

100 and value.endswith("}") 

101 ): 

102 env_var = value[2:-1] 

103 if env_var in os.environ: 

104 return os.environ[env_var] 

105 logger.warning("env_variable_not_found", variable=env_var) 

106 return value 

107 if isinstance(value, dict): 

108 return {k: substitute(v) for k, v in value.items()} 

109 if isinstance(value, list): 

110 return [substitute(item) for item in value] 

111 return value 

112 

113 return {k: substitute(v) for k, v in params.items()} 

114 

115 self.transforms.append(apply_env_vars) 

116 return self 

117 

118 def transform(self, input_params: dict[str, Any] | None = None) -> dict[str, Any]: 

119 """Apply all transformations to input parameters.""" 

120 result = (input_params or {}).copy() 

121 

122 for transform in self.transforms: 

123 result = transform(result) 

124 

125 return result