Coverage for packages/core/src/langgate/core/models.py: 99%
113 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-08-02 18:47 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-08-02 18:47 +0000
1"""Core domain data models for LangGate."""
3from datetime import UTC, datetime
4from decimal import Decimal
5from enum import Enum
6from typing import Annotated, Any, NewType, Self
8from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
10from langgate.core.fields import NormalizedDecimal
12# Type aliases for flexibility while maintaining naming compatibility
13ServiceProviderId = NewType("ServiceProviderId", str)
14# Model provider might differ from the inference service provider
15# The service provider is not intended to be exposed to external consumers of the registry
16# The service provider is used by the proxy for routing requests to the correct service
17ModelProviderId = NewType("ModelProviderId", str)
19# Common model providers for convenience
20MODEL_PROVIDER_OPENAI = ModelProviderId("openai")
21MODEL_PROVIDER_ANTHROPIC = ModelProviderId("anthropic")
22MODEL_PROVIDER_META = ModelProviderId("meta")
23MODEL_PROVIDER_GOOGLE = ModelProviderId("google")
24MODEL_PROVIDER_MINIMAX = ModelProviderId("minimax")
25MODEL_PROVIDER_DEEPSEEK = ModelProviderId("deepseek")
26MODEL_PROVIDER_MISTRAL = ModelProviderId("mistralai")
27MODEL_PROVIDER_ALIBABA = ModelProviderId("alibaba")
28MODEL_PROVIDER_XAI = ModelProviderId("xai")
29MODEL_PROVIDER_COHERE = ModelProviderId("cohere")
30MODEL_PROVIDER_ELEUTHERIA = ModelProviderId("eleutheria")
33class Modality(str, Enum):
34 """Supported model modalities."""
36 TEXT = "text"
37 IMAGE = "image"
40class ServiceProvider(BaseModel):
41 """Information about a service provider (API service)."""
43 id: ServiceProviderId
44 base_url: str
45 api_key: SecretStr
46 default_params: dict[str, Any] = Field(default_factory=dict)
49class ModelProviderBase(BaseModel):
50 id: ModelProviderId | None = None
51 name: str | None = None
52 description: str | None = None
54 model_config = ConfigDict(extra="allow")
57class ModelProvider(ModelProviderBase):
58 """Information about a model provider (creator)."""
60 id: ModelProviderId = Field(default=...)
61 name: str = Field(default=...)
64class ContextWindow(BaseModel):
65 """Context window information for a model."""
67 max_input_tokens: int = 0
68 max_output_tokens: int = 0
70 model_config = ConfigDict(extra="allow")
73class ModelCapabilities(BaseModel):
74 """Capabilities of a language model."""
76 supports_tools: bool | None = None
77 supports_parallel_tool_calls: bool | None = None
78 supports_vision: bool | None = None
79 supports_audio_input: bool | None = None
80 supports_audio_output: bool | None = None
81 supports_prompt_caching: bool | None = None
82 supports_reasoning: bool | None = None
83 supports_response_schema: bool | None = None
84 supports_system_messages: bool | None = None
86 model_config = ConfigDict(extra="allow")
89TokenCost = Annotated[NormalizedDecimal, "TokenCost"]
90Percentage = Annotated[NormalizedDecimal, "Percentage"]
91TokenUsage = Annotated[NormalizedDecimal, "TokenUsage"]
94class BaseModelInfo(BaseModel):
95 """Base class for all model types with common fields."""
97 id: str | None = None
98 name: str | None = None
99 provider_id: ModelProviderId | None = None
100 description: str | None = None
102 model_config = ConfigDict(extra="allow")
105class TokenCosts(BaseModel):
106 """Token-based cost information for models that charge for input/output tokens."""
108 input_cost_per_token: TokenCost = Field(default_factory=Decimal)
109 output_cost_per_token: TokenCost = Field(default_factory=Decimal)
110 input_cost_per_token_batches: TokenCost | None = None
111 output_cost_per_token_batches: TokenCost | None = None
112 cache_read_input_token_cost: TokenCost | None = None
113 input_cached_cost_per_token: TokenCost | None = None
115 model_config = ConfigDict(extra="allow")
118class LLMInfoBase(BaseModelInfo):
119 costs: TokenCosts | None = None
120 capabilities: ModelCapabilities | None = None
121 context_window: ContextWindow | None = None
124class LLMInfo(LLMInfoBase):
125 """Information about a language model."""
127 id: str = Field(default=...) # "gpt-4o"
128 name: str = Field(default=...)
129 provider_id: ModelProviderId = Field(default=...)
131 provider: ModelProvider # Who created it (shown to users)
132 description: str | None = None
133 costs: TokenCosts = Field(default_factory=TokenCosts)
134 capabilities: ModelCapabilities = Field(default_factory=ModelCapabilities)
135 context_window: ContextWindow = Field(default_factory=ContextWindow)
136 updated_dt: datetime = Field(default_factory=lambda: datetime.now(UTC))
138 @model_validator(mode="after")
139 def _validate_provider_id(self):
140 self.provider_id = self.provider.id
141 return self
144class ImageGenerationCost(BaseModel):
145 """Cost information for image generation with support for multiple pricing models."""
147 # For simple flat-rate pricing (most providers)
148 flat_rate: Decimal | None = None
150 # For dimension-based pricing (OpenAI)
151 quality_tiers: dict[str, dict[str, Decimal]] | None = None
153 # For usage-based pricing
154 cost_per_megapixel: Decimal | None = None
155 cost_per_second: Decimal | None = None
157 @model_validator(mode="after")
158 def validate_exactly_one_pricing_model(self) -> Self:
159 """Ensure exactly one pricing model is specified."""
160 pricing_models = [
161 self.flat_rate,
162 self.quality_tiers,
163 self.cost_per_megapixel,
164 self.cost_per_second,
165 ]
166 if sum(p is not None for p in pricing_models) != 1:
167 raise ValueError("Exactly one pricing model must be set.")
168 return self
171class ImageModelCost(BaseModel):
172 """Cost information for image generation models."""
174 token_costs: TokenCosts | None = None
175 image_generation: ImageGenerationCost
178class ImageModelInfoBase(BaseModelInfo):
179 costs: ImageModelCost | None = None
182class ImageModelInfo(ImageModelInfoBase):
183 """Information about an image generation model."""
185 id: str = Field(default=...)
186 name: str = Field(default=...)
187 provider_id: ModelProviderId = Field(default=...)
189 provider: ModelProvider
190 description: str | None = None
191 costs: ImageModelCost = Field(default=...)
192 updated_dt: datetime = Field(default_factory=lambda: datetime.now(UTC))
194 @model_validator(mode="after")
195 def _validate_provider_id(self) -> Self:
196 """Ensure provider_id matches provider.id."""
197 self.provider_id = self.provider.id
198 return self