Skip to content

Commit cc6919e

Browse files
author
balogh.adam@icloud.com
committed
in memory cert + registry test
1 parent 56a1ce4 commit cc6919e

4 files changed

Lines changed: 261 additions & 16 deletions

File tree

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ docs:
3131
# Testing
3232
# ============================================================================
3333

34-
test: utils_test client_test langchain_adapter_test opg_token_test
34+
test: utils_test client_test langchain_adapter_test opg_token_test tee_registry_test
3535

3636
utils_test:
3737
pytest tests/utils_test.py -v
@@ -45,6 +45,9 @@ langchain_adapter_test:
4545
opg_token_test:
4646
pytest tests/opg_token_test.py -v
4747

48+
tee_registry_test:
49+
pytest tests/tee_registry_test.py -v
50+
4851
integrationtest:
4952
python integrationtest/agent/test_agent.py
5053
python integrationtest/workflow_models/test_workflow_models.py

src/opengradient/client/tee_registry.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
import ssl
5-
import tempfile
65
from dataclasses import dataclass
76
from typing import List, Optional
87

@@ -138,15 +137,8 @@ def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext:
138137
"""
139138
pem = ssl.DER_cert_to_PEM_cert(der_cert)
140139

141-
cert_file = tempfile.NamedTemporaryFile(
142-
prefix="og_tee_tls_", suffix=".pem", delete=False, mode="w"
143-
)
144-
cert_file.write(pem)
145-
cert_file.flush()
146-
cert_file.close()
147-
148140
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
149-
ctx.load_verify_locations(cert_file.name)
141+
ctx.load_verify_locations(cadata=pem)
150142
ctx.check_hostname = False # TEE cert may be issued for a hostname; we connect via IP
151143
ctx.verify_mode = ssl.CERT_REQUIRED
152144
return ctx

tests/langchain_adapter_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,34 +26,34 @@ def mock_client():
2626
@pytest.fixture
2727
def model(mock_client):
2828
"""Create an OpenGradientChatModel with a mocked client."""
29-
return OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_4O)
29+
return OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_5)
3030

3131

3232
class TestOpenGradientChatModel:
3333
def test_initialization(self, model):
3434
"""Test model initializes with correct fields."""
35-
assert model.model_cid == TEE_LLM.GPT_4O
35+
assert model.model_cid == TEE_LLM.GPT_5
3636
assert model.max_tokens == 300
3737
assert model.x402_settlement_mode == x402SettlementMode.SETTLE_BATCH
3838
assert model._llm_type == "opengradient"
3939

4040
def test_initialization_custom_max_tokens(self, mock_client):
4141
"""Test model initializes with custom max_tokens."""
42-
model = OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.CLAUDE_3_5_HAIKU, max_tokens=1000)
42+
model = OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, max_tokens=1000)
4343
assert model.max_tokens == 1000
4444

4545
def test_initialization_custom_settlement_mode(self, mock_client):
4646
"""Test model initializes with custom settlement mode."""
4747
model = OpenGradientChatModel(
4848
private_key="0x" + "a" * 64,
49-
model_cid=TEE_LLM.GPT_4O,
49+
model_cid=TEE_LLM.GPT_5,
5050
x402_settlement_mode=x402SettlementMode.SETTLE,
5151
)
5252
assert model.x402_settlement_mode == x402SettlementMode.SETTLE
5353

5454
def test_identifying_params(self, model):
5555
"""Test _identifying_params returns model name."""
56-
assert model._identifying_params == {"model_name": TEE_LLM.GPT_4O}
56+
assert model._identifying_params == {"model_name": TEE_LLM.GPT_5}
5757

5858

5959
class TestGenerate:
@@ -210,7 +210,7 @@ def test_passes_correct_params_to_client(self, model, mock_client):
210210
model._generate([HumanMessage(content="Hi")], stop=["END"])
211211

212212
mock_client.llm.chat.assert_called_once_with(
213-
model=TEE_LLM.GPT_4O,
213+
model=TEE_LLM.GPT_5,
214214
messages=[{"role": "user", "content": "Hi"}],
215215
stop_sequence=["END"],
216216
max_tokens=300,

tests/tee_registry_test.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import os
2+
import ssl
3+
import sys
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
9+
10+
from src.opengradient.client.tee_registry import (
11+
TEE_TYPE_LLM_PROXY,
12+
TEE_TYPE_VALIDATOR,
13+
TEEEndpoint,
14+
TEERegistry,
15+
build_ssl_context_from_der,
16+
)
17+
18+
19+
# --- Helpers ---
20+
21+
22+
def _make_tee_info(
23+
endpoint="https://tee.example.com",
24+
payment_address="0xPayment",
25+
tls_cert_der=b"\x01\x02\x03",
26+
active=True,
27+
):
28+
"""Build a tuple matching the TEEInfo struct order from the contract."""
29+
return (
30+
"0xOwner", # owner
31+
payment_address, # paymentAddress
32+
endpoint, # endpoint
33+
b"pubkey", # publicKey
34+
tls_cert_der, # tlsCertificate
35+
b"pcrhash", # pcrHash
36+
0, # teeType
37+
active, # active
38+
1000, # registeredAt
39+
2000, # lastUpdatedAt
40+
)
41+
42+
43+
def _make_self_signed_der() -> bytes:
44+
"""Generate a minimal self-signed DER certificate for testing."""
45+
from cryptography import x509
46+
from cryptography.hazmat.primitives import hashes, serialization
47+
from cryptography.hazmat.primitives.asymmetric import rsa
48+
from cryptography.x509.oid import NameOID
49+
import datetime
50+
51+
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
52+
subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
53+
cert = (
54+
x509.CertificateBuilder()
55+
.subject_name(subject)
56+
.issuer_name(issuer)
57+
.public_key(key.public_key())
58+
.serial_number(x509.random_serial_number())
59+
.not_valid_before(datetime.datetime.now(datetime.UTC))
60+
.not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1))
61+
.sign(key, hashes.SHA256())
62+
)
63+
return cert.public_bytes(serialization.Encoding.DER)
64+
65+
66+
# --- Fixtures ---
67+
68+
69+
@pytest.fixture
70+
def mock_contract():
71+
"""Create a TEERegistry with a mocked Web3 contract."""
72+
with (
73+
patch("src.opengradient.client.tee_registry.Web3") as mock_web3_cls,
74+
patch("src.opengradient.client.tee_registry.get_abi") as mock_get_abi,
75+
):
76+
mock_get_abi.return_value = []
77+
mock_web3 = MagicMock()
78+
mock_web3_cls.return_value = mock_web3
79+
mock_web3_cls.HTTPProvider.return_value = MagicMock()
80+
mock_web3_cls.to_checksum_address.side_effect = lambda x: x
81+
82+
contract = MagicMock()
83+
mock_web3.eth.contract.return_value = contract
84+
85+
registry = TEERegistry(rpc_url="http://localhost:8545", registry_address="0xRegistry")
86+
yield registry, contract
87+
88+
89+
# --- TEERegistry Tests ---
90+
91+
92+
class TestGetActiveTeesByType:
93+
def test_returns_active_tees(self, mock_contract):
94+
registry, contract = mock_contract
95+
96+
tee_id = b"\xaa" * 32
97+
contract.functions.getTEEsByType.return_value.call.return_value = [tee_id]
98+
contract.functions.getTEE.return_value.call.return_value = _make_tee_info()
99+
100+
result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
101+
102+
assert len(result) == 1
103+
assert result[0].tee_id == tee_id.hex()
104+
assert result[0].endpoint == "https://tee.example.com"
105+
assert result[0].payment_address == "0xPayment"
106+
assert result[0].tls_cert_der == b"\x01\x02\x03"
107+
108+
def test_skips_inactive_tees(self, mock_contract):
109+
registry, contract = mock_contract
110+
111+
tee_id = b"\xbb" * 32
112+
contract.functions.getTEEsByType.return_value.call.return_value = [tee_id]
113+
contract.functions.getTEE.return_value.call.return_value = _make_tee_info(active=False)
114+
115+
result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
116+
assert len(result) == 0
117+
118+
def test_skips_tee_with_empty_endpoint(self, mock_contract):
119+
registry, contract = mock_contract
120+
121+
tee_id = b"\xcc" * 32
122+
contract.functions.getTEEsByType.return_value.call.return_value = [tee_id]
123+
contract.functions.getTEE.return_value.call.return_value = _make_tee_info(endpoint="")
124+
125+
result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
126+
assert len(result) == 0
127+
128+
def test_skips_tee_with_empty_cert(self, mock_contract):
129+
registry, contract = mock_contract
130+
131+
tee_id = b"\xdd" * 32
132+
contract.functions.getTEEsByType.return_value.call.return_value = [tee_id]
133+
contract.functions.getTEE.return_value.call.return_value = _make_tee_info(tls_cert_der=b"")
134+
135+
result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
136+
assert len(result) == 0
137+
138+
def test_returns_empty_on_rpc_failure(self, mock_contract):
139+
registry, contract = mock_contract
140+
141+
contract.functions.getTEEsByType.return_value.call.side_effect = Exception("RPC error")
142+
143+
result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
144+
assert result == []
145+
146+
def test_skips_individual_tee_on_lookup_failure(self, mock_contract):
147+
registry, contract = mock_contract
148+
149+
good_id = b"\xaa" * 32
150+
bad_id = b"\xbb" * 32
151+
contract.functions.getTEEsByType.return_value.call.return_value = [bad_id, good_id]
152+
153+
def get_tee_side_effect(tee_id):
154+
mock = MagicMock()
155+
if tee_id == bad_id:
156+
mock.call.side_effect = Exception("lookup failed")
157+
else:
158+
mock.call.return_value = _make_tee_info()
159+
return mock
160+
161+
contract.functions.getTEE.side_effect = get_tee_side_effect
162+
163+
result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
164+
assert len(result) == 1
165+
assert result[0].tee_id == good_id.hex()
166+
167+
def test_multiple_active_tees(self, mock_contract):
168+
registry, contract = mock_contract
169+
170+
ids = [b"\x01" * 32, b"\x02" * 32, b"\x03" * 32]
171+
contract.functions.getTEEsByType.return_value.call.return_value = ids
172+
173+
def get_tee_side_effect(tee_id):
174+
mock = MagicMock()
175+
mock.call.return_value = _make_tee_info(
176+
endpoint=f"https://tee-{tee_id.hex()[:4]}.example.com"
177+
)
178+
return mock
179+
180+
contract.functions.getTEE.side_effect = get_tee_side_effect
181+
182+
result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
183+
assert len(result) == 3
184+
185+
def test_validator_type_label(self, mock_contract):
186+
"""Ensure validator type queries work the same way."""
187+
registry, contract = mock_contract
188+
189+
contract.functions.getTEEsByType.return_value.call.return_value = []
190+
191+
result = registry.get_active_tees_by_type(TEE_TYPE_VALIDATOR)
192+
assert result == []
193+
contract.functions.getTEEsByType.assert_called_once_with(TEE_TYPE_VALIDATOR)
194+
195+
196+
class TestGetLlmTee:
197+
def test_returns_first_active_tee(self, mock_contract):
198+
registry, contract = mock_contract
199+
200+
ids = [b"\x01" * 32, b"\x02" * 32]
201+
contract.functions.getTEEsByType.return_value.call.return_value = ids
202+
contract.functions.getTEE.return_value.call.return_value = _make_tee_info()
203+
204+
result = registry.get_llm_tee()
205+
206+
assert result is not None
207+
assert result.tee_id == ids[0].hex()
208+
209+
def test_returns_none_when_no_tees(self, mock_contract):
210+
registry, contract = mock_contract
211+
212+
contract.functions.getTEEsByType.return_value.call.return_value = []
213+
214+
result = registry.get_llm_tee()
215+
assert result is None
216+
217+
def test_queries_llm_proxy_type(self, mock_contract):
218+
registry, contract = mock_contract
219+
220+
contract.functions.getTEEsByType.return_value.call.return_value = []
221+
registry.get_llm_tee()
222+
223+
contract.functions.getTEEsByType.assert_called_once_with(TEE_TYPE_LLM_PROXY)
224+
225+
226+
# --- build_ssl_context_from_der Tests ---
227+
228+
229+
class TestBuildSslContextFromDer:
230+
def test_returns_ssl_context(self):
231+
der_cert = _make_self_signed_der()
232+
ctx = build_ssl_context_from_der(der_cert)
233+
234+
assert isinstance(ctx, ssl.SSLContext)
235+
236+
def test_hostname_check_disabled(self):
237+
der_cert = _make_self_signed_der()
238+
ctx = build_ssl_context_from_der(der_cert)
239+
240+
assert ctx.check_hostname is False
241+
242+
def test_cert_required(self):
243+
der_cert = _make_self_signed_der()
244+
ctx = build_ssl_context_from_der(der_cert)
245+
246+
assert ctx.verify_mode == ssl.CERT_REQUIRED
247+
248+
def test_rejects_invalid_der(self):
249+
with pytest.raises(Exception):
250+
build_ssl_context_from_der(b"not-a-valid-cert")

0 commit comments

Comments
 (0)