Skip to content

Commit 17836cb

Browse files
committed
Fix load_from_state to work with subclass constructors
Moved load_from_state to SweBenchEnv and HarborEnv since they need different constructor arguments (tasks list). Base class now provides _restore_from_snapshot helper that subclasses call after init. - Add load_from_state implementation to SweBenchEnv - Add load_from_state implementation to HarborEnv - Refactor base class to use _restore_from_snapshot helper - Add test for load_from_state functionality
1 parent abae672 commit 17836cb

File tree

4 files changed

+175
-42
lines changed

4 files changed

+175
-42
lines changed

src/ares/environments/base.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -565,55 +565,27 @@ async def export_state(
565565

566566
return snap
567567

568-
@classmethod
569-
async def load_from_state(
570-
cls,
571-
snapshot_path: snapshot.EnvironmentSnapshot | pathlib.Path,
572-
*,
573-
container_factory: containers.ContainerFactory | None = None,
574-
code_agent_factory: code_agent_base.CodeAgentFactory | None = None,
575-
) -> "CodeBaseEnv":
576-
"""Restore environment from snapshot.
568+
async def _restore_from_snapshot(self, snap: snapshot.EnvironmentSnapshot) -> None:
569+
"""Internal helper to restore state from snapshot.
577570
578-
Args:
579-
snapshot_path: EnvironmentSnapshot or path to snapshot.json
580-
container_factory: Override factory (uses snapshot metadata if None)
581-
code_agent_factory: Override factory (uses default if None)
571+
Called by subclass load_from_state implementations after creating the environment.
582572
583-
Returns:
584-
Restored environment (NOT active, use async with)
573+
Args:
574+
snap: The snapshot to restore from
585575
"""
586-
# Load snapshot if path provided
587-
if isinstance(snapshot_path, pathlib.Path):
588-
snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_path)
589-
else:
590-
snap = snapshot_path
591-
592-
# Deserialize task
593-
task = cls._deserialize_task(snap.task_data, snap.task_type)
594-
595-
# Create environment instance
596-
container_factory = container_factory or cls._default_container_factory(snap)
597-
code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent
576+
# Restore container
577+
self._container = await self._restore_container(snap)
598578

599-
# Note: This creates a base CodeBaseEnv which is abstract
600-
# In practice, this should be called on SweBenchEnv or HarborEnv subclasses
601-
env = cls(
602-
container_factory=container_factory,
603-
code_agent_factory=code_agent_factory,
604-
step_limit=snap.step_limit,
605-
)
579+
# Deserialize and set task
580+
task = self._deserialize_task(snap.task_data, snap.task_type)
581+
self._current_task = task
606582

607-
# Restore container
608-
env._container = await env._restore_container(snap)
609-
env._current_task = task
610-
env._step_count = snap.step_count
611-
env._requires_reset = snap.requires_reset
583+
# Restore state
584+
self._step_count = snap.step_count
585+
self._requires_reset = snap.requires_reset
612586

613587
# Store messages for later restoration
614-
env._saved_agent_messages = snap.agent_messages
615-
616-
return env
588+
self._saved_agent_messages = snap.agent_messages
617589

618590
@abc.abstractmethod
619591
def _serialize_task(self, task: TaskType) -> dict:

src/ares/environments/harbor_env.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,51 @@ def _deserialize_task(cls, task_data: dict, task_type: str) -> harbor_task.Task:
179179
"""Deserialize Harbor task (reload from directory)."""
180180
del task_type # Unused - validated by caller
181181
return harbor_task.Task(task_dir=pathlib.Path(task_data["task_dir"]))
182+
183+
@classmethod
184+
async def load_from_state(
185+
cls,
186+
snapshot_path: "base.snapshot.EnvironmentSnapshot | pathlib.Path",
187+
*,
188+
container_factory: containers.ContainerFactory | None = None,
189+
code_agent_factory: code_agent_base.CodeAgentFactory | None = None,
190+
tracker: stat_tracker.StatTracker | None = None,
191+
) -> "HarborEnv":
192+
"""Restore HarborEnv from snapshot.
193+
194+
Args:
195+
snapshot_path: EnvironmentSnapshot or path to snapshot.json
196+
container_factory: Override factory (uses snapshot metadata if None)
197+
code_agent_factory: Override factory (uses default if None)
198+
tracker: Optional stat tracker
199+
200+
Returns:
201+
Restored environment (NOT active, use async with)
202+
"""
203+
from ares.environments import snapshot as snapshot_module
204+
205+
# Load snapshot if path provided
206+
if isinstance(snapshot_path, pathlib.Path):
207+
snap = snapshot_module.EnvironmentSnapshot.load_from_file(snapshot_path)
208+
else:
209+
snap = snapshot_path
210+
211+
# Deserialize task
212+
task = cls._deserialize_task(snap.task_data, snap.task_type)
213+
214+
# Create environment instance with tasks argument
215+
container_factory = container_factory or base.ares_daytona.DaytonaContainer
216+
code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent
217+
218+
env = cls(
219+
tasks=[task],
220+
container_factory=container_factory,
221+
code_agent_factory=code_agent_factory,
222+
step_limit=snap.step_limit,
223+
tracker=tracker,
224+
)
225+
226+
# Restore state using base helper
227+
await env._restore_from_snapshot(snap)
228+
229+
return env

src/ares/environments/snapshot_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,65 @@ async def test_export_state_raises_if_no_task():
334334

335335
with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(RuntimeError, match="No current task set"):
336336
await env.export_state(pathlib.Path(tmp_dir))
337+
338+
339+
@pytest.mark.asyncio
340+
async def test_load_from_state_creates_valid_env(tmp_path: pathlib.Path):
341+
"""Test load_from_state creates a properly initialized environment."""
342+
343+
# Create a mock container with download_dir and upload_dir support
344+
class MockContainerWithDirOps(mock_container.MockContainer):
345+
def __init__(self):
346+
super().__init__()
347+
self.resources = None
348+
self.image = "python:3.12" # Add image attribute for snapshot
349+
350+
async def download_dir(self, remote_path: str, local_path: pathlib.Path):
351+
del remote_path # Unused in mock
352+
local_path.parent.mkdir(parents=True, exist_ok=True)
353+
local_path.write_text("mock tarball")
354+
355+
async def upload_dir(self, local_path: pathlib.Path, remote_path: str):
356+
"""Mock upload_dir for container restoration."""
357+
del local_path, remote_path # Unused in mock
358+
359+
# Create and export state
360+
env = swebench_env.SweBenchEnv(
361+
tasks=[_MOCK_SWEBENCH_TASK],
362+
container_factory=mock_container.MockContainerFactory,
363+
step_limit=42,
364+
)
365+
366+
container = MockContainerWithDirOps()
367+
await container.start()
368+
env._container = container
369+
env._current_task = _MOCK_SWEBENCH_TASK
370+
env._step_count = 7
371+
env._requires_reset = False
372+
env._code_agent_task = None
373+
374+
snap = await env.export_state(tmp_path, snapshot_id="test-load")
375+
376+
# Load from snapshot
377+
class MockContainerFactory:
378+
@classmethod
379+
def from_image(cls, *, image: str, name: str | None = None, resources=None):
380+
del image, name, resources
381+
return MockContainerWithDirOps()
382+
383+
@classmethod
384+
def from_dockerfile(cls, *, dockerfile_path, name: str | None = None, resources=None):
385+
del dockerfile_path, name, resources
386+
return MockContainerWithDirOps()
387+
388+
restored_env = await swebench_env.SweBenchEnv.load_from_state(snap, container_factory=MockContainerFactory)
389+
390+
# Verify restoration
391+
assert restored_env._step_count == 7
392+
assert restored_env._step_limit == 42
393+
assert restored_env._requires_reset is False
394+
assert restored_env._current_task.instance_id == _MOCK_SWEBENCH_TASK.instance_id
395+
assert restored_env._container is not None
396+
397+
# Cleanup
398+
await restored_env.close()

src/ares/environments/swebench_env.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import functools
1111
import json
1212
import logging
13+
import pathlib
1314
import random
1415
import time
1516
from typing import Any, Literal, cast
@@ -266,3 +267,53 @@ def _deserialize_task(cls, task_data: dict, task_type: str) -> SwebenchTask:
266267
del task_type # Unused - validated by caller
267268
# The field validators will convert JSON strings to lists
268269
return SwebenchTask.model_validate(task_data)
270+
271+
@classmethod
272+
async def load_from_state(
273+
cls,
274+
snapshot_path: "base.snapshot.EnvironmentSnapshot | pathlib.Path",
275+
*,
276+
container_factory: containers.ContainerFactory | None = None,
277+
code_agent_factory: code_agent_base.CodeAgentFactory | None = None,
278+
tracker: stat_tracker.StatTracker | None = None,
279+
) -> "SweBenchEnv":
280+
"""Restore SweBenchEnv from snapshot.
281+
282+
Args:
283+
snapshot_path: EnvironmentSnapshot or path to snapshot.json
284+
container_factory: Override factory (uses snapshot metadata if None)
285+
code_agent_factory: Override factory (uses default if None)
286+
tracker: Optional stat tracker
287+
288+
Returns:
289+
Restored environment (NOT active, use async with)
290+
"""
291+
import pathlib as pathlib_module
292+
293+
from ares.environments import snapshot as snapshot_module
294+
295+
# Load snapshot if path provided
296+
if isinstance(snapshot_path, pathlib_module.Path):
297+
snap = snapshot_module.EnvironmentSnapshot.load_from_file(snapshot_path)
298+
else:
299+
snap = snapshot_path
300+
301+
# Deserialize task
302+
task = cls._deserialize_task(snap.task_data, snap.task_type)
303+
304+
# Create environment instance with tasks argument
305+
container_factory = container_factory or base.ares_daytona.DaytonaContainer
306+
code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent
307+
308+
env = cls(
309+
tasks=[task],
310+
container_factory=container_factory,
311+
code_agent_factory=code_agent_factory,
312+
step_limit=snap.step_limit,
313+
tracker=tracker,
314+
)
315+
316+
# Restore state using base helper
317+
await env._restore_from_snapshot(snap)
318+
319+
return env

0 commit comments

Comments
 (0)