Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/openenv/core/env_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
) # Convert MB to bytes
self._provider = provider
self._ws: Optional[ClientConnection] = None
self._ws_loop: Optional[asyncio.AbstractEventLoop] = None

def __setattr__(self, name: str, value: Any) -> None:
"""Prevent modification of _mode after initialization."""
Expand All @@ -200,7 +201,19 @@ async def connect(self) -> "EnvClient":
ConnectionError: If connection cannot be established
"""
if self._ws is not None:
return self
if self._ws_loop is asyncio.get_running_loop():
return self
# Connected from a different event loop than the one running
# now -- e.g. `client = await Client.from_env(...)` inside
# `asyncio.run(...)`, then `client.sync()` drives every later
# call on `SyncEnvClient`'s own dedicated background loop. The
# websocket object is bound to internals of the original loop,
# which is typically already closed by the time we get here, so
# it cannot be reused (or even cleanly closed) from this loop.
# Drop the stale reference and reconnect fresh below rather than
# silently no-op-ing onto a dead connection.
self._ws = None
self._ws_loop = None

# Disable the proxy for localhost connections via the per-connection
# `proxy` argument rather than mutating the process-global NO_PROXY
Expand All @@ -217,6 +230,7 @@ async def connect(self) -> "EnvClient":
max_size=self._max_message_size,
**connect_kwargs,
)
self._ws_loop = asyncio.get_running_loop()
except Exception as e:
raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e

Expand All @@ -235,6 +249,7 @@ async def disconnect(self) -> None:
except Exception:
pass
self._ws = None
self._ws_loop = None

async def _ensure_connected(self) -> None:
"""Ensure WebSocket connection is established."""
Expand Down
85 changes: 85 additions & 0 deletions tests/test_core/test_generic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,91 @@ async def fake_ws_connect(*args, **kwargs):
assert "NO_PROXY" not in os.environ


class TestForeignLoopReconnect:
"""`connect()` must not silently no-op onto a websocket bound to a
different (often already-closed) event loop -- e.g. a client connected
via `await Client.from_env(...)` inside `asyncio.run(...)`, then driven
afterwards through `.sync()`'s own dedicated background loop.
"""

@pytest.mark.asyncio
async def test_same_loop_reconnect_is_a_noop(self):
"""Calling connect() again on the loop that created _ws is still a
cheap no-op -- only a genuinely different loop should reconnect.
"""
client = GenericEnvClient(base_url="http://localhost:8000")

async def fake_ws_connect(*args, **kwargs):
return MagicMock()

with patch(
"openenv.core.env_client.ws_connect", side_effect=fake_ws_connect
) as mock_connect:
await client.connect()
await client.connect()

mock_connect.assert_called_once()

@pytest.mark.asyncio
async def test_foreign_loop_triggers_reconnect_not_noop(self):
"""If _ws was bound to a different loop than the one connect() is
now running on, connect() must drop the stale reference and
establish a fresh connection on the current loop, rather than
returning early as if already connected.
"""
client = GenericEnvClient(base_url="http://localhost:8000")

first_ws = MagicMock()

async def fake_ws_connect_first(*args, **kwargs):
return first_ws

with patch(
"openenv.core.env_client.ws_connect", side_effect=fake_ws_connect_first
):
await client.connect()

assert client._ws is first_ws

# Simulate "connected on a different loop": the real scenario is two
# different asyncio event loops, which isn't practical to spin up
# for a unit test, so we directly substitute a sentinel loop object
# that isn't the one currently running.
client._ws_loop = object()

second_ws = MagicMock()

async def fake_ws_connect_second(*args, **kwargs):
return second_ws

with patch(
"openenv.core.env_client.ws_connect", side_effect=fake_ws_connect_second
) as mock_connect:
await client.connect()

mock_connect.assert_called_once()
assert client._ws is second_ws
assert client._ws_loop is asyncio.get_running_loop()

@pytest.mark.asyncio
async def test_disconnect_clears_ws_loop(self):
"""disconnect() must clear _ws_loop along with _ws so a later
connect() on the same loop doesn't think it's still connected.
"""
client = GenericEnvClient(base_url="http://localhost:8000")

async def fake_ws_connect(*args, **kwargs):
return MagicMock()

with patch("openenv.core.env_client.ws_connect", side_effect=fake_ws_connect):
await client.connect()

await client.disconnect()

assert client._ws is None
assert client._ws_loop is None


# ============================================================================
# Integration Tests (require running server)
# ============================================================================
Expand Down