Coverage for packages/client/src/langgate/client/protocol.py: 32%

75 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-11 05:31 +0000

1"""Protocol definitions for LangGate clients.""" 

2 

3from abc import ABC, abstractmethod 

4from collections.abc import Sequence 

5from datetime import datetime, timedelta 

6from typing import Generic, Protocol, TypeVar 

7 

8from langgate.core.logging import get_logger 

9from langgate.core.models import ImageModelInfo, LLMInfo 

10 

11logger = get_logger(__name__) 

12 

13 

14LLMInfoT = TypeVar("LLMInfoT", bound=LLMInfo, covariant=True) 

15ImageInfoT = TypeVar("ImageInfoT", bound=ImageModelInfo, covariant=True) 

16 

17 

18class LLMRegistryClientProtocol(Protocol[LLMInfoT]): 

19 """Protocol for LLM registry clients.""" 

20 

21 async def get_llm_info(self, model_id: str) -> LLMInfoT: 

22 """Get LLM information by ID.""" 

23 ... 

24 

25 async def list_llms(self) -> Sequence[LLMInfoT]: 

26 """List all available LLMs.""" 

27 ... 

28 

29 

30class ImageRegistryClientProtocol(Protocol[ImageInfoT]): 

31 """Protocol for image model registry clients.""" 

32 

33 async def get_image_model_info(self, model_id: str) -> ImageInfoT: 

34 """Get image model information by ID.""" 

35 ... 

36 

37 async def list_image_models(self) -> Sequence[ImageInfoT]: 

38 """List all available image models.""" 

39 ... 

40 

41 

42class RegistryClientProtocol( 

43 LLMRegistryClientProtocol[LLMInfoT], ImageRegistryClientProtocol[ImageInfoT] 

44): 

45 """Protocol for clients supporting multiple modalities.""" 

46 

47 

48class BaseRegistryClient( 

49 ABC, 

50 Generic[LLMInfoT, ImageInfoT], 

51 RegistryClientProtocol[LLMInfoT, ImageInfoT], 

52): 

53 """Base class for registry clients with common operations.""" 

54 

55 def __init__(self, cache_ttl: timedelta | None = None) -> None: 

56 """Initialize the client with cache settings.""" 

57 self._llm_cache: dict[str, LLMInfoT] = {} 

58 self._image_cache: dict[str, ImageInfoT] = {} 

59 self._last_cache_refresh: datetime | None = None 

60 # TODO: allow this to be set via config or env var 

61 self._cache_ttl = cache_ttl or timedelta(minutes=60) 

62 

63 # LLM methods 

64 async def get_llm_info(self, model_id: str) -> LLMInfoT: 

65 """Get LLM information by ID.""" 

66 if self._should_refresh_cache(): 

67 await self._refresh_cache() 

68 

69 model = self._llm_cache.get(model_id) 

70 if model is None: 

71 # If not found after potential refresh, try fetching individually 

72 await logger.awarning( 

73 "cache_miss_fetching_individual_model", 

74 model_id=model_id, 

75 model_type="llm", 

76 ) 

77 try: 

78 model = await self._fetch_llm_info(model_id) 

79 self._llm_cache[model_id] = model 

80 except Exception as exc: 

81 await logger.aexception( 

82 "failed_to_fetch_individual_model", 

83 model_id=model_id, 

84 model_type="llm", 

85 ) 

86 raise ValueError( 

87 f"LLM model '{model_id}' not found in registry." 

88 ) from exc 

89 return model 

90 

91 async def list_llms(self) -> Sequence[LLMInfoT]: 

92 """List all available LLMs.""" 

93 if self._should_refresh_cache(): 

94 await self._refresh_cache() 

95 return list(self._llm_cache.values()) 

96 

97 # Image model methods 

98 async def get_image_model_info(self, model_id: str) -> ImageInfoT: 

99 """Get image model information by ID.""" 

100 if self._should_refresh_cache(): 

101 await self._refresh_cache() 

102 

103 model = self._image_cache.get(model_id) 

104 if model is None: 

105 # If not found after potential refresh, try fetching individually 

106 await logger.awarning( 

107 "cache_miss_fetching_individual_model", 

108 model_id=model_id, 

109 model_type="image", 

110 ) 

111 try: 

112 model = await self._fetch_image_model_info(model_id) 

113 self._image_cache[model_id] = model 

114 except Exception as exc: 

115 await logger.aexception( 

116 "failed_to_fetch_individual_model", 

117 model_id=model_id, 

118 model_type="image", 

119 ) 

120 raise ValueError( 

121 f"Image model '{model_id}' not found in registry." 

122 ) from exc 

123 return model 

124 

125 async def list_image_models(self) -> Sequence[ImageInfoT]: 

126 """List all available image models.""" 

127 if self._should_refresh_cache(): 

128 await self._refresh_cache() 

129 return list(self._image_cache.values()) 

130 

131 @abstractmethod 

132 async def _fetch_llm_info(self, model_id: str) -> LLMInfoT: 

133 """Fetch LLM info from the source.""" 

134 ... 

135 

136 @abstractmethod 

137 async def _fetch_image_model_info(self, model_id: str) -> ImageInfoT: 

138 """Fetch image model info from the source.""" 

139 ... 

140 

141 @abstractmethod 

142 async def _fetch_all_llms(self) -> Sequence[LLMInfoT]: 

143 """Fetch all LLMs from the source.""" 

144 ... 

145 

146 @abstractmethod 

147 async def _fetch_all_image_models(self) -> Sequence[ImageInfoT]: 

148 """Fetch all image models from the source.""" 

149 ... 

150 

151 async def _refresh_cache(self) -> None: 

152 """Refresh both model caches.""" 

153 await logger.adebug("refreshing_model_caches") 

154 try: 

155 # Refresh both caches in parallel 

156 llms = await self._fetch_all_llms() 

157 image_models = await self._fetch_all_image_models() 

158 

159 self._llm_cache = {model.id: model for model in llms} 

160 self._image_cache = {model.id: model for model in image_models} 

161 

162 self._last_cache_refresh = datetime.now() 

163 await logger.adebug( 

164 "refreshed_model_caches", 

165 llm_count=len(self._llm_cache), 

166 image_count=len(self._image_cache), 

167 ) 

168 except Exception as exc: 

169 await logger.aexception("failed_to_refresh_model_caches") 

170 # Decide: Keep stale cache or clear it? Keeping stale might be better than empty. 

171 self._last_cache_refresh = None # Force retry next time 

172 raise exc 

173 

174 def _should_refresh_cache(self) -> bool: 

175 """Check if cache should be refreshed.""" 

176 return ( 

177 self._last_cache_refresh is None 

178 or (datetime.now() - self._last_cache_refresh) > self._cache_ttl 

179 )