|
5 | 5 | import abc |
6 | 6 | import asyncio |
7 | 7 | import atexit |
| 8 | +import dataclasses |
| 9 | +import datetime |
8 | 10 | import functools |
9 | 11 | import logging |
10 | 12 | import os |
11 | 13 | import pathlib |
12 | 14 | import time |
13 | 15 | from types import TracebackType |
14 | | -from typing import Literal, NamedTuple, Protocol, Self |
| 16 | +from typing import Any, Literal, NamedTuple, Protocol, Self |
15 | 17 | import uuid |
16 | 18 |
|
17 | 19 | from numpy.typing import NDArray |
|
21 | 23 | from ares.containers import containers |
22 | 24 | from ares.containers import daytona as ares_daytona |
23 | 25 | from ares.environments import base |
| 26 | +from ares.environments import snapshot |
24 | 27 | from ares.experiment_tracking import stat_tracker |
25 | 28 | from ares.llms import llm_clients |
26 | 29 | from ares.llms import queue_mediated_client |
@@ -287,9 +290,11 @@ def __init__( |
287 | 290 | self._is_active = False |
288 | 291 | self._container: containers.Container | None = None |
289 | 292 | self._current_task: TaskType | None = None |
| 293 | + self._code_agent: code_agent_base.CodeAgent | None = None |
290 | 294 | self._code_agent_task: asyncio.Task[None] | None = None |
291 | 295 | self._step_count = 0 |
292 | | - self._is_active = False |
| 296 | + self._requires_reset = False |
| 297 | + self._saved_agent_messages: list[dict] = [] |
293 | 298 |
|
294 | 299 | # Register for cleanup on exit. |
295 | 300 | _ENVIRONMENT_JANITOR.register_for_cleanup(self) |
@@ -430,6 +435,197 @@ def _assert_active(self) -> None: |
430 | 435 | if not self._is_active: |
431 | 436 | raise RuntimeError("Environment is not active.") |
432 | 437 |
|
| 438 | + def _require_container(self) -> containers.Container: |
| 439 | + """Get container or raise if not available.""" |
| 440 | + if self._container is None: |
| 441 | + raise RuntimeError("Container is not available.") |
| 442 | + return self._container |
| 443 | + |
| 444 | + def _require_task(self) -> TaskType: |
| 445 | + """Get current task or raise if not available.""" |
| 446 | + if self._current_task is None: |
| 447 | + raise RuntimeError("No current task set.") |
| 448 | + return self._current_task |
| 449 | + |
| 450 | + def _validate_snapshot_allowed(self) -> None: |
| 451 | + """Raise error if snapshot not allowed (mid-episode).""" |
| 452 | + if self._code_agent_task is not None and not self._code_agent_task.done(): |
| 453 | + raise RuntimeError( |
| 454 | + "Cannot snapshot during active episode. Call export_state() after reset() or after final step()." |
| 455 | + ) |
| 456 | + |
| 457 | + def _get_task_type(self) -> Literal["swebench", "harbor"]: |
| 458 | + """Return 'swebench' or 'harbor'. Override in subclasses if needed.""" |
| 459 | + # This will be overridden in subclasses if needed |
| 460 | + raise NotImplementedError("Override _get_task_type in subclass") |
| 461 | + |
| 462 | + def _get_container_type(self, container: containers.Container) -> Literal["daytona", "docker"]: |
| 463 | + """Return 'daytona' or 'docker'.""" |
| 464 | + from ares.containers.daytona import DaytonaContainer |
| 465 | + |
| 466 | + return "daytona" if isinstance(container, DaytonaContainer) else "docker" |
| 467 | + |
| 468 | + def _get_agent_messages(self) -> list[dict]: |
| 469 | + """Get agent message history if available.""" |
| 470 | + if self._code_agent is not None and hasattr(self._code_agent, "_messages"): |
| 471 | + return list(self._code_agent._messages) |
| 472 | + return [] |
| 473 | + |
| 474 | + async def _restore_container(self, snap: snapshot.EnvironmentSnapshot) -> containers.Container: |
| 475 | + """Restore container from filesystem snapshot.""" |
| 476 | + # Create new container from original image/dockerfile |
| 477 | + if snap.container_image: |
| 478 | + container = self._container_factory.from_image( |
| 479 | + image=snap.container_image, |
| 480 | + resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, |
| 481 | + ) |
| 482 | + elif snap.container_dockerfile_path: |
| 483 | + container = self._container_factory.from_dockerfile( |
| 484 | + dockerfile_path=pathlib.Path(snap.container_dockerfile_path), |
| 485 | + resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, |
| 486 | + ) |
| 487 | + else: |
| 488 | + raise ValueError("Snapshot must have either image or dockerfile_path") |
| 489 | + |
| 490 | + # Start container |
| 491 | + await container.start() |
| 492 | + |
| 493 | + # Restore filesystem from tarball |
| 494 | + fs_path = snap.snapshot_dir / "container_fs.tar.gz" |
| 495 | + if fs_path.exists(): |
| 496 | + await container.upload_dir(fs_path, "/") |
| 497 | + |
| 498 | + return container |
| 499 | + |
| 500 | + @classmethod |
| 501 | + def _default_container_factory(cls, snap: snapshot.EnvironmentSnapshot) -> containers.ContainerFactory: |
| 502 | + """Create default container factory from snapshot metadata.""" |
| 503 | + del snap # Unused - could be used in future to select factory based on container_type |
| 504 | + # Default to Daytona |
| 505 | + return ares_daytona.DaytonaContainer |
| 506 | + |
| 507 | + async def export_state( |
| 508 | + self, |
| 509 | + output_dir: pathlib.Path, |
| 510 | + *, |
| 511 | + snapshot_id: str | None = None, |
| 512 | + ) -> snapshot.EnvironmentSnapshot: |
| 513 | + """Export environment state to snapshot. |
| 514 | +
|
| 515 | + Args: |
| 516 | + output_dir: Directory to save snapshot files (tarballs, metadata) |
| 517 | + snapshot_id: Optional ID (defaults to UUID) |
| 518 | +
|
| 519 | + Returns: |
| 520 | + EnvironmentSnapshot with metadata |
| 521 | +
|
| 522 | + Raises: |
| 523 | + RuntimeError: If called during active episode (running code agent) |
| 524 | + """ |
| 525 | + # Validate episode boundary |
| 526 | + self._validate_snapshot_allowed() |
| 527 | + |
| 528 | + snapshot_id = snapshot_id or str(uuid.uuid4()) |
| 529 | + snapshot_dir = output_dir / snapshot_id |
| 530 | + snapshot_dir.mkdir(parents=True, exist_ok=True) |
| 531 | + |
| 532 | + # 1. Download container filesystem |
| 533 | + container = self._require_container() |
| 534 | + fs_path = snapshot_dir / "container_fs.tar.gz" |
| 535 | + await container.download_dir("/", fs_path) |
| 536 | + |
| 537 | + # 2. Serialize task |
| 538 | + task = self._require_task() |
| 539 | + task_data = self._serialize_task(task) |
| 540 | + |
| 541 | + # 3. Get agent messages (if agent exists) |
| 542 | + agent_messages = self._get_agent_messages() |
| 543 | + |
| 544 | + # 4. Create snapshot metadata |
| 545 | + snap = snapshot.EnvironmentSnapshot( |
| 546 | + snapshot_id=snapshot_id, |
| 547 | + created_at=datetime.datetime.now().isoformat(), |
| 548 | + snapshot_dir=snapshot_dir, |
| 549 | + step_count=self._step_count, |
| 550 | + step_limit=self._step_limit, |
| 551 | + requires_reset=self._requires_reset, |
| 552 | + task_type=self._get_task_type(), |
| 553 | + task_data=task_data, |
| 554 | + container_type=self._get_container_type(container), |
| 555 | + container_image=getattr(container, "image", None), |
| 556 | + container_dockerfile_path=( |
| 557 | + str(getattr(container, "dockerfile_path", None)) if hasattr(container, "dockerfile_path") else None |
| 558 | + ), |
| 559 | + container_resources=dataclasses.asdict(container.resources) if container.resources else None, |
| 560 | + agent_messages=agent_messages, |
| 561 | + ) |
| 562 | + |
| 563 | + # Save metadata JSON |
| 564 | + snap.save_to_file(snapshot_dir / "snapshot.json") |
| 565 | + |
| 566 | + return snap |
| 567 | + |
| 568 | + @classmethod |
| 569 | + async def load_from_state( |
| 570 | + cls, |
| 571 | + snapshot_path: snapshot.EnvironmentSnapshot | pathlib.Path, |
| 572 | + *, |
| 573 | + container_factory: containers.ContainerFactory | None = None, |
| 574 | + code_agent_factory: code_agent_base.CodeAgentFactory | None = None, |
| 575 | + ) -> "CodeBaseEnv": |
| 576 | + """Restore environment from snapshot. |
| 577 | +
|
| 578 | + Args: |
| 579 | + snapshot_path: EnvironmentSnapshot or path to snapshot.json |
| 580 | + container_factory: Override factory (uses snapshot metadata if None) |
| 581 | + code_agent_factory: Override factory (uses default if None) |
| 582 | +
|
| 583 | + Returns: |
| 584 | + Restored environment (NOT active, use async with) |
| 585 | + """ |
| 586 | + # Load snapshot if path provided |
| 587 | + if isinstance(snapshot_path, pathlib.Path): |
| 588 | + snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_path) |
| 589 | + else: |
| 590 | + snap = snapshot_path |
| 591 | + |
| 592 | + # Deserialize task |
| 593 | + task = cls._deserialize_task(snap.task_data, snap.task_type) |
| 594 | + |
| 595 | + # Create environment instance |
| 596 | + container_factory = container_factory or cls._default_container_factory(snap) |
| 597 | + code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent |
| 598 | + |
| 599 | + # Note: This creates a base CodeBaseEnv which is abstract |
| 600 | + # In practice, this should be called on SweBenchEnv or HarborEnv subclasses |
| 601 | + env = cls( |
| 602 | + container_factory=container_factory, |
| 603 | + code_agent_factory=code_agent_factory, |
| 604 | + step_limit=snap.step_limit, |
| 605 | + ) |
| 606 | + |
| 607 | + # Restore container |
| 608 | + env._container = await env._restore_container(snap) |
| 609 | + env._current_task = task |
| 610 | + env._step_count = snap.step_count |
| 611 | + env._requires_reset = snap.requires_reset |
| 612 | + |
| 613 | + # Store messages for later restoration |
| 614 | + env._saved_agent_messages = snap.agent_messages |
| 615 | + |
| 616 | + return env |
| 617 | + |
| 618 | + @abc.abstractmethod |
| 619 | + def _serialize_task(self, task: TaskType) -> dict: |
| 620 | + """Serialize task to dict. Override in subclasses.""" |
| 621 | + pass |
| 622 | + |
| 623 | + @classmethod |
| 624 | + @abc.abstractmethod |
| 625 | + def _deserialize_task(cls, task_data: dict, task_type: str) -> Any: |
| 626 | + """Deserialize task from dict. Override in subclasses.""" |
| 627 | + pass |
| 628 | + |
433 | 629 | @abc.abstractmethod |
434 | 630 | async def _reset_task(self) -> None: |
435 | 631 | """Should set `self._current_task` with a TaskType""" |
|
0 commit comments