Coverage for packages/registry/src/langgate/registry/config.py: 70%

89 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-11 05:31 +0000

1"""Configuration handling for the registry.""" 

2 

3import importlib.resources 

4import json 

5from pathlib import Path 

6from typing import Any 

7 

8from dotenv import load_dotenv 

9 

10from langgate.core.logging import get_logger 

11from langgate.core.schemas.config import ConfigSchema 

12from langgate.core.utils.config_utils import load_yaml_config, resolve_path 

13 

14logger = get_logger(__name__) 

15 

16 

17class RegistryConfig: 

18 """Configuration handler for the registry.""" 

19 

20 def __init__( 

21 self, 

22 models_data_path: Path | None = None, 

23 config_path: Path | None = None, 

24 env_file_path: Path | None = None, 

25 ): 

26 """ 

27 Args: 

28 models_data_path: Path to the models data JSON file 

29 config_path: Path to the main configuration YAML file 

30 env_file_path: Path to a `.env` file for environment variables 

31 """ 

32 # Set up default paths 

33 cwd = Path.cwd() 

34 # Get package resource paths 

35 registry_resources = importlib.resources.files("langgate.registry") 

36 core_resources = importlib.resources.files("langgate.core") 

37 self.default_models_path = Path( 

38 str(registry_resources.joinpath("data", "default_models.json")) 

39 ) 

40 default_config_path = Path( 

41 str(core_resources.joinpath("data", "default_config.yaml")) 

42 ) 

43 

44 # Define default paths with priorities 

45 # Models data: args > env > cwd > package_dir 

46 cwd_models_path = cwd / "langgate_models.json" 

47 

48 # Config: args > env > cwd > package_dir 

49 cwd_config_path = cwd / "langgate_config.yaml" 

50 

51 # Env file: args > env > cwd 

52 cwd_env_path = cwd / ".env" 

53 

54 # Resolve paths using priority order 

55 self.models_data_path = resolve_path( 

56 "LANGGATE_MODELS", 

57 models_data_path, 

58 cwd_models_path if cwd_models_path.exists() else self.default_models_path, 

59 "models_data_path", 

60 ) 

61 

62 self.config_path = resolve_path( 

63 "LANGGATE_CONFIG", 

64 config_path, 

65 cwd_config_path if cwd_config_path.exists() else default_config_path, 

66 "config_path", 

67 ) 

68 

69 self.env_file_path = resolve_path( 

70 "LANGGATE_ENV_FILE", env_file_path, cwd_env_path, "env_file_path" 

71 ) 

72 

73 # Load environment variables from .env file if it exists 

74 if self.env_file_path.exists(): 

75 load_dotenv(self.env_file_path) 

76 logger.debug("loaded_env_file", path=str(self.env_file_path)) 

77 

78 # Initialize data structures 

79 self.models_data: dict[str, dict[str, Any]] = {} 

80 self.global_config: dict[str, Any] = {} 

81 self.service_config: dict[str, dict[str, Any]] = {} 

82 self.model_mappings: dict[str, dict[str, Any]] = {} 

83 self.models_merge_mode: str = "merge" # Default value 

84 

85 # Load configuration 

86 self._load_config() 

87 

88 def _load_config(self) -> None: 

89 """Load configuration from files.""" 

90 try: 

91 # Load main configuration first to get merge mode 

92 self._load_main_config() 

93 

94 # Load model data with merge mode available 

95 self._load_model_data() 

96 

97 except Exception: 

98 logger.exception( 

99 "failed_to_load_config", 

100 models_data_path=str(self.models_data_path), 

101 config_path=str(self.config_path), 

102 ) 

103 raise 

104 

105 def _has_user_models(self) -> bool: 

106 """Check if user has specified custom models.""" 

107 return ( 

108 self.models_data_path != self.default_models_path 

109 and self.models_data_path.exists() 

110 ) 

111 

112 def _load_json_file(self, path: Path) -> dict[str, dict[str, Any]]: 

113 """Load and return JSON data from file.""" 

114 try: 

115 with open(path) as f: 

116 return json.load(f) 

117 except FileNotFoundError: 

118 logger.warning( 

119 "model_data_file_not_found", 

120 models_data_path=str(path), 

121 ) 

