Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pearl/utils/instantiations/environments/gym_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pearl.utils.instantiations.spaces.discrete import DiscreteSpace
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import Tensor
import torch

try:
import gymnasium as gym
Expand Down Expand Up @@ -198,11 +199,35 @@ def _get_gym_action(
return pearl_to_gym_action_transform(pearl_action)


def _tuple_to_box_space(gym_space: gym.Space) -> BoxSpace:
"""Converts a Gymnasium ``Tuple`` space into a ``BoxSpace`` by concatenating
the bounds of each subspace."""
assert gym_space.__class__.__name__ == "Tuple"
lows = []
highs = []
for subspace in gym_space.spaces: # pyre-ignore[16]
name = subspace.__class__.__name__
if name == "Box":
lows.append(torch.tensor(subspace.low).flatten())
highs.append(torch.tensor(subspace.high).flatten())
elif name == "Discrete":
start = getattr(subspace, "start", 0)
lows.append(torch.tensor([float(start)]))
highs.append(torch.tensor([float(start + subspace.n - 1)]))
else:
raise NotImplementedError(f"Unsupported subspace type: {name}")
low = torch.cat(lows).float()
high = torch.cat(highs).float()
return BoxSpace(low=low, high=high)


def _get_pearl_space(
gym_space: gym.Space, gym_to_pearl_map: dict[str, Any]
) -> ActionSpace:
"""Returns the Pearl action space for this environment."""
gym_space_name = gym_space.__class__.__name__
if gym_space_name == "Tuple":
return _tuple_to_box_space(gym_space)
try:
pearl_action_space_cls = gym_to_pearl_map[gym_space_name]
except KeyError:
Expand Down
17 changes: 10 additions & 7 deletions pearl/utils/instantiations/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@


def reshape_to_1d_tensor(x: Tensor) -> Tensor:
"""Reshapes a Tensor that is either scalar or `1 x d` -> `d`."""
"""
Reshapes ``x`` to a 1-D tensor.

Scalars are expanded and ``(1, d)`` tensors are squeezed. For tensors with
more than one dimension, the elements are flattened.
"""
if x.ndim == 1:
return x
if x.ndim == 0: # scalar -> `d`
x = x.unsqueeze(dim=0) # `1 x d` -> `d`
elif x.ndim == 2 and x.shape[0] == 1:
x = x.squeeze(dim=0)
else:
raise ValueError(f"Tensor of shape {x.shape} is not supported.")
return x
return x.unsqueeze(dim=0) # `1 x d` -> `d`
if x.ndim == 2 and x.shape[0] == 1:
return x.squeeze(dim=0)
return x.flatten()
60 changes: 60 additions & 0 deletions test/unit/test_box_space_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
import numpy as np
import torch

try:
import gymnasium as gym
except ModuleNotFoundError:
import gym

from pearl.utils.instantiations.environments.gym_environment import GymEnvironment
from pearl.utils.instantiations.spaces.box import BoxSpace


class DummyBoxGymEnv(gym.Env):
def __init__(self, shape: tuple[int, ...]):
super().__init__()
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=shape,
dtype=np.uint8,
)
self._shape = shape

def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
return np.zeros(self._shape, dtype=np.uint8), {}

def step(self, action):
return np.zeros(self._shape, dtype=np.uint8), 0.0, True, False, {}


class TestBoxSpaceND(unittest.TestCase):
def _run_shape_test(self, shape: tuple[int, ...]):
env = GymEnvironment(DummyBoxGymEnv(shape))
self.assertIsInstance(env.observation_space, BoxSpace)
self.assertEqual(env.observation_space.shape, torch.Size(shape))
obs, _ = env.reset()
self.assertEqual(obs.shape, shape)
step_result = env.step(torch.tensor(0))
self.assertEqual(step_result.observation.shape, shape)

def test_box_space_3d_observation(self) -> None:
self._run_shape_test((2, 3, 4))

def test_box_space_4d_observation(self) -> None:
self._run_shape_test((2, 3, 4, 5))


if __name__ == "__main__":
unittest.main()
52 changes: 52 additions & 0 deletions test/unit/test_tuple_space_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
import torch

try:
import gymnasium as gym
except ModuleNotFoundError:
import gym

from pearl.utils.instantiations.environments.gym_environment import GymEnvironment
from pearl.utils.instantiations.spaces.box import BoxSpace


class DummyTupleGymEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Tuple(
(
gym.spaces.Discrete(2),
gym.spaces.Discrete(3),
)
)

def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
return (0, 1), {}

def step(self, action):
return (0, 1), 0.0, True, False, {}


class TestTupleSpaceSupport(unittest.TestCase):
def test_gym_environment_tuple_observation(self):
env = GymEnvironment(DummyTupleGymEnv())
self.assertIsInstance(env.observation_space, BoxSpace)
self.assertEqual(env.observation_space.shape[0], 2)
obs, _ = env.reset()
self.assertEqual(obs, (0, 1))
step_result = env.step(torch.tensor(0))
self.assertEqual(step_result.observation, (0, 1))


if __name__ == "__main__":
unittest.main()
Loading