Coverage for packages/core/src/langgate/core/models.py: 99%

113 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-08-02 18:47 +0000

1"""Core domain data models for LangGate.""" 

2 

3from datetime import UTC, datetime 

4from decimal import Decimal 

5from enum import Enum 

6from typing import Annotated, Any, NewType, Self 

7 

8from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator 

9 

10from langgate.core.fields import NormalizedDecimal 

11 

12# Type aliases for flexibility while maintaining naming compatibility 

13ServiceProviderId = NewType("ServiceProviderId", str) 

14# Model provider might differ from the inference service provider 

15# The service provider is not intended to be exposed to external consumers of the registry 

16# The service provider is used by the proxy for routing requests to the correct service 

17ModelProviderId = NewType("ModelProviderId", str) 

18 

19# Common model providers for convenience 

20MODEL_PROVIDER_OPENAI = ModelProviderId("openai") 

21MODEL_PROVIDER_ANTHROPIC = ModelProviderId("anthropic") 

22MODEL_PROVIDER_META = ModelProviderId("meta") 

23MODEL_PROVIDER_GOOGLE = ModelProviderId("google") 

24MODEL_PROVIDER_MINIMAX = ModelProviderId("minimax") 

25MODEL_PROVIDER_DEEPSEEK = ModelProviderId("deepseek") 

26MODEL_PROVIDER_MISTRAL = ModelProviderId("mistralai") 

27MODEL_PROVIDER_ALIBABA = ModelProviderId("alibaba") 

28MODEL_PROVIDER_XAI = ModelProviderId("xai") 

29MODEL_PROVIDER_COHERE = ModelProviderId("cohere") 

30MODEL_PROVIDER_ELEUTHERIA = ModelProviderId("eleutheria") 

31 

32 

33class Modality(str, Enum): 

34 """Supported model modalities.""" 

35 

36 TEXT = "text" 

37 IMAGE = "image" 

38 

39 

40class ServiceProvider(BaseModel): 

41 """Information about a service provider (API service).""" 

42 

43 id: ServiceProviderId 

44 base_url: str 

45 api_key: SecretStr 

46 default_params: dict[str, Any] = Field(default_factory=dict) 

47 

48 

49class ModelProviderBase(BaseModel): 

50 id: ModelProviderId | None = None 

51 name: str | None = None 

52 description: str | None = None 

53 

54 model_config = ConfigDict(extra="allow") 

55 

56 

57class ModelProvider(ModelProviderBase): 

58 """Information about a model provider (creator).""" 

59 

60 id: ModelProviderId = Field(default=...) 

61 name: str = Field(default=...) 

62 

63 

64class ContextWindow(BaseModel): 

65 """Context window information for a model.""" 

66 

67 max_input_tokens: int = 0 

68 max_output_tokens: int = 0 

69 

70 model_config = ConfigDict(extra="allow") 

71 

72 

73class ModelCapabilities(BaseModel): 

74 """Capabilities of a language model.""" 

75 

76 supports_tools: bool | None = None 

77 supports_parallel_tool_calls: bool | None = None 

78 supports_vision: bool | None = None 

79 supports_audio_input: bool | None = None 

80 supports_audio_output: bool | None = None 

81 supports_prompt_caching: bool | None = None 

82 supports_reasoning: bool | None = None 

83 supports_response_schema: bool | None = None 

84 supports_system_messages: bool | None = None 

85 

86 model_config = ConfigDict(extra="allow") 

87 

88 

89TokenCost = Annotated[NormalizedDecimal, "TokenCost"] 

90Percentage = Annotated[NormalizedDecimal, "Percentage"] 

91TokenUsage = Annotated[NormalizedDecimal, "TokenUsage"] 

92 

93 

94class BaseModelInfo(BaseModel): 

95 """Base class for all model types with common fields.""" 

96 

97 id: str | None = None 

98 name: str | None = None 

99 provider_id: ModelProviderId | None = None 

100 description: str | None = None 

101 

102 model_config = ConfigDict(extra="allow") 

103 

104 

105class TokenCosts(BaseModel): 

106 """Token-based cost information for models that charge for input/output tokens.""" 

107 

108 input_cost_per_token: TokenCost = Field(default_factory=Decimal) 

109 output_cost_per_token: TokenCost = Field(default_factory=Decimal) 

110 input_cost_per_token_batches: TokenCost | None = None 

111 output_cost_per_token_batches: TokenCost | None = None 

112 cache_read_input_token_cost: TokenCost | None = None 

113 input_cached_cost_per_token: TokenCost | None = None 

114 

115 model_config = ConfigDict(extra="allow") 

116 

117 

118class LLMInfoBase(BaseModelInfo): 

119 costs: TokenCosts | None = None 

120 capabilities: ModelCapabilities | None = None 

121 context_window: ContextWindow | None = None 

122 

123 

124class LLMInfo(LLMInfoBase): 

125 """Information about a language model.""" 

126 

127 id: str = Field(default=...) # "gpt-4o" 

128 name: str = Field(default=...) 

129 provider_id: ModelProviderId = Field(default=...) 

130 

131 provider: ModelProvider # Who created it (shown to users) 

132 description: str | None = None 

133 costs: TokenCosts = Field(default_factory=TokenCosts) 

134 capabilities: ModelCapabilities = Field(default_factory=ModelCapabilities) 

135 context_window: ContextWindow = Field(default_factory=ContextWindow) 

136 updated_dt: datetime = Field(default_factory=lambda: datetime.now(UTC)) 

137 

138 @model_validator(mode="after") 

139 def _validate_provider_id(self): 

140 self.provider_id = self.provider.id 

141 return self 

142 

143 

144class ImageGenerationCost(BaseModel): 

145 """Cost information for image generation with support for multiple pricing models.""" 

146 

147 # For simple flat-rate pricing (most providers) 

148 flat_rate: Decimal | None = None 

149 

150 # For dimension-based pricing (OpenAI) 

151 quality_tiers: dict[str, dict[str, Decimal]] | None = None 

152 

153 # For usage-based pricing 

154 cost_per_megapixel: Decimal | None = None 

155 cost_per_second: Decimal | None = None 

156 

157 @model_validator(mode="after") 

158 def validate_exactly_one_pricing_model(self) -> Self: 

159 """Ensure exactly one pricing model is specified.""" 

160 pricing_models = [ 

161 self.flat_rate, 

162 self.quality_tiers, 

163 self.cost_per_megapixel, 

164 self.cost_per_second, 

165 ] 

166 if sum(p is not None for p in pricing_models) != 1: 

167 raise ValueError("Exactly one pricing model must be set.") 

168 return self 

169 

170 

171class ImageModelCost(BaseModel): 

172 """Cost information for image generation models.""" 

173 

174 token_costs: TokenCosts | None = None 

175 image_generation: ImageGenerationCost 

176 

177 

178class ImageModelInfoBase(BaseModelInfo): 

179 costs: ImageModelCost | None = None 

180 

181 

182class ImageModelInfo(ImageModelInfoBase): 

183 """Information about an image generation model.""" 

184 

185 id: str = Field(default=...) 

186 name: str = Field(default=...) 

187 provider_id: ModelProviderId = Field(default=...) 

188 

189 provider: ModelProvider 

190 description: str | None = None 

191 costs: ImageModelCost = Field(default=...) 

192 updated_dt: datetime = Field(default_factory=lambda: datetime.now(UTC)) 

193 

194 @model_validator(mode="after") 

195 def _validate_provider_id(self) -> Self: 

196 """Ensure provider_id matches provider.id.""" 

197 self.provider_id = self.provider.id 

198 return self