Skip to content

Commit 11af494

Browse files
author
balogh.adam@icloud.com
committed
tls test ete
1 parent 97e36ca commit 11af494

2 files changed

Lines changed: 137 additions & 4 deletions

File tree

src/opengradient/client/tee_connection.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,10 @@ def _connect(self) -> ActiveTEE:
142142
"""Resolve TEE from registry and create a secure HTTP client."""
143143
tee = self._resolve_tee()
144144

145-
ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) if tee.tls_cert_der else None
146-
tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True
147-
145+
ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der)
148146
return ActiveTEE(
149147
endpoint=tee.endpoint,
150-
http_client=x402HttpxClient(self._x402_client, verify=tls_verify),
148+
http_client=x402HttpxClient(self._x402_client, verify=ssl_ctx),
151149
tee_id=tee.tee_id,
152150
payment_address=tee.payment_address,
153151
)

tests/tee_connection_test.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
"""Tests for RegistryTEEConnection and ActiveTEE."""
22

33
import asyncio
4+
import datetime
5+
import os
46
import ssl
7+
import tempfile
58
from unittest.mock import AsyncMock, MagicMock, patch
69

10+
import httpx
711
import pytest
12+
from cryptography import x509
13+
from cryptography.hazmat.primitives import hashes, serialization
14+
from cryptography.hazmat.primitives.asymmetric import rsa
15+
from cryptography.x509.oid import NameOID
16+
from x402 import x402Client
817

918
from src.opengradient.client.tee_connection import (
1019
ActiveTEE,
1120
RegistryTEEConnection,
1221
)
22+
from src.opengradient.client.tee_registry import build_ssl_context_from_der
1323

1424

1525
# ── Helpers ──────────────────────────────────────────────────────────
@@ -327,3 +337,128 @@ async def test_close_without_refresh_task(self):
327337
conn = _make_registry_connection(registry=mock_reg)
328338

329339
await conn.close() # should not raise
340+
341+
342+
# ── TLS certificate verification (real handshake) ────────────────────
343+
344+
345+
def _make_self_signed_cert():
346+
"""Generate a self-signed cert. Returns (der_bytes, pem_cert_bytes, pem_key_bytes)."""
347+
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
348+
subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")])
349+
cert = (
350+
x509.CertificateBuilder()
351+
.subject_name(subject)
352+
.issuer_name(issuer)
353+
.public_key(key.public_key())
354+
.serial_number(x509.random_serial_number())
355+
.not_valid_before(datetime.datetime.now(datetime.UTC))
356+
.not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1))
357+
.sign(key, hashes.SHA256())
358+
)
359+
return (
360+
cert.public_bytes(serialization.Encoding.DER),
361+
cert.public_bytes(serialization.Encoding.PEM),
362+
key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()),
363+
)
364+
365+
366+
@pytest.fixture
367+
async def tls_server():
368+
"""Spin up a local TLS server with a self-signed cert."""
369+
der, pem_cert, pem_key = _make_self_signed_cert()
370+
371+
cert_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False)
372+
key_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False)
373+
try:
374+
cert_file.write(pem_cert)
375+
cert_file.close()
376+
key_file.write(pem_key)
377+
key_file.close()
378+
379+
server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
380+
server_ctx.load_cert_chain(cert_file.name, key_file.name)
381+
382+
async def handler(reader, writer):
383+
await reader.read(4096)
384+
writer.write(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok")
385+
await writer.drain()
386+
writer.close()
387+
388+
server = await asyncio.start_server(handler, "127.0.0.1", 0, ssl=server_ctx)
389+
port = server.sockets[0].getsockname()[1]
390+
391+
yield {"port": port, "der": der}
392+
393+
server.close()
394+
await server.wait_closed()
395+
finally:
396+
os.unlink(cert_file.name)
397+
os.unlink(key_file.name)
398+
399+
400+
def _registry_with_real_cert(tls_server):
401+
"""Return a mock registry that serves the local TLS server's real DER cert."""
402+
return _mock_registry_with_tee(
403+
endpoint=f"https://127.0.0.1:{tls_server['port']}",
404+
tls_cert_der=tls_server["der"],
405+
tee_id="tee-real",
406+
payment_address="0xRealPay",
407+
)
408+
409+
410+
@pytest.mark.asyncio
411+
class TestTlsCertVerification:
412+
"""End-to-end TLS handshake tests through RegistryTEEConnection.
413+
414+
A real local TLS server is started with a self-signed cert. The registry
415+
mock returns that cert's DER bytes. RegistryTEEConnection._connect() runs
416+
its real code (build_ssl_context_from_der → x402HttpxClient(verify=ctx))
417+
so the full cert-pinning path is exercised with an actual TLS handshake.
418+
"""
419+
420+
async def test_connect_succeeds_with_matching_cert(self, tls_server):
421+
mock_reg = _registry_with_real_cert(tls_server)
422+
conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)
423+
424+
resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/")
425+
assert resp.status_code == 200
426+
assert conn.get().tee_id == "tee-real"
427+
assert conn.get().payment_address == "0xRealPay"
428+
await conn.close()
429+
430+
async def test_connect_fails_with_wrong_cert(self, tls_server):
431+
wrong_der, _, _ = _make_self_signed_cert() # different key pair
432+
mock_reg = _mock_registry_with_tee(
433+
endpoint=f"https://127.0.0.1:{tls_server['port']}",
434+
tls_cert_der=wrong_der,
435+
)
436+
conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)
437+
438+
with pytest.raises(httpx.ConnectError):
439+
await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/")
440+
await conn.close()
441+
442+
async def test_connect_fails_with_no_cert_pinning(self, tls_server):
443+
"""Without a pinned cert (tls_cert_der=None), system CAs are used
444+
which won't trust our self-signed server cert."""
445+
mock_reg = _mock_registry_with_tee(
446+
endpoint=f"https://127.0.0.1:{tls_server['port']}",
447+
tls_cert_der=None,
448+
)
449+
conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)
450+
451+
with pytest.raises(httpx.ConnectError):
452+
await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/")
453+
await conn.close()
454+
455+
async def test_reconnect_picks_up_new_cert(self, tls_server):
456+
"""After reconnect, the connection uses the freshly-resolved cert."""
457+
mock_reg = _registry_with_real_cert(tls_server)
458+
conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)
459+
460+
await conn.reconnect()
461+
462+
resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/")
463+
assert resp.status_code == 200
464+
await conn.close()

0 commit comments

Comments
 (0)