Coverage for packages/client/src/langgate/client/protocol.py: 35%

48 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-04-09 21:23 +0000

1"""Protocol definitions for LangGate clients.""" 

2 

3from abc import ABC, abstractmethod 

4from collections.abc import Sequence 

5from datetime import datetime, timedelta 

6from typing import Generic, Protocol, TypeVar 

7 

8from langgate.core.logging import get_logger 

9from langgate.core.models import LLMInfo 

10 

11logger = get_logger(__name__) 

12 

13 

14LLMInfoT = TypeVar("LLMInfoT", bound=LLMInfo, covariant=True) 

15 

16 

17class RegistryClientProtocol(Protocol[LLMInfoT]): 

18 """Protocol for registry clients.""" 

19 

20 async def get_model_info(self, model_id: str) -> LLMInfoT: 

21 """Get model information by ID.""" 

22 ... 

23 

24 async def list_models(self) -> Sequence[LLMInfoT]: 

25 """List all available models.""" 

26 ... 

27 

28 

29class BaseRegistryClient(ABC, Generic[LLMInfoT]): 

30 """Base class for registry clients with common operations.""" 

31 

32 def __init__(self, cache_ttl: timedelta | None = None) -> None: 

33 """Initialize the client with cache settings.""" 

34 self._model_cache: dict[str, LLMInfoT] = {} 

35 self._last_cache_refresh: datetime | None = None 

36 # TODO: allow this to be set via config or env var 

37 self._cache_ttl = cache_ttl or timedelta(minutes=60) 

38 

39 async def get_model_info(self, model_id: str) -> LLMInfoT: 

40 """Get model info with caching. 

41 

42 Args: 

43 model_id: The ID of the model to get information for 

44 

45 Returns: 

46 Information about the requested model 

47 

48 Raises: 

49 ValueError: If the model is not found 

50 """ 

51 if self._should_refresh_cache(): 

52 await self._refresh_cache() 

53 

54 model = self._model_cache.get(model_id) 

55 if model is None: 

56 # If not found after potential refresh, try fetching individually 

57 await logger.awarning( 

58 "cache_miss_fetching_individual_model", model_id=model_id 

59 ) 

60 try: 

61 model = await self._fetch_model_info(model_id) 

62 self._model_cache[model_id] = model 

63 except Exception as exc: 

64 await logger.aexception( 

65 "failed_to_fetch_individual_model", model_id=model_id 

66 ) 

67 raise ValueError(f"Model '{model_id}' not found in registry.") from exc 

68 return model 

69 

70 async def list_models(self) -> Sequence[LLMInfoT]: 

71 """List all models with caching. 

72 

73 Returns: 

74 A list of all available models 

75 """ 

76 if self._should_refresh_cache(): 

77 await self._refresh_cache() 

78 

79 return list(self._model_cache.values()) 

80 

81 @abstractmethod 

82 async def _fetch_model_info(self, model_id: str) -> LLMInfoT: 

83 """Fetch model info from the source.""" 

84 ... 

85 

86 @abstractmethod 

87 async def _fetch_all_models(self) -> Sequence[LLMInfoT]: 

88 """Fetch all models from the source.""" 

89 ... 

90 

91 async def _refresh_cache(self) -> None: 

92 """Refresh the model cache.""" 

93 await logger.adebug("refreshing_model_cache") 

94 try: 

95 models = await self._fetch_all_models() 

96 self._model_cache = {model.id: model for model in models} 

97 self._last_cache_refresh = datetime.now() 

98 await logger.adebug( 

99 "refreshed_model_cache", model_count=len(self._model_cache) 

100 ) 

101 except Exception as exc: 

102 await logger.aexception("failed_to_refresh_model_cache") 

103 # Decide: Keep stale cache or clear it? Keeping stale might be better than empty. 

104 # self._model_cache = {} # Clear cache on error 

105 self._last_cache_refresh = None # Force retry next time 

106 raise exc 

107 

108 def _should_refresh_cache(self) -> bool: 

109 """Check if cache should be refreshed.""" 

110 return ( 

111 self._last_cache_refresh is None 

112 or (datetime.now() - self._last_cache_refresh) > self._cache_ttl 

113 )