Skip to content

Commit 1084388

Browse files
author
balogh.adam@icloud.com
committed
refactor
1 parent 1df5d4c commit 1084388

5 files changed

Lines changed: 348 additions & 290 deletions

File tree

src/opengradient/client/llm.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode
1717
from .opg_token import Permit2ApprovalResult, ensure_opg_approval
18-
from .tee_connection import TEEConnection
18+
from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface
1919
from .tee_registry import TEERegistry
2020

2121
logger = logging.getLogger(__name__)
@@ -62,62 +62,67 @@ class LLM:
6262
below the requested amount.
6363
6464
Usage:
65+
# Via on-chain registry (default)
6566
llm = og.LLM(private_key="0x...")
6667
68+
# Via hardcoded URL (development / self-hosted)
69+
llm = og.LLM.from_url(private_key="0x...", llm_server_url="https://1.2.3.4")
70+
6771
# One-time approval (idempotent — skips if allowance is already sufficient)
6872
llm.ensure_opg_approval(opg_amount=5)
6973
7074
result = await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...])
7175
result = await llm.completion(model=TEE_LLM.CLAUDE_HAIKU_4_5, prompt="Hello")
72-
73-
Args:
74-
private_key (str): Ethereum private key for signing x402 payments.
75-
rpc_url (str): RPC URL for the OpenGradient network. Used to fetch the
76-
active TEE endpoint from the on-chain registry when ``llm_server_url``
77-
is not provided.
78-
tee_registry_address (str): Address of the on-chain TEE registry contract.
79-
llm_server_url (str, optional): Bypass the registry and connect directly
80-
to this TEE endpoint URL (e.g. ``"https://1.2.3.4"``). When set,
81-
TLS certificate verification is disabled automatically because
82-
self-hosted TEE servers typically use self-signed certificates.
83-
84-
.. warning::
85-
Using ``llm_server_url`` disables TLS certificate verification,
86-
which removes protection against man-in-the-middle attacks.
87-
Only connect to servers you trust and over secure network paths.
8876
"""
8977

9078
def __init__(
9179
self,
9280
private_key: str,
9381
rpc_url: str = DEFAULT_RPC_URL,
9482
tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS,
95-
llm_server_url: Optional[str] = None,
9683
):
9784
if not private_key:
98-
raise ValueError(
99-
"A private key is required to use the LLM client. "
100-
"Pass a valid private_key to the constructor."
101-
)
85+
raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.")
10286
self._wallet_account: LocalAccount = Account.from_key(private_key)
10387

104-
# x402 payment stack (created once, reused across TEE refreshes)
105-
signer = EthAccountSigner(self._wallet_account)
106-
x402_client = x402Client()
107-
register_exact_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK])
108-
register_upto_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK])
88+
x402_client = LLM._build_x402_client(private_key)
89+
onchain_registry = TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address)
90+
self._tee: TEEConnectionInterface = RegistryTEEConnection(x402_client=x402_client, registry=onchain_registry)
10991

110-
registry: Optional[TEERegistry] = (
111-
TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address)
112-
if llm_server_url is None
113-
else None
114-
)
92+
@classmethod
93+
def from_url(
94+
cls,
95+
private_key: str,
96+
llm_server_url: str,
97+
) -> "LLM":
98+
"""**[Dev]** Create an LLM client with a hardcoded TEE endpoint URL.
11599
116-
self._tee = TEEConnection(
117-
x402_client=x402_client,
118-
registry=registry,
119-
llm_server_url=llm_server_url,
120-
)
100+
Intended for development and self-hosted TEE servers. TLS certificate
101+
verification is disabled because these servers typically use self-signed
102+
certificates. For production use, prefer the default constructor which
103+
resolves TEEs from the on-chain registry.
104+
105+
Args:
106+
private_key: Ethereum private key for signing x402 payments.
107+
llm_server_url: The TEE endpoint URL (e.g. ``"https://1.2.3.4"``).
108+
"""
109+
instance = cls.__new__(cls)
110+
if not private_key:
111+
raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.")
112+
instance._wallet_account = Account.from_key(private_key)
113+
x402_client = cls._build_x402_client(private_key)
114+
instance._tee = StaticTEEConnection(x402_client=x402_client, endpoint=llm_server_url)
115+
return instance
116+
117+
@staticmethod
118+
def _build_x402_client(private_key: str) -> x402Client:
119+
"""Build the x402 payment stack from a private key."""
120+
account = Account.from_key(private_key)
121+
signer = EthAccountSigner(account)
122+
client = x402Client()
123+
register_exact_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK])
124+
register_upto_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK])
125+
return client
121126

122127
# ── Lifecycle ───────────────────────────────────────────────────────
123128

src/opengradient/client/tee_connection.py

Lines changed: 93 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import ssl
66
from dataclasses import dataclass
7-
from typing import Dict, Optional, Union
7+
from typing import Dict, Optional, Protocol, Union
88

99
from x402 import x402Client
1010
from x402.http.clients import x402HttpxClient
@@ -34,36 +34,80 @@ def metadata(self) -> Dict:
3434
)
3535

3636

37-
class TEEConnection:
38-
"""Maintains a verified connection to a single TEE endpoint.
37+
class TEEConnectionInterface(Protocol):
38+
"""Interface for TEE connection implementations."""
3939

