Coverage for packages/core/src/langgate/core/schemas/config.py: 95%
55 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-08-02 18:47 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-08-02 18:47 +0000
1"""Schema definitions for YAML configuration validation."""
3from typing import Any, Literal
5from pydantic import (
6 BaseModel,
7 ConfigDict,
8 Field,
9 SecretStr,
10 field_validator,
11 model_validator,
12)
14from langgate.core.fields import UrlOrEnvVar
15from langgate.core.models import Modality
18class ServiceModelPatternConfig(BaseModel):
19 """Configuration for model pattern matching within a service."""
21 default_params: dict[str, Any] = Field(default_factory=dict)
22 override_params: dict[str, Any] = Field(default_factory=dict)
23 remove_params: list[str] = Field(default_factory=list)
24 rename_params: dict[str, str] = Field(default_factory=dict)
27class ServiceConfig(BaseModel):
28 """Configuration for a service provider."""
30 api_key: str | SecretStr
31 base_url: UrlOrEnvVar | None = None
32 api_format: str | None = None
33 default_params: dict[str, Any] = Field(default_factory=dict)
34 override_params: dict[str, Any] = Field(default_factory=dict)
35 remove_params: list[str] = Field(default_factory=list)
36 rename_params: dict[str, str] = Field(default_factory=dict)
37 model_patterns: dict[str, ServiceModelPatternConfig] = Field(default_factory=dict)
39 model_config = ConfigDict(extra="allow")
41 @field_validator("api_key")
42 def validate_api_key(cls, v):
43 """Convert string API keys to SecretStr."""
44 if isinstance(v, str):
45 # Keep environment variable references as strings
46 if v.startswith("${") and v.endswith("}"):
47 return v
48 return SecretStr(v)
49 return v
52class ModelServiceConfig(BaseModel):
53 """Service configuration for a specific model."""
55 provider: str
56 model_id: str
59class ModelConfig(BaseModel):
60 """Configuration for a specific model."""
62 id: str
63 service: ModelServiceConfig
64 modality: Modality | None = None
65 name: str | None = None
66 description: str | None = None
67 api_format: str | None = None
68 default_params: dict[str, Any] = Field(default_factory=dict)
69 override_params: dict[str, Any] = Field(default_factory=dict)
70 remove_params: list[str] = Field(default_factory=list)
71 rename_params: dict[str, str] = Field(default_factory=dict)
73 model_config = ConfigDict(extra="allow")
76class ConfigSchema(BaseModel):
77 """Root schema for the configuration YAML."""
79 default_params: dict[str, Any] = Field(default_factory=dict)
80 services: dict[str, ServiceConfig] = Field(default_factory=dict)
81 models: dict[str, list[ModelConfig]] = Field(default_factory=dict)
82 app_config: dict[str, Any] = Field(default_factory=dict)
83 models_merge_mode: Literal["merge", "replace", "extend"] = "merge"
85 @model_validator(mode="after")
86 def validate_model_service_providers(self) -> "ConfigSchema":
87 """Validate that all model service providers exist in the services configuration."""
88 available_providers = set(self.services.keys())
90 for modality, model_list in self.models.items():
91 for model in model_list:
92 if model.service.provider not in available_providers:
93 raise ValueError(
94 f"Model '{model.id}' in modality '{modality}' references service provider '{model.service.provider}' "
95 f"which is not defined in services. Available providers: {sorted(available_providers)}"
96 )
98 return self