From 18f0a0a2b9d3356e5d61fb91dbf4e9427990a2b5 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sat, 5 Jul 2025 20:35:35 +0400 Subject: [PATCH 01/30] Move configs from package and refine --- {src/oprl/configs => configs}/d3pg.py | 0 {src/oprl/configs => configs}/ddpg.py | 27 ++++++++++--------- {src/oprl/configs => configs}/sac.py | 31 +++++++++++----------- {src/oprl/configs => configs}/td3.py | 30 ++++++++++++---------- {src/oprl/configs => configs}/tqc.py | 29 +++++++++++---------- src/oprl/configs/utils.py | 35 ------------------------- src/oprl/trainers/base_trainer.py | 37 +++------------------------ src/oprl/utils/run_training.py | 6 +---- src/oprl/utils/utils.py | 13 ---------- 9 files changed, 67 insertions(+), 141 deletions(-) rename {src/oprl/configs => configs}/d3pg.py (100%) rename {src/oprl/configs => configs}/ddpg.py (70%) rename {src/oprl/configs => configs}/sac.py (68%) rename {src/oprl/configs => configs}/td3.py (67%) rename {src/oprl/configs => configs}/tqc.py (69%) delete mode 100644 src/oprl/configs/utils.py diff --git a/src/oprl/configs/d3pg.py b/configs/d3pg.py similarity index 100% rename from src/oprl/configs/d3pg.py rename to configs/d3pg.py diff --git a/src/oprl/configs/ddpg.py b/configs/ddpg.py similarity index 70% rename from src/oprl/configs/ddpg.py rename to configs/ddpg.py index 22437f0..afd34b9 100644 --- a/src/oprl/configs/ddpg.py +++ b/configs/ddpg.py @@ -1,12 +1,14 @@ -import logging - +from oprl.algos import OffPolicyAlgorithm from oprl.algos.ddpg import DDPG -from oprl.configs.utils import create_logdir, parse_args -from oprl.utils.utils import set_logging - -set_logging(logging.INFO) +from oprl.parse_args import parse_args +from oprl.logging import ( + create_logdir, + set_logging, + FileTxtLogger, + LoggerProtocol +) +set_logging() from oprl.env import make_env as _make_env -from oprl.utils.logger import FileLogger, Logger from oprl.utils.run_training import run_training args = parse_args() @@ -29,7 +31,6 @@ def make_env(seed: int): "num_steps": int(100_000), "eval_every": 2500, "device": args.device, - "save_buffer": False, "visualise_every": 50000, "estimate_q_every": 5000, "log_every": 2500, @@ -38,18 +39,20 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger): +def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: return DDPG( state_dim=STATE_DIM, action_dim=ACTION_DIM, device=args.device, logger=logger, - ) + ).create() -def make_logger(seed: int) -> Logger: +def make_logger(seed: int) -> LoggerProtocol: log_dir = create_logdir(logdir="logs", algo="DDPG", env=args.env, seed=seed) - return FileLogger(log_dir, config) + logger = FileTxtLogger(log_dir, config) + logger.copy_source_code() + return logger if __name__ == "__main__": diff --git a/src/oprl/configs/sac.py b/configs/sac.py similarity index 68% rename from src/oprl/configs/sac.py rename to configs/sac.py index 475fd0e..b9c09f4 100644 --- a/src/oprl/configs/sac.py +++ b/configs/sac.py @@ -1,16 +1,16 @@ -import logging - +from oprl.algos import OffPolicyAlgorithm from oprl.algos.sac import SAC -from oprl.configs.utils import create_logdir, parse_args -from oprl.utils.utils import set_logging - -set_logging(logging.INFO) +from oprl.parse_args import parse_args +from oprl.logging import ( + create_logdir, + set_logging, + FileTxtLogger, + LoggerProtocol +) +set_logging() from oprl.env import make_env as _make_env -from oprl.utils.logger import FileLogger, Logger from oprl.utils.run_training import run_training -logging.basicConfig(level=logging.INFO) - args = parse_args() @@ -28,10 +28,9 @@ def make_env(seed: int): config = { "state_dim": STATE_DIM, "action_dim": ACTION_DIM, - "num_steps": int(1_000_000), + "num_steps": int(100_000), "eval_every": 2500, "device": args.device, - "save_buffer": False, "visualise_every": 0, "estimate_q_every": 5000, "log_every": 1000, @@ -40,18 +39,20 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger): +def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: return SAC( state_dim=STATE_DIM, action_dim=ACTION_DIM, device=args.device, logger=logger, - ) + ).create() -def make_logger(seed: int) -> Logger: +def make_logger(seed: int) -> LoggerProtocol: log_dir = create_logdir(logdir="logs", algo="SAC", env=args.env, seed=seed) - return FileLogger(log_dir, config) + logger = FileTxtLogger(log_dir, config) + logger.copy_source_code() + return logger if __name__ == "__main__": diff --git a/src/oprl/configs/td3.py b/configs/td3.py similarity index 67% rename from src/oprl/configs/td3.py rename to configs/td3.py index c8dac65..6a996be 100644 --- a/src/oprl/configs/td3.py +++ b/configs/td3.py @@ -1,12 +1,14 @@ -import logging - +from oprl.algos import OffPolicyAlgorithm from oprl.algos.td3 import TD3 -from oprl.configs.utils import create_logdir, parse_args -from oprl.utils.utils import set_logging - -set_logging(logging.INFO) +from oprl.parse_args import parse_args +from oprl.logging import ( + create_logdir, + set_logging, + FileTxtLogger, + LoggerProtocol +) +set_logging() from oprl.env import make_env as _make_env -from oprl.utils.logger import FileLogger, Logger from oprl.utils.run_training import run_training args = parse_args() @@ -26,11 +28,9 @@ def make_env(seed: int): config = { "state_dim": STATE_DIM, "action_dim": ACTION_DIM, - "num_steps": int(1_000_000), + "num_steps": int(100_000), "eval_every": 2500, "device": args.device, - "save_buffer": False, - "visualise_every": 0, "estimate_q_every": 5000, "log_every": 2500, } @@ -38,18 +38,20 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger): +def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: return TD3( state_dim=STATE_DIM, action_dim=ACTION_DIM, device=args.device, logger=logger, - ) + ).create() -def make_logger(seed: int) -> Logger: +def make_logger(seed: int) -> LoggerProtocol: log_dir = create_logdir(logdir="logs", algo="TD3", env=args.env, seed=seed) - return FileLogger(log_dir, config) + logger = FileTxtLogger(log_dir, config) + logger.copy_source_code() + return logger if __name__ == "__main__": diff --git a/src/oprl/configs/tqc.py b/configs/tqc.py similarity index 69% rename from src/oprl/configs/tqc.py rename to configs/tqc.py index 640071e..13ccc5b 100644 --- a/src/oprl/configs/tqc.py +++ b/configs/tqc.py @@ -1,12 +1,14 @@ -import logging - +from oprl.algos import OffPolicyAlgorithm from oprl.algos.tqc import TQC -from oprl.configs.utils import create_logdir, parse_args -from oprl.utils.utils import set_logging - -set_logging(logging.INFO) +from oprl.parse_args import parse_args +from oprl.logging import ( + create_logdir, + set_logging, + FileTxtLogger, + LoggerProtocol +) +set_logging() from oprl.env import make_env as _make_env -from oprl.utils.logger import FileLogger, Logger from oprl.utils.run_training import run_training args = parse_args() @@ -26,10 +28,9 @@ def make_env(seed: int): config = { "state_dim": STATE_DIM, "action_dim": ACTION_DIM, - "num_steps": int(1_000_000), + "num_steps": int(100_000), "eval_every": 2500, "device": args.device, - "save_buffer": False, "visualise_every": 0, "estimate_q_every": 0, # TODO: Here is the unsupported logic "log_every": 2500, @@ -38,18 +39,20 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger: Logger): +def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: return TQC( state_dim=STATE_DIM, action_dim=ACTION_DIM, device=args.device, logger=logger, - ) + ).create() -def make_logger(seed: int) -> Logger: +def make_logger(seed: int) -> LoggerProtocol: log_dir = create_logdir(logdir="logs", algo="TQC", env=args.env, seed=seed) - return FileLogger(log_dir, config) + logger = FileTxtLogger(log_dir, config) + logger.copy_source_code() + return logger if __name__ == "__main__": diff --git a/src/oprl/configs/utils.py b/src/oprl/configs/utils.py deleted file mode 100644 index 50c16e7..0000000 --- a/src/oprl/configs/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -import argparse -import logging -import os -from datetime import datetime - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Run training") - parser.add_argument("--config", type=str, help="Path to the config file.") - parser.add_argument( - "--env", type=str, default="cartpole-balance", help="Name of the environment." - ) - parser.add_argument( - "--seeds", - type=int, - default=1, - help="Number of parallel processes launched with different random seeds.", - ) - parser.add_argument( - "--start_seed", - type=int, - default=0, - help="Number of the first seed. Following seeds will be incremented from it.", - ) - parser.add_argument( - "--device", type=str, default="cpu", help="Device to perform training on." - ) - return parser.parse_args() - - -def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str: - dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss") - log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}") - logging.info(f"LOGDIR: {log_dir}") - return log_dir diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index c9fccb2..adaa6ab 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -1,4 +1,3 @@ -import os from typing import Any, Callable import numpy as np @@ -6,7 +5,7 @@ from oprl.env import BaseEnv from oprl.trainers.buffers.episodic_buffer import EpisodicReplayBuffer -from oprl.utils.logger import Logger, StdLogger +from oprl.logging import LoggerProtocol class BaseTrainer: @@ -25,13 +24,12 @@ def __init__( eval_interval: int = int(2e3), num_eval_episodes: int = 10, save_buffer_every: int = 0, - save_policy_every: int = int(50_000), - visualise_every: int = 0, + save_policy_every: int = int(100_000), estimate_q_every: int = 0, stdout_log_every: int = int(1e5), device: str = "cpu", seed: int = 0, - logger: Logger = StdLogger(), + logger: LoggerProtocol | None = None, ): """ Args: @@ -59,7 +57,6 @@ def __init__( self._gamma = gamma self._device = device self._save_buffer_every = save_buffer_every - self._visualize_every = visualise_every self._estimate_q_every = estimate_q_every self._stdout_log_every = stdout_log_every self._save_policy_every= save_policy_every @@ -107,7 +104,6 @@ def train(self): self._algo.update(*batch) self._eval_routine(env_step, batch) - self._visualize(env_step) self._save_buffer(env_step) self._save_policy(env_step) self._log_stdout(env_step, batch) @@ -148,12 +144,6 @@ def _log_evaluation(self, env_step: int): mean_return = np.mean(returns) self._logger.log_scalar("trainer/ep_reward", mean_return, env_step) - def _visualize(self, env_step: int): - if self._visualize_every > 0 and env_step % self._visualize_every == 0: - imgs = self.visualise_policy() # [T, W, H, C] - if imgs is not None: - self._logger.log_video("eval_policy", imgs, env_step) - def _save_buffer(self, env_step: int): # TODO: doesn't work if self._save_buffer_every > 0 and env_step % self._save_buffer_every == 0: @@ -181,26 +171,6 @@ def _log_stdout(self, env_step: int, batch): f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {batch[2].mean():10.3f}" ) - def visualise_policy(self): - """ - returned shape: [N, C, W, H] - """ - env = self._make_env_test(seed=self.seed) - try: - imgs = [] - state, _ = env.reset() - done = False - while not done: - img = env.render() - imgs.append(img) - action = self._algo.exploit(state) - state, _, terminated, truncated, _ = env.step(action) - done = terminated or truncated - return np.concatenate(imgs, dtype="uint8") - except Exception as e: - print(f"Failed to visualise a policy: {e}") - return None - def estimate_true_q(self, eval_episodes: int = 10) -> float | None: try: qs = [] @@ -261,7 +231,6 @@ def run_training(make_algo, make_env, make_logger, config: dict[str, Any], seed: eval_interval=config["eval_every"], device=config["device"], save_buffer_every=config["save_buffer"], - visualise_every=config["visualise_every"], estimate_q_every=config["estimate_q_every"], stdout_log_every=config["log_every"], seed=seed, diff --git a/src/oprl/utils/run_training.py b/src/oprl/utils/run_training.py index d329a47..1458204 100644 --- a/src/oprl/utils/run_training.py +++ b/src/oprl/utils/run_training.py @@ -24,11 +24,9 @@ def run_training( for i, p in enumerate(processes): p.start() logging.info(f"Starting process {i}...") - for p in processes: p.join() - - logging.info("Training OK.") + logging.info("Training finished.") def _run_training_func(make_algo, make_env, make_logger, config, seed: int): @@ -52,8 +50,6 @@ def _run_training_func(make_algo, make_env, make_logger, config, seed: int): num_steps=config["num_steps"], eval_interval=config["eval_every"], device=config["device"], - save_buffer_every=config["save_buffer"], - visualise_every=config["visualise_every"], estimate_q_every=config["estimate_q_every"], stdout_log_every=config["log_every"], seed=seed, diff --git a/src/oprl/utils/utils.py b/src/oprl/utils/utils.py index 29736cc..d168b0d 100644 --- a/src/oprl/utils/utils.py +++ b/src/oprl/utils/utils.py @@ -77,22 +77,9 @@ def empty_torch_queue(q): q.close() -def copy_exp_dir(log_dir: str) -> None: - cur_dir = os.path.join(os.getcwd(), "src") - dest_dir = os.path.join(log_dir, "src") - shutil.copy(cur_dir, dest_dir) - logging.info(f"Source copied into {dest_dir}") - - def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) t.manual_seed(seed) -def set_logging(level: int): - logging.basicConfig( - level=level, - format="%(asctime)s | %(filename)s:%(lineno)d\t %(levelname)s - %(message)s", - stream=sys.stdout, - ) From 47dd28007e334d86a730c0600a317498ad0d364e Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sun, 6 Jul 2025 23:19:38 +0400 Subject: [PATCH 02/30] Replace inheritance with composition for trainer, make visualization script --- configs/ddpg.py | 23 +++- scripts/visualize_policy_from_weights.py | 89 +++++++++++++ src/oprl/algos/nn.py | 7 +- src/oprl/buffers/__init__.py | 0 .../{trainers => }/buffers/episodic_buffer.py | 79 +++++------ src/oprl/distrib/__init__.py | 0 src/oprl/env.py | 5 +- src/oprl/logging.py | 95 +++++++++++++ src/oprl/parse_args.py | 26 ++++ src/oprl/trainers/__init__.py | 0 src/oprl/trainers/base_trainer.py | 55 ++------ src/oprl/trainers/safe_trainer.py | 126 ++++++------------ src/oprl/utils/run_training.py | 31 ++--- 13 files changed, 350 insertions(+), 186 deletions(-) create mode 100644 scripts/visualize_policy_from_weights.py create mode 100644 src/oprl/buffers/__init__.py rename src/oprl/{trainers => }/buffers/episodic_buffer.py (72%) create mode 100644 src/oprl/distrib/__init__.py create mode 100644 src/oprl/logging.py create mode 100644 src/oprl/parse_args.py create mode 100644 src/oprl/trainers/__init__.py diff --git a/configs/ddpg.py b/configs/ddpg.py index afd34b9..3180549 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -1,5 +1,6 @@ from oprl.algos import OffPolicyAlgorithm from oprl.algos.ddpg import DDPG +from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer from oprl.parse_args import parse_args from oprl.logging import ( create_logdir, @@ -33,6 +34,7 @@ def make_env(seed: int): "device": args.device, "visualise_every": 50000, "estimate_q_every": 5000, + "gamma": 0.99, "log_every": 2500, } @@ -44,10 +46,21 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: state_dim=STATE_DIM, action_dim=ACTION_DIM, device=args.device, + discount=config["gamma"], logger=logger, ).create() +def make_replay_buffer() -> ReplayBufferProtocol: + return EpisodicReplayBuffer( + buffer_size=max(config["num_steps"], int(1e6)), + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + device=config["device"], + gamma=config["gamma"], + ) + + def make_logger(seed: int) -> LoggerProtocol: log_dir = create_logdir(logdir="logs", algo="DDPG", env=args.env, seed=seed) logger = FileTxtLogger(log_dir, config) @@ -57,4 +70,12 @@ def make_logger(seed: int) -> LoggerProtocol: if __name__ == "__main__": args = parse_args() - run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) + run_training( + make_algo=make_algo, + make_env=make_env, + make_replay_buffer=make_replay_buffer, + make_logger=make_logger, + config=config, + seeds=args.seeds, + start_seed=args.start_seed + ) diff --git a/scripts/visualize_policy_from_weights.py b/scripts/visualize_policy_from_weights.py new file mode 100644 index 0000000..659cbd5 --- /dev/null +++ b/scripts/visualize_policy_from_weights.py @@ -0,0 +1,89 @@ +import click +import torch +import numpy as np +from PIL import Image + +from oprl.env import make_env + + +def create_webp_gif(numpy_arrays, output_path, duration=100, loop=0): + """ + Create a WebP animated image from a list of NumPy arrays. + + Args: + numpy_arrays: List of NumPy arrays (each representing a frame) + output_path: Output file path (should end with .webp) + duration: Duration between frames in milliseconds + loop: Number of loops (0 = infinite loop) + """ + # Convert NumPy arrays to PIL Images + pil_images = [] + + for arr in numpy_arrays: + # Ensure the array is in the right format + if arr.dtype != np.uint8: + # Normalize to 0-255 range if needed + if arr.max() <= 1.0: + arr = (arr * 255).astype(np.uint8) + else: + arr = arr.astype(np.uint8) + + # Handle different array shapes + if len(arr.shape) == 2: # Grayscale + img = Image.fromarray(arr, mode='L') + elif len(arr.shape) == 3: # RGB/RGBA + if arr.shape[2] == 3: + img = Image.fromarray(arr, mode='RGB') + elif arr.shape[2] == 4: + img = Image.fromarray(arr, mode='RGBA') + else: + raise ValueError(f"Unsupported number of channels: {arr.shape[2]}") + else: + raise ValueError(f"Unsupported array shape: {arr.shape}") + + pil_images.append(img) + + # Save as animated WebP + pil_images[0].save( + output_path, + format='WebP', + save_all=True, + append_images=pil_images[1:], + duration=duration, + loop=loop, + optimize=True + ) + + +@click.command() +@click.option("--policy", "-p", help="Path to policy weights.") +@click.option("--output", "-o", default="policy.webp", help="Path to output file.") +@click.option("--env", "-e", default="walker-walk", help="Environemnt name.") +@click.option("--seed", "-s", default=0, help="Environment seed.") +def visualize_policy(policy, output, env, seed): + env = make_env(env, seed=seed) + + actor = torch.load(policy, weights_only=False) + print("Actor loaded: ", type(actor)) + + imgs = [] + state, _ = env.reset() + done = False + while not done: + img = np.expand_dims(env.render(), axis=0) # [1, W, H, C] + imgs.append(img) + action = actor.exploit(torch.from_numpy(state)) + state, _, terminated, truncated, _ = env.step(action) + done = terminated or truncated + + print("imgs: ", len(imgs), imgs[0].shape) + frames = np.concatenate(imgs, dtype="uint8", axis=0) + print("frames: ", frames.shape) + + # Create the WebP GIF + create_webp_gif(frames, output, duration=25) + print("WebP GIF for dm_control created successfully!") + + +if __name__ == "__main__": + visualize_policy() diff --git a/src/oprl/algos/nn.py b/src/oprl/algos/nn.py index cbac507..935f535 100644 --- a/src/oprl/algos/nn.py +++ b/src/oprl/algos/nn.py @@ -125,7 +125,7 @@ def forward(self, states: t.Tensor) -> t.Tensor: def exploit(self, state: npt.ArrayLike) -> npt.NDArray: state = t.tensor(state).unsqueeze_(0).to(self._device) - return self.forward(state).cpu().numpy().flatten() + return self.forward(state).detach().cpu().numpy().flatten() def explore(self, state: npt.ArrayLike) -> npt.NDArray: state = t.tensor(state, device=self._device).unsqueeze_(0) @@ -161,6 +161,11 @@ def forward(self, obs: t.Tensor) -> tuple[t.Tensor, t.Tensor | None]: log_prob = None return action, log_prob + def exploit(self, state: npt.ArrayLike) -> npt.NDArray: + state = t.tensor(state).unsqueeze_(0).to(self.device) + action, _ = self.forward(state) + return action.detach().cpu().numpy().flatten() + @property def device(self): return next(self.parameters()).device diff --git a/src/oprl/buffers/__init__.py b/src/oprl/buffers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/oprl/trainers/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py similarity index 72% rename from src/oprl/trainers/buffers/episodic_buffer.py rename to src/oprl/buffers/episodic_buffer.py index 8f7b7cf..34733c7 100644 --- a/src/oprl/trainers/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -1,8 +1,25 @@ import os import pickle +from typing import Protocol import numpy as np -import torch +import numpy.typing as npt +import torch as t + + +class ReplayBufferProtocol(Protocol): + def add_transition(self, state, action, reward, done, episode_done=None): ... + + def add_episode(self, episode): ... + + def sample(self, batch_size) -> tuple[ + t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor + ]: ... + + def save(self, path: str) -> None: ... + + @property + def last_episode_length(self) -> int: ... class EpisodicReplayBuffer: @@ -14,18 +31,8 @@ def __init__( device: str, gamma: float, max_episode_len: int = 1000, - dtype=torch.float, + dtype=t.float, ): - """ - Args: - buffer_size: Max number of transitions in buffer. - state_dim: Dimension of the state. - action_dim: Dimension of the action. - device: Device to place buffer. - gamma: Discount factor for N-step. - max_episode_len: Max length of the episode to store. - dtype: Data type. - """ self.buffer_size = buffer_size self.max_episodes = buffer_size // max_episode_len self.max_episode_len = max_episode_len @@ -38,50 +45,41 @@ def __init__( self.cur_episodes = 1 self.cur_size = 0 - self.actions = torch.empty( + self.actions = t.empty( (self.max_episodes, max_episode_len, action_dim), dtype=dtype, device=device, ) - self.rewards = torch.empty( + self.rewards = t.empty( (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device ) - self.dones = torch.empty( + self.dones = t.empty( (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device ) - self.states = torch.empty( + self.states = t.empty( (self.max_episodes, max_episode_len + 1, state_dim), dtype=dtype, device=device, ) self.ep_lens = [0] * self.max_episodes - self.actions_for_std = torch.empty( + self.actions_for_std = t.empty( (100, action_dim), dtype=dtype, device=device ) self.actions_for_std_cnt = 0 - # TODO: rename to add - def append(self, state, action, reward, done, episode_done=None): - """ - Args: - state: state. - action: action. - reward: reward. - done: done only if episode ends naturally. - episode_done: done that can be set to True if time limit is reached. - """ + def add_transition(self, state: npt.ArrayLike, action: npt.ArrayLike, reward: float, done: bool, episode_done: bool | None = None): self.states[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_( - torch.from_numpy(state) + t.from_numpy(state) ) self.actions[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_( - torch.from_numpy(action) + t.from_numpy(action) ) self.rewards[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(reward) self.dones[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(done) self.actions_for_std[self.actions_for_std_cnt % 100].copy_( - torch.from_numpy(action) + t.from_numpy(action) ) self.actions_for_std_cnt += 1 @@ -96,9 +94,9 @@ def _inc_episode(self): self.cur_size -= self.ep_lens[self.ep_pointer] self.ep_lens[self.ep_pointer] = 0 - def add_episode(self, episode): - for s, a, r, d, s_ in episode: - self.append(s, a, r, d, episode_done=d) + def add_episode(self, episode: list): + for s, a, r, d, _ in episode: + self.add_transition(s, a, r, d, episode_done=d) if d: break else: @@ -127,10 +125,6 @@ def sample(self, batch_size): ) def save(self, path: str): - """ - Args: - path: Path to pickle file. - """ dirname = os.path.dirname(path) if not os.path.exists(dirname): os.makedirs(dirname) @@ -149,12 +143,13 @@ def save(self, path: str): except Exception as e: print(f"Failed to save replay buffer: {e}") - def __len__(self): + def __len__(self) -> int: return self.cur_size - @property - def num_episodes(self): - return self.cur_episodes + # @property + # def num_episodes(self): + # return self.cur_episodes - def get_last_ep_len(self): + @property + def last_episode_length(self): return self.ep_lens[self.ep_pointer] diff --git a/src/oprl/distrib/__init__.py b/src/oprl/distrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/oprl/env.py b/src/oprl/env.py index 6017ab2..d777ecf 100644 --- a/src/oprl/env.py +++ b/src/oprl/env.py @@ -48,7 +48,7 @@ class SafetyGym(BaseEnv): def __init__(self, env_name: str, seed: int): import safety_gymnasium as gym - self._env = gym.make(env_name) + self._env = gym.make(env_name, render_mode='rgb_array', camera_name="fixednear") self._seed = seed def step( @@ -62,6 +62,9 @@ def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: obs, info = self._env.reset(seed=self._seed) self._env.step(self._env.action_space.sample()) return obs.astype("float32"), info + + def render(self) -> npt.NDArray: + return self._env.render() def sample_action(self): return self._env.action_space.sample() diff --git a/src/oprl/logging.py b/src/oprl/logging.py new file mode 100644 index 0000000..3a92384 --- /dev/null +++ b/src/oprl/logging.py @@ -0,0 +1,95 @@ +import os +import sys +import logging +from datetime import datetime +import json +import shutil +from abc import ABC, abstractmethod +from typing import Any, Protocol + +import torch as t +import torch.nn as nn +from torch.utils.tensorboard.writer import SummaryWriter + + +def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str: + dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss") + log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}") + logging.info(f"LOGDIR: {log_dir}") + return log_dir + + +def set_logging(level: int = logging.INFO) -> None: + logging.basicConfig( + level=level, + format="%(asctime)s | %(filename)s:%(lineno)d\t %(levelname)s - %(message)s", + stream=sys.stdout, + ) + + +def copy_exp_dir(log_dir: str) -> None: + cur_dir = os.path.join(os.getcwd(), "src") + dest_dir = os.path.join(log_dir, "src") + shutil.copytree(cur_dir, dest_dir) + logging.info(f"Source copied into {dest_dir}") + + +def save_json_config(config: dict[str, Any], path: str): + with open(path, "w") as f: + json.dump(config, f) + + +class LoggerProtocol(Protocol): + def log_scalar(self, tag: str, value: float, step: int) -> None: ... + + def log_scalars(self, values: dict[str, float], step: int) -> None: ... + + +class BaseLogger(ABC): + @abstractmethod + def log_scalar(self, tag: str, value: float, step: int) -> None: + ... + + def log_scalars(self, values: dict[str, float], step: int) -> None: + """ + Args: + values: Dict with tag -> value to log. + step: Iter step. + """ + (self.log_scalar(k, v, step) for k, v in values.items()) + + +class StdLogger(BaseLogger): + def log_scalar(self, tag: str, value: float, step: int) -> None: + logging.info(f"{tag}\t{value}\tat step {step}") + + +class FileTxtLogger(BaseLogger): + def __init__(self, logdir: str, config: dict[str, Any]) -> None: + self.writer = SummaryWriter(logdir) + self.log_dir = logdir + self.config = config + + def copy_source_code(self) -> None: + logging.info(f"Source code is copied to {self.log_dir}") + copy_exp_dir(self.log_dir) + save_json_config(self.config, os.path.join(self.log_dir, "config.json")) + + def log_scalar(self, tag: str, value: float, step: int) -> None: + self.writer.add_scalar(tag, value, step) + self._log_scalar_to_file(tag, value, step) + + def save_weights(self, weights: nn.Module, step: int) -> None: + os.makedirs(os.path.join(self.log_dir, "weights"), exist_ok=True) + fn = os.path.join(self.log_dir, "weights", f"step_{step}.w") + t.save( + weights, + fn + ) + + def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None: + fn = os.path.join(self.log_dir, f"{tag}.log") + os.makedirs(os.path.dirname(fn), exist_ok=True) + with open(fn, "a") as f: + f.write(f"{step} {val}\n") + diff --git a/src/oprl/parse_args.py b/src/oprl/parse_args.py new file mode 100644 index 0000000..d01867e --- /dev/null +++ b/src/oprl/parse_args.py @@ -0,0 +1,26 @@ +import argparse + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run training") + parser.add_argument("--config", type=str, help="Path to the config file.") + parser.add_argument( + "--env", type=str, default="cartpole-balance", help="Name of the environment." + ) + parser.add_argument( + "--seeds", + type=int, + default=1, + help="Number of parallel processes launched with different random seeds.", + ) + parser.add_argument( + "--start_seed", + type=int, + default=0, + help="Number of the first seed. Following seeds will be incremented from it.", + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Device to perform training on." + ) + return parser.parse_args() + diff --git a/src/oprl/trainers/__init__.py b/src/oprl/trainers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index adaa6ab..2248693 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -1,22 +1,22 @@ +from dataclasses import dataclass from typing import Any, Callable import numpy as np import torch from oprl.env import BaseEnv -from oprl.trainers.buffers.episodic_buffer import EpisodicReplayBuffer +from oprl.buffers.episodic_buffer import EpisodicReplayBuffer, ReplayBufferProtocol from oprl.logging import LoggerProtocol +@dataclass class BaseTrainer: def __init__( self, - state_dim: int, - action_dim: int, env: BaseEnv, make_env_test: Callable[[int], BaseEnv], + replay_buffer: ReplayBufferProtocol, algo: Any | None = None, - buffer_size: int = int(1e6), gamma: float = 0.99, num_steps: int = int(1e6), start_steps: int = int(10e3), @@ -31,26 +31,6 @@ def __init__( seed: int = 0, logger: LoggerProtocol | None = None, ): - """ - Args: - state_dim: Dimension of the observation. - action_dim: Dimension of the action. - env: Enviornment object. - make_env_test: Environment object for evaluation. - algo: Codename for the algo (SAC). - buffer_size: Buffer size in transitions. - gamma: Discount factor. - num_step: Number of env steps to train. - start_steps: Number of environment steps not to perform training at the beginning. - batch_size: Batch-size. - eval_interval: Number of env step after which perform evaluation. - save_buffer_every: Number of env steps after which save replay buffer. - visualise_every: Number of env steps after which perform vizualisation. - device: Name of the device. - stdout_log_every: Number of evn steps after which log info to stdout. - seed: Random seed. - logger: Logger instance. - """ self._env = env self._make_env_test = make_env_test self._algo = algo @@ -62,15 +42,7 @@ def __init__( self._save_policy_every= save_policy_every self._logger = logger self.seed = seed - - self.buffer = EpisodicReplayBuffer( - buffer_size=buffer_size, - state_dim=state_dim, - action_dim=action_dim, - device=device, - gamma=gamma, - ) - + self.replay_buffer = replay_buffer self.batch_size = batch_size self.num_steps = num_steps self.start_steps = start_steps @@ -89,7 +61,7 @@ def train(self): action = self._algo.explore(state) next_state, reward, terminated, truncated, _ = self._env.step(action) - self.buffer.append( + self.replay_buffer.add_transition( state, action, reward, terminated, episode_done=terminated or truncated ) if terminated or truncated: @@ -97,10 +69,10 @@ def train(self): ep_step = 0 state = next_state - if len(self.buffer) < self.batch_size: + if len(self.replay_buffer) < self.batch_size: continue - batch = self.buffer.sample(self.batch_size) + batch = self.replay_buffer.sample(self.batch_size) self._algo.update(*batch) self._eval_routine(env_step, batch) @@ -114,14 +86,14 @@ def _eval_routine(self, env_step: int, batch): self._logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) self._logger.log_scalar( - "trainer/buffer_transitions", len(self.buffer), env_step + "trainer/buffer_transitions", len(self.replay_buffer), env_step ) self._logger.log_scalar( - "trainer/buffer_episodes", self.buffer.num_episodes, env_step + "trainer/buffer_episodes", self.replay_buffer.num_episodes, env_step ) self._logger.log_scalar( "trainer/buffer_last_ep_len", - self.buffer.get_last_ep_len(), + self.replay_buffer.get_last_ep_len(), env_step, ) @@ -147,7 +119,7 @@ def _log_evaluation(self, env_step: int): def _save_buffer(self, env_step: int): # TODO: doesn't work if self._save_buffer_every > 0 and env_step % self._save_buffer_every == 0: - self.buffer.save(f"{self.log_dir}/buffers/buffer_step_{env_step}.pickle") + self.replay_buffer.save(f"{self.log_dir}/buffers/buffer_step_{env_step}.pickle") def _save_policy(self, env_step: int): if self._save_policy_every > 0 and env_step % self._save_policy_every == 0: @@ -217,7 +189,7 @@ def estimate_critic_q(self, num_episodes: int = 10) -> float: return np.mean(qs, dtype=float) -def run_training(make_algo, make_env, make_logger, config: dict[str, Any], seed: int): +def run_training(make_algo, make_env, make_replay_buffer, make_logger, config: dict[str, Any], seed: int): env = make_env(seed=seed) logger = make_logger(seed) @@ -227,6 +199,7 @@ def run_training(make_algo, make_env, make_logger, config: dict[str, Any], seed: env=env, make_env_test=make_env, algo=make_algo(logger, seed), + replay_buffer=make_replay_buffer(), num_steps=config["num_steps"], eval_interval=config["eval_every"], device=config["device"], diff --git a/src/oprl/trainers/safe_trainer.py b/src/oprl/trainers/safe_trainer.py index a5edf9f..2b3c33f 100644 --- a/src/oprl/trainers/safe_trainer.py +++ b/src/oprl/trainers/safe_trainer.py @@ -4,116 +4,72 @@ from oprl.env import BaseEnv from oprl.trainers.base_trainer import BaseTrainer -from oprl.utils.logger import Logger, StdLogger +from oprl.logging import LoggerProtocol -class SafeTrainer(BaseTrainer): +class SafeTrainer: def __init__( self, - state_dim: int, - action_dim: int, - env: BaseEnv, - make_env_test: Callable[[int], BaseEnv], - algo: Any | None = None, - buffer_size: int = int(1e6), - gamma: float = 0.99, - num_steps=int(1e6), - start_steps: int = int(10e3), - batch_size: int = 128, - eval_interval: int = int(2e3), - num_eval_episodes: int = 10, - save_buffer_every: int = 0, - save_policy_every: int = int(50_000), - visualise_every: int = 0, - estimate_q_every: int = 0, - stdout_log_every: int = int(1e5), - device: str = "cpu", - seed: int = 0, - logger: Logger = StdLogger(), + trainer: BaseTrainer ): - """ - Args: - state_dim: Dimension of the observation. - action_dim: Dimension of the action. - env: Enviornment object. - make_env_test: Environment object for evaluation. - algo: Codename for the algo (SAC). - buffer_size: Buffer size in transitions. - gamma: Discount factor. - num_step: Number of env steps to train. - start_steps: Number of environment steps not to perform training at the beginning. - batch_size: Batch-size. - eval_interval: Number of env step after which perform evaluation. - save_buffer_every: Number of env steps after which save replay buffer. - visualise_every: Number of env steps after which perform vizualisation. - stdout_log_every: Number of evn steps after which log info to stdout. - device: Name of the device. - seed: Random seed. - logger: Logger instance. - """ - super().__init__( - state_dim=state_dim, - action_dim=action_dim, - env=env, - make_env_test=make_env_test, - algo=algo, - buffer_size=buffer_size, - gamma=gamma, - device=device, - num_steps=num_steps, - start_steps=start_steps, - batch_size=batch_size, - eval_interval=eval_interval, - num_eval_episodes=num_eval_episodes, - save_buffer_every=save_buffer_every, - save_policy_every=save_policy_every, - visualise_every=visualise_every, - estimate_q_every=estimate_q_every, - stdout_log_every=stdout_log_every, - seed=seed, - logger=logger, - ) + self.trainer = trainer def train(self): ep_step = 0 - state, _ = self._env.reset() + state, _ = self.trainer._env.reset() total_cost = 0 - for env_step in range(self.num_steps + 1): + for env_step in range(self.trainer.num_steps + 1): ep_step += 1 - if env_step <= self.start_steps: - action = self._env.sample_action() + if env_step <= self.trainer.start_steps: + action = self.trainer._env.sample_action() else: - action = self._algo.explore(state) - next_state, reward, terminated, truncated, info = self._env.step(action) + action = self.trainer._algo.explore(state) + next_state, reward, terminated, truncated, info = self.trainer._env.step(action) total_cost += info["cost"] - self.buffer.append( + self.trainer.replay_buffer.add_transition( state, action, reward, terminated, episode_done=terminated or truncated ) if terminated or truncated: - next_state, _ = self._env.reset() + next_state, _ = self.trainer._env.reset() ep_step = 0 state = next_state - if len(self.buffer) < self.batch_size: + if len(self.trainer.replay_buffer) < self.trainer.batch_size: continue - batch = self.buffer.sample(self.batch_size) - self._algo.update(*batch) + batch = self.trainer.replay_buffer.sample(self.trainer.batch_size) + self.trainer._algo.update(*batch) self._eval_routine(env_step, batch) - self._visualize(env_step) - self._save_policy(env_step) - self._save_buffer(env_step) - self._log_stdout(env_step, batch) + self.trainer._save_policy(env_step) + self.trainer._save_buffer(env_step) + self.trainer._log_stdout(env_step, batch) + + self.trainer._logger.log_scalar("trainer/total_cost", total_cost, self.trainer.num_steps) + + def _eval_routine(self, env_step: int, batch): + if env_step % self.trainer.eval_interval == 0: + self._log_evaluation(env_step) - self._logger.log_scalar("trainer/total_cost", total_cost, self.num_steps) + self.trainer._logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) + self.trainer._logger.log_scalar( + "trainer/buffer_transitions", len(self.trainer.replay_buffer), env_step + ) + self.trainer._logger.log_scalar( + "trainer/buffer_episodes", self.trainer.replay_buffer.cur_episodes, env_step + ) + self.trainer._logger.log_scalar( + "trainer/buffer_last_ep_len", + self.trainer.replay_buffer.last_episode_length, + env_step, + ) def _log_evaluation(self, env_step: int): returns = [] costs = [] - for i_ep in range(self.num_eval_episodes): - env_test = self._make_env_test(seed=self.seed + i_ep) + for i_ep in range(self.trainer.num_eval_episodes): + env_test = self.trainer._make_env_test(seed=self.trainer.seed + i_ep) state, _ = env_test.reset() episode_return = 0 @@ -121,7 +77,7 @@ def _log_evaluation(self, env_step: int): terminated, truncated = False, False while not (terminated or truncated): - action = self._algo.exploit(state) + action = self.trainer._algo.exploit(state) state, reward, terminated, truncated, info = env_test.step(action) episode_return += reward episode_cost += info["cost"] @@ -129,9 +85,9 @@ def _log_evaluation(self, env_step: int): returns.append(episode_return) costs.append(episode_cost) - self._logger.log_scalar( + self.trainer._logger.log_scalar( "trainer/ep_reward", np.mean(returns, dtype=float), env_step ) - self._logger.log_scalar( + self.trainer._logger.log_scalar( "trainer/ep_cost", np.mean(costs, dtype=float), env_step ) diff --git a/src/oprl/utils/run_training.py b/src/oprl/utils/run_training.py index 1458204..8fdb6a7 100644 --- a/src/oprl/utils/run_training.py +++ b/src/oprl/utils/run_training.py @@ -1,5 +1,6 @@ import logging from multiprocessing import Process +from oprl.env import BaseEnv from oprl.trainers.base_trainer import BaseTrainer from oprl.trainers.safe_trainer import SafeTrainer @@ -7,17 +8,17 @@ def run_training( - make_algo, make_env, make_logger, config, seeds: int = 1, start_seed: int = 0 + make_algo, make_env, make_replay_buffer, make_logger, config, seeds: int = 1, start_seed: int = 0 ): if seeds == 1: - _run_training_func(make_algo, make_env, make_logger, config, 0) + _run_training_func(make_algo, make_env, make_replay_buffer, make_logger, config, 0) else: processes = [] for seed in range(start_seed, start_seed + seeds): processes.append( Process( target=_run_training_func, - args=(make_algo, make_env, make_logger, config, seed), + args=(make_algo, make_env, make_replay_buffer, make_logger, config, seed), ) ) @@ -29,24 +30,18 @@ def run_training( logging.info("Training finished.") -def _run_training_func(make_algo, make_env, make_logger, config, seed: int): +def _run_training_func(make_algo, make_env, make_replay_buffer, make_logger, config, seed: int): set_seed(seed) env = make_env(seed=seed) + replay_buffer = make_replay_buffer() logger = make_logger(seed) + algo = make_algo(logger) - if env.env_family == "dm_control": - trainer_class = BaseTrainer - elif env.env_family == "safety_gymnasium": - trainer_class = SafeTrainer - else: - raise ValueError(f"Unsupported env family: {env.env_family}") - - trainer = trainer_class( - state_dim=config["state_dim"], - action_dim=config["action_dim"], + base_trainer = BaseTrainer( env=env, make_env_test=make_env, - algo=make_algo(logger), + algo=algo, + replay_buffer=replay_buffer, num_steps=config["num_steps"], eval_interval=config["eval_every"], device=config["device"], @@ -55,5 +50,11 @@ def _run_training_func(make_algo, make_env, make_logger, config, seed: int): seed=seed, logger=logger, ) + if env.env_family == "dm_control": + trainer = base_trainer + elif env.env_family == "safety_gymnasium": + trainer = SafeTrainer(trainer=base_trainer) + else: + raise ValueError(f"Unsupported env family: {env.env_family}") trainer.train() From f1193d5f9df5f30766892443edf21796d3202862 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Mon, 7 Jul 2025 02:43:31 +0400 Subject: [PATCH 03/30] Refactor replay buffer --- configs/ddpg.py | 2 +- src/oprl/buffers/episodic_buffer.py | 152 +++++++++++++++------------- src/oprl/trainers/base_trainer.py | 6 +- 3 files changed, 87 insertions(+), 73 deletions(-) diff --git a/configs/ddpg.py b/configs/ddpg.py index 3180549..6c647b5 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -58,7 +58,7 @@ def make_replay_buffer() -> ReplayBufferProtocol: action_dim=ACTION_DIM, device=config["device"], gamma=config["gamma"], - ) + ).create() def make_logger(seed: int) -> LoggerProtocol: diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index 34733c7..fcf9b3b 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os import pickle from typing import Protocol @@ -22,77 +23,96 @@ def save(self, path: str) -> None: ... def last_episode_length(self) -> int: ... +@dataclass class EpisodicReplayBuffer: - def __init__( - self, - buffer_size: int, - state_dim: int, - action_dim: int, - device: str, - gamma: float, - max_episode_len: int = 1000, - dtype=t.float, - ): - self.buffer_size = buffer_size - self.max_episodes = buffer_size // max_episode_len - self.max_episode_len = max_episode_len - self.state_dim = state_dim - self.action_dim = action_dim - self.device = device - self.gamma = gamma - - self.ep_pointer = 0 - self.cur_episodes = 1 - self.cur_size = 0 - - self.actions = t.empty( - (self.max_episodes, max_episode_len, action_dim), - dtype=dtype, - device=device, - ) - self.rewards = t.empty( - (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device - ) - self.dones = t.empty( - (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device - ) - self.states = t.empty( - (self.max_episodes, max_episode_len + 1, state_dim), - dtype=dtype, - device=device, - ) - self.ep_lens = [0] * self.max_episodes + buffer_size: int + state_dim: int + action_dim: int + gamma: float + max_episode_lenth: int = 1000 + device: str = "cpu" + + _tensors: dict[str, t.Tensor] = field(default_factory=dict, init=False) + _max_episodes: int | None = None + _ep_pointer: int = 0 + episodes_counter: int = 1 + _number_transitions = 0 + _created: bool = False + + def _check_if_created(self) -> None: + if not self._created: + raise RuntimeError("Trying to work with non created buffer. Invoke .create() first.") + + def create(self) -> "EpisodicReplayBuffer": + self._max_episodes = self.buffer_size // self.max_episode_lenth + self._tensors = { + "actions": t.empty( + (self._max_episodes, self.max_episode_lenth, self.action_dim), + dtype=t.float32, + device=self.device, + ), + "rewards": t.empty( + (self._max_episodes, self.max_episode_lenth, 1), + dtype=t.float32, + device=self.device + ), + "dones": t.empty( + (self._max_episodes, self.max_episode_lenth, 1), + dtype=t.float32, + device=self.device + ), + "states": t.empty( + (self._max_episodes, self.max_episode_lenth + 1, self.state_dim), + dtype=t.float32, + device=self.device, + ), + } + self.ep_lens = [0] * self._max_episodes + self._created = True + return self - self.actions_for_std = t.empty( - (100, action_dim), dtype=dtype, device=device - ) - self.actions_for_std_cnt = 0 + @property + def states(self) -> t.Tensor: + self._check_if_created() + return self._tensors["states"] + + @property + def actions(self) -> t.Tensor: + self._check_if_created() + return self._tensors["actions"] + + @property + def rewards(self) -> t.Tensor: + self._check_if_created() + return self._tensors["rewards"] + + @property + def dones(self) -> t.Tensor: + self._check_if_created() + return self._tensors["dones"] def add_transition(self, state: npt.ArrayLike, action: npt.ArrayLike, reward: float, done: bool, episode_done: bool | None = None): - self.states[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_( + self.states[self._ep_pointer, self.ep_lens[self._ep_pointer]].copy_( t.from_numpy(state) ) - self.actions[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_( + self.actions[self._ep_pointer, self.ep_lens[self._ep_pointer]].copy_( t.from_numpy(action) ) - self.rewards[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(reward) - self.dones[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(done) + self.rewards[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(reward) + self.dones[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(done) + + self.ep_lens[self._ep_pointer] += 1 - self.actions_for_std[self.actions_for_std_cnt % 100].copy_( - t.from_numpy(action) - ) - self.actions_for_std_cnt += 1 - self.ep_lens[self.ep_pointer] += 1 - self.cur_size = min(self.cur_size + 1, self.buffer_size) + self._number_transitions = min(self._number_transitions + 1, self.buffer_size) if episode_done: self._inc_episode() def _inc_episode(self): - self.ep_pointer = (self.ep_pointer + 1) % self.max_episodes - self.cur_episodes = min(self.cur_episodes + 1, self.max_episodes) - self.cur_size -= self.ep_lens[self.ep_pointer] - self.ep_lens[self.ep_pointer] = 0 + self._ep_pointer = (self._ep_pointer + 1) % self._max_episodes + self.episodes_counter = min(self.episodes_counter + 1, self._max_episodes) + self._number_transitions -= self.ep_lens[self._ep_pointer] + self.ep_lens[self._ep_pointer] = 0 def add_episode(self, episode: list): for s, a, r, d, _ in episode: @@ -103,8 +123,8 @@ def add_episode(self, episode: list): self._inc_episode() def _inds_to_episodic(self, inds): - start_inds = np.cumsum([0] + self.ep_lens[: self.cur_episodes - 1]) - end_inds = start_inds + np.array(self.ep_lens[: self.cur_episodes]) + start_inds = np.cumsum([0] + self.ep_lens[: self.episodes_counter - 1]) + end_inds = start_inds + np.array(self.ep_lens[: self.episodes_counter]) ep_inds = np.argmin( inds.reshape(-1, 1) >= np.tile(end_inds, (len(inds), 1)), axis=1 ) @@ -113,7 +133,7 @@ def _inds_to_episodic(self, inds): return ep_inds, step_inds def sample(self, batch_size): - inds = np.random.randint(low=0, high=self.cur_size, size=batch_size) + inds = np.random.randint(low=0, high=self._number_transitions, size=batch_size) ep_inds, step_inds = self._inds_to_episodic(inds) return ( @@ -143,13 +163,9 @@ def save(self, path: str): except Exception as e: print(f"Failed to save replay buffer: {e}") - def __len__(self) -> int: - return self.cur_size - - # @property - # def num_episodes(self): - # return self.cur_episodes - @property def last_episode_length(self): - return self.ep_lens[self.ep_pointer] + return self.ep_lens[self._ep_pointer] + + def __len__(self) -> int: + return self._number_transitions diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index 2248693..46809cd 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -89,11 +89,11 @@ def _eval_routine(self, env_step: int, batch): "trainer/buffer_transitions", len(self.replay_buffer), env_step ) self._logger.log_scalar( - "trainer/buffer_episodes", self.replay_buffer.num_episodes, env_step + "trainer/buffer_episodes", self.replay_buffer.episodes_counter, env_step ) self._logger.log_scalar( "trainer/buffer_last_ep_len", - self.replay_buffer.get_last_ep_len(), + self.replay_buffer.last_episode_length, env_step, ) @@ -194,8 +194,6 @@ def run_training(make_algo, make_env, make_replay_buffer, make_logger, config: d logger = make_logger(seed) trainer = BaseTrainer( - state_dim=config["state_shape"], - action_dim=config["action_shape"], env=env, make_env_test=make_env, algo=make_algo(logger, seed), From 39eb435d69cd0a9116b1676b21b1be246c16b4cf Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Mon, 7 Jul 2025 21:27:26 +0400 Subject: [PATCH 04/30] Refactor environment module --- src/oprl/buffers/episodic_buffer.py | 16 +- src/oprl/env.py | 233 ---------------------------- src/oprl/trainers/base_trainer.py | 2 +- tests/functional/test_env.py | 2 +- tests/functional/test_rl_algos.py | 2 +- 5 files changed, 7 insertions(+), 248 deletions(-) delete mode 100644 src/oprl/env.py diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index fcf9b3b..7263a8f 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -24,18 +24,18 @@ def last_episode_length(self) -> int: ... @dataclass -class EpisodicReplayBuffer: +class EpisodicReplayBuffer(ReplayBufferProtocol): buffer_size: int state_dim: int action_dim: int gamma: float max_episode_lenth: int = 1000 device: str = "cpu" + episodes_counter: int = 1 _tensors: dict[str, t.Tensor] = field(default_factory=dict, init=False) - _max_episodes: int | None = None + _max_episodes: int = field(init=False) _ep_pointer: int = 0 - episodes_counter: int = 1 _number_transitions = 0 _created: bool = False @@ -100,13 +100,8 @@ def add_transition(self, state: npt.ArrayLike, action: npt.ArrayLike, reward: fl ) self.rewards[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(reward) self.dones[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(done) - self.ep_lens[self._ep_pointer] += 1 - - self._number_transitions = min(self._number_transitions + 1, self.buffer_size) - if episode_done: - self._inc_episode() def _inc_episode(self): self._ep_pointer = (self._ep_pointer + 1) % self._max_episodes @@ -117,10 +112,7 @@ def _inc_episode(self): def add_episode(self, episode: list): for s, a, r, d, _ in episode: self.add_transition(s, a, r, d, episode_done=d) - if d: - break - else: - self._inc_episode() + self._inc_episode() def _inds_to_episodic(self, inds): start_inds = np.cumsum([0] + self.ep_lens[: self.episodes_counter - 1]) diff --git a/src/oprl/env.py b/src/oprl/env.py deleted file mode 100644 index d777ecf..0000000 --- a/src/oprl/env.py +++ /dev/null @@ -1,233 +0,0 @@ -from abc import ABC, abstractmethod -from collections import OrderedDict -from typing import Any - -import numpy as np -import numpy.typing as npt -from dm_control import suite - - -class BaseEnv(ABC): - @abstractmethod - def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: - pass - - @abstractmethod - def step( - self, action: npt.ArrayLike - ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: - pass - - @abstractmethod - def sample_action(self) -> npt.ArrayLike: - pass - - @property - def env_family(self) -> str: - return "" - - -class DummyEnv(BaseEnv): - def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: - return np.array([]), {} - - def step( - self, action: npt.ArrayLike - ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: - return np.array([]), np.array([]), False, False, {} - - def sample_action(self) -> npt.ArrayLike: - return np.array([]) - - @property - def env_family(self) -> str: - return "" - - -class SafetyGym(BaseEnv): - def __init__(self, env_name: str, seed: int): - import safety_gymnasium as gym - - self._env = gym.make(env_name, render_mode='rgb_array', camera_name="fixednear") - self._seed = seed - - def step( - self, action: npt.ArrayLike - ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: - obs, reward, cost, terminated, truncated, info = self._env.step(action) - info["cost"] = cost - return obs.astype("float32"), reward, terminated, truncated, info - - def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: - obs, info = self._env.reset(seed=self._seed) - self._env.step(self._env.action_space.sample()) - return obs.astype("float32"), info - - def render(self) -> npt.NDArray: - return self._env.render() - - def sample_action(self): - return self._env.action_space.sample() - - @property - def observation_space(self): - return self._env.observation_space - - @property - def action_space(self): - return self._env.action_space - - @property - def env_family(self) -> str: - return "safety_gymnasium" - - -class DMControlEnv(BaseEnv): - def __init__(self, env: str, seed: int): - domain, task = env.split("-") - self.random_state = np.random.RandomState(seed) - self.env = suite.load(domain, task, task_kwargs={"random": self.random_state}) - - self._render_width = 200 - self._render_height = 200 - self._camera_id = 0 - - def reset(self, *args, **kwargs) -> tuple[npt.ArrayLike, dict[str, Any]]: - obs = self._flat_obs(self.env.reset().observation) - return obs, {} - - def step( - self, action: npt.ArrayLike - ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: - time_step = self.env.step(action) - obs = self._flat_obs(time_step.observation) - - terminated = False - truncated = self.env._step_count >= self.env._step_limit - - return obs, time_step.reward, terminated, truncated, {} - - def sample_action(self) -> npt.ArrayLike: - spec = self.env.action_spec() - action = self.random_state.uniform(spec.minimum, spec.maximum, spec.shape) - return action - - @property - def observation_space(self) -> npt.ArrayLike: - return np.zeros( - sum(int(np.prod(v.shape)) for v in self.env.observation_spec().values()) - ) - - @property - def action_space(self) -> npt.ArrayLike: - return np.zeros(self.env.action_spec().shape[0]) - - def render(self) -> npt.ArrayLike: - """ - returned shape: [1, W, H, C] - """ - img = self.env.physics.render( - camera_id=self._camera_id, - height=self._render_width, - width=self._render_width, - ) - img = img.astype(np.uint8) - return img - - def _flat_obs(self, obs: OrderedDict) -> npt.ArrayLike: - obs_flatten = [] - for _, o in obs.items(): - if len(o.shape) == 0: - obs_flatten.append(np.array([o])) - elif len(o.shape) == 2 and o.shape[1] > 1: - obs_flatten.append(o.flatten()) - else: - obs_flatten.append(o) - return np.concatenate(obs_flatten, dtype="float32") - - @property - def env_family(self) -> str: - return "dm_control" - - -ENV_MAPPER = { - "dm_control": set( - [ - "acrobot-swingup", - "ball_in_cup-catch", - "cartpole-balance", - "cartpole-swingup", - "cheetah-run", - "finger-spin", - "finger-turn_easy", - "finger-turn_hard", - "fish-upright", - "fish-swim", - "hopper-stand", - "hopper-hop", - "humanoid-stand", - "humanoid-walk", - "humanoid-run", - "pendulum-swingup", - "point_mass-easy", - "reacher-easy", - "reacher-hard", - "swimmer-swimmer6", - "swimmer-swimmer15", - "walker-stand", - "walker-walk", - "walker-run", - ] - ), - "safety_gymnasium": set( - [ - "SafetyPointGoal1-v0", - "SafetyPointGoal2-v0", - "SafetyPointButton1-v0", - "SafetyPointButton2-v0", - "SafetyPointPush1-v0", - "SafetyPointPush2-v0", - "SafetyPointCircle1-v0", - "SafetyPointCircle2-v0", - "SafetyCarGoal1-v0", - "SafetyCarGoal2-v0", - "SafetyCarButton1-v0", - "SafetyCarButton2-v0", - "SafetyCarPush1-v0", - "SafetyCarPush2-v0", - "SafetyCarCircle1-v0", - "SafetyCarCircle2-v0", - "SafetyAntGoal1-v0", - "SafetyAntGoal2-v0", - "SafetyAntButton1-v0", - "SafetyAntButton2-v0", - "SafetyAntPush1-v0", - "SafetyAntPush2-v0", - "SafetyAntCircle1-v0", - "SafetyAntCircle2-v0", - "SafetyDoggoGoal1-v0", - "SafetyDoggoGoal2-v0", - "SafetyDoggoButton1-v0", - "SafetyDoggoButton2-v0", - "SafetyDoggoPush1-v0", - "SafetyDoggoPush2-v0", - "SafetyDoggoCircle1-v0", - "SafetyDoggoCircle2-v0", - ] - ), -} - - -def make_env(name: str, seed: int): - """ - Args: - name: Environment name. - """ - for env_type, env_set in ENV_MAPPER.items(): - if name in env_set: - if env_type == "dm_control": - return DMControlEnv(name, seed=seed) - elif env_type == "safety_gymnasium": - return SafetyGym(name, seed=seed) - else: - raise ValueError(f"Unsupported environment: {name}") diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index 46809cd..4a6d0a0 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -5,7 +5,7 @@ import torch from oprl.env import BaseEnv -from oprl.buffers.episodic_buffer import EpisodicReplayBuffer, ReplayBufferProtocol +from oprl.buffers.episodic_buffer import ReplayBufferProtocol from oprl.logging import LoggerProtocol diff --git a/tests/functional/test_env.py b/tests/functional/test_env.py index 1facf87..0ff2bc5 100644 --- a/tests/functional/test_env.py +++ b/tests/functional/test_env.py @@ -1,6 +1,6 @@ import pytest -from oprl.env import make_env +from oprl.environment import make_env dm_control_envs: list[str] = [ diff --git a/tests/functional/test_rl_algos.py b/tests/functional/test_rl_algos.py index 876f505..234d909 100644 --- a/tests/functional/test_rl_algos.py +++ b/tests/functional/test_rl_algos.py @@ -5,7 +5,7 @@ from oprl.algos.sac import SAC from oprl.algos.td3 import TD3 from oprl.algos.tqc import TQC -from oprl.env import DMControlEnv +from oprl.environment import DMControlEnv rl_algo_classes = [DDPG, SAC, TD3, TQC] From bd39d3025c193eec9250d3bb48f687dc9e8e10b6 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Mon, 7 Jul 2025 21:27:42 +0400 Subject: [PATCH 05/30] Add new environment files --- src/oprl/environment/__init__.py | 8 +++ src/oprl/environment/dm_control.py | 73 +++++++++++++++++++++ src/oprl/environment/make_env.py | 80 ++++++++++++++++++++++++ src/oprl/environment/protocol.py | 35 +++++++++++ src/oprl/environment/safety_gymnasium.py | 42 +++++++++++++ 5 files changed, 238 insertions(+) create mode 100644 src/oprl/environment/__init__.py create mode 100644 src/oprl/environment/dm_control.py create mode 100644 src/oprl/environment/make_env.py create mode 100644 src/oprl/environment/protocol.py create mode 100644 src/oprl/environment/safety_gymnasium.py diff --git a/src/oprl/environment/__init__.py b/src/oprl/environment/__init__.py new file mode 100644 index 0000000..26f88ef --- /dev/null +++ b/src/oprl/environment/__init__.py @@ -0,0 +1,8 @@ +from oprl.environment.protocol import EnvProtocol +from oprl.environment.dm_control import DMControlEnv +from oprl.environment.safety_gymnasium import SafetyGym +from oprl.environment.make_env import make_env + +___all__ = ['DMControlEnv', 'SafetyGym', "make_env", "EnvProtocol"] + + diff --git a/src/oprl/environment/dm_control.py b/src/oprl/environment/dm_control.py new file mode 100644 index 0000000..2ce564f --- /dev/null +++ b/src/oprl/environment/dm_control.py @@ -0,0 +1,73 @@ +from collections import OrderedDict +from typing import Any + +import numpy as np +import numpy.typing as npt +from dm_control import suite + +from oprl.environment import EnvProtocol + + +class DMControlEnv(EnvProtocol): + def __init__(self, env: str, seed: int): + domain, task = env.split("-") + self.random_state = np.random.RandomState(seed) + self.env = suite.load(domain, task, task_kwargs={"random": self.random_state}) + + self._render_width = 200 + self._render_height = 200 + self._camera_id = 0 + + def reset(self, *args, **kwargs) -> tuple[npt.ArrayLike, dict[str, Any]]: + obs = self._flat_obs(self.env.reset().observation) + return obs, {} + + def step( + self, action: npt.ArrayLike + ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: + time_step = self.env.step(action) + obs = self._flat_obs(time_step.observation) + + terminated = False + truncated = self.env._step_count >= self.env._step_limit + + return obs, time_step.reward, terminated, truncated, {} + + def sample_action(self) -> npt.ArrayLike: + spec = self.env.action_spec() + action = self.random_state.uniform(spec.minimum, spec.maximum, spec.shape) + return action + + @property + def observation_space(self) -> npt.ArrayLike: + return np.zeros( + sum(int(np.prod(v.shape)) for v in self.env.observation_spec().values()) + ) + + @property + def action_space(self) -> npt.ArrayLike: + return np.zeros(self.env.action_spec().shape[0]) + + def render(self) -> npt.ArrayLike: # [1, W, H, C] + img = self.env.physics.render( + camera_id=self._camera_id, + height=self._render_width, + width=self._render_width, + ) + img = img.astype(np.uint8) + return img + + def _flat_obs(self, obs: OrderedDict) -> npt.ArrayLike: + obs_flatten = [] + for _, o in obs.items(): + if len(o.shape) == 0: + obs_flatten.append(np.array([o])) + elif len(o.shape) == 2 and o.shape[1] > 1: + obs_flatten.append(o.flatten()) + else: + obs_flatten.append(o) + return np.concatenate(obs_flatten, dtype="float32") + + @property + def env_family(self) -> str: + return "dm_control" diff --git a/src/oprl/environment/make_env.py b/src/oprl/environment/make_env.py new file mode 100644 index 0000000..f4f9f12 --- /dev/null +++ b/src/oprl/environment/make_env.py @@ -0,0 +1,80 @@ +from oprl.environment import EnvProtocol, DMControlEnv, SafetyGym + + +ENV_MAPPER = { + "dm_control": set( + [ + "acrobot-swingup", + "ball_in_cup-catch", + "cartpole-balance", + "cartpole-swingup", + "cheetah-run", + "finger-spin", + "finger-turn_easy", + "finger-turn_hard", + "fish-upright", + "fish-swim", + "hopper-stand", + "hopper-hop", + "humanoid-stand", + "humanoid-walk", + "humanoid-run", + "pendulum-swingup", + "point_mass-easy", + "reacher-easy", + "reacher-hard", + "swimmer-swimmer6", + "swimmer-swimmer15", + "walker-stand", + "walker-walk", + "walker-run", + ] + ), + "safety_gymnasium": set( + [ + "SafetyPointGoal1-v0", + "SafetyPointGoal2-v0", + "SafetyPointButton1-v0", + "SafetyPointButton2-v0", + "SafetyPointPush1-v0", + "SafetyPointPush2-v0", + "SafetyPointCircle1-v0", + "SafetyPointCircle2-v0", + "SafetyCarGoal1-v0", + "SafetyCarGoal2-v0", + "SafetyCarButton1-v0", + "SafetyCarButton2-v0", + "SafetyCarPush1-v0", + "SafetyCarPush2-v0", + "SafetyCarCircle1-v0", + "SafetyCarCircle2-v0", + "SafetyAntGoal1-v0", + "SafetyAntGoal2-v0", + "SafetyAntButton1-v0", + "SafetyAntButton2-v0", + "SafetyAntPush1-v0", + "SafetyAntPush2-v0", + "SafetyAntCircle1-v0", + "SafetyAntCircle2-v0", + "SafetyDoggoGoal1-v0", + "SafetyDoggoGoal2-v0", + "SafetyDoggoButton1-v0", + "SafetyDoggoButton2-v0", + "SafetyDoggoPush1-v0", + "SafetyDoggoPush2-v0", + "SafetyDoggoCircle1-v0", + "SafetyDoggoCircle2-v0", + ] + ), +} + + +def make_env(name: str, seed: int) -> EnvProtocol: + for env_type, env_set in ENV_MAPPER.items(): + if name in env_set: + if env_type == "dm_control": + return DMControlEnv(name, seed=seed) + elif env_type == "safety_gymnasium": + return SafetyGym(name, seed=seed) + else: + raise ValueError(f"Unsupported environment: {name}") diff --git a/src/oprl/environment/protocol.py b/src/oprl/environment/protocol.py new file mode 100644 index 0000000..3f1a585 --- /dev/null +++ b/src/oprl/environment/protocol.py @@ -0,0 +1,35 @@ +from typing import Protocol, Any + +import numpy.typing as npt + + +class EnvProtocol(Protocol): + def __init__(self, env_name: str, seed: int) -> None: + ... + + def step( + self, action: npt.ArrayLike + ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: + ... + + def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: + ... + + def sample_action(self) -> npt.ArrayLike: + ... + + def render(self) -> npt.ArrayLike: + ... + + @property + def observation_space(self) -> npt.ArrayLike: + ... + + @property + def action_space(self) -> npt.ArrayLike: + ... + + @property + def env_family(self) -> str: + ... + diff --git a/src/oprl/environment/safety_gymnasium.py b/src/oprl/environment/safety_gymnasium.py new file mode 100644 index 0000000..0d31a31 --- /dev/null +++ b/src/oprl/environment/safety_gymnasium.py @@ -0,0 +1,42 @@ +import numpy.typing as npt +from typing import Any + +from oprl.environment import EnvProtocol + + +class SafetyGym(EnvProtocol): + def __init__(self, env_name: str, seed: int) -> None: + import safety_gymnasium as gym + self._env = gym.make(env_name, render_mode='rgb_array', camera_name="fixednear") + self._seed = seed + + def step( + self, action: npt.ArrayLike + ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: + obs, reward, cost, terminated, truncated, info = self._env.step(action) + info["cost"] = cost + return obs.astype("float32"), reward, terminated, truncated, info + + def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: + obs, info = self._env.reset(seed=self._seed) + self._env.step(self._env.action_space.sample()) + return obs.astype("float32"), info + + def sample_action(self): + return self._env.action_space.sample() + + def render(self) -> npt.ArrayLike: + return self._env.render() + + @property + def observation_space(self) -> npt.ArrayLike: + return self._env.observation_space + + @property + def action_space(self) -> npt.ArrayLike: + return self._env.action_space + + @property + def env_family(self) -> str: + return "safety_gymnasium" + From 6feea32a230b69838af392d1a7e893a6959369e5 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Mon, 7 Jul 2025 22:29:10 +0400 Subject: [PATCH 06/30] Remove utils folder --- configs/ddpg.py | 4 +- configs/sac.py | 4 +- configs/td3.py | 4 +- configs/tqc.py | 4 +- src/oprl/algos/ddpg.py | 4 +- src/oprl/buffers/episodic_buffer.py | 3 + src/oprl/{utils/run_training.py => train.py} | 11 ++- src/oprl/trainers/base_trainer.py | 6 +- src/oprl/trainers/safe_trainer.py | 4 - src/oprl/utils/config.py | 10 --- src/oprl/utils/logger.py | 89 -------------------- src/oprl/utils/utils.py | 85 ------------------- 12 files changed, 25 insertions(+), 203 deletions(-) rename src/oprl/{utils/run_training.py => train.py} (92%) delete mode 100644 src/oprl/utils/config.py delete mode 100644 src/oprl/utils/logger.py delete mode 100644 src/oprl/utils/utils.py diff --git a/configs/ddpg.py b/configs/ddpg.py index 6c647b5..e173a07 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -9,8 +9,8 @@ LoggerProtocol ) set_logging() -from oprl.env import make_env as _make_env -from oprl.utils.run_training import run_training +from oprl.environment import make_env as _make_env +from oprl.train import run_training args = parse_args() diff --git a/configs/sac.py b/configs/sac.py index b9c09f4..3165a03 100644 --- a/configs/sac.py +++ b/configs/sac.py @@ -8,8 +8,8 @@ LoggerProtocol ) set_logging() -from oprl.env import make_env as _make_env -from oprl.utils.run_training import run_training +from oprl.environment import make_env as _make_env +from oprl.train import run_training args = parse_args() diff --git a/configs/td3.py b/configs/td3.py index 6a996be..7c77d66 100644 --- a/configs/td3.py +++ b/configs/td3.py @@ -8,8 +8,8 @@ LoggerProtocol ) set_logging() -from oprl.env import make_env as _make_env -from oprl.utils.run_training import run_training +from oprl.environment import make_env as _make_env +from oprl.train import run_training args = parse_args() diff --git a/configs/tqc.py b/configs/tqc.py index 13ccc5b..6cf045d 100644 --- a/configs/tqc.py +++ b/configs/tqc.py @@ -8,8 +8,8 @@ LoggerProtocol ) set_logging() -from oprl.env import make_env as _make_env -from oprl.utils.run_training import run_training +from oprl.environment import make_env as _make_env +from oprl.train import run_training args = parse_args() diff --git a/src/oprl/algos/ddpg.py b/src/oprl/algos/ddpg.py index 714a166..bf08413 100644 --- a/src/oprl/algos/ddpg.py +++ b/src/oprl/algos/ddpg.py @@ -10,20 +10,20 @@ from oprl.algos import OffPolicyAlgorithm from oprl.algos.nn import Critic, DeterministicPolicy from oprl.algos.utils import disable_gradient -from oprl.utils.logger import Logger, StdLogger +from oprl.logging import LoggerProtocol @dataclass class DDPG(OffPolicyAlgorithm): state_dim: int action_dim: int + logger: LoggerProtocol expl_noise: float = 0.1 discount: float = 0.99 tau: float = 5e-3 batch_size: int = 256 max_action: float = 1. device: str = "cpu" - logger: Logger = StdLogger() def create(self) -> "DDPG": self.actor = DeterministicPolicy( diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index 7263a8f..5eef9c3 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -102,6 +102,9 @@ def add_transition(self, state: npt.ArrayLike, action: npt.ArrayLike, reward: fl self.dones[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(done) self.ep_lens[self._ep_pointer] += 1 self._number_transitions = min(self._number_transitions + 1, self.buffer_size) + # TODO: Switch to the episodic append and remove condition below + if episode_done: + self._inc_episode() def _inc_episode(self): self._ep_pointer = (self._ep_pointer + 1) % self._max_episodes diff --git a/src/oprl/utils/run_training.py b/src/oprl/train.py similarity index 92% rename from src/oprl/utils/run_training.py rename to src/oprl/train.py index 8fdb6a7..374a74e 100644 --- a/src/oprl/utils/run_training.py +++ b/src/oprl/train.py @@ -1,10 +1,17 @@ import logging +import random +import numpy as np +import torch as t from multiprocessing import Process -from oprl.env import BaseEnv from oprl.trainers.base_trainer import BaseTrainer from oprl.trainers.safe_trainer import SafeTrainer -from oprl.utils.utils import set_seed + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + t.manual_seed(seed) def run_training( diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index 4a6d0a0..3a3693f 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -4,7 +4,7 @@ import numpy as np import torch -from oprl.env import BaseEnv +from oprl.environment import EnvProtocol from oprl.buffers.episodic_buffer import ReplayBufferProtocol from oprl.logging import LoggerProtocol @@ -13,8 +13,8 @@ class BaseTrainer: def __init__( self, - env: BaseEnv, - make_env_test: Callable[[int], BaseEnv], + env: EnvProtocol, + make_env_test: Callable[[int], EnvProtocol], replay_buffer: ReplayBufferProtocol, algo: Any | None = None, gamma: float = 0.99, diff --git a/src/oprl/trainers/safe_trainer.py b/src/oprl/trainers/safe_trainer.py index 2b3c33f..456d47a 100644 --- a/src/oprl/trainers/safe_trainer.py +++ b/src/oprl/trainers/safe_trainer.py @@ -1,10 +1,6 @@ -from typing import Any, Callable - import numpy as np -from oprl.env import BaseEnv from oprl.trainers.base_trainer import BaseTrainer -from oprl.logging import LoggerProtocol class SafeTrainer: diff --git a/src/oprl/utils/config.py b/src/oprl/utils/config.py deleted file mode 100644 index 95bc38e..0000000 --- a/src/oprl/utils/config.py +++ /dev/null @@ -1,10 +0,0 @@ -import importlib.util -import sys - - -def load_config(path: str): - spec = importlib.util.spec_from_file_location("config", path) - config = importlib.util.module_from_spec(spec) - sys.modules["config"] = config - spec.loader.exec_module(config) - return config diff --git a/src/oprl/utils/logger.py b/src/oprl/utils/logger.py deleted file mode 100644 index c78175a..0000000 --- a/src/oprl/utils/logger.py +++ /dev/null @@ -1,89 +0,0 @@ -import json -import logging -import os -import shutil -from abc import ABC, abstractmethod -from sys import path -from typing import Any - -import numpy as np -import torch -import torch.nn as nn -from torch.utils.tensorboard.writer import SummaryWriter - - -def copy_exp_dir(log_dir: str) -> None: - cur_dir = os.path.join(os.getcwd(), "src") - dest_dir = os.path.join(log_dir, "src") - shutil.copytree(cur_dir, dest_dir) - logging.info(f"Source copied into {dest_dir}") - - -def save_json_config(config: dict[str, Any], path: str): - with open(path, "w") as f: - json.dump(config, f) - - -class Logger(ABC): - - def log_scalars(self, values: dict[str, float], step: int): - """ - Args: - values: Dict with tag -> value to log. - step: Iter step. - """ - (self.log_scalar(k, v, step) for k, v in values.items()) - - @abstractmethod - def log_scalar(self, tag: str, value: float, step: int): - logging.info(f"{tag}\t{value}\tat step {step}") - - @abstractmethod - def log_video(self, tag: str, imgs, step: int): - logging.warning("Skipping logging video in STDOUT logger") - - -class StdLogger(Logger): - def __init__(self, *args, **kwargs): - pass - - def log_scalar(self, tag: str, value: float, step: int): - logging.info(f"{tag}\t{value}\tat step {step}") - - def log_video(self, *args, **kwargs): - logging.warning("Skipping logging video in STDOUT logger") - - -class FileLogger(Logger): - def __init__(self, logdir: str, config: dict[str, Any]): - self.writer = SummaryWriter(logdir) - - self._log_dir = logdir - - logging.info(f"Source code is copied to {logdir}") - copy_exp_dir(logdir) - save_json_config(config, os.path.join(logdir, "config.json")) - - def log_scalar(self, tag: str, value: float, step: int) -> None: - self.writer.add_scalar(tag, value, step) - self._log_scalar_to_file(tag, value, step) - - def log_video(self, tag: str, imgs, step: int) -> None: - os.makedirs(os.path.join(self._log_dir, "images"), exist_ok=True) - fn = os.path.join(self._log_dir, "images", f"{tag}_step_{step}.npz") - with open(fn, "wb") as f: - np.save(f, imgs) - - def save_weights(self, weights: nn.Module, step: int) -> None: - os.makedirs(os.path.join(self._log_dir, "weights"), exist_ok=True) - fn = os.path.join(self._log_dir, "weights", f"step_{step}.w") - torch.save( - weights, - fn - ) - - def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None: - fn = os.path.join(self._log_dir, f"{tag}.log") - os.makedirs(os.path.dirname(fn), exist_ok=True) - with open(fn, "a") as f: - f.write(f"{step} {val}\n") diff --git a/src/oprl/utils/utils.py b/src/oprl/utils/utils.py deleted file mode 100644 index d168b0d..0000000 --- a/src/oprl/utils/utils.py +++ /dev/null @@ -1,85 +0,0 @@ -import logging -import os -import random -import shutil -import sys -from glob import glob - -import imageio -import numpy as np -import torch as t - - -class OUNoise(object): - def __init__( - self, - dim, - low, - high, - mu=0.0, - theta=0.15, - max_sigma=0.3, - min_sigma=0.3, - decay_period=10_000, - ): - self.mu = mu - self.theta = theta - self.sigma = max_sigma - self.max_sigma = max_sigma - self.min_sigma = min_sigma - self.decay_period = decay_period - self.action_dim = dim - self.low = low - self.high = high - - def reset(self): - self.state = np.ones(self.action_dim) * self.mu - - def evolve_state(self): - x = self.state - dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim) - self.state = x + dx - return self.state - - def get_action(self, action, t=0): - ou_state = self.evolve_state() - self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min( - 1.0, t / self.decay_period - ) - action = action.cpu().detach().numpy() - return np.clip(action + ou_state, self.low, self.high) - - -def make_gif(source_dir, output): - """ - Make gif file from set of .jpeg images. - Args: - source_dir (str): path with .jpeg images - output (str): path to the output .gif file - Returns: None - """ - batch_sort = lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) - image_paths = sorted(glob(os.path.join(source_dir, "*.png")), key=batch_sort) - - images = [] - for filename in image_paths: - images.append(imageio.imread(filename)) - imageio.mimsave(output, images) - - -def empty_torch_queue(q): - while True: - try: - o = q.get_nowait() - del o - except: - break - q.close() - - -def set_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - t.manual_seed(seed) - - From 3880f027257930569eae5a51a35654521d7dbb45 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sat, 12 Jul 2025 16:02:16 +0400 Subject: [PATCH 07/30] Make trainers dataclasses --- src/oprl/trainers/base_trainer.py | 153 +++++++++++++----------------- src/oprl/trainers/safe_trainer.py | 67 ++++++------- 2 files changed, 103 insertions(+), 117 deletions(-) diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index 3a3693f..23a6f11 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -1,71 +1,58 @@ from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, Protocol import numpy as np import torch +from oprl.algos import OffPolicyAlgorithm from oprl.environment import EnvProtocol from oprl.buffers.episodic_buffer import ReplayBufferProtocol from oprl.logging import LoggerProtocol +class TrainerProtocol(Protocol): + def train(self) -> None: ... + + def evaluate(self) -> dict[str, float]: ... + + @dataclass -class BaseTrainer: - def __init__( - self, - env: EnvProtocol, - make_env_test: Callable[[int], EnvProtocol], - replay_buffer: ReplayBufferProtocol, - algo: Any | None = None, - gamma: float = 0.99, - num_steps: int = int(1e6), - start_steps: int = int(10e3), - batch_size: int = 128, - eval_interval: int = int(2e3), - num_eval_episodes: int = 10, - save_buffer_every: int = 0, - save_policy_every: int = int(100_000), - estimate_q_every: int = 0, - stdout_log_every: int = int(1e5), - device: str = "cpu", - seed: int = 0, - logger: LoggerProtocol | None = None, - ): - self._env = env - self._make_env_test = make_env_test - self._algo = algo - self._gamma = gamma - self._device = device - self._save_buffer_every = save_buffer_every - self._estimate_q_every = estimate_q_every - self._stdout_log_every = stdout_log_every - self._save_policy_every= save_policy_every - self._logger = logger - self.seed = seed - self.replay_buffer = replay_buffer - self.batch_size = batch_size - self.num_steps = num_steps - self.start_steps = start_steps - self.eval_interval = eval_interval - self.num_eval_episodes = num_eval_episodes - - def train(self): +class BaseTrainer(TrainerProtocol): + env: EnvProtocol + make_env_test: Callable[[int], EnvProtocol] + replay_buffer: ReplayBufferProtocol + algo: OffPolicyAlgorithm | None = None + gamma: float = 0.99 + num_steps: int = int(1e6) + start_steps: int = int(10e3) + batch_size: int = 128 + eval_interval: int = int(2e3) + num_eval_episodes: int = 10 + save_buffer_every: int = 0 + save_policy_every: int = int(100_000) + estimate_q_every: int = 0 + stdout_log_every: int = int(1e5) + device: str = "cpu" + seed: int = 0 + logger: LoggerProtocol | None = None + + def train(self) -> None: ep_step = 0 - state, _ = self._env.reset() + state, _ = self.env.reset() for env_step in range(self.num_steps + 1): ep_step += 1 if env_step <= self.start_steps: - action = self._env.sample_action() + action = self.env.sample_action() else: - action = self._algo.explore(state) - next_state, reward, terminated, truncated, _ = self._env.step(action) + action = self.algo.explore(state) + next_state, reward, terminated, truncated, _ = self.env.step(action) self.replay_buffer.add_transition( state, action, reward, terminated, episode_done=terminated or truncated ) if terminated or truncated: - next_state, _ = self._env.reset() + next_state, _ = self.env.reset() ep_step = 0 state = next_state @@ -73,71 +60,67 @@ def train(self): continue batch = self.replay_buffer.sample(self.batch_size) - self._algo.update(*batch) + self.algo.update(*batch) - self._eval_routine(env_step, batch) - self._save_buffer(env_step) + self._log_evaluation(env_step, batch) self._save_policy(env_step) self._log_stdout(env_step, batch) - def _eval_routine(self, env_step: int, batch): + def _log_evaluation(self, env_step: int, batch): if env_step % self.eval_interval == 0: - self._log_evaluation(env_step) + eval_metrics = self.evaluate() + self.logger.log_scalar("trainer/ep_reward", eval_metrics["return"], env_step) - self._logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) - self._logger.log_scalar( + self.logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) + self.logger.log_scalar( "trainer/buffer_transitions", len(self.replay_buffer), env_step ) - self._logger.log_scalar( + self.logger.log_scalar( "trainer/buffer_episodes", self.replay_buffer.episodes_counter, env_step ) - self._logger.log_scalar( + self.logger.log_scalar( "trainer/buffer_last_ep_len", self.replay_buffer.last_episode_length, env_step, ) - def _log_evaluation(self, env_step: int): + def evaluate(self) -> dict[str, float]: returns = [] for i_ep in range(self.num_eval_episodes): - env_test = self._make_env_test(seed=self.seed + i_ep) + env_test = self.make_env_test(seed=self.seed + i_ep) state, _ = env_test.reset() episode_return = 0.0 terminated, truncated = False, False while not (terminated or truncated): - action = self._algo.exploit(state) + action = self.algo.exploit(state) state, reward, terminated, truncated, _ = env_test.step(action) episode_return += reward returns.append(episode_return) - mean_return = np.mean(returns) - self._logger.log_scalar("trainer/ep_reward", mean_return, env_step) - - def _save_buffer(self, env_step: int): - # TODO: doesn't work - if self._save_buffer_every > 0 and env_step % self._save_buffer_every == 0: - self.replay_buffer.save(f"{self.log_dir}/buffers/buffer_step_{env_step}.pickle") + return { + "return": float(np.mean(returns)) + } def _save_policy(self, env_step: int): - if self._save_policy_every > 0 and env_step % self._save_policy_every == 0: - self._logger.save_weights(self._algo.actor, env_step) + if self.save_policy_every > 0 and env_step % self.save_policy_every == 0: + self.logger.save_weights(self.algo.actor, env_step) def _estimate_q(self, env_step: int): - if self._estimate_q_every > 0 and env_step % self._estimate_q_every == 0: + if self.estimate_q_every > 0 and env_step % self.estimate_q_every == 0: q_true = self.estimate_true_q() q_critic = self.estimate_critic_q() if q_true is not None: - self._logger.log_scalar("trainer/Q-estimate", q_true, env_step) - self._logger.log_scalar("trainer/Q-critic", q_critic, env_step) - self._logger.log_scalar( + self.logger.log_scalar("trainer/Q-estimate", q_true, env_step) + self.logger.log_scalar("trainer/Q-critic", q_critic, env_step) + self.logger.log_scalar( "trainer/Q_asb_diff", q_critic - q_true, env_step ) def _log_stdout(self, env_step: int, batch): - if env_step % self._stdout_log_every == 0: + if env_step % self.stdout_log_every == 0: perc = int(env_step / self.num_steps * 100) print( f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {batch[2].mean():10.3f}" @@ -147,16 +130,16 @@ def estimate_true_q(self, eval_episodes: int = 10) -> float | None: try: qs = [] for i_eval in range(eval_episodes): - env = self._make_env_test(seed=self.seed * 100 + i_eval) + env = self.make_env_test(seed=self.seed * 100 + i_eval) print("Before reset etimate q") state, _ = env.reset() q = 0 s_i = 1 while True: - action = self._algo.exploit(state) + action = self.algo.exploit(state) state, r, terminated, truncated, _ = env.step(action) - q += r * self._gamma**s_i + q += r * self.gamma ** s_i s_i += 1 if terminated or truncated: break @@ -171,15 +154,15 @@ def estimate_true_q(self, eval_episodes: int = 10) -> float | None: def estimate_critic_q(self, num_episodes: int = 10) -> float: qs = [] for i_eval in range(num_episodes): - env = self._make_env_test(seed=self.seed * 100 + i_eval) + env = self.make_env_test(seed=self.seed * 100 + i_eval) state, _ = env.reset() - action = self._algo.exploit(state) + action = self.algo.exploit(state) - state = torch.tensor(state).unsqueeze(0).float().to(self._device) - action = torch.tensor(action).unsqueeze(0).float().to(self._device) + state = torch.tensor(state).unsqueeze(0).float().to(self.device) + action = torch.tensor(action).unsqueeze(0).float().to(self.device) - q = self._algo.critic(state, action) + q = self.algo.critic(state, action) # TODO: TQC is not supported by this logic, need to update if isinstance(q, tuple): q = q[0] @@ -189,14 +172,14 @@ def estimate_critic_q(self, num_episodes: int = 10) -> float: return np.mean(qs, dtype=float) -def run_training(make_algo, make_env, make_replay_buffer, make_logger, config: dict[str, Any], seed: int): - env = make_env(seed=seed) - logger = make_logger(seed) +def run_training(makealgo, makeenv, make_replay_buffer, makelogger, config: dict[str, Any], seed: int): + env = makeenv(seed=seed) + logger = makelogger(seed) trainer = BaseTrainer( env=env, - make_env_test=make_env, - algo=make_algo(logger, seed), + makeenv_test=makeenv, + algo=makealgo(logger, seed), replay_buffer=make_replay_buffer(), num_steps=config["num_steps"], eval_interval=config["eval_every"], diff --git a/src/oprl/trainers/safe_trainer.py b/src/oprl/trainers/safe_trainer.py index 456d47a..8124742 100644 --- a/src/oprl/trainers/safe_trainer.py +++ b/src/oprl/trainers/safe_trainer.py @@ -1,71 +1,75 @@ +from dataclasses import dataclass + import numpy as np -from oprl.trainers.base_trainer import BaseTrainer +from oprl.trainers.base_trainer import BaseTrainer, TrainerProtocol -class SafeTrainer: - def __init__( - self, - trainer: BaseTrainer - ): - self.trainer = trainer +@dataclass +class SafeTrainer(TrainerProtocol): + trainer: BaseTrainer def train(self): ep_step = 0 - state, _ = self.trainer._env.reset() + state, _ = self.trainer.env.reset() total_cost = 0 for env_step in range(self.trainer.num_steps + 1): ep_step += 1 if env_step <= self.trainer.start_steps: - action = self.trainer._env.sample_action() + action = self.trainer.env.sample_action() else: - action = self.trainer._algo.explore(state) - next_state, reward, terminated, truncated, info = self.trainer._env.step(action) + action = self.trainer.algo.explore(state) + next_state, reward, terminated, truncated, info = self.trainer.env.step(action) total_cost += info["cost"] self.trainer.replay_buffer.add_transition( state, action, reward, terminated, episode_done=terminated or truncated ) if terminated or truncated: - next_state, _ = self.trainer._env.reset() + next_state, _ = self.trainer.env.reset() ep_step = 0 state = next_state if len(self.trainer.replay_buffer) < self.trainer.batch_size: continue batch = self.trainer.replay_buffer.sample(self.trainer.batch_size) - self.trainer._algo.update(*batch) + self.trainer.algo.update(*batch) - self._eval_routine(env_step, batch) + self._log_evaluation(env_step, batch) self.trainer._save_policy(env_step) - self.trainer._save_buffer(env_step) self.trainer._log_stdout(env_step, batch) - self.trainer._logger.log_scalar("trainer/total_cost", total_cost, self.trainer.num_steps) + self.trainer.logger.log_scalar("trainer/total_cost", total_cost, self.trainer.num_steps) - def _eval_routine(self, env_step: int, batch): + def _log_evaluation(self, env_step: int, batch) -> None: if env_step % self.trainer.eval_interval == 0: - self._log_evaluation(env_step) + eval_metrics = self.evaluate() + self.trainer.logger.log_scalar( + "trainer/ep_reward", eval_metrics["return"], env_step + ) + self.trainer.logger.log_scalar( + "trainer/ep_cost", eval_metrics["cost"], env_step + ) - self.trainer._logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) - self.trainer._logger.log_scalar( + self.trainer.logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) + self.trainer.logger.log_scalar( "trainer/buffer_transitions", len(self.trainer.replay_buffer), env_step ) - self.trainer._logger.log_scalar( - "trainer/buffer_episodes", self.trainer.replay_buffer.cur_episodes, env_step + self.trainer.logger.log_scalar( + "trainer/buffer_episodes", self.trainer.replay_buffer.episodes_counter, env_step ) - self.trainer._logger.log_scalar( + self.trainer.logger.log_scalar( "trainer/buffer_last_ep_len", self.trainer.replay_buffer.last_episode_length, env_step, ) - def _log_evaluation(self, env_step: int): + def evaluate(self) -> dict[str, float]: returns = [] costs = [] for i_ep in range(self.trainer.num_eval_episodes): - env_test = self.trainer._make_env_test(seed=self.trainer.seed + i_ep) + env_test = self.trainer.make_env_test(seed=self.trainer.seed + i_ep) state, _ = env_test.reset() episode_return = 0 @@ -73,7 +77,7 @@ def _log_evaluation(self, env_step: int): terminated, truncated = False, False while not (terminated or truncated): - action = self.trainer._algo.exploit(state) + action = self.trainer.algo.exploit(state) state, reward, terminated, truncated, info = env_test.step(action) episode_return += reward episode_cost += info["cost"] @@ -81,9 +85,8 @@ def _log_evaluation(self, env_step: int): returns.append(episode_return) costs.append(episode_cost) - self.trainer._logger.log_scalar( - "trainer/ep_reward", np.mean(returns, dtype=float), env_step - ) - self.trainer._logger.log_scalar( - "trainer/ep_cost", np.mean(costs, dtype=float), env_step - ) + return { + "return": float(np.mean(returns)), + "cost": float(np.mean(costs)), + } + From faa8d0d36475c58808467485f7ee9b6e140a2963 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sat, 12 Jul 2025 20:28:38 +0400 Subject: [PATCH 08/30] Add option for storing logs outside src folder, make logs dir creation common code --- configs/ddpg.py | 15 +++++++-------- configs/sac.py | 15 +++++++-------- configs/td3.py | 15 +++++++-------- configs/tqc.py | 15 +++++++-------- src/oprl/logging.py | 24 +++++++++++++++++------- 5 files changed, 45 insertions(+), 39 deletions(-) diff --git a/configs/ddpg.py b/configs/ddpg.py index e173a07..f030391 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -3,10 +3,9 @@ from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer from oprl.parse_args import parse_args from oprl.logging import ( - create_logdir, set_logging, - FileTxtLogger, - LoggerProtocol + LoggerProtocol, + make_text_logger_func, ) set_logging() from oprl.environment import make_env as _make_env @@ -61,11 +60,11 @@ def make_replay_buffer() -> ReplayBufferProtocol: ).create() -def make_logger(seed: int) -> LoggerProtocol: - log_dir = create_logdir(logdir="logs", algo="DDPG", env=args.env, seed=seed) - logger = FileTxtLogger(log_dir, config) - logger.copy_source_code() - return logger +make_logger = make_text_logger_func( + config=config, + algo="DDPG", + env=args.env, +) if __name__ == "__main__": diff --git a/configs/sac.py b/configs/sac.py index 3165a03..47ea3fb 100644 --- a/configs/sac.py +++ b/configs/sac.py @@ -2,10 +2,9 @@ from oprl.algos.sac import SAC from oprl.parse_args import parse_args from oprl.logging import ( - create_logdir, set_logging, - FileTxtLogger, - LoggerProtocol + LoggerProtocol, + make_text_logger_func, ) set_logging() from oprl.environment import make_env as _make_env @@ -48,11 +47,11 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: ).create() -def make_logger(seed: int) -> LoggerProtocol: - log_dir = create_logdir(logdir="logs", algo="SAC", env=args.env, seed=seed) - logger = FileTxtLogger(log_dir, config) - logger.copy_source_code() - return logger +make_logger = make_text_logger_func( + config=config, + algo="SAC", + env=args.env, +) if __name__ == "__main__": diff --git a/configs/td3.py b/configs/td3.py index 7c77d66..9eeaf15 100644 --- a/configs/td3.py +++ b/configs/td3.py @@ -2,10 +2,9 @@ from oprl.algos.td3 import TD3 from oprl.parse_args import parse_args from oprl.logging import ( - create_logdir, set_logging, - FileTxtLogger, - LoggerProtocol + LoggerProtocol, + make_text_logger_func, ) set_logging() from oprl.environment import make_env as _make_env @@ -47,11 +46,11 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: ).create() -def make_logger(seed: int) -> LoggerProtocol: - log_dir = create_logdir(logdir="logs", algo="TD3", env=args.env, seed=seed) - logger = FileTxtLogger(log_dir, config) - logger.copy_source_code() - return logger +make_logger = make_text_logger_func( + config=config, + algo="TD3", + env=args.env, +) if __name__ == "__main__": diff --git a/configs/tqc.py b/configs/tqc.py index 6cf045d..2731615 100644 --- a/configs/tqc.py +++ b/configs/tqc.py @@ -2,10 +2,9 @@ from oprl.algos.tqc import TQC from oprl.parse_args import parse_args from oprl.logging import ( - create_logdir, set_logging, - FileTxtLogger, - LoggerProtocol + LoggerProtocol, + make_text_logger_func, ) set_logging() from oprl.environment import make_env as _make_env @@ -48,11 +47,11 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: ).create() -def make_logger(seed: int) -> LoggerProtocol: - log_dir = create_logdir(logdir="logs", algo="TQC", env=args.env, seed=seed) - logger = FileTxtLogger(log_dir, config) - logger.copy_source_code() - return logger +make_logger = make_text_logger_func( + config=config, + algo="TQC", + env=args.env, +) if __name__ == "__main__": diff --git a/src/oprl/logging.py b/src/oprl/logging.py index 3a92384..2d39b80 100644 --- a/src/oprl/logging.py +++ b/src/oprl/logging.py @@ -5,13 +5,19 @@ import json import shutil from abc import ABC, abstractmethod -from typing import Any, Protocol +from typing import Any, Protocol, Callable import torch as t import torch.nn as nn from torch.utils.tensorboard.writer import SummaryWriter +class LoggerProtocol(Protocol): + def log_scalar(self, tag: str, value: float, step: int) -> None: ... + + def log_scalars(self, values: dict[str, float], step: int) -> None: ... + + def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str: dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss") log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}") @@ -34,17 +40,21 @@ def copy_exp_dir(log_dir: str) -> None: logging.info(f"Source copied into {dest_dir}") +def make_text_logger_func(config: dict, algo, env) -> Callable: + def make_logger(seed: int) -> LoggerProtocol: + logs_root = os.environ.get("OPRL_LOGS", "logs") + log_dir = create_logdir(logdir=logs_root, algo=algo, env=env, seed=seed) + logger = FileTxtLogger(log_dir, config) + logger.copy_source_code() + return logger + return make_logger + + def save_json_config(config: dict[str, Any], path: str): with open(path, "w") as f: json.dump(config, f) -class LoggerProtocol(Protocol): - def log_scalar(self, tag: str, value: float, step: int) -> None: ... - - def log_scalars(self, values: dict[str, float], step: int) -> None: ... - - class BaseLogger(ABC): @abstractmethod def log_scalar(self, tag: str, value: float, step: int) -> None: From 1edb17894c47be77ec9453bf8e6a4d8c38dd0af3 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sat, 12 Jul 2025 21:54:17 +0400 Subject: [PATCH 09/30] Introduce explicit protocols --- configs/ddpg.py | 5 +- configs/sac.py | 22 +++++++- configs/td3.py | 22 +++++++- configs/tqc.py | 22 +++++++- scripts/visualize_policy_from_weights.py | 2 +- src/oprl/algos/__init__.py | 16 ------ src/oprl/algos/ddpg.py | 36 +++++-------- src/oprl/algos/nn_functions.py | 15 ++++++ src/oprl/algos/{nn.py => nn_models.py} | 51 ++++++++++++------- src/oprl/algos/protocols.py | 24 +++++++++ src/oprl/algos/sac.py | 27 ++++------ src/oprl/algos/td3.py | 28 +++------- src/oprl/algos/tqc.py | 25 ++++----- src/oprl/algos/utils.py | 36 ------------- src/oprl/buffers/episodic_buffer.py | 18 +------ src/oprl/buffers/protocols.py | 20 ++++++++ src/oprl/environment/__init__.py | 2 +- src/oprl/environment/dm_control.py | 2 +- .../environment/{protocol.py => protocols.py} | 0 src/oprl/environment/safety_gymnasium.py | 2 +- src/oprl/trainers/base_trainer.py | 12 ++--- src/oprl/trainers/protocols.py | 7 +++ 22 files changed, 214 insertions(+), 180 deletions(-) create mode 100644 src/oprl/algos/nn_functions.py rename src/oprl/algos/{nn.py => nn_models.py} (80%) create mode 100644 src/oprl/algos/protocols.py delete mode 100644 src/oprl/algos/utils.py create mode 100644 src/oprl/buffers/protocols.py rename src/oprl/environment/{protocol.py => protocols.py} (100%) create mode 100644 src/oprl/trainers/protocols.py diff --git a/configs/ddpg.py b/configs/ddpg.py index f030391..6b0bd99 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -1,4 +1,4 @@ -from oprl.algos import OffPolicyAlgorithm +from oprl.algos.protocols import OffPolicyAlgorithm from oprl.algos.ddpg import DDPG from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer from oprl.parse_args import parse_args @@ -33,7 +33,6 @@ def make_env(seed: int): "device": args.device, "visualise_every": 50000, "estimate_q_every": 5000, - "gamma": 0.99, "log_every": 2500, } @@ -45,7 +44,6 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: state_dim=STATE_DIM, action_dim=ACTION_DIM, device=args.device, - discount=config["gamma"], logger=logger, ).create() @@ -56,7 +54,6 @@ def make_replay_buffer() -> ReplayBufferProtocol: state_dim=STATE_DIM, action_dim=ACTION_DIM, device=config["device"], - gamma=config["gamma"], ).create() diff --git a/configs/sac.py b/configs/sac.py index 47ea3fb..b5750b0 100644 --- a/configs/sac.py +++ b/configs/sac.py @@ -1,4 +1,4 @@ -from oprl.algos import OffPolicyAlgorithm +from oprl.algos.protocols import OffPolicyAlgorithm from oprl.algos.sac import SAC from oprl.parse_args import parse_args from oprl.logging import ( @@ -8,6 +8,7 @@ ) set_logging() from oprl.environment import make_env as _make_env +from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer from oprl.train import run_training args = parse_args() @@ -54,6 +55,23 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: ) +def make_replay_buffer() -> ReplayBufferProtocol: + return EpisodicReplayBuffer( + buffer_size=max(config["num_steps"], int(1e6)), + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + device=config["device"], + ).create() + + if __name__ == "__main__": args = parse_args() - run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) + run_training( + make_algo=make_algo, + make_env=make_env, + make_replay_buffer=make_replay_buffer, + make_logger=make_logger, + config=config, + seeds=args.seeds, + start_seed=args.start_seed + ) diff --git a/configs/td3.py b/configs/td3.py index 9eeaf15..d4dcb16 100644 --- a/configs/td3.py +++ b/configs/td3.py @@ -1,6 +1,7 @@ -from oprl.algos import OffPolicyAlgorithm +from oprl.algos.protocols import OffPolicyAlgorithm from oprl.algos.td3 import TD3 from oprl.parse_args import parse_args +from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer from oprl.logging import ( set_logging, LoggerProtocol, @@ -46,6 +47,15 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: ).create() +def make_replay_buffer() -> ReplayBufferProtocol: + return EpisodicReplayBuffer( + buffer_size=max(config["num_steps"], int(1e6)), + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + device=config["device"], + ).create() + + make_logger = make_text_logger_func( config=config, algo="TD3", @@ -55,4 +65,12 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: if __name__ == "__main__": args = parse_args() - run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) + run_training( + make_algo=make_algo, + make_env=make_env, + make_logger=make_logger, + make_replay_buffer=make_replay_buffer, + config=config, + seeds=args.seeds, + start_seed=args.start_seed + ) diff --git a/configs/tqc.py b/configs/tqc.py index 2731615..882d9d3 100644 --- a/configs/tqc.py +++ b/configs/tqc.py @@ -1,4 +1,4 @@ -from oprl.algos import OffPolicyAlgorithm +from oprl.algos.protocols import OffPolicyAlgorithm from oprl.algos.tqc import TQC from oprl.parse_args import parse_args from oprl.logging import ( @@ -8,6 +8,7 @@ ) set_logging() from oprl.environment import make_env as _make_env +from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer from oprl.train import run_training args = parse_args() @@ -54,6 +55,23 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: ) +def make_replay_buffer() -> ReplayBufferProtocol: + return EpisodicReplayBuffer( + buffer_size=max(config["num_steps"], int(1e6)), + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + device=config["device"], + ).create() + + if __name__ == "__main__": args = parse_args() - run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) + run_training( + make_algo=make_algo, + make_env=make_env, + make_replay_buffer=make_replay_buffer, + make_logger=make_logger, + config=config, + seeds=args.seeds, + start_seed=args.start_seed + ) diff --git a/scripts/visualize_policy_from_weights.py b/scripts/visualize_policy_from_weights.py index 659cbd5..2683785 100644 --- a/scripts/visualize_policy_from_weights.py +++ b/scripts/visualize_policy_from_weights.py @@ -3,7 +3,7 @@ import numpy as np from PIL import Image -from oprl.env import make_env +from oprl.environment import make_env def create_webp_gif(numpy_arrays, output_path, duration=100, loop=0): diff --git a/src/oprl/algos/__init__.py b/src/oprl/algos/__init__.py index b6f3a8f..8b13789 100644 --- a/src/oprl/algos/__init__.py +++ b/src/oprl/algos/__init__.py @@ -1,17 +1 @@ -from typing import Protocol - -import torch as t - - -class OffPolicyAlgorithm(Protocol): - def create(self) -> "OffPolicyAlgorithm": ... - - def update( - self, - state: t.Tensor, - action: t.Tensor, - reward: t.Tensor, - done: t.Tensor, - next_state: t.Tensor - ) -> None: ... diff --git a/src/oprl/algos/ddpg.py b/src/oprl/algos/ddpg.py index bf08413..c983a20 100644 --- a/src/oprl/algos/ddpg.py +++ b/src/oprl/algos/ddpg.py @@ -1,18 +1,19 @@ from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any -import numpy as np import numpy.typing as npt import torch as t from torch import nn -from oprl.algos import OffPolicyAlgorithm -from oprl.algos.nn import Critic, DeterministicPolicy -from oprl.algos.utils import disable_gradient +from oprl.algos.protocols import OffPolicyAlgorithm, PolicyProtocol +from oprl.algos.nn_models import Critic, DeterministicPolicy +from oprl.algos.nn_functions import disable_gradient from oprl.logging import LoggerProtocol +# TODO: Do I need max_action all the time? need to check envs for their max actions + @dataclass class DDPG(OffPolicyAlgorithm): state_dim: int @@ -25,12 +26,18 @@ class DDPG(OffPolicyAlgorithm): max_action: float = 1. device: str = "cpu" + actor: PolicyProtocol = field(init=False) + critic: nn.Module = field(init=False) + def create(self) -> "DDPG": self.actor = DeterministicPolicy( state_dim=self.state_dim, action_dim=self.action_dim, hidden_units=(256, 256), hidden_activation=nn.ReLU(inplace=True), + expl_noise=self.expl_noise, + max_action=self.max_action, + device=self.device, ).to(self.device) self.actor_target = deepcopy(self.actor) disable_gradient(self.actor_target) @@ -81,36 +88,21 @@ def _update_critic( current_Q = self.critic(state, action) critic_loss = (current_Q - target_Q).pow(2).mean() - self.optim_critic.zero_grad() critic_loss.backward() self.optim_critic.step() def _update_actor(self, state: t.Tensor) -> None: actor_loss = -self.critic(state, self.actor(state)).mean() - self.optim_actor.zero_grad() actor_loss.backward() self.optim_actor.step() def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: - state = t.tensor(state, device=self.device).unsqueeze_(0) - with t.no_grad(): - action = self.actor(state).cpu() - return action.numpy().flatten() + return self.actor.exploit(state) - # TODO: remove explore from algo to agent completely def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: - state = t.tensor(state, device=self.device).unsqueeze_(0) - - with t.no_grad(): - noise = ( - t.randn(self.action_dim) * self.max_action * self.expl_noise - ).to(self.device) - action = self.actor(state) + noise - - a = action.cpu().numpy()[0] - return np.clip(a, -self.max_action, self.max_action) + return self.actor.explore(state) def get_policy_state_dict(self) -> dict[str, Any]: return self.actor.state_dict() diff --git a/src/oprl/algos/nn_functions.py b/src/oprl/algos/nn_functions.py new file mode 100644 index 0000000..c57726c --- /dev/null +++ b/src/oprl/algos/nn_functions.py @@ -0,0 +1,15 @@ +import torch + + +def soft_update(target, source, tau): + """Update target network using Polyak-Ruppert Averaging.""" + with torch.no_grad(): + for tgt, src in zip(target.parameters(), source.parameters()): + tgt.data.mul_(1.0 - tau) + tgt.data.add_(tau * src.data) + + +def disable_gradient(network): + """Disable gradient calculations of the network.""" + for param in network.parameters(): + param.requires_grad = False diff --git a/src/oprl/algos/nn.py b/src/oprl/algos/nn_models.py similarity index 80% rename from src/oprl/algos/nn.py rename to src/oprl/algos/nn_models.py index 935f535..bb69064 100644 --- a/src/oprl/algos/nn.py +++ b/src/oprl/algos/nn_models.py @@ -1,15 +1,28 @@ import numpy as np +from numpy._typing import NDArray import numpy.typing as npt import torch as t import torch.nn as nn from torch.distributions import Distribution, Normal from torch.nn.functional import logsigmoid -from oprl.algos.utils import initialize_weight LOG_STD_MIN_MAX = (-20, 2) +def initialize_weight_orthogonal(m, gain=nn.init.calculate_gain("relu")): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data, gain) + m.bias.data.fill_(0.0) + # delta-orthogonal initialization. + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + assert m.weight.size(2) == m.weight.size(3) + m.weight.data.fill_(0.0) + m.bias.data.fill_(0.0) + mid = m.weight.size(2) // 2 + nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) + + class Critic(nn.Module): def __init__( self, @@ -113,7 +126,7 @@ def __init__( output_dim=action_dim, hidden_units=hidden_units, hidden_activation=hidden_activation, - ).apply(initialize_weight) + ).apply(initialize_weight_orthogonal) self._device = device self._action_shape = action_dim @@ -125,26 +138,27 @@ def forward(self, states: t.Tensor) -> t.Tensor: def exploit(self, state: npt.ArrayLike) -> npt.NDArray: state = t.tensor(state).unsqueeze_(0).to(self._device) - return self.forward(state).detach().cpu().numpy().flatten() + with t.no_grad(): + action = self.forward(state) + return action.cpu().numpy().flatten() def explore(self, state: npt.ArrayLike) -> npt.NDArray: state = t.tensor(state, device=self._device).unsqueeze_(0) - + noise = (t.randn(self._action_shape) * self._expl_noise).to(self._device) with t.no_grad(): - noise = (t.randn(self._action_shape) * self._expl_noise).to(self._device) action = self.mlp(state) + noise - - a = action.cpu().numpy()[0] - return np.clip(a, -self._max_action, self._max_action) + action = action.cpu().numpy()[0] + return np.clip(action, -self._max_action, self._max_action) class GaussianActor(nn.Module): - def __init__(self, state_dim, action_dim, hidden_units, hidden_activation): + def __init__(self, state_dim, action_dim, hidden_units, hidden_activation, device: str): super().__init__() self.action_dim = action_dim self.net = MLP( state_dim, 2 * action_dim, hidden_units, hidden_activation=hidden_activation ) + self.device = device def forward(self, obs: t.Tensor) -> tuple[t.Tensor, t.Tensor | None]: mean, log_std = self.net(obs).split([self.action_dim, self.action_dim], dim=1) @@ -161,14 +175,17 @@ def forward(self, obs: t.Tensor) -> tuple[t.Tensor, t.Tensor | None]: log_prob = None return action, log_prob - def exploit(self, state: npt.ArrayLike) -> npt.NDArray: - state = t.tensor(state).unsqueeze_(0).to(self.device) - action, _ = self.forward(state) - return action.detach().cpu().numpy().flatten() - - @property - def device(self): - return next(self.parameters()).device + def explore(self, state: npt.NDArray) -> npt.NDArray: + state_tensor = t.tensor(state, device=self.device).unsqueeze_(0) + with t.no_grad(): + action, _ = self.forward(state_tensor) + return action.cpu().numpy()[0] + + def exploit(self, state: npt.NDArray) -> npt.NDArray: + self.eval() + action = self.explore(state) + self.train() + return action class TanhNormal(Distribution): diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py new file mode 100644 index 0000000..f6af71e --- /dev/null +++ b/src/oprl/algos/protocols.py @@ -0,0 +1,24 @@ +from typing import Protocol + +import numpy.typing as npt + +import torch as t + + +class OffPolicyAlgorithm(Protocol): + def create(self) -> "OffPolicyAlgorithm": ... + + def update( + self, + state: t.Tensor, + action: t.Tensor, + reward: t.Tensor, + done: t.Tensor, + next_state: t.Tensor + ) -> None: ... + + +class PolicyProtocol(Protocol): + def explore(self, state: npt.ArrayLike) -> npt.NDArray: ... + + def exploit(self, state: npt.ArrayLike) -> npt.NDArray: ... diff --git a/src/oprl/algos/sac.py b/src/oprl/algos/sac.py index aa1ea04..910d624 100644 --- a/src/oprl/algos/sac.py +++ b/src/oprl/algos/sac.py @@ -7,14 +7,15 @@ from torch import nn from torch.optim import Adam -from oprl.algos import OffPolicyAlgorithm -from oprl.algos.nn import DoubleCritic, GaussianActor -from oprl.algos.utils import disable_gradient, soft_update -from oprl.utils.logger import Logger, StdLogger +from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.nn_models import DoubleCritic, GaussianActor +from oprl.algos.nn_functions import disable_gradient, soft_update +from oprl.logging import LoggerProtocol @dataclass class SAC(OffPolicyAlgorithm): + logger: LoggerProtocol state_dim: int action_dim: int batch_size: int = 256 @@ -27,7 +28,6 @@ class SAC(OffPolicyAlgorithm): target_update_coef: float = 5e-3 device: str = "cpu" log_every: int = 5000 - logger: Logger = StdLogger() def create(self) -> "SAC": self.actor = GaussianActor( @@ -35,6 +35,7 @@ def create(self) -> "SAC": action_dim=self.action_dim, hidden_units=(256, 256), hidden_activation=nn.ReLU(inplace=True), + device=self.device, ).to(self.device) self.critic = DoubleCritic( @@ -144,14 +145,8 @@ def update_actor(self, state: t.Tensor) -> None: self.update_step, ) - def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: - state = t.tensor(state, device=self.device).unsqueeze_(0) - with t.no_grad(): - action, _ = self.actor(state) - return action.cpu().numpy()[0] - - def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: - self.actor.eval() - action = self.explore(state) - self.actor.train() - return action + def explore(self, state: npt.NDArray) -> npt.NDArray: + return self.actor.explore(state) + + def exploit(self, state: npt.NDArray) -> npt.NDArray: + return self.actor.exploit(state) diff --git a/src/oprl/algos/td3.py b/src/oprl/algos/td3.py index b912520..56fd615 100644 --- a/src/oprl/algos/td3.py +++ b/src/oprl/algos/td3.py @@ -1,20 +1,20 @@ from copy import deepcopy from dataclasses import dataclass -import numpy as np import numpy.typing as npt import torch as t from torch import nn from torch.optim import Adam -from oprl.algos import OffPolicyAlgorithm -from oprl.algos.nn import DeterministicPolicy, DoubleCritic -from oprl.algos.utils import disable_gradient, soft_update -from oprl.utils.logger import Logger, StdLogger +from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.nn_models import DeterministicPolicy, DoubleCritic +from oprl.algos.nn_functions import disable_gradient, soft_update +from oprl.logging import LoggerProtocol @dataclass class TD3(OffPolicyAlgorithm): + logger: LoggerProtocol state_dim: int action_dim: int batch_size: int = 256 @@ -29,7 +29,6 @@ class TD3(OffPolicyAlgorithm): tau: float = 5e-3 log_every: int = 5000 device: str = "cpu" - logger: Logger = StdLogger() update_step: int = 0 def create(self) -> "TD3": @@ -38,6 +37,7 @@ def create(self) -> "TD3": action_dim=self.action_dim, hidden_units=(256, 256), hidden_activation=nn.ReLU(inplace=True), + device=self.device, ).to(self.device) self.actor_target = deepcopy(self.actor).to(self.device).eval() disable_gradient(self.actor_target) @@ -57,22 +57,10 @@ def create(self) -> "TD3": return self def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: - state = t.tensor(state, device=self.device).unsqueeze_(0) - with t.no_grad(): - action = self.actor(state) - return action.cpu().numpy().flatten() + return self.actor.exploit(state) def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: - state = t.tensor(state, device=self.device).unsqueeze_(0) - noise = (t.randn(self.action_dim) * self.max_action * self.expl_noise).to( - self.device - ) - - with t.no_grad(): - action = self.actor(state) + noise - - a = action.cpu().numpy()[0] - return np.clip(a, -self.max_action, self.max_action) + return self.actor.explore(state) def update(self, state: t.Tensor, action, reward, done, next_state) -> None: self._update_critic(state, action, reward, done, next_state) diff --git a/src/oprl/algos/tqc.py b/src/oprl/algos/tqc.py index 8ac5e40..c167d0f 100644 --- a/src/oprl/algos/tqc.py +++ b/src/oprl/algos/tqc.py @@ -6,9 +6,9 @@ import torch as t import torch.nn as nn -from oprl.algos import OffPolicyAlgorithm -from oprl.algos.nn import MLP, GaussianActor -from oprl.utils.logger import Logger, StdLogger +from oprl.algos.protocol import OffPolicyAlgorithm +from oprl.algos.nn_models import MLP, GaussianActor +from oprl.logging import LoggerProtocol def quantile_huber_loss_f( @@ -64,6 +64,7 @@ def forward(self, state: t.Tensor, action: t.Tensor) -> t.Tensor: @dataclass class TQC(OffPolicyAlgorithm): + logger: LoggerProtocol state_dim: int action_dim: int discount: float = 0.99 @@ -73,7 +74,6 @@ class TQC(OffPolicyAlgorithm): n_nets: int = 5 log_every: int = 5000 device: str = "cpu" - logger: Logger = StdLogger() update_step = 0 def create(self) -> "TQC": @@ -83,6 +83,7 @@ def create(self) -> "TQC": self.action_dim, hidden_units=(256, 256), hidden_activation=nn.ReLU(), + device=self.device, ).to(self.device) self.critic = QuantileQritic( self.state_dim, @@ -176,14 +177,8 @@ def update( self.update_step += 1 - def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: - state = t.tensor(state, device=self.device).unsqueeze_(0) - with t.no_grad(): - action, _ = self.actor(state) - return action.cpu().numpy()[0] - - def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: - self.actor.eval() - action = self.explore(state) - self.actor.train() - return action + def explore(self, state: npt.NDArray) -> npt.NDArray: + return self.actor.explore(state) + + def exploit(self, state: npt.NDArray) -> npt.NDArray: + return self.actor.exploit(state) diff --git a/src/oprl/algos/utils.py b/src/oprl/algos/utils.py deleted file mode 100644 index 2ad4fd0..0000000 --- a/src/oprl/algos/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import torch.nn as nn - - -class Clamp(nn.Module): - def forward(self, log_stds): - return log_stds.clamp_(-20, 2) - - -def initialize_weight(m, gain=nn.init.calculate_gain("relu")): - # Initialize linear layers with the orthogonal initialization. - if isinstance(m, nn.Linear): - nn.init.orthogonal_(m.weight.data, gain) - m.bias.data.fill_(0.0) - - # Initialize conv layers with the delta-orthogonal initialization. - elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - assert m.weight.size(2) == m.weight.size(3) - m.weight.data.fill_(0.0) - m.bias.data.fill_(0.0) - mid = m.weight.size(2) // 2 - nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) - - -def soft_update(target, source, tau): - """Update target network using Polyak-Ruppert Averaging.""" - with torch.no_grad(): - for tgt, src in zip(target.parameters(), source.parameters()): - tgt.data.mul_(1.0 - tau) - tgt.data.add_(tau * src.data) - - -def disable_gradient(network): - """Disable gradient calculations of the network.""" - for param in network.parameters(): - param.requires_grad = False diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index 5eef9c3..7be132d 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -1,26 +1,12 @@ from dataclasses import dataclass, field import os import pickle -from typing import Protocol import numpy as np import numpy.typing as npt import torch as t - -class ReplayBufferProtocol(Protocol): - def add_transition(self, state, action, reward, done, episode_done=None): ... - - def add_episode(self, episode): ... - - def sample(self, batch_size) -> tuple[ - t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor - ]: ... - - def save(self, path: str) -> None: ... - - @property - def last_episode_length(self) -> int: ... +from oprl.buffers.protocols import ReplayBufferProtocol @dataclass @@ -28,7 +14,7 @@ class EpisodicReplayBuffer(ReplayBufferProtocol): buffer_size: int state_dim: int action_dim: int - gamma: float + gamma: float = 0.99 max_episode_lenth: int = 1000 device: str = "cpu" episodes_counter: int = 1 diff --git a/src/oprl/buffers/protocols.py b/src/oprl/buffers/protocols.py new file mode 100644 index 0000000..d70b781 --- /dev/null +++ b/src/oprl/buffers/protocols.py @@ -0,0 +1,20 @@ +from typing import Protocol + +import torch as t + + +class ReplayBufferProtocol(Protocol): + def create(self) -> "ReplayBufferProtocol": ... + def add_transition(self, state, action, reward, done, episode_done=None): ... + + def add_episode(self, episode): ... + + def sample(self, batch_size) -> tuple[ + t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor + ]: ... + + def save(self, path: str) -> None: ... + + @property + def last_episode_length(self) -> int: ... + diff --git a/src/oprl/environment/__init__.py b/src/oprl/environment/__init__.py index 26f88ef..1dc2688 100644 --- a/src/oprl/environment/__init__.py +++ b/src/oprl/environment/__init__.py @@ -1,4 +1,4 @@ -from oprl.environment.protocol import EnvProtocol +from oprl.environment.protocols import EnvProtocol from oprl.environment.dm_control import DMControlEnv from oprl.environment.safety_gymnasium import SafetyGym from oprl.environment.make_env import make_env diff --git a/src/oprl/environment/dm_control.py b/src/oprl/environment/dm_control.py index 2ce564f..02db9df 100644 --- a/src/oprl/environment/dm_control.py +++ b/src/oprl/environment/dm_control.py @@ -5,7 +5,7 @@ import numpy.typing as npt from dm_control import suite -from oprl.environment import EnvProtocol +from oprl.environment.protocols import EnvProtocol class DMControlEnv(EnvProtocol): diff --git a/src/oprl/environment/protocol.py b/src/oprl/environment/protocols.py similarity index 100% rename from src/oprl/environment/protocol.py rename to src/oprl/environment/protocols.py diff --git a/src/oprl/environment/safety_gymnasium.py b/src/oprl/environment/safety_gymnasium.py index 0d31a31..10c1614 100644 --- a/src/oprl/environment/safety_gymnasium.py +++ b/src/oprl/environment/safety_gymnasium.py @@ -1,7 +1,7 @@ import numpy.typing as npt from typing import Any -from oprl.environment import EnvProtocol +from oprl.environment.protocols import EnvProtocol class SafetyGym(EnvProtocol): diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index 23a6f11..dd3c849 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -1,19 +1,15 @@ from dataclasses import dataclass -from typing import Any, Callable, Protocol +from typing import Any, Callable import numpy as np import torch -from oprl.algos import OffPolicyAlgorithm +from oprl.algos.protocols import OffPolicyAlgorithm from oprl.environment import EnvProtocol -from oprl.buffers.episodic_buffer import ReplayBufferProtocol +from oprl.buffers.protocols import ReplayBufferProtocol from oprl.logging import LoggerProtocol - -class TrainerProtocol(Protocol): - def train(self) -> None: ... - - def evaluate(self) -> dict[str, float]: ... +from oprl.trainers.protocols import TrainerProtocol @dataclass diff --git a/src/oprl/trainers/protocols.py b/src/oprl/trainers/protocols.py new file mode 100644 index 0000000..6b3e986 --- /dev/null +++ b/src/oprl/trainers/protocols.py @@ -0,0 +1,7 @@ +from typing import Protocol + +class TrainerProtocol(Protocol): + def train(self) -> None: ... + + def evaluate(self) -> dict[str, float]: ... + From f4dc0945496b8a43c7ac2012ab3a2e3df77c9f48 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sat, 12 Jul 2025 23:03:48 +0400 Subject: [PATCH 10/30] Introduce base algorithm with explore and exploit functionality --- configs/ddpg.py | 9 ++++---- configs/sac.py | 9 ++++---- configs/td3.py | 9 ++++---- configs/tqc.py | 9 ++++---- src/oprl/algos/base_algorithm.py | 18 ++++++++++++++++ src/oprl/algos/ddpg.py | 19 +++++++---------- src/oprl/algos/nn_models.py | 1 - src/oprl/algos/protocols.py | 16 +++++++++------ src/oprl/algos/sac.py | 32 +++++++++++++++-------------- src/oprl/algos/td3.py | 21 +++++++++++-------- src/oprl/algos/tqc.py | 28 ++++++++++++++----------- src/oprl/buffers/episodic_buffer.py | 10 ++++----- src/oprl/train.py | 1 - src/oprl/trainers/base_trainer.py | 30 ++++----------------------- 14 files changed, 109 insertions(+), 103 deletions(-) create mode 100644 src/oprl/algos/base_algorithm.py diff --git a/configs/ddpg.py b/configs/ddpg.py index 6b0bd99..1e55a65 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -1,6 +1,7 @@ -from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.protocols import AlgorithmProtocol from oprl.algos.ddpg import DDPG -from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer +from oprl.buffers.protocols import ReplayBufferProtocol +from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.parse_args import parse_args from oprl.logging import ( set_logging, @@ -39,7 +40,7 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: +def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: return DDPG( state_dim=STATE_DIM, action_dim=ACTION_DIM, @@ -50,7 +51,7 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config["num_steps"], int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, device=config["device"], diff --git a/configs/sac.py b/configs/sac.py index b5750b0..345ee5c 100644 --- a/configs/sac.py +++ b/configs/sac.py @@ -1,4 +1,4 @@ -from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.protocols import AlgorithmProtocol from oprl.algos.sac import SAC from oprl.parse_args import parse_args from oprl.logging import ( @@ -8,7 +8,8 @@ ) set_logging() from oprl.environment import make_env as _make_env -from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer +from oprl.buffers.protocols import ReplayBufferProtocol +from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.train import run_training args = parse_args() @@ -39,7 +40,7 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: +def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: return SAC( state_dim=STATE_DIM, action_dim=ACTION_DIM, @@ -57,7 +58,7 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config["num_steps"], int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, device=config["device"], diff --git a/configs/td3.py b/configs/td3.py index d4dcb16..31dc557 100644 --- a/configs/td3.py +++ b/configs/td3.py @@ -1,7 +1,8 @@ -from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.protocols import AlgorithmProtocol from oprl.algos.td3 import TD3 from oprl.parse_args import parse_args -from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer +from oprl.buffers.protocols import ReplayBufferProtocol +from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.logging import ( set_logging, LoggerProtocol, @@ -38,7 +39,7 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: +def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: return TD3( state_dim=STATE_DIM, action_dim=ACTION_DIM, @@ -49,7 +50,7 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config["num_steps"], int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, device=config["device"], diff --git a/configs/tqc.py b/configs/tqc.py index 882d9d3..fb051bf 100644 --- a/configs/tqc.py +++ b/configs/tqc.py @@ -1,4 +1,4 @@ -from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.protocols import AlgorithmProtocol from oprl.algos.tqc import TQC from oprl.parse_args import parse_args from oprl.logging import ( @@ -8,7 +8,8 @@ ) set_logging() from oprl.environment import make_env as _make_env -from oprl.buffers.episodic_buffer import ReplayBufferProtocol, EpisodicReplayBuffer +from oprl.buffers.protocols import ReplayBufferProtocol +from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.train import run_training args = parse_args() @@ -39,7 +40,7 @@ def make_env(seed: int): # ----------------------------------- -def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: +def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: return TQC( state_dim=STATE_DIM, action_dim=ACTION_DIM, @@ -57,7 +58,7 @@ def make_algo(logger: LoggerProtocol) -> OffPolicyAlgorithm: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config["num_steps"], int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, device=config["device"], diff --git a/src/oprl/algos/base_algorithm.py b/src/oprl/algos/base_algorithm.py new file mode 100644 index 0000000..a689026 --- /dev/null +++ b/src/oprl/algos/base_algorithm.py @@ -0,0 +1,18 @@ + +from abc import ABC +from typing import Any + +import numpy.typing as npt + +from oprl.algos.protocols import AlgorithmProtocol + + +class OffPolicyAlgorithm(ABC, AlgorithmProtocol): + def exploit(self, state: npt.NDArray) -> npt.NDArray: + return self.actor.exploit(state) + + def explore(self, state: npt.NDArray) -> npt.NDArray: + return self.actor.explore(state) + + def get_policy_state_dict(self) -> dict[str, Any]: + return self.actor.state_dict() diff --git a/src/oprl/algos/ddpg.py b/src/oprl/algos/ddpg.py index c983a20..fd44be4 100644 --- a/src/oprl/algos/ddpg.py +++ b/src/oprl/algos/ddpg.py @@ -1,12 +1,11 @@ from copy import deepcopy from dataclasses import dataclass, field -from typing import Any -import numpy.typing as npt import torch as t from torch import nn -from oprl.algos.protocols import OffPolicyAlgorithm, PolicyProtocol +from oprl.algos.protocols import PolicyProtocol +from oprl.algos.base_algorithm import OffPolicyAlgorithm from oprl.algos.nn_models import Critic, DeterministicPolicy from oprl.algos.nn_functions import disable_gradient from oprl.logging import LoggerProtocol @@ -16,9 +15,9 @@ @dataclass class DDPG(OffPolicyAlgorithm): + logger: LoggerProtocol state_dim: int action_dim: int - logger: LoggerProtocol expl_noise: float = 0.1 discount: float = 0.99 tau: float = 5e-3 @@ -27,7 +26,11 @@ class DDPG(OffPolicyAlgorithm): device: str = "cpu" actor: PolicyProtocol = field(init=False) + actor_target: PolicyProtocol = field(init=False) + optim_actor: t.optim.Optimizer = field(init=False) critic: nn.Module = field(init=False) + critic_target: nn.Module = field(init=False) + optim_critic: t.optim.Optimizer = field(init=False) def create(self) -> "DDPG": self.actor = DeterministicPolicy( @@ -98,11 +101,3 @@ def _update_actor(self, state: t.Tensor) -> None: actor_loss.backward() self.optim_actor.step() - def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: - return self.actor.exploit(state) - - def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: - return self.actor.explore(state) - - def get_policy_state_dict(self) -> dict[str, Any]: - return self.actor.state_dict() diff --git a/src/oprl/algos/nn_models.py b/src/oprl/algos/nn_models.py index bb69064..547e289 100644 --- a/src/oprl/algos/nn_models.py +++ b/src/oprl/algos/nn_models.py @@ -1,5 +1,4 @@ import numpy as np -from numpy._typing import NDArray import numpy.typing as npt import torch as t import torch.nn as nn diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py index f6af71e..e9a7b0c 100644 --- a/src/oprl/algos/protocols.py +++ b/src/oprl/algos/protocols.py @@ -5,8 +5,16 @@ import torch as t -class OffPolicyAlgorithm(Protocol): - def create(self) -> "OffPolicyAlgorithm": ... +class PolicyProtocol(Protocol): + def explore(self, state: npt.ArrayLike) -> npt.NDArray: ... + + def exploit(self, state: npt.ArrayLike) -> npt.NDArray: ... + + +class AlgorithmProtocol(Protocol): + actor: PolicyProtocol + + def create(self) -> "AlgorithmProtocol": ... def update( self, @@ -18,7 +26,3 @@ def update( ) -> None: ... -class PolicyProtocol(Protocol): - def explore(self, state: npt.ArrayLike) -> npt.NDArray: ... - - def exploit(self, state: npt.ArrayLike) -> npt.NDArray: ... diff --git a/src/oprl/algos/sac.py b/src/oprl/algos/sac.py index 910d624..b459fbc 100644 --- a/src/oprl/algos/sac.py +++ b/src/oprl/algos/sac.py @@ -1,13 +1,13 @@ from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np -import numpy.typing as npt import torch as t from torch import nn from torch.optim import Adam -from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.protocols import PolicyProtocol +from oprl.algos.base_algorithm import OffPolicyAlgorithm from oprl.algos.nn_models import DoubleCritic, GaussianActor from oprl.algos.nn_functions import disable_gradient, soft_update from oprl.logging import LoggerProtocol @@ -28,6 +28,15 @@ class SAC(OffPolicyAlgorithm): target_update_coef: float = 5e-3 device: str = "cpu" log_every: int = 5000 + + actor: PolicyProtocol = field(init=False) + actor_target: PolicyProtocol = field(init=False) + optim_actor: t.optim.Optimizer = field(init=False) + critic: nn.Module = field(init=False) + critic_target: nn.Module = field(init=False) + optim_critic: t.optim.Optimizer = field(init=False) + alpha: float = field(init=False) + update_step: int = field(init=False) def create(self) -> "SAC": self.actor = GaussianActor( @@ -51,10 +60,10 @@ def create(self) -> "SAC": self.optim_actor = Adam(self.actor.parameters(), lr=self.lr_actor) self.optim_critic = Adam(self.critic.parameters(), lr=self.lr_critic) - self._alpha = self.alpha_init + self.alpha = self.alpha_init if self.tune_alpha: self.log_alpha = t.tensor( - np.log(self._alpha), device=self.device, requires_grad=True + np.log(self.alpha), device=self.device, requires_grad=True ) self.optim_alpha = t.optim.Adam([self.log_alpha], lr=self.lr_alpha) self.target_entropy = -float(self.action_dim) @@ -73,7 +82,6 @@ def update( self.update_critic(state, action, reward, done, next_state) self.update_actor(state) soft_update(self.critic_target, self.critic, self.target_update_coef) - self.update_step += 1 def update_critic( @@ -88,7 +96,7 @@ def update_critic( with t.no_grad(): next_actions, log_pis = self.actor(next_states) q1_next, q2_next = self.critic_target(next_states, next_actions) - q_next = t.min(q1_next, q2_next) - self._alpha * log_pis + q_next = t.min(q1_next, q2_next) - self.alpha * log_pis q_target = rewards + (1.0 - dones) * self.gamma * q_next @@ -114,7 +122,7 @@ def update_critic( def update_actor(self, state: t.Tensor) -> None: actions, log_pi = self.actor(state) qs1, qs2 = self.critic(state, actions) - loss_actor = self._alpha * log_pi.mean() - t.min(qs1, qs2).mean() + loss_actor = self.alpha * log_pi.mean() - t.min(qs1, qs2).mean() self.optim_actor.zero_grad() loss_actor.backward() @@ -129,7 +137,7 @@ def update_actor(self, state: t.Tensor) -> None: loss_alpha.backward() self.optim_alpha.step() with t.no_grad(): - self._alpha = self.log_alpha.exp().item() + self.alpha = self.log_alpha.exp().item() if self.update_step % self.log_every == 0: if self.tune_alpha: @@ -144,9 +152,3 @@ def update_actor(self, state: t.Tensor) -> None: }, self.update_step, ) - - def explore(self, state: npt.NDArray) -> npt.NDArray: - return self.actor.explore(state) - - def exploit(self, state: npt.NDArray) -> npt.NDArray: - return self.actor.exploit(state) diff --git a/src/oprl/algos/td3.py b/src/oprl/algos/td3.py index 56fd615..ea6e749 100644 --- a/src/oprl/algos/td3.py +++ b/src/oprl/algos/td3.py @@ -1,12 +1,12 @@ from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field -import numpy.typing as npt import torch as t from torch import nn from torch.optim import Adam -from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.protocols import PolicyProtocol +from oprl.algos.base_algorithm import OffPolicyAlgorithm from oprl.algos.nn_models import DeterministicPolicy, DoubleCritic from oprl.algos.nn_functions import disable_gradient, soft_update from oprl.logging import LoggerProtocol @@ -29,7 +29,14 @@ class TD3(OffPolicyAlgorithm): tau: float = 5e-3 log_every: int = 5000 device: str = "cpu" - update_step: int = 0 + + actor: PolicyProtocol = field(init=False) + actor_target: PolicyProtocol = field(init=False) + optim_actor: t.optim.Optimizer = field(init=False) + critic: nn.Module = field(init=False) + critic_target: nn.Module = field(init=False) + optim_critic: t.optim.Optimizer = field(init=False) + update_step: int = field(init=False) def create(self) -> "TD3": self.actor = DeterministicPolicy( @@ -37,6 +44,7 @@ def create(self) -> "TD3": action_dim=self.action_dim, hidden_units=(256, 256), hidden_activation=nn.ReLU(inplace=True), + expl_noise=self.expl_noise, device=self.device, ).to(self.device) self.actor_target = deepcopy(self.actor).to(self.device).eval() @@ -56,11 +64,6 @@ def create(self) -> "TD3": self.optim_critic = Adam(self.critic.parameters(), lr=self.lr_critic) return self - def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: - return self.actor.exploit(state) - - def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: - return self.actor.explore(state) def update(self, state: t.Tensor, action, reward, done, next_state) -> None: self._update_critic(state, action, reward, done, next_state) diff --git a/src/oprl/algos/tqc.py b/src/oprl/algos/tqc.py index c167d0f..e632896 100644 --- a/src/oprl/algos/tqc.py +++ b/src/oprl/algos/tqc.py @@ -1,12 +1,12 @@ import copy -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np -import numpy.typing as npt import torch as t import torch.nn as nn -from oprl.algos.protocol import OffPolicyAlgorithm +from oprl.algos.protocol import PolicyProtocol +from oprl.algos.base_algorithm import OffPolicyAlgorithm from oprl.algos.nn_models import MLP, GaussianActor from oprl.logging import LoggerProtocol @@ -74,7 +74,17 @@ class TQC(OffPolicyAlgorithm): n_nets: int = 5 log_every: int = 5000 device: str = "cpu" - update_step = 0 + + actor: PolicyProtocol = field(init=False) + actor_target: PolicyProtocol = field(init=False) + actor_optimizer: t.optim.Optimizer = field(init=False) + critic: QuantileQritic = field(init=False) + critic_target: QuantileQritic = field(init=False) + critic_optimizer: t.optim.Optimizer = field(init=False) + target_entropy: float = field(init=False) + alpha_optimizer: t.optim.Optimizer = field(init=False) + quantiles_total: int = field(init=False) + update_step: int = field(init=False) def create(self) -> "TQC": self.target_entropy = -np.prod(self.action_dim).item() @@ -93,7 +103,7 @@ def create(self) -> "TQC": ).to(self.device) self.critic_target = copy.deepcopy(self.critic) self.log_alpha = t.tensor(np.log(0.2), requires_grad=True, device=self.device) - self._quantiles_total = self.critic.n_quantiles * self.critic.n_nets + self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets # TODO: check hyperparams self.actor_optimizer = t.optim.Adam(self.actor.parameters(), lr=3e-4) @@ -125,7 +135,7 @@ def update( ) # batch x nets x quantiles sorted_z, _ = t.sort(next_z.reshape(batch_size, -1)) sorted_z_part = sorted_z[ - :, : self._quantiles_total - self.top_quantiles_to_drop + :, : self.quantiles_total - self.top_quantiles_to_drop ] # compute target @@ -176,9 +186,3 @@ def update( ) self.update_step += 1 - - def explore(self, state: npt.NDArray) -> npt.NDArray: - return self.actor.explore(state) - - def exploit(self, state: npt.NDArray) -> npt.NDArray: - return self.actor.exploit(state) diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index 7be132d..93390b7 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -11,15 +11,15 @@ @dataclass class EpisodicReplayBuffer(ReplayBufferProtocol): - buffer_size: int + buffer_size_transitions: int state_dim: int action_dim: int gamma: float = 0.99 max_episode_lenth: int = 1000 - device: str = "cpu" episodes_counter: int = 1 + device: str = "cpu" - _tensors: dict[str, t.Tensor] = field(default_factory=dict, init=False) + _tensors: dict[str, t.Tensor] = field(init=False) _max_episodes: int = field(init=False) _ep_pointer: int = 0 _number_transitions = 0 @@ -30,7 +30,7 @@ def _check_if_created(self) -> None: raise RuntimeError("Trying to work with non created buffer. Invoke .create() first.") def create(self) -> "EpisodicReplayBuffer": - self._max_episodes = self.buffer_size // self.max_episode_lenth + self._max_episodes = self.buffer_size_transitions // self.max_episode_lenth self._tensors = { "actions": t.empty( (self._max_episodes, self.max_episode_lenth, self.action_dim), @@ -87,7 +87,7 @@ def add_transition(self, state: npt.ArrayLike, action: npt.ArrayLike, reward: fl self.rewards[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(reward) self.dones[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(done) self.ep_lens[self._ep_pointer] += 1 - self._number_transitions = min(self._number_transitions + 1, self.buffer_size) + self._number_transitions = min(self._number_transitions + 1, self.buffer_size_transitions) # TODO: Switch to the episodic append and remove condition below if episode_done: self._inc_episode() diff --git a/src/oprl/train.py b/src/oprl/train.py index 374a74e..475ad7a 100644 --- a/src/oprl/train.py +++ b/src/oprl/train.py @@ -63,5 +63,4 @@ def _run_training_func(make_algo, make_env, make_replay_buffer, make_logger, con trainer = SafeTrainer(trainer=base_trainer) else: raise ValueError(f"Unsupported env family: {env.env_family}") - trainer.train() diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index dd3c849..be1089e 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Any, Callable +from typing import Callable import numpy as np import torch -from oprl.algos.protocols import OffPolicyAlgorithm +from oprl.algos.protocols import AlgorithmProtocol from oprl.environment import EnvProtocol from oprl.buffers.protocols import ReplayBufferProtocol from oprl.logging import LoggerProtocol @@ -14,10 +14,11 @@ @dataclass class BaseTrainer(TrainerProtocol): + logger: LoggerProtocol env: EnvProtocol make_env_test: Callable[[int], EnvProtocol] replay_buffer: ReplayBufferProtocol - algo: OffPolicyAlgorithm | None = None + algo: AlgorithmProtocol gamma: float = 0.99 num_steps: int = int(1e6) start_steps: int = int(10e3) @@ -30,7 +31,6 @@ class BaseTrainer(TrainerProtocol): stdout_log_every: int = int(1e5) device: str = "cpu" seed: int = 0 - logger: LoggerProtocol | None = None def train(self) -> None: ep_step = 0 @@ -166,25 +166,3 @@ def estimate_critic_q(self, num_episodes: int = 10) -> float: qs.append(q) return np.mean(qs, dtype=float) - - -def run_training(makealgo, makeenv, make_replay_buffer, makelogger, config: dict[str, Any], seed: int): - env = makeenv(seed=seed) - logger = makelogger(seed) - - trainer = BaseTrainer( - env=env, - makeenv_test=makeenv, - algo=makealgo(logger, seed), - replay_buffer=make_replay_buffer(), - num_steps=config["num_steps"], - eval_interval=config["eval_every"], - device=config["device"], - save_buffer_every=config["save_buffer"], - estimate_q_every=config["estimate_q_every"], - stdout_log_every=config["log_every"], - seed=seed, - logger=logger, - ) - - trainer.train() From efe2f9dd5074cca551e8978b3618c99115658810 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sat, 12 Jul 2025 23:20:42 +0400 Subject: [PATCH 11/30] Ensure algo and buffer has been created in trainer --- src/oprl/algos/base_algorithm.py | 6 ++++++ src/oprl/algos/ddpg.py | 3 +++ src/oprl/algos/protocols.py | 3 +++ src/oprl/algos/tqc.py | 7 +++++-- src/oprl/buffers/episodic_buffer.py | 8 ++++---- src/oprl/buffers/protocols.py | 5 +++++ src/oprl/trainers/base_trainer.py | 5 +++-- src/oprl/trainers/safe_trainer.py | 4 +++- 8 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/oprl/algos/base_algorithm.py b/src/oprl/algos/base_algorithm.py index a689026..35f2ff7 100644 --- a/src/oprl/algos/base_algorithm.py +++ b/src/oprl/algos/base_algorithm.py @@ -8,6 +8,12 @@ class OffPolicyAlgorithm(ABC, AlgorithmProtocol): + def check_created(self) -> None: + if not self._created: + raise RuntimeError( + f"Algorithm {type(self).__name__} has not been created with `create()`." + ) + def exploit(self, state: npt.NDArray) -> npt.NDArray: return self.actor.exploit(state) diff --git a/src/oprl/algos/ddpg.py b/src/oprl/algos/ddpg.py index fd44be4..3ce17ec 100644 --- a/src/oprl/algos/ddpg.py +++ b/src/oprl/algos/ddpg.py @@ -31,6 +31,7 @@ class DDPG(OffPolicyAlgorithm): critic: nn.Module = field(init=False) critic_target: nn.Module = field(init=False) optim_critic: t.optim.Optimizer = field(init=False) + _created: bool = False def create(self) -> "DDPG": self.actor = DeterministicPolicy( @@ -50,6 +51,8 @@ def create(self) -> "DDPG": self.critic_target = deepcopy(self.critic) disable_gradient(self.critic_target) self.optim_critic = t.optim.Adam(self.critic.parameters(), lr=3e-4) + + self._created = True return self def update( diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py index e9a7b0c..e505fc1 100644 --- a/src/oprl/algos/protocols.py +++ b/src/oprl/algos/protocols.py @@ -13,9 +13,12 @@ def exploit(self, state: npt.ArrayLike) -> npt.NDArray: ... class AlgorithmProtocol(Protocol): actor: PolicyProtocol + _created: bool def create(self) -> "AlgorithmProtocol": ... + def check_created(self) -> None: ... + def update( self, state: t.Tensor, diff --git a/src/oprl/algos/tqc.py b/src/oprl/algos/tqc.py index e632896..a0abf3f 100644 --- a/src/oprl/algos/tqc.py +++ b/src/oprl/algos/tqc.py @@ -5,7 +5,7 @@ import torch as t import torch.nn as nn -from oprl.algos.protocol import PolicyProtocol +from oprl.algos.protocols import PolicyProtocol from oprl.algos.base_algorithm import OffPolicyAlgorithm from oprl.algos.nn_models import MLP, GaussianActor from oprl.logging import LoggerProtocol @@ -84,7 +84,8 @@ class TQC(OffPolicyAlgorithm): target_entropy: float = field(init=False) alpha_optimizer: t.optim.Optimizer = field(init=False) quantiles_total: int = field(init=False) - update_step: int = field(init=False) + update_step: int = 0 + _created: bool = False def create(self) -> "TQC": self.target_entropy = -np.prod(self.action_dim).item() @@ -110,6 +111,8 @@ def create(self) -> "TQC": self.critic_optimizer = t.optim.Adam(self.critic.parameters(), lr=3e-4) self.alpha_optimizer = t.optim.Adam([self.log_alpha], lr=3e-4) + self.udpate_step = 0 + self._created = True return self def update( diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index 93390b7..dd33e75 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -25,10 +25,6 @@ class EpisodicReplayBuffer(ReplayBufferProtocol): _number_transitions = 0 _created: bool = False - def _check_if_created(self) -> None: - if not self._created: - raise RuntimeError("Trying to work with non created buffer. Invoke .create() first.") - def create(self) -> "EpisodicReplayBuffer": self._max_episodes = self.buffer_size_transitions // self.max_episode_lenth self._tensors = { @@ -57,6 +53,10 @@ def create(self) -> "EpisodicReplayBuffer": self._created = True return self + def check_created(self) -> None: + if not self._created: + raise RuntimeError("Replay buffer has to be created with `.create()`.") + @property def states(self) -> t.Tensor: self._check_if_created() diff --git a/src/oprl/buffers/protocols.py b/src/oprl/buffers/protocols.py index d70b781..7a0d2a0 100644 --- a/src/oprl/buffers/protocols.py +++ b/src/oprl/buffers/protocols.py @@ -4,7 +4,12 @@ class ReplayBufferProtocol(Protocol): + _created: bool + def create(self) -> "ReplayBufferProtocol": ... + + def check_created(self) -> None: ... + def add_transition(self, state, action, reward, done, episode_done=None): ... def add_episode(self, episode): ... diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index be1089e..b66f18c 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -33,9 +33,11 @@ class BaseTrainer(TrainerProtocol): seed: int = 0 def train(self) -> None: + self.algo.check_created() + self.replay_buffer.check_created() + ep_step = 0 state, _ = self.env.reset() - for env_step in range(self.num_steps + 1): ep_step += 1 if env_step <= self.start_steps: @@ -66,7 +68,6 @@ def _log_evaluation(self, env_step: int, batch): if env_step % self.eval_interval == 0: eval_metrics = self.evaluate() self.logger.log_scalar("trainer/ep_reward", eval_metrics["return"], env_step) - self.logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) self.logger.log_scalar( "trainer/buffer_transitions", len(self.replay_buffer), env_step diff --git a/src/oprl/trainers/safe_trainer.py b/src/oprl/trainers/safe_trainer.py index 8124742..8aa038f 100644 --- a/src/oprl/trainers/safe_trainer.py +++ b/src/oprl/trainers/safe_trainer.py @@ -10,10 +10,12 @@ class SafeTrainer(TrainerProtocol): trainer: BaseTrainer def train(self): + self.algo.check_created() + self.replay_buffer.check_created() + ep_step = 0 state, _ = self.trainer.env.reset() total_cost = 0 - for env_step in range(self.trainer.num_steps + 1): ep_step += 1 if env_step <= self.trainer.start_steps: From d4b54866818b9058c246b16db1491ee287ec1fe8 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sun, 13 Jul 2025 02:09:49 +0400 Subject: [PATCH 12/30] Add proper stdout logging instead of bare prints --- configs/ddpg.py | 9 ++- configs/sac.py | 2 +- configs/td3.py | 3 +- configs/tqc.py | 2 +- scripts/visualize_policy_from_weights.py | 6 +- src/oprl/algos/base_algorithm.py | 1 - src/oprl/algos/nn_functions.py | 4 +- src/oprl/algos/tqc.py | 8 +-- src/oprl/buffers/episodic_buffer.py | 27 ++------- src/oprl/logging.py | 58 ++++++++++--------- src/oprl/parse_args.py | 1 - src/oprl/{ => runners}/train.py | 0 .../train_distrib.py} | 3 - src/oprl/trainers/base_trainer.py | 16 ++--- 14 files changed, 57 insertions(+), 83 deletions(-) rename src/oprl/{ => runners}/train.py (100%) rename src/oprl/{distrib_train.py => runners/train_distrib.py} (98%) diff --git a/configs/ddpg.py b/configs/ddpg.py index 1e55a65..11f487b 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -4,16 +4,14 @@ from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.parse_args import parse_args from oprl.logging import ( - set_logging, LoggerProtocol, make_text_logger_func, ) -set_logging() from oprl.environment import make_env as _make_env -from oprl.train import run_training +from oprl.runners.train import run_training -args = parse_args() +args = parse_args() def make_env(seed: int): return _make_env(args.env, seed=seed) @@ -66,7 +64,6 @@ def make_replay_buffer() -> ReplayBufferProtocol: if __name__ == "__main__": - args = parse_args() run_training( make_algo=make_algo, make_env=make_env, @@ -76,3 +73,5 @@ def make_replay_buffer() -> ReplayBufferProtocol: seeds=args.seeds, start_seed=args.start_seed ) + + diff --git a/configs/sac.py b/configs/sac.py index 345ee5c..17deae7 100644 --- a/configs/sac.py +++ b/configs/sac.py @@ -10,7 +10,7 @@ from oprl.environment import make_env as _make_env from oprl.buffers.protocols import ReplayBufferProtocol from oprl.buffers.episodic_buffer import EpisodicReplayBuffer -from oprl.train import run_training +from oprl.runners.train import run_training args = parse_args() diff --git a/configs/td3.py b/configs/td3.py index 31dc557..61b6d77 100644 --- a/configs/td3.py +++ b/configs/td3.py @@ -10,7 +10,7 @@ ) set_logging() from oprl.environment import make_env as _make_env -from oprl.train import run_training +from oprl.runners.train import run_training args = parse_args() @@ -65,7 +65,6 @@ def make_replay_buffer() -> ReplayBufferProtocol: if __name__ == "__main__": - args = parse_args() run_training( make_algo=make_algo, make_env=make_env, diff --git a/configs/tqc.py b/configs/tqc.py index fb051bf..fd2db34 100644 --- a/configs/tqc.py +++ b/configs/tqc.py @@ -10,7 +10,7 @@ from oprl.environment import make_env as _make_env from oprl.buffers.protocols import ReplayBufferProtocol from oprl.buffers.episodic_buffer import EpisodicReplayBuffer -from oprl.train import run_training +from oprl.runners.train import run_training args = parse_args() diff --git a/scripts/visualize_policy_from_weights.py b/scripts/visualize_policy_from_weights.py index 2683785..8e4b3a7 100644 --- a/scripts/visualize_policy_from_weights.py +++ b/scripts/visualize_policy_from_weights.py @@ -1,5 +1,5 @@ import click -import torch +import torch as t import numpy as np from PIL import Image @@ -63,7 +63,7 @@ def create_webp_gif(numpy_arrays, output_path, duration=100, loop=0): def visualize_policy(policy, output, env, seed): env = make_env(env, seed=seed) - actor = torch.load(policy, weights_only=False) + actor = t.load(policy, weights_only=False) print("Actor loaded: ", type(actor)) imgs = [] @@ -72,7 +72,7 @@ def visualize_policy(policy, output, env, seed): while not done: img = np.expand_dims(env.render(), axis=0) # [1, W, H, C] imgs.append(img) - action = actor.exploit(torch.from_numpy(state)) + action = actor.exploit(t.from_numpy(state)) state, _, terminated, truncated, _ = env.step(action) done = terminated or truncated diff --git a/src/oprl/algos/base_algorithm.py b/src/oprl/algos/base_algorithm.py index 35f2ff7..c3e5aa6 100644 --- a/src/oprl/algos/base_algorithm.py +++ b/src/oprl/algos/base_algorithm.py @@ -1,4 +1,3 @@ - from abc import ABC from typing import Any diff --git a/src/oprl/algos/nn_functions.py b/src/oprl/algos/nn_functions.py index c57726c..72db6d6 100644 --- a/src/oprl/algos/nn_functions.py +++ b/src/oprl/algos/nn_functions.py @@ -1,9 +1,9 @@ -import torch +import torch as t def soft_update(target, source, tau): """Update target network using Polyak-Ruppert Averaging.""" - with torch.no_grad(): + with t.no_grad(): for tgt, src in zip(target.parameters(), source.parameters()): tgt.data.mul_(1.0 - tau) tgt.data.add_(tau * src.data) diff --git a/src/oprl/algos/tqc.py b/src/oprl/algos/tqc.py index a0abf3f..c803525 100644 --- a/src/oprl/algos/tqc.py +++ b/src/oprl/algos/tqc.py @@ -15,12 +15,8 @@ def quantile_huber_loss_f( quantiles: t.Tensor, samples: t.Tensor, device: str ) -> t.Tensor: """ - Args: - quantiles: [batch, n_nets, n_quantiles]. - samples: [batch, n_nets * n_quantiles - top_quantiles_to_drop]. - - Returns: - loss as a torch value. + quantiles: [batch, n_nets, n_quantiles]. + samples: [batch, n_nets * n_quantiles - top_quantiles_to_drop]. """ pairwise_delta = ( samples[:, None, None, :] - quantiles[:, :, :, None] diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index dd33e75..4b0802b 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -59,22 +59,22 @@ def check_created(self) -> None: @property def states(self) -> t.Tensor: - self._check_if_created() + self.check_created() return self._tensors["states"] @property def actions(self) -> t.Tensor: - self._check_if_created() + self.check_created() return self._tensors["actions"] @property def rewards(self) -> t.Tensor: - self._check_if_created() + self.check_created() return self._tensors["rewards"] @property def dones(self) -> t.Tensor: - self._check_if_created() + self.check_created() return self._tensors["dones"] def add_transition(self, state: npt.ArrayLike, action: npt.ArrayLike, reward: float, done: bool, episode_done: bool | None = None): @@ -125,25 +125,6 @@ def sample(self, batch_size): self.states[ep_inds, step_inds + 1], ) - def save(self, path: str): - dirname = os.path.dirname(path) - if not os.path.exists(dirname): - os.makedirs(dirname) - - data = { - "states": self.states.cpu(), - "actions": self.actions.cpu(), - "rewards": self.rewards.cpu(), - "dones": self.dones.cpu(), - "ep_lens": self.ep_lens, - } - try: - with open(path, "wb") as f: - pickle.dump(data, f) - print(f"Replay buffer saved to {path}") - except Exception as e: - print(f"Failed to save replay buffer: {e}") - @property def last_episode_length(self): return self.ep_lens[self._ep_pointer] diff --git a/src/oprl/logging.py b/src/oprl/logging.py index 2d39b80..d8ab7d3 100644 --- a/src/oprl/logging.py +++ b/src/oprl/logging.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import sys import logging from datetime import datetime @@ -18,24 +19,28 @@ def log_scalar(self, tag: str, value: float, step: int) -> None: ... def log_scalars(self, values: dict[str, float], step: int) -> None: ... -def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str: +def get_logs_path(logdir: str, algo: str, env: str, seed: int) -> Path: dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss") - log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}") + log_dir = Path(logdir) / algo / f"{algo}-env_{env}-seed_{seed}-{dt}" logging.info(f"LOGDIR: {log_dir}") return log_dir -def set_logging(level: int = logging.INFO) -> None: - logging.basicConfig( - level=level, - format="%(asctime)s | %(filename)s:%(lineno)d\t %(levelname)s - %(message)s", - stream=sys.stdout, - ) +def create_stdout_logger(name=None): + if name is None: + import inspect + frame = inspect.currentframe().f_back + filename = os.path.basename(frame.f_code.co_filename) + name = os.path.splitext(filename)[0] + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + return logger -def copy_exp_dir(log_dir: str) -> None: - cur_dir = os.path.join(os.getcwd(), "src") - dest_dir = os.path.join(log_dir, "src") +def copy_exp_dir(log_dir: Path) -> None: + cur_dir = Path(__file__).parents[0] + dest_dir = log_dir / "src" shutil.copytree(cur_dir, dest_dir) logging.info(f"Source copied into {dest_dir}") @@ -43,14 +48,14 @@ def copy_exp_dir(log_dir: str) -> None: def make_text_logger_func(config: dict, algo, env) -> Callable: def make_logger(seed: int) -> LoggerProtocol: logs_root = os.environ.get("OPRL_LOGS", "logs") - log_dir = create_logdir(logdir=logs_root, algo=algo, env=env, seed=seed) + log_dir = get_logs_path(logdir=logs_root, algo=algo, env=env, seed=seed) logger = FileTxtLogger(log_dir, config) logger.copy_source_code() return logger return make_logger -def save_json_config(config: dict[str, Any], path: str): +def save_json_config(config: dict[str, Any], path: Path) -> None: with open(path, "w") as f: json.dump(config, f) @@ -69,37 +74,34 @@ def log_scalars(self, values: dict[str, float], step: int) -> None: (self.log_scalar(k, v, step) for k, v in values.items()) -class StdLogger(BaseLogger): - def log_scalar(self, tag: str, value: float, step: int) -> None: - logging.info(f"{tag}\t{value}\tat step {step}") - - class FileTxtLogger(BaseLogger): - def __init__(self, logdir: str, config: dict[str, Any]) -> None: + def __init__(self, logdir: Path | str, config: dict[str, Any]) -> None: self.writer = SummaryWriter(logdir) - self.log_dir = logdir + self.log_dir = Path(logdir) self.config = config def copy_source_code(self) -> None: - logging.info(f"Source code is copied to {self.log_dir}") copy_exp_dir(self.log_dir) - save_json_config(self.config, os.path.join(self.log_dir, "config.json")) + logging.info(f"Source code is copied to {self.log_dir}.") + save_json_config(self.config, self.log_dir / "config.json") def log_scalar(self, tag: str, value: float, step: int) -> None: self.writer.add_scalar(tag, value, step) self._log_scalar_to_file(tag, value, step) def save_weights(self, weights: nn.Module, step: int) -> None: - os.makedirs(os.path.join(self.log_dir, "weights"), exist_ok=True) - fn = os.path.join(self.log_dir, "weights", f"step_{step}.w") + weights_path = self.log_dir / "weights" / f"{step}.w" + weights_path.parents[0].mkdir(exist_ok=True) t.save( weights, - fn + weights_path ) def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None: - fn = os.path.join(self.log_dir, f"{tag}.log") - os.makedirs(os.path.dirname(fn), exist_ok=True) - with open(fn, "a") as f: + log_path = self.log_dir / f"{tag}.log" + log_path.parents[0].mkdir(exist_ok=True) + with open(log_path, "a") as f: f.write(f"{step} {val}\n") + + diff --git a/src/oprl/parse_args.py b/src/oprl/parse_args.py index d01867e..060355c 100644 --- a/src/oprl/parse_args.py +++ b/src/oprl/parse_args.py @@ -23,4 +23,3 @@ def parse_args() -> argparse.Namespace: "--device", type=str, default="cpu", help="Device to perform training on." ) return parser.parse_args() - diff --git a/src/oprl/train.py b/src/oprl/runners/train.py similarity index 100% rename from src/oprl/train.py rename to src/oprl/runners/train.py diff --git a/src/oprl/distrib_train.py b/src/oprl/runners/train_distrib.py similarity index 98% rename from src/oprl/distrib_train.py rename to src/oprl/runners/train_distrib.py index 4810768..9ee3b14 100644 --- a/src/oprl/distrib_train.py +++ b/src/oprl/runners/train_distrib.py @@ -1,10 +1,8 @@ import argparse import os -import time from datetime import datetime from multiprocessing import Process -import torch import torch.nn as nn from algos.ddpg import DDPG, DeterministicPolicy from distrib.distrib_runner import env_worker, policy_update_worker @@ -12,7 +10,6 @@ from trainers.buffers.episodic_buffer import EpisodicReplayBuffer from utils.logger import Logger -print("Imports ok.") def parse_args(): diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index b66f18c..0fbf55b 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -2,16 +2,19 @@ from typing import Callable import numpy as np -import torch +import torch as t from oprl.algos.protocols import AlgorithmProtocol from oprl.environment import EnvProtocol from oprl.buffers.protocols import ReplayBufferProtocol -from oprl.logging import LoggerProtocol +from oprl.logging import LoggerProtocol, create_stdout_logger from oprl.trainers.protocols import TrainerProtocol +logger = create_stdout_logger() + + @dataclass class BaseTrainer(TrainerProtocol): logger: LoggerProtocol @@ -119,7 +122,7 @@ def _estimate_q(self, env_step: int): def _log_stdout(self, env_step: int, batch): if env_step % self.stdout_log_every == 0: perc = int(env_step / self.num_steps * 100) - print( + logger.info( f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {batch[2].mean():10.3f}" ) @@ -128,7 +131,6 @@ def estimate_true_q(self, eval_episodes: int = 10) -> float | None: qs = [] for i_eval in range(eval_episodes): env = self.make_env_test(seed=self.seed * 100 + i_eval) - print("Before reset etimate q") state, _ = env.reset() q = 0 @@ -145,7 +147,7 @@ def estimate_true_q(self, eval_episodes: int = 10) -> float | None: return np.mean(qs, dtype=float) except Exception as e: - print(f"Failed to estimate Q-value: {e}") + logger.warning(f"Failed to estimate Q-value: {e}") return None def estimate_critic_q(self, num_episodes: int = 10) -> float: @@ -156,8 +158,8 @@ def estimate_critic_q(self, num_episodes: int = 10) -> float: state, _ = env.reset() action = self.algo.exploit(state) - state = torch.tensor(state).unsqueeze(0).float().to(self.device) - action = torch.tensor(action).unsqueeze(0).float().to(self.device) + state = t.tensor(state).unsqueeze(0).float().to(self.device) + action = t.tensor(action).unsqueeze(0).float().to(self.device) q = self.algo.critic(state, action) # TODO: TQC is not supported by this logic, need to update From 0e6511ff0829200bb55eff2fd0fcaeb4f66d0afa Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sun, 13 Jul 2025 15:21:21 +0400 Subject: [PATCH 13/30] Make algos configurable --- src/oprl/algos/ddpg.py | 11 +++++++---- src/oprl/algos/protocols.py | 4 ++-- src/oprl/algos/sac.py | 3 +-- src/oprl/algos/td3.py | 6 +++--- src/oprl/algos/tqc.py | 14 ++++++++------ src/oprl/buffers/episodic_buffer.py | 2 -- src/oprl/buffers/protocols.py | 2 -- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/oprl/algos/ddpg.py b/src/oprl/algos/ddpg.py index 3ce17ec..46abe88 100644 --- a/src/oprl/algos/ddpg.py +++ b/src/oprl/algos/ddpg.py @@ -19,7 +19,9 @@ class DDPG(OffPolicyAlgorithm): state_dim: int action_dim: int expl_noise: float = 0.1 - discount: float = 0.99 + gamma: float = 0.99 + lr_actor: float = 3e-4 + lr_critic: float = 3e-4 tau: float = 5e-3 batch_size: int = 256 max_action: float = 1. @@ -31,6 +33,7 @@ class DDPG(OffPolicyAlgorithm): critic: nn.Module = field(init=False) critic_target: nn.Module = field(init=False) optim_critic: t.optim.Optimizer = field(init=False) + update_step: int = 0 _created: bool = False def create(self) -> "DDPG": @@ -45,12 +48,12 @@ def create(self) -> "DDPG": ).to(self.device) self.actor_target = deepcopy(self.actor) disable_gradient(self.actor_target) - self.optim_actor = t.optim.Adam(self.actor.parameters(), lr=3e-4) + self.optim_actor = t.optim.Adam(self.actor.parameters(), lr=self.lr_actor) self.critic = Critic(self.state_dim, self.action_dim).to(self.device) self.critic_target = deepcopy(self.critic) disable_gradient(self.critic_target) - self.optim_critic = t.optim.Adam(self.critic.parameters(), lr=3e-4) + self.optim_critic = t.optim.Adam(self.critic.parameters(), lr=self.lr_critic) self._created = True return self @@ -90,7 +93,7 @@ def _update_critic( next_state: t.Tensor, ) -> None: target_Q = self.critic_target(next_state, self.actor_target(next_state)) - target_Q = reward + (1.0 - done) * self.discount * target_Q.detach() + target_Q = reward + (1.0 - done) * self.gamma * target_Q.detach() current_Q = self.critic(state, action) critic_loss = (current_Q - target_Q).pow(2).mean() diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py index e505fc1..fdcf626 100644 --- a/src/oprl/algos/protocols.py +++ b/src/oprl/algos/protocols.py @@ -6,9 +6,9 @@ class PolicyProtocol(Protocol): - def explore(self, state: npt.ArrayLike) -> npt.NDArray: ... + def explore(self, state: npt.NDArray) -> npt.NDArray: ... - def exploit(self, state: npt.ArrayLike) -> npt.NDArray: ... + def exploit(self, state: npt.NDArray) -> npt.NDArray: ... class AlgorithmProtocol(Protocol): diff --git a/src/oprl/algos/sac.py b/src/oprl/algos/sac.py index b459fbc..eb84a8d 100644 --- a/src/oprl/algos/sac.py +++ b/src/oprl/algos/sac.py @@ -36,7 +36,7 @@ class SAC(OffPolicyAlgorithm): critic_target: nn.Module = field(init=False) optim_critic: t.optim.Optimizer = field(init=False) alpha: float = field(init=False) - update_step: int = field(init=False) + update_step: int = 0 def create(self) -> "SAC": self.actor = GaussianActor( @@ -67,7 +67,6 @@ def create(self) -> "SAC": ) self.optim_alpha = t.optim.Adam([self.log_alpha], lr=self.lr_alpha) self.target_entropy = -float(self.action_dim) - self.update_step = 0 return self diff --git a/src/oprl/algos/td3.py b/src/oprl/algos/td3.py index ea6e749..04b287b 100644 --- a/src/oprl/algos/td3.py +++ b/src/oprl/algos/td3.py @@ -22,7 +22,7 @@ class TD3(OffPolicyAlgorithm): expl_noise: float = 0.1 noise_clip: float = 0.5 policy_freq: int = 2 - discount: float = 0.99 + gamma: float = 0.99 lr_actor: float = 3e-4 lr_critic: float = 3e-4 max_action: float = 1.0 @@ -36,7 +36,7 @@ class TD3(OffPolicyAlgorithm): critic: nn.Module = field(init=False) critic_target: nn.Module = field(init=False) optim_critic: t.optim.Optimizer = field(init=False) - update_step: int = field(init=False) + update_step: int = 0 def create(self) -> "TD3": self.actor = DeterministicPolicy( @@ -96,7 +96,7 @@ def _update_critic( q1_next, q2_next = self.critic_target(next_state, next_actions) q_next = t.min(q1_next, q2_next) - q_target = reward + (1.0 - done) * self.discount * q_next + q_target = reward + (1.0 - done) * self.gamma * q_next td_error1 = (q1 - q_target).pow(2).mean() td_error2 = (q2 - q_target).pow(2).mean() diff --git a/src/oprl/algos/tqc.py b/src/oprl/algos/tqc.py index c803525..fc0d552 100644 --- a/src/oprl/algos/tqc.py +++ b/src/oprl/algos/tqc.py @@ -63,7 +63,10 @@ class TQC(OffPolicyAlgorithm): logger: LoggerProtocol state_dim: int action_dim: int - discount: float = 0.99 + gamma: float = 0.99 + lr_actor = 3e-4 + lr_critic = 3e-4 + lr_alpha = 3e-4 tau: float = 0.005 top_quantiles_to_drop: int = 2 n_quantiles: int = 25 @@ -103,11 +106,10 @@ def create(self) -> "TQC": self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets # TODO: check hyperparams - self.actor_optimizer = t.optim.Adam(self.actor.parameters(), lr=3e-4) - self.critic_optimizer = t.optim.Adam(self.critic.parameters(), lr=3e-4) - self.alpha_optimizer = t.optim.Adam([self.log_alpha], lr=3e-4) + self.actor_optimizer = t.optim.Adam(self.actor.parameters(), lr=self.lr_actor) + self.critic_optimizer = t.optim.Adam(self.critic.parameters(), lr=self.lr_critic) + self.alpha_optimizer = t.optim.Adam([self.log_alpha], lr=self.lr_alpha) - self.udpate_step = 0 self._created = True return self @@ -138,7 +140,7 @@ def update( ] # compute target - target = reward + (1 - done) * self.discount * ( + target = reward + (1 - done) * self.gamma * ( sorted_z_part - alpha * next_log_pi ) diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index 4b0802b..030b8b5 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -1,6 +1,4 @@ from dataclasses import dataclass, field -import os -import pickle import numpy as np import numpy.typing as npt diff --git a/src/oprl/buffers/protocols.py b/src/oprl/buffers/protocols.py index 7a0d2a0..e914959 100644 --- a/src/oprl/buffers/protocols.py +++ b/src/oprl/buffers/protocols.py @@ -18,8 +18,6 @@ def sample(self, batch_size) -> tuple[ t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor ]: ... - def save(self, path: str) -> None: ... - @property def last_episode_length(self) -> int: ... From 2ed37d001d5bf21b6f053275ecd8013c7b04f41a Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Sun, 13 Jul 2025 15:55:53 +0400 Subject: [PATCH 14/30] Replace bare dict with pydantic config in configs --- configs/ddpg.py | 29 ++++++++++++----------------- configs/sac.py | 31 ++++++++++++------------------- configs/td3.py | 30 ++++++++++++------------------ configs/tqc.py | 31 ++++++++++++------------------- pyproject.toml | 1 + src/oprl/algos/sac.py | 4 +++- src/oprl/algos/td3.py | 3 +++ src/oprl/logging.py | 29 ++++++++++++++++------------- src/oprl/runners/config.py | 11 +++++++++++ src/oprl/runners/train.py | 10 +++++----- 10 files changed, 87 insertions(+), 92 deletions(-) create mode 100644 src/oprl/runners/config.py diff --git a/configs/ddpg.py b/configs/ddpg.py index 11f487b..468074b 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -9,6 +9,7 @@ ) from oprl.environment import make_env as _make_env from oprl.runners.train import run_training +from oprl.runners.config import CommonParameters args = parse_args() @@ -22,20 +23,15 @@ def make_env(seed: int): ACTION_DIM: int = env.action_space.shape[0] -# -------- Config params ----------- - -config = { - "state_dim": STATE_DIM, - "action_dim": ACTION_DIM, - "num_steps": int(100_000), - "eval_every": 2500, - "device": args.device, - "visualise_every": 50000, - "estimate_q_every": 5000, - "log_every": 2500, -} - -# ----------------------------------- +config = CommonParameters( + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + num_steps=int(100_000), + eval_every=2500, + device=args.device, + estimate_q_every=5000, + log_every=2500, +) def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: @@ -49,15 +45,14 @@ def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size_transitions=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config.num_steps, int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, - device=config["device"], + device=config.device, ).create() make_logger = make_text_logger_func( - config=config, algo="DDPG", env=args.env, ) diff --git a/configs/sac.py b/configs/sac.py index 17deae7..48f5dce 100644 --- a/configs/sac.py +++ b/configs/sac.py @@ -2,15 +2,14 @@ from oprl.algos.sac import SAC from oprl.parse_args import parse_args from oprl.logging import ( - set_logging, LoggerProtocol, make_text_logger_func, ) -set_logging() from oprl.environment import make_env as _make_env from oprl.buffers.protocols import ReplayBufferProtocol from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.runners.train import run_training +from oprl.runners.config import CommonParameters args = parse_args() @@ -24,20 +23,15 @@ def make_env(seed: int): ACTION_DIM: int = env.action_space.shape[0] -# -------- Config params ----------- - -config = { - "state_dim": STATE_DIM, - "action_dim": ACTION_DIM, - "num_steps": int(100_000), - "eval_every": 2500, - "device": args.device, - "visualise_every": 0, - "estimate_q_every": 5000, - "log_every": 1000, -} - -# ----------------------------------- +config = CommonParameters( + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + num_steps=int(100_000), + eval_every=2500, + device=args.device, + estimate_q_every=5000, + log_every=1000, +) def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: @@ -50,7 +44,6 @@ def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: make_logger = make_text_logger_func( - config=config, algo="SAC", env=args.env, ) @@ -58,10 +51,10 @@ def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size_transitions=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config.num_steps, int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, - device=config["device"], + device=config.device, ).create() diff --git a/configs/td3.py b/configs/td3.py index 61b6d77..827f27f 100644 --- a/configs/td3.py +++ b/configs/td3.py @@ -4,13 +4,12 @@ from oprl.buffers.protocols import ReplayBufferProtocol from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.logging import ( - set_logging, LoggerProtocol, make_text_logger_func, ) -set_logging() from oprl.environment import make_env as _make_env from oprl.runners.train import run_training +from oprl.runners.config import CommonParameters args = parse_args() @@ -24,19 +23,15 @@ def make_env(seed: int): ACTION_DIM: int = env.action_space.shape[0] -# -------- Config params ----------- - -config = { - "state_dim": STATE_DIM, - "action_dim": ACTION_DIM, - "num_steps": int(100_000), - "eval_every": 2500, - "device": args.device, - "estimate_q_every": 5000, - "log_every": 2500, -} - -# ----------------------------------- +config = CommonParameters( + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + num_steps=int(100_000), + eval_every=2500, + device=args.device, + estimate_q_every=5000, + log_every=2500, +) def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: @@ -50,15 +45,14 @@ def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size_transitions=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config.num_steps, int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, - device=config["device"], + device=config.device, ).create() make_logger = make_text_logger_func( - config=config, algo="TD3", env=args.env, ) diff --git a/configs/tqc.py b/configs/tqc.py index fd2db34..fdfef89 100644 --- a/configs/tqc.py +++ b/configs/tqc.py @@ -2,15 +2,14 @@ from oprl.algos.tqc import TQC from oprl.parse_args import parse_args from oprl.logging import ( - set_logging, LoggerProtocol, make_text_logger_func, ) -set_logging() from oprl.environment import make_env as _make_env from oprl.buffers.protocols import ReplayBufferProtocol from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.runners.train import run_training +from oprl.runners.config import CommonParameters args = parse_args() @@ -24,20 +23,15 @@ def make_env(seed: int): ACTION_DIM: int = env.action_space.shape[0] -# -------- Config params ----------- - -config = { - "state_dim": STATE_DIM, - "action_dim": ACTION_DIM, - "num_steps": int(100_000), - "eval_every": 2500, - "device": args.device, - "visualise_every": 0, - "estimate_q_every": 0, # TODO: Here is the unsupported logic - "log_every": 2500, -} - -# ----------------------------------- +config = CommonParameters( + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + num_steps=int(100_000), + eval_every=2500, + device=args.device, + estimate_q_every=0, # TODO: unsupported logic + log_every=2500, +) def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: @@ -50,7 +44,6 @@ def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: make_logger = make_text_logger_func( - config=config, algo="TQC", env=args.env, ) @@ -58,10 +51,10 @@ def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( - buffer_size_transitions=max(config["num_steps"], int(1e6)), + buffer_size_transitions=max(config.num_steps, int(1e6)), state_dim=STATE_DIM, action_dim=ACTION_DIM, - device=config["device"], + device=config.device, ).create() diff --git a/pyproject.toml b/pyproject.toml index f531d30..0eee46a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "dm-control==1.0.11", "mujoco==2.3.3", "numpy==1.26.4", + "pydantic_settings==2.10.1", ] [project.optional-dependencies] diff --git a/src/oprl/algos/sac.py b/src/oprl/algos/sac.py index eb84a8d..a22fcec 100644 --- a/src/oprl/algos/sac.py +++ b/src/oprl/algos/sac.py @@ -37,6 +37,7 @@ class SAC(OffPolicyAlgorithm): optim_critic: t.optim.Optimizer = field(init=False) alpha: float = field(init=False) update_step: int = 0 + _created: bool = False def create(self) -> "SAC": self.actor = GaussianActor( @@ -68,6 +69,7 @@ def create(self) -> "SAC": self.optim_alpha = t.optim.Adam([self.log_alpha], lr=self.lr_alpha) self.target_entropy = -float(self.action_dim) + self._created = True return self def update( @@ -146,7 +148,7 @@ def update_actor(self, state: t.Tensor) -> None: self.logger.log_scalars( { "algo/loss_actor": loss_actor.item(), - "algo/alpha": self._alpha, + "algo/alpha": self.alpha, "algo/log_pi": log_pi.cpu().mean(), }, self.update_step, diff --git a/src/oprl/algos/td3.py b/src/oprl/algos/td3.py index 04b287b..6d06bd6 100644 --- a/src/oprl/algos/td3.py +++ b/src/oprl/algos/td3.py @@ -37,6 +37,7 @@ class TD3(OffPolicyAlgorithm): critic_target: nn.Module = field(init=False) optim_critic: t.optim.Optimizer = field(init=False) update_step: int = 0 + _created: bool = False def create(self) -> "TD3": self.actor = DeterministicPolicy( @@ -62,6 +63,8 @@ def create(self) -> "TD3": disable_gradient(self.critic_target) self.optim_critic = Adam(self.critic.parameters(), lr=self.lr_critic) + + self._created = True return self diff --git a/src/oprl/logging.py b/src/oprl/logging.py index d8ab7d3..6c9d7d4 100644 --- a/src/oprl/logging.py +++ b/src/oprl/logging.py @@ -3,10 +3,9 @@ import sys import logging from datetime import datetime -import json import shutil from abc import ABC, abstractmethod -from typing import Any, Protocol, Callable +from typing import Protocol, Callable import torch as t import torch.nn as nn @@ -45,21 +44,16 @@ def copy_exp_dir(log_dir: Path) -> None: logging.info(f"Source copied into {dest_dir}") -def make_text_logger_func(config: dict, algo, env) -> Callable: +def make_text_logger_func(algo, env) -> Callable: def make_logger(seed: int) -> LoggerProtocol: logs_root = os.environ.get("OPRL_LOGS", "logs") log_dir = get_logs_path(logdir=logs_root, algo=algo, env=env, seed=seed) - logger = FileTxtLogger(log_dir, config) + logger = FileTxtLogger(log_dir) logger.copy_source_code() return logger return make_logger -def save_json_config(config: dict[str, Any], path: Path) -> None: - with open(path, "w") as f: - json.dump(config, f) - - class BaseLogger(ABC): @abstractmethod def log_scalar(self, tag: str, value: float, step: int) -> None: @@ -74,16 +68,25 @@ def log_scalars(self, values: dict[str, float], step: int) -> None: (self.log_scalar(k, v, step) for k, v in values.items()) +logger = create_stdout_logger() + + class FileTxtLogger(BaseLogger): - def __init__(self, logdir: Path | str, config: dict[str, Any]) -> None: + def __init__(self, logdir: Path | str) -> None: self.writer = SummaryWriter(logdir) self.log_dir = Path(logdir) - self.config = config def copy_source_code(self) -> None: copy_exp_dir(self.log_dir) - logging.info(f"Source code is copied to {self.log_dir}.") - save_json_config(self.config, self.log_dir / "config.json") + logger.info(f"Source code is copied to {self.log_dir}") + self._copy_config_file() + + def _copy_config_file(self) -> None: + main_module = sys.modules.get('__main__') + if main_module and hasattr(main_module, '__file__'): + shutil.copyfile(main_module.__file__, self.log_dir / Path(main_module.__file__).name) + else: + logger.warning("Failed to copy config file.") def log_scalar(self, tag: str, value: float, step: int) -> None: self.writer.add_scalar(tag, value, step) diff --git a/src/oprl/runners/config.py b/src/oprl/runners/config.py new file mode 100644 index 0000000..80fb6d1 --- /dev/null +++ b/src/oprl/runners/config.py @@ -0,0 +1,11 @@ +from pydantic_settings import BaseSettings + + +class CommonParameters(BaseSettings): + state_dim: int + action_dim: int + num_steps: int + eval_every: int = 2500 + estimate_q_every: int = 5000 + log_every: int = 2500 + device: str = "cpu" diff --git a/src/oprl/runners/train.py b/src/oprl/runners/train.py index 475ad7a..fc510ff 100644 --- a/src/oprl/runners/train.py +++ b/src/oprl/runners/train.py @@ -49,11 +49,11 @@ def _run_training_func(make_algo, make_env, make_replay_buffer, make_logger, con make_env_test=make_env, algo=algo, replay_buffer=replay_buffer, - num_steps=config["num_steps"], - eval_interval=config["eval_every"], - device=config["device"], - estimate_q_every=config["estimate_q_every"], - stdout_log_every=config["log_every"], + num_steps=config.num_steps, + eval_interval=config.eval_every, + device=config.device, + estimate_q_every=config.estimate_q_every, + stdout_log_every=config.log_every, seed=seed, logger=logger, ) From fa44399ec8243a036c98cb09ac45efff7168da20 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 15:33:38 +0300 Subject: [PATCH 15/30] Refine annotations --- configs/ddpg.py | 3 +- src/oprl/algos/ddpg.py | 3 +- src/oprl/algos/nn_functions.py | 5 +- src/oprl/algos/nn_models.py | 40 ++++--- src/oprl/algos/protocols.py | 4 + src/oprl/algos/td3.py | 10 +- src/oprl/algos/tqc.py | 2 +- src/oprl/buffers/episodic_buffer.py | 23 ++-- src/oprl/buffers/protocols.py | 3 + src/oprl/environment/dm_control.py | 18 +-- src/oprl/environment/make_env.py | 138 ++++++++++++----------- src/oprl/environment/protocols.py | 14 +-- src/oprl/environment/safety_gymnasium.py | 16 +-- src/oprl/logging.py | 21 +--- src/oprl/runners/train.py | 28 ++++- src/oprl/trainers/base_trainer.py | 81 ++++++------- src/oprl/trainers/safe_trainer.py | 29 +++-- tests/functional/test_rl_algos.py | 16 ++- 18 files changed, 257 insertions(+), 197 deletions(-) diff --git a/configs/ddpg.py b/configs/ddpg.py index 468074b..a183970 100644 --- a/configs/ddpg.py +++ b/configs/ddpg.py @@ -7,6 +7,7 @@ LoggerProtocol, make_text_logger_func, ) +from oprl.environment.protocols import EnvProtocol from oprl.environment import make_env as _make_env from oprl.runners.train import run_training from oprl.runners.config import CommonParameters @@ -14,7 +15,7 @@ args = parse_args() -def make_env(seed: int): +def make_env(seed: int) -> EnvProtocol: return _make_env(args.env, seed=seed) diff --git a/src/oprl/algos/ddpg.py b/src/oprl/algos/ddpg.py index 46abe88..782bd58 100644 --- a/src/oprl/algos/ddpg.py +++ b/src/oprl/algos/ddpg.py @@ -65,11 +65,10 @@ def update( reward: t.Tensor, done: t.Tensor, next_state: t.Tensor, - ): + ) -> None: self._update_critic(state, action, reward, done, next_state) self._update_actor(state) - # Update the frozen target models for param, target_param in zip( self.critic.parameters(), self.critic_target.parameters() ): diff --git a/src/oprl/algos/nn_functions.py b/src/oprl/algos/nn_functions.py index 72db6d6..e168e02 100644 --- a/src/oprl/algos/nn_functions.py +++ b/src/oprl/algos/nn_functions.py @@ -1,7 +1,8 @@ import torch as t +import torch.nn as nn -def soft_update(target, source, tau): +def soft_update(target: nn.Module, source: nn.Module, tau: float) -> None: """Update target network using Polyak-Ruppert Averaging.""" with t.no_grad(): for tgt, src in zip(target.parameters(), source.parameters()): @@ -9,7 +10,7 @@ def soft_update(target, source, tau): tgt.data.add_(tau * src.data) -def disable_gradient(network): +def disable_gradient(network: nn.Module) -> None: """Disable gradient calculations of the network.""" for param in network.parameters(): param.requires_grad = False diff --git a/src/oprl/algos/nn_models.py b/src/oprl/algos/nn_models.py index 547e289..599cc15 100644 --- a/src/oprl/algos/nn_models.py +++ b/src/oprl/algos/nn_models.py @@ -1,3 +1,5 @@ +from typing import Final + import numpy as np import numpy.typing as npt import torch as t @@ -6,10 +8,10 @@ from torch.nn.functional import logsigmoid -LOG_STD_MIN_MAX = (-20, 2) +LOG_STD_MIN_MAX: Final[tuple[float, float]] = (-20, 2) -def initialize_weight_orthogonal(m, gain=nn.init.calculate_gain("relu")): +def initialize_weight_orthogonal(m: nn.Module, gain: float = nn.init.calculate_gain("relu")): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data, gain) m.bias.data.fill_(0.0) @@ -29,9 +31,8 @@ def __init__( action_dim: int, hidden_units: tuple[int, ...] = (256, 256), hidden_activation: nn.Module = nn.ReLU(inplace=True), - ): + ) -> None: super().__init__() - self.q1 = MLP( input_dim=state_dim + action_dim, output_dim=1, @@ -39,7 +40,7 @@ def __init__( hidden_activation=hidden_activation, ) - def forward(self, states: t.Tensor, actions: t.Tensor): + def forward(self, states: t.Tensor, actions: t.Tensor) -> t.Tensor: x = t.cat([states, actions], dim=-1) return self.q1(x) @@ -57,7 +58,6 @@ def __init__( hidden_activation: nn.Module = nn.ReLU(inplace=True), ): super().__init__() - self.q1 = MLP( input_dim=state_dim + action_dim, output_dim=1, @@ -113,7 +113,7 @@ def __init__( state_dim: int, action_dim: int, hidden_units: tuple[int, ...] = (256, 256), - hidden_activation=nn.ReLU(inplace=True), + hidden_activation: nn.Module = nn.ReLU(inplace=True), max_action: float = 1.0, expl_noise: float = 0.1, device: str = "cpu", @@ -135,23 +135,30 @@ def __init__( def forward(self, states: t.Tensor) -> t.Tensor: return t.tanh(self.mlp(states)) - def exploit(self, state: npt.ArrayLike) -> npt.NDArray: - state = t.tensor(state).unsqueeze_(0).to(self._device) + def exploit(self, state: npt.NDArray) -> npt.NDArray: + state_tensor = t.tensor(state).unsqueeze_(0).to(self._device) with t.no_grad(): - action = self.forward(state) + action = self.forward(state_tensor) return action.cpu().numpy().flatten() - def explore(self, state: npt.ArrayLike) -> npt.NDArray: - state = t.tensor(state, device=self._device).unsqueeze_(0) + def explore(self, state: npt.NDArray) -> npt.NDArray: + state_tensor = t.tensor(state, device=self._device).unsqueeze_(0) noise = (t.randn(self._action_shape) * self._expl_noise).to(self._device) with t.no_grad(): - action = self.mlp(state) + noise + action = self.mlp(state_tensor) + noise action = action.cpu().numpy()[0] return np.clip(action, -self._max_action, self._max_action) class GaussianActor(nn.Module): - def __init__(self, state_dim, action_dim, hidden_units, hidden_activation, device: str): + def __init__( + self, + state_dim: int, + action_dim: int, + hidden_units: tuple[int, ...], + hidden_activation: nn.Module, + device: str, + ): super().__init__() self.action_dim = action_dim self.net = MLP( @@ -188,7 +195,7 @@ def exploit(self, state: npt.NDArray) -> npt.NDArray: class TanhNormal(Distribution): - def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str): + def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str) -> None: super().__init__() self.normal_mean = normal_mean self.normal_std = normal_std @@ -200,8 +207,7 @@ def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str): def log_prob(self, pre_tanh: t.Tensor) -> t.Tensor: log_det = 2 * np.log(2) + logsigmoid(2 * pre_tanh) + logsigmoid(-2 * pre_tanh) - result = self.normal.log_prob(pre_tanh) - log_det - return result + return self.normal.log_prob(pre_tanh) - log_det def rsample(self) -> tuple[t.Tensor, t.Tensor]: pretanh = self.normal_mean + self.normal_std * self.standard_normal.sample() diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py index fdcf626..94f7322 100644 --- a/src/oprl/algos/protocols.py +++ b/src/oprl/algos/protocols.py @@ -3,6 +3,7 @@ import numpy.typing as npt import torch as t +import torch.nn as nn class PolicyProtocol(Protocol): @@ -10,9 +11,12 @@ def explore(self, state: npt.NDArray) -> npt.NDArray: ... def exploit(self, state: npt.NDArray) -> npt.NDArray: ... + def __call__(*args, **kwargs) -> t.Tensor: ... + class AlgorithmProtocol(Protocol): actor: PolicyProtocol + critic: nn.Module _created: bool def create(self) -> "AlgorithmProtocol": ... diff --git a/src/oprl/algos/td3.py b/src/oprl/algos/td3.py index 6d06bd6..066b930 100644 --- a/src/oprl/algos/td3.py +++ b/src/oprl/algos/td3.py @@ -68,14 +68,20 @@ def create(self) -> "TD3": return self - def update(self, state: t.Tensor, action, reward, done, next_state) -> None: + def update( + self, + state: t.Tensor, + action: t.Tensor, + reward: t.Tensor, + done: t.Tensor, + next_state: t.Tensor, + ) -> None: self._update_critic(state, action, reward, done, next_state) if self.update_step % self.policy_freq == 0: self._update_actor(state) soft_update(self.critic_target, self.critic, self.tau) soft_update(self.actor_target, self.actor, self.tau) - self.update_step += 1 def _update_critic( diff --git a/src/oprl/algos/tqc.py b/src/oprl/algos/tqc.py index fc0d552..025c184 100644 --- a/src/oprl/algos/tqc.py +++ b/src/oprl/algos/tqc.py @@ -37,7 +37,7 @@ def quantile_huber_loss_f( class QuantileQritic(nn.Module): - def __init__(self, state_dim: int, action_dim: int, n_quantiles: int, n_nets: int): + def __init__(self, state_dim: int, action_dim: int, n_quantiles: int, n_nets: int) -> None: super().__init__() self.nets = [] self.n_quantiles = n_quantiles diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py index 030b8b5..2c7bc69 100644 --- a/src/oprl/buffers/episodic_buffer.py +++ b/src/oprl/buffers/episodic_buffer.py @@ -7,6 +7,9 @@ from oprl.buffers.protocols import ReplayBufferProtocol +Transition = tuple[npt.NDArray, npt.NDArray, float, bool, npt.NDArray] + + @dataclass class EpisodicReplayBuffer(ReplayBufferProtocol): buffer_size_transitions: int @@ -75,14 +78,21 @@ def dones(self) -> t.Tensor: self.check_created() return self._tensors["dones"] - def add_transition(self, state: npt.ArrayLike, action: npt.ArrayLike, reward: float, done: bool, episode_done: bool | None = None): + def add_transition( + self, + state: npt.NDArray, + action: npt.NDArray, + reward: float, + done: bool, + episode_done: bool | None = None + ) -> None: self.states[self._ep_pointer, self.ep_lens[self._ep_pointer]].copy_( t.from_numpy(state) ) self.actions[self._ep_pointer, self.ep_lens[self._ep_pointer]].copy_( t.from_numpy(action) ) - self.rewards[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(reward) + self.rewards[self._ep_pointer, self.ep_lens[self._ep_pointer]] = reward self.dones[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(done) self.ep_lens[self._ep_pointer] += 1 self._number_transitions = min(self._number_transitions + 1, self.buffer_size_transitions) @@ -96,22 +106,21 @@ def _inc_episode(self): self._number_transitions -= self.ep_lens[self._ep_pointer] self.ep_lens[self._ep_pointer] = 0 - def add_episode(self, episode: list): + def add_episode(self, episode: list[Transition]) -> None: for s, a, r, d, _ in episode: self.add_transition(s, a, r, d, episode_done=d) self._inc_episode() - def _inds_to_episodic(self, inds): + def _inds_to_episodic(self, inds: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]: start_inds = np.cumsum([0] + self.ep_lens[: self.episodes_counter - 1]) end_inds = start_inds + np.array(self.ep_lens[: self.episodes_counter]) ep_inds = np.argmin( inds.reshape(-1, 1) >= np.tile(end_inds, (len(inds), 1)), axis=1 ) step_inds = inds - start_inds[ep_inds] - return ep_inds, step_inds - def sample(self, batch_size): + def sample(self, batch_size: int) -> tuple[t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor]: inds = np.random.randint(low=0, high=self._number_transitions, size=batch_size) ep_inds, step_inds = self._inds_to_episodic(inds) @@ -124,7 +133,7 @@ def sample(self, batch_size): ) @property - def last_episode_length(self): + def last_episode_length(self) -> int: return self.ep_lens[self._ep_pointer] def __len__(self) -> int: diff --git a/src/oprl/buffers/protocols.py b/src/oprl/buffers/protocols.py index e914959..4c744c5 100644 --- a/src/oprl/buffers/protocols.py +++ b/src/oprl/buffers/protocols.py @@ -4,6 +4,7 @@ class ReplayBufferProtocol(Protocol): + episodes_counter: int _created: bool def create(self) -> "ReplayBufferProtocol": ... @@ -18,6 +19,8 @@ def sample(self, batch_size) -> tuple[ t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor ]: ... + def __len__(self) -> int: ... + @property def last_episode_length(self) -> int: ... diff --git a/src/oprl/environment/dm_control.py b/src/oprl/environment/dm_control.py index 02db9df..124b6b8 100644 --- a/src/oprl/environment/dm_control.py +++ b/src/oprl/environment/dm_control.py @@ -9,7 +9,7 @@ class DMControlEnv(EnvProtocol): - def __init__(self, env: str, seed: int): + def __init__(self, env: str, seed: int) -> None: domain, task = env.split("-") self.random_state = np.random.RandomState(seed) self.env = suite.load(domain, task, task_kwargs={"random": self.random_state}) @@ -18,13 +18,13 @@ def __init__(self, env: str, seed: int): self._render_height = 200 self._camera_id = 0 - def reset(self, *args, **kwargs) -> tuple[npt.ArrayLike, dict[str, Any]]: + def reset(self) -> tuple[npt.NDArray, dict[str, Any]]: obs = self._flat_obs(self.env.reset().observation) return obs, {} def step( - self, action: npt.ArrayLike - ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: + self, action: npt.NDArray + ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]: time_step = self.env.step(action) obs = self._flat_obs(time_step.observation) @@ -33,22 +33,22 @@ def step( return obs, time_step.reward, terminated, truncated, {} - def sample_action(self) -> npt.ArrayLike: + def sample_action(self) -> npt.NDArray: spec = self.env.action_spec() action = self.random_state.uniform(spec.minimum, spec.maximum, spec.shape) return action @property - def observation_space(self) -> npt.ArrayLike: + def observation_space(self) -> npt.NDArray: return np.zeros( sum(int(np.prod(v.shape)) for v in self.env.observation_spec().values()) ) @property - def action_space(self) -> npt.ArrayLike: + def action_space(self) -> npt.NDArray: return np.zeros(self.env.action_spec().shape[0]) - def render(self) -> npt.ArrayLike: # [1, W, H, C] + def render(self) -> npt.NDArray: # [1, W, H, C] img = self.env.physics.render( camera_id=self._camera_id, height=self._render_width, @@ -57,7 +57,7 @@ def render(self) -> npt.ArrayLike: # [1, W, H, C] img = img.astype(np.uint8) return img - def _flat_obs(self, obs: OrderedDict) -> npt.ArrayLike: + def _flat_obs(self, obs: OrderedDict) -> npt.NDArray: obs_flatten = [] for _, o in obs.items(): if len(o.shape) == 0: diff --git a/src/oprl/environment/make_env.py b/src/oprl/environment/make_env.py index f4f9f12..3b3a158 100644 --- a/src/oprl/environment/make_env.py +++ b/src/oprl/environment/make_env.py @@ -1,72 +1,75 @@ +from collections import defaultdict + from oprl.environment import EnvProtocol, DMControlEnv, SafetyGym -ENV_MAPPER = { - "dm_control": set( - [ - "acrobot-swingup", - "ball_in_cup-catch", - "cartpole-balance", - "cartpole-swingup", - "cheetah-run", - "finger-spin", - "finger-turn_easy", - "finger-turn_hard", - "fish-upright", - "fish-swim", - "hopper-stand", - "hopper-hop", - "humanoid-stand", - "humanoid-walk", - "humanoid-run", - "pendulum-swingup", - "point_mass-easy", - "reacher-easy", - "reacher-hard", - "swimmer-swimmer6", - "swimmer-swimmer15", - "walker-stand", - "walker-walk", - "walker-run", - ] - ), - "safety_gymnasium": set( - [ - "SafetyPointGoal1-v0", - "SafetyPointGoal2-v0", - "SafetyPointButton1-v0", - "SafetyPointButton2-v0", - "SafetyPointPush1-v0", - "SafetyPointPush2-v0", - "SafetyPointCircle1-v0", - "SafetyPointCircle2-v0", - "SafetyCarGoal1-v0", - "SafetyCarGoal2-v0", - "SafetyCarButton1-v0", - "SafetyCarButton2-v0", - "SafetyCarPush1-v0", - "SafetyCarPush2-v0", - "SafetyCarCircle1-v0", - "SafetyCarCircle2-v0", - "SafetyAntGoal1-v0", - "SafetyAntGoal2-v0", - "SafetyAntButton1-v0", - "SafetyAntButton2-v0", - "SafetyAntPush1-v0", - "SafetyAntPush2-v0", - "SafetyAntCircle1-v0", - "SafetyAntCircle2-v0", - "SafetyDoggoGoal1-v0", - "SafetyDoggoGoal2-v0", - "SafetyDoggoButton1-v0", - "SafetyDoggoButton2-v0", - "SafetyDoggoPush1-v0", - "SafetyDoggoPush2-v0", - "SafetyDoggoCircle1-v0", - "SafetyDoggoCircle2-v0", - ] - ), -} + +ENV_MAPPER: defaultdict[str, set[str]] = defaultdict(set) +ENV_MAPPER["dm_control"] = set( + [ + "acrobot-swingup", + "ball_in_cup-catch", + "cartpole-balance", + "cartpole-swingup", + "cheetah-run", + "finger-spin", + "finger-turn_easy", + "finger-turn_hard", + "fish-upright", + "fish-swim", + "hopper-stand", + "hopper-hop", + "humanoid-stand", + "humanoid-walk", + "humanoid-run", + "pendulum-swingup", + "point_mass-easy", + "reacher-easy", + "reacher-hard", + "swimmer-swimmer6", + "swimmer-swimmer15", + "walker-stand", + "walker-walk", + "walker-run", + ] +) + +ENV_MAPPER["safety_gymnasium"] = set( + [ + "SafetyPointGoal1-v0", + "SafetyPointGoal2-v0", + "SafetyPointButton1-v0", + "SafetyPointButton2-v0", + "SafetyPointPush1-v0", + "SafetyPointPush2-v0", + "SafetyPointCircle1-v0", + "SafetyPointCircle2-v0", + "SafetyCarGoal1-v0", + "SafetyCarGoal2-v0", + "SafetyCarButton1-v0", + "SafetyCarButton2-v0", + "SafetyCarPush1-v0", + "SafetyCarPush2-v0", + "SafetyCarCircle1-v0", + "SafetyCarCircle2-v0", + "SafetyAntGoal1-v0", + "SafetyAntGoal2-v0", + "SafetyAntButton1-v0", + "SafetyAntButton2-v0", + "SafetyAntPush1-v0", + "SafetyAntPush2-v0", + "SafetyAntCircle1-v0", + "SafetyAntCircle2-v0", + "SafetyDoggoGoal1-v0", + "SafetyDoggoGoal2-v0", + "SafetyDoggoButton1-v0", + "SafetyDoggoButton2-v0", + "SafetyDoggoPush1-v0", + "SafetyDoggoPush2-v0", + "SafetyDoggoCircle1-v0", + "SafetyDoggoCircle2-v0", + ] +) def make_env(name: str, seed: int) -> EnvProtocol: @@ -76,5 +79,4 @@ def make_env(name: str, seed: int) -> EnvProtocol: return DMControlEnv(name, seed=seed) elif env_type == "safety_gymnasium": return SafetyGym(name, seed=seed) - else: - raise ValueError(f"Unsupported environment: {name}") + raise ValueError(f"Unsupported environment: {name}") diff --git a/src/oprl/environment/protocols.py b/src/oprl/environment/protocols.py index 3f1a585..d086c87 100644 --- a/src/oprl/environment/protocols.py +++ b/src/oprl/environment/protocols.py @@ -8,25 +8,25 @@ def __init__(self, env_name: str, seed: int) -> None: ... def step( - self, action: npt.ArrayLike - ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: + self, action: npt.NDArray + ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]: ... - def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: + def reset(self) -> tuple[npt.NDArray, dict[str, Any]]: ... - def sample_action(self) -> npt.ArrayLike: + def sample_action(self) -> npt.NDArray: ... - def render(self) -> npt.ArrayLike: + def render(self) -> npt.NDArray: ... @property - def observation_space(self) -> npt.ArrayLike: + def observation_space(self) -> npt.NDArray: ... @property - def action_space(self) -> npt.ArrayLike: + def action_space(self) -> npt.NDArray: ... @property diff --git a/src/oprl/environment/safety_gymnasium.py b/src/oprl/environment/safety_gymnasium.py index 10c1614..910ecab 100644 --- a/src/oprl/environment/safety_gymnasium.py +++ b/src/oprl/environment/safety_gymnasium.py @@ -11,29 +11,29 @@ def __init__(self, env_name: str, seed: int) -> None: self._seed = seed def step( - self, action: npt.ArrayLike - ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: + self, action: npt.NDArray + ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]: obs, reward, cost, terminated, truncated, info = self._env.step(action) info["cost"] = cost - return obs.astype("float32"), reward, terminated, truncated, info + return obs.astype("float32"), float(reward), terminated, bool(truncated), info - def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: + def reset(self) -> tuple[npt.NDArray, dict[str, Any]]: obs, info = self._env.reset(seed=self._seed) self._env.step(self._env.action_space.sample()) return obs.astype("float32"), info - def sample_action(self): + def sample_action(self) -> npt.NDArray: return self._env.action_space.sample() - def render(self) -> npt.ArrayLike: + def render(self) -> npt.NDArray: return self._env.render() @property - def observation_space(self) -> npt.ArrayLike: + def observation_space(self) -> npt.NDArray: return self._env.observation_space @property - def action_space(self) -> npt.ArrayLike: + def action_space(self) -> npt.NDArray: return self._env.action_space @property diff --git a/src/oprl/logging.py b/src/oprl/logging.py index 6c9d7d4..e3cf137 100644 --- a/src/oprl/logging.py +++ b/src/oprl/logging.py @@ -7,17 +7,18 @@ from abc import ABC, abstractmethod from typing import Protocol, Callable -import torch as t -import torch.nn as nn from torch.utils.tensorboard.writer import SummaryWriter class LoggerProtocol(Protocol): + log_dir: Path + def log_scalar(self, tag: str, value: float, step: int) -> None: ... def log_scalars(self, values: dict[str, float], step: int) -> None: ... + def get_logs_path(logdir: str, algo: str, env: str, seed: int) -> Path: dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss") log_dir = Path(logdir) / algo / f"{algo}-env_{env}-seed_{seed}-{dt}" @@ -25,7 +26,7 @@ def get_logs_path(logdir: str, algo: str, env: str, seed: int) -> Path: return log_dir -def create_stdout_logger(name=None): +def create_stdout_logger(name: str | None = None): if name is None: import inspect frame = inspect.currentframe().f_back @@ -44,7 +45,7 @@ def copy_exp_dir(log_dir: Path) -> None: logging.info(f"Source copied into {dest_dir}") -def make_text_logger_func(algo, env) -> Callable: +def make_text_logger_func(algo: str, env: str) -> Callable[[int], LoggerProtocol]: def make_logger(seed: int) -> LoggerProtocol: logs_root = os.environ.get("OPRL_LOGS", "logs") log_dir = get_logs_path(logdir=logs_root, algo=algo, env=env, seed=seed) @@ -68,7 +69,7 @@ def log_scalars(self, values: dict[str, float], step: int) -> None: (self.log_scalar(k, v, step) for k, v in values.items()) -logger = create_stdout_logger() +logger = create_stdout_logger(__name__) class FileTxtLogger(BaseLogger): @@ -92,19 +93,9 @@ def log_scalar(self, tag: str, value: float, step: int) -> None: self.writer.add_scalar(tag, value, step) self._log_scalar_to_file(tag, value, step) - def save_weights(self, weights: nn.Module, step: int) -> None: - weights_path = self.log_dir / "weights" / f"{step}.w" - weights_path.parents[0].mkdir(exist_ok=True) - t.save( - weights, - weights_path - ) - def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None: log_path = self.log_dir / f"{tag}.log" log_path.parents[0].mkdir(exist_ok=True) with open(log_path, "a") as f: f.write(f"{step} {val}\n") - - diff --git a/src/oprl/runners/train.py b/src/oprl/runners/train.py index fc510ff..2c0180d 100644 --- a/src/oprl/runners/train.py +++ b/src/oprl/runners/train.py @@ -1,11 +1,18 @@ +from typing import Callable import logging import random +from multiprocessing import Process + import numpy as np import torch as t -from multiprocessing import Process +from oprl.algos.protocols import AlgorithmProtocol +from oprl.buffers.protocols import ReplayBufferProtocol +from oprl.environment.protocols import EnvProtocol from oprl.trainers.base_trainer import BaseTrainer from oprl.trainers.safe_trainer import SafeTrainer +from oprl.logging import LoggerProtocol +from oprl.runners.config import CommonParameters def set_seed(seed: int) -> None: @@ -15,8 +22,14 @@ def set_seed(seed: int) -> None: def run_training( - make_algo, make_env, make_replay_buffer, make_logger, config, seeds: int = 1, start_seed: int = 0 -): + make_algo: Callable[[LoggerProtocol], AlgorithmProtocol], + make_env: Callable[[int], EnvProtocol], + make_replay_buffer: Callable[[], ReplayBufferProtocol], + make_logger: Callable[[int], LoggerProtocol], + config: CommonParameters, + seeds: int = 1, + start_seed: int = 0 +) -> None: if seeds == 1: _run_training_func(make_algo, make_env, make_replay_buffer, make_logger, config, 0) else: @@ -37,7 +50,14 @@ def run_training( logging.info("Training finished.") -def _run_training_func(make_algo, make_env, make_replay_buffer, make_logger, config, seed: int): +def _run_training_func( + make_algo: Callable[[LoggerProtocol], AlgorithmProtocol], + make_env: Callable[[int], EnvProtocol], + make_replay_buffer: Callable[[], ReplayBufferProtocol], + make_logger: Callable[[int], LoggerProtocol], + config: CommonParameters, + seed: int, +) -> None: set_seed(seed) env = make_env(seed=seed) replay_buffer = make_replay_buffer() diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index 0fbf55b..b8adb3e 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -60,18 +60,24 @@ def train(self) -> None: if len(self.replay_buffer) < self.batch_size: continue - batch = self.replay_buffer.sample(self.batch_size) - self.algo.update(*batch) - - self._log_evaluation(env_step, batch) + ( + states, + actions, + rewards, + dones, + next_states + ) = self.replay_buffer.sample(self.batch_size) + self.algo.update(states, actions, rewards, dones, next_states) + + self._log_evaluation(env_step, rewards) self._save_policy(env_step) - self._log_stdout(env_step, batch) + self._log_stdout(env_step, rewards) - def _log_evaluation(self, env_step: int, batch): + def _log_evaluation(self, env_step: int, rewards: t.Tensor) -> None: if env_step % self.eval_interval == 0: eval_metrics = self.evaluate() self.logger.log_scalar("trainer/ep_reward", eval_metrics["return"], env_step) - self.logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) + self.logger.log_scalar("trainer/avg_reward", rewards.mean().item(), env_step) self.logger.log_scalar( "trainer/buffer_transitions", len(self.replay_buffer), env_step ) @@ -94,7 +100,7 @@ def evaluate(self) -> dict[str, float]: terminated, truncated = False, False while not (terminated or truncated): - action = self.algo.exploit(state) + action = self.algo.actor.exploit(state) state, reward, terminated, truncated, _ = env_test.step(action) episode_return += reward @@ -104,11 +110,16 @@ def evaluate(self) -> dict[str, float]: "return": float(np.mean(returns)) } - def _save_policy(self, env_step: int): + def _save_policy(self, env_step: int) -> None: if self.save_policy_every > 0 and env_step % self.save_policy_every == 0: - self.logger.save_weights(self.algo.actor, env_step) + weights_path = self.logger.log_dir / "weights" / f"{env_step}.w" + weights_path.parents[0].mkdir(exist_ok=True) + t.save( + self.algo.actor, + weights_path + ) - def _estimate_q(self, env_step: int): + def _estimate_q(self, env_step: int) -> None: if self.estimate_q_every > 0 and env_step % self.estimate_q_every == 0: q_true = self.estimate_true_q() q_critic = self.estimate_critic_q() @@ -119,44 +130,38 @@ def _estimate_q(self, env_step: int): "trainer/Q_asb_diff", q_critic - q_true, env_step ) - def _log_stdout(self, env_step: int, batch): + def _log_stdout(self, env_step: int, rewards: t.Tensor) -> None: if env_step % self.stdout_log_every == 0: perc = int(env_step / self.num_steps * 100) logger.info( - f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {batch[2].mean():10.3f}" + f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {rewards.mean():10.3f}" ) - def estimate_true_q(self, eval_episodes: int = 10) -> float | None: - try: - qs = [] - for i_eval in range(eval_episodes): - env = self.make_env_test(seed=self.seed * 100 + i_eval) - state, _ = env.reset() - - q = 0 - s_i = 1 - while True: - action = self.algo.exploit(state) - state, r, terminated, truncated, _ = env.step(action) - q += r * self.gamma ** s_i - s_i += 1 - if terminated or truncated: - break - - qs.append(q) - - return np.mean(qs, dtype=float) - except Exception as e: - logger.warning(f"Failed to estimate Q-value: {e}") - return None + def estimate_true_q(self, eval_episodes: int = 10) -> float: + qs = [] + for i_eval in range(eval_episodes): + env = self.make_env_test(seed=self.seed * 100 + i_eval) + state, _ = env.reset() + + q = 0 + s_i = 1 + while True: + action = self.algo.actor.exploit(state) + state, r, terminated, truncated, _ = env.step(action) + q += r * self.gamma ** s_i + s_i += 1 + if terminated or truncated: + break + qs.append(q) + + return np.mean(qs, dtype=float) def estimate_critic_q(self, num_episodes: int = 10) -> float: qs = [] for i_eval in range(num_episodes): env = self.make_env_test(seed=self.seed * 100 + i_eval) - state, _ = env.reset() - action = self.algo.exploit(state) + action = self.algo.actor.exploit(state) state = t.tensor(state).unsqueeze(0).float().to(self.device) action = t.tensor(action).unsqueeze(0).float().to(self.device) diff --git a/src/oprl/trainers/safe_trainer.py b/src/oprl/trainers/safe_trainer.py index 8aa038f..dc4d4f2 100644 --- a/src/oprl/trainers/safe_trainer.py +++ b/src/oprl/trainers/safe_trainer.py @@ -1,5 +1,6 @@ from dataclasses import dataclass +import torch as t import numpy as np from oprl.trainers.base_trainer import BaseTrainer, TrainerProtocol @@ -10,8 +11,8 @@ class SafeTrainer(TrainerProtocol): trainer: BaseTrainer def train(self): - self.algo.check_created() - self.replay_buffer.check_created() + self.trainer.algo.check_created() + self.trainer.replay_buffer.check_created() ep_step = 0 state, _ = self.trainer.env.reset() @@ -21,7 +22,7 @@ def train(self): if env_step <= self.trainer.start_steps: action = self.trainer.env.sample_action() else: - action = self.trainer.algo.explore(state) + action = self.trainer.algo.actor.explore(state) next_state, reward, terminated, truncated, info = self.trainer.env.step(action) total_cost += info["cost"] @@ -35,16 +36,22 @@ def train(self): if len(self.trainer.replay_buffer) < self.trainer.batch_size: continue - batch = self.trainer.replay_buffer.sample(self.trainer.batch_size) - self.trainer.algo.update(*batch) - - self._log_evaluation(env_step, batch) + ( + states, + actions, + rewards, + dones, + next_states + ) = self.trainer.replay_buffer.sample(self.trainer.batch_size) + self.trainer.algo.update(states, actions, rewards, dones, next_states) + + self._log_evaluation(env_step, rewards) self.trainer._save_policy(env_step) - self.trainer._log_stdout(env_step, batch) + self.trainer._log_stdout(env_step, rewards) self.trainer.logger.log_scalar("trainer/total_cost", total_cost, self.trainer.num_steps) - def _log_evaluation(self, env_step: int, batch) -> None: + def _log_evaluation(self, env_step: int, rewards: t.Tensor) -> None: if env_step % self.trainer.eval_interval == 0: eval_metrics = self.evaluate() self.trainer.logger.log_scalar( @@ -54,7 +61,7 @@ def _log_evaluation(self, env_step: int, batch) -> None: "trainer/ep_cost", eval_metrics["cost"], env_step ) - self.trainer.logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) + self.trainer.logger.log_scalar("trainer/avg_reward", rewards.mean().item(), env_step) self.trainer.logger.log_scalar( "trainer/buffer_transitions", len(self.trainer.replay_buffer), env_step ) @@ -79,7 +86,7 @@ def evaluate(self) -> dict[str, float]: terminated, truncated = False, False while not (terminated or truncated): - action = self.trainer.algo.exploit(state) + action = self.trainer.algo.actor.exploit(state) state, reward, terminated, truncated, info = env_test.step(action) episode_return += reward episode_cost += info["cost"] diff --git a/tests/functional/test_rl_algos.py b/tests/functional/test_rl_algos.py index 234d909..a202249 100644 --- a/tests/functional/test_rl_algos.py +++ b/tests/functional/test_rl_algos.py @@ -1,28 +1,34 @@ import pytest import torch +from oprl.algos.protocols import AlgorithmProtocol from oprl.algos.ddpg import DDPG from oprl.algos.sac import SAC from oprl.algos.td3 import TD3 from oprl.algos.tqc import TQC from oprl.environment import DMControlEnv +from oprl.logging import FileTxtLogger -rl_algo_classes = [DDPG, SAC, TD3, TQC] + +rl_algo_classes: list[type[AlgorithmProtocol]] = [DDPG, SAC, TD3, TQC] @pytest.mark.parametrize("algo_class", rl_algo_classes) -def test_rl_algo_run(algo_class): +def test_rl_algo_run(algo_class: type[AlgorithmProtocol]) -> None: env = DMControlEnv("walker-walk", seed=0) - obs, _ = env.reset(env.sample_action()) + # TODO: Change to mocked logger + logger = FileTxtLogger(".") + obs, _ = env.reset() algo = algo_class( + logger=logger, state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0], ).create() - action = algo.exploit(obs) + action = algo.actor.exploit(obs) assert action.ndim == 1 - action = algo.explore(obs) + action = algo.actor.explore(obs) assert action.ndim == 1 _batch_size = 8 From 80f63252f4a007c4d86d7100de22f1bdc474a85a Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 16:51:07 +0300 Subject: [PATCH 16/30] Make distrib code work --- configs/d3pg.py | 113 ----------------------------- configs/distrib_ddpg.py | 96 ++++++++++++++++++++++++ pyproject.toml | 1 + src/oprl/distrib/distrib_runner.py | 67 +++++++++++------ src/oprl/parse_args.py | 13 ++++ src/oprl/runners/train_distrib.py | 9 ++- 6 files changed, 162 insertions(+), 137 deletions(-) delete mode 100644 configs/d3pg.py create mode 100644 configs/distrib_ddpg.py diff --git a/configs/d3pg.py b/configs/d3pg.py deleted file mode 100644 index 93b60cc..0000000 --- a/configs/d3pg.py +++ /dev/null @@ -1,113 +0,0 @@ -import argparse -import logging -from multiprocessing import Process - -import torch.nn as nn - -from oprl.algos.ddpg import DDPG, DeterministicPolicy -from oprl.configs.utils import create_logdir -from oprl.distrib.distrib_runner import env_worker, policy_update_worker -from oprl.utils.utils import set_logging - -set_logging(logging.INFO) -from oprl.env import make_env as _make_env -from oprl.trainers.buffers.episodic_buffer import EpisodicReplayBuffer -from oprl.utils.logger import FileLogger, Logger - - -def parse_args(): - parser = argparse.ArgumentParser(description="Run training") - parser.add_argument("--config", type=str, help="Path to the config file.") - parser.add_argument( - "--env", type=str, default="cartpole-balance", help="Name of the environment." - ) - parser.add_argument( - "--device", type=str, default="cpu", help="Device to perform training on." - ) - parser.add_argument("--seed", type=int, default=0, help="Random seed") - return parser.parse_args() - - -# -------- Distrib params ----------- - -ENV_WORKERS = 4 -N_EPISODES = 50 # 500 # Number of episodes each env worker would perform - -# ----------------------------------- - -args = parse_args() - - -def make_env(seed: int): - return _make_env(args.env, seed=seed) - - -env = make_env(seed=0) -STATE_DIM = env.observation_space.shape[0] -ACTION_DIM = env.action_space.shape[0] -logging.info(f"Env state {STATE_DIM}\tEnv action {ACTION_DIM}") - - -log_dir = create_logdir(logdir="logs", algo="D3PG", env=args.env, seed=args.seed) -logging.info(f"LOG_DIR: {log_dir}") - - -def make_logger(seed: int) -> Logger: - log_dir = create_logdir(logdir="logs", algo="D3PG", env=args.env, seed=seed) - # TODO: add here actual config - return FileLogger(log_dir, {}) - - -def make_policy(): - return DeterministicPolicy( - state_dim=STATE_DIM, - action_dim=ACTION_DIM, - hidden_units=(256, 256), - hidden_activation=nn.ReLU(inplace=True), - device=args.device, - ) - - -def make_buffer(): - return EpisodicReplayBuffer( - buffer_size=int(1_000_000), - state_dim=STATE_DIM, - action_dim=ACTION_DIM, - device=args.device, - gamma=0.99, - ) - - -def make_algo(): - logger = make_logger(args.seed) - - algo = DDPG( - state_dim=STATE_DIM, - action_dim=ACTION_DIM, - device=args.device, - logger=logger, - ) - return algo - - -if __name__ == "__main__": - processes = [] - - for i_env in range(ENV_WORKERS): - processes.append( - Process(target=env_worker, args=(make_env, make_policy, N_EPISODES, i_env)) - ) - processes.append( - Process( - target=policy_update_worker, - args=(make_algo, make_env, make_buffer, ENV_WORKERS), - ) - ) - - for p in processes: - p.start() - - for p in processes: - p.join() - - logging.info("Training OK.") diff --git a/configs/distrib_ddpg.py b/configs/distrib_ddpg.py new file mode 100644 index 0000000..ef8c90e --- /dev/null +++ b/configs/distrib_ddpg.py @@ -0,0 +1,96 @@ +import os +import argparse +import logging +from multiprocessing import Process + +import torch.nn as nn + +from oprl.algos.ddpg import DDPG +from oprl.algos.nn_models import DeterministicPolicy +from oprl.distrib.distrib_runner import env_worker, policy_update_worker + +from oprl.environment import make_env as _make_env +from oprl.buffers.episodic_buffer import EpisodicReplayBuffer +from oprl.logging import ( + LoggerProtocol, + FileTxtLogger, + get_logs_path, +) +from oprl.parse_args import parse_args_distrib + + + +# -------- Distrib params ----------- + +ENV_WORKERS = 4 +EPISODES_PER_WORKER = 100 # Number of episodes each env worker would perform + +# ----------------------------------- + +args = parse_args_distrib() + +def make_env(seed: int): + return _make_env(args.env, seed=seed) + + +env = make_env(seed=0) +STATE_DIM = env.observation_space.shape[0] +ACTION_DIM = env.action_space.shape[0] +logging.info(f"Env state {STATE_DIM}\tEnv action {ACTION_DIM}") + + +def make_logger(): + logs_root = os.environ.get("OPRL_LOGS", "logs") + log_dir = get_logs_path(logdir=logs_root, algo="DistribDDPG", env=args.env, seed=0) + logger = FileTxtLogger(log_dir) + logger.copy_source_code() + return logger + + +def make_policy(): + return DeterministicPolicy( + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + hidden_units=(256, 256), + hidden_activation=nn.ReLU(inplace=True), + device=args.device, + ) + + +def make_buffer(): + return EpisodicReplayBuffer( + buffer_size_transitions=int(1_000_000), + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + device=args.device, + ).create() + + +def make_algo(logger: LoggerProtocol): + return DDPG( + logger=logger, + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + device=args.device, + ).create() + + +if __name__ == "__main__": + processes = [] + + for i_env in range(ENV_WORKERS): + processes.append( + Process(target=env_worker, args=(make_env, make_policy, EPISODES_PER_WORKER, i_env)) + ) + processes.append( + Process( + target=policy_update_worker, + args=(make_algo, make_env, make_buffer, make_logger, ENV_WORKERS), + ) + ) + + for p in processes: + p.start() + for p in processes: + p.join() + logging.info("Training OK.") diff --git a/pyproject.toml b/pyproject.toml index 0eee46a..6b4b570 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "dm-control==1.0.11", "mujoco==2.3.3", "numpy==1.26.4", + "pika==1.3.2", "pydantic_settings==2.10.1", ] diff --git a/src/oprl/distrib/distrib_runner.py b/src/oprl/distrib/distrib_runner.py index 58698be..6ea4a37 100644 --- a/src/oprl/distrib/distrib_runner.py +++ b/src/oprl/distrib/distrib_runner.py @@ -1,12 +1,17 @@ -import logging import pickle import time from itertools import count -from multiprocessing import Process +from pathlib import Path import numpy as np +import torch as t +import torch.nn as nn import pika -import torch + +from oprl.logging import create_stdout_logger + + +logger = create_stdout_logger() class Queue: @@ -30,23 +35,22 @@ def pop(self) -> bytes | None: def env_worker(make_env, make_policy, n_episodes, id_worker): env = make_env(seed=0) - logging.info("Env created.") + logger.info("Env created.") policy = make_policy() - logging.info("Policy created.") + logger.info("Policy created.") q_env = Queue(f"env_{id_worker}") q_policy = Queue(f"policy_{id_worker}") - logging.info("Queue created.") - - episodes = [] + logger.info("Queue created.") total_env_step = 0 # TODO: Move parameter to config start_steps = 1000 for i_ep in range(n_episodes): + print("Running episode: ", i_ep) if i_ep % 10 == 0: - logging.info(f"AGENT {id_worker} EPISODE {i_ep}") + logger.info(f"AGENT {id_worker} EPISODE {i_ep}") episode = [] state, _ = env.reset() @@ -70,36 +74,45 @@ def env_worker(make_env, make_policy, n_episodes, id_worker): while True: data = q_policy.pop() if data is None: - logging.info("Waiting for the policy..") + logger.info("Waiting for the policy..") time.sleep(2.0) continue policy.load_state_dict(pickle.loads(data)) break - logging.info("Episode by env worker is done.") + logger.info("Episode by env worker is done.") -def policy_update_worker(make_algo, make_env_test, make_buffer, n_workers): - algo = make_algo() - logging.info("Algo created.") +def save_policy(policy: nn.Module, save_path: Path): + save_path.parents[0].mkdir(exist_ok=True) + t.save( + policy, + save_path + ) + + +def policy_update_worker(make_algo, make_env_test, make_buffer, make_logger, n_workers): + scalar_logger = make_logger() + algo = make_algo(scalar_logger) + logger.info("Algo created.") buffer = make_buffer() - logging.info("Buffer created.") + logger.info("Buffer created.") q_envs = [] q_policies = [] for i_env in range(n_workers): q_envs.append(Queue(f"env_{i_env}")) q_policies.append(Queue(f"policy_{i_env}")) - logging.info("Learner queue created.") + logger.info("Learner queue created.") batch_size = 128 - logging.info("Warming up the learner...") + logger.info("Warming up the learner...") time.sleep(2.0) for i_epoch in count(0): - logging.info(f"Epoch: {i_epoch}") + logger.info(f"Epoch: {i_epoch}") n_waits = 0 for i_env in range(n_workers): while True: @@ -109,12 +122,12 @@ def policy_update_worker(make_algo, make_env_test, make_buffer, n_workers): buffer.add_episode(episode) break else: - logging.info("Waiting for the env data...") + logger.info("Waiting for the env data...") # TODO: not optimal wait for each queue time.sleep(1) n_waits += 1 if n_waits == 10: - logging.info("Learner tired to wait, exiting...") + logger.info("Learner tired to wait, exiting...") return continue @@ -124,7 +137,7 @@ def policy_update_worker(make_algo, make_env_test, make_buffer, n_workers): batch = buffer.sample(batch_size) algo.update(*batch) if i % int(1000) == 0: - logging.info(f"\tUpdating {i}") + logger.info(f"\tUpdating {i}") policy_state_dict = algo.get_policy_state_dict() @@ -132,11 +145,19 @@ def policy_update_worker(make_algo, make_env_test, make_buffer, n_workers): for i_env in range(n_workers): q_policies[i_env].push(policy_serialized) - if True: + + if i_epoch > 0 and i_epoch % 10 == 0: mean_reward = evaluate(algo, make_env_test) + logger.info(f"Evaluating policy [epoch {i_epoch}]: {mean_reward}") algo.logger.log_scalar("trainer/ep_reward", mean_reward, i_epoch) - logging.info("Update by policy update worker done.") + save_policy( + policy=algo.actor, + save_path=algo.logger.log_dir / "weights" / f"epoch_{i_epoch}.w" + ) + logger.info(f"Weights saved.") + + logger.info("Update by policy update worker done.") def evaluate(algo, make_env_test, num_eval_episodes: int = 5, seed: int = 0): diff --git a/src/oprl/parse_args.py b/src/oprl/parse_args.py index 060355c..66bb580 100644 --- a/src/oprl/parse_args.py +++ b/src/oprl/parse_args.py @@ -23,3 +23,16 @@ def parse_args() -> argparse.Namespace: "--device", type=str, default="cpu", help="Device to perform training on." ) return parser.parse_args() + + +def parse_args_distrib(): + parser = argparse.ArgumentParser(description="Run distrib training") + parser.add_argument("--config", type=str, help="Path to the config file.") + parser.add_argument( + "--env", type=str, default="cartpole-balance", help="Name of the environment." + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Device to perform training on." + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + return parser.parse_args() diff --git a/src/oprl/runners/train_distrib.py b/src/oprl/runners/train_distrib.py index 9ee3b14..0230103 100644 --- a/src/oprl/runners/train_distrib.py +++ b/src/oprl/runners/train_distrib.py @@ -4,7 +4,12 @@ from multiprocessing import Process import torch.nn as nn -from algos.ddpg import DDPG, DeterministicPolicy + +from oprl.algos.ddpg import DDPG +from oprl.algos.nn_models import DeterministicPolicy +print("Imports ok.") + +""" from distrib.distrib_runner import env_worker, policy_update_worker from env import DMControlEnv, make_env from trainers.buffers.episodic_buffer import EpisodicReplayBuffer @@ -12,6 +17,7 @@ + def parse_args(): parser = argparse.ArgumentParser(description="Run training") parser.add_argument("--config", type=str, help="Path to the config file.") @@ -103,3 +109,4 @@ def make_algo(): p.join() print("OK.") +""" From 8e00d9c69e6a82c9587898d084fab7346eae5093 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 18:17:46 +0300 Subject: [PATCH 17/30] Refactor distrib training --- configs/distrib_ddpg.py | 62 ++++----- src/oprl/algos/protocols.py | 7 +- src/oprl/distrib/distrib_runner.py | 163 ++--------------------- src/oprl/distrib/env_worker.py | 65 +++++++++ src/oprl/distrib/policy_update_worker.py | 119 +++++++++++++++++ src/oprl/distrib/queue.py | 20 +++ src/oprl/runners/config.py | 11 ++ src/oprl/runners/train_distrib.py | 116 +++------------- 8 files changed, 286 insertions(+), 277 deletions(-) create mode 100644 src/oprl/distrib/env_worker.py create mode 100644 src/oprl/distrib/policy_update_worker.py create mode 100644 src/oprl/distrib/queue.py diff --git a/configs/distrib_ddpg.py b/configs/distrib_ddpg.py index ef8c90e..70451d7 100644 --- a/configs/distrib_ddpg.py +++ b/configs/distrib_ddpg.py @@ -1,14 +1,12 @@ import os -import argparse import logging -from multiprocessing import Process import torch.nn as nn from oprl.algos.ddpg import DDPG from oprl.algos.nn_models import DeterministicPolicy -from oprl.distrib.distrib_runner import env_worker, policy_update_worker - +from oprl.algos.protocols import AlgorithmProtocol, PolicyProtocol +from oprl.buffers.protocols import ReplayBufferProtocol from oprl.environment import make_env as _make_env from oprl.buffers.episodic_buffer import EpisodicReplayBuffer from oprl.logging import ( @@ -17,16 +15,22 @@ get_logs_path, ) from oprl.parse_args import parse_args_distrib +from oprl.runners.config import DistribConfig +from oprl.runners.train_distrib import run_distrib_training +from oprl.distrib.env_worker import run_env_worker +from oprl.distrib.policy_update_worker import run_policy_update_worker + + +config = DistribConfig( + batch_size=128, + num_env_workers=4, + episodes_per_worker=100, + warmup_epochs=16, + episode_length=1000, + learner_num_waits=10, +) - -# -------- Distrib params ----------- - -ENV_WORKERS = 4 -EPISODES_PER_WORKER = 100 # Number of episodes each env worker would perform - -# ----------------------------------- - args = parse_args_distrib() def make_env(seed: int): @@ -39,7 +43,7 @@ def make_env(seed: int): logging.info(f"Env state {STATE_DIM}\tEnv action {ACTION_DIM}") -def make_logger(): +def make_logger() -> LoggerProtocol: logs_root = os.environ.get("OPRL_LOGS", "logs") log_dir = get_logs_path(logdir=logs_root, algo="DistribDDPG", env=args.env, seed=0) logger = FileTxtLogger(log_dir) @@ -47,7 +51,7 @@ def make_logger(): return logger -def make_policy(): +def make_policy() -> PolicyProtocol: return DeterministicPolicy( state_dim=STATE_DIM, action_dim=ACTION_DIM, @@ -57,7 +61,7 @@ def make_policy(): ) -def make_buffer(): +def make_replay_buffer() -> ReplayBufferProtocol: return EpisodicReplayBuffer( buffer_size_transitions=int(1_000_000), state_dim=STATE_DIM, @@ -66,7 +70,7 @@ def make_buffer(): ).create() -def make_algo(logger: LoggerProtocol): +def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol: return DDPG( logger=logger, state_dim=STATE_DIM, @@ -76,21 +80,13 @@ def make_algo(logger: LoggerProtocol): if __name__ == "__main__": - processes = [] - - for i_env in range(ENV_WORKERS): - processes.append( - Process(target=env_worker, args=(make_env, make_policy, EPISODES_PER_WORKER, i_env)) - ) - processes.append( - Process( - target=policy_update_worker, - args=(make_algo, make_env, make_buffer, make_logger, ENV_WORKERS), - ) + run_distrib_training( + run_env_worker=run_env_worker, + run_policy_update_worker=run_policy_update_worker, + make_env=make_env, + make_algo=make_algo, + make_policy=make_policy, + make_replay_buffer=make_replay_buffer, + make_logger=make_logger, + config=config, ) - - for p in processes: - p.start() - for p in processes: - p.join() - logging.info("Training OK.") diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py index 94f7322..065d12f 100644 --- a/src/oprl/algos/protocols.py +++ b/src/oprl/algos/protocols.py @@ -1,10 +1,12 @@ -from typing import Protocol +from typing import Protocol, Any import numpy.typing as npt import torch as t import torch.nn as nn +from oprl.logging import LoggerProtocol + class PolicyProtocol(Protocol): def explore(self, state: npt.NDArray) -> npt.NDArray: ... @@ -17,6 +19,7 @@ def __call__(*args, **kwargs) -> t.Tensor: ... class AlgorithmProtocol(Protocol): actor: PolicyProtocol critic: nn.Module + logger: LoggerProtocol _created: bool def create(self) -> "AlgorithmProtocol": ... @@ -32,4 +35,6 @@ def update( next_state: t.Tensor ) -> None: ... + def get_policy_state_dict(self) -> dict[str, Any]: + return self.actor.state_dict() diff --git a/src/oprl/distrib/distrib_runner.py b/src/oprl/distrib/distrib_runner.py index 6ea4a37..1b438a4 100644 --- a/src/oprl/distrib/distrib_runner.py +++ b/src/oprl/distrib/distrib_runner.py @@ -2,22 +2,26 @@ import time from itertools import count from pathlib import Path +from typing import Callable import numpy as np import torch as t import torch.nn as nn import pika -from oprl.logging import create_stdout_logger +from oprl.algos.protocols import AlgorithmProtocol, PolicyProtocol +from oprl.environment.protocols import EnvProtocol +from oprl.logging import create_stdout_logger, LoggerProtocol +from oprl.runners.config import DistribConfig +from oprl.buffers.protocols import ReplayBufferProtocol logger = create_stdout_logger() class Queue: - def __init__(self, name: str, host: str = "localhost"): + def __init__(self, name: str, host: str = "localhost") -> None: self._name = name - connection = pika.BlockingConnection(pika.ConnectionParameters(host=host)) self.channel = connection.channel() self.channel.queue_declare(queue=name) @@ -26,155 +30,16 @@ def push(self, data) -> None: self.channel.basic_publish(exchange="", routing_key=self._name, body=data) def pop(self) -> bytes | None: - method_frame, header_frame, body = self.channel.basic_get(queue=self._name) + method_frame, _, body = self.channel.basic_get(queue=self._name) if method_frame: self.channel.basic_ack(method_frame.delivery_tag) return body return None -def env_worker(make_env, make_policy, n_episodes, id_worker): - env = make_env(seed=0) - logger.info("Env created.") - - policy = make_policy() - logger.info("Policy created.") - - q_env = Queue(f"env_{id_worker}") - q_policy = Queue(f"policy_{id_worker}") - logger.info("Queue created.") - - total_env_step = 0 - # TODO: Move parameter to config - start_steps = 1000 - for i_ep in range(n_episodes): - print("Running episode: ", i_ep) - if i_ep % 10 == 0: - logger.info(f"AGENT {id_worker} EPISODE {i_ep}") - - episode = [] - state, _ = env.reset() - # TODO: Move parameter to config - for env_step in range(1000): - if total_env_step <= start_steps: - action = env.sample_action() - else: - action = policy.explore(state) - - next_state, reward, terminated, truncated, _ = env.step(action) - episode.append([state, action, reward, terminated, next_state]) - - if terminated or truncated: - break - state = next_state - total_env_step += 1 - - q_env.push(pickle.dumps(episode)) - - while True: - data = q_policy.pop() - if data is None: - logger.info("Waiting for the policy..") - time.sleep(2.0) - continue - - policy.load_state_dict(pickle.loads(data)) - break - - logger.info("Episode by env worker is done.") - - -def save_policy(policy: nn.Module, save_path: Path): - save_path.parents[0].mkdir(exist_ok=True) - t.save( - policy, - save_path - ) - - -def policy_update_worker(make_algo, make_env_test, make_buffer, make_logger, n_workers): - scalar_logger = make_logger() - algo = make_algo(scalar_logger) - logger.info("Algo created.") - buffer = make_buffer() - logger.info("Buffer created.") - - q_envs = [] - q_policies = [] - for i_env in range(n_workers): - q_envs.append(Queue(f"env_{i_env}")) - q_policies.append(Queue(f"policy_{i_env}")) - logger.info("Learner queue created.") - - batch_size = 128 - - logger.info("Warming up the learner...") - time.sleep(2.0) - - for i_epoch in count(0): - logger.info(f"Epoch: {i_epoch}") - n_waits = 0 - for i_env in range(n_workers): - while True: - data = q_envs[i_env].pop() - if data: - episode = pickle.loads(data) - buffer.add_episode(episode) - break - else: - logger.info("Waiting for the env data...") - # TODO: not optimal wait for each queue - time.sleep(1) - n_waits += 1 - if n_waits == 10: - logger.info("Learner tired to wait, exiting...") - return - continue - - # TODO: Remove hardcoded value - if i_epoch > 16: - for i in range(1000 * 4): - batch = buffer.sample(batch_size) - algo.update(*batch) - if i % int(1000) == 0: - logger.info(f"\tUpdating {i}") - - policy_state_dict = algo.get_policy_state_dict() - - policy_serialized = pickle.dumps(policy_state_dict) - for i_env in range(n_workers): - q_policies[i_env].push(policy_serialized) - - - if i_epoch > 0 and i_epoch % 10 == 0: - mean_reward = evaluate(algo, make_env_test) - logger.info(f"Evaluating policy [epoch {i_epoch}]: {mean_reward}") - algo.logger.log_scalar("trainer/ep_reward", mean_reward, i_epoch) - - save_policy( - policy=algo.actor, - save_path=algo.logger.log_dir / "weights" / f"epoch_{i_epoch}.w" - ) - logger.info(f"Weights saved.") - - logger.info("Update by policy update worker done.") - - -def evaluate(algo, make_env_test, num_eval_episodes: int = 5, seed: int = 0): - returns = [] - for i_ep in range(num_eval_episodes): - env_test = make_env_test(seed * 100 + i_ep) - state, _ = env_test.reset() - - episode_return = 0.0 - terminated, truncated = False, False - - while not (terminated or truncated): - action = algo.exploit(state) - state, reward, terminated, truncated, _ = env_test.step(action) - episode_return += reward - - returns.append(episode_return) - - mean_return = np.mean(returns) - return mean_return + + + + + + diff --git a/src/oprl/distrib/env_worker.py b/src/oprl/distrib/env_worker.py new file mode 100644 index 0000000..67bcdeb --- /dev/null +++ b/src/oprl/distrib/env_worker.py @@ -0,0 +1,65 @@ +import pickle +import time +from typing import Callable + +from oprl.algos.protocols import PolicyProtocol +from oprl.environment.protocols import EnvProtocol +from oprl.logging import create_stdout_logger +from oprl.runners.config import DistribConfig +from oprl.distrib.queue import Queue + + +logger = create_stdout_logger() + + +def run_env_worker( + make_env: Callable[[int], EnvProtocol], + make_policy: Callable[[], PolicyProtocol], + config: DistribConfig, + id_worker: int, +) -> None: + env = make_env(seed=0) + logger.info("Env created.") + + policy = make_policy() + logger.info("Policy created.") + + q_env = Queue(f"env_{id_worker}") + q_policy = Queue(f"policy_{id_worker}") + logger.info("Queue created.") + + total_env_step = 0 + for i_ep in range(config.episodes_per_worker): + print("Running episode: ", i_ep) + if i_ep % 10 == 0: + logger.info(f"AGENT {id_worker} EPISODE {i_ep}") + + episode = [] + state, _ = env.reset() + for _ in range(config.episode_length): + if total_env_step <= config.warmup_env_steps: + action = env.sample_action() + else: + action = policy.explore(state) + + next_state, reward, terminated, truncated, _ = env.step(action) + episode.append([state, action, reward, terminated, next_state]) + + if terminated or truncated: + break + state = next_state + total_env_step += 1 + + q_env.push(pickle.dumps(episode)) + + while True: + data = q_policy.pop() + if data is None: + logger.info("Waiting for the policy..") + time.sleep(2.0) + continue + policy.load_state_dict(pickle.loads(data)) + break + + logger.info("Episode by env worker is done.") + diff --git a/src/oprl/distrib/policy_update_worker.py b/src/oprl/distrib/policy_update_worker.py new file mode 100644 index 0000000..ad17c8b --- /dev/null +++ b/src/oprl/distrib/policy_update_worker.py @@ -0,0 +1,119 @@ +import pickle +import time +from itertools import count +from pathlib import Path +from typing import Callable + +import numpy as np +import torch as t +import torch.nn as nn + +from oprl.algos.protocols import AlgorithmProtocol +from oprl.environment.protocols import EnvProtocol +from oprl.logging import create_stdout_logger, LoggerProtocol +from oprl.runners.config import DistribConfig +from oprl.buffers.protocols import ReplayBufferProtocol +from oprl.distrib.queue import Queue + + +logger = create_stdout_logger() + + +def run_policy_update_worker( + make_algo: Callable[[LoggerProtocol], AlgorithmProtocol], + make_env_test: Callable[[int], EnvProtocol], + make_buffer: Callable[[], ReplayBufferProtocol], + make_logger: Callable[[], LoggerProtocol], + config: DistribConfig, +) -> None: + scalar_logger = make_logger() + algo = make_algo(scalar_logger) + logger.info("Algo created.") + buffer = make_buffer() + logger.info("Buffer created.") + + q_envs = [] + q_policies = [] + for i_env in range(config.num_env_workers): + q_envs.append(Queue(f"env_{i_env}")) + q_policies.append(Queue(f"policy_{i_env}")) + logger.info("Learner queue created.") + + logger.info("Warming up the learner...") + time.sleep(2.0) + + for i_epoch in count(0): + logger.info(f"Epoch: {i_epoch}") + n_waits = 0 + for i_env in range(config.num_env_workers): + while True: + data = q_envs[i_env].pop() + if data: + episode = pickle.loads(data) + buffer.add_episode(episode) + break + else: + logger.info("Waiting for the env data...") + # TODO: not optimal wait for each queue + time.sleep(1) + n_waits += 1 + if n_waits == config.learner_num_waits: + logger.info("Learner is not receiving data, exiting...") + return + continue + + if i_epoch > config.warmup_epochs: + for i in range(config.episode_length * config.num_env_workers): + batch = buffer.sample(config.batch_size) + algo.update(*batch) + if i % int(1000) == 0: + logger.info(f"\tUpdating {i}") + + policy_state_dict = algo.get_policy_state_dict() + + policy_serialized = pickle.dumps(policy_state_dict) + for i_env in range(config.num_env_workers): + q_policies[i_env].push(policy_serialized) + + + if i_epoch > 0 and i_epoch % 10 == 0: + mean_reward = evaluate(algo, make_env_test) + logger.info(f"Evaluating policy [epoch {i_epoch}]: {mean_reward}") + algo.logger.log_scalar("trainer/ep_reward", mean_reward, i_epoch) + save_policy( + policy=algo.actor, + save_path=algo.logger.log_dir / "weights" / f"epoch_{i_epoch}.w" + ) + logger.info(f"Weights saved.") + + logger.info("Update by policy update worker done.") + + +def save_policy(policy: nn.Module, save_path: Path) -> None: + save_path.parents[0].mkdir(exist_ok=True) + t.save( + policy, + save_path + ) + + +def evaluate( + algo: AlgorithmProtocol, + make_env_test: Callable[[int], EnvProtocol], + num_eval_episodes: int = 5, + seed: int = 0 +) -> float: + returns = [] + for i_ep in range(num_eval_episodes): + env_test = make_env_test(seed * 100 + i_ep) + state, _ = env_test.reset() + + episode_return = 0.0 + terminated, truncated = False, False + while not (terminated or truncated): + action = algo.actor.exploit(state) + state, reward, terminated, truncated, _ = env_test.step(action) + episode_return += reward + returns.append(episode_return) + + return np.mean(returns) diff --git a/src/oprl/distrib/queue.py b/src/oprl/distrib/queue.py new file mode 100644 index 0000000..cfa9e91 --- /dev/null +++ b/src/oprl/distrib/queue.py @@ -0,0 +1,20 @@ +import pika + + +class Queue: + def __init__(self, name: str, host: str = "localhost") -> None: + self._name = name + connection = pika.BlockingConnection(pika.ConnectionParameters(host=host)) + self.channel = connection.channel() + self.channel.queue_declare(queue=name) + + def push(self, data) -> None: + self.channel.basic_publish(exchange="", routing_key=self._name, body=data) + + def pop(self) -> bytes | None: + method_frame, _, body = self.channel.basic_get(queue=self._name) + if method_frame: + self.channel.basic_ack(method_frame.delivery_tag) + return body + return None + diff --git a/src/oprl/runners/config.py b/src/oprl/runners/config.py index 80fb6d1..b84b6d5 100644 --- a/src/oprl/runners/config.py +++ b/src/oprl/runners/config.py @@ -9,3 +9,14 @@ class CommonParameters(BaseSettings): estimate_q_every: int = 5000 log_every: int = 2500 device: str = "cpu" + + +class DistribConfig(BaseSettings): + batch_size: int = 128 + num_env_workers: int = 4 + episodes_per_worker: int = 100 + warmup_epochs: int = 16 + episode_length: int = 1000 + learner_num_waits: int = 10 + warmup_env_steps: int = 1000 + diff --git a/src/oprl/runners/train_distrib.py b/src/oprl/runners/train_distrib.py index 0230103..cf639cd 100644 --- a/src/oprl/runners/train_distrib.py +++ b/src/oprl/runners/train_distrib.py @@ -1,112 +1,40 @@ -import argparse -import os -from datetime import datetime +from typing import Callable from multiprocessing import Process -import torch.nn as nn +from oprl.algos.protocols import AlgorithmProtocol, PolicyProtocol +from oprl.buffers.protocols import ReplayBufferProtocol +from oprl.environment.protocols import EnvProtocol +from oprl.logging import LoggerProtocol, create_stdout_logger +from oprl.runners.config import DistribConfig -from oprl.algos.ddpg import DDPG -from oprl.algos.nn_models import DeterministicPolicy -print("Imports ok.") -""" -from distrib.distrib_runner import env_worker, policy_update_worker -from env import DMControlEnv, make_env -from trainers.buffers.episodic_buffer import EpisodicReplayBuffer -from utils.logger import Logger +logger = create_stdout_logger() - - -def parse_args(): - parser = argparse.ArgumentParser(description="Run training") - parser.add_argument("--config", type=str, help="Path to the config file.") - parser.add_argument( - "--env", type=str, default="cartpole-balance", help="Name of the environment." - ) - parser.add_argument( - "--device", type=str, default="cpu", help="Device to perform training on." - ) - return parser.parse_args() - - -args = parse_args() - - -def make_env(seed: int): - """ - Args: - name: Environment name. - """ - return DMControlEnv(args.env, seed=seed) - - -env = make_env(seed=0) - -STATE_SHAPE = env.observation_space.shape -ACTION_SHAPE = env.action_space.shape -print("STATE ACTION SHAPE: ", STATE_SHAPE, ACTION_SHAPE) - - -def make_policy(): - return DeterministicPolicy( - state_dim=STATE_SHAPE, - action_dim=ACTION_SHAPE, - hidden_units=[256, 256], - hidden_activation=nn.ReLU(inplace=True), - ) - - -def make_buffer(): - buffer = EpisodicReplayBuffer( - buffer_size=int(100_000), - state_shape=STATE_SHAPE, - action_shape=ACTION_SHAPE, - device="cpu", - gamma=0.99, - ) - return buffer - - -def make_algo(): - time = datetime.now().strftime("%Y-%m-%d_%H_%M") - log_dir = os.path.join("logs_debug", "DDPG", f"DDPG-env_ENV-seedSEED-{time}") - print("LOGDIR: ", log_dir) - logger = Logger(log_dir, {}) - - algo = DDPG( - state_dim=STATE_SHAPE, - action_dim=ACTION_SHAPE, - device="cpu", - seed=0, - logger=logger, - ) - return algo - - -if __name__ == "__main__": - ENV_WORKERS = 2 - - seed = 0 - +def run_distrib_training( + run_env_worker: Callable, + run_policy_update_worker: Callable, + make_env: Callable[[int], EnvProtocol], + make_algo: Callable[[LoggerProtocol], AlgorithmProtocol], + make_policy: Callable[[], PolicyProtocol], + make_replay_buffer: Callable[[], ReplayBufferProtocol], + make_logger: Callable[[], LoggerProtocol], + config: DistribConfig +) -> None: processes = [] - - for i_env in range(ENV_WORKERS): + for i_env in range(config.num_env_workers): processes.append( - Process(target=env_worker, args=(make_env, make_policy, i_env)) + Process(target=run_env_worker, args=(make_env, make_policy, config, i_env)) ) processes.append( Process( - target=policy_update_worker, - args=(make_algo, make_env, make_buffer, ENV_WORKERS), + target=run_policy_update_worker, + args=(make_algo, make_env, make_replay_buffer, make_logger, config), ) ) for p in processes: p.start() - for p in processes: p.join() - - print("OK.") -""" + logger.info("Training Finished.") From 45f006d779f07efb31a26cefe93b2544c2168ebe Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 18:35:43 +0300 Subject: [PATCH 18/30] Add explicit gymnasium support --- pyproject.toml | 1 + src/oprl/environment/gymnasium.py | 42 +++++++++++++++++++++++++++++++ src/oprl/environment/make_env.py | 21 ++++++++++++++++ tests/functional/test_env.py | 17 ++++++++++++- 4 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 src/oprl/environment/gymnasium.py diff --git a/pyproject.toml b/pyproject.toml index 6b4b570..67aa352 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "numpy==1.26.4", "pika==1.3.2", "pydantic_settings==2.10.1", + "gymnasium==0.28.1", ] [project.optional-dependencies] diff --git a/src/oprl/environment/gymnasium.py b/src/oprl/environment/gymnasium.py new file mode 100644 index 0000000..a2c122a --- /dev/null +++ b/src/oprl/environment/gymnasium.py @@ -0,0 +1,42 @@ +import numpy.typing as npt +from typing import Any + +import gymnasium as gym + +from oprl.environment.protocols import EnvProtocol + + +class Gymnasium(EnvProtocol): + def __init__(self, env_name: str, seed: int) -> None: + self._env = gym.make(env_name) + self._seed = seed + + def step( + self, + action: npt.NDArray, + ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]: + obs, reward, terminated, truncated, info = self._env.step(action) + return obs.astype("float32"), float(reward), terminated, bool(truncated), info + + def reset(self) -> tuple[npt.NDArray, dict[str, Any]]: + obs, info = self._env.reset(seed=self._seed) + self._env.step(self._env.action_space.sample()) + return obs.astype("float32"), info + + def sample_action(self) -> npt.NDArray: + return self._env.action_space.sample() + + def render(self) -> npt.NDArray: + return self._env.render() + + @property + def observation_space(self) -> npt.NDArray: + return self._env.observation_space + + @property + def action_space(self) -> npt.NDArray: + return self._env.action_space + + @property + def env_family(self) -> str: + return "safety_gymnasium" diff --git a/src/oprl/environment/make_env.py b/src/oprl/environment/make_env.py index 3b3a158..6f85544 100644 --- a/src/oprl/environment/make_env.py +++ b/src/oprl/environment/make_env.py @@ -1,6 +1,7 @@ from collections import defaultdict from oprl.environment import EnvProtocol, DMControlEnv, SafetyGym +from oprl.environment.gymnasium import Gymnasium @@ -34,6 +35,24 @@ ] ) + +ENV_MAPPER["gymnasium"] = set( + [ + "Ant-v4", + "Hopper-v4", + "HalfCheetah-v4", + "HumanoidStandup-v4", + "Humanoid-v4", + "InvertedDoublePendulum-v4", + "InvertedPendulum-v4", + "Pusher-v4", + "Reacher-v4", + "Swimmer-v4", + "Walker2d-v4", + ] +) + + ENV_MAPPER["safety_gymnasium"] = set( [ "SafetyPointGoal1-v0", @@ -79,4 +98,6 @@ def make_env(name: str, seed: int) -> EnvProtocol: return DMControlEnv(name, seed=seed) elif env_type == "safety_gymnasium": return SafetyGym(name, seed=seed) + elif env_type == "gymnasium": + return Gymnasium(name, seed=seed) raise ValueError(f"Unsupported environment: {name}") diff --git a/tests/functional/test_env.py b/tests/functional/test_env.py index 0ff2bc5..4a38b8e 100644 --- a/tests/functional/test_env.py +++ b/tests/functional/test_env.py @@ -31,6 +31,21 @@ ] +gymnasium_envs: list[str] = [ + "Ant-v4", + "Hopper-v4", + "HalfCheetah-v4", + "HumanoidStandup-v4", + "Humanoid-v4", + "InvertedDoublePendulum-v4", + "InvertedPendulum-v4", + "Pusher-v4", + "Reacher-v4", + "Swimmer-v4", + "Walker2d-v4", +] + + safety_envs: list[str] = [ "SafetyPointGoal1-v0", "SafetyPointButton1-v0", @@ -39,7 +54,7 @@ ] -env_names: list[str] = dm_control_envs + safety_envs +env_names: list[str] = dm_control_envs + safety_envs + gymnasium_envs @pytest.mark.parametrize("env_name", env_names) From 6b972d7791dd0f6ee573f487523e13a5bf55608f Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 18:48:26 +0300 Subject: [PATCH 19/30] Fix: gymnasium support --- src/oprl/environment/gymnasium.py | 2 +- src/oprl/runners/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/oprl/environment/gymnasium.py b/src/oprl/environment/gymnasium.py index a2c122a..60f802a 100644 --- a/src/oprl/environment/gymnasium.py +++ b/src/oprl/environment/gymnasium.py @@ -39,4 +39,4 @@ def action_space(self) -> npt.NDArray: @property def env_family(self) -> str: - return "safety_gymnasium" + return "gymnasium" diff --git a/src/oprl/runners/train.py b/src/oprl/runners/train.py index 2c0180d..d801dac 100644 --- a/src/oprl/runners/train.py +++ b/src/oprl/runners/train.py @@ -77,7 +77,7 @@ def _run_training_func( seed=seed, logger=logger, ) - if env.env_family == "dm_control": + if env.env_family in ["dm_control", "gymnasium"]: trainer = base_trainer elif env.env_family == "safety_gymnasium": trainer = SafeTrainer(trainer=base_trainer) From 5b815f203b8c56212cf939de5fd3611526a69bdd Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 19:18:43 +0300 Subject: [PATCH 20/30] Add uv ruff --- pyproject.toml | 1 + src/oprl/distrib/distrib_runner.py | 45 - src/oprl/distrib/policy_update_worker.py | 2 +- src/oprl/environment/__init__.py | 2 +- tests/functional/test_env.py | 6 + uv.lock | 1141 ++++++++++++++++++++++ 6 files changed, 1150 insertions(+), 47 deletions(-) delete mode 100644 src/oprl/distrib/distrib_runner.py create mode 100644 uv.lock diff --git a/pyproject.toml b/pyproject.toml index 67aa352..57cdc33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pika==1.3.2", "pydantic_settings==2.10.1", "gymnasium==0.28.1", + "ruff>=0.12.3", ] [project.optional-dependencies] diff --git a/src/oprl/distrib/distrib_runner.py b/src/oprl/distrib/distrib_runner.py deleted file mode 100644 index 1b438a4..0000000 --- a/src/oprl/distrib/distrib_runner.py +++ /dev/null @@ -1,45 +0,0 @@ -import pickle -import time -from itertools import count -from pathlib import Path -from typing import Callable - -import numpy as np -import torch as t -import torch.nn as nn -import pika - -from oprl.algos.protocols import AlgorithmProtocol, PolicyProtocol -from oprl.environment.protocols import EnvProtocol -from oprl.logging import create_stdout_logger, LoggerProtocol -from oprl.runners.config import DistribConfig -from oprl.buffers.protocols import ReplayBufferProtocol - - -logger = create_stdout_logger() - - -class Queue: - def __init__(self, name: str, host: str = "localhost") -> None: - self._name = name - connection = pika.BlockingConnection(pika.ConnectionParameters(host=host)) - self.channel = connection.channel() - self.channel.queue_declare(queue=name) - - def push(self, data) -> None: - self.channel.basic_publish(exchange="", routing_key=self._name, body=data) - - def pop(self) -> bytes | None: - method_frame, _, body = self.channel.basic_get(queue=self._name) - if method_frame: - self.channel.basic_ack(method_frame.delivery_tag) - return body - return None - - - - - - - - diff --git a/src/oprl/distrib/policy_update_worker.py b/src/oprl/distrib/policy_update_worker.py index ad17c8b..4dea035 100644 --- a/src/oprl/distrib/policy_update_worker.py +++ b/src/oprl/distrib/policy_update_worker.py @@ -84,7 +84,7 @@ def run_policy_update_worker( policy=algo.actor, save_path=algo.logger.log_dir / "weights" / f"epoch_{i_epoch}.w" ) - logger.info(f"Weights saved.") + logger.info("Weights saved.") logger.info("Update by policy update worker done.") diff --git a/src/oprl/environment/__init__.py b/src/oprl/environment/__init__.py index 1dc2688..3e0fa36 100644 --- a/src/oprl/environment/__init__.py +++ b/src/oprl/environment/__init__.py @@ -3,6 +3,6 @@ from oprl.environment.safety_gymnasium import SafetyGym from oprl.environment.make_env import make_env -___all__ = ['DMControlEnv', 'SafetyGym', "make_env", "EnvProtocol"] +___all__ = [DMControlEnv, SafetyGym, make_env, EnvProtocol] diff --git a/tests/functional/test_env.py b/tests/functional/test_env.py index 4a38b8e..918da39 100644 --- a/tests/functional/test_env.py +++ b/tests/functional/test_env.py @@ -46,6 +46,12 @@ ] +# TODO: +# gymansium_robotics_envs: list[str] = [ +# "FetchPickAndPlace-v3", +# ] + + safety_envs: list[str] = [ "SafetyPointGoal1-v0", "SafetyPointButton1-v0", diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..c0ea2d1 --- /dev/null +++ b/uv.lock @@ -0,0 +1,1141 @@ +version = 1 +revision = 2 +requires-python = "==3.10.8" + +[[package]] +name = "absl-py" +version = "2.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "attrs" +version = "25.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/b0/1367933a8532ee6ff8d63537de4f1177af4bff9f3e829baf7331f595bb24/attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b", size = 812032, upload-time = "2025-03-13T11:10:22.779Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, +] + +[[package]] +name = "black" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "mypy-extensions" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "platformdirs" }, + { name = "tomli" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449, upload-time = "2025-01-29T04:15:40.373Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419, upload-time = "2025-01-29T05:37:06.642Z" }, + { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080, upload-time = "2025-01-29T05:37:09.321Z" }, + { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886, upload-time = "2025-01-29T04:18:24.432Z" }, + { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404, upload-time = "2025-01-29T04:19:04.296Z" }, + { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, +] + +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380, upload-time = "2025-02-20T21:01:19.524Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080, upload-time = "2025-02-20T21:01:16.647Z" }, +] + +[[package]] +name = "certifi" +version = "2025.7.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/76/52c535bcebe74590f296d6c77c86dabf761c41980e1347a2422e4aa2ae41/certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995", size = 163981, upload-time = "2025-07-14T03:29:28.449Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/52/34c6cf5bb9285074dc3531c437b3919e825d976fde097a7a73f79e726d03/certifi-2025.7.14-py3-none-any.whl", hash = "sha256:6b31f564a415d79ee77df69d757bb49a5bb53bd9f756cbbe24394ffd6fc1f4b2", size = 162722, upload-time = "2025-07-14T03:29:26.863Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/33/89c2ced2b67d1c2a61c19c6751aa8902d46ce3dacb23600a283619f5a12d/charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63", size = 126367, upload-time = "2025-05-02T08:34:42.01Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/28/9901804da60055b406e1a1c5ba7aac1276fb77f1dde635aabfc7fd84b8ab/charset_normalizer-3.4.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941", size = 201818, upload-time = "2025-05-02T08:31:46.725Z" }, + { url = "https://files.pythonhosted.org/packages/d9/9b/892a8c8af9110935e5adcbb06d9c6fe741b6bb02608c6513983048ba1a18/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd", size = 144649, upload-time = "2025-05-02T08:31:48.889Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a5/4179abd063ff6414223575e008593861d62abfc22455b5d1a44995b7c101/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cbfacf36cb0ec2897ce0ebc5d08ca44213af24265bd56eca54bee7923c48fd6", size = 155045, upload-time = "2025-05-02T08:31:50.757Z" }, + { url = "https://files.pythonhosted.org/packages/3b/95/bc08c7dfeddd26b4be8c8287b9bb055716f31077c8b0ea1cd09553794665/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18dd2e350387c87dabe711b86f83c9c78af772c748904d372ade190b5c7c9d4d", size = 147356, upload-time = "2025-05-02T08:31:52.634Z" }, + { url = "https://files.pythonhosted.org/packages/a8/2d/7a5b635aa65284bf3eab7653e8b4151ab420ecbae918d3e359d1947b4d61/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8075c35cd58273fee266c58c0c9b670947c19df5fb98e7b66710e04ad4e9ff86", size = 149471, upload-time = "2025-05-02T08:31:56.207Z" }, + { url = "https://files.pythonhosted.org/packages/ae/38/51fc6ac74251fd331a8cfdb7ec57beba8c23fd5493f1050f71c87ef77ed0/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5bf4545e3b962767e5c06fe1738f951f77d27967cb2caa64c28be7c4563e162c", size = 151317, upload-time = "2025-05-02T08:31:57.613Z" }, + { url = "https://files.pythonhosted.org/packages/b7/17/edee1e32215ee6e9e46c3e482645b46575a44a2d72c7dfd49e49f60ce6bf/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a6ab32f7210554a96cd9e33abe3ddd86732beeafc7a28e9955cdf22ffadbab0", size = 146368, upload-time = "2025-05-02T08:31:59.468Z" }, + { url = "https://files.pythonhosted.org/packages/26/2c/ea3e66f2b5f21fd00b2825c94cafb8c326ea6240cd80a91eb09e4a285830/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b33de11b92e9f75a2b545d6e9b6f37e398d86c3e9e9653c4864eb7e89c5773ef", size = 154491, upload-time = "2025-05-02T08:32:01.219Z" }, + { url = "https://files.pythonhosted.org/packages/52/47/7be7fa972422ad062e909fd62460d45c3ef4c141805b7078dbab15904ff7/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8755483f3c00d6c9a77f490c17e6ab0c8729e39e6390328e42521ef175380ae6", size = 157695, upload-time = "2025-05-02T08:32:03.045Z" }, + { url = "https://files.pythonhosted.org/packages/2f/42/9f02c194da282b2b340f28e5fb60762de1151387a36842a92b533685c61e/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:68a328e5f55ec37c57f19ebb1fdc56a248db2e3e9ad769919a58672958e8f366", size = 154849, upload-time = "2025-05-02T08:32:04.651Z" }, + { url = "https://files.pythonhosted.org/packages/67/44/89cacd6628f31fb0b63201a618049be4be2a7435a31b55b5eb1c3674547a/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:21b2899062867b0e1fde9b724f8aecb1af14f2778d69aacd1a5a1853a597a5db", size = 150091, upload-time = "2025-05-02T08:32:06.719Z" }, + { url = "https://files.pythonhosted.org/packages/1f/79/4b8da9f712bc079c0f16b6d67b099b0b8d808c2292c937f267d816ec5ecc/charset_normalizer-3.4.2-cp310-cp310-win32.whl", hash = "sha256:e8082b26888e2f8b36a042a58307d5b917ef2b1cacab921ad3323ef91901c71a", size = 98445, upload-time = "2025-05-02T08:32:08.66Z" }, + { url = "https://files.pythonhosted.org/packages/7d/d7/96970afb4fb66497a40761cdf7bd4f6fca0fc7bafde3a84f836c1f57a926/charset_normalizer-3.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:f69a27e45c43520f5487f27627059b64aaf160415589230992cec34c5e18a509", size = 105782, upload-time = "2025-05-02T08:32:10.46Z" }, + { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, +] + +[[package]] +name = "click" +version = "8.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, +] + +[[package]] +name = "cloudpickle" +version = "3.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/52/39/069100b84d7418bc358d81669d5748efb14b9cceacd2f9c75f550424132f/cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64", size = 22113, upload-time = "2025-01-14T17:02:05.085Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e", size = 20992, upload-time = "2025-01-14T17:02:02.417Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "dm-control" +version = "1.0.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "dm-env" }, + { name = "dm-tree" }, + { name = "glfw" }, + { name = "labmaze" }, + { name = "lxml" }, + { name = "mujoco" }, + { name = "numpy" }, + { name = "protobuf" }, + { name = "pyopengl" }, + { name = "pyparsing" }, + { name = "requests" }, + { name = "scipy" }, + { name = "setuptools" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/42/fd5aecc74c747c16e98097a07053438fb2ca6d13300b1a9eb27bddaad62c/dm_control-1.0.11.tar.gz", hash = "sha256:ac222c91a34be9d9d7573a168bdce791c8a6693cb84bd3de988090a96e8df010", size = 38991406, upload-time = "2023-03-22T16:45:07.783Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/b9/11639f413e2f407e71b4fcdcc83bd5914c785180667e5eed93ce6f3c28db/dm_control-1.0.11-py3-none-any.whl", hash = "sha256:2b46def2cfc5a547f61b496fee00287fd2af52c9cd5ba7e1e7a59a6973adaad9", size = 39291059, upload-time = "2023-03-22T16:44:54.358Z" }, +] + +[[package]] +name = "dm-env" +version = "1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "dm-tree" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/62/c9/93e8d6239d5806508a2ee4b370e67c6069943ca149f59f533923737a99b7/dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de", size = 20187, upload-time = "2022-12-21T00:25:29.306Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/7e/36d548040e61337bf9182637a589c44da407a47a923ee88aec7f0e89867c/dm_env-1.6-py3-none-any.whl", hash = "sha256:0eabb6759dd453b625e041032f7ae0c1e87d4eb61b6a96b9ca586483837abf29", size = 26339, upload-time = "2022-12-21T00:25:37.128Z" }, +] + +[[package]] +name = "dm-tree" +version = "0.1.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "attrs" }, + { name = "numpy" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/83/ce29720ccf934c6cfa9b9c95ebbe96558386e66886626066632b5e44afed/dm_tree-0.1.9.tar.gz", hash = "sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b", size = 35623, upload-time = "2025-01-30T20:45:37.13Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/d2/88f685534d87072a5174fe229e77aab6b7da50092d5151ebc172f6270b5c/dm_tree-0.1.9-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5d5b28ee2e461b6af65330c143806a6d0945dcabbb8d22d2ba863e6dabd9254e", size = 173568, upload-time = "2025-03-31T08:35:38.425Z" }, + { url = "https://files.pythonhosted.org/packages/d1/6a/64924e102f559c1380263a28a751f20a1bdd18e85ea599e216feead84adf/dm_tree-0.1.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54d5616015412311df154908069fcf2c2d8786f6088a2ae3554d186cdf2b1e15", size = 146935, upload-time = "2025-01-30T20:45:16.505Z" }, + { url = "https://files.pythonhosted.org/packages/7c/79/ba0f7274164eb6bd06a36c2f8cb21b0debc32fd9ba8e73a7c9e50c90041b/dm_tree-0.1.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831699d2c60a1b38776a193b7143ae0acad0a687d87654e6d3342584166816bc", size = 152892, upload-time = "2025-01-30T20:45:18.021Z" }, + { url = "https://files.pythonhosted.org/packages/bf/20/8b96a34a15c5c4d1d6af44795963fa44381716975aabac83beab4fe80974/dm_tree-0.1.9-cp310-cp310-win_amd64.whl", hash = "sha256:1ae3cbff592bb3f2e197f5a8030de4a94e292e6cdd85adeea0b971d07a1b85f2", size = 101469, upload-time = "2025-01-30T20:45:19.197Z" }, +] + +[[package]] +name = "exceptiongroup" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, +] + +[[package]] +name = "farama-notifications" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/2c/8384832b7a6b1fd6ba95bbdcae26e7137bb3eedc955c42fd5cdcc086cfbf/Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18", size = 2131, upload-time = "2023-02-27T18:28:41.047Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/2c/ffc08c54c05cdce6fbed2aeebc46348dbe180c6d2c541c7af7ba0aa5f5f8/Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae", size = 2511, upload-time = "2023-02-27T18:28:39.447Z" }, +] + +[[package]] +name = "filelock" +version = "3.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/10/c23352565a6544bdc5353e0b15fc1c563352101f30e24bf500207a54df9a/filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2", size = 18075, upload-time = "2025-03-14T07:11:40.47Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215, upload-time = "2025-03-14T07:11:39.145Z" }, +] + +[[package]] +name = "flake8" +version = "7.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mccabe" }, + { name = "pycodestyle" }, + { name = "pyflakes" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/af/fbfe3c4b5a657d79e5c47a2827a362f9e1b763336a52f926126aa6dc7123/flake8-7.3.0.tar.gz", hash = "sha256:fe044858146b9fc69b551a4b490d69cf960fcb78ad1edcb84e7fbb1b4a8e3872", size = 48326, upload-time = "2025-06-20T19:31:35.838Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/56/13ab06b4f93ca7cac71078fbe37fcea175d3216f31f85c3168a6bbd0bb9a/flake8-7.3.0-py2.py3-none-any.whl", hash = "sha256:b9696257b9ce8beb888cdbe31cf885c90d31928fe202be0889a7cdafad32f01e", size = 57922, upload-time = "2025-06-20T19:31:34.425Z" }, +] + +[[package]] +name = "fsspec" +version = "2025.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/02/0835e6ab9cfc03916fe3f78c0956cfcdb6ff2669ffa6651065d5ebf7fc98/fsspec-2025.7.0.tar.gz", hash = "sha256:786120687ffa54b8283d942929540d8bc5ccfa820deb555a2b5d0ed2b737bf58", size = 304432, upload-time = "2025-07-15T16:05:21.19Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/e0/014d5d9d7a4564cf1c40b5039bc882db69fd881111e03ab3657ac0b218e2/fsspec-2025.7.0-py3-none-any.whl", hash = "sha256:8b012e39f63c7d5f10474de957f3ab793b47b45ae7d39f2fb735f8bbe25c0e21", size = 199597, upload-time = "2025-07-15T16:05:19.529Z" }, +] + +[[package]] +name = "glfw" +version = "2.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/38/97/a2d667c98b8474f6b8294042488c1bd488681fb3cb4c3b9cdac1a9114287/glfw-2.9.0.tar.gz", hash = "sha256:077111a150ff09bc302c5e4ae265a5eb6aeaff0c8b01f727f7fb34e3764bb8e2", size = 31453, upload-time = "2025-04-15T15:39:54.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/71/13dd8a8d547809543d21de9438a3a76a8728fc7966d01ad9fb54599aebf5/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-macosx_10_6_intel.whl", hash = "sha256:183da99152f63469e9263146db2eb1b6cc4ee0c4082b280743e57bd1b0a3bd70", size = 105297, upload-time = "2025-04-15T15:39:39.677Z" }, + { url = "https://files.pythonhosted.org/packages/f8/a2/45e6dceec1e0a0ffa8dd3c0ecf1e11d74639a55186243129160c6434d456/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-macosx_11_0_arm64.whl", hash = "sha256:aef5b555673b9555216e4cd7bc0bdbbb9983f66c620a85ba7310cfcfda5cd38c", size = 102146, upload-time = "2025-04-15T15:39:42.354Z" }, + { url = "https://files.pythonhosted.org/packages/d2/72/b6261ed918e3747c6070fe80636c63a3c8f1c42ce122670315eeeada156f/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux2014_aarch64.whl", hash = "sha256:fcc430cb21984afba74945b7df38a5e1a02b36c0b4a2a2bab42b4a26d7cc51d6", size = 230002, upload-time = "2025-04-15T15:39:43.933Z" }, + { url = "https://files.pythonhosted.org/packages/45/d6/7f95786332e8b798569b8e60db2ee081874cec2a62572b8ec55c309d85b7/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux2014_x86_64.whl", hash = "sha256:7f85b58546880466ac445fc564c5c831ca93c8a99795ab8eaf0a2d521af293d7", size = 241949, upload-time = "2025-04-15T15:39:45.28Z" }, + { url = "https://files.pythonhosted.org/packages/a1/e6/093ab7874a74bba351e754f6e7748c031bd7276702135da6cbcd00e1f3e2/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_aarch64.whl", hash = "sha256:2123716c8086b80b797e849a534fc6f21aebca300519e57c80618a65ca8135dc", size = 231016, upload-time = "2025-04-15T15:39:46.669Z" }, + { url = "https://files.pythonhosted.org/packages/7f/ba/de3630757c7d7fc2086aaf3994926d6b869d31586e4d0c14f1666af31b93/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl", hash = "sha256:4e11271e49eb9bc53431ade022e284d5a59abeace81fe3b178db1bf3ccc0c449", size = 243489, upload-time = "2025-04-15T15:39:48.321Z" }, + { url = "https://files.pythonhosted.org/packages/32/36/c3bada8503681806231d1705ea1802bac8febf69e4186b9f0f0b9e2e4f7e/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-win32.whl", hash = "sha256:8e4fbff88e4e953bb969b6813195d5de4641f886530cc8083897e56b00bf2c8e", size = 552655, upload-time = "2025-04-15T15:39:50.029Z" }, + { url = "https://files.pythonhosted.org/packages/cb/70/7f2f052ca20c3b69892818f2ee1fea53b037ea9145ff75b944ed1dc4ff82/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-win_amd64.whl", hash = "sha256:9aa3ae51601601c53838315bd2a03efb1e6bebecd072b2f64ddbd0b2556d511a", size = 559441, upload-time = "2025-04-15T15:39:52.531Z" }, +] + +[[package]] +name = "google-auth" +version = "2.40.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029, upload-time = "2025-06-04T18:04:57.577Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" }, +] + +[[package]] +name = "google-auth-oauthlib" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "requests-oauthlib" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fb/87/e10bf24f7bcffc1421b84d6f9c3377c30ec305d082cd737ddaa6d8f77f7c/google_auth_oauthlib-1.2.2.tar.gz", hash = "sha256:11046fb8d3348b296302dd939ace8af0a724042e8029c1b872d87fabc9f41684", size = 20955, upload-time = "2025-04-22T16:40:29.172Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/84/40ee070be95771acd2f4418981edb834979424565c3eec3cd88b6aa09d24/google_auth_oauthlib-1.2.2-py3-none-any.whl", hash = "sha256:fd619506f4b3908b5df17b65f39ca8d66ea56986e5472eb5978fd8f3786f00a2", size = 19072, upload-time = "2025-04-22T16:40:28.174Z" }, +] + +[[package]] +name = "grpcio" +version = "1.73.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/79/e8/b43b851537da2e2f03fa8be1aef207e5cbfb1a2e014fbb6b40d24c177cd3/grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87", size = 12730355, upload-time = "2025-06-26T01:53:24.622Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/51/a5748ab2773d893d099b92653039672f7e26dd35741020972b84d604066f/grpcio-1.73.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:2d70f4ddd0a823436c2624640570ed6097e40935c9194482475fe8e3d9754d55", size = 5365087, upload-time = "2025-06-26T01:51:44.541Z" }, + { url = "https://files.pythonhosted.org/packages/ae/12/c5ee1a5dfe93dbc2eaa42a219e2bf887250b52e2e2ee5c036c4695f2769c/grpcio-1.73.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:3841a8a5a66830261ab6a3c2a3dc539ed84e4ab019165f77b3eeb9f0ba621f26", size = 10608921, upload-time = "2025-06-26T01:51:48.111Z" }, + { url = "https://files.pythonhosted.org/packages/c4/6d/b0c6a8120f02b7d15c5accda6bfc43bc92be70ada3af3ba6d8e077c00374/grpcio-1.73.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:628c30f8e77e0258ab788750ec92059fc3d6628590fb4b7cea8c102503623ed7", size = 5803221, upload-time = "2025-06-26T01:51:50.486Z" }, + { url = "https://files.pythonhosted.org/packages/a6/7a/3c886d9f1c1e416ae81f7f9c7d1995ae72cd64712d29dab74a6bafacb2d2/grpcio-1.73.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:67a0468256c9db6d5ecb1fde4bf409d016f42cef649323f0a08a72f352d1358b", size = 6444603, upload-time = "2025-06-26T01:51:52.203Z" }, + { url = "https://files.pythonhosted.org/packages/42/07/f143a2ff534982c9caa1febcad1c1073cdec732f6ac7545d85555a900a7e/grpcio-1.73.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68b84d65bbdebd5926eb5c53b0b9ec3b3f83408a30e4c20c373c5337b4219ec5", size = 6040969, upload-time = "2025-06-26T01:51:55.028Z" }, + { url = "https://files.pythonhosted.org/packages/fb/0f/523131b7c9196d0718e7b2dac0310eb307b4117bdbfef62382e760f7e8bb/grpcio-1.73.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c54796ca22b8349cc594d18b01099e39f2b7ffb586ad83217655781a350ce4da", size = 6132201, upload-time = "2025-06-26T01:51:56.867Z" }, + { url = "https://files.pythonhosted.org/packages/ad/18/010a055410eef1d3a7a1e477ec9d93b091ac664ad93e9c5f56d6cc04bdee/grpcio-1.73.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:75fc8e543962ece2f7ecd32ada2d44c0c8570ae73ec92869f9af8b944863116d", size = 6774718, upload-time = "2025-06-26T01:51:58.338Z" }, + { url = "https://files.pythonhosted.org/packages/16/11/452bfc1ab39d8ee748837ab8ee56beeae0290861052948785c2c445fb44b/grpcio-1.73.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6a6037891cd2b1dd1406b388660522e1565ed340b1fea2955b0234bdd941a862", size = 6304362, upload-time = "2025-06-26T01:51:59.802Z" }, + { url = "https://files.pythonhosted.org/packages/1e/1c/c75ceee626465721e5cb040cf4b271eff817aa97388948660884cb7adffa/grpcio-1.73.1-cp310-cp310-win32.whl", hash = "sha256:cce7265b9617168c2d08ae570fcc2af4eaf72e84f8c710ca657cc546115263af", size = 3679036, upload-time = "2025-06-26T01:52:01.817Z" }, + { url = "https://files.pythonhosted.org/packages/62/2e/42cb31b6cbd671a7b3dbd97ef33f59088cf60e3cf2141368282e26fafe79/grpcio-1.73.1-cp310-cp310-win_amd64.whl", hash = "sha256:6a2b372e65fad38842050943f42ce8fee00c6f2e8ea4f7754ba7478d26a356ee", size = 4340208, upload-time = "2025-06-26T01:52:03.674Z" }, +] + +[[package]] +name = "gymnasium" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "farama-notifications" }, + { name = "jax-jumpy" }, + { name = "numpy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/6a/c304954dc009648a21db245a8f56f63c8da8a025d446dd0fd67319726003/gymnasium-0.28.1.tar.gz", hash = "sha256:4c2c745808792c8f45c6e88ad0a5504774394e0c126f6e3db555e720d3da6f24", size = 796462, upload-time = "2023-03-25T12:02:00.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/82/3762ef4555791a729ae554e13c011efe5e8347d7eba9ea5ed245a8d1b234/gymnasium-0.28.1-py3-none-any.whl", hash = "sha256:7bc9a5bce1022f997d1dbc152fc91d1ac977bad9cc7794cdc25437010867cabf", size = 925534, upload-time = "2023-03-25T12:01:58.35Z" }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + +[[package]] +name = "jax-jumpy" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/6a/b6affff68f172a4c8316d9ab9b7d952e865df15b854f158690991864e0fe/jax-jumpy-1.0.0.tar.gz", hash = "sha256:195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad", size = 19417, upload-time = "2023-03-17T16:52:56.598Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/23/338caee543d80584916da20f018aeb017764509d964fd347b97f41f97baa/jax_jumpy-1.0.0-py3-none-any.whl", hash = "sha256:ab7e01454bba462de3c4d098e3e585c302a8f06bc36d9182ab4e7e4aa7067c5e", size = 20368, upload-time = "2023-03-17T16:52:55.437Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "labmaze" +version = "1.0.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "numpy" }, + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/0a/139c4ae896b9413bd4ca69c62b08ee98dcfc78a9cbfdb7cadd0dce2ad31d/labmaze-1.0.6.tar.gz", hash = "sha256:2e8de7094042a77d6972f1965cf5c9e8f971f1b34d225752f343190a825ebe73", size = 4670455, upload-time = "2022-12-05T18:42:43.566Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/0c/6a3941f48644c0b9305c7a22bd51974be1fed8e9233b16c893d728805143/labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a", size = 4815423, upload-time = "2022-12-05T18:41:47.351Z" }, + { url = "https://files.pythonhosted.org/packages/d0/fe/b038c6a15732eb064767dc92ca39a38b2f5df183576384f0cfb6a4840f69/labmaze-1.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:157efaa93228c8ccce5cae337902dd652093e0fba9d3a0f6506e4bee272bb66f", size = 4806825, upload-time = "2022-12-05T18:41:49.922Z" }, + { url = "https://files.pythonhosted.org/packages/59/ec/2762281d4f26845b20bb7529742a6914fcb07c8e7c522175b879df0127cf/labmaze-1.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3ce98b9541c5fe6a306e411e7d018121dd646f2c9978d763fad86f9f30c5f57", size = 4871532, upload-time = "2022-12-05T18:41:52.784Z" }, + { url = "https://files.pythonhosted.org/packages/4d/93/abac7877e1d7de984a2f0f5be561ff0dc795ae7e22595cf2f7c7032cd27e/labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e6433bd49bc541791de8191040526fddfebb77151620eb04203453f43ee486a", size = 4875892, upload-time = "2022-12-05T18:41:55.603Z" }, + { url = "https://files.pythonhosted.org/packages/c4/10/5262db11b3c1db8e4fbc3feed9baed4f95db6047b8d9dcaf4f9fb8da9ba3/labmaze-1.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:6a507fc35961f1b1479708e2716f65e0d0611cefb55f31a77be29ce2339b6fef", size = 4812953, upload-time = "2022-12-05T18:41:58.098Z" }, +] + +[[package]] +name = "lxml" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c5/ed/60eb6fa2923602fba988d9ca7c5cdbd7cf25faa795162ed538b527a35411/lxml-6.0.0.tar.gz", hash = "sha256:032e65120339d44cdc3efc326c9f660f5f7205f3a535c1fdbf898b29ea01fb72", size = 4096938, upload-time = "2025-06-26T16:28:19.373Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/e9/9c3ca02fbbb7585116c2e274b354a2d92b5c70561687dd733ec7b2018490/lxml-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35bc626eec405f745199200ccb5c6b36f202675d204aa29bb52e27ba2b71dea8", size = 8399057, upload-time = "2025-06-26T16:25:02.169Z" }, + { url = "https://files.pythonhosted.org/packages/86/25/10a6e9001191854bf283515020f3633b1b1f96fd1b39aa30bf8fff7aa666/lxml-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:246b40f8a4aec341cbbf52617cad8ab7c888d944bfe12a6abd2b1f6cfb6f6082", size = 4569676, upload-time = "2025-06-26T16:25:05.431Z" }, + { url = "https://files.pythonhosted.org/packages/f5/a5/378033415ff61d9175c81de23e7ad20a3ffb614df4ffc2ffc86bc6746ffd/lxml-6.0.0-cp310-cp310-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:2793a627e95d119e9f1e19720730472f5543a6d84c50ea33313ce328d870f2dd", size = 5291361, upload-time = "2025-06-26T16:25:07.901Z" }, + { url = "https://files.pythonhosted.org/packages/5a/a6/19c87c4f3b9362b08dc5452a3c3bce528130ac9105fc8fff97ce895ce62e/lxml-6.0.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:46b9ed911f36bfeb6338e0b482e7fe7c27d362c52fde29f221fddbc9ee2227e7", size = 5008290, upload-time = "2025-06-28T18:47:13.196Z" }, + { url = "https://files.pythonhosted.org/packages/09/d1/e9b7ad4b4164d359c4d87ed8c49cb69b443225cb495777e75be0478da5d5/lxml-6.0.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b4790b558bee331a933e08883c423f65bbcd07e278f91b2272489e31ab1e2b4", size = 5163192, upload-time = "2025-06-28T18:47:17.279Z" }, + { url = "https://files.pythonhosted.org/packages/56/d6/b3eba234dc1584744b0b374a7f6c26ceee5dc2147369a7e7526e25a72332/lxml-6.0.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2030956cf4886b10be9a0285c6802e078ec2391e1dd7ff3eb509c2c95a69b76", size = 5076973, upload-time = "2025-06-26T16:25:10.936Z" }, + { url = "https://files.pythonhosted.org/packages/8e/47/897142dd9385dcc1925acec0c4afe14cc16d310ce02c41fcd9010ac5d15d/lxml-6.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d23854ecf381ab1facc8f353dcd9adeddef3652268ee75297c1164c987c11dc", size = 5297795, upload-time = "2025-06-26T16:25:14.282Z" }, + { url = "https://files.pythonhosted.org/packages/fb/db/551ad84515c6f415cea70193a0ff11d70210174dc0563219f4ce711655c6/lxml-6.0.0-cp310-cp310-manylinux_2_31_armv7l.whl", hash = "sha256:43fe5af2d590bf4691531b1d9a2495d7aab2090547eaacd224a3afec95706d76", size = 4776547, upload-time = "2025-06-26T16:25:17.123Z" }, + { url = "https://files.pythonhosted.org/packages/e0/14/c4a77ab4f89aaf35037a03c472f1ccc54147191888626079bd05babd6808/lxml-6.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:74e748012f8c19b47f7d6321ac929a9a94ee92ef12bc4298c47e8b7219b26541", size = 5124904, upload-time = "2025-06-26T16:25:19.485Z" }, + { url = "https://files.pythonhosted.org/packages/70/b4/12ae6a51b8da106adec6a2e9c60f532350a24ce954622367f39269e509b1/lxml-6.0.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:43cfbb7db02b30ad3926e8fceaef260ba2fb7df787e38fa2df890c1ca7966c3b", size = 4805804, upload-time = "2025-06-26T16:25:21.949Z" }, + { url = "https://files.pythonhosted.org/packages/a9/b6/2e82d34d49f6219cdcb6e3e03837ca5fb8b7f86c2f35106fb8610ac7f5b8/lxml-6.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:34190a1ec4f1e84af256495436b2d196529c3f2094f0af80202947567fdbf2e7", size = 5323477, upload-time = "2025-06-26T16:25:24.475Z" }, + { url = "https://files.pythonhosted.org/packages/a1/e6/b83ddc903b05cd08a5723fefd528eee84b0edd07bdf87f6c53a1fda841fd/lxml-6.0.0-cp310-cp310-win32.whl", hash = "sha256:5967fe415b1920a3877a4195e9a2b779249630ee49ece22021c690320ff07452", size = 3613840, upload-time = "2025-06-26T16:25:27.345Z" }, + { url = "https://files.pythonhosted.org/packages/40/af/874fb368dd0c663c030acb92612341005e52e281a102b72a4c96f42942e1/lxml-6.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:f3389924581d9a770c6caa4df4e74b606180869043b9073e2cec324bad6e306e", size = 3993584, upload-time = "2025-06-26T16:25:29.391Z" }, + { url = "https://files.pythonhosted.org/packages/4a/f4/d296bc22c17d5607653008f6dd7b46afdfda12efd31021705b507df652bb/lxml-6.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:522fe7abb41309e9543b0d9b8b434f2b630c5fdaf6482bee642b34c8c70079c8", size = 3681400, upload-time = "2025-06-26T16:25:31.421Z" }, + { url = "https://files.pythonhosted.org/packages/66/e1/2c22a3cff9e16e1d717014a1e6ec2bf671bf56ea8716bb64466fcf820247/lxml-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:dbdd7679a6f4f08152818043dbb39491d1af3332128b3752c3ec5cebc0011a72", size = 3898804, upload-time = "2025-06-26T16:27:59.751Z" }, + { url = "https://files.pythonhosted.org/packages/2b/3a/d68cbcb4393a2a0a867528741fafb7ce92dac5c9f4a1680df98e5e53e8f5/lxml-6.0.0-pp310-pypy310_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40442e2a4456e9910875ac12951476d36c0870dcb38a68719f8c4686609897c4", size = 4216406, upload-time = "2025-06-28T18:47:45.518Z" }, + { url = "https://files.pythonhosted.org/packages/15/8f/d9bfb13dff715ee3b2a1ec2f4a021347ea3caf9aba93dea0cfe54c01969b/lxml-6.0.0-pp310-pypy310_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:db0efd6bae1c4730b9c863fc4f5f3c0fa3e8f05cae2c44ae141cb9dfc7d091dc", size = 4326455, upload-time = "2025-06-28T18:47:48.411Z" }, + { url = "https://files.pythonhosted.org/packages/01/8b/fde194529ee8a27e6f5966d7eef05fa16f0567e4a8e8abc3b855ef6b3400/lxml-6.0.0-pp310-pypy310_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ab542c91f5a47aaa58abdd8ea84b498e8e49fe4b883d67800017757a3eb78e8", size = 4268788, upload-time = "2025-06-26T16:28:02.776Z" }, + { url = "https://files.pythonhosted.org/packages/99/a8/3b8e2581b4f8370fc9e8dc343af4abdfadd9b9229970fc71e67bd31c7df1/lxml-6.0.0-pp310-pypy310_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:013090383863b72c62a702d07678b658fa2567aa58d373d963cca245b017e065", size = 4411394, upload-time = "2025-06-26T16:28:05.179Z" }, + { url = "https://files.pythonhosted.org/packages/e7/a5/899a4719e02ff4383f3f96e5d1878f882f734377f10dfb69e73b5f223e44/lxml-6.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c86df1c9af35d903d2b52d22ea3e66db8058d21dc0f59842ca5deb0595921141", size = 3517946, upload-time = "2025-06-26T16:28:07.665Z" }, +] + +[[package]] +name = "markdown" +version = "3.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/c2/4ab49206c17f75cb08d6311171f2d65798988db4360c4d1485bd0eedd67c/markdown-3.8.2.tar.gz", hash = "sha256:247b9a70dd12e27f67431ce62523e675b866d254f900c4fe75ce3dda62237c45", size = 362071, upload-time = "2025-06-19T17:12:44.483Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/2b/34cc11786bc00d0f04d0f5fdc3a2b1ae0b6239eef72d3d345805f9ad92a1/markdown-3.8.2-py3-none-any.whl", hash = "sha256:5c83764dbd4e00bdd94d85a19b8d55ccca20fe35b2e678a1422b380324dd5f24", size = 106827, upload-time = "2025-06-19T17:12:42.994Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537, upload-time = "2024-10-18T15:21:54.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/90/d08277ce111dd22f77149fd1a5d4653eeb3b3eaacbdfcbae5afb2600eebd/MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8", size = 14357, upload-time = "2024-10-18T15:20:51.44Z" }, + { url = "https://files.pythonhosted.org/packages/04/e1/6e2194baeae0bca1fae6629dc0cbbb968d4d941469cbab11a3872edff374/MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158", size = 12393, upload-time = "2024-10-18T15:20:52.426Z" }, + { url = "https://files.pythonhosted.org/packages/1d/69/35fa85a8ece0a437493dc61ce0bb6d459dcba482c34197e3efc829aa357f/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579", size = 21732, upload-time = "2024-10-18T15:20:53.578Z" }, + { url = "https://files.pythonhosted.org/packages/22/35/137da042dfb4720b638d2937c38a9c2df83fe32d20e8c8f3185dbfef05f7/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d", size = 20866, upload-time = "2024-10-18T15:20:55.06Z" }, + { url = "https://files.pythonhosted.org/packages/29/28/6d029a903727a1b62edb51863232152fd335d602def598dade38996887f0/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb", size = 20964, upload-time = "2024-10-18T15:20:55.906Z" }, + { url = "https://files.pythonhosted.org/packages/cc/cd/07438f95f83e8bc028279909d9c9bd39e24149b0d60053a97b2bc4f8aa51/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b", size = 21977, upload-time = "2024-10-18T15:20:57.189Z" }, + { url = "https://files.pythonhosted.org/packages/29/01/84b57395b4cc062f9c4c55ce0df7d3108ca32397299d9df00fedd9117d3d/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c", size = 21366, upload-time = "2024-10-18T15:20:58.235Z" }, + { url = "https://files.pythonhosted.org/packages/bd/6e/61ebf08d8940553afff20d1fb1ba7294b6f8d279df9fd0c0db911b4bbcfd/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171", size = 21091, upload-time = "2024-10-18T15:20:59.235Z" }, + { url = "https://files.pythonhosted.org/packages/11/23/ffbf53694e8c94ebd1e7e491de185124277964344733c45481f32ede2499/MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50", size = 15065, upload-time = "2024-10-18T15:21:00.307Z" }, + { url = "https://files.pythonhosted.org/packages/44/06/e7175d06dd6e9172d4a69a72592cb3f7a996a9c396eee29082826449bbc3/MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a", size = 15514, upload-time = "2024-10-18T15:21:01.122Z" }, +] + +[[package]] +name = "mccabe" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/ff/0ffefdcac38932a54d2b5eed4e0ba8a408f215002cd178ad1df0f2806ff8/mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325", size = 9658, upload-time = "2022-01-24T01:14:51.113Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350, upload-time = "2022-01-24T01:14:49.62Z" }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + +[[package]] +name = "mujoco" +version = "2.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "glfw" }, + { name = "numpy" }, + { name = "pyopengl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7c/03/7ce6078745085febd22fc4586a63016a71125031b97883d456ce1d64e5ed/mujoco-2.3.3.zip", hash = "sha256:8bd074d3c5d9d25416cf2a5b82b337a7431a6e20edbd0da7fbc05ee5255c1aaa", size = 633278, upload-time = "2023-03-20T18:23:59.89Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/2c/f59255b4fbd159a374c4077721bc5baac11afcc15b8a28f8c7658ac89df5/mujoco-2.3.3-2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:7b95a0b7ae8bb9e36d04ba475a950025791c43087845235bb92bd2dd1787589a", size = 4308671, upload-time = "2023-03-20T18:09:30.955Z" }, + { url = "https://files.pythonhosted.org/packages/31/20/afc0ef5d5b9d96f3853458329f39518e638f41c0439c94ae2b97a9ab9f3a/mujoco-2.3.3-2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8b9f8e9e6d47fe60f96dc54be780a66a38d5ec2ff94d091ad54c6f87468b4b6a", size = 4177357, upload-time = "2023-03-20T18:09:34.016Z" }, + { url = "https://files.pythonhosted.org/packages/6d/da/27b0ef31aa23f64c21e7129ddf378c2548be45d87433f89e2bea4cafc811/mujoco-2.3.3-2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9cbd3b60ac30f0b6661de050ca3e1de906b9c28ed9d084bdd708c141953247d", size = 3995762, upload-time = "2023-03-20T18:09:36.335Z" }, + { url = "https://files.pythonhosted.org/packages/6d/27/90cc9b4f88c5b797417e1fbeacb7590cd85f7e464a8ab79f60c885708e39/mujoco-2.3.3-2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c7fd195bca86102788d86dbc2773c59c1f0ee3933b35eb0f6a86f0f2aeb5065", size = 4270849, upload-time = "2023-03-20T18:09:38.593Z" }, + { url = "https://files.pythonhosted.org/packages/cd/b3/e9119ebbbe9ea830e6c8ab7eafc0de7c82b38d6b71d2b2e38ba20e43a1b7/mujoco-2.3.3-2-cp310-cp310-win_amd64.whl", hash = "sha256:f3595e992770eff3f842cb80f7eb2b7b1b3e78995b6ecc247f98036da17ef74f", size = 3190540, upload-time = "2023-03-20T18:09:41.101Z" }, +] + +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + +[[package]] +name = "networkx" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" }, +] + +[[package]] +name = "numpy" +version = "1.26.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129, upload-time = "2024-02-06T00:26:44.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/94/ace0fdea5241a27d13543ee117cbc65868e82213fb31a8eb7fe9ff23f313/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", size = 20631468, upload-time = "2024-02-05T23:48:01.194Z" }, + { url = "https://files.pythonhosted.org/packages/20/f7/b24208eba89f9d1b58c1668bc6c8c4fd472b20c45573cb767f59d49fb0f6/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a", size = 13966411, upload-time = "2024-02-05T23:48:29.038Z" }, + { url = "https://files.pythonhosted.org/packages/fc/a5/4beee6488160798683eed5bdb7eead455892c3b4e1f78d79d8d3f3b084ac/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4", size = 14219016, upload-time = "2024-02-05T23:48:54.098Z" }, + { url = "https://files.pythonhosted.org/packages/4b/d7/ecf66c1cd12dc28b4040b15ab4d17b773b87fa9d29ca16125de01adb36cd/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f", size = 18240889, upload-time = "2024-02-05T23:49:25.361Z" }, + { url = "https://files.pythonhosted.org/packages/24/03/6f229fe3187546435c4f6f89f6d26c129d4f5bed40552899fcf1f0bf9e50/numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a", size = 13876746, upload-time = "2024-02-05T23:49:51.983Z" }, + { url = "https://files.pythonhosted.org/packages/39/fe/39ada9b094f01f5a35486577c848fe274e374bbf8d8f472e1423a0bbd26d/numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2", size = 18078620, upload-time = "2024-02-05T23:50:22.515Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ef/6ad11d51197aad206a9ad2286dc1aac6a378059e06e8cf22cd08ed4f20dc/numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07", size = 5972659, upload-time = "2024-02-05T23:50:35.834Z" }, + { url = "https://files.pythonhosted.org/packages/19/77/538f202862b9183f54108557bfda67e17603fc560c384559e769321c9d92/numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5", size = 15808905, upload-time = "2024-02-05T23:51:03.701Z" }, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774, upload-time = "2023-04-19T15:50:03.519Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015, upload-time = "2023-04-19T15:47:32.502Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734, upload-time = "2023-04-19T15:48:32.42Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596, upload-time = "2023-04-19T15:47:22.471Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9", size = 731725872, upload-time = "2023-06-01T19:24:57.328Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161, upload-time = "2023-04-19T15:50:46Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784, upload-time = "2023-04-19T15:51:04.804Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cusparse-cu12" }, + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928, upload-time = "2023-04-19T15:51:25.781Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278, upload-time = "2023-04-19T15:51:49.939Z" }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.19.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/00/d0d4e48aef772ad5aebcf70b73028f88db6e5640b36c38e90445b7a57c45/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d", size = 165987969, upload-time = "2023-10-24T16:16:24.789Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.9.86" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9", size = 39748338, upload-time = "2025-06-05T20:10:25.613Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138, upload-time = "2023-04-19T15:48:43.556Z" }, +] + +[[package]] +name = "oauthlib" +version = "3.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/5f/19930f824ffeb0ad4372da4812c50edbd1434f678c90c2733e1188edfc63/oauthlib-3.3.1.tar.gz", hash = "sha256:0f0f8aa759826a193cf66c12ea1af1637f87b9b4622d46e866952bb022e538c9", size = 185918, upload-time = "2025-06-19T22:48:08.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" }, +] + +[[package]] +name = "oprl" +version = "0.1.0" +source = { editable = "." } +dependencies = [ + { name = "dm-control" }, + { name = "gymnasium" }, + { name = "mujoco" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pika" }, + { name = "pydantic-settings" }, + { name = "ruff" }, + { name = "tensorboard" }, + { name = "torch" }, +] + +[package.optional-dependencies] +dev = [ + { name = "black" }, + { name = "flake8" }, + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [ + { name = "black", marker = "extra == 'dev'" }, + { name = "dm-control", specifier = "==1.0.11" }, + { name = "flake8", marker = "extra == 'dev'" }, + { name = "gymnasium", specifier = "==0.28.1" }, + { name = "mujoco", specifier = "==2.3.3" }, + { name = "numpy", specifier = "==1.26.4" }, + { name = "packaging", specifier = "==23.2" }, + { name = "pika", specifier = "==1.3.2" }, + { name = "pydantic-settings", specifier = "==2.10.1" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=6.0" }, + { name = "ruff", specifier = ">=0.12.3" }, + { name = "tensorboard", specifier = "==2.15.1" }, + { name = "torch", specifier = "==2.2.2" }, +] +provides-extras = ["dev"] + +[[package]] +name = "packaging" +version = "23.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" }, +] + +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, +] + +[[package]] +name = "pika" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/db/db/d4102f356af18f316c67f2cead8ece307f731dd63140e2c71f170ddacf9b/pika-1.3.2.tar.gz", hash = "sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f", size = 145029, upload-time = "2023-05-05T14:25:43.368Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/f3/f412836ec714d36f0f4ab581b84c491e3f42c6b5b97a6c6ed1817f3c16d0/pika-1.3.2-py3-none-any.whl", hash = "sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f", size = 155415, upload-time = "2023-05-05T14:25:41.484Z" }, +] + +[[package]] +name = "platformdirs" +version = "4.3.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/8b/3c73abc9c759ecd3f1f7ceff6685840859e8070c4d947c93fae71f6a0bf2/platformdirs-4.3.8.tar.gz", hash = "sha256:3d512d96e16bcb959a814c9f348431070822a6496326a4be0911c40b5a74c2bc", size = 21362, upload-time = "2025-05-07T22:47:42.121Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/39/979e8e21520d4e47a0bbe349e2713c0aac6f3d853d0e5b34d76206c439aa/platformdirs-4.3.8-py3-none-any.whl", hash = "sha256:ff7059bb7eb1179e2685604f4aaf157cfd9535242bd23742eadc3c13542139b4", size = 18567, upload-time = "2025-05-07T22:47:40.376Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "protobuf" +version = "4.23.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/1c/de86d82a5fc780feca36ef52c1231823bb3140266af8a04ed6286957aa6e/protobuf-4.23.4.tar.gz", hash = "sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9", size = 400173, upload-time = "2023-07-06T23:28:22.071Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/a7/872807299eb114956c665fb1717ce106a8874db08a724651ac4f78c1198c/protobuf-4.23.4-cp310-abi3-win32.whl", hash = "sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b", size = 402981, upload-time = "2023-07-06T23:27:57.401Z" }, + { url = "https://files.pythonhosted.org/packages/80/70/dc63d340d27b8ff22022d7dd14b8d6d68b479a003eacdc4507150a286d9a/protobuf-4.23.4-cp310-abi3-win_amd64.whl", hash = "sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12", size = 422467, upload-time = "2023-07-06T23:28:00.387Z" }, + { url = "https://files.pythonhosted.org/packages/cb/d3/a164038605494d49acc4f9cda1c0bc200b96382c53edd561387263bb181d/protobuf-4.23.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd", size = 400308, upload-time = "2023-07-06T23:28:02.356Z" }, + { url = "https://files.pythonhosted.org/packages/71/42/3a7fc57f360f728f38eca6656e8d00edaf22bc0ffc35dd2936f23e5fbb3e/protobuf-4.23.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a", size = 303455, upload-time = "2023-07-06T23:28:04.292Z" }, + { url = "https://files.pythonhosted.org/packages/01/cb/445b3e465abdb8042a41957dc8f60c54620dc7540dbcf9b458a921531ca2/protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597", size = 304498, upload-time = "2023-07-06T23:28:06.277Z" }, + { url = "https://files.pythonhosted.org/packages/b0/07/fb712cce15ba456f7c24b82b97c8a7db2233f07037ffe61c9011660c592a/protobuf-4.23.4-py3-none-any.whl", hash = "sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff", size = 173332, upload-time = "2023-07-06T23:28:20.053Z" }, +] + +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + +[[package]] +name = "pycodestyle" +version = "2.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/e0/abfd2a0d2efe47670df87f3e3a0e2edda42f055053c85361f19c0e2c1ca8/pycodestyle-2.14.0.tar.gz", hash = "sha256:c4b5b517d278089ff9d0abdec919cd97262a3367449ea1c8b49b91529167b783", size = 39472, upload-time = "2025-06-20T18:49:48.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/27/a58ddaf8c588a3ef080db9d0b7e0b97215cee3a45df74f3a94dbbf5c893a/pycodestyle-2.14.0-py2.py3-none-any.whl", hash = "sha256:dd6bf7cb4ee77f8e016f9c8e74a35ddd9f67e1d5fd4184d86c3b98e07099f42d", size = 31594, upload-time = "2025-06-20T18:49:47.491Z" }, +] + +[[package]] +name = "pydantic" +version = "2.11.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/dd/4325abf92c39ba8623b5af936ddb36ffcfe0beae70405d456ab1fb2f5b8c/pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db", size = 788350, upload-time = "2025-06-14T08:33:17.137Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/92/b31726561b5dae176c2d2c2dc43a9c5bfba5d32f96f8b4c0a600dd492447/pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8", size = 2028817, upload-time = "2025-04-23T18:30:43.919Z" }, + { url = "https://files.pythonhosted.org/packages/a3/44/3f0b95fafdaca04a483c4e685fe437c6891001bf3ce8b2fded82b9ea3aa1/pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d", size = 1861357, upload-time = "2025-04-23T18:30:46.372Z" }, + { url = "https://files.pythonhosted.org/packages/30/97/e8f13b55766234caae05372826e8e4b3b96e7b248be3157f53237682e43c/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d", size = 1898011, upload-time = "2025-04-23T18:30:47.591Z" }, + { url = "https://files.pythonhosted.org/packages/9b/a3/99c48cf7bafc991cc3ee66fd544c0aae8dc907b752f1dad2d79b1b5a471f/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572", size = 1982730, upload-time = "2025-04-23T18:30:49.328Z" }, + { url = "https://files.pythonhosted.org/packages/de/8e/a5b882ec4307010a840fb8b58bd9bf65d1840c92eae7534c7441709bf54b/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02", size = 2136178, upload-time = "2025-04-23T18:30:50.907Z" }, + { url = "https://files.pythonhosted.org/packages/e4/bb/71e35fc3ed05af6834e890edb75968e2802fe98778971ab5cba20a162315/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b", size = 2736462, upload-time = "2025-04-23T18:30:52.083Z" }, + { url = "https://files.pythonhosted.org/packages/31/0d/c8f7593e6bc7066289bbc366f2235701dcbebcd1ff0ef8e64f6f239fb47d/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2", size = 2005652, upload-time = "2025-04-23T18:30:53.389Z" }, + { url = "https://files.pythonhosted.org/packages/d2/7a/996d8bd75f3eda405e3dd219ff5ff0a283cd8e34add39d8ef9157e722867/pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a", size = 2113306, upload-time = "2025-04-23T18:30:54.661Z" }, + { url = "https://files.pythonhosted.org/packages/ff/84/daf2a6fb2db40ffda6578a7e8c5a6e9c8affb251a05c233ae37098118788/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac", size = 2073720, upload-time = "2025-04-23T18:30:56.11Z" }, + { url = "https://files.pythonhosted.org/packages/77/fb/2258da019f4825128445ae79456a5499c032b55849dbd5bed78c95ccf163/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a", size = 2244915, upload-time = "2025-04-23T18:30:57.501Z" }, + { url = "https://files.pythonhosted.org/packages/d8/7a/925ff73756031289468326e355b6fa8316960d0d65f8b5d6b3a3e7866de7/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b", size = 2241884, upload-time = "2025-04-23T18:30:58.867Z" }, + { url = "https://files.pythonhosted.org/packages/0b/b0/249ee6d2646f1cdadcb813805fe76265745c4010cf20a8eba7b0e639d9b2/pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22", size = 1910496, upload-time = "2025-04-23T18:31:00.078Z" }, + { url = "https://files.pythonhosted.org/packages/66/ff/172ba8f12a42d4b552917aa65d1f2328990d3ccfc01d5b7c943ec084299f/pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640", size = 1955019, upload-time = "2025-04-23T18:31:01.335Z" }, + { url = "https://files.pythonhosted.org/packages/30/68/373d55e58b7e83ce371691f6eaa7175e3a24b956c44628eb25d7da007917/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa", size = 2023982, upload-time = "2025-04-23T18:32:53.14Z" }, + { url = "https://files.pythonhosted.org/packages/a4/16/145f54ac08c96a63d8ed6442f9dec17b2773d19920b627b18d4f10a061ea/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29", size = 1858412, upload-time = "2025-04-23T18:32:55.52Z" }, + { url = "https://files.pythonhosted.org/packages/41/b1/c6dc6c3e2de4516c0bb2c46f6a373b91b5660312342a0cf5826e38ad82fa/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d", size = 1892749, upload-time = "2025-04-23T18:32:57.546Z" }, + { url = "https://files.pythonhosted.org/packages/12/73/8cd57e20afba760b21b742106f9dbdfa6697f1570b189c7457a1af4cd8a0/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e", size = 2067527, upload-time = "2025-04-23T18:32:59.771Z" }, + { url = "https://files.pythonhosted.org/packages/e3/d5/0bb5d988cc019b3cba4a78f2d4b3854427fc47ee8ec8e9eaabf787da239c/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c", size = 2108225, upload-time = "2025-04-23T18:33:04.51Z" }, + { url = "https://files.pythonhosted.org/packages/f1/c5/00c02d1571913d496aabf146106ad8239dc132485ee22efe08085084ff7c/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec", size = 2069490, upload-time = "2025-04-23T18:33:06.391Z" }, + { url = "https://files.pythonhosted.org/packages/22/a8/dccc38768274d3ed3a59b5d06f59ccb845778687652daa71df0cab4040d7/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052", size = 2237525, upload-time = "2025-04-23T18:33:08.44Z" }, + { url = "https://files.pythonhosted.org/packages/d4/e7/4f98c0b125dda7cf7ccd14ba936218397b44f50a56dd8c16a3091df116c3/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c", size = 2238446, upload-time = "2025-04-23T18:33:10.313Z" }, + { url = "https://files.pythonhosted.org/packages/ce/91/2ec36480fdb0b783cd9ef6795753c1dea13882f2e68e73bce76ae8c21e6a/pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808", size = 2066678, upload-time = "2025-04-23T18:33:12.224Z" }, +] + +[[package]] +name = "pydantic-settings" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/85/1ea668bbab3c50071ca613c6ab30047fb36ab0da1b92fa8f17bbc38fd36c/pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee", size = 172583, upload-time = "2025-06-24T13:26:46.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/f0/427018098906416f580e3cf1366d3b1abfb408a0652e9f31600c24a1903c/pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796", size = 45235, upload-time = "2025-06-24T13:26:45.485Z" }, +] + +[[package]] +name = "pyflakes" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/45/dc/fd034dc20b4b264b3d015808458391acbf9df40b1e54750ef175d39180b1/pyflakes-3.4.0.tar.gz", hash = "sha256:b24f96fafb7d2ab0ec5075b7350b3d2d2218eab42003821c06344973d3ea2f58", size = 64669, upload-time = "2025-06-20T18:45:27.834Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/2f/81d580a0fb83baeb066698975cb14a618bdbed7720678566f1b046a95fe8/pyflakes-3.4.0-py2.py3-none-any.whl", hash = "sha256:f742a7dbd0d9cb9ea41e9a24a918996e8170c799fa528688d40dd582c8265f4f", size = 63551, upload-time = "2025-06-20T18:45:26.937Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pyopengl" +version = "3.1.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/42/71080db298df3ddb7e3090bfea8fd7c300894d8b10954c22f8719bd434eb/pyopengl-3.1.9.tar.gz", hash = "sha256:28ebd82c5f4491a418aeca9672dffb3adbe7d33b39eada4548a5b4e8c03f60c8", size = 1913642, upload-time = "2025-01-20T02:17:53.263Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/44/8634af40b0db528b5b37e901c0dc67321354880d251bf8965901d57693a5/PyOpenGL-3.1.9-py3-none-any.whl", hash = "sha256:15995fd3b0deb991376805da36137a4ae5aba6ddbb5e29ac1f35462d130a3f77", size = 3190341, upload-time = "2025-01-20T02:17:50.913Z" }, +] + +[[package]] +name = "pyparsing" +version = "3.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/22/f1129e69d94ffff626bdb5c835506b3a5b4f3d070f17ea295e12c2c6f60f/pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be", size = 1088608, upload-time = "2025-03-25T05:01:28.114Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120, upload-time = "2025-03-25T05:01:24.908Z" }, +] + +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +] + +[[package]] +name = "python-dotenv" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/b0/4bc07ccd3572a2f9df7e6782f52b0c6c90dcbb803ac4a167702d7d0dfe1e/python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab", size = 41978, upload-time = "2025-06-24T04:21:07.341Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc", size = 20556, upload-time = "2025-06-24T04:21:06.073Z" }, +] + +[[package]] +name = "requests" +version = "2.32.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e1/0a/929373653770d8a0d7ea76c37de6e41f11eb07559b103b1c02cafb3f7cf8/requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422", size = 135258, upload-time = "2025-06-09T16:43:07.34Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" }, +] + +[[package]] +name = "requests-oauthlib" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "oauthlib" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650, upload-time = "2024-03-22T20:32:29.939Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179, upload-time = "2024-03-22T20:32:28.055Z" }, +] + +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + +[[package]] +name = "ruff" +version = "0.12.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" }, + { url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" }, + { url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" }, + { url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" }, + { url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" }, + { url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" }, + { url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" }, + { url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" }, + { url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" }, + { url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" }, + { url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" }, + { url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" }, + { url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" }, + { url = "https://files.pythonhosted.org/packages/fe/84/7cc7bd73924ee6be4724be0db5414a4a2ed82d06b30827342315a1be9e9c/ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d", size = 10589263, upload-time = "2025-07-11T13:21:07.148Z" }, + { url = "https://files.pythonhosted.org/packages/07/87/c070f5f027bd81f3efee7d14cb4d84067ecf67a3a8efb43aadfc72aa79a6/ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7", size = 11695072, upload-time = "2025-07-11T13:21:11.004Z" }, + { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" }, +] + +[[package]] +name = "scipy" +version = "1.15.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0f/37/6964b830433e654ec7485e45a00fc9a27cf868d622838f6b6d9c5ec0d532/scipy-1.15.3.tar.gz", hash = "sha256:eae3cf522bc7df64b42cad3925c876e1b0b6c35c1337c93e12c0f366f55b0eaf", size = 59419214, upload-time = "2025-05-08T16:13:05.955Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/2f/4966032c5f8cc7e6a60f1b2e0ad686293b9474b65246b0c642e3ef3badd0/scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c", size = 38702770, upload-time = "2025-05-08T16:04:20.849Z" }, + { url = "https://files.pythonhosted.org/packages/a0/6e/0c3bf90fae0e910c274db43304ebe25a6b391327f3f10b5dcc638c090795/scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253", size = 30094511, upload-time = "2025-05-08T16:04:27.103Z" }, + { url = "https://files.pythonhosted.org/packages/ea/b1/4deb37252311c1acff7f101f6453f0440794f51b6eacb1aad4459a134081/scipy-1.15.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:aef683a9ae6eb00728a542b796f52a5477b78252edede72b8327a886ab63293f", size = 22368151, upload-time = "2025-05-08T16:04:31.731Z" }, + { url = "https://files.pythonhosted.org/packages/38/7d/f457626e3cd3c29b3a49ca115a304cebb8cc6f31b04678f03b216899d3c6/scipy-1.15.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:1c832e1bd78dea67d5c16f786681b28dd695a8cb1fb90af2e27580d3d0967e92", size = 25121732, upload-time = "2025-05-08T16:04:36.596Z" }, + { url = "https://files.pythonhosted.org/packages/db/0a/92b1de4a7adc7a15dcf5bddc6e191f6f29ee663b30511ce20467ef9b82e4/scipy-1.15.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:263961f658ce2165bbd7b99fa5135195c3a12d9bef045345016b8b50c315cb82", size = 35547617, upload-time = "2025-05-08T16:04:43.546Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6d/41991e503e51fc1134502694c5fa7a1671501a17ffa12716a4a9151af3df/scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2abc762b0811e09a0d3258abee2d98e0c703eee49464ce0069590846f31d40", size = 37662964, upload-time = "2025-05-08T16:04:49.431Z" }, + { url = "https://files.pythonhosted.org/packages/25/e1/3df8f83cb15f3500478c889be8fb18700813b95e9e087328230b98d547ff/scipy-1.15.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed7284b21a7a0c8f1b6e5977ac05396c0d008b89e05498c8b7e8f4a1423bba0e", size = 37238749, upload-time = "2025-05-08T16:04:55.215Z" }, + { url = "https://files.pythonhosted.org/packages/93/3e/b3257cf446f2a3533ed7809757039016b74cd6f38271de91682aa844cfc5/scipy-1.15.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5380741e53df2c566f4d234b100a484b420af85deb39ea35a1cc1be84ff53a5c", size = 40022383, upload-time = "2025-05-08T16:05:01.914Z" }, + { url = "https://files.pythonhosted.org/packages/d1/84/55bc4881973d3f79b479a5a2e2df61c8c9a04fcb986a213ac9c02cfb659b/scipy-1.15.3-cp310-cp310-win_amd64.whl", hash = "sha256:9d61e97b186a57350f6d6fd72640f9e99d5a4a2b8fbf4b9ee9a841eab327dc13", size = 41259201, upload-time = "2025-05-08T16:05:08.166Z" }, +] + +[[package]] +name = "setuptools" +version = "80.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + +[[package]] +name = "tensorboard" +version = "2.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "google-auth" }, + { name = "google-auth-oauthlib" }, + { name = "grpcio" }, + { name = "markdown" }, + { name = "numpy" }, + { name = "protobuf" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "six" }, + { name = "tensorboard-data-server" }, + { name = "werkzeug" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/0c/1059a6682cf2cc1fcc0d5327837b5672fe4f5574255fa5430d0a8ceb75e9/tensorboard-2.15.1-py3-none-any.whl", hash = "sha256:c46c1d1cf13a458c429868a78b2531d8ff5f682058d69ec0840b0bc7a38f1c0f", size = 5539710, upload-time = "2023-11-02T20:49:50.813Z" }, +] + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356, upload-time = "2023-10-23T21:23:32.16Z" }, + { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598, upload-time = "2023-10-23T21:23:33.714Z" }, + { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" }, +] + +[[package]] +name = "tomli" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175, upload-time = "2024-11-27T22:38:36.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, +] + +[[package]] +name = "torch" +version = "2.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/b3/1fcc3bccfddadfd6845dcbfe26eb4b099f1dfea5aa0e5cfb92b3c98dba5b/torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585", size = 755526581, upload-time = "2024-03-27T21:06:46.5Z" }, + { url = "https://files.pythonhosted.org/packages/c3/7c/aeb0c5789a3f10cf909640530cd75b314959b9d9914a4996ed2c7bf8779d/torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030", size = 86623646, upload-time = "2024-03-27T21:10:22.719Z" }, + { url = "https://files.pythonhosted.org/packages/3a/81/684d99e536b20e869a7c1222cf1dd233311fb05d3628e9570992bfb65760/torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5", size = 198579616, upload-time = "2024-03-27T21:10:15.41Z" }, + { url = "https://files.pythonhosted.org/packages/3b/55/7192974ab13e5e5577f45d14ce70d42f5a9a686b4f57bbe8c9ab45c4a61a/torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e", size = 150788930, upload-time = "2024-03-27T21:08:09.98Z" }, + { url = "https://files.pythonhosted.org/packages/33/6b/21496316c9b8242749ee2a9064406271efdf979e91d440e8a3806b5e84bf/torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2", size = 59707286, upload-time = "2024-03-27T21:10:28.154Z" }, +] + +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + +[[package]] +name = "triton" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/05/ed974ce87fe8c8843855daa2136b3409ee1c126707ab54a8b72815c08b49/triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5", size = 167900779, upload-time = "2024-01-10T03:11:56.576Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.14.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/b1/0c11f5058406b3af7609f121aaa6b609744687f1d158b3c3a5bf4cc94238/typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28", size = 75726, upload-time = "2025-05-21T18:55:23.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, +] + +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] + +[[package]] +name = "werkzeug" +version = "3.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925, upload-time = "2024-11-08T15:52:18.093Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/24/ab44c871b0f07f491e5d2ad12c9bd7358e527510618cb1b803a88e986db1/werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e", size = 224498, upload-time = "2024-11-08T15:52:16.132Z" }, +] + +[[package]] +name = "wrapt" +version = "1.17.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/fc/e91cc220803d7bc4db93fb02facd8461c37364151b8494762cc88b0fbcef/wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3", size = 55531, upload-time = "2025-01-14T10:35:45.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/d1/1daec934997e8b160040c78d7b31789f19b122110a75eca3d4e8da0049e1/wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984", size = 53307, upload-time = "2025-01-14T10:33:13.616Z" }, + { url = "https://files.pythonhosted.org/packages/1b/7b/13369d42651b809389c1a7153baa01d9700430576c81a2f5c5e460df0ed9/wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22", size = 38486, upload-time = "2025-01-14T10:33:15.947Z" }, + { url = "https://files.pythonhosted.org/packages/62/bf/e0105016f907c30b4bd9e377867c48c34dc9c6c0c104556c9c9126bd89ed/wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7", size = 38777, upload-time = "2025-01-14T10:33:17.462Z" }, + { url = "https://files.pythonhosted.org/packages/27/70/0f6e0679845cbf8b165e027d43402a55494779295c4b08414097b258ac87/wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c", size = 83314, upload-time = "2025-01-14T10:33:21.282Z" }, + { url = "https://files.pythonhosted.org/packages/0f/77/0576d841bf84af8579124a93d216f55d6f74374e4445264cb378a6ed33eb/wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72", size = 74947, upload-time = "2025-01-14T10:33:24.414Z" }, + { url = "https://files.pythonhosted.org/packages/90/ec/00759565518f268ed707dcc40f7eeec38637d46b098a1f5143bff488fe97/wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061", size = 82778, upload-time = "2025-01-14T10:33:26.152Z" }, + { url = "https://files.pythonhosted.org/packages/f8/5a/7cffd26b1c607b0b0c8a9ca9d75757ad7620c9c0a9b4a25d3f8a1480fafc/wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2", size = 81716, upload-time = "2025-01-14T10:33:27.372Z" }, + { url = "https://files.pythonhosted.org/packages/7e/09/dccf68fa98e862df7e6a60a61d43d644b7d095a5fc36dbb591bbd4a1c7b2/wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c", size = 74548, upload-time = "2025-01-14T10:33:28.52Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8e/067021fa3c8814952c5e228d916963c1115b983e21393289de15128e867e/wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62", size = 81334, upload-time = "2025-01-14T10:33:29.643Z" }, + { url = "https://files.pythonhosted.org/packages/4b/0d/9d4b5219ae4393f718699ca1c05f5ebc0c40d076f7e65fd48f5f693294fb/wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563", size = 36427, upload-time = "2025-01-14T10:33:30.832Z" }, + { url = "https://files.pythonhosted.org/packages/72/6a/c5a83e8f61aec1e1aeef939807602fb880e5872371e95df2137142f5c58e/wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f", size = 38774, upload-time = "2025-01-14T10:33:32.897Z" }, + { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594, upload-time = "2025-01-14T10:35:44.018Z" }, +] From 164b6e374e7d84da0691c119720806bc96543f8b Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 20:04:53 +0300 Subject: [PATCH 21/30] Add uv to dockerfile and ci --- .github/workflows/unit-tests-docker.yml | 2 +- Dockerfile | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml index 7adeafd..fa63c00 100644 --- a/.github/workflows/unit-tests-docker.yml +++ b/.github/workflows/unit-tests-docker.yml @@ -20,7 +20,7 @@ jobs: - name: Extract coverage run: | docker run --rm -v $(pwd):/host oprl sh -c " - pytest tests/functional --cov=oprl --cov-report=xml && + uv run pytest tests/functional --cov=oprl --cov-report=xml && cp coverage.xml /host/ " diff --git a/Dockerfile b/Dockerfile index 688df96..1f1d26b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,19 +1,22 @@ FROM python:3.10.8 +COPY --from=ghcr.io/astral-sh/uv:0.7.21 /uv /uvx /bin/ WORKDIR /app -RUN pip install --no-cache-dir --upgrade pip +# RUN pip install --no-cache-dir --upgrade pip + + +COPY . . + +RUN uv sync --locked && uv pip install pytest pytest-cov +# RUN pip install --no-cache-dir . && pip install pytest pytest-cov # Install SafetyGymansium from external lib RUN wget https://github.com/PKU-Alignment/safety-gymnasium/archive/refs/heads/main.zip && \ unzip main.zip && \ cd safety-gymnasium-main && \ - pip install . && \ + uv pip install . && \ cd .. -COPY . . - -RUN pip install --no-cache-dir . && pip install pytest pytest-cov - # Run tests by default -CMD ["pytest", "tests/functional"] +CMD ["uv", "run", "pytest", "tests/functional"] From 32f1b1f985d3ae20f4f592afcadae0f11ed72840 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 20:59:03 +0300 Subject: [PATCH 22/30] Add tests --- src/oprl/buffers/protocols.py | 3 ++- src/oprl/logging.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/oprl/buffers/protocols.py b/src/oprl/buffers/protocols.py index 4c744c5..beac0c0 100644 --- a/src/oprl/buffers/protocols.py +++ b/src/oprl/buffers/protocols.py @@ -1,8 +1,9 @@ -from typing import Protocol +from typing import Protocol, runtime_checkable import torch as t +@runtime_checkable class ReplayBufferProtocol(Protocol): episodes_counter: int _created: bool diff --git a/src/oprl/logging.py b/src/oprl/logging.py index e3cf137..d34d6e9 100644 --- a/src/oprl/logging.py +++ b/src/oprl/logging.py @@ -5,11 +5,12 @@ from datetime import datetime import shutil from abc import ABC, abstractmethod -from typing import Protocol, Callable +from typing import Protocol, Callable, runtime_checkable from torch.utils.tensorboard.writer import SummaryWriter +@runtime_checkable class LoggerProtocol(Protocol): log_dir: Path @@ -26,7 +27,7 @@ def get_logs_path(logdir: str, algo: str, env: str, seed: int) -> Path: return log_dir -def create_stdout_logger(name: str | None = None): +def create_stdout_logger(name: str = None): if name is None: import inspect frame = inspect.currentframe().f_back From dc8db60dcba68d0b37405cac3c1c75e9b9f19e3b Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 21:26:33 +0300 Subject: [PATCH 23/30] Add mypy check for custom modules --- .github/workflows/unit-tests-docker.yml | 7 ++++++- Dockerfile | 6 +----- src/oprl/algos/nn_models.py | 4 ++-- src/oprl/algos/protocols.py | 2 ++ src/oprl/runners/train.py | 2 +- src/oprl/trainers/base_trainer.py | 8 ++++---- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml index fa63c00..33c54c7 100644 --- a/.github/workflows/unit-tests-docker.yml +++ b/.github/workflows/unit-tests-docker.yml @@ -5,7 +5,7 @@ on: jobs: unit_tests: - name: Unit Tests + name: Unit Tests - Coverage - MyPy runs-on: ubuntu-latest steps: - name: Checkout code @@ -17,6 +17,11 @@ jobs: - name: Unit Tests run: docker run --rm oprl + - name: MyPy + run: docker run --rm sh -c " + uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/runners src/oprl/trainers src/opr/buffers + " + - name: Extract coverage run: | docker run --rm -v $(pwd):/host oprl sh -c " diff --git a/Dockerfile b/Dockerfile index 1f1d26b..9bc28d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,13 +3,9 @@ COPY --from=ghcr.io/astral-sh/uv:0.7.21 /uv /uvx /bin/ WORKDIR /app -# RUN pip install --no-cache-dir --upgrade pip - - COPY . . -RUN uv sync --locked && uv pip install pytest pytest-cov -# RUN pip install --no-cache-dir . && pip install pytest pytest-cov +RUN uv sync --locked && uv pip install pytest pytest-cov mypy # Install SafetyGymansium from external lib RUN wget https://github.com/PKU-Alignment/safety-gymnasium/archive/refs/heads/main.zip && \ diff --git a/src/oprl/algos/nn_models.py b/src/oprl/algos/nn_models.py index 599cc15..0e37964 100644 --- a/src/oprl/algos/nn_models.py +++ b/src/oprl/algos/nn_models.py @@ -11,7 +11,7 @@ LOG_STD_MIN_MAX: Final[tuple[float, float]] = (-20, 2) -def initialize_weight_orthogonal(m: nn.Module, gain: float = nn.init.calculate_gain("relu")): +def initialize_weight_orthogonal(m: nn.Module, gain: int = nn.init.calculate_gain("relu")): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data, gain) m.bias.data.fill_(0.0) @@ -194,7 +194,7 @@ def exploit(self, state: npt.NDArray) -> npt.NDArray: return action -class TanhNormal(Distribution): +class TanhNormal: def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str) -> None: super().__init__() self.normal_mean = normal_mean diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py index 065d12f..025b490 100644 --- a/src/oprl/algos/protocols.py +++ b/src/oprl/algos/protocols.py @@ -15,6 +15,8 @@ def exploit(self, state: npt.NDArray) -> npt.NDArray: ... def __call__(*args, **kwargs) -> t.Tensor: ... + def state_dict(self) -> dict: ... + class AlgorithmProtocol(Protocol): actor: PolicyProtocol diff --git a/src/oprl/runners/train.py b/src/oprl/runners/train.py index d801dac..c128890 100644 --- a/src/oprl/runners/train.py +++ b/src/oprl/runners/train.py @@ -59,7 +59,7 @@ def _run_training_func( seed: int, ) -> None: set_seed(seed) - env = make_env(seed=seed) + env = make_env(seed) replay_buffer = make_replay_buffer() logger = make_logger(seed) algo = make_algo(logger) diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py index b8adb3e..7c15529 100644 --- a/src/oprl/trainers/base_trainer.py +++ b/src/oprl/trainers/base_trainer.py @@ -46,7 +46,7 @@ def train(self) -> None: if env_step <= self.start_steps: action = self.env.sample_action() else: - action = self.algo.explore(state) + action = self.algo.actor.explore(state) next_state, reward, terminated, truncated, _ = self.env.step(action) self.replay_buffer.add_transition( @@ -93,7 +93,7 @@ def _log_evaluation(self, env_step: int, rewards: t.Tensor) -> None: def evaluate(self) -> dict[str, float]: returns = [] for i_ep in range(self.num_eval_episodes): - env_test = self.make_env_test(seed=self.seed + i_ep) + env_test = self.make_env_test(self.seed + i_ep) state, _ = env_test.reset() episode_return = 0.0 @@ -140,7 +140,7 @@ def _log_stdout(self, env_step: int, rewards: t.Tensor) -> None: def estimate_true_q(self, eval_episodes: int = 10) -> float: qs = [] for i_eval in range(eval_episodes): - env = self.make_env_test(seed=self.seed * 100 + i_eval) + env = self.make_env_test(self.seed * 100 + i_eval) state, _ = env.reset() q = 0 @@ -159,7 +159,7 @@ def estimate_true_q(self, eval_episodes: int = 10) -> float: def estimate_critic_q(self, num_episodes: int = 10) -> float: qs = [] for i_eval in range(num_episodes): - env = self.make_env_test(seed=self.seed * 100 + i_eval) + env = self.make_env_test(self.seed * 100 + i_eval) state, _ = env.reset() action = self.algo.actor.exploit(state) From 2df0394a108d8959587649740432641759290fe1 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 21:37:29 +0300 Subject: [PATCH 24/30] Add ruff to ci pipeline --- .github/workflows/unit-tests-docker.yml | 13 +++++-- src/oprl/algos/base_algorithm.py | 8 ----- src/oprl/algos/nn_models.py | 2 +- tests/functional/test_logging.py | 19 +++++++++++ tests/functional/test_replay_buffer.py | 21 ++++++++++++ tests/functional/test_rl_algos.py | 45 ++++++++++++++++++------- 6 files changed, 83 insertions(+), 25 deletions(-) create mode 100644 tests/functional/test_logging.py create mode 100644 tests/functional/test_replay_buffer.py diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml index 33c54c7..dc72da6 100644 --- a/.github/workflows/unit-tests-docker.yml +++ b/.github/workflows/unit-tests-docker.yml @@ -13,14 +13,21 @@ jobs: - name: Build Docker image run: docker build -t oprl . + + - name: Ruff + run: | + docker run --rm sh -c " + uv run ruff check src + " - name: Unit Tests run: docker run --rm oprl - name: MyPy - run: docker run --rm sh -c " - uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/runners src/oprl/trainers src/opr/buffers - " + run: | + docker run --rm sh -c " + uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/runners src/oprl/trainers src/opr/buffers + " - name: Extract coverage run: | diff --git a/src/oprl/algos/base_algorithm.py b/src/oprl/algos/base_algorithm.py index c3e5aa6..7c88623 100644 --- a/src/oprl/algos/base_algorithm.py +++ b/src/oprl/algos/base_algorithm.py @@ -1,8 +1,6 @@ from abc import ABC from typing import Any -import numpy.typing as npt - from oprl.algos.protocols import AlgorithmProtocol @@ -13,11 +11,5 @@ def check_created(self) -> None: f"Algorithm {type(self).__name__} has not been created with `create()`." ) - def exploit(self, state: npt.NDArray) -> npt.NDArray: - return self.actor.exploit(state) - - def explore(self, state: npt.NDArray) -> npt.NDArray: - return self.actor.explore(state) - def get_policy_state_dict(self) -> dict[str, Any]: return self.actor.state_dict() diff --git a/src/oprl/algos/nn_models.py b/src/oprl/algos/nn_models.py index 0e37964..f35987d 100644 --- a/src/oprl/algos/nn_models.py +++ b/src/oprl/algos/nn_models.py @@ -4,7 +4,7 @@ import numpy.typing as npt import torch as t import torch.nn as nn -from torch.distributions import Distribution, Normal +from torch.distributions import Normal from torch.nn.functional import logsigmoid diff --git a/tests/functional/test_logging.py b/tests/functional/test_logging.py new file mode 100644 index 0000000..f09d8d8 --- /dev/null +++ b/tests/functional/test_logging.py @@ -0,0 +1,19 @@ +from logging import Logger + +from oprl.logging import ( + LoggerProtocol, + create_stdout_logger, + make_text_logger_func, +) + + +def test_create_stdout_logger() -> None: + logger = create_stdout_logger() + assert isinstance(logger, Logger) + + +def test_create_text_logger_func() -> None: + func = make_text_logger_func("test_algo", "test_env") + logger = func(0) + assert isinstance(logger, LoggerProtocol) + diff --git a/tests/functional/test_replay_buffer.py b/tests/functional/test_replay_buffer.py new file mode 100644 index 0000000..dde0244 --- /dev/null +++ b/tests/functional/test_replay_buffer.py @@ -0,0 +1,21 @@ +from oprl.buffers.episodic_buffer import EpisodicReplayBuffer +from oprl.buffers.protocols import ReplayBufferProtocol + + +def test_replay_buffer() -> None: + state_dim = 7 + max_episode_length = 10 + num_transitions = 100 + buffer = EpisodicReplayBuffer( + buffer_size_transitions=num_transitions, + state_dim=state_dim, + action_dim=3, + max_episode_lenth=max_episode_length, + ).create() + assert isinstance(buffer, ReplayBufferProtocol) + + states = buffer.states + assert len(states.shape) == 3 + assert states.shape[0] == num_transitions // max_episode_length + assert states.shape[1] == max_episode_length + 1 + assert states.shape[2] == state_dim diff --git a/tests/functional/test_rl_algos.py b/tests/functional/test_rl_algos.py index a202249..3f32ce0 100644 --- a/tests/functional/test_rl_algos.py +++ b/tests/functional/test_rl_algos.py @@ -1,5 +1,6 @@ import pytest import torch +import numpy.typing as npt from oprl.algos.protocols import AlgorithmProtocol from oprl.algos.ddpg import DDPG @@ -7,24 +8,13 @@ from oprl.algos.td3 import TD3 from oprl.algos.tqc import TQC from oprl.environment import DMControlEnv +from oprl.environment.protocols import EnvProtocol from oprl.logging import FileTxtLogger rl_algo_classes: list[type[AlgorithmProtocol]] = [DDPG, SAC, TD3, TQC] - -@pytest.mark.parametrize("algo_class", rl_algo_classes) -def test_rl_algo_run(algo_class: type[AlgorithmProtocol]) -> None: - env = DMControlEnv("walker-walk", seed=0) - # TODO: Change to mocked logger - logger = FileTxtLogger(".") - obs, _ = env.reset() - - algo = algo_class( - logger=logger, - state_dim=env.observation_space.shape[0], - action_dim=env.action_space.shape[0], - ).create() +def _run_common_test(algo: AlgorithmProtocol, env: EnvProtocol, obs: npt.NDArray) -> None: action = algo.actor.exploit(obs) assert action.ndim == 1 @@ -39,3 +29,32 @@ def test_rl_algo_run(algo_class: type[AlgorithmProtocol]) -> None: batch_rewards = torch.randn(_batch_size, 1) batch_dones = torch.randint(2, (_batch_size, 1)) algo.update(batch_obs, batch_actions, batch_rewards, batch_dones, batch_obs) + + +@pytest.mark.parametrize("algo_class", rl_algo_classes) +def test_ddpg_td3_tqc(algo_class: type[AlgorithmProtocol]) -> None: + env = DMControlEnv("walker-walk", seed=0) + # TODO: Change to mocked logger + logger = FileTxtLogger(".") + obs, _ = env.reset() + algo = algo_class( + logger=logger, + state_dim=env.observation_space.shape[0], + action_dim=env.action_space.shape[0], + ).create() + _run_common_test(algo, env, obs) + + +@pytest.mark.parametrize("tune_alpha", [True, False]) +def test_sac(tune_alpha: bool) -> None: + env = DMControlEnv("walker-walk", seed=0) + # TODO: Change to mocked logger + logger = FileTxtLogger(".") + obs, _ = env.reset() + algo = SAC( + logger=logger, + tune_alpha=tune_alpha, + state_dim=env.observation_space.shape[0], + action_dim=env.action_space.shape[0], + ).create() + _run_common_test(algo, env, obs) From 29aa426bb9b0ca0fbc10906a5470bf139ac6937f Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 21:41:18 +0300 Subject: [PATCH 25/30] Fix in ci --- .github/workflows/unit-tests-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml index dc72da6..e2816b3 100644 --- a/.github/workflows/unit-tests-docker.yml +++ b/.github/workflows/unit-tests-docker.yml @@ -25,7 +25,7 @@ jobs: - name: MyPy run: | - docker run --rm sh -c " + docker run --rm oprl sh -c " uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/runners src/oprl/trainers src/opr/buffers " From 6c9b56d9d11a4dba5369bb38db065d9ce22bf4ca Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 21:44:38 +0300 Subject: [PATCH 26/30] Fix ci --- .github/workflows/unit-tests-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml index e2816b3..68f879c 100644 --- a/.github/workflows/unit-tests-docker.yml +++ b/.github/workflows/unit-tests-docker.yml @@ -16,7 +16,7 @@ jobs: - name: Ruff run: | - docker run --rm sh -c " + docker run --rm oprl sh -c " uv run ruff check src " From 00e255d0859a9070c464fbf6a51c62e9fa2f201d Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 21:47:47 +0300 Subject: [PATCH 27/30] Fix mypy typo in ci --- .github/workflows/unit-tests-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml index 68f879c..4f9558c 100644 --- a/.github/workflows/unit-tests-docker.yml +++ b/.github/workflows/unit-tests-docker.yml @@ -26,7 +26,7 @@ jobs: - name: MyPy run: | docker run --rm oprl sh -c " - uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/runners src/oprl/trainers src/opr/buffers + uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/runners src/oprl/trainers src/oprl/buffers " - name: Extract coverage From 6efbecb61335d750459f391ed92a4fade911fb61 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 21:52:10 +0300 Subject: [PATCH 28/30] Add missing module file --- src/oprl/runners/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/oprl/runners/__init__.py diff --git a/src/oprl/runners/__init__.py b/src/oprl/runners/__init__.py new file mode 100644 index 0000000..e69de29 From 23cfd4afc6fba1ffebbdbc6d2c6774ca6c971ad4 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Thu, 17 Jul 2025 22:00:31 +0300 Subject: [PATCH 29/30] Exclude runners --- .github/workflows/unit-tests-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml index 4f9558c..770a3fc 100644 --- a/.github/workflows/unit-tests-docker.yml +++ b/.github/workflows/unit-tests-docker.yml @@ -26,7 +26,7 @@ jobs: - name: MyPy run: | docker run --rm oprl sh -c " - uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/runners src/oprl/trainers src/oprl/buffers + uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/trainers src/oprl/buffers " - name: Extract coverage From 60f8f07519133a11e9e42daa5913bf8cb2eaff74 Mon Sep 17 00:00:00 2001 From: Igor Kuznetsov Date: Fri, 18 Jul 2025 00:33:45 +0300 Subject: [PATCH 30/30] Update README.md --- README.md | 103 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 68 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 4a64378..70a9f2e 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,69 @@ -oprl_logo +

+ Description +

-# OPRL - -A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. Benchmarking resutls are available at associated homepage: [Homepage](https://schatty.github.io/oprl/) +A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. The code supports `SafetyGymnasium` environment set for giving a starting point developing SafeRL solutions. Distributed setting is implemented via `pika` library and will be improved in the near future. [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![codecov](https://codecov.io/gh/schatty/oprl/branch/master/graph/badge.svg)](https://codecov.io/gh/schatty/oprl) +### Roadmap 🏗 +- [x] Support for SafetyGymnasium +- [x] Style and readability improvements +- [ ] REDQ, DrQ Algorithms support +- [ ] Distributed Training Improvements +## In a Snapshot -# Disclaimer -The project is under an active renovation, for the old code with D4PG algorithm working with multiprocessing queues and `mujoco_py` please refer to the branch `d4pg_legacy`. +Environments Support -### Roadmap 🏗 -- [x] Switching to `mujoco 3.1.1` -- [x] Replacing multiprocessing queues with RabbitMQ for distributed RL -- [x] Baselines with DDPG, TQC for `dm_control` for 1M step -- [x] Tests -- [x] Support for SafetyGymnasium -- [ ] Style and readability improvements -- [ ] Baselines with Distributed algorithms for `dm_control` -- [ ] D4PG logic on top of TQC +| DMControl Suite | SafetyGymnasium | Gymnasium | +| -------- | -------- | -------- | + +Algorithms + +| DDPG | TD3 | SAC | TQC | +| --- | --- | --- | --- | -# Installation +## Installation + +The project supports [uv](https://docs.astral.sh/uv/) for package managment and [ruff](https://github.com/astral-sh/ruff) for formatting checks. To install it via uv in virutalenv: ``` -pip install -r requirements.txt -cd src && pip install -e . +uv venv +source .venv/bin/activate +uv sync ``` +### Installing SafetyGymnasium + For working with [SafetyGymnasium](https://github.com/PKU-Alignment/safety-gymnasium) install it manually ``` git clone https://github.com/PKU-Alignment/safety-gymnasium -cd safety-gymnasium && pip install -e . +cd safety-gymnasium && uv pip install -e . +``` + +## Tests + +To run tests locally: + ``` +uv pip install pytest +uv run pytest tests/functional +``` + +## RL Training -# Usage +All training is set via python config files located in `configs` folder. To make your own configuration, change the code there or create a similar one. During training, all the code is copied to logs folder to ensure full experimental reproducibility. + +### Single Agent To run DDPG in a single process ``` -python src/oprl/configs/ddpg.py --env walker-walk +python configs/ddpg.py --env walker-walk ``` -To run distributed DDPG +### Distributed Run RabbitMQ ``` @@ -51,15 +72,7 @@ docker run -it --rm --name rabbitmq -p 5672:5672 -p 15672:15672 rabbitmq:3.12-ma Run training ``` -python src/oprl/configs/d3pg.py --env walker-walk -``` - -## Tests - -``` -cd src && pip install -e . -cd .. && pip install -r tests/functional/requirements.txt -python -m pytest tests +python configs/distrib_ddpg.py --env walker-walk ``` ## Results @@ -67,7 +80,27 @@ python -m pytest tests Results for single process DDPG and TQC: ![ddpg_tqc_eval](https://github.com/schatty/d4pg-pytorch/assets/23639048/f2c32f62-63b4-4a66-a636-4ce0ea1522f6) -## Acknowledgements -* DDPG and TD3 code is based on the official TD3 implementation: [sfujim/TD3](https://github.com/sfujim/TD3) -* TQC code is based on the official TQC implementation: [SamsungLabs/tqc](https://github.com/SamsungLabs/tqc) -* SafetyGymnasium: [PKU-Alignment/safety-gymnasium](https://github.com/PKU-Alignment/safety-gymnasium) +## Cite + +__OPRL__ +``` +@inproceedings{ + kuznetsov2024safer, + title={Safer Reinforcement Learning by Going Off-policy: a Benchmark}, + author={Igor Kuznetsov}, + booktitle={ICML 2024 Next Generation of AI Safety Workshop}, + year={2024}, + url={https://openreview.net/forum?id=pAmTC9EdGq} +} +``` + +__SafetyGymnasium__ +``` +@inproceedings{ji2023safety, + title={Safety Gymnasium: A Unified Safe Reinforcement Learning Benchmark}, + author={Jiaming Ji and Borong Zhang and Jiayi Zhou and Xuehai Pan and Weidong Huang and Ruiyang Sun and Yiran Geng and Yifan Zhong and Josef Dai and Yaodong Yang}, + booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, + year={2023}, + url={https://openreview.net/forum?id=WZmlxIuIGR} +} +```