Skip to content

Commit f680d72

Browse files
committed
cleanup
1 parent 2cc17ec commit f680d72

3 files changed

Lines changed: 98 additions & 24 deletions

File tree

examples/llm_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
async def main():
1212
llm = og.LLM(private_key=os.environ.get("OG_PRIVATE_KEY"))
13-
llm.ensure_opg_approval(opg_amount=0.1)
13+
llm.ensure_opg_approval(opg_amount=1)
1414

1515
messages = [
1616
{"role": "user", "content": "What is the capital of France?"},

examples/llm_chat_streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
async def main():
88
llm = og.LLM(private_key=os.environ.get("OG_PRIVATE_KEY"))
9-
llm.ensure_opg_approval(opg_amount=0.1)
9+
llm.ensure_opg_approval(opg_amount=1)
1010

1111
messages = [
1212
{"role": "user", "content": "What is Python?"},

src/opengradient/client/llm.py

Lines changed: 96 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import json
44
import logging
55
import ssl
6+
import threading
67
from dataclasses import dataclass
7-
from typing import AsyncGenerator, Dict, List, Optional, Union
8+
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union
89

910
from eth_account import Account
1011
from eth_account.account import LocalAccount
@@ -19,6 +20,7 @@
1920
from .tee_registry import TEERegistry, build_ssl_context_from_der
2021

2122
logger = logging.getLogger(__name__)
23+
T = TypeVar("T")
2224

2325
DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai"
2426
DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248"
@@ -94,31 +96,76 @@ def __init__(
9496
llm_server_url: Optional[str] = None,
9597
):
9698
self._wallet_account: LocalAccount = Account.from_key(private_key)
97-
98-
endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee(
99-
llm_server_url,
100-
rpc_url,
101-
tee_registry_address,
99+
self._rpc_url = rpc_url
100+
self._tee_registry_address = tee_registry_address
101+
self._llm_server_url = llm_server_url
102+
self._reset_lock = threading.Lock()
103+
104+
self._refresh_tee_config()
105+
self._init_x402_stack()
106+
107+
def _refresh_tee_config(self) -> None:
108+
"""Resolve TEE metadata from the registry and update TLS config."""
109+
endpoint, tls_cert_der, tee_id, payment_addr = self._resolve_tee(
110+
self._llm_server_url, self._rpc_url, self._tee_registry_address,
102111
)
103-
112+
ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None
104113
self._tee_id = tee_id
105114
self._tee_endpoint = endpoint
106-
self._tee_payment_address = tee_payment_address
107-
108-
ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None
109-
# When connecting directly via llm_server_url, skip cert verification —
110-
# self-hosted TEE servers commonly use self-signed certificates.
111-
verify_ssl = llm_server_url is None
112-
self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else verify_ssl
115+
self._tee_payment_address = payment_addr
116+
self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None)
113117

114-
# x402 client and signer
118+
def _init_x402_stack(self) -> None:
119+
"""Initialize x402 signer/client/http stack."""
115120
signer = EthAccountSigner(self._wallet_account)
116121
self._x402_client = x402Client()
117122
register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
118123
register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
119-
# httpx.AsyncClient subclass - construction is sync, connections open lazily
120124
self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify)
121125

126+
@staticmethod
127+
def _has_ssl_cause(exc: BaseException) -> bool:
128+
"""Return true when the exception chain contains an SSL error."""
129+
visited: set[int] = set()
130+
current: Optional[BaseException] = exc
131+
while current is not None and id(current) not in visited:
132+
visited.add(id(current))
133+
if isinstance(current, ssl.SSLError):
134+
return True
135+
current = current.__cause__ or current.__context__
136+
return False
137+
138+
async def _refresh_tee_and_reset(self) -> None:
139+
"""Re-resolve TEE and rebuild the HTTP client with fresh TLS config."""
140+
with self._reset_lock:
141+
old_http_client = self._http_client
142+
self._refresh_tee_config()
143+
self._init_x402_stack()
144+
145+
try:
146+
await old_http_client.aclose()
147+
except Exception:
148+
logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True)
149+
150+
async def _call_with_ssl_retry(
151+
self,
152+
operation_name: str,
153+
call: Callable[[], Awaitable[T]],
154+
) -> T:
155+
"""Retry once with fresh TEE/TLS state when the failure is SSL-related."""
156+
try:
157+
return await call()
158+
except Exception as exc:
159+
if not self._has_ssl_cause(exc):
160+
raise
161+
logger.warning(
162+
"SSL failure during %s; refreshing TEE and retrying once: %s",
163+
operation_name,
164+
exc,
165+
)
166+
await self._refresh_tee_and_reset()
167+
return await call()
168+
122169
# ── TEE resolution ──────────────────────────────────────────────────
123170