40-
Handles initial resolution from the on-chain registry (or an explicit URL),
41-
TLS certificate pinning, background health checks, and automatic failover
42-
when the current TEE becomes unavailable.
40+
def get(self) -> ActiveTEE: ...
41+
def ensure_refresh_loop(self) -> None: ...
42+
async def reconnect(self) -> None: ...
43+
async def close(self) -> None: ...
4344

44-
Use ``get()`` to obtain the current ``ActiveTEE`` snapshot for making requests.
45+
46+
class StaticTEEConnection:
47+
"""TEE connection with a hardcoded endpoint URL.
48+
49+
No registry lookup, no background refresh. TLS certificate verification
50+
is disabled because self-hosted TEE servers typically use self-signed certs.
4551
4652
Args:
4753
x402_client: Configured x402 payment client for creating HTTP clients.
48-
registry: TEERegistry for looking up active TEEs. None when using an explicit URL.
49-
llm_server_url: Bypass the registry and connect directly to this URL.
54+
endpoint: The TEE endpoint URL to connect to.
5055
"""
5156

52-
def __init__(
53-
self,
54-
x402_client: x402Client,
55-
registry: Optional[TEERegistry] = None,
56-
llm_server_url: Optional[str] = None,
57-
):
57+
def __init__(self, x402_client: x402Client, endpoint: str):
58+
self._x402_client = x402_client
59+
self._endpoint = endpoint
60+
self._active: ActiveTEE = self._connect()
61+
62+
def get(self) -> ActiveTEE:
63+
"""Return a snapshot of the current TEE connection."""
64+
return self._active
65+
66+
def _connect(self) -> ActiveTEE:
67+
return ActiveTEE(
68+
endpoint=self._endpoint,
69+
http_client=x402HttpxClient(self._x402_client, verify=False),
70+
tee_id=None,
71+
payment_address=None,
72+
)
73+
74+
def ensure_refresh_loop(self) -> None:
75+
"""No-op — static connections don't refresh."""
76+
pass
77+
78+
async def reconnect(self) -> None:
79+
"""Rebuild the HTTP client (same endpoint)."""
80+
old_client = self._active.http_client
81+
self._active = self._connect()
82+
try:
83+
await old_client.aclose()
84+
except Exception:
85+
logger.debug("Failed to close previous HTTP client during reconnect.", exc_info=True)
86+
87+
async def close(self) -> None:
88+
"""Close the HTTP client."""
89+
await self._active.http_client.aclose()
90+
91+
92+
class RegistryTEEConnection:
93+
"""TEE connection resolved from the on-chain registry.
94+
95+
Handles TLS certificate pinning, background health checks, and automatic
96+
failover when the current TEE becomes unavailable.
97+
98+
Args:
99+
x402_client: Configured x402 payment client for creating HTTP clients.
100+
registry: TEERegistry for looking up active TEEs.
101+
"""
102+
103+
def __init__(self, x402_client: x402Client, registry: TEERegistry):
58104
self._x402_client = x402_client
59105
self._registry = registry
60-
self._llm_server_url = llm_server_url
61106

62-
self._active: Optional[ActiveTEE] = None
63107
self._refresh_lock = asyncio.Lock()
64108
self._refresh_task: Optional[asyncio.Task] = None
65109

66-
self._connect()
110+
self._active: ActiveTEE = self._connect()
67111

68112
# ── Public API ──────────────────────────────────────────────────────
69113

@@ -73,28 +117,46 @@ def get(self) -> ActiveTEE:
73117

74118
# ── Connection management ───────────────────────────────────────────
75119

