Coverage for packages/core/src/langgate/core/schemas/config.py: 94%
51 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-06-18 14:26 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-06-18 14:26 +0000
1"""Schema definitions for YAML configuration validation."""
3from typing import Any
5from pydantic import (
6 BaseModel,
7 ConfigDict,
8 Field,
9 SecretStr,
10 field_validator,
11 model_validator,
12)
14from langgate.core.fields import UrlOrEnvVar
17class ServiceModelPatternConfig(BaseModel):
18 """Configuration for model pattern matching within a service."""
20 default_params: dict[str, Any] = Field(default_factory=dict)
21 override_params: dict[str, Any] = Field(default_factory=dict)
22 remove_params: list[str] = Field(default_factory=list)
23 rename_params: dict[str, str] = Field(default_factory=dict)
26class ServiceConfig(BaseModel):
27 """Configuration for a service provider."""
29 api_key: str | SecretStr
30 base_url: UrlOrEnvVar | None = None
31 api_format: str | None = None
32 default_params: dict[str, Any] = Field(default_factory=dict)
33 override_params: dict[str, Any] = Field(default_factory=dict)
34 remove_params: list[str] = Field(default_factory=list)
35 rename_params: dict[str, str] = Field(default_factory=dict)
36 model_patterns: dict[str, ServiceModelPatternConfig] = Field(default_factory=dict)
38 model_config = ConfigDict(extra="allow")
40 @field_validator("api_key")
41 def validate_api_key(cls, v):
42 """Convert string API keys to SecretStr."""
43 if isinstance(v, str):
44 # Keep environment variable references as strings
45 if v.startswith("${") and v.endswith("}"):
46 return v
47 return SecretStr(v)
48 return v
51class ModelServiceConfig(BaseModel):
52 """Service configuration for a specific model."""
54 provider: str
55 model_id: str
58class ModelConfig(BaseModel):
59 """Configuration for a specific model."""
61 id: str
62 service: ModelServiceConfig
63 name: str | None = None
64 description: str | None = None
65 api_format: str | None = None
66 default_params: dict[str, Any] = Field(default_factory=dict)
67 override_params: dict[str, Any] = Field(default_factory=dict)
68 remove_params: list[str] = Field(default_factory=list)
69 rename_params: dict[str, str] = Field(default_factory=dict)
71 model_config = ConfigDict(extra="allow")
74class ConfigSchema(BaseModel):
75 """Root schema for the configuration YAML."""
77 default_params: dict[str, Any] = Field(default_factory=dict)
78 services: dict[str, ServiceConfig] = Field(default_factory=dict)
79 models: list[ModelConfig] = Field(default_factory=list)
80 app_config: dict[str, Any] = Field(default_factory=dict)
82 @model_validator(mode="after")
83 def validate_model_service_providers(self) -> "ConfigSchema":
84 """Validate that all model service providers exist in the services configuration."""
85 available_providers = set(self.services.keys())
87 for model in self.models:
88 if model.service.provider not in available_providers:
89 raise ValueError(
90 f"Model '{model.id}' references service provider '{model.service.provider}' "
91 f"which is not defined in services. Available providers: {sorted(available_providers)}"
92 )
94 return self