Skip to content

Commit eb09eb5

Browse files
author
balogh.adam@icloud.com
committed
use frozen
1 parent c7501cd commit eb09eb5

5 files changed

Lines changed: 253 additions & 152 deletions

File tree

src/opengradient/client/llm.py

Lines changed: 40 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
"""LLM chat and completion via TEE-verified execution with x402 payments."""
22

3-
import asyncio
43
import json
54
import logging
6-
import ssl
7-
import time
85
from dataclasses import dataclass
96
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union
107
import httpx
118

129
from eth_account import Account
1310
from eth_account.account import LocalAccount
1411
from x402 import x402Client
15-
from x402.http.clients import x402HttpxClient
1612
from x402.mechanisms.evm import EthAccountSigner
1713
from x402.mechanisms.evm.exact.register import register_exact_evm_client
1814
from x402.mechanisms.evm.upto.register import register_upto_evm_client
1915

2016
from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode
2117
from .opg_token import Permit2ApprovalResult, ensure_opg_approval
22-
from .tee_registry import TEERegistry, build_ssl_context_from_der
18+
from .tee_connection import TEEConnection
19+
from .tee_registry import TEERegistry
2320

2421
logger = logging.getLogger(__name__)
2522
T = TypeVar("T")
@@ -34,10 +31,9 @@
3431
_CHAT_ENDPOINT = "/v1/chat/completions"
3532
_COMPLETION_ENDPOINT = "/v1/completions"
3633
_REQUEST_TIMEOUT = 60
37-
_TEE_REFRESH_INTERVAL = 300 # Re-resolve TEE from registry every 5 minutes
3834

3935

40-
@dataclass
36+
@dataclass(frozen=True)
4137
class _ChatParams:
4238
"""Bundles the common parameters for chat/completion requests."""
4339

@@ -99,96 +95,30 @@ def __init__(
9995
llm_server_url: Optional[str] = None,
10096
):
10197
self._wallet_account: LocalAccount = Account.from_key(private_key)
102-
self._rpc_url = rpc_url
103-
self._tee_registry_address = tee_registry_address
104-
self._llm_server_url = llm_server_url
10598

10699
# x402 payment stack (created once, reused across TEE refreshes)
107100
signer = EthAccountSigner(self._wallet_account)
108-
self._x402_client = x402Client()
109-
register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
110-
register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
111-
112-
self._connect_tee()
113-
self._tee_refreshed_at: float = time.monotonic()
114-
self._refresh_lock = asyncio.Lock()
115-
116-
# ── TEE resolution and connection ───────────────────────────────────────────
117-
118-
def _connect_tee(self) -> None:
119-
"""Resolve TEE from registry and create a secure HTTP client for it."""
120-
endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee(
121-
self._llm_server_url,
122-
self._rpc_url,
123-
self._tee_registry_address,
101+
x402_client = x402Client()
102+
register_exact_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK])
103+
register_upto_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK])
104+
105+
registry: Optional[TEERegistry] = (
106+
TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address)
107+
if llm_server_url is None
108+
else None
124109
)
125-
self._tee_id = tee_id
126-
self._tee_endpoint = endpoint
127-
self._tee_payment_address = tee_payment_address
128-
129-
ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None
130-
self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None)
131-
self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify)
132-
133-
async def _refresh_tee(self) -> None:
134-
"""Re-resolve TEE from the registry and rebuild the HTTP client."""
135-
async with self._refresh_lock:
136-
old_http_client = self._http_client
137-
self._connect_tee()
138-
self._tee_refreshed_at = time.monotonic()
139-
try:
140-
await old_http_client.aclose()
141-
except Exception:
142-
logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True)
143-
144-
async def _maybe_refresh_tee(self) -> None:
145-
"""Re-resolve TEE if the current one is older than ``_TEE_REFRESH_INTERVAL``.
146-
147-
Skips the refresh for explicit ``llm_server_url`` overrides since they
148-
bypass the registry entirely.
149-
"""
150-
if self._llm_server_url is not None:
151-
return
152-
if time.monotonic() - self._tee_refreshed_at < _TEE_REFRESH_INTERVAL:
153-
return
154-
logger.debug("TEE endpoint stale (>%ds); refreshing from registry.", _TEE_REFRESH_INTERVAL)
155-
await self._refresh_tee()
156-
157-
158-
@staticmethod
159-
def _resolve_tee(
160-
tee_endpoint_override: Optional[str],
161-
og_rpc_url: Optional[str],
162-
tee_registry_address: Optional[str],
163-
) -> tuple:
164-
"""Resolve TEE endpoint and metadata from the on-chain registry or explicit URL.
165-
166-
Returns:
167-
(endpoint, tls_cert_der, tee_id, payment_address)
168-
"""
169-
if tee_endpoint_override is not None:
170-
return tee_endpoint_override, None, None, None
171110

172-
if og_rpc_url is None or tee_registry_address is None:
173-
raise ValueError("Either llm_server_url or both rpc_url and tee_registry_address must be provided.")
174-
175-
try:
176-
registry = TEERegistry(rpc_url=og_rpc_url, registry_address=tee_registry_address)
177-
tee = registry.get_llm_tee()
178-
except Exception as e:
179-
raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address} on {og_rpc_url}): {e}. ") from e
180-
181-
if tee is None:
182-
raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.")
183-
184-
logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id)
185-
return tee.endpoint, tee.tls_cert_der, tee.tee_id, tee.payment_address
111+
self._tee = TEEConnection(
112+
x402_client=x402_client,
113+
registry=registry,
114+
llm_server_url=llm_server_url,
115+
)
186116