122 return {} 

123 

124 def _merge_models( 

125 self, 

126 default_models: dict[str, dict[str, Any]], 

127 user_models: dict[str, dict[str, Any]], 

128 merge_mode: str, 

129 ) -> dict[str, dict[str, Any]]: 

130 """Merge user models with default models based on merge mode.""" 

131 if merge_mode == "merge": 

132 # User models override defaults, new models are added 

133 merged = default_models.copy() 

134 for model_id, model_config in user_models.items(): 

135 if model_id in merged: 

136 logger.debug( 

137 "overriding_default_model", 

138 model_id=model_id, 

139 original_name=merged[model_id].get("name"), 

140 new_name=model_config.get("name"), 

141 ) 

142 merged[model_id] = model_config 

143 return merged 

144 

145 if merge_mode == "extend": 

146 # User models are added, conflicts with defaults cause errors 

147 conflicts = set(default_models.keys()) & set(user_models.keys()) 

148 if conflicts: 

149 raise ValueError( 

150 f"Model ID conflicts found in extend mode: {', '.join(conflicts)}. " 

151 f"Use 'merge' mode to allow overrides or rename conflicting models." 

152 ) 

153 return {**default_models, **user_models} 

154 

155 raise ValueError(f"Unknown merge mode: {merge_mode}") 

156 

157 def _load_model_data(self) -> None: 

158 """Load model data from JSON file(s) based on merge mode.""" 

159 try: 

160 # Always load default models first 

161 default_models = self._load_json_file(self.default_models_path) 

162 

163 has_user_models = self._has_user_models() 

164 

165 if self.models_merge_mode == "replace" or not has_user_models: 

166 # Use only user models or default if no user models 

167 if has_user_models: 

168 self.models_data = self._load_json_file(self.models_data_path) 

169 else: 

170 self.models_data = default_models 

171 else: 

172 # merge user models with defaults 

173 user_models = ( 

174 self._load_json_file(self.models_data_path) 

175 if has_user_models 

176 else {} 

177 ) 

178 self.models_data = self._merge_models( 

179 default_models, user_models, self.models_merge_mode 

180 ) 

181 

182 logger.info( 

183 "loaded_model_data", 

184 models_data_path=str(self.models_data_path), 

185 merge_mode=self.models_merge_mode, 

186 model_count=len(self.models_data), 

187 ) 

188 except Exception: 

189 logger.exception("failed_to_load_model_data") 

190 raise 

191 

192 def _load_main_config(self) -> None: 

193 """Load main configuration from YAML file.""" 

194 config = load_yaml_config(self.config_path, ConfigSchema, logger) 

195 

196 # Extract merge mode 

197 self.models_merge_mode = config.models_merge_mode 

198 

199 # Extract validated data 

200 self.global_config = { 

201 "default_params": config.default_params, 

202 "models_merge_mode": config.models_merge_mode, 

203 } 

204 

205 # Extract service provider config 

206 self.service_config = { 

207 k: v.model_dump(exclude_none=True) for k, v in config.services.items() 

208 } 

209 

210 # Process model mappings 

211 self._process_model_mappings(config.models) 

212 

213 def _process_model_mappings(self, models_config: dict[str, list]) -> None: 

214 """Process model mappings from validated configuration. 

215 

216 Args: 

217 models_config: Dict of modality to list of model configurations 

218 """ 

219 self.model_mappings = {} 

220 

221 for _, model_list in models_config.items(): 

222 for model_config in model_list: 

223 model_data = model_config.model_dump(exclude_none=True) 

224 model_id = model_data["id"] 

225 service = model_data["service"] 

226 

227 # Store mapping info with proper type handling 

228 self.model_mappings[model_id] = { 

229 "service_provider": service["provider"], 

230 "service_model_id": service["model_id"], 

231 "override_params": model_data.get("override_params", {}), 

232 "remove_params": model_data.get("remove_params", []), 

233 "rename_params": model_data.get("rename_params", {}), 

234 "name": model_data.get("name"), 

235 "model_provider": model_data.get("model_provider"), 

236 "model_provider_name": model_data.get("model_provider_name"), 

237 "description": model_data.get("description"), 

238 }