Coverage for packages/core/src/langgate/core/schemas/config.py: 94%

51 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-06-18 14:26 +0000

1"""Schema definitions for YAML configuration validation.""" 

2 

3from typing import Any 

4 

5from pydantic import ( 

6 BaseModel, 

7 ConfigDict, 

8 Field, 

9 SecretStr, 

10 field_validator, 

11 model_validator, 

12) 

13 

14from langgate.core.fields import UrlOrEnvVar 

15 

16 

17class ServiceModelPatternConfig(BaseModel): 

18 """Configuration for model pattern matching within a service.""" 

19 

20 default_params: dict[str, Any] = Field(default_factory=dict) 

21 override_params: dict[str, Any] = Field(default_factory=dict) 

22 remove_params: list[str] = Field(default_factory=list) 

23 rename_params: dict[str, str] = Field(default_factory=dict) 

24 

25 

26class ServiceConfig(BaseModel): 

27 """Configuration for a service provider.""" 

28 

29 api_key: str | SecretStr 

30 base_url: UrlOrEnvVar | None = None 

31 api_format: str | None = None 

32 default_params: dict[str, Any] = Field(default_factory=dict) 

33 override_params: dict[str, Any] = Field(default_factory=dict) 

34 remove_params: list[str] = Field(default_factory=list) 

35 rename_params: dict[str, str] = Field(default_factory=dict) 

36 model_patterns: dict[str, ServiceModelPatternConfig] = Field(default_factory=dict) 

37 

38 model_config = ConfigDict(extra="allow") 

39 

40 @field_validator("api_key") 

41 def validate_api_key(cls, v): 

42 """Convert string API keys to SecretStr.""" 

43 if isinstance(v, str): 

44 # Keep environment variable references as strings 

45 if v.startswith("${") and v.endswith("}"): 

46 return v 

47 return SecretStr(v) 

48 return v 

49 

50 

51class ModelServiceConfig(BaseModel): 

52 """Service configuration for a specific model.""" 

53 

54 provider: str 

55 model_id: str 

56 

57 

58class ModelConfig(BaseModel): 

59 """Configuration for a specific model.""" 

60 

61 id: str 

62 service: ModelServiceConfig 

63 name: str | None = None 

64 description: str | None = None 

65 api_format: str | None = None 

66 default_params: dict[str, Any] = Field(default_factory=dict) 

67 override_params: dict[str, Any] = Field(default_factory=dict) 

68 remove_params: list[str] = Field(default_factory=list) 

69 rename_params: dict[str, str] = Field(default_factory=dict) 

70 

71 model_config = ConfigDict(extra="allow") 

72 

73 

74class ConfigSchema(BaseModel): 

75 """Root schema for the configuration YAML.""" 

76 

77 default_params: dict[str, Any] = Field(default_factory=dict) 

78 services: dict[str, ServiceConfig] = Field(default_factory=dict) 

79 models: list[ModelConfig] = Field(default_factory=list) 

80 app_config: dict[str, Any] = Field(default_factory=dict) 

81 

82 @model_validator(mode="after") 

83 def validate_model_service_providers(self) -> "ConfigSchema": 

84 """Validate that all model service providers exist in the services configuration.""" 

85 available_providers = set(self.services.keys()) 

86 

87 for model in self.models: 

88 if model.service.provider not in available_providers: 

89 raise ValueError( 

90 f"Model '{model.id}' references service provider '{model.service.provider}' " 

91 f"which is not defined in services. Available providers: {sorted(available_providers)}" 

92 ) 

93 

94 return self