187117
# ── Lifecycle ───────────────────────────────────────────────────────
188118

189119
async def close(self) -> None:
190-
"""Close the underlying HTTP client."""
191-
await self._http_client.aclose()
120+
"""Cancel the background refresh loop and close the HTTP client."""
121+
await self._tee.close()
192122

193123
# ── Request helpers ─────────────────────────────────────────────────
194124

@@ -215,13 +145,6 @@ def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool
215145
payload["tool_choice"] = params.tool_choice or "auto"
216146
return payload
217147

218-
def _tee_metadata(self) -> Dict:
219-
return dict(
220-
tee_id=self._tee_id,
221-
tee_endpoint=self._tee_endpoint,
222-
tee_payment_address=self._tee_payment_address,
223-
)
224-
225148
async def _call_with_tee_retry(
226149
self,
227150
operation_name: str,
@@ -232,7 +155,7 @@ async def _call_with_tee_retry(
232155
Only retries when the request never reached the server (no HTTP response).
233156
Server-side errors (4xx/5xx) are not retried.
234157
"""
235-
await self._maybe_refresh_tee()
158+
self._tee.ensure_refresh_loop()
236159
try:
237160
return await call()
238161
except httpx.HTTPStatusError:
@@ -243,7 +166,7 @@ async def _call_with_tee_retry(
243166
operation_name,
244167
exc,
245168
)
246-
await self._refresh_tee()
169+
await self._tee.reconnect()
247170
return await call()
248171

249172
# ── Public API ──────────────────────────────────────────────────────
@@ -316,8 +239,9 @@ async def completion(
316239
payload["stop"] = stop_sequence
317240

318241
async def _request() -> TextGenerationOutput:
319-
response = await self._http_client.post(
320-
self._tee_endpoint + _COMPLETION_ENDPOINT,
242+
tee = self._tee.get()
243+
response = await tee.http_client.post(
244+
tee.endpoint + _COMPLETION_ENDPOINT,
321245
json=payload,
322246
headers=self._headers(x402_settlement_mode),
323247
timeout=_REQUEST_TIMEOUT,
@@ -330,7 +254,7 @@ async def _request() -> TextGenerationOutput:
330254
completion_output=result.get("completion"),
331255
tee_signature=result.get("tee_signature"),
332256
tee_timestamp=result.get("tee_timestamp"),
333-
**self._tee_metadata(),
257+
**tee.metadata(),
334258
)
335259

336260
try:
@@ -405,8 +329,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
405329
payload = self._chat_payload(params, messages)
406330

407331
async def _request() -> TextGenerationOutput:
408-
response = await self._http_client.post(
409-
self._tee_endpoint + _CHAT_ENDPOINT,
332+
tee = self._tee.get()
333+
response = await tee.http_client.post(
334+
tee.endpoint + _CHAT_ENDPOINT,
410335
json=payload,
411336
headers=self._headers(params.x402_settlement_mode),
412337
timeout=_REQUEST_TIMEOUT,
@@ -432,7 +357,7 @@ async def _request() -> TextGenerationOutput:
432357
chat_output=message,
433358
tee_signature=result.get("tee_signature"),
434359
tee_timestamp=result.get("tee_timestamp"),
435-
**self._tee_metadata(),
360+
**tee.metadata(),
436361
)
437362

438363
try:
@@ -469,15 +394,16 @@ async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict])
469394

470395
async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]:
471396
"""Async SSE streaming implementation."""
472-
await self._maybe_refresh_tee()
397+
self._tee.ensure_refresh_loop()
473398
headers = self._headers(params.x402_settlement_mode)
474399
payload = self._chat_payload(params, messages, stream=True)
475400

476401
chunks_yielded = False
477402
try:
478-
async with self._http_client.stream(
403+
tee = self._tee.get()
404+
async with tee.http_client.stream(
479405
"POST",
480-
self._tee_endpoint + _CHAT_ENDPOINT,
406+
tee.endpoint + _CHAT_ENDPOINT,
481407
json=payload,
482408
headers=headers,
483409
timeout=_REQUEST_TIMEOUT,
@@ -496,11 +422,12 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async
496422
exc,
497423
)
498424

499-
await self._refresh_tee()
425+
await self._tee.reconnect()
426+
tee = self._tee.get()
500427
headers = self._headers(params.x402_settlement_mode)
501-
async with self._http_client.stream(
428+
async with tee.http_client.stream(
502429
"POST",
503-
self._tee_endpoint + _CHAT_ENDPOINT,
430+
tee.endpoint + _CHAT_ENDPOINT,
504431
json=payload,
505432
headers=headers,
506433
timeout=_REQUEST_TIMEOUT,
@@ -546,7 +473,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non
546473

547474
chunk = StreamChunk.from_sse_data(data)
548475
if chunk.is_final:
549-
chunk.tee_id = self._tee_id
550-
chunk.tee_endpoint = self._tee_endpoint
551-
chunk.tee_payment_address = self._tee_payment_address
476+
tee = self._tee.get()
477+
chunk.tee_id = tee.tee_id
478+
chunk.tee_endpoint = tee.endpoint
479+
chunk.tee_payment_address = tee.payment_address
552480
yield chunk

0 commit comments

Comments
 (0)