124171
@staticmethod
@@ -248,7 +295,6 @@ async def completion(
248295
RuntimeError: If the inference fails.
249296
"""
250297
model_id = model.split("/")[1]
251-
headers = self._headers(x402_settlement_mode)
252298
payload: Dict = {
253299
"model": model_id,
254300
"prompt": prompt,
@@ -258,11 +304,11 @@ async def completion(
258304
if stop_sequence:
259305
payload["stop"] = stop_sequence
260306

261-
try:
307+
async def _request() -> TextGenerationOutput:
262308
response = await self._http_client.post(
263309
self._tee_endpoint + _COMPLETION_ENDPOINT,
264310
json=payload,
265-
headers=headers,
311+
headers=self._headers(x402_settlement_mode),
266312
timeout=_REQUEST_TIMEOUT,
267313
)
268314
response.raise_for_status()
@@ -275,6 +321,9 @@ async def completion(
275321
tee_timestamp=result.get("tee_timestamp"),
276322
**self._tee_metadata(),
277323
)
324+
325+
try:
326+
return await self._call_with_ssl_retry("completion", _request)
278327
except RuntimeError:
279328
raise
280329
except Exception as e:
@@ -342,14 +391,13 @@ async def chat(
342391

343392
async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> TextGenerationOutput:
344393
"""Non-streaming chat request."""
345-
headers = self._headers(params.x402_settlement_mode)
346394
payload = self._chat_payload(params, messages)
347395

348-
try:
396+
async def _request() -> TextGenerationOutput:
349397
response = await self._http_client.post(
350398
self._tee_endpoint + _CHAT_ENDPOINT,
351399
json=payload,
352-
headers=headers,
400+
headers=self._headers(params.x402_settlement_mode),
353401
timeout=_REQUEST_TIMEOUT,
354402
)
355403
response.raise_for_status()
@@ -375,6 +423,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
375423
tee_timestamp=result.get("tee_timestamp"),
376424
**self._tee_metadata(),
377425
)
426+
427+
try:
428+
return await self._call_with_ssl_retry("chat", _request)
378429
except RuntimeError:
379430
raise
380431
except Exception as e:
@@ -410,6 +461,29 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async
410461
headers = self._headers(params.x402_settlement_mode)
411462
payload = self._chat_payload(params, messages, stream=True)
412463

464+
chunks_yielded = False
465+
try:
466+
async with self._http_client.stream(
467+
"POST",
468+
self._tee_endpoint + _CHAT_ENDPOINT,
469+
json=payload,
470+
headers=headers,
471+
timeout=_REQUEST_TIMEOUT,
472+
) as response:
473+
async for chunk in self._parse_sse_response(response):
474+
chunks_yielded = True
475+
yield chunk
476+
return
477+
except Exception as exc:
478+
if chunks_yielded or not self._has_ssl_cause(exc):
479+
raise
480+
logger.warning(
481+
"SSL failure during stream setup; refreshing TEE and retrying once: %s",
482+
exc,
483+
)
484+
485+
await self._refresh_tee_and_reset()
486+
headers = self._headers(params.x402_settlement_mode)
413487
async with self._http_client.stream(
414488
"POST",
415489
self._tee_endpoint + _CHAT_ENDPOINT,

0 commit comments

Comments
 (0)