diff --git a/src/openenv/core/env_client.py b/src/openenv/core/env_client.py index 7bfd33be2..39bd510e9 100644 --- a/src/openenv/core/env_client.py +++ b/src/openenv/core/env_client.py @@ -182,6 +182,7 @@ def __init__( ) # Convert MB to bytes self._provider = provider self._ws: Optional[ClientConnection] = None + self._ws_loop: Optional[asyncio.AbstractEventLoop] = None def __setattr__(self, name: str, value: Any) -> None: """Prevent modification of _mode after initialization.""" @@ -200,7 +201,19 @@ async def connect(self) -> "EnvClient": ConnectionError: If connection cannot be established """ if self._ws is not None: - return self + if self._ws_loop is asyncio.get_running_loop(): + return self + # Connected from a different event loop than the one running + # now -- e.g. `client = await Client.from_env(...)` inside + # `asyncio.run(...)`, then `client.sync()` drives every later + # call on `SyncEnvClient`'s own dedicated background loop. The + # websocket object is bound to internals of the original loop, + # which is typically already closed by the time we get here, so + # it cannot be reused (or even cleanly closed) from this loop. + # Drop the stale reference and reconnect fresh below rather than + # silently no-op-ing onto a dead connection. + self._ws = None + self._ws_loop = None # Disable the proxy for localhost connections via the per-connection # `proxy` argument rather than mutating the process-global NO_PROXY @@ -217,6 +230,7 @@ async def connect(self) -> "EnvClient": max_size=self._max_message_size, **connect_kwargs, ) + self._ws_loop = asyncio.get_running_loop() except Exception as e: raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e @@ -235,6 +249,7 @@ async def disconnect(self) -> None: except Exception: pass self._ws = None + self._ws_loop = None async def _ensure_connected(self) -> None: """Ensure WebSocket connection is established.""" diff --git a/tests/test_core/test_generic_client.py b/tests/test_core/test_generic_client.py index 9a9121ce6..98632a269 100644 --- a/tests/test_core/test_generic_client.py +++ b/tests/test_core/test_generic_client.py @@ -695,6 +695,91 @@ async def fake_ws_connect(*args, **kwargs): assert "NO_PROXY" not in os.environ +class TestForeignLoopReconnect: + """`connect()` must not silently no-op onto a websocket bound to a + different (often already-closed) event loop -- e.g. a client connected + via `await Client.from_env(...)` inside `asyncio.run(...)`, then driven + afterwards through `.sync()`'s own dedicated background loop. + """ + + @pytest.mark.asyncio + async def test_same_loop_reconnect_is_a_noop(self): + """Calling connect() again on the loop that created _ws is still a + cheap no-op -- only a genuinely different loop should reconnect. + """ + client = GenericEnvClient(base_url="http://localhost:8000") + + async def fake_ws_connect(*args, **kwargs): + return MagicMock() + + with patch( + "openenv.core.env_client.ws_connect", side_effect=fake_ws_connect + ) as mock_connect: + await client.connect() + await client.connect() + + mock_connect.assert_called_once() + + @pytest.mark.asyncio + async def test_foreign_loop_triggers_reconnect_not_noop(self): + """If _ws was bound to a different loop than the one connect() is + now running on, connect() must drop the stale reference and + establish a fresh connection on the current loop, rather than + returning early as if already connected. + """ + client = GenericEnvClient(base_url="http://localhost:8000") + + first_ws = MagicMock() + + async def fake_ws_connect_first(*args, **kwargs): + return first_ws + + with patch( + "openenv.core.env_client.ws_connect", side_effect=fake_ws_connect_first + ): + await client.connect() + + assert client._ws is first_ws + + # Simulate "connected on a different loop": the real scenario is two + # different asyncio event loops, which isn't practical to spin up + # for a unit test, so we directly substitute a sentinel loop object + # that isn't the one currently running. + client._ws_loop = object() + + second_ws = MagicMock() + + async def fake_ws_connect_second(*args, **kwargs): + return second_ws + + with patch( + "openenv.core.env_client.ws_connect", side_effect=fake_ws_connect_second + ) as mock_connect: + await client.connect() + + mock_connect.assert_called_once() + assert client._ws is second_ws + assert client._ws_loop is asyncio.get_running_loop() + + @pytest.mark.asyncio + async def test_disconnect_clears_ws_loop(self): + """disconnect() must clear _ws_loop along with _ws so a later + connect() on the same loop doesn't think it's still connected. + """ + client = GenericEnvClient(base_url="http://localhost:8000") + + async def fake_ws_connect(*args, **kwargs): + return MagicMock() + + with patch("openenv.core.env_client.ws_connect", side_effect=fake_ws_connect): + await client.connect() + + await client.disconnect() + + assert client._ws is None + assert client._ws_loop is None + + # ============================================================================ # Integration Tests (require running server) # ============================================================================