Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,27 @@ __pycache__/
.Python

.vscode
.cursorindexingignore

# Claude Code worktree config — per-session local state, not part of the project
.claude/

# Jupyter Notebook
.ipynb_checkpoints

# Sandbox additions (not upstream)
.venv/
# .venv as a bare name covers both a real directory and a symlink to a
# shared parent venv (a trailing-slash pattern misses the symlink case).
.venv
recordings/
# tensorboard event files live in subdirs of runs/ (e.g. runs/<exp>/events.out.tfevents.*)
runs/**/events.out.tfevents.*
# per-training-job stdout logs written by frontend/server.py on /api/training/start
runs/_training_*.log
# Per-run summary CSVs that deep_rl_zoo.<algo>.run_atari appends to on
# every training run. Untracked so any TRAIN keeps the working tree
# clean — historical snapshots remain in git history if needed.
logs/*_atari_results.csv
# new training checkpoints; bundled ones are kept via !-rules below
checkpoints/*.ckpt
!checkpoints/IQN_Pong_2.ckpt
Expand Down
9 changes: 9 additions & 0 deletions deep_rl_zoo/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
import os
import datetime
import numpy as np

# NumPy 2.0 removed np.bool8; gym 0.25.2 (pinned by this repo) still references
# it in its passive_env_checker.py during env.step. The upstream README §9
# notes this as a known deprecation warning, but on NumPy 2.0+ the attribute
# is gone entirely and gym blows up. Restore the alias so every entry point
# (run_atari / run_classic / eval_agent / frontend.stream_eval) is covered.
if not hasattr(np, "bool8"):
np.bool8 = np.bool_ # type: ignore[attr-defined]

import cv2
import logging
import gym
Expand Down
28 changes: 27 additions & 1 deletion frontend/stream_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
from typing import NamedTuple

import numpy as np

# NumPy 2.0 removed np.bool8; gym 0.25.2 (pinned by this repo) still references
# it in its step/reset paths. Restore the alias before any gym import.
if not hasattr(np, "bool8"):
np.bool8 = np.bool_ # type: ignore[attr-defined]

import torch

import gym
Expand Down Expand Up @@ -99,7 +105,7 @@ def _build_iqn(state_dim, action_dim, device, checkpoint_path, env_name) -> tupl
return network, actor


def _build_prioritized_dqn(state_dim, action_dim, device, checkpoint_path, env_name) -> tuple:
def _build_dqn(state_dim, action_dim, device, checkpoint_path, env_name) -> tuple:
network = DqnConvNet(state_dim=state_dim, action_dim=action_dim)
ckpt = PyTorchCheckpoint(environment_name=env_name, agent_name="DQN", restore_only=True)
ckpt.register_pair(("network", network))
Expand All @@ -114,6 +120,25 @@ def _build_prioritized_dqn(state_dim, action_dim, device, checkpoint_path, env_n
return network, actor


def _build_prioritized_dqn(state_dim, action_dim, device, checkpoint_path, env_name) -> tuple:
network = DqnConvNet(state_dim=state_dim, action_dim=action_dim)
# The bundled PER-DQN_Pong_4.ckpt is stamped with agent_name="PER-DQN"
# (matches deep_rl_zoo/prioritized_dqn/run_atari.py line 96). The previous
# "DQN" value here caused PyTorchCheckpoint.restore to reject the file with
# 'agent_name "PER-DQN" and "DQN" mismatch.'
ckpt = PyTorchCheckpoint(environment_name=env_name, agent_name="PER-DQN", restore_only=True)
ckpt.register_pair(("network", network))
ckpt.restore(checkpoint_path)
network.eval()
actor = greedy_actors.EpsilonGreedyActor(
network=network,
exploration_epsilon=0.01,
random_state=np.random.RandomState(0),
device=device,
)
return network, actor


def _build_rainbow(state_dim, action_dim, device, checkpoint_path, env_name) -> tuple:
atoms = torch.linspace(-10.0, 10.0, 51)
network = RainbowDqnConvNet(state_dim=state_dim, action_dim=action_dim, atoms=atoms)
Expand Down Expand Up @@ -145,6 +170,7 @@ def _build_ppo_rnd(state_dim, action_dim, device, checkpoint_path, env_name) ->


ALGO_FACTORIES: dict[str, Callable] = {
"dqn": _build_dqn,
"iqn": _build_iqn,
"prioritized_dqn": _build_prioritized_dqn,
"rainbow": _build_rainbow,
Expand Down
11 changes: 0 additions & 11 deletions logs/dqn_atari_results.csv

This file was deleted.

4 changes: 0 additions & 4 deletions logs/iqn_atari_results.csv

This file was deleted.

11 changes: 0 additions & 11 deletions logs/per_dqn_atari_results.csv

This file was deleted.

Loading
Loading