Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions ptan/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# Mostly copy-pasted from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
import numpy as np
from collections import deque
import gym
from gym import spaces
import gymnasium as gym
from gymnasium import spaces
import cv2


Expand All @@ -20,20 +20,20 @@ def __init__(self, env=None, noop_max=30):
def step(self, action):
return self.env.step(action)

def reset(self):
def reset(self, **kwargs):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = np.random.randint(1, self.noop_max + 1)
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
obs, _, done, _, info = self.env.step(0)
if done:
obs = self.env.reset()
return obs
obs, info = self.env.reset()
return obs, info


class FireResetEnv(gym.Wrapper):
Expand All @@ -46,15 +46,15 @@ def __init__(self, env=None):
def step(self, action):
return self.env.step(action)

def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _, info = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
obs, _, done, _, info = self.env.step(2)
if done:
self.env.reset()
return obs
return obs, info


class EpisodicLifeEnv(gym.Wrapper):
Expand All @@ -68,33 +68,32 @@ def __init__(self, env=None):
self.was_real_reset = False

def step(self, action):
obs, reward, done, info = self.env.step(action)
obs, reward, done, _, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
# so its important to keep lives > 0, so that we only reset once the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
return obs, reward, done, _, info

def reset(self):
def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
obs, info = self.env.reset(**kwargs)
self.was_real_reset = True
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
obs, _, _, _, info = self.env.step(0)
self.was_real_reset = False
self.lives = self.env.unwrapped.ale.lives()
return obs
return obs, info


class MaxAndSkipEnv(gym.Wrapper):
Expand All @@ -109,22 +108,20 @@ def step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
obs, reward, done, _, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break

max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, _, info

return max_frame, total_reward, done, info

def reset(self):
def reset(self, **kwargs):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
obs, info = self.env.reset(**kwargs)
self._obs_buffer.append(obs)
return obs
return obs, info


class ProcessFrame84(gym.ObservationWrapper):
Expand Down Expand Up @@ -184,18 +181,19 @@ def __init__(self, env, k):
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0]*k, shp[1], shp[2]), dtype=np.float32)
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0]*k, shp[1], shp[2]), dtype=np.float32)

def reset(self):
ob = self.env.reset()
def reset(self, **kwargs):
ob, info = self.env.reset(**kwargs)
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()
return self._get_ob(), info

def step(self, action):
ob, reward, done, info = self.env.step(action)
ob, reward, done, _, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
return self._get_ob(), reward, done, _, info

def _get_ob(self):
assert len(self.frames) == self.k
Expand All @@ -216,13 +214,15 @@ class ImageToPyTorch(gym.ObservationWrapper):
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
dtype=np.float32)
new_shape = (old_shape[-1], old_shape[0], old_shape[1])
self.observation_space = gym.spaces.Box(
low=0.0, high=1.0, shape=new_shape, dtype=np.float32)

def observation(self, observation):
return np.swapaxes(observation, 2, 0)



def wrap_dqn(env, stack_frames=4, episodic_life=True, reward_clipping=True):
"""Apply a common set of wrappers for Atari games."""
assert 'NoFrameskip' in env.spec.id
Expand All @@ -237,4 +237,5 @@ def wrap_dqn(env, stack_frames=4, episodic_life=True, reward_clipping=True):
env = FrameStack(env, stack_frames)
if reward_clipping:
env = ClippedRewardsWrapper(env)

return env
4 changes: 2 additions & 2 deletions ptan/common/wrappers_simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Simple wrappers
"""
import gym
import gymnasium as gym
import numpy as np
import collections

Expand All @@ -24,7 +24,7 @@ def __init__(self, env, k):
dtype=env.observation_space.dtype)

def reset(self):
ob = self.env.reset()
ob = self.env.reset()[0]
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()
Expand Down
16 changes: 8 additions & 8 deletions ptan/experience.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import torch
import random
import collections
Expand Down Expand Up @@ -50,7 +50,7 @@ def __iter__(self):
states, agent_states, histories, cur_rewards, cur_steps = [], [], [], [], []
env_lens = []
for env in self.pool:
obs = env.reset()
obs, info = env.reset()
# if the environment is vectorized, all it's output is lists of results.
# Details are here: https://github.com/openai/universe/blob/master/doc/env_semantics.rst
if self.vectorized:
Expand Down Expand Up @@ -89,9 +89,9 @@ def __iter__(self):
global_ofs = 0
for env_idx, (env, action_n) in enumerate(zip(self.pool, grouped_actions)):
if self.vectorized:
next_state_n, r_n, is_done_n, _ = env.step(action_n)
next_state_n, r_n, is_done_n, _, _ = env.step(action_n)
else:
next_state, r, is_done, _ = env.step(action_n[0])
next_state, r, is_done, _, _ = env.step(action_n[0])
next_state_n, r_n, is_done_n = [next_state], [r], [is_done]

for ofs, (action, next_state, r, is_done) in enumerate(zip(action_n, next_state_n, r_n, is_done_n)):
Expand Down Expand Up @@ -119,7 +119,7 @@ def __iter__(self):
cur_rewards[idx] = 0.0
cur_steps[idx] = 0
# vectorized envs are reset automatically
states[idx] = env.reset() if not self.vectorized else None
states[idx] = env.reset()[0] if not self.vectorized else None
agent_states[idx] = self.agent.initial_state()
history.clear()
global_ofs += len(action_n)
Expand Down Expand Up @@ -233,7 +233,7 @@ def __init__(self, env, agent, gamma, steps_count=5):

def __iter__(self):
pool_size = len(self.pool)
states = [np.array(e.reset()) for e in self.pool]
states = [np.array(e.reset()[0]) for e in self.pool]
mb_states = np.zeros((pool_size, self.steps_count) + states[0].shape, dtype=states[0].dtype)
mb_rewards = np.zeros((pool_size, self.steps_count), dtype=np.float32)
mb_values = np.zeros((pool_size, self.steps_count), dtype=np.float32)
Expand All @@ -250,11 +250,11 @@ def __iter__(self):
dones = []
new_states = []
for env_idx, (e, action) in enumerate(zip(self.pool, actions)):
o, r, done, _ = e.step(action)
o, r, done, _, _ = e.step(action)
total_rewards[env_idx] += r
total_steps[env_idx] += 1
if done:
o = e.reset()
o = e.reset()[0]
self.total_rewards.append(total_rewards[env_idx])
self.total_steps.append(total_steps[env_idx])
total_rewards[env_idx] = 0.0
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ opencv-python
gym
atari-py
nose
torch==1.7.0
torch-ignite==0.4.2
torch==1.13.0
torch-ignite==0.4.12
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import setuptools


requirements = ['torch==1.7.0', 'gym', 'atari-py', 'numpy', 'opencv-python']
requirements = ['torch==1.13.0', 'gymnasium', 'atari-py', 'numpy', 'opencv-python']


setuptools.setup(
Expand All @@ -13,7 +13,7 @@
author_email="max.lapan@gmail.com",
license='GPL-v3',
description="PyTorch reinforcement learning framework",
version="0.7",
version="1.0",
packages=setuptools.find_packages(),
install_requires=requirements,
)