|
1 | 1 | # mypy: ignore-errors |
2 | 2 | import asyncio |
3 | 3 | import json |
4 | | -from queue import Queue |
5 | | -from threading import Thread |
6 | | -from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast |
| 4 | +from enum import Enum |
| 5 | +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast |
7 | 6 |
|
8 | 7 | from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun |
9 | 8 | from langchain_core.language_models.base import LanguageModelInput |
|
33 | 32 |
|
34 | 33 | __all__ = ["OpenGradientChatModel"] |
35 | 34 |
|
36 | | -_STREAM_END = object() |
37 | | - |
38 | 35 |
|
39 | 36 | def _extract_content(content: Any) -> str: |
40 | 37 | """Normalize content to a plain string. |
@@ -122,28 +119,29 @@ def _parse_tool_call_chunk(tool_call: Dict[str, Any], default_index: int) -> Too |
122 | 119 | ) |
123 | 120 |
|
124 | 121 |
|
125 | | -def _run_coro_sync(coro: Any) -> Any: |
| 122 | +def _run_coro_sync(coro_factory: Callable[[], Awaitable[Any]]) -> Any: |
126 | 123 | try: |
127 | 124 | asyncio.get_running_loop() |
128 | 125 | except RuntimeError: |
129 | | - return asyncio.run(coro) |
130 | | - |
131 | | - queue: Queue[Any] = Queue(maxsize=1) |
| 126 | + return asyncio.run(coro_factory()) |
132 | 127 |
|
133 | | - def _runner() -> None: |
134 | | - try: |
135 | | - queue.put(asyncio.run(coro)) |
136 | | - except BaseException as exc: # noqa: BLE001 |
137 | | - queue.put(exc) |
| 128 | + raise RuntimeError( |
| 129 | + "Synchronous LangChain calls cannot run inside an active event loop for this adapter. " |
| 130 | + "Use `ainvoke`/`astream` instead of `invoke`/`stream`." |
| 131 | + ) |
138 | 132 |
|
139 | | - thread = Thread(target=_runner, daemon=True) |
140 | | - thread.start() |
141 | | - outcome = queue.get() |
142 | | - thread.join() |
143 | 133 |
|
144 | | - if isinstance(outcome, BaseException): |
145 | | - raise outcome |
146 | | - return outcome |
| 134 | +def _validate_model_string(model: Union[TEE_LLM, str]) -> Union[TEE_LLM, str]: |
| 135 | + if isinstance(model, Enum): |
| 136 | + model_str = str(model.value) |
| 137 | + else: |
| 138 | + model_str = str(model) |
| 139 | + if "/" not in model_str: |
| 140 | + raise ValueError( |
| 141 | + f"Unsupported model value '{model_str}'. " |
| 142 | + "Expected provider/model format (for example: 'openai/gpt-5')." |
| 143 | + ) |
| 144 | + return model |
147 | 145 |
|
148 | 146 |
|
149 | 147 | class OpenGradientChatModel(BaseChatModel): |
@@ -176,6 +174,7 @@ def __init__( |
176 | 174 | resolved_model_cid = model_cid or model |
177 | 175 | if resolved_model_cid is None: |
178 | 176 | raise ValueError("model_cid (or model) is required.") |
| 177 | + resolved_model_cid = _validate_model_string(resolved_model_cid) |
179 | 178 | super().__init__( |
180 | 179 | model_cid=resolved_model_cid, |
181 | 180 | max_tokens=max_tokens, |
@@ -213,7 +212,7 @@ async def aclose(self) -> None: |
213 | 212 |
|
214 | 213 | def close(self) -> None: |
215 | 214 | if self._owns_client: |
216 | | - _run_coro_sync(self._llm.close()) |
| 215 | + _run_coro_sync(self._llm.close) |
217 | 216 |
|
218 | 217 | def bind_tools( |
219 | 218 | self, |
@@ -309,9 +308,11 @@ def _build_chat_kwargs(self, sdk_messages: List[Dict[str, Any]], stop: Optional[ |
309 | 308 | x402_settlement_mode = kwargs.get("x402_settlement_mode", self.x402_settlement_mode) |
310 | 309 | if isinstance(x402_settlement_mode, str): |
311 | 310 | x402_settlement_mode = x402SettlementMode(x402_settlement_mode) |
| 311 | + model = kwargs.get("model", self.model_cid) |
| 312 | + model = _validate_model_string(model) |
312 | 313 |
|
313 | 314 | return { |
314 | | - "model": kwargs.get("model", self.model_cid), |
| 315 | + "model": model, |
315 | 316 | "messages": sdk_messages, |
316 | 317 | "stop_sequence": stop, |
317 | 318 | "max_tokens": kwargs.get("max_tokens", self.max_tokens), |
@@ -346,7 +347,7 @@ def _generate( |
346 | 347 | ) -> ChatResult: |
347 | 348 | sdk_messages = self._convert_messages_to_sdk(messages) |
348 | 349 | chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) |
349 | | - chat_output = _run_coro_sync(self._llm.chat(**chat_kwargs)) |
| 350 | + chat_output = _run_coro_sync(lambda: self._llm.chat(**chat_kwargs)) |
350 | 351 | if not isinstance(chat_output, TextGenerationOutput): |
351 | 352 | raise RuntimeError("Expected non-streaming chat output but received streaming generator.") |
352 | 353 | return self._build_chat_result(chat_output) |
@@ -374,33 +375,30 @@ def _stream( |
374 | 375 | ) -> Iterator[ChatGenerationChunk]: |
375 | 376 | sdk_messages = self._convert_messages_to_sdk(messages) |
376 | 377 | chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) |
377 | | - queue: Queue[Any] = Queue() |
378 | | - |
379 | | - def _runner() -> None: |
380 | | - async def _run() -> None: |
381 | | - stream = await self._llm.chat(**chat_kwargs) |
382 | | - async for chunk in cast(AsyncIterator[StreamChunk], stream): |
383 | | - queue.put(self._stream_chunk_to_generation(chunk)) |
384 | | - |
385 | | - try: |
386 | | - asyncio.run(_run()) |
387 | | - except BaseException as exc: # noqa: BLE001 |
388 | | - queue.put(exc) |
389 | | - finally: |
390 | | - queue.put(_STREAM_END) |
391 | | - |
392 | | - thread = Thread(target=_runner, daemon=True) |
393 | | - thread.start() |
394 | | - |
395 | | - while True: |
396 | | - item = queue.get() |
397 | | - if item is _STREAM_END: |
398 | | - break |
399 | | - if isinstance(item, BaseException): |
400 | | - raise item |
401 | | - yield cast(ChatGenerationChunk, item) |
402 | | - |
403 | | - thread.join() |
| 378 | + try: |
| 379 | + asyncio.get_running_loop() |
| 380 | + except RuntimeError: |
| 381 | + pass |
| 382 | + else: |
| 383 | + raise RuntimeError( |
| 384 | + "Synchronous stream cannot run inside an active event loop for this adapter. " |
| 385 | + "Use `astream` instead." |
| 386 | + ) |
| 387 | + |
| 388 | + loop = asyncio.new_event_loop() |
| 389 | + try: |
| 390 | + stream = loop.run_until_complete(self._llm.chat(**chat_kwargs)) |
| 391 | + stream_iter = cast(AsyncIterator[StreamChunk], stream) |
| 392 | + |
| 393 | + while True: |
| 394 | + try: |
| 395 | + chunk = loop.run_until_complete(stream_iter.__anext__()) |
| 396 | + except StopAsyncIteration: |
| 397 | + break |
| 398 | + yield self._stream_chunk_to_generation(chunk) |
| 399 | + finally: |
| 400 | + loop.run_until_complete(loop.shutdown_asyncgens()) |
| 401 | + loop.close() |
404 | 402 |
|
405 | 403 | async def _astream( |
406 | 404 | self, |
|
0 commit comments