Coverage for packages/core/src/langgate/core/schemas/config.py: 95%

55 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-08-02 18:47 +0000

1"""Schema definitions for YAML configuration validation.""" 

2 

3from typing import Any, Literal 

4 

5from pydantic import ( 

6 BaseModel, 

7 ConfigDict, 

8 Field, 

9 SecretStr, 

10 field_validator, 

11 model_validator, 

12) 

13 

14from langgate.core.fields import UrlOrEnvVar 

15from langgate.core.models import Modality 

16 

17 

18class ServiceModelPatternConfig(BaseModel): 

19 """Configuration for model pattern matching within a service.""" 

20 

21 default_params: dict[str, Any] = Field(default_factory=dict) 

22 override_params: dict[str, Any] = Field(default_factory=dict) 

23 remove_params: list[str] = Field(default_factory=list) 

24 rename_params: dict[str, str] = Field(default_factory=dict) 

25 

26 

27class ServiceConfig(BaseModel): 

28 """Configuration for a service provider.""" 

29 

30 api_key: str | SecretStr 

31 base_url: UrlOrEnvVar | None = None 

32 api_format: str | None = None 

33 default_params: dict[str, Any] = Field(default_factory=dict) 

34 override_params: dict[str, Any] = Field(default_factory=dict) 

35 remove_params: list[str] = Field(default_factory=list) 

36 rename_params: dict[str, str] = Field(default_factory=dict) 

37 model_patterns: dict[str, ServiceModelPatternConfig] = Field(default_factory=dict) 

38 

39 model_config = ConfigDict(extra="allow") 

40 

41 @field_validator("api_key") 

42 def validate_api_key(cls, v): 

43 """Convert string API keys to SecretStr.""" 

44 if isinstance(v, str): 

45 # Keep environment variable references as strings 

46 if v.startswith("${") and v.endswith("}"): 

47 return v 

48 return SecretStr(v) 

49 return v 

50 

51 

52class ModelServiceConfig(BaseModel): 

53 """Service configuration for a specific model.""" 

54 

55 provider: str 

56 model_id: str 

57 

58 

59class ModelConfig(BaseModel): 

60 """Configuration for a specific model.""" 

61 

62 id: str 

63 service: ModelServiceConfig 

64 modality: Modality | None = None 

65 name: str | None = None 

66 description: str | None = None 

67 api_format: str | None = None 

68 default_params: dict[str, Any] = Field(default_factory=dict) 

69 override_params: dict[str, Any] = Field(default_factory=dict) 

70 remove_params: list[str] = Field(default_factory=list) 

71 rename_params: dict[str, str] = Field(default_factory=dict) 

72 

73 model_config = ConfigDict(extra="allow") 

74 

75 

76class ConfigSchema(BaseModel): 

77 """Root schema for the configuration YAML.""" 

78 

79 default_params: dict[str, Any] = Field(default_factory=dict) 

80 services: dict[str, ServiceConfig] = Field(default_factory=dict) 

81 models: dict[str, list[ModelConfig]] = Field(default_factory=dict) 

82 app_config: dict[str, Any] = Field(default_factory=dict) 

83 models_merge_mode: Literal["merge", "replace", "extend"] = "merge" 

84 

85 @model_validator(mode="after") 

86 def validate_model_service_providers(self) -> "ConfigSchema": 

87 """Validate that all model service providers exist in the services configuration.""" 

88 available_providers = set(self.services.keys()) 

89 

90 for modality, model_list in self.models.items(): 

91 for model in model_list: 

92 if model.service.provider not in available_providers: 

93 raise ValueError( 

94 f"Model '{model.id}' in modality '{modality}' references service provider '{model.service.provider}' " 

95 f"which is not defined in services. Available providers: {sorted(available_providers)}" 

96 ) 

97 

98 return self