|
1 | 1 | """Tests for RegistryTEEConnection and ActiveTEE.""" |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import datetime |
| 5 | +import os |
4 | 6 | import ssl |
| 7 | +import tempfile |
5 | 8 | from unittest.mock import AsyncMock, MagicMock, patch |
6 | 9 |
|
| 10 | +import httpx |
7 | 11 | import pytest |
| 12 | +from cryptography import x509 |
| 13 | +from cryptography.hazmat.primitives import hashes, serialization |
| 14 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 15 | +from cryptography.x509.oid import NameOID |
| 16 | +from x402 import x402Client |
8 | 17 |
|
9 | 18 | from src.opengradient.client.tee_connection import ( |
10 | 19 | ActiveTEE, |
11 | 20 | RegistryTEEConnection, |
12 | 21 | ) |
| 22 | +from src.opengradient.client.tee_registry import build_ssl_context_from_der |
13 | 23 |
|
14 | 24 |
|
15 | 25 | # ── Helpers ────────────────────────────────────────────────────────── |
@@ -327,3 +337,128 @@ async def test_close_without_refresh_task(self): |
327 | 337 | conn = _make_registry_connection(registry=mock_reg) |
328 | 338 |
|
329 | 339 | await conn.close() # should not raise |
| 340 | + |
| 341 | + |
| 342 | +# ── TLS certificate verification (real handshake) ──────────────────── |
| 343 | + |
| 344 | + |
| 345 | +def _make_self_signed_cert(): |
| 346 | + """Generate a self-signed cert. Returns (der_bytes, pem_cert_bytes, pem_key_bytes).""" |
| 347 | + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| 348 | + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")]) |
| 349 | + cert = ( |
| 350 | + x509.CertificateBuilder() |
| 351 | + .subject_name(subject) |
| 352 | + .issuer_name(issuer) |
| 353 | + .public_key(key.public_key()) |
| 354 | + .serial_number(x509.random_serial_number()) |
| 355 | + .not_valid_before(datetime.datetime.now(datetime.UTC)) |
| 356 | + .not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)) |
| 357 | + .sign(key, hashes.SHA256()) |
| 358 | + ) |
| 359 | + return ( |
| 360 | + cert.public_bytes(serialization.Encoding.DER), |
| 361 | + cert.public_bytes(serialization.Encoding.PEM), |
| 362 | + key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()), |
| 363 | + ) |
| 364 | + |
| 365 | + |
| 366 | +@pytest.fixture |
| 367 | +async def tls_server(): |
| 368 | + """Spin up a local TLS server with a self-signed cert.""" |
| 369 | + der, pem_cert, pem_key = _make_self_signed_cert() |
| 370 | + |
| 371 | + cert_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) |
| 372 | + key_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) |
| 373 | + try: |
| 374 | + cert_file.write(pem_cert) |
| 375 | + cert_file.close() |
| 376 | + key_file.write(pem_key) |
| 377 | + key_file.close() |
| 378 | + |
| 379 | + server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
| 380 | + server_ctx.load_cert_chain(cert_file.name, key_file.name) |
| 381 | + |
| 382 | + async def handler(reader, writer): |
| 383 | + await reader.read(4096) |
| 384 | + writer.write(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok") |
| 385 | + await writer.drain() |
| 386 | + writer.close() |
| 387 | + |
| 388 | + server = await asyncio.start_server(handler, "127.0.0.1", 0, ssl=server_ctx) |
| 389 | + port = server.sockets[0].getsockname()[1] |
| 390 | + |
| 391 | + yield {"port": port, "der": der} |
| 392 | + |
| 393 | + server.close() |
| 394 | + await server.wait_closed() |
| 395 | + finally: |
| 396 | + os.unlink(cert_file.name) |
| 397 | + os.unlink(key_file.name) |
| 398 | + |
| 399 | + |
| 400 | +def _registry_with_real_cert(tls_server): |
| 401 | + """Return a mock registry that serves the local TLS server's real DER cert.""" |
| 402 | + return _mock_registry_with_tee( |
| 403 | + endpoint=f"https://127.0.0.1:{tls_server['port']}", |
| 404 | + tls_cert_der=tls_server["der"], |
| 405 | + tee_id="tee-real", |
| 406 | + payment_address="0xRealPay", |
| 407 | + ) |
| 408 | + |
| 409 | + |
| 410 | +@pytest.mark.asyncio |
| 411 | +class TestTlsCertVerification: |
| 412 | + """End-to-end TLS handshake tests through RegistryTEEConnection. |
| 413 | +
|
| 414 | + A real local TLS server is started with a self-signed cert. The registry |
| 415 | + mock returns that cert's DER bytes. RegistryTEEConnection._connect() runs |
| 416 | + its real code (build_ssl_context_from_der → x402HttpxClient(verify=ctx)) |
| 417 | + so the full cert-pinning path is exercised with an actual TLS handshake. |
| 418 | + """ |
| 419 | + |
| 420 | + async def test_connect_succeeds_with_matching_cert(self, tls_server): |
| 421 | + mock_reg = _registry_with_real_cert(tls_server) |
| 422 | + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) |
| 423 | + |
| 424 | + resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") |
| 425 | + assert resp.status_code == 200 |
| 426 | + assert conn.get().tee_id == "tee-real" |
| 427 | + assert conn.get().payment_address == "0xRealPay" |
| 428 | + await conn.close() |
| 429 | + |
| 430 | + async def test_connect_fails_with_wrong_cert(self, tls_server): |
| 431 | + wrong_der, _, _ = _make_self_signed_cert() # different key pair |
| 432 | + mock_reg = _mock_registry_with_tee( |
| 433 | + endpoint=f"https://127.0.0.1:{tls_server['port']}", |
| 434 | + tls_cert_der=wrong_der, |
| 435 | + ) |
| 436 | + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) |
| 437 | + |
| 438 | + with pytest.raises(httpx.ConnectError): |
| 439 | + await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") |
| 440 | + await conn.close() |
| 441 | + |
| 442 | + async def test_connect_fails_with_no_cert_pinning(self, tls_server): |
| 443 | + """Without a pinned cert (tls_cert_der=None), system CAs are used |
| 444 | + which won't trust our self-signed server cert.""" |
| 445 | + mock_reg = _mock_registry_with_tee( |
| 446 | + endpoint=f"https://127.0.0.1:{tls_server['port']}", |
| 447 | + tls_cert_der=None, |
| 448 | + ) |
| 449 | + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) |
| 450 | + |
| 451 | + with pytest.raises(httpx.ConnectError): |
| 452 | + await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") |
| 453 | + await conn.close() |
| 454 | + |
| 455 | + async def test_reconnect_picks_up_new_cert(self, tls_server): |
| 456 | + """After reconnect, the connection uses the freshly-resolved cert.""" |
| 457 | + mock_reg = _registry_with_real_cert(tls_server) |
| 458 | + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) |
| 459 | + |
| 460 | + await conn.reconnect() |
| 461 | + |
| 462 | + resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") |
| 463 | + assert resp.status_code == 200 |
| 464 | + await conn.close() |
0 commit comments