Skip to content

Commit 5e1e10a

Browse files
committed
tests
1 parent f680d72 commit 5e1e10a

1 file changed

Lines changed: 263 additions & 1 deletion

File tree

tests/llm_test.py

Lines changed: 263 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
"""
66

77
import json
8+
import ssl
89
from contextlib import asynccontextmanager
910
from typing import List
10-
from unittest.mock import MagicMock, patch
11+
from unittest.mock import AsyncMock, MagicMock, patch
1112

1213
import httpx
1314
import pytest
@@ -31,6 +32,8 @@ def __init__(self, *_args, **_kwargs):
3132
self._response_body: bytes = b"{}"
3233
self._post_calls: List[dict] = []
3334
self._stream_response = None
35+
self._error_on_next: BaseException | None = None
36+
self._stream_error_on_next: BaseException | None = None
3437

3538
def set_response(self, status_code: int, body: dict) -> None:
3639
self._response_status = status_code
@@ -43,8 +46,19 @@ def set_stream_response(self, status_code: int, chunks: List[bytes]) -> None:
4346
def post_calls(self) -> List[dict]:
4447
return self._post_calls
4548

49+
def fail_next_post(self, exc: BaseException) -> None:
50+
"""Make the next post() call raise *exc*, then revert to normal."""
51+
self._error_on_next = exc
52+
53+
def fail_next_stream(self, exc: BaseException) -> None:
54+
"""Make the next stream() call raise *exc*, then revert to normal."""
55+
self._stream_error_on_next = exc
56+
4657
async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_FakeResponse":
4758
self._post_calls.append({"url": url, "json": json, "headers": headers, "timeout": timeout})
59+
if self._error_on_next is not None:
60+
exc, self._error_on_next = self._error_on_next, None
61+
raise exc
4862
resp = _FakeResponse(self._response_status, self._response_body)
4963
if self._response_status >= 400:
5064
resp.raise_for_status = MagicMock(side_effect=httpx.HTTPStatusError("error", request=MagicMock(), response=MagicMock()))
@@ -53,6 +67,9 @@ async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_Fa
5367
@asynccontextmanager
5468
async def stream(self, method: str, url: str, *, json=None, headers=None, timeout=None):
5569
self._post_calls.append({"method": method, "url": url, "json": json, "headers": headers, "timeout": timeout})
70+
if self._stream_error_on_next is not None:
71+
exc, self._stream_error_on_next = self._stream_error_on_next, None
72+
raise exc
5673
yield self._stream_response
5774

5875
async def aclose(self):
@@ -535,3 +552,248 @@ def test_registry_success(self):
535552
assert cert == b"cert-bytes"
536553
assert tee_id == "tee-42"
537554
assert pay_addr == "0xPay"
555+
556+
557+
# ── SSL cause detection tests ────────────────────────────────────────
558+
559+
560+
class TestHasSSLCause:
561+
def test_direct_ssl_error(self):
562+
assert LLM._has_ssl_cause(ssl.SSLError("cert expired")) is True
563+
564+
def test_wrapped_ssl_error_via_cause(self):
565+
root = ssl.SSLError("handshake failed")
566+
wrapper = RuntimeError("connection error")
567+
wrapper.__cause__ = root
568+
assert LLM._has_ssl_cause(wrapper) is True
569+
570+
def test_wrapped_ssl_error_via_context(self):
571+
root = ssl.SSLError("cert verify failed")
572+
wrapper = OSError("transport error")
573+
wrapper.__context__ = root
574+
assert LLM._has_ssl_cause(wrapper) is True
575+
576+
def test_deeply_nested_ssl(self):
577+
root = ssl.SSLError("deep")
578+
mid = OSError("mid")
579+
mid.__cause__ = root
580+
top = RuntimeError("top")
581+
top.__cause__ = mid
582+
assert LLM._has_ssl_cause(top) is True
583+
584+
def test_non_ssl_error(self):
585+
assert LLM._has_ssl_cause(ValueError("not ssl")) is False
586+
587+
def test_non_ssl_chain(self):
588+
root = TimeoutError("timed out")
589+
wrapper = RuntimeError("oops")
590+
wrapper.__cause__ = root
591+
assert LLM._has_ssl_cause(wrapper) is False
592+
593+
def test_cycle_detection(self):
594+
"""Self-referencing chain should not loop forever."""
595+
exc = RuntimeError("cycle")
596+
exc.__cause__ = exc
597+
assert LLM._has_ssl_cause(exc) is False
598+
599+
600+
# ── SSL retry tests (non-streaming) ──────────────────────────────────
601+
602+
603+
def _make_ssl_error(msg: str = "certificate verify failed") -> Exception:
604+
"""Create a RuntimeError wrapping an SSLError, mimicking httpx behaviour."""
605+
root = ssl.SSLError(msg)
606+
wrapper = RuntimeError(f"connection failed: {msg}")
607+
wrapper.__cause__ = root
608+
return wrapper
609+
610+
611+
@pytest.mark.asyncio
612+
class TestSSLRetryCompletion:
613+
async def test_retries_on_ssl_and_succeeds(self, fake_http):
614+
"""First call hits SSL error → refresh → second call succeeds."""
615+
fake_http.set_response(200, {"completion": "retried ok", "tee_signature": "s", "tee_timestamp": "t"})
616+
fake_http.fail_next_post(_make_ssl_error())
617+
llm = _make_llm()
618+
619+
result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi")
620+
621+
assert result.completion_output == "retried ok"
622+
# Two post calls: the failed one + the retry
623+
assert len(fake_http.post_calls) == 2
624+
625+
async def test_non_ssl_error_not_retried(self, fake_http):
626+
"""A non-SSL error should propagate immediately, no retry."""
627+
fake_http.fail_next_post(ConnectionError("DNS failed"))
628+
llm = _make_llm()
629+
630+
with pytest.raises(RuntimeError, match="TEE LLM completion failed"):
631+
await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi")
632+
assert len(fake_http.post_calls) == 1
633+
634+
async def test_second_ssl_failure_propagates(self, fake_http):
635+
"""If the retry also hits SSL, the error should propagate."""
636+
fake_http.set_response(200, {"completion": "ok"})
637+
638+
call_count = 0
639+
640+
async def always_ssl(*args, **kwargs):
641+
nonlocal call_count
642+
call_count += 1
643+
raise _make_ssl_error()
644+
645+
fake_http.post = always_ssl
646+
llm = _make_llm()
647+
648+
# The second SSL error bubbles out of _call_with_ssl_retry as a
649+
# RuntimeError (our wrapper), then caught by completion()'s outer
650+
# handler which re-wraps it.
651+
with pytest.raises(RuntimeError):
652+
await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi")
653+
assert call_count == 2 # original + retry
654+
655+
656+
@pytest.mark.asyncio
657+
class TestSSLRetryChat:
658+
async def test_retries_on_ssl_and_succeeds(self, fake_http):
659+
fake_http.set_response(
660+
200,
661+
{"choices": [{"message": {"role": "assistant", "content": "retry ok"}, "finish_reason": "stop"}]},
662+
)
663+
fake_http.fail_next_post(_make_ssl_error())
664+
llm = _make_llm()
665+
666+
result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}])
667+
668+
assert result.chat_output["content"] == "retry ok"
669+
assert len(fake_http.post_calls) == 2
670+
671+
async def test_non_ssl_error_not_retried(self, fake_http):
672+
fake_http.fail_next_post(TimeoutError("timed out"))
673+
llm = _make_llm()
674+
675+
with pytest.raises(RuntimeError, match="TEE LLM chat failed"):
676+
await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}])
677+
assert len(fake_http.post_calls) == 1
678+
679+
680+
# ── SSL retry tests (streaming) ──────────────────────────────────────
681+
682+
683+
@pytest.mark.asyncio
684+
class TestSSLRetryStreaming:
685+
async def test_retries_stream_on_ssl_before_chunks(self, fake_http):
686+
"""SSL failure during stream setup (no chunks yielded) → retry succeeds."""
687+
fake_http.set_stream_response(
688+
200,
689+
[
690+
b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n',
691+
b"data: [DONE]\n\n",
692+
],
693+
)
694+
fake_http.fail_next_stream(_make_ssl_error())
695+
llm = _make_llm()
696+
697+
gen = await llm.chat(
698+
model=TEE_LLM.GPT_5,
699+
messages=[{"role": "user", "content": "Hi"}],
700+
stream=True,
701+
)
702+
chunks = [c async for c in gen]
703+
704+
assert len(chunks) == 1
705+
assert chunks[0].choices[0].delta.content == "ok"
706+
# Two stream attempts: failed + retry
707+
assert len(fake_http.post_calls) == 2
708+
709+
async def test_no_retry_after_chunks_yielded(self, fake_http):
710+
"""SSL failure AFTER chunks were yielded must raise, not retry."""
711+
712+
class _FailMidStream:
713+
"""Yields one chunk then raises SSL."""
714+
715+
def __init__(self):
716+
self.status_code = 200
717+
718+
async def aiter_raw(self):
719+
yield b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"partial"},"finish_reason":null}]}\n\n'
720+
raise _make_ssl_error()
721+
722+
async def aread(self) -> bytes:
723+
return b""
724+
725+
fake_http._stream_response = _FailMidStream()
726+
llm = _make_llm()
727+
728+
gen = await llm.chat(
729+
model=TEE_LLM.GPT_5,
730+
messages=[{"role": "user", "content": "Hi"}],
731+
stream=True,
732+
)
733+
734+
with pytest.raises(RuntimeError):
735+
_ = [c async for c in gen]
736+
737+
# Only one stream call — no retry after partial output
738+
assert len(fake_http.post_calls) == 1
739+
740+
async def test_non_ssl_stream_error_not_retried(self, fake_http):
741+
fake_http.fail_next_stream(ConnectionError("reset"))
742+
llm = _make_llm()
743+
744+
gen = await llm.chat(
745+
model=TEE_LLM.GPT_5,
746+
messages=[{"role": "user", "content": "Hi"}],
747+
stream=True,
748+
)
749+
750+
with pytest.raises(ConnectionError):
751+
_ = [c async for c in gen]
752+
assert len(fake_http.post_calls) == 1
753+
754+
755+
# ── _refresh_tee_and_reset tests ─────────────────────────────────────
756+
757+
758+
@pytest.mark.asyncio
759+
class TestRefreshTeeAndReset:
760+
async def test_replaces_http_client(self):
761+
"""After refresh, the http client should be a new instance."""
762+
clients_created = []
763+
764+
def make_client(*args, **kwargs):
765+
c = FakeHTTPClient()
766+
clients_created.append(c)
767+
return c
768+
769+
with (
770+
patch(_PATCHES["x402_httpx"], side_effect=make_client),
771+
patch(_PATCHES["x402_client"]),
772+
patch(_PATCHES["signer"]),
773+
patch(_PATCHES["register_exact"]),
774+
patch(_PATCHES["register_upto"]),
775+
):
776+
llm = _make_llm()
777+
old_client = llm._http_client
778+
779+
await llm._refresh_tee_and_reset()
780+
781+
assert llm._http_client is not old_client
782+
assert len(clients_created) == 2 # init + refresh
783+
784+
async def test_closes_old_client(self, fake_http):
785+
llm = _make_llm()
786+
old_client = llm._http_client
787+
old_client.aclose = AsyncMock()
788+
789+
await llm._refresh_tee_and_reset()
790+
791+
old_client.aclose.assert_awaited_once()
792+
793+
async def test_close_failure_is_swallowed(self, fake_http):
794+
llm = _make_llm()
795+
old_client = llm._http_client
796+
old_client.aclose = AsyncMock(side_effect=OSError("already closed"))
797+
798+
# Should not raise
799+
await llm._refresh_tee_and_reset()

0 commit comments

Comments
 (0)