diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml new file mode 100644 index 0000000..3158532 --- /dev/null +++ b/.github/workflows/unit-tests-docker.yml @@ -0,0 +1,34 @@ +name: Unit Tests +on: + push: + branches: [ "master", "develop"] + pull_request: + branches: [ "master", "develop"] + +jobs: + unit_tests: + name: Unit Tests + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Build Docker image + run: docker build -t oprl . + + - name: Unit Tests + run: docker run --rm oprl + + - name: Extract coverage + run: | + docker run --rm -v $(pwd):/host oprl sh -c " + pytest --cov=oprl --cov-report=xml && + cp coverage.xml /host/ + " + + - name: Upload coverage + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: schatty/oprl + file: ./coverage.xml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..029e0f1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.10.8 + +WORKDIR /app + +RUN pip install --no-cache-dir --upgrade pip + +COPY . . + +RUN pip install --no-cache-dir . && pip install pytest pytest-cov + +# Run tests by default +CMD ["pytest", "tests/functional"] diff --git a/README.md b/README.md index c9f6eaf..4a64378 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,8 @@ A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. Benchmarking resutls are available at associated homepage: [Homepage](https://schatty.github.io/oprl/) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![codecov](https://codecov.io/gh/schatty/oprl/branch/master/graph/badge.svg)](https://codecov.io/gh/schatty/oprl) + # Disclaimer diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f531d30 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,44 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + +[project] +name = "oprl" +version = "0.1.0" +description = "An RL Lib" +readme = "README.md" +requires-python = "==3.10.8" +license = {text = "MIT"} +authors = [ + {name = "Igor Kuznetsov"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.10", +] +dependencies = [ + "torch==2.2.2", + "tensorboard==2.15.1", + "packaging==23.2", + "dm-control==1.0.11", + "mujoco==2.3.3", + "numpy==1.26.4", +] + +[project.optional-dependencies] +dev = [ + "pytest>=6.0", + "black", + "flake8", +] + +[project.urls] +"Homepage" = "https://schatty.github.io/oprl" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["oprl*"] + +[tool.setuptools.package-dir] +"" = "src" diff --git a/requirements.txt b/requirements.txt index e099f43..2759849 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch==2.2.2 tensorboard==2.15.1 packaging==23.2 -dm-control==1.0.16 -mujoco==3.1.3 +dm-control==1.0.11 +mujoco==2.3.3 +numpy==1.26.4 diff --git a/src/oprl/configs/ddpg.py b/src/oprl/configs/ddpg.py index fa24a3f..22437f0 100644 --- a/src/oprl/configs/ddpg.py +++ b/src/oprl/configs/ddpg.py @@ -26,11 +26,11 @@ def make_env(seed: int): config = { "state_dim": STATE_DIM, "action_dim": ACTION_DIM, - "num_steps": int(1_000_000), + "num_steps": int(100_000), "eval_every": 2500, "device": args.device, "save_buffer": False, - "visualise_every": 0, + "visualise_every": 50000, "estimate_q_every": 5000, "log_every": 2500, } @@ -48,7 +48,6 @@ def make_algo(logger): def make_logger(seed: int) -> Logger: - global config log_dir = create_logdir(logdir="logs", algo="DDPG", env=args.env, seed=seed) return FileLogger(log_dir, config) diff --git a/src/oprl/configs/sac.py b/src/oprl/configs/sac.py index 2981aae..475fd0e 100644 --- a/src/oprl/configs/sac.py +++ b/src/oprl/configs/sac.py @@ -50,7 +50,6 @@ def make_algo(logger): def make_logger(seed: int) -> Logger: - global config log_dir = create_logdir(logdir="logs", algo="SAC", env=args.env, seed=seed) return FileLogger(log_dir, config) diff --git a/src/oprl/configs/td3.py b/src/oprl/configs/td3.py index 9cb7174..c8dac65 100644 --- a/src/oprl/configs/td3.py +++ b/src/oprl/configs/td3.py @@ -48,8 +48,6 @@ def make_algo(logger): def make_logger(seed: int) -> Logger: - global config - log_dir = create_logdir(logdir="logs", algo="TD3", env=args.env, seed=seed) return FileLogger(log_dir, config) diff --git a/src/oprl/configs/tqc.py b/src/oprl/configs/tqc.py index b2242f1..640071e 100644 --- a/src/oprl/configs/tqc.py +++ b/src/oprl/configs/tqc.py @@ -48,7 +48,6 @@ def make_algo(logger: Logger): def make_logger(seed: int) -> Logger: - global config log_dir = create_logdir(logdir="logs", algo="TQC", env=args.env, seed=seed) return FileLogger(log_dir, config) diff --git a/src/oprl/configs/utils.py b/src/oprl/configs/utils.py index 0e216ab..50c16e7 100644 --- a/src/oprl/configs/utils.py +++ b/src/oprl/configs/utils.py @@ -29,7 +29,7 @@ def parse_args() -> argparse.Namespace: def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str: - dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm") + dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss") log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}") logging.info(f"LOGDIR: {log_dir}") return log_dir diff --git a/src/oprl/env.py b/src/oprl/env.py index f4d8fe5..6017ab2 100644 --- a/src/oprl/env.py +++ b/src/oprl/env.py @@ -56,12 +56,12 @@ def step( ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: obs, reward, cost, terminated, truncated, info = self._env.step(action) info["cost"] = cost - return obs, reward, terminated, truncated, info + return obs.astype("float32"), reward, terminated, truncated, info def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: obs, info = self._env.reset(seed=self._seed) self._env.step(self._env.action_space.sample()) - return obs, info + return obs.astype("float32"), info def sample_action(self): return self._env.action_space.sample() @@ -129,7 +129,7 @@ def render(self) -> npt.ArrayLike: width=self._render_width, ) img = img.astype(np.uint8) - return np.expand_dims(img, 0) + return img def _flat_obs(self, obs: OrderedDict) -> npt.ArrayLike: obs_flatten = [] diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index 160649b..c9fccb2 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -1,3 +1,4 @@ +import os from typing import Any, Callable import numpy as np @@ -24,6 +25,7 @@ def __init__( eval_interval: int = int(2e3), num_eval_episodes: int = 10, save_buffer_every: int = 0, + save_policy_every: int = int(50_000), visualise_every: int = 0, estimate_q_every: int = 0, stdout_log_every: int = int(1e5), @@ -60,6 +62,7 @@ def __init__( self._visualize_every = visualise_every self._estimate_q_every = estimate_q_every self._stdout_log_every = stdout_log_every + self._save_policy_every= save_policy_every self._logger = logger self.seed = seed @@ -106,6 +109,7 @@ def train(self): self._eval_routine(env_step, batch) self._visualize(env_step) self._save_buffer(env_step) + self._save_policy(env_step) self._log_stdout(env_step, batch) def _eval_routine(self, env_step: int, batch): @@ -151,9 +155,14 @@ def _visualize(self, env_step: int): self._logger.log_video("eval_policy", imgs, env_step) def _save_buffer(self, env_step: int): + # TODO: doesn't work if self._save_buffer_every > 0 and env_step % self._save_buffer_every == 0: self.buffer.save(f"{self.log_dir}/buffers/buffer_step_{env_step}.pickle") + def _save_policy(self, env_step: int): + if self._save_policy_every > 0 and env_step % self._save_policy_every == 0: + self._logger.save_weights(self._algo.actor, env_step) + def _estimate_q(self, env_step: int): if self._estimate_q_every > 0 and env_step % self._estimate_q_every == 0: q_true = self.estimate_true_q() @@ -187,7 +196,7 @@ def visualise_policy(self): action = self._algo.exploit(state) state, _, terminated, truncated, _ = env.step(action) done = terminated or truncated - return np.concatenate(imgs) + return np.concatenate(imgs, dtype="uint8") except Exception as e: print(f"Failed to visualise a policy: {e}") return None diff --git a/src/oprl/trainers/safe_trainer.py b/src/oprl/trainers/safe_trainer.py index ae3a85e..a5edf9f 100644 --- a/src/oprl/trainers/safe_trainer.py +++ b/src/oprl/trainers/safe_trainer.py @@ -23,6 +23,7 @@ def __init__( eval_interval: int = int(2e3), num_eval_episodes: int = 10, save_buffer_every: int = 0, + save_policy_every: int = int(50_000), visualise_every: int = 0, estimate_q_every: int = 0, stdout_log_every: int = int(1e5), @@ -65,6 +66,7 @@ def __init__( eval_interval=eval_interval, num_eval_episodes=num_eval_episodes, save_buffer_every=save_buffer_every, + save_policy_every=save_policy_every, visualise_every=visualise_every, estimate_q_every=estimate_q_every, stdout_log_every=stdout_log_every, @@ -97,10 +99,11 @@ def train(self): if len(self.buffer) < self.batch_size: continue batch = self.buffer.sample(self.batch_size) - self._algo.update(batch) + self._algo.update(*batch) self._eval_routine(env_step, batch) self._visualize(env_step) + self._save_policy(env_step) self._save_buffer(env_step) self._log_stdout(env_step, batch) diff --git a/src/oprl/utils/logger.py b/src/oprl/utils/logger.py index 563c0c1..c78175a 100644 --- a/src/oprl/utils/logger.py +++ b/src/oprl/utils/logger.py @@ -3,9 +3,12 @@ import os import shutil from abc import ABC, abstractmethod +from sys import path from typing import Any import numpy as np +import torch +import torch.nn as nn from torch.utils.tensorboard.writer import SummaryWriter @@ -66,11 +69,19 @@ def log_scalar(self, tag: str, value: float, step: int) -> None: self._log_scalar_to_file(tag, value, step) def log_video(self, tag: str, imgs, step: int) -> None: - os.makedirs(os.path.join(self._log_dir, "images")) + os.makedirs(os.path.join(self._log_dir, "images"), exist_ok=True) fn = os.path.join(self._log_dir, "images", f"{tag}_step_{step}.npz") with open(fn, "wb") as f: np.save(f, imgs) + def save_weights(self, weights: nn.Module, step: int) -> None: + os.makedirs(os.path.join(self._log_dir, "weights"), exist_ok=True) + fn = os.path.join(self._log_dir, "weights", f"step_{step}.w") + torch.save( + weights, + fn + ) + def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None: fn = os.path.join(self._log_dir, f"{tag}.log") os.makedirs(os.path.dirname(fn), exist_ok=True) diff --git a/src/setup.py b/src/setup.py deleted file mode 100644 index 06324f0..0000000 --- a/src/setup.py +++ /dev/null @@ -1,20 +0,0 @@ -from setuptools import setup - -setup( - name="oprl", - version="1.0", - author="Igor Kuznetsov", - description="Reinforcement Learning Off-policy Algorithms with PyTorch", - long_description="todo", - url="todo", - keywords="reinforcement, learning, off-policy", - python_requires=">=3.7", - # packages=find_packages(include=['exampleproject', 'exampleproject.*']), - # install_requires=[ - # 'PyYAML', - # 'pandas==0.23.3', - # 'numpy>=1.14.5', - # 'matplotlib>=2.2.0,, - # 'jupyter' - # ], -) diff --git a/tests/functional/requirements.txt b/tests/functional/requirements.txt deleted file mode 100644 index 2e41aa3..0000000 --- a/tests/functional/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -pytest==8.0.1 -torch==2.2.0 -tensorboard==2.15.1 -packaging==23.2 -dm-control==1.0.16 -mujoco==3.1.3 diff --git a/tests/functional/src/test_env.py b/tests/functional/test_env.py similarity index 76% rename from tests/functional/src/test_env.py rename to tests/functional/test_env.py index 3f056c4..2cf2ec2 100644 --- a/tests/functional/src/test_env.py +++ b/tests/functional/test_env.py @@ -1,8 +1,9 @@ import pytest -from oprl.env import DMControlEnv +from oprl.env import make_env -dm_control_envs = [ + +dm_control_envs: list[str] = [ "acrobot-swingup", "ball_in_cup-catch", "cartpole-balance", @@ -30,9 +31,20 @@ ] -@pytest.mark.parametrize("env_name", dm_control_envs) -def test_dm_control_envs(env_name: str): - env = DMControlEnv(env_name, seed=0) +safety_envs: list[str] = [ + "SafetyPointGoal1-v0", + "SafetyPointButton1-v0", + "SafetyPointPush1-v0", + "SafetyPointCircle1-v0", +] + + +env_names: list[str] = dm_control_envs # + safety_envs + + +@pytest.mark.parametrize("env_name", env_names) +def test_envs(env_name: str) -> None: + env = make_env(env_name, seed=0) obs, info = env.reset() assert obs.shape[0] == env.observation_space.shape[0] assert isinstance(info, dict), "Info is expected to be a dict" diff --git a/tests/functional/src/test_rl_algos.py b/tests/functional/test_rl_algos.py similarity index 100% rename from tests/functional/src/test_rl_algos.py rename to tests/functional/test_rl_algos.py