Skip to content

Commit 6072d20

Browse files
committed
updates
1 parent bfe7146 commit 6072d20

2 files changed

Lines changed: 71 additions & 51 deletions

File tree

src/opengradient/agents/og_langchain.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# mypy: ignore-errors
22
import asyncio
33
import json
4-
from queue import Queue
5-
from threading import Thread
6-
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast
4+
from enum import Enum
5+
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast
76

87
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
98
from langchain_core.language_models.base import LanguageModelInput
@@ -33,8 +32,6 @@
3332

3433
__all__ = ["OpenGradientChatModel"]
3534

36-
_STREAM_END = object()
37-
3835

3936
def _extract_content(content: Any) -> str:
4037
"""Normalize content to a plain string.
@@ -122,28 +119,29 @@ def _parse_tool_call_chunk(tool_call: Dict[str, Any], default_index: int) -> Too
122119
)
123120

124121

125-
def _run_coro_sync(coro: Any) -> Any:
122+
def _run_coro_sync(coro_factory: Callable[[], Awaitable[Any]]) -> Any:
126123
try:
127124
asyncio.get_running_loop()
128125
except RuntimeError:
129-
return asyncio.run(coro)
130-
131-
queue: Queue[Any] = Queue(maxsize=1)
126+
return asyncio.run(coro_factory())
132127

133-
def _runner() -> None:
134-
try:
135-
queue.put(asyncio.run(coro))
136-
except BaseException as exc: # noqa: BLE001
137-
queue.put(exc)
128+
raise RuntimeError(
129+
"Synchronous LangChain calls cannot run inside an active event loop for this adapter. "
130+
"Use `ainvoke`/`astream` instead of `invoke`/`stream`."
131+
)
138132

139-
thread = Thread(target=_runner, daemon=True)
140-
thread.start()
141-
outcome = queue.get()
142-
thread.join()
143133

144-
if isinstance(outcome, BaseException):
145-
raise outcome
146-
return outcome
134+
def _validate_model_string(model: Union[TEE_LLM, str]) -> Union[TEE_LLM, str]:
135+
if isinstance(model, Enum):
136+
model_str = str(model.value)
137+
else:
138+
model_str = str(model)
139+
if "/" not in model_str:
140+
raise ValueError(
141+
f"Unsupported model value '{model_str}'. "
142+
"Expected provider/model format (for example: 'openai/gpt-5')."
143+
)
144+
return model
147145

148146

149147
class OpenGradientChatModel(BaseChatModel):
@@ -176,6 +174,7 @@ def __init__(
176174
resolved_model_cid = model_cid or model
177175
if resolved_model_cid is None:
178176
raise ValueError("model_cid (or model) is required.")
177+
resolved_model_cid = _validate_model_string(resolved_model_cid)
179178
super().__init__(
180179
model_cid=resolved_model_cid,
181180
max_tokens=max_tokens,
@@ -213,7 +212,7 @@ async def aclose(self) -> None:
213212

214213
def close(self) -> None:
215214
if self._owns_client:
216-
_run_coro_sync(self._llm.close())
215+
_run_coro_sync(self._llm.close)
217216

218217
def bind_tools(
219218
self,
@@ -309,9 +308,11 @@ def _build_chat_kwargs(self, sdk_messages: List[Dict[str, Any]], stop: Optional[
309308
x402_settlement_mode = kwargs.get("x402_settlement_mode", self.x402_settlement_mode)
310309
if isinstance(x402_settlement_mode, str):
311310
x402_settlement_mode = x402SettlementMode(x402_settlement_mode)
311+
model = kwargs.get("model", self.model_cid)
312+
model = _validate_model_string(model)
312313

313314
return {
314-
"model": kwargs.get("model", self.model_cid),
315+
"model": model,
315316
"messages": sdk_messages,
316317
"stop_sequence": stop,
317318
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
@@ -346,7 +347,7 @@ def _generate(
346347
) -> ChatResult:
347348
sdk_messages = self._convert_messages_to_sdk(messages)
348349
chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs)
349-
chat_output = _run_coro_sync(self._llm.chat(**chat_kwargs))
350+
chat_output = _run_coro_sync(lambda: self._llm.chat(**chat_kwargs))
350351
if not isinstance(chat_output, TextGenerationOutput):
351352
raise RuntimeError("Expected non-streaming chat output but received streaming generator.")
352353
return self._build_chat_result(chat_output)
@@ -374,33 +375,30 @@ def _stream(
374375
) -> Iterator[ChatGenerationChunk]:
375376
sdk_messages = self._convert_messages_to_sdk(messages)
376377
chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs)
377-
queue: Queue[Any] = Queue()
378-
379-
def _runner() -> None:
380-
async def _run() -> None:
381-
stream = await self._llm.chat(**chat_kwargs)
382-
async for chunk in cast(AsyncIterator[StreamChunk], stream):
383-
queue.put(self._stream_chunk_to_generation(chunk))
384-
385-
try:
386-
asyncio.run(_run())
387-
except BaseException as exc: # noqa: BLE001
388-
queue.put(exc)
389-
finally:
390-
queue.put(_STREAM_END)
391-
392-
thread = Thread(target=_runner, daemon=True)
393-
thread.start()
394-
395-
while True:
396-
item = queue.get()
397-
if item is _STREAM_END:
398-
break
399-
if isinstance(item, BaseException):
400-
raise item
401-
yield cast(ChatGenerationChunk, item)
402-
403-
thread.join()
378+
try:
379+
asyncio.get_running_loop()
380+
except RuntimeError:
381+
pass
382+
else:
383+
raise RuntimeError(
384+
"Synchronous stream cannot run inside an active event loop for this adapter. "
385+
"Use `astream` instead."
386+
)
387+
388+
loop = asyncio.new_event_loop()
389+
try:
390+
stream = loop.run_until_complete(self._llm.chat(**chat_kwargs))
391+
stream_iter = cast(AsyncIterator[StreamChunk], stream)
392+
393+
while True:
394+
try:
395+
chunk = loop.run_until_complete(stream_iter.__anext__())
396+
except StopAsyncIteration:
397+
break
398+
yield self._stream_chunk_to_generation(chunk)
399+
finally:
400+
loop.run_until_complete(loop.shutdown_asyncgens())
401+
loop.close()
404402

405403
async def _astream(
406404
self,

tests/langchain_adapter_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def test_initialization_without_private_key_or_client_raises(self):
6464
with pytest.raises(ValueError, match="private_key is required"):
6565
OpenGradientChatModel(private_key=None, model_cid=TEE_LLM.GPT_5)
6666

67+
def test_initialization_with_invalid_model_string_raises(self):
68+
with pytest.raises(ValueError, match="provider/model format"):
69+
OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid="gpt-5")
70+
6771
def test_identifying_params(self, model):
6872
"""Test _identifying_params returns model name."""
6973
assert model._identifying_params == {"model_name": TEE_LLM.GPT_5, "temperature": 0.0, "max_tokens": 300}
@@ -168,6 +172,24 @@ def test_empty_chat_output(self, model, mock_llm_client):
168172

169173
assert result.generations[0].message.content == ""
170174

175+
def test_generate_with_invalid_model_kwarg_raises(self, model):
176+
with pytest.raises(ValueError, match="provider/model format"):
177+
model._generate([HumanMessage(content="Hi")], model="gpt-5")
178+
179+
def test_sync_generate_inside_running_loop_raises(self, model):
180+
async def run_test():
181+
with pytest.raises(RuntimeError, match="Use `ainvoke`/`astream`"):
182+
model._generate([HumanMessage(content="Hi")])
183+
184+
asyncio.run(run_test())
185+
186+
def test_sync_stream_inside_running_loop_raises(self, model):
187+
async def run_test():
188+
with pytest.raises(RuntimeError, match="Use `astream`"):
189+
next(model._stream([HumanMessage(content="Hi")]))
190+
191+
asyncio.run(run_test())
192+
171193

172194
class TestMessageConversion:
173195
def test_converts_all_message_types(self, model, mock_llm_client):

0 commit comments

Comments
 (0)