11"""LLM chat and completion via TEE-verified execution with x402 payments."""
22
3- import asyncio
43import json
54import logging
6- import ssl
7- import time
85from dataclasses import dataclass
96from typing import AsyncGenerator , Awaitable , Callable , Dict , List , Optional , TypeVar , Union
107import httpx
118
129from eth_account import Account
1310from eth_account .account import LocalAccount
1411from x402 import x402Client
15- from x402 .http .clients import x402HttpxClient
1612from x402 .mechanisms .evm import EthAccountSigner
1713from x402 .mechanisms .evm .exact .register import register_exact_evm_client
1814from x402 .mechanisms .evm .upto .register import register_upto_evm_client
1915
2016from ..types import TEE_LLM , StreamChoice , StreamChunk , StreamDelta , TextGenerationOutput , x402SettlementMode
2117from .opg_token import Permit2ApprovalResult , ensure_opg_approval
22- from .tee_registry import TEERegistry , build_ssl_context_from_der
18+ from .tee_connection import TEEConnection
19+ from .tee_registry import TEERegistry
2320
2421logger = logging .getLogger (__name__ )
2522T = TypeVar ("T" )
3431_CHAT_ENDPOINT = "/v1/chat/completions"
3532_COMPLETION_ENDPOINT = "/v1/completions"
3633_REQUEST_TIMEOUT = 60
37- _TEE_REFRESH_INTERVAL = 300 # Re-resolve TEE from registry every 5 minutes
3834
3935
40- @dataclass
36+ @dataclass ( frozen = True )
4137class _ChatParams :
4238 """Bundles the common parameters for chat/completion requests."""
4339
@@ -99,96 +95,30 @@ def __init__(
9995 llm_server_url : Optional [str ] = None ,
10096 ):
10197 self ._wallet_account : LocalAccount = Account .from_key (private_key )
102- self ._rpc_url = rpc_url
103- self ._tee_registry_address = tee_registry_address
104- self ._llm_server_url = llm_server_url
10598
10699 # x402 payment stack (created once, reused across TEE refreshes)
107100 signer = EthAccountSigner (self ._wallet_account )
108- self ._x402_client = x402Client ()
109- register_exact_evm_client (self ._x402_client , signer , networks = [BASE_TESTNET_NETWORK ])
110- register_upto_evm_client (self ._x402_client , signer , networks = [BASE_TESTNET_NETWORK ])
111-
112- self ._connect_tee ()
113- self ._tee_refreshed_at : float = time .monotonic ()
114- self ._refresh_lock = asyncio .Lock ()
115-
116- # ── TEE resolution and connection ───────────────────────────────────────────
117-
118- def _connect_tee (self ) -> None :
119- """Resolve TEE from registry and create a secure HTTP client for it."""
120- endpoint , tls_cert_der , tee_id , tee_payment_address = self ._resolve_tee (
121- self ._llm_server_url ,
122- self ._rpc_url ,
123- self ._tee_registry_address ,
101+ x402_client = x402Client ()
102+ register_exact_evm_client (x402_client , signer , networks = [BASE_TESTNET_NETWORK ])
103+ register_upto_evm_client (x402_client , signer , networks = [BASE_TESTNET_NETWORK ])
104+
105+ registry : Optional [TEERegistry ] = (
106+ TEERegistry (rpc_url = rpc_url , registry_address = tee_registry_address )
107+ if llm_server_url is None
108+ else None
124109 )
125- self ._tee_id = tee_id
126- self ._tee_endpoint = endpoint
127- self ._tee_payment_address = tee_payment_address
128-
129- ssl_ctx = build_ssl_context_from_der (tls_cert_der ) if tls_cert_der else None
130- self ._tls_verify : Union [ssl .SSLContext , bool ] = ssl_ctx if ssl_ctx else (self ._llm_server_url is None )
131- self ._http_client = x402HttpxClient (self ._x402_client , verify = self ._tls_verify )
132-
133- async def _refresh_tee (self ) -> None :
134- """Re-resolve TEE from the registry and rebuild the HTTP client."""
135- async with self ._refresh_lock :
136- old_http_client = self ._http_client
137- self ._connect_tee ()
138- self ._tee_refreshed_at = time .monotonic ()
139- try :
140- await old_http_client .aclose ()
141- except Exception :
142- logger .debug ("Failed to close previous HTTP client during TEE refresh." , exc_info = True )
143-
144- async def _maybe_refresh_tee (self ) -> None :
145- """Re-resolve TEE if the current one is older than ``_TEE_REFRESH_INTERVAL``.
146-
147- Skips the refresh for explicit ``llm_server_url`` overrides since they
148- bypass the registry entirely.
149- """
150- if self ._llm_server_url is not None :
151- return
152- if time .monotonic () - self ._tee_refreshed_at < _TEE_REFRESH_INTERVAL :
153- return
154- logger .debug ("TEE endpoint stale (>%ds); refreshing from registry." , _TEE_REFRESH_INTERVAL )
155- await self ._refresh_tee ()
156-
157-
158- @staticmethod
159- def _resolve_tee (
160- tee_endpoint_override : Optional [str ],
161- og_rpc_url : Optional [str ],
162- tee_registry_address : Optional [str ],
163- ) -> tuple :
164- """Resolve TEE endpoint and metadata from the on-chain registry or explicit URL.
165-
166- Returns:
167- (endpoint, tls_cert_der, tee_id, payment_address)
168- """
169- if tee_endpoint_override is not None :
170- return tee_endpoint_override , None , None , None
171110
172- if og_rpc_url is None or tee_registry_address is None :
173- raise ValueError ("Either llm_server_url or both rpc_url and tee_registry_address must be provided." )
174-
175- try :
176- registry = TEERegistry (rpc_url = og_rpc_url , registry_address = tee_registry_address )
177- tee = registry .get_llm_tee ()
178- except Exception as e :
179- raise RuntimeError (f"Failed to fetch LLM TEE endpoint from registry ({ tee_registry_address } on { og_rpc_url } ): { e } . " ) from e
180-
181- if tee is None :
182- raise ValueError ("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override." )
183-
184- logger .info ("Using TEE endpoint from registry: %s (teeId=%s)" , tee .endpoint , tee .tee_id )
185- return tee .endpoint , tee .tls_cert_der , tee .tee_id , tee .payment_address
111+ self ._tee = TEEConnection (
112+ x402_client = x402_client ,
113+ registry = registry ,
114+ llm_server_url = llm_server_url ,
115+ )
186116
187117 # ── Lifecycle ───────────────────────────────────────────────────────
188118
189119 async def close (self ) -> None :
190- """Close the underlying HTTP client."""
191- await self ._http_client . aclose ()
120+ """Cancel the background refresh loop and close the HTTP client."""
121+ await self ._tee . close ()
192122
193123 # ── Request helpers ─────────────────────────────────────────────────
194124
@@ -215,13 +145,6 @@ def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool
215145 payload ["tool_choice" ] = params .tool_choice or "auto"
216146 return payload
217147
218- def _tee_metadata (self ) -> Dict :
219- return dict (
220- tee_id = self ._tee_id ,
221- tee_endpoint = self ._tee_endpoint ,
222- tee_payment_address = self ._tee_payment_address ,
223- )
224-
225148 async def _call_with_tee_retry (
226149 self ,
227150 operation_name : str ,
@@ -232,7 +155,7 @@ async def _call_with_tee_retry(
232155 Only retries when the request never reached the server (no HTTP response).
233156 Server-side errors (4xx/5xx) are not retried.
234157 """
235- await self ._maybe_refresh_tee ()
158+ self ._tee . ensure_refresh_loop ()
236159 try :
237160 return await call ()
238161 except httpx .HTTPStatusError :
@@ -243,7 +166,7 @@ async def _call_with_tee_retry(
243166 operation_name ,
244167 exc ,
245168 )
246- await self ._refresh_tee ()
169+ await self ._tee . reconnect ()
247170 return await call ()
248171
249172 # ── Public API ──────────────────────────────────────────────────────
@@ -316,8 +239,9 @@ async def completion(
316239 payload ["stop" ] = stop_sequence
317240
318241 async def _request () -> TextGenerationOutput :
319- response = await self ._http_client .post (
320- self ._tee_endpoint + _COMPLETION_ENDPOINT ,
242+ tee = self ._tee .get ()
243+ response = await tee .http_client .post (
244+ tee .endpoint + _COMPLETION_ENDPOINT ,
321245 json = payload ,
322246 headers = self ._headers (x402_settlement_mode ),
323247 timeout = _REQUEST_TIMEOUT ,
@@ -330,7 +254,7 @@ async def _request() -> TextGenerationOutput:
330254 completion_output = result .get ("completion" ),
331255 tee_signature = result .get ("tee_signature" ),
332256 tee_timestamp = result .get ("tee_timestamp" ),
333- ** self . _tee_metadata (),
257+ ** tee . metadata (),
334258 )
335259
336260 try :
@@ -405,8 +329,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
405329 payload = self ._chat_payload (params , messages )
406330
407331 async def _request () -> TextGenerationOutput :
408- response = await self ._http_client .post (
409- self ._tee_endpoint + _CHAT_ENDPOINT ,
332+ tee = self ._tee .get ()
333+ response = await tee .http_client .post (
334+ tee .endpoint + _CHAT_ENDPOINT ,
410335 json = payload ,
411336 headers = self ._headers (params .x402_settlement_mode ),
412337 timeout = _REQUEST_TIMEOUT ,
@@ -432,7 +357,7 @@ async def _request() -> TextGenerationOutput:
432357 chat_output = message ,
433358 tee_signature = result .get ("tee_signature" ),
434359 tee_timestamp = result .get ("tee_timestamp" ),
435- ** self . _tee_metadata (),
360+ ** tee . metadata (),
436361 )
437362
438363 try :
@@ -469,15 +394,16 @@ async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict])
469394
470395 async def _chat_stream (self , params : _ChatParams , messages : List [Dict ]) -> AsyncGenerator [StreamChunk , None ]:
471396 """Async SSE streaming implementation."""
472- await self ._maybe_refresh_tee ()
397+ self ._tee . ensure_refresh_loop ()
473398 headers = self ._headers (params .x402_settlement_mode )
474399 payload = self ._chat_payload (params , messages , stream = True )
475400
476401 chunks_yielded = False
477402 try :
478- async with self ._http_client .stream (
403+ tee = self ._tee .get ()
404+ async with tee .http_client .stream (
479405 "POST" ,
480- self . _tee_endpoint + _CHAT_ENDPOINT ,
406+ tee . endpoint + _CHAT_ENDPOINT ,
481407 json = payload ,
482408 headers = headers ,
483409 timeout = _REQUEST_TIMEOUT ,
@@ -496,11 +422,12 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async
496422 exc ,
497423 )
498424
499- await self ._refresh_tee ()
425+ await self ._tee .reconnect ()
426+ tee = self ._tee .get ()
500427 headers = self ._headers (params .x402_settlement_mode )
501- async with self . _http_client .stream (
428+ async with tee . http_client .stream (
502429 "POST" ,
503- self . _tee_endpoint + _CHAT_ENDPOINT ,
430+ tee . endpoint + _CHAT_ENDPOINT ,
504431 json = payload ,
505432 headers = headers ,
506433 timeout = _REQUEST_TIMEOUT ,
@@ -546,7 +473,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non
546473
547474 chunk = StreamChunk .from_sse_data (data )
548475 if chunk .is_final :
549- chunk .tee_id = self ._tee_id
550- chunk .tee_endpoint = self ._tee_endpoint
551- chunk .tee_payment_address = self ._tee_payment_address
476+ tee = self ._tee .get ()
477+ chunk .tee_id = tee .tee_id
478+ chunk .tee_endpoint = tee .endpoint
479+ chunk .tee_payment_address = tee .payment_address
552480 yield chunk
0 commit comments