55"""
66
77import json
8+ import ssl
89from contextlib import asynccontextmanager
910from typing import List
10- from unittest .mock import MagicMock , patch
11+ from unittest .mock import AsyncMock , MagicMock , patch
1112
1213import httpx
1314import pytest
@@ -31,6 +32,8 @@ def __init__(self, *_args, **_kwargs):
3132 self ._response_body : bytes = b"{}"
3233 self ._post_calls : List [dict ] = []
3334 self ._stream_response = None
35+ self ._error_on_next : BaseException | None = None
36+ self ._stream_error_on_next : BaseException | None = None
3437
3538 def set_response (self , status_code : int , body : dict ) -> None :
3639 self ._response_status = status_code
@@ -43,8 +46,19 @@ def set_stream_response(self, status_code: int, chunks: List[bytes]) -> None:
4346 def post_calls (self ) -> List [dict ]:
4447 return self ._post_calls
4548
49+ def fail_next_post (self , exc : BaseException ) -> None :
50+ """Make the next post() call raise *exc*, then revert to normal."""
51+ self ._error_on_next = exc
52+
53+ def fail_next_stream (self , exc : BaseException ) -> None :
54+ """Make the next stream() call raise *exc*, then revert to normal."""
55+ self ._stream_error_on_next = exc
56+
4657 async def post (self , url : str , * , json = None , headers = None , timeout = None ) -> "_FakeResponse" :
4758 self ._post_calls .append ({"url" : url , "json" : json , "headers" : headers , "timeout" : timeout })
59+ if self ._error_on_next is not None :
60+ exc , self ._error_on_next = self ._error_on_next , None
61+ raise exc
4862 resp = _FakeResponse (self ._response_status , self ._response_body )
4963 if self ._response_status >= 400 :
5064 resp .raise_for_status = MagicMock (side_effect = httpx .HTTPStatusError ("error" , request = MagicMock (), response = MagicMock ()))
@@ -53,6 +67,9 @@ async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_Fa
5367 @asynccontextmanager
5468 async def stream (self , method : str , url : str , * , json = None , headers = None , timeout = None ):
5569 self ._post_calls .append ({"method" : method , "url" : url , "json" : json , "headers" : headers , "timeout" : timeout })
70+ if self ._stream_error_on_next is not None :
71+ exc , self ._stream_error_on_next = self ._stream_error_on_next , None
72+ raise exc
5673 yield self ._stream_response
5774
5875 async def aclose (self ):
@@ -535,3 +552,248 @@ def test_registry_success(self):
535552 assert cert == b"cert-bytes"
536553 assert tee_id == "tee-42"
537554 assert pay_addr == "0xPay"
555+
556+
557+ # ── SSL cause detection tests ────────────────────────────────────────
558+
559+
560+ class TestHasSSLCause :
561+ def test_direct_ssl_error (self ):
562+ assert LLM ._has_ssl_cause (ssl .SSLError ("cert expired" )) is True
563+
564+ def test_wrapped_ssl_error_via_cause (self ):
565+ root = ssl .SSLError ("handshake failed" )
566+ wrapper = RuntimeError ("connection error" )
567+ wrapper .__cause__ = root
568+ assert LLM ._has_ssl_cause (wrapper ) is True
569+
570+ def test_wrapped_ssl_error_via_context (self ):
571+ root = ssl .SSLError ("cert verify failed" )
572+ wrapper = OSError ("transport error" )
573+ wrapper .__context__ = root
574+ assert LLM ._has_ssl_cause (wrapper ) is True
575+
576+ def test_deeply_nested_ssl (self ):
577+ root = ssl .SSLError ("deep" )
578+ mid = OSError ("mid" )
579+ mid .__cause__ = root
580+ top = RuntimeError ("top" )
581+ top .__cause__ = mid
582+ assert LLM ._has_ssl_cause (top ) is True
583+
584+ def test_non_ssl_error (self ):
585+ assert LLM ._has_ssl_cause (ValueError ("not ssl" )) is False
586+
587+ def test_non_ssl_chain (self ):
588+ root = TimeoutError ("timed out" )
589+ wrapper = RuntimeError ("oops" )
590+ wrapper .__cause__ = root
591+ assert LLM ._has_ssl_cause (wrapper ) is False
592+
593+ def test_cycle_detection (self ):
594+ """Self-referencing chain should not loop forever."""
595+ exc = RuntimeError ("cycle" )
596+ exc .__cause__ = exc
597+ assert LLM ._has_ssl_cause (exc ) is False
598+
599+
600+ # ── SSL retry tests (non-streaming) ──────────────────────────────────
601+
602+
603+ def _make_ssl_error (msg : str = "certificate verify failed" ) -> Exception :
604+ """Create a RuntimeError wrapping an SSLError, mimicking httpx behaviour."""
605+ root = ssl .SSLError (msg )
606+ wrapper = RuntimeError (f"connection failed: { msg } " )
607+ wrapper .__cause__ = root
608+ return wrapper
609+
610+
611+ @pytest .mark .asyncio
612+ class TestSSLRetryCompletion :
613+ async def test_retries_on_ssl_and_succeeds (self , fake_http ):
614+ """First call hits SSL error → refresh → second call succeeds."""
615+ fake_http .set_response (200 , {"completion" : "retried ok" , "tee_signature" : "s" , "tee_timestamp" : "t" })
616+ fake_http .fail_next_post (_make_ssl_error ())
617+ llm = _make_llm ()
618+
619+ result = await llm .completion (model = TEE_LLM .GPT_5 , prompt = "Hi" )
620+
621+ assert result .completion_output == "retried ok"
622+ # Two post calls: the failed one + the retry
623+ assert len (fake_http .post_calls ) == 2
624+
625+ async def test_non_ssl_error_not_retried (self , fake_http ):
626+ """A non-SSL error should propagate immediately, no retry."""
627+ fake_http .fail_next_post (ConnectionError ("DNS failed" ))
628+ llm = _make_llm ()
629+
630+ with pytest .raises (RuntimeError , match = "TEE LLM completion failed" ):
631+ await llm .completion (model = TEE_LLM .GPT_5 , prompt = "Hi" )
632+ assert len (fake_http .post_calls ) == 1
633+
634+ async def test_second_ssl_failure_propagates (self , fake_http ):
635+ """If the retry also hits SSL, the error should propagate."""
636+ fake_http .set_response (200 , {"completion" : "ok" })
637+
638+ call_count = 0
639+
640+ async def always_ssl (* args , ** kwargs ):
641+ nonlocal call_count
642+ call_count += 1
643+ raise _make_ssl_error ()
644+
645+ fake_http .post = always_ssl
646+ llm = _make_llm ()
647+
648+ # The second SSL error bubbles out of _call_with_ssl_retry as a
649+ # RuntimeError (our wrapper), then caught by completion()'s outer
650+ # handler which re-wraps it.
651+ with pytest .raises (RuntimeError ):
652+ await llm .completion (model = TEE_LLM .GPT_5 , prompt = "Hi" )
653+ assert call_count == 2 # original + retry
654+
655+
656+ @pytest .mark .asyncio
657+ class TestSSLRetryChat :
658+ async def test_retries_on_ssl_and_succeeds (self , fake_http ):
659+ fake_http .set_response (
660+ 200 ,
661+ {"choices" : [{"message" : {"role" : "assistant" , "content" : "retry ok" }, "finish_reason" : "stop" }]},
662+ )
663+ fake_http .fail_next_post (_make_ssl_error ())
664+ llm = _make_llm ()
665+
666+ result = await llm .chat (model = TEE_LLM .GPT_5 , messages = [{"role" : "user" , "content" : "Hi" }])
667+
668+ assert result .chat_output ["content" ] == "retry ok"
669+ assert len (fake_http .post_calls ) == 2
670+
671+ async def test_non_ssl_error_not_retried (self , fake_http ):
672+ fake_http .fail_next_post (TimeoutError ("timed out" ))
673+ llm = _make_llm ()
674+
675+ with pytest .raises (RuntimeError , match = "TEE LLM chat failed" ):
676+ await llm .chat (model = TEE_LLM .GPT_5 , messages = [{"role" : "user" , "content" : "Hi" }])
677+ assert len (fake_http .post_calls ) == 1
678+
679+
680+ # ── SSL retry tests (streaming) ──────────────────────────────────────
681+
682+
683+ @pytest .mark .asyncio
684+ class TestSSLRetryStreaming :
685+ async def test_retries_stream_on_ssl_before_chunks (self , fake_http ):
686+ """SSL failure during stream setup (no chunks yielded) → retry succeeds."""
687+ fake_http .set_stream_response (
688+ 200 ,
689+ [
690+ b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n \n ' ,
691+ b"data: [DONE]\n \n " ,
692+ ],
693+ )
694+ fake_http .fail_next_stream (_make_ssl_error ())
695+ llm = _make_llm ()
696+
697+ gen = await llm .chat (
698+ model = TEE_LLM .GPT_5 ,
699+ messages = [{"role" : "user" , "content" : "Hi" }],
700+ stream = True ,
701+ )
702+ chunks = [c async for c in gen ]
703+
704+ assert len (chunks ) == 1
705+ assert chunks [0 ].choices [0 ].delta .content == "ok"
706+ # Two stream attempts: failed + retry
707+ assert len (fake_http .post_calls ) == 2
708+
709+ async def test_no_retry_after_chunks_yielded (self , fake_http ):
710+ """SSL failure AFTER chunks were yielded must raise, not retry."""
711+
712+ class _FailMidStream :
713+ """Yields one chunk then raises SSL."""
714+
715+ def __init__ (self ):
716+ self .status_code = 200
717+
718+ async def aiter_raw (self ):
719+ yield b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"partial"},"finish_reason":null}]}\n \n '
720+ raise _make_ssl_error ()
721+
722+ async def aread (self ) -> bytes :
723+ return b""
724+
725+ fake_http ._stream_response = _FailMidStream ()
726+ llm = _make_llm ()
727+
728+ gen = await llm .chat (
729+ model = TEE_LLM .GPT_5 ,
730+ messages = [{"role" : "user" , "content" : "Hi" }],
731+ stream = True ,
732+ )
733+
734+ with pytest .raises (RuntimeError ):
735+ _ = [c async for c in gen ]
736+
737+ # Only one stream call — no retry after partial output
738+ assert len (fake_http .post_calls ) == 1
739+
740+ async def test_non_ssl_stream_error_not_retried (self , fake_http ):
741+ fake_http .fail_next_stream (ConnectionError ("reset" ))
742+ llm = _make_llm ()
743+
744+ gen = await llm .chat (
745+ model = TEE_LLM .GPT_5 ,
746+ messages = [{"role" : "user" , "content" : "Hi" }],
747+ stream = True ,
748+ )
749+
750+ with pytest .raises (ConnectionError ):
751+ _ = [c async for c in gen ]
752+ assert len (fake_http .post_calls ) == 1
753+
754+
755+ # ── _refresh_tee_and_reset tests ─────────────────────────────────────
756+
757+
758+ @pytest .mark .asyncio
759+ class TestRefreshTeeAndReset :
760+ async def test_replaces_http_client (self ):
761+ """After refresh, the http client should be a new instance."""
762+ clients_created = []
763+
764+ def make_client (* args , ** kwargs ):
765+ c = FakeHTTPClient ()
766+ clients_created .append (c )
767+ return c
768+
769+ with (
770+ patch (_PATCHES ["x402_httpx" ], side_effect = make_client ),
771+ patch (_PATCHES ["x402_client" ]),
772+ patch (_PATCHES ["signer" ]),
773+ patch (_PATCHES ["register_exact" ]),
774+ patch (_PATCHES ["register_upto" ]),
775+ ):
776+ llm = _make_llm ()
777+ old_client = llm ._http_client
778+
779+ await llm ._refresh_tee_and_reset ()
780+
781+ assert llm ._http_client is not old_client
782+ assert len (clients_created ) == 2 # init + refresh
783+
784+ async def test_closes_old_client (self , fake_http ):
785+ llm = _make_llm ()
786+ old_client = llm ._http_client
787+ old_client .aclose = AsyncMock ()
788+
789+ await llm ._refresh_tee_and_reset ()
790+
791+ old_client .aclose .assert_awaited_once ()
792+
793+ async def test_close_failure_is_swallowed (self , fake_http ):
794+ llm = _make_llm ()
795+ old_client = llm ._http_client
796+ old_client .aclose = AsyncMock (side_effect = OSError ("already closed" ))
797+
798+ # Should not raise
799+ await llm ._refresh_tee_and_reset ()
0 commit comments