76-
def _connect(self) -> None:
120+
def _resolve_tee(self):
121+
"""Resolve TEE endpoint and metadata from the on-chain registry.
122+
123+
Returns:
124+
The TEE object from the registry.
125+
126+
Raises:
127+
RuntimeError: If the registry lookup fails.
128+
ValueError: If no active LLM proxy TEE is found.
129+
"""
130+
try:
131+
tee = self._registry.get_llm_tee()
132+
except Exception as e:
133+
raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry: {e}") from e
134+
135+
if tee is None:
136+
raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.")
137+
138+
logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id)
139+
return tee
140+
141+
def _connect(self) -> ActiveTEE:
77142
"""Resolve TEE from registry and create a secure HTTP client."""
78-
endpoint, tls_cert_der, tee_id, payment_address = self._resolve_tee(
79-
self._llm_server_url,
80-
self._registry,
81-
)
143+
tee = self._resolve_tee()
82144

83-
ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None
84-
tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None)
145+
ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) if tee.tls_cert_der else None
146+
tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True
85147

86-
self._active = ActiveTEE(
87-
endpoint=endpoint,
148+
return ActiveTEE(
149+
endpoint=tee.endpoint,
88150
http_client=x402HttpxClient(self._x402_client, verify=tls_verify),
89-
tee_id=tee_id,
90-
payment_address=payment_address,
151+
tee_id=tee.tee_id,
152+
payment_address=tee.payment_address,
91153
)
92154

93155
async def reconnect(self) -> None:
94156
"""Connect to a new TEE from the registry and rebuild the HTTP client."""
95157
async with self._refresh_lock:
96158
old_client = self._active.http_client
97-
self._connect()
159+
self._active = self._connect()
98160
try:
99161
await old_client.aclose()
100162
except Exception:
@@ -105,11 +167,8 @@ async def reconnect(self) -> None:
105167
def ensure_refresh_loop(self) -> None:
106168
"""Start the background TEE refresh loop if not already running.
107169
108-
No-op when ``llm_server_url`` is set (bypasses the registry).
109170
Called lazily from async request methods since ``__init__`` is synchronous.
110171
"""
111-
if self._llm_server_url is not None:
112-
return
113172
if self._refresh_task is not None and not self._refresh_task.done():
114173
return
115174
self._refresh_task = asyncio.create_task(self._tee_refresh_loop())
@@ -139,34 +198,4 @@ async def close(self) -> None:
139198
if self._refresh_task is not None:
140199
self._refresh_task.cancel()
141200
self._refresh_task = None
142-
if self._active is not None:
143-
await self._active.http_client.aclose()
144-
145-
# ── Static helpers ──────────────────────────────────────────────────
146-
147-
@staticmethod
148-
def _resolve_tee(
149-
tee_endpoint_override: Optional[str],
150-
registry: Optional[TEERegistry],
151-
) -> tuple:
152-
"""Resolve TEE endpoint and metadata from the on-chain registry or explicit URL.
153-
154-
Returns:
155-
(endpoint, tls_cert_der, tee_id, payment_address)
156-
"""
157-
if tee_endpoint_override is not None:
158-
return tee_endpoint_override, None, None, None
159-
160-
if registry is None:
161-
raise ValueError("Either llm_server_url or a TEERegistry instance must be provided.")
162-
163-
try:
164-
tee = registry.get_llm_tee()
165-
except Exception as e:
166-
raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry: {e}") from e
167-
168-
if tee is None:
169-
raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.")
170-
171-
logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id)
172-
return tee.endpoint, tee.tls_cert_der, tee.tee_id, tee.payment_address
201+
await self._active.http_client.aclose()

tests/client_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_llm_initialization(self, mock_tee_registry):
7777
def test_llm_initialization_custom_url(self, mock_tee_registry):
7878
"""Test LLM initialization with custom server URL."""
7979
custom_llm_url = "https://custom.llm.server"
80-
llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url)
80+
llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url)
8181
assert llm._tee.get().endpoint == custom_llm_url
8282

8383

tests/llm_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ def _make_llm(
137137
endpoint: str = "https://test.tee.server",
138138
) -> LLM:
139139
"""Build an LLM with an explicit server URL (skips registry lookup)."""
140-
llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=endpoint)
141-
# llm_server_url path sets tee_id/payment_address to None; replace with test values.
142140
from dataclasses import replace
141+
142+
llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=endpoint)
143+
# from_url sets tee_id/payment_address to None; replace with test values.
143144
llm._tee._active = replace(llm._tee.get(), tee_id="test-tee-id", payment_address="0xTestPayment")
144145
return llm
145146

0 commit comments

Comments
 (0)