Coverage for src / dataknobs_llm / llm / providers / base.py: 18%

76 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-15 10:28 -0700

1"""Base adapter for synchronous LLM provider access.""" 

2 

3from typing import List, Union, Dict, Any 

4 

5from ..base import ( 

6 LLMMessage, LLMResponse, 

7 AsyncLLMProvider, ModelCapability 

8) 

9 

10 

11class SyncProviderAdapter: 

12 """Sync adapter for async LLM providers.""" 

13 

14 def __init__(self, async_provider: AsyncLLMProvider): 

15 """Initialize with async provider. 

16 

17 Args: 

18 async_provider: The async provider to wrap. 

19 """ 

20 self.async_provider = async_provider 

21 

22 def initialize(self) -> None: 

23 """Initialize the provider synchronously.""" 

24 import asyncio 

25 try: 

26 loop = asyncio.get_event_loop() 

27 except RuntimeError: 

28 loop = asyncio.new_event_loop() 

29 asyncio.set_event_loop(loop) 

30 

31 return loop.run_until_complete(self.async_provider.initialize()) 

32 

33 def close(self) -> None: 

34 """Close the provider synchronously.""" 

35 import asyncio 

36 try: 

37 loop = asyncio.get_event_loop() 

38 except RuntimeError: 

39 loop = asyncio.new_event_loop() 

40 asyncio.set_event_loop(loop) 

41 

42 return loop.run_until_complete(self.async_provider.close()) 

43 

44 def complete( 

45 self, 

46 messages: Union[str, List[LLMMessage]], 

47 **kwargs 

48 ) -> LLMResponse: 

49 """Generate completion synchronously.""" 

50 import asyncio 

51 try: 

52 loop = asyncio.get_event_loop() 

53 except RuntimeError: 

54 loop = asyncio.new_event_loop() 

55 asyncio.set_event_loop(loop) 

56 

57 return loop.run_until_complete(self.async_provider.complete(messages, **kwargs)) 

58 

59 def stream( 

60 self, 

61 messages: Union[str, List[LLMMessage]], 

62 **kwargs 

63 ): 

64 """Stream completion synchronously.""" 

65 import asyncio 

66 try: 

67 loop = asyncio.get_event_loop() 

68 except RuntimeError: 

69 loop = asyncio.new_event_loop() 

70 asyncio.set_event_loop(loop) 

71 

72 async def _stream(): 

73 async for chunk in self.async_provider.stream_complete(messages, **kwargs): 

74 yield chunk 

75 

76 # Convert async generator to sync generator 

77 async_gen = _stream() 

78 try: 

79 while True: 

80 try: 

81 yield loop.run_until_complete(async_gen.__anext__()) 

82 except StopAsyncIteration: 

83 break 

84 finally: 

85 loop.run_until_complete(async_gen.aclose()) 

86 

87 def embed( 

88 self, 

89 texts: Union[str, List[str]], 

90 **kwargs 

91 ) -> Union[List[float], List[List[float]]]: 

92 """Generate embeddings synchronously.""" 

93 import asyncio 

94 try: 

95 loop = asyncio.get_event_loop() 

96 except RuntimeError: 

97 loop = asyncio.new_event_loop() 

98 asyncio.set_event_loop(loop) 

99 

100 return loop.run_until_complete(self.async_provider.embed(texts, **kwargs)) 

101 

102 def function_call( 

103 self, 

104 messages: List[LLMMessage], 

105 functions: List[Dict[str, Any]], 

106 **kwargs 

107 ) -> LLMResponse: 

108 """Make function call synchronously.""" 

109 import asyncio 

110 try: 

111 loop = asyncio.get_event_loop() 

112 except RuntimeError: 

113 loop = asyncio.new_event_loop() 

114 asyncio.set_event_loop(loop) 

115 

116 return loop.run_until_complete(self.async_provider.function_call(messages, functions, **kwargs)) 

117 

118 def validate_model(self) -> bool: 

119 """Validate model synchronously.""" 

120 import asyncio 

121 try: 

122 loop = asyncio.get_event_loop() 

123 except RuntimeError: 

124 loop = asyncio.new_event_loop() 

125 asyncio.set_event_loop(loop) 

126 

127 return loop.run_until_complete(self.async_provider.validate_model()) # type: ignore 

128 

129 def get_capabilities(self) -> List[ModelCapability]: 

130 """Get capabilities synchronously.""" 

131 return self.async_provider.get_capabilities() 

132 

133 @property 

134 def is_initialized(self) -> bool: 

135 """Check if provider is initialized.""" 

136 return self.async_provider.is_initialized