33import json
44import logging
55import ssl
6+ import threading
67from dataclasses import dataclass
7- from typing import AsyncGenerator , Dict , List , Optional , Union
8+ from typing import AsyncGenerator , Awaitable , Callable , Dict , List , Optional , TypeVar , Union
89
910from eth_account import Account
1011from eth_account .account import LocalAccount
1920from .tee_registry import TEERegistry , build_ssl_context_from_der
2021
2122logger = logging .getLogger (__name__ )
23+ T = TypeVar ("T" )
2224
2325DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai"
2426DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248"
@@ -94,31 +96,76 @@ def __init__(
9496 llm_server_url : Optional [str ] = None ,
9597 ):
9698 self ._wallet_account : LocalAccount = Account .from_key (private_key )
97-
98- endpoint , tls_cert_der , tee_id , tee_payment_address = self ._resolve_tee (
99- llm_server_url ,
100- rpc_url ,
101- tee_registry_address ,
99+ self ._rpc_url = rpc_url
100+ self ._tee_registry_address = tee_registry_address
101+ self ._llm_server_url = llm_server_url
102+ self ._reset_lock = threading .Lock ()
103+
104+ self ._refresh_tee_config ()
105+ self ._init_x402_stack ()
106+
107+ def _refresh_tee_config (self ) -> None :
108+ """Resolve TEE metadata from the registry and update TLS config."""
109+ endpoint , tls_cert_der , tee_id , payment_addr = self ._resolve_tee (
110+ self ._llm_server_url , self ._rpc_url , self ._tee_registry_address ,
102111 )
103-
112+ ssl_ctx = build_ssl_context_from_der ( tls_cert_der ) if tls_cert_der else None
104113 self ._tee_id = tee_id
105114 self ._tee_endpoint = endpoint
106- self ._tee_payment_address = tee_payment_address
107-
108- ssl_ctx = build_ssl_context_from_der (tls_cert_der ) if tls_cert_der else None
109- # When connecting directly via llm_server_url, skip cert verification —
110- # self-hosted TEE servers commonly use self-signed certificates.
111- verify_ssl = llm_server_url is None
112- self ._tls_verify : Union [ssl .SSLContext , bool ] = ssl_ctx if ssl_ctx else verify_ssl
115+ self ._tee_payment_address = payment_addr
116+ self ._tls_verify : Union [ssl .SSLContext , bool ] = ssl_ctx if ssl_ctx else (self ._llm_server_url is None )
113117
114- # x402 client and signer
118+ def _init_x402_stack (self ) -> None :
119+ """Initialize x402 signer/client/http stack."""
115120 signer = EthAccountSigner (self ._wallet_account )
116121 self ._x402_client = x402Client ()
117122 register_exact_evm_client (self ._x402_client , signer , networks = [BASE_TESTNET_NETWORK ])
118123 register_upto_evm_client (self ._x402_client , signer , networks = [BASE_TESTNET_NETWORK ])
119- # httpx.AsyncClient subclass - construction is sync, connections open lazily
120124 self ._http_client = x402HttpxClient (self ._x402_client , verify = self ._tls_verify )
121125
126+ @staticmethod
127+ def _has_ssl_cause (exc : BaseException ) -> bool :
128+ """Return true when the exception chain contains an SSL error."""
129+ visited : set [int ] = set ()
130+ current : Optional [BaseException ] = exc
131+ while current is not None and id (current ) not in visited :
132+ visited .add (id (current ))
133+ if isinstance (current , ssl .SSLError ):
134+ return True
135+ current = current .__cause__ or current .__context__
136+ return False
137+
138+ async def _refresh_tee_and_reset (self ) -> None :
139+ """Re-resolve TEE and rebuild the HTTP client with fresh TLS config."""
140+ with self ._reset_lock :
141+ old_http_client = self ._http_client
142+ self ._refresh_tee_config ()
143+ self ._init_x402_stack ()
144+
145+ try :
146+ await old_http_client .aclose ()
147+ except Exception :
148+ logger .debug ("Failed to close previous HTTP client during TEE refresh." , exc_info = True )
149+
150+ async def _call_with_ssl_retry (
151+ self ,
152+ operation_name : str ,
153+ call : Callable [[], Awaitable [T ]],
154+ ) -> T :
155+ """Retry once with fresh TEE/TLS state when the failure is SSL-related."""
156+ try :
157+ return await call ()
158+ except Exception as exc :
159+ if not self ._has_ssl_cause (exc ):
160+ raise
161+ logger .warning (
162+ "SSL failure during %s; refreshing TEE and retrying once: %s" ,
163+ operation_name ,
164+ exc ,
165+ )
166+ await self ._refresh_tee_and_reset ()
167+ return await call ()
168+
122169 # ── TEE resolution ──────────────────────────────────────────────────
123170
124171 @staticmethod
@@ -248,7 +295,6 @@ async def completion(
248295 RuntimeError: If the inference fails.
249296 """
250297 model_id = model .split ("/" )[1 ]
251- headers = self ._headers (x402_settlement_mode )
252298 payload : Dict = {
253299 "model" : model_id ,
254300 "prompt" : prompt ,
@@ -258,11 +304,11 @@ async def completion(
258304 if stop_sequence :
259305 payload ["stop" ] = stop_sequence
260306
261- try :
307+ async def _request () -> TextGenerationOutput :
262308 response = await self ._http_client .post (
263309 self ._tee_endpoint + _COMPLETION_ENDPOINT ,
264310 json = payload ,
265- headers = headers ,
311+ headers = self . _headers ( x402_settlement_mode ) ,
266312 timeout = _REQUEST_TIMEOUT ,
267313 )
268314 response .raise_for_status ()
@@ -275,6 +321,9 @@ async def completion(
275321 tee_timestamp = result .get ("tee_timestamp" ),
276322 ** self ._tee_metadata (),
277323 )
324+
325+ try :
326+ return await self ._call_with_ssl_retry ("completion" , _request )
278327 except RuntimeError :
279328 raise
280329 except Exception as e :
@@ -342,14 +391,13 @@ async def chat(
342391
343392 async def _chat_request (self , params : _ChatParams , messages : List [Dict ]) -> TextGenerationOutput :
344393 """Non-streaming chat request."""
345- headers = self ._headers (params .x402_settlement_mode )
346394 payload = self ._chat_payload (params , messages )
347395
348- try :
396+ async def _request () -> TextGenerationOutput :
349397 response = await self ._http_client .post (
350398 self ._tee_endpoint + _CHAT_ENDPOINT ,
351399 json = payload ,
352- headers = headers ,
400+ headers = self . _headers ( params . x402_settlement_mode ) ,
353401 timeout = _REQUEST_TIMEOUT ,
354402 )
355403 response .raise_for_status ()
@@ -375,6 +423,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
375423 tee_timestamp = result .get ("tee_timestamp" ),
376424 ** self ._tee_metadata (),
377425 )
426+
427+ try :
428+ return await self ._call_with_ssl_retry ("chat" , _request )
378429 except RuntimeError :
379430 raise
380431 except Exception as e :
@@ -410,6 +461,29 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async
410461 headers = self ._headers (params .x402_settlement_mode )
411462 payload = self ._chat_payload (params , messages , stream = True )
412463
464+ chunks_yielded = False
465+ try :
466+ async with self ._http_client .stream (
467+ "POST" ,
468+ self ._tee_endpoint + _CHAT_ENDPOINT ,
469+ json = payload ,
470+ headers = headers ,
471+ timeout = _REQUEST_TIMEOUT ,
472+ ) as response :
473+ async for chunk in self ._parse_sse_response (response ):
474+ chunks_yielded = True
475+ yield chunk
476+ return
477+ except Exception as exc :
478+ if chunks_yielded or not self ._has_ssl_cause (exc ):
479+ raise
480+ logger .warning (
481+ "SSL failure during stream setup; refreshing TEE and retrying once: %s" ,
482+ exc ,
483+ )
484+
485+ await self ._refresh_tee_and_reset ()
486+ headers = self ._headers (params .x402_settlement_mode )
413487 async with self ._http_client .stream (
414488 "POST" ,
415489 self ._tee_endpoint + _CHAT_ENDPOINT ,
0 commit comments