Skip to content

Commit abae672

Browse files
committed
Add environment state snapshotting for RL research
Implements snapshot/restore functionality to save and replay episodes from specific checkpoints. Useful for debugging, trajectory analysis, and mechanistic interpretability. - Add EnvironmentSnapshot dataclass for serializing env state - Implement export_state() and load_from_state() methods on CodeBaseEnv - Support both SWEBench and Harbor environments - Save container filesystem as tarball with JSON metadata - Snapshots only work at episode boundaries (can't snapshot mid-episode) - Add comprehensive test coverage Closes withmartian#39
1 parent 430ef60 commit abae672

6 files changed

Lines changed: 857 additions & 3 deletions

File tree

examples/03_state_snapshotting.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Example demonstrating environment state snapshotting and restoration.
2+
3+
This example shows how to:
4+
1. Create a snapshot after reset (at episode boundary)
5+
2. Save the snapshot to disk
6+
3. Restore an environment from a saved snapshot
7+
4. Continue execution from the restored state
8+
9+
Example usage:
10+
11+
1. Make sure you have examples dependencies installed
12+
`uv sync --group examples`
13+
2. Run the example
14+
`uv run -m examples.03_state_snapshotting`
15+
"""
16+
17+
import asyncio
18+
import pathlib
19+
import tempfile
20+
21+
from ares.code_agents import mini_swe_agent
22+
from ares.containers import docker
23+
from ares.environments import snapshot
24+
from ares.environments import swebench_env
25+
from ares.llms import chat_completions_compatible
26+
27+
28+
async def main():
29+
# Create an LLM client
30+
agent = chat_completions_compatible.ChatCompletionCompatibleLLMClient(model="openai/gpt-4o-mini")
31+
32+
# Load SWE-bench tasks
33+
all_tasks = swebench_env.swebench_verified_tasks()
34+
tasks = [all_tasks[0]]
35+
36+
print(f"Running on task: {tasks[0].instance_id}")
37+
print(f"Repository: {tasks[0].repo}")
38+
print("-" * 80)
39+
40+
# Create a temporary directory for snapshots
41+
with tempfile.TemporaryDirectory() as snapshot_dir:
42+
snapshot_path = pathlib.Path(snapshot_dir)
43+
44+
# === PART 1: Create and save a snapshot ===
45+
print("\n[PART 1] Creating initial environment and snapshot...")
46+
47+
async with swebench_env.SweBenchEnv(
48+
tasks=tasks,
49+
code_agent_factory=mini_swe_agent.MiniSWECodeAgent,
50+
container_factory=docker.DockerContainer,
51+
) as env:
52+
# Reset the environment to get the first timestep
53+
ts = await env.reset()
54+
print(f"Environment reset complete. Step count: {env._step_count}")
55+
56+
# Take a few steps before snapshotting
57+
for i in range(3):
58+
action = await agent(ts.observation)
59+
print(f" Step {i}: Taking action...")
60+
ts = await env.step(action)
61+
62+
if ts.last():
63+
print(" Episode completed early")
64+
break
65+
66+
print(f"Current step count: {env._step_count}")
67+
68+
# Wait for agent to finish current operation (reach episode boundary)
69+
# In practice, you'd snapshot after step() returns with done=True
70+
# or after reset() completes. For this example, we'll simulate
71+
# waiting for agent to finish.
72+
if not ts.last():
73+
print("\n Note: For snapshotting, we need to be at episode boundary.")
74+
print(" Cancelling agent task to reach boundary...")
75+
if env._code_agent_task and not env._code_agent_task.done():
76+
env._code_agent_task.cancel()
77+
import contextlib
78+
79+
with contextlib.suppress(asyncio.CancelledError):
80+
await env._code_agent_task
81+
82+
# Now we can export state (at episode boundary)
83+
print("\n Exporting state snapshot...")
84+
snap = await env.export_state(snapshot_path, snapshot_id="example-snapshot")
85+
86+
print(f" ✓ Snapshot created: {snap.snapshot_id}")
87+
print(f" ✓ Snapshot saved to: {snap.snapshot_dir}")
88+
print(f" ✓ Step count in snapshot: {snap.step_count}")
89+
print(f" ✓ Task type: {snap.task_type}")
90+
print(f" ✓ Container type: {snap.container_type}")
91+
92+
# === PART 2: Restore from snapshot ===
93+
print("\n[PART 2] Restoring environment from snapshot...")
94+
95+
# Load snapshot metadata
96+
snapshot_file = snapshot_path / "example-snapshot" / "snapshot.json"
97+
loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file)
98+
99+
print(f" ✓ Loaded snapshot: {loaded_snap.snapshot_id}")
100+
print(f" ✓ Original step count: {loaded_snap.step_count}")
101+
102+
# Restore environment from snapshot
103+
# Note: This creates a new environment instance with the saved state
104+
restored_env = await swebench_env.SweBenchEnv.load_from_state(
105+
loaded_snap,
106+
container_factory=docker.DockerContainer,
107+
code_agent_factory=mini_swe_agent.MiniSWECodeAgent,
108+
)
109+
110+
print(" ✓ Environment restored")
111+
print(f" ✓ Restored step count: {restored_env._step_count}")
112+
print(f" ✓ Task: {restored_env._current_task.instance_id}")
113+
114+
# Use the restored environment in async context
115+
async with restored_env:
116+
print("\n[PART 3] Continuing from restored state...")
117+
118+
# The environment is now at the same state as when we snapshotted
119+
# We can continue taking steps from here
120+
ts = await restored_env.reset() # Reset to start a new episode
121+
step_count = 0
122+
123+
# Take a few more steps to demonstrate it works
124+
while not ts.last() and step_count < 3:
125+
action = await agent(ts.observation)
126+
print(f" Step {step_count}: Taking action from restored env...")
127+
ts = await restored_env.step(action)
128+
step_count += 1
129+
130+
print(f"\n ✓ Completed {step_count} additional steps from restored state")
131+
132+
print("\n" + "=" * 80)
133+
print("Snapshot example completed successfully!")
134+
print("=" * 80)
135+
print("\nKey takeaways:")
136+
print(" 1. Snapshots can only be taken at episode boundaries")
137+
print(" 2. Snapshots save: task state, container filesystem, agent messages")
138+
print(" 3. Restored environments can continue execution normally")
139+
print(" 4. Use cases: debugging, RL replay, mechanistic interpretability")
140+
141+
142+
if __name__ == "__main__":
143+
asyncio.run(main())

src/ares/environments/base.py

Lines changed: 198 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import abc
66
import asyncio
77
import atexit
8+
import dataclasses
9+
import datetime
810
import functools
911
import logging
1012
import os
1113
import pathlib
1214
import time
1315
from types import TracebackType
14-
from typing import Literal, NamedTuple, Protocol, Self
16+
from typing import Any, Literal, NamedTuple, Protocol, Self
1517
import uuid
1618

1719
from numpy.typing import NDArray
@@ -21,6 +23,7 @@
2123
from ares.containers import containers
2224
from ares.containers import daytona as ares_daytona
2325
from ares.environments import base
26+
from ares.environments import snapshot
2427
from ares.experiment_tracking import stat_tracker
2528
from ares.llms import llm_clients
2629
from ares.llms import queue_mediated_client
@@ -287,9 +290,11 @@ def __init__(
287290
self._is_active = False
288291
self._container: containers.Container | None = None
289292
self._current_task: TaskType | None = None
293+
self._code_agent: code_agent_base.CodeAgent | None = None
290294
self._code_agent_task: asyncio.Task[None] | None = None
291295
self._step_count = 0
292-
self._is_active = False
296+
self._requires_reset = False
297+
self._saved_agent_messages: list[dict] = []
293298

294299
# Register for cleanup on exit.
295300
_ENVIRONMENT_JANITOR.register_for_cleanup(self)
@@ -430,6 +435,197 @@ def _assert_active(self) -> None:
430435
if not self._is_active:
431436
raise RuntimeError("Environment is not active.")
432437

438+
def _require_container(self) -> containers.Container:
439+
"""Get container or raise if not available."""
440+
if self._container is None:
441+
raise RuntimeError("Container is not available.")
442+
return self._container
443+
444+
def _require_task(self) -> TaskType:
445+
"""Get current task or raise if not available."""
446+
if self._current_task is None:
447+
raise RuntimeError("No current task set.")
448+
return self._current_task
449+
450+
def _validate_snapshot_allowed(self) -> None:
451+
"""Raise error if snapshot not allowed (mid-episode)."""
452+
if self._code_agent_task is not None and not self._code_agent_task.done():
453+
raise RuntimeError(
454+
"Cannot snapshot during active episode. Call export_state() after reset() or after final step()."
455+
)
456+
457+
def _get_task_type(self) -> Literal["swebench", "harbor"]:
458+
"""Return 'swebench' or 'harbor'. Override in subclasses if needed."""
459+
# This will be overridden in subclasses if needed
460+
raise NotImplementedError("Override _get_task_type in subclass")
461+
462+
def _get_container_type(self, container: containers.Container) -> Literal["daytona", "docker"]:
463+
"""Return 'daytona' or 'docker'."""
464+
from ares.containers.daytona import DaytonaContainer
465+
466+
return "daytona" if isinstance(container, DaytonaContainer) else "docker"
467+
468+
def _get_agent_messages(self) -> list[dict]:
469+
"""Get agent message history if available."""
470+
if self._code_agent is not None and hasattr(self._code_agent, "_messages"):
471+
return list(self._code_agent._messages)
472+
return []
473+
474+
async def _restore_container(self, snap: snapshot.EnvironmentSnapshot) -> containers.Container:
475+
"""Restore container from filesystem snapshot."""
476+
# Create new container from original image/dockerfile
477+
if snap.container_image:
478+
container = self._container_factory.from_image(
479+
image=snap.container_image,
480+
resources=containers.Resources(**snap.container_resources) if snap.container_resources else None,
481+
)
482+
elif snap.container_dockerfile_path:
483+
container = self._container_factory.from_dockerfile(
484+
dockerfile_path=pathlib.Path(snap.container_dockerfile_path),
485+
resources=containers.Resources(**snap.container_resources) if snap.container_resources else None,
486+
)
487+
else:
488+
raise ValueError("Snapshot must have either image or dockerfile_path")
489+
490+
# Start container
491+
await container.start()
492+
493+
# Restore filesystem from tarball
494+
fs_path = snap.snapshot_dir / "container_fs.tar.gz"
495+
if fs_path.exists():
496+
await container.upload_dir(fs_path, "/")
497+
498+
return container
499+
500+
@classmethod
501+
def _default_container_factory(cls, snap: snapshot.EnvironmentSnapshot) -> containers.ContainerFactory:
502+
"""Create default container factory from snapshot metadata."""
503+
del snap # Unused - could be used in future to select factory based on container_type
504+
# Default to Daytona
505+
return ares_daytona.DaytonaContainer
506+
507+
async def export_state(
508+
self,
509+
output_dir: pathlib.Path,
510+
*,
511+
snapshot_id: str | None = None,
512+
) -> snapshot.EnvironmentSnapshot:
513+
"""Export environment state to snapshot.
514+
515+
Args:
516+
output_dir: Directory to save snapshot files (tarballs, metadata)
517+
snapshot_id: Optional ID (defaults to UUID)
518+
519+
Returns:
520+
EnvironmentSnapshot with metadata
521+
522+
Raises:
523+
RuntimeError: If called during active episode (running code agent)
524+
"""
525+
# Validate episode boundary
526+
self._validate_snapshot_allowed()
527+
528+
snapshot_id = snapshot_id or str(uuid.uuid4())
529+
snapshot_dir = output_dir / snapshot_id
530+
snapshot_dir.mkdir(parents=True, exist_ok=True)
531+
532+
# 1. Download container filesystem
533+
container = self._require_container()
534+
fs_path = snapshot_dir / "container_fs.tar.gz"
535+
await container.download_dir("/", fs_path)
536+
537+
# 2. Serialize task
538+
task = self._require_task()
539+
task_data = self._serialize_task(task)
540+
541+
# 3. Get agent messages (if agent exists)
542+
agent_messages = self._get_agent_messages()
543+
544+
# 4. Create snapshot metadata
545+
snap = snapshot.EnvironmentSnapshot(
546+
snapshot_id=snapshot_id,
547+
created_at=datetime.datetime.now().isoformat(),
548+
snapshot_dir=snapshot_dir,
549+
step_count=self._step_count,
550+
step_limit=self._step_limit,
551+
requires_reset=self._requires_reset,
552+
task_type=self._get_task_type(),
553+
task_data=task_data,
554+
container_type=self._get_container_type(container),
555+
container_image=getattr(container, "image", None),
556+
container_dockerfile_path=(
557+
str(getattr(container, "dockerfile_path", None)) if hasattr(container, "dockerfile_path") else None
558+
),
559+
container_resources=dataclasses.asdict(container.resources) if container.resources else None,
560+
agent_messages=agent_messages,
561+
)
562+
563+
# Save metadata JSON
564+
snap.save_to_file(snapshot_dir / "snapshot.json")
565+
566+
return snap
567+
568+
@classmethod
569+
async def load_from_state(
570+
cls,
571+
snapshot_path: snapshot.EnvironmentSnapshot | pathlib.Path,
572+
*,
573+
container_factory: containers.ContainerFactory | None = None,
574+
code_agent_factory: code_agent_base.CodeAgentFactory | None = None,
575+
) -> "CodeBaseEnv":
576+
"""Restore environment from snapshot.
577+
578+
Args:
579+
snapshot_path: EnvironmentSnapshot or path to snapshot.json
580+
container_factory: Override factory (uses snapshot metadata if None)
581+
code_agent_factory: Override factory (uses default if None)
582+
583+
Returns:
584+
Restored environment (NOT active, use async with)
585+
"""
586+
# Load snapshot if path provided
587+
if isinstance(snapshot_path, pathlib.Path):
588+
snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_path)
589+
else:
590+
snap = snapshot_path
591+
592+
# Deserialize task
593+
task = cls._deserialize_task(snap.task_data, snap.task_type)
594+
595+
# Create environment instance
596+
container_factory = container_factory or cls._default_container_factory(snap)
597+
code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent
598+
599+
# Note: This creates a base CodeBaseEnv which is abstract
600+
# In practice, this should be called on SweBenchEnv or HarborEnv subclasses
601+
env = cls(
602+
container_factory=container_factory,
603+
code_agent_factory=code_agent_factory,
604+
step_limit=snap.step_limit,
605+
)
606+
607+
# Restore container
608+
env._container = await env._restore_container(snap)
609+
env._current_task = task
610+
env._step_count = snap.step_count
611+
env._requires_reset = snap.requires_reset
612+
613+
# Store messages for later restoration
614+
env._saved_agent_messages = snap.agent_messages
615+
616+
return env
617+
618+
@abc.abstractmethod
619+
def _serialize_task(self, task: TaskType) -> dict:
620+
"""Serialize task to dict. Override in subclasses."""
621+
pass
622+
623+
@classmethod
624+
@abc.abstractmethod
625+
def _deserialize_task(cls, task_data: dict, task_type: str) -> Any:
626+
"""Deserialize task from dict. Override in subclasses."""
627+
pass
628+
433629
@abc.abstractmethod
434630
async def _reset_task(self) -> None:
435631
"""Should set `self._current_task` with a TaskType"""

0 commit comments

Comments
 (0)