Coverage for packages/registry/src/langgate/registry/local.py: 52%

64 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-10-04 15:46 +0000

1"""LocalRegistryClient for direct registry access.""" 

2 

3from collections.abc import Sequence 

4from datetime import timedelta 

5from typing import Generic, cast, get_args 

6 

7from langgate.client.protocol import BaseRegistryClient, ImageInfoT, LLMInfoT 

8from langgate.core.logging import get_logger 

9from langgate.core.models import ImageModelInfo, LLMInfo 

10from langgate.registry.models import ModelRegistry 

11 

12logger = get_logger(__name__) 

13 

14 

15class BaseLocalRegistryClient( 

16 BaseRegistryClient[LLMInfoT, ImageInfoT], Generic[LLMInfoT, ImageInfoT] 

17): 

18 """ 

19 Base local registry client that calls ModelRegistry directly. 

20 

21 This client is used when you want to embed the registry in your application 

22 rather than connecting to a remote registry service. 

23 

24 Type Parameters: 

25 LLMInfoT: The LLMInfo-derived model class to use for responses 

26 ImageInfoT: The ImageModelInfo-derived model class to use for responses 

27 """ 

28 

29 __orig_bases__: tuple 

30 llm_info_cls: type[LLMInfoT] 

31 image_info_cls: type[ImageInfoT] 

32 

33 def __init__( 

34 self, 

35 llm_info_cls: type[LLMInfoT] | None = None, 

36 image_info_cls: type[ImageInfoT] | None = None, 

37 ): 

38 """Initialize the client with a ModelRegistry instance.""" 

39 # Cache refreshing is no-op for local registry clients. 

40 # Since this client is local, we don't need to refresh the cache. 

41 # TODO: Move caching to the base HTTP client class instead. 

42 cache_ttl = timedelta(days=365) 

43 super().__init__(cache_ttl=cache_ttl) 

44 self.registry = ModelRegistry() 

45 

46 # Set model info classes if provided, otherwise they are inferred from the class 

47 if llm_info_cls is not None: 

48 self.llm_info_cls = llm_info_cls 

49 if image_info_cls is not None: 

50 self.image_info_cls = image_info_cls 

51 

52 logger.debug("initialized_base_local_registry_client") 

53 

54 def __init_subclass__(cls, **kwargs): 

55 """Set up model classes when this class is subclassed.""" 

56 super().__init_subclass__(**kwargs) 

57 

58 if not hasattr(cls, "llm_info_cls") or not hasattr(cls, "image_info_cls"): 

59 llm_cls, image_cls = cls._get_model_info_classes() 

60 if not hasattr(cls, "llm_info_cls"): 

61 cls.llm_info_cls = llm_cls 

62 if not hasattr(cls, "image_info_cls"): 

63 cls.image_info_cls = image_cls 

64 

65 @classmethod 

66 def _get_model_info_classes(cls) -> tuple[type[LLMInfoT], type[ImageInfoT]]: 

67 """Extract the model classes from generic type parameters.""" 

68 args = get_args(cls.__orig_bases__[0]) 

69 return args[0], args[1] 

70 

71 async def _fetch_llm_info(self, model_id: str) -> LLMInfoT: 

72 """Get LLM information directly from registry. 

73 

74 Args: 

75 model_id: The ID of the LLM to get information for 

76 

77 Returns: 

78 Information about the requested LLM with the type expected by this client 

79 """ 

80 # Get the LLM info from the registry 

81 info = self.registry.get_llm_info(model_id) 

82 

83 # If llm_info_cls is LLMInfo (not a subclass), we can return it as-is 

84 if self.llm_info_cls is LLMInfo: 

85 return cast(LLMInfoT, info) 

86 

87 # Otherwise, validate against the subclass schema 

88 return self.llm_info_cls.model_validate(info.model_dump()) 

89 

90 async def _fetch_image_model_info(self, model_id: str) -> ImageInfoT: 

91 """Get image model information directly from registry. 

92 

93 Args: 

94 model_id: The ID of the image model to get information for 

95 

96 Returns: 

97 Information about the requested image model with the type expected by this client 

98 """ 

99 # Get the image model info from the registry 

100 info = self.registry.get_image_model_info(model_id) 

101 

102 # If image_info_cls is ImageModelInfo (not a subclass), we can return it as-is 

103 if self.image_info_cls is ImageModelInfo: 

104 return cast(ImageInfoT, info) 

105 

106 # Otherwise, validate against the subclass schema 

107 return self.image_info_cls.model_validate(info.model_dump()) 

108 

109 async def _fetch_all_llms(self) -> Sequence[LLMInfoT]: 

110 """List all available LLMs directly from registry. 

111 

112 Returns: 

113 A sequence of LLM information objects of the type expected by this client. 

114 """ 

115 models = self.registry.list_llms() 

116 

117 # If llm_info_cls is LLMInfo (not a subclass), we can return the list as-is 

118 if self.llm_info_cls is LLMInfo: 

119 return cast(Sequence[LLMInfoT], models) 

120 

121 # Otherwise, we need to validate each model against the subclass schema 

122 return [ 

123 self.llm_info_cls.model_validate(model.model_dump()) for model in models 

124 ] 

125 

126 async def _fetch_all_image_models(self) -> Sequence[ImageInfoT]: 

127 """List all available image models directly from registry. 

128 

129 Returns: 

130 A sequence of image model information objects of the type expected by this client. 

131 """ 

132 models = self.registry.list_image_models() 

133 

134 # If image_info_cls is ImageModelInfo (not a subclass), we can return the list as-is 

135 if self.image_info_cls is ImageModelInfo: 

136 return cast(Sequence[ImageInfoT], models) 

137 

138 # Otherwise, we need to validate each model against the subclass schema 

139 return [ 

140 self.image_info_cls.model_validate(model.model_dump()) for model in models 

141 ] 

142 

143 

144class LocalRegistryClient(BaseLocalRegistryClient[LLMInfo, ImageModelInfo]): 

145 """ 

146 Local registry client that calls ModelRegistry directly using the default schemas. 

147 

148 This is implemented as a singleton for convenient access across an application. 

149 """ 

150 

151 _instance = None 

152 

153 def __new__(cls, *args, **kwargs): 

154 """Create or return the singleton instance.""" 

155 if cls._instance is None: 

156 cls._instance = super().__new__(cls) 

157 return cls._instance 

158 

159 def __init__(self): 

160 """Initialize the client with a ModelRegistry instance.""" 

161 if not hasattr(self, "_initialized"): 

162 super().__init__() 

163 self._initialized = True 

164 logger.debug("initialized_local_registry_client_singleton")