diff --git a/.github/workflows/unit-tests-docker.yml b/.github/workflows/unit-tests-docker.yml
index 7adeafd..770a3fc 100644
--- a/.github/workflows/unit-tests-docker.yml
+++ b/.github/workflows/unit-tests-docker.yml
@@ -5,7 +5,7 @@ on:
jobs:
unit_tests:
- name: Unit Tests
+ name: Unit Tests - Coverage - MyPy
runs-on: ubuntu-latest
steps:
- name: Checkout code
@@ -13,14 +13,26 @@ jobs:
- name: Build Docker image
run: docker build -t oprl .
+
+ - name: Ruff
+ run: |
+ docker run --rm oprl sh -c "
+ uv run ruff check src
+ "
- name: Unit Tests
run: docker run --rm oprl
+ - name: MyPy
+ run: |
+ docker run --rm oprl sh -c "
+ uv run mypy --ignore-missing-imports --python-version 3.10 src/oprl/trainers src/oprl/buffers
+ "
+
- name: Extract coverage
run: |
docker run --rm -v $(pwd):/host oprl sh -c "
- pytest tests/functional --cov=oprl --cov-report=xml &&
+ uv run pytest tests/functional --cov=oprl --cov-report=xml &&
cp coverage.xml /host/
"
diff --git a/Dockerfile b/Dockerfile
index 688df96..9bc28d4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,19 +1,18 @@
FROM python:3.10.8
+COPY --from=ghcr.io/astral-sh/uv:0.7.21 /uv /uvx /bin/
WORKDIR /app
-RUN pip install --no-cache-dir --upgrade pip
+COPY . .
+
+RUN uv sync --locked && uv pip install pytest pytest-cov mypy
# Install SafetyGymansium from external lib
RUN wget https://github.com/PKU-Alignment/safety-gymnasium/archive/refs/heads/main.zip && \
unzip main.zip && \
cd safety-gymnasium-main && \
- pip install . && \
+ uv pip install . && \
cd ..
-COPY . .
-
-RUN pip install --no-cache-dir . && pip install pytest pytest-cov
-
# Run tests by default
-CMD ["pytest", "tests/functional"]
+CMD ["uv", "run", "pytest", "tests/functional"]
diff --git a/README.md b/README.md
index 4a64378..70a9f2e 100644
--- a/README.md
+++ b/README.md
@@ -1,48 +1,69 @@
-
+
+
+
-# OPRL
-
-A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. Benchmarking resutls are available at associated homepage: [Homepage](https://schatty.github.io/oprl/)
+A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. The code supports `SafetyGymnasium` environment set for giving a starting point developing SafeRL solutions. Distributed setting is implemented via `pika` library and will be improved in the near future.
[](https://github.com/psf/black)
[](https://codecov.io/gh/schatty/oprl)
+### Roadmap 🏗
+- [x] Support for SafetyGymnasium
+- [x] Style and readability improvements
+- [ ] REDQ, DrQ Algorithms support
+- [ ] Distributed Training Improvements
+## In a Snapshot
-# Disclaimer
-The project is under an active renovation, for the old code with D4PG algorithm working with multiprocessing queues and `mujoco_py` please refer to the branch `d4pg_legacy`.
+Environments Support
-### Roadmap 🏗
-- [x] Switching to `mujoco 3.1.1`
-- [x] Replacing multiprocessing queues with RabbitMQ for distributed RL
-- [x] Baselines with DDPG, TQC for `dm_control` for 1M step
-- [x] Tests
-- [x] Support for SafetyGymnasium
-- [ ] Style and readability improvements
-- [ ] Baselines with Distributed algorithms for `dm_control`
-- [ ] D4PG logic on top of TQC
+| DMControl Suite | SafetyGymnasium | Gymnasium |
+| -------- | -------- | -------- |
+
+Algorithms
+
+| DDPG | TD3 | SAC | TQC |
+| --- | --- | --- | --- |
-# Installation
+## Installation
+
+The project supports [uv](https://docs.astral.sh/uv/) for package managment and [ruff](https://github.com/astral-sh/ruff) for formatting checks. To install it via uv in virutalenv:
```
-pip install -r requirements.txt
-cd src && pip install -e .
+uv venv
+source .venv/bin/activate
+uv sync
```
+### Installing SafetyGymnasium
+
For working with [SafetyGymnasium](https://github.com/PKU-Alignment/safety-gymnasium) install it manually
```
git clone https://github.com/PKU-Alignment/safety-gymnasium
-cd safety-gymnasium && pip install -e .
+cd safety-gymnasium && uv pip install -e .
+```
+
+## Tests
+
+To run tests locally:
+
```
+uv pip install pytest
+uv run pytest tests/functional
+```
+
+## RL Training
-# Usage
+All training is set via python config files located in `configs` folder. To make your own configuration, change the code there or create a similar one. During training, all the code is copied to logs folder to ensure full experimental reproducibility.
+
+### Single Agent
To run DDPG in a single process
```
-python src/oprl/configs/ddpg.py --env walker-walk
+python configs/ddpg.py --env walker-walk
```
-To run distributed DDPG
+### Distributed
Run RabbitMQ
```
@@ -51,15 +72,7 @@ docker run -it --rm --name rabbitmq -p 5672:5672 -p 15672:15672 rabbitmq:3.12-ma
Run training
```
-python src/oprl/configs/d3pg.py --env walker-walk
-```
-
-## Tests
-
-```
-cd src && pip install -e .
-cd .. && pip install -r tests/functional/requirements.txt
-python -m pytest tests
+python configs/distrib_ddpg.py --env walker-walk
```
## Results
@@ -67,7 +80,27 @@ python -m pytest tests
Results for single process DDPG and TQC:

-## Acknowledgements
-* DDPG and TD3 code is based on the official TD3 implementation: [sfujim/TD3](https://github.com/sfujim/TD3)
-* TQC code is based on the official TQC implementation: [SamsungLabs/tqc](https://github.com/SamsungLabs/tqc)
-* SafetyGymnasium: [PKU-Alignment/safety-gymnasium](https://github.com/PKU-Alignment/safety-gymnasium)
+## Cite
+
+__OPRL__
+```
+@inproceedings{
+ kuznetsov2024safer,
+ title={Safer Reinforcement Learning by Going Off-policy: a Benchmark},
+ author={Igor Kuznetsov},
+ booktitle={ICML 2024 Next Generation of AI Safety Workshop},
+ year={2024},
+ url={https://openreview.net/forum?id=pAmTC9EdGq}
+}
+```
+
+__SafetyGymnasium__
+```
+@inproceedings{ji2023safety,
+ title={Safety Gymnasium: A Unified Safe Reinforcement Learning Benchmark},
+ author={Jiaming Ji and Borong Zhang and Jiayi Zhou and Xuehai Pan and Weidong Huang and Ruiyang Sun and Yiran Geng and Yifan Zhong and Josef Dai and Yaodong Yang},
+ booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
+ year={2023},
+ url={https://openreview.net/forum?id=WZmlxIuIGR}
+}
+```
diff --git a/configs/ddpg.py b/configs/ddpg.py
new file mode 100644
index 0000000..a183970
--- /dev/null
+++ b/configs/ddpg.py
@@ -0,0 +1,73 @@
+from oprl.algos.protocols import AlgorithmProtocol
+from oprl.algos.ddpg import DDPG
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
+from oprl.parse_args import parse_args
+from oprl.logging import (
+ LoggerProtocol,
+ make_text_logger_func,
+)
+from oprl.environment.protocols import EnvProtocol
+from oprl.environment import make_env as _make_env
+from oprl.runners.train import run_training
+from oprl.runners.config import CommonParameters
+
+
+args = parse_args()
+
+def make_env(seed: int) -> EnvProtocol:
+ return _make_env(args.env, seed=seed)
+
+
+env = make_env(seed=0)
+STATE_DIM: int = env.observation_space.shape[0]
+ACTION_DIM: int = env.action_space.shape[0]
+
+
+config = CommonParameters(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ num_steps=int(100_000),
+ eval_every=2500,
+ device=args.device,
+ estimate_q_every=5000,
+ log_every=2500,
+)
+
+
+def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol:
+ return DDPG(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=args.device,
+ logger=logger,
+ ).create()
+
+
+def make_replay_buffer() -> ReplayBufferProtocol:
+ return EpisodicReplayBuffer(
+ buffer_size_transitions=max(config.num_steps, int(1e6)),
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=config.device,
+ ).create()
+
+
+make_logger = make_text_logger_func(
+ algo="DDPG",
+ env=args.env,
+)
+
+
+if __name__ == "__main__":
+ run_training(
+ make_algo=make_algo,
+ make_env=make_env,
+ make_replay_buffer=make_replay_buffer,
+ make_logger=make_logger,
+ config=config,
+ seeds=args.seeds,
+ start_seed=args.start_seed
+ )
+
+
diff --git a/configs/distrib_ddpg.py b/configs/distrib_ddpg.py
new file mode 100644
index 0000000..70451d7
--- /dev/null
+++ b/configs/distrib_ddpg.py
@@ -0,0 +1,92 @@
+import os
+import logging
+
+import torch.nn as nn
+
+from oprl.algos.ddpg import DDPG
+from oprl.algos.nn_models import DeterministicPolicy
+from oprl.algos.protocols import AlgorithmProtocol, PolicyProtocol
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.environment import make_env as _make_env
+from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
+from oprl.logging import (
+ LoggerProtocol,
+ FileTxtLogger,
+ get_logs_path,
+)
+from oprl.parse_args import parse_args_distrib
+from oprl.runners.config import DistribConfig
+from oprl.runners.train_distrib import run_distrib_training
+from oprl.distrib.env_worker import run_env_worker
+from oprl.distrib.policy_update_worker import run_policy_update_worker
+
+
+config = DistribConfig(
+ batch_size=128,
+ num_env_workers=4,
+ episodes_per_worker=100,
+ warmup_epochs=16,
+ episode_length=1000,
+ learner_num_waits=10,
+)
+
+
+args = parse_args_distrib()
+
+def make_env(seed: int):
+ return _make_env(args.env, seed=seed)
+
+
+env = make_env(seed=0)
+STATE_DIM = env.observation_space.shape[0]
+ACTION_DIM = env.action_space.shape[0]
+logging.info(f"Env state {STATE_DIM}\tEnv action {ACTION_DIM}")
+
+
+def make_logger() -> LoggerProtocol:
+ logs_root = os.environ.get("OPRL_LOGS", "logs")
+ log_dir = get_logs_path(logdir=logs_root, algo="DistribDDPG", env=args.env, seed=0)
+ logger = FileTxtLogger(log_dir)
+ logger.copy_source_code()
+ return logger
+
+
+def make_policy() -> PolicyProtocol:
+ return DeterministicPolicy(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ hidden_units=(256, 256),
+ hidden_activation=nn.ReLU(inplace=True),
+ device=args.device,
+ )
+
+
+def make_replay_buffer() -> ReplayBufferProtocol:
+ return EpisodicReplayBuffer(
+ buffer_size_transitions=int(1_000_000),
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=args.device,
+ ).create()
+
+
+def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol:
+ return DDPG(
+ logger=logger,
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=args.device,
+ ).create()
+
+
+if __name__ == "__main__":
+ run_distrib_training(
+ run_env_worker=run_env_worker,
+ run_policy_update_worker=run_policy_update_worker,
+ make_env=make_env,
+ make_algo=make_algo,
+ make_policy=make_policy,
+ make_replay_buffer=make_replay_buffer,
+ make_logger=make_logger,
+ config=config,
+ )
diff --git a/configs/sac.py b/configs/sac.py
new file mode 100644
index 0000000..48f5dce
--- /dev/null
+++ b/configs/sac.py
@@ -0,0 +1,71 @@
+from oprl.algos.protocols import AlgorithmProtocol
+from oprl.algos.sac import SAC
+from oprl.parse_args import parse_args
+from oprl.logging import (
+ LoggerProtocol,
+ make_text_logger_func,
+)
+from oprl.environment import make_env as _make_env
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
+from oprl.runners.train import run_training
+from oprl.runners.config import CommonParameters
+
+args = parse_args()
+
+
+def make_env(seed: int):
+ return _make_env(args.env, seed=seed)
+
+
+env = make_env(seed=0)
+STATE_DIM: int = env.observation_space.shape[0]
+ACTION_DIM: int = env.action_space.shape[0]
+
+
+config = CommonParameters(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ num_steps=int(100_000),
+ eval_every=2500,
+ device=args.device,
+ estimate_q_every=5000,
+ log_every=1000,
+)
+
+
+def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol:
+ return SAC(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=args.device,
+ logger=logger,
+ ).create()
+
+
+make_logger = make_text_logger_func(
+ algo="SAC",
+ env=args.env,
+)
+
+
+def make_replay_buffer() -> ReplayBufferProtocol:
+ return EpisodicReplayBuffer(
+ buffer_size_transitions=max(config.num_steps, int(1e6)),
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=config.device,
+ ).create()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ run_training(
+ make_algo=make_algo,
+ make_env=make_env,
+ make_replay_buffer=make_replay_buffer,
+ make_logger=make_logger,
+ config=config,
+ seeds=args.seeds,
+ start_seed=args.start_seed
+ )
diff --git a/configs/td3.py b/configs/td3.py
new file mode 100644
index 0000000..827f27f
--- /dev/null
+++ b/configs/td3.py
@@ -0,0 +1,70 @@
+from oprl.algos.protocols import AlgorithmProtocol
+from oprl.algos.td3 import TD3
+from oprl.parse_args import parse_args
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
+from oprl.logging import (
+ LoggerProtocol,
+ make_text_logger_func,
+)
+from oprl.environment import make_env as _make_env
+from oprl.runners.train import run_training
+from oprl.runners.config import CommonParameters
+
+args = parse_args()
+
+
+def make_env(seed: int):
+ return _make_env(args.env, seed=seed)
+
+
+env = make_env(seed=0)
+STATE_DIM: int = env.observation_space.shape[0]
+ACTION_DIM: int = env.action_space.shape[0]
+
+
+config = CommonParameters(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ num_steps=int(100_000),
+ eval_every=2500,
+ device=args.device,
+ estimate_q_every=5000,
+ log_every=2500,
+)
+
+
+def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol:
+ return TD3(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=args.device,
+ logger=logger,
+ ).create()
+
+
+def make_replay_buffer() -> ReplayBufferProtocol:
+ return EpisodicReplayBuffer(
+ buffer_size_transitions=max(config.num_steps, int(1e6)),
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=config.device,
+ ).create()
+
+
+make_logger = make_text_logger_func(
+ algo="TD3",
+ env=args.env,
+)
+
+
+if __name__ == "__main__":
+ run_training(
+ make_algo=make_algo,
+ make_env=make_env,
+ make_logger=make_logger,
+ make_replay_buffer=make_replay_buffer,
+ config=config,
+ seeds=args.seeds,
+ start_seed=args.start_seed
+ )
diff --git a/configs/tqc.py b/configs/tqc.py
new file mode 100644
index 0000000..fdfef89
--- /dev/null
+++ b/configs/tqc.py
@@ -0,0 +1,71 @@
+from oprl.algos.protocols import AlgorithmProtocol
+from oprl.algos.tqc import TQC
+from oprl.parse_args import parse_args
+from oprl.logging import (
+ LoggerProtocol,
+ make_text_logger_func,
+)
+from oprl.environment import make_env as _make_env
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
+from oprl.runners.train import run_training
+from oprl.runners.config import CommonParameters
+
+args = parse_args()
+
+
+def make_env(seed: int):
+ return _make_env(args.env, seed=seed)
+
+
+env = make_env(seed=0)
+STATE_DIM: int = env.observation_space.shape[0]
+ACTION_DIM: int = env.action_space.shape[0]
+
+
+config = CommonParameters(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ num_steps=int(100_000),
+ eval_every=2500,
+ device=args.device,
+ estimate_q_every=0, # TODO: unsupported logic
+ log_every=2500,
+)
+
+
+def make_algo(logger: LoggerProtocol) -> AlgorithmProtocol:
+ return TQC(
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=args.device,
+ logger=logger,
+ ).create()
+
+
+make_logger = make_text_logger_func(
+ algo="TQC",
+ env=args.env,
+)
+
+
+def make_replay_buffer() -> ReplayBufferProtocol:
+ return EpisodicReplayBuffer(
+ buffer_size_transitions=max(config.num_steps, int(1e6)),
+ state_dim=STATE_DIM,
+ action_dim=ACTION_DIM,
+ device=config.device,
+ ).create()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ run_training(
+ make_algo=make_algo,
+ make_env=make_env,
+ make_replay_buffer=make_replay_buffer,
+ make_logger=make_logger,
+ config=config,
+ seeds=args.seeds,
+ start_seed=args.start_seed
+ )
diff --git a/pyproject.toml b/pyproject.toml
index f531d30..57cdc33 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,6 +24,10 @@ dependencies = [
"dm-control==1.0.11",
"mujoco==2.3.3",
"numpy==1.26.4",
+ "pika==1.3.2",
+ "pydantic_settings==2.10.1",
+ "gymnasium==0.28.1",
+ "ruff>=0.12.3",
]
[project.optional-dependencies]
diff --git a/scripts/visualize_policy_from_weights.py b/scripts/visualize_policy_from_weights.py
new file mode 100644
index 0000000..8e4b3a7
--- /dev/null
+++ b/scripts/visualize_policy_from_weights.py
@@ -0,0 +1,89 @@
+import click
+import torch as t
+import numpy as np
+from PIL import Image
+
+from oprl.environment import make_env
+
+
+def create_webp_gif(numpy_arrays, output_path, duration=100, loop=0):
+ """
+ Create a WebP animated image from a list of NumPy arrays.
+
+ Args:
+ numpy_arrays: List of NumPy arrays (each representing a frame)
+ output_path: Output file path (should end with .webp)
+ duration: Duration between frames in milliseconds
+ loop: Number of loops (0 = infinite loop)
+ """
+ # Convert NumPy arrays to PIL Images
+ pil_images = []
+
+ for arr in numpy_arrays:
+ # Ensure the array is in the right format
+ if arr.dtype != np.uint8:
+ # Normalize to 0-255 range if needed
+ if arr.max() <= 1.0:
+ arr = (arr * 255).astype(np.uint8)
+ else:
+ arr = arr.astype(np.uint8)
+
+ # Handle different array shapes
+ if len(arr.shape) == 2: # Grayscale
+ img = Image.fromarray(arr, mode='L')
+ elif len(arr.shape) == 3: # RGB/RGBA
+ if arr.shape[2] == 3:
+ img = Image.fromarray(arr, mode='RGB')
+ elif arr.shape[2] == 4:
+ img = Image.fromarray(arr, mode='RGBA')
+ else:
+ raise ValueError(f"Unsupported number of channels: {arr.shape[2]}")
+ else:
+ raise ValueError(f"Unsupported array shape: {arr.shape}")
+
+ pil_images.append(img)
+
+ # Save as animated WebP
+ pil_images[0].save(
+ output_path,
+ format='WebP',
+ save_all=True,
+ append_images=pil_images[1:],
+ duration=duration,
+ loop=loop,
+ optimize=True
+ )
+
+
+@click.command()
+@click.option("--policy", "-p", help="Path to policy weights.")
+@click.option("--output", "-o", default="policy.webp", help="Path to output file.")
+@click.option("--env", "-e", default="walker-walk", help="Environemnt name.")
+@click.option("--seed", "-s", default=0, help="Environment seed.")
+def visualize_policy(policy, output, env, seed):
+ env = make_env(env, seed=seed)
+
+ actor = t.load(policy, weights_only=False)
+ print("Actor loaded: ", type(actor))
+
+ imgs = []
+ state, _ = env.reset()
+ done = False
+ while not done:
+ img = np.expand_dims(env.render(), axis=0) # [1, W, H, C]
+ imgs.append(img)
+ action = actor.exploit(t.from_numpy(state))
+ state, _, terminated, truncated, _ = env.step(action)
+ done = terminated or truncated
+
+ print("imgs: ", len(imgs), imgs[0].shape)
+ frames = np.concatenate(imgs, dtype="uint8", axis=0)
+ print("frames: ", frames.shape)
+
+ # Create the WebP GIF
+ create_webp_gif(frames, output, duration=25)
+ print("WebP GIF for dm_control created successfully!")
+
+
+if __name__ == "__main__":
+ visualize_policy()
diff --git a/src/oprl/algos/__init__.py b/src/oprl/algos/__init__.py
index b6f3a8f..8b13789 100644
--- a/src/oprl/algos/__init__.py
+++ b/src/oprl/algos/__init__.py
@@ -1,17 +1 @@
-from typing import Protocol
-
-import torch as t
-
-
-class OffPolicyAlgorithm(Protocol):
- def create(self) -> "OffPolicyAlgorithm": ...
-
- def update(
- self,
- state: t.Tensor,
- action: t.Tensor,
- reward: t.Tensor,
- done: t.Tensor,
- next_state: t.Tensor
- ) -> None: ...
diff --git a/src/oprl/algos/base_algorithm.py b/src/oprl/algos/base_algorithm.py
new file mode 100644
index 0000000..7c88623
--- /dev/null
+++ b/src/oprl/algos/base_algorithm.py
@@ -0,0 +1,15 @@
+from abc import ABC
+from typing import Any
+
+from oprl.algos.protocols import AlgorithmProtocol
+
+
+class OffPolicyAlgorithm(ABC, AlgorithmProtocol):
+ def check_created(self) -> None:
+ if not self._created:
+ raise RuntimeError(
+ f"Algorithm {type(self).__name__} has not been created with `create()`."
+ )
+
+ def get_policy_state_dict(self) -> dict[str, Any]:
+ return self.actor.state_dict()
diff --git a/src/oprl/algos/ddpg.py b/src/oprl/algos/ddpg.py
index 714a166..782bd58 100644
--- a/src/oprl/algos/ddpg.py
+++ b/src/oprl/algos/ddpg.py
@@ -1,29 +1,40 @@
from copy import deepcopy
-from dataclasses import dataclass
-from typing import Any
+from dataclasses import dataclass, field
-import numpy as np
-import numpy.typing as npt
import torch as t
from torch import nn
-from oprl.algos import OffPolicyAlgorithm
-from oprl.algos.nn import Critic, DeterministicPolicy
-from oprl.algos.utils import disable_gradient
-from oprl.utils.logger import Logger, StdLogger
+from oprl.algos.protocols import PolicyProtocol
+from oprl.algos.base_algorithm import OffPolicyAlgorithm
+from oprl.algos.nn_models import Critic, DeterministicPolicy
+from oprl.algos.nn_functions import disable_gradient
+from oprl.logging import LoggerProtocol
+# TODO: Do I need max_action all the time? need to check envs for their max actions
+
@dataclass
class DDPG(OffPolicyAlgorithm):
+ logger: LoggerProtocol
state_dim: int
action_dim: int
expl_noise: float = 0.1
- discount: float = 0.99
+ gamma: float = 0.99
+ lr_actor: float = 3e-4
+ lr_critic: float = 3e-4
tau: float = 5e-3
batch_size: int = 256
max_action: float = 1.
device: str = "cpu"
- logger: Logger = StdLogger()
+
+ actor: PolicyProtocol = field(init=False)
+ actor_target: PolicyProtocol = field(init=False)
+ optim_actor: t.optim.Optimizer = field(init=False)
+ critic: nn.Module = field(init=False)
+ critic_target: nn.Module = field(init=False)
+ optim_critic: t.optim.Optimizer = field(init=False)
+ update_step: int = 0
+ _created: bool = False
def create(self) -> "DDPG":
self.actor = DeterministicPolicy(
@@ -31,15 +42,20 @@ def create(self) -> "DDPG":
action_dim=self.action_dim,
hidden_units=(256, 256),
hidden_activation=nn.ReLU(inplace=True),
+ expl_noise=self.expl_noise,
+ max_action=self.max_action,
+ device=self.device,
).to(self.device)
self.actor_target = deepcopy(self.actor)
disable_gradient(self.actor_target)
- self.optim_actor = t.optim.Adam(self.actor.parameters(), lr=3e-4)
+ self.optim_actor = t.optim.Adam(self.actor.parameters(), lr=self.lr_actor)
self.critic = Critic(self.state_dim, self.action_dim).to(self.device)
self.critic_target = deepcopy(self.critic)
disable_gradient(self.critic_target)
- self.optim_critic = t.optim.Adam(self.critic.parameters(), lr=3e-4)
+ self.optim_critic = t.optim.Adam(self.critic.parameters(), lr=self.lr_critic)
+
+ self._created = True
return self
def update(
@@ -49,11 +65,10 @@ def update(
reward: t.Tensor,
done: t.Tensor,
next_state: t.Tensor,
- ):
+ ) -> None:
self._update_critic(state, action, reward, done, next_state)
self._update_actor(state)
- # Update the frozen target models
for param, target_param in zip(
self.critic.parameters(), self.critic_target.parameters()
):
@@ -77,40 +92,17 @@ def _update_critic(
next_state: t.Tensor,
) -> None:
target_Q = self.critic_target(next_state, self.actor_target(next_state))
- target_Q = reward + (1.0 - done) * self.discount * target_Q.detach()
+ target_Q = reward + (1.0 - done) * self.gamma * target_Q.detach()
current_Q = self.critic(state, action)
critic_loss = (current_Q - target_Q).pow(2).mean()
-
self.optim_critic.zero_grad()
critic_loss.backward()
self.optim_critic.step()
def _update_actor(self, state: t.Tensor) -> None:
actor_loss = -self.critic(state, self.actor(state)).mean()
-
self.optim_actor.zero_grad()
actor_loss.backward()
self.optim_actor.step()
- def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike:
- state = t.tensor(state, device=self.device).unsqueeze_(0)
- with t.no_grad():
- action = self.actor(state).cpu()
- return action.numpy().flatten()
-
- # TODO: remove explore from algo to agent completely
- def explore(self, state: npt.ArrayLike) -> npt.ArrayLike:
- state = t.tensor(state, device=self.device).unsqueeze_(0)
-
- with t.no_grad():
- noise = (
- t.randn(self.action_dim) * self.max_action * self.expl_noise
- ).to(self.device)
- action = self.actor(state) + noise
-
- a = action.cpu().numpy()[0]
- return np.clip(a, -self.max_action, self.max_action)
-
- def get_policy_state_dict(self) -> dict[str, Any]:
- return self.actor.state_dict()
diff --git a/src/oprl/algos/nn_functions.py b/src/oprl/algos/nn_functions.py
new file mode 100644
index 0000000..e168e02
--- /dev/null
+++ b/src/oprl/algos/nn_functions.py
@@ -0,0 +1,16 @@
+import torch as t
+import torch.nn as nn
+
+
+def soft_update(target: nn.Module, source: nn.Module, tau: float) -> None:
+ """Update target network using Polyak-Ruppert Averaging."""
+ with t.no_grad():
+ for tgt, src in zip(target.parameters(), source.parameters()):
+ tgt.data.mul_(1.0 - tau)
+ tgt.data.add_(tau * src.data)
+
+
+def disable_gradient(network: nn.Module) -> None:
+ """Disable gradient calculations of the network."""
+ for param in network.parameters():
+ param.requires_grad = False
diff --git a/src/oprl/algos/nn.py b/src/oprl/algos/nn_models.py
similarity index 68%
rename from src/oprl/algos/nn.py
rename to src/oprl/algos/nn_models.py
index cbac507..f35987d 100644
--- a/src/oprl/algos/nn.py
+++ b/src/oprl/algos/nn_models.py
@@ -1,13 +1,27 @@
+from typing import Final
+
import numpy as np
import numpy.typing as npt
import torch as t
import torch.nn as nn
-from torch.distributions import Distribution, Normal
+from torch.distributions import Normal
from torch.nn.functional import logsigmoid
-from oprl.algos.utils import initialize_weight
-LOG_STD_MIN_MAX = (-20, 2)
+LOG_STD_MIN_MAX: Final[tuple[float, float]] = (-20, 2)
+
+
+def initialize_weight_orthogonal(m: nn.Module, gain: int = nn.init.calculate_gain("relu")):
+ if isinstance(m, nn.Linear):
+ nn.init.orthogonal_(m.weight.data, gain)
+ m.bias.data.fill_(0.0)
+ # delta-orthogonal initialization.
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ assert m.weight.size(2) == m.weight.size(3)
+ m.weight.data.fill_(0.0)
+ m.bias.data.fill_(0.0)
+ mid = m.weight.size(2) // 2
+ nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
class Critic(nn.Module):
@@ -17,9 +31,8 @@ def __init__(
action_dim: int,
hidden_units: tuple[int, ...] = (256, 256),
hidden_activation: nn.Module = nn.ReLU(inplace=True),
- ):
+ ) -> None:
super().__init__()
-
self.q1 = MLP(
input_dim=state_dim + action_dim,
output_dim=1,
@@ -27,7 +40,7 @@ def __init__(
hidden_activation=hidden_activation,
)
- def forward(self, states: t.Tensor, actions: t.Tensor):
+ def forward(self, states: t.Tensor, actions: t.Tensor) -> t.Tensor:
x = t.cat([states, actions], dim=-1)
return self.q1(x)
@@ -45,7 +58,6 @@ def __init__(
hidden_activation: nn.Module = nn.ReLU(inplace=True),
):
super().__init__()
-
self.q1 = MLP(
input_dim=state_dim + action_dim,
output_dim=1,
@@ -101,7 +113,7 @@ def __init__(
state_dim: int,
action_dim: int,
hidden_units: tuple[int, ...] = (256, 256),
- hidden_activation=nn.ReLU(inplace=True),
+ hidden_activation: nn.Module = nn.ReLU(inplace=True),
max_action: float = 1.0,
expl_noise: float = 0.1,
device: str = "cpu",
@@ -113,7 +125,7 @@ def __init__(
output_dim=action_dim,
hidden_units=hidden_units,
hidden_activation=hidden_activation,
- ).apply(initialize_weight)
+ ).apply(initialize_weight_orthogonal)
self._device = device
self._action_shape = action_dim
@@ -123,28 +135,36 @@ def __init__(
def forward(self, states: t.Tensor) -> t.Tensor:
return t.tanh(self.mlp(states))
- def exploit(self, state: npt.ArrayLike) -> npt.NDArray:
- state = t.tensor(state).unsqueeze_(0).to(self._device)
- return self.forward(state).cpu().numpy().flatten()
-
- def explore(self, state: npt.ArrayLike) -> npt.NDArray:
- state = t.tensor(state, device=self._device).unsqueeze_(0)
-
+ def exploit(self, state: npt.NDArray) -> npt.NDArray:
+ state_tensor = t.tensor(state).unsqueeze_(0).to(self._device)
with t.no_grad():
- noise = (t.randn(self._action_shape) * self._expl_noise).to(self._device)
- action = self.mlp(state) + noise
+ action = self.forward(state_tensor)
+ return action.cpu().numpy().flatten()
- a = action.cpu().numpy()[0]
- return np.clip(a, -self._max_action, self._max_action)
+ def explore(self, state: npt.NDArray) -> npt.NDArray:
+ state_tensor = t.tensor(state, device=self._device).unsqueeze_(0)
+ noise = (t.randn(self._action_shape) * self._expl_noise).to(self._device)
+ with t.no_grad():
+ action = self.mlp(state_tensor) + noise
+ action = action.cpu().numpy()[0]
+ return np.clip(action, -self._max_action, self._max_action)
class GaussianActor(nn.Module):
- def __init__(self, state_dim, action_dim, hidden_units, hidden_activation):
+ def __init__(
+ self,
+ state_dim: int,
+ action_dim: int,
+ hidden_units: tuple[int, ...],
+ hidden_activation: nn.Module,
+ device: str,
+ ):
super().__init__()
self.action_dim = action_dim
self.net = MLP(
state_dim, 2 * action_dim, hidden_units, hidden_activation=hidden_activation
)
+ self.device = device
def forward(self, obs: t.Tensor) -> tuple[t.Tensor, t.Tensor | None]:
mean, log_std = self.net(obs).split([self.action_dim, self.action_dim], dim=1)
@@ -161,13 +181,21 @@ def forward(self, obs: t.Tensor) -> tuple[t.Tensor, t.Tensor | None]:
log_prob = None
return action, log_prob
- @property
- def device(self):
- return next(self.parameters()).device
+ def explore(self, state: npt.NDArray) -> npt.NDArray:
+ state_tensor = t.tensor(state, device=self.device).unsqueeze_(0)
+ with t.no_grad():
+ action, _ = self.forward(state_tensor)
+ return action.cpu().numpy()[0]
+
+ def exploit(self, state: npt.NDArray) -> npt.NDArray:
+ self.eval()
+ action = self.explore(state)
+ self.train()
+ return action
-class TanhNormal(Distribution):
- def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str):
+class TanhNormal:
+ def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str) -> None:
super().__init__()
self.normal_mean = normal_mean
self.normal_std = normal_std
@@ -179,8 +207,7 @@ def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str):
def log_prob(self, pre_tanh: t.Tensor) -> t.Tensor:
log_det = 2 * np.log(2) + logsigmoid(2 * pre_tanh) + logsigmoid(-2 * pre_tanh)
- result = self.normal.log_prob(pre_tanh) - log_det
- return result
+ return self.normal.log_prob(pre_tanh) - log_det
def rsample(self) -> tuple[t.Tensor, t.Tensor]:
pretanh = self.normal_mean + self.normal_std * self.standard_normal.sample()
diff --git a/src/oprl/algos/protocols.py b/src/oprl/algos/protocols.py
new file mode 100644
index 0000000..025b490
--- /dev/null
+++ b/src/oprl/algos/protocols.py
@@ -0,0 +1,42 @@
+from typing import Protocol, Any
+
+import numpy.typing as npt
+
+import torch as t
+import torch.nn as nn
+
+from oprl.logging import LoggerProtocol
+
+
+class PolicyProtocol(Protocol):
+ def explore(self, state: npt.NDArray) -> npt.NDArray: ...
+
+ def exploit(self, state: npt.NDArray) -> npt.NDArray: ...
+
+ def __call__(*args, **kwargs) -> t.Tensor: ...
+
+ def state_dict(self) -> dict: ...
+
+
+class AlgorithmProtocol(Protocol):
+ actor: PolicyProtocol
+ critic: nn.Module
+ logger: LoggerProtocol
+ _created: bool
+
+ def create(self) -> "AlgorithmProtocol": ...
+
+ def check_created(self) -> None: ...
+
+ def update(
+ self,
+ state: t.Tensor,
+ action: t.Tensor,
+ reward: t.Tensor,
+ done: t.Tensor,
+ next_state: t.Tensor
+ ) -> None: ...
+
+ def get_policy_state_dict(self) -> dict[str, Any]:
+ return self.actor.state_dict()
+
diff --git a/src/oprl/algos/sac.py b/src/oprl/algos/sac.py
index aa1ea04..a22fcec 100644
--- a/src/oprl/algos/sac.py
+++ b/src/oprl/algos/sac.py
@@ -1,20 +1,21 @@
from copy import deepcopy
-from dataclasses import dataclass
+from dataclasses import dataclass, field
import numpy as np
-import numpy.typing as npt
import torch as t
from torch import nn
from torch.optim import Adam
-from oprl.algos import OffPolicyAlgorithm
-from oprl.algos.nn import DoubleCritic, GaussianActor
-from oprl.algos.utils import disable_gradient, soft_update
-from oprl.utils.logger import Logger, StdLogger
+from oprl.algos.protocols import PolicyProtocol
+from oprl.algos.base_algorithm import OffPolicyAlgorithm
+from oprl.algos.nn_models import DoubleCritic, GaussianActor
+from oprl.algos.nn_functions import disable_gradient, soft_update
+from oprl.logging import LoggerProtocol
@dataclass
class SAC(OffPolicyAlgorithm):
+ logger: LoggerProtocol
state_dim: int
action_dim: int
batch_size: int = 256
@@ -27,7 +28,16 @@ class SAC(OffPolicyAlgorithm):
target_update_coef: float = 5e-3
device: str = "cpu"
log_every: int = 5000
- logger: Logger = StdLogger()
+
+ actor: PolicyProtocol = field(init=False)
+ actor_target: PolicyProtocol = field(init=False)
+ optim_actor: t.optim.Optimizer = field(init=False)
+ critic: nn.Module = field(init=False)
+ critic_target: nn.Module = field(init=False)
+ optim_critic: t.optim.Optimizer = field(init=False)
+ alpha: float = field(init=False)
+ update_step: int = 0
+ _created: bool = False
def create(self) -> "SAC":
self.actor = GaussianActor(
@@ -35,6 +45,7 @@ def create(self) -> "SAC":
action_dim=self.action_dim,
hidden_units=(256, 256),
hidden_activation=nn.ReLU(inplace=True),
+ device=self.device,
).to(self.device)
self.critic = DoubleCritic(
@@ -50,15 +61,15 @@ def create(self) -> "SAC":
self.optim_actor = Adam(self.actor.parameters(), lr=self.lr_actor)
self.optim_critic = Adam(self.critic.parameters(), lr=self.lr_critic)
- self._alpha = self.alpha_init
+ self.alpha = self.alpha_init
if self.tune_alpha:
self.log_alpha = t.tensor(
- np.log(self._alpha), device=self.device, requires_grad=True
+ np.log(self.alpha), device=self.device, requires_grad=True
)
self.optim_alpha = t.optim.Adam([self.log_alpha], lr=self.lr_alpha)
self.target_entropy = -float(self.action_dim)
- self.update_step = 0
+ self._created = True
return self
def update(
@@ -72,7 +83,6 @@ def update(
self.update_critic(state, action, reward, done, next_state)
self.update_actor(state)
soft_update(self.critic_target, self.critic, self.target_update_coef)
-
self.update_step += 1
def update_critic(
@@ -87,7 +97,7 @@ def update_critic(
with t.no_grad():
next_actions, log_pis = self.actor(next_states)
q1_next, q2_next = self.critic_target(next_states, next_actions)
- q_next = t.min(q1_next, q2_next) - self._alpha * log_pis
+ q_next = t.min(q1_next, q2_next) - self.alpha * log_pis
q_target = rewards + (1.0 - dones) * self.gamma * q_next
@@ -113,7 +123,7 @@ def update_critic(
def update_actor(self, state: t.Tensor) -> None:
actions, log_pi = self.actor(state)
qs1, qs2 = self.critic(state, actions)
- loss_actor = self._alpha * log_pi.mean() - t.min(qs1, qs2).mean()
+ loss_actor = self.alpha * log_pi.mean() - t.min(qs1, qs2).mean()
self.optim_actor.zero_grad()
loss_actor.backward()
@@ -128,7 +138,7 @@ def update_actor(self, state: t.Tensor) -> None:
loss_alpha.backward()
self.optim_alpha.step()
with t.no_grad():
- self._alpha = self.log_alpha.exp().item()
+ self.alpha = self.log_alpha.exp().item()
if self.update_step % self.log_every == 0:
if self.tune_alpha:
@@ -138,20 +148,8 @@ def update_actor(self, state: t.Tensor) -> None:
self.logger.log_scalars(
{
"algo/loss_actor": loss_actor.item(),
- "algo/alpha": self._alpha,
+ "algo/alpha": self.alpha,
"algo/log_pi": log_pi.cpu().mean(),
},
self.update_step,
)
-
- def explore(self, state: npt.ArrayLike) -> npt.ArrayLike:
- state = t.tensor(state, device=self.device).unsqueeze_(0)
- with t.no_grad():
- action, _ = self.actor(state)
- return action.cpu().numpy()[0]
-
- def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike:
- self.actor.eval()
- action = self.explore(state)
- self.actor.train()
- return action
diff --git a/src/oprl/algos/td3.py b/src/oprl/algos/td3.py
index b912520..066b930 100644
--- a/src/oprl/algos/td3.py
+++ b/src/oprl/algos/td3.py
@@ -1,20 +1,20 @@
from copy import deepcopy
-from dataclasses import dataclass
+from dataclasses import dataclass, field
-import numpy as np
-import numpy.typing as npt
import torch as t
from torch import nn
from torch.optim import Adam
-from oprl.algos import OffPolicyAlgorithm
-from oprl.algos.nn import DeterministicPolicy, DoubleCritic
-from oprl.algos.utils import disable_gradient, soft_update
-from oprl.utils.logger import Logger, StdLogger
+from oprl.algos.protocols import PolicyProtocol
+from oprl.algos.base_algorithm import OffPolicyAlgorithm
+from oprl.algos.nn_models import DeterministicPolicy, DoubleCritic
+from oprl.algos.nn_functions import disable_gradient, soft_update
+from oprl.logging import LoggerProtocol
@dataclass
class TD3(OffPolicyAlgorithm):
+ logger: LoggerProtocol
state_dim: int
action_dim: int
batch_size: int = 256
@@ -22,15 +22,22 @@ class TD3(OffPolicyAlgorithm):
expl_noise: float = 0.1
noise_clip: float = 0.5
policy_freq: int = 2
- discount: float = 0.99
+ gamma: float = 0.99
lr_actor: float = 3e-4
lr_critic: float = 3e-4
max_action: float = 1.0
tau: float = 5e-3
log_every: int = 5000
device: str = "cpu"
- logger: Logger = StdLogger()
+
+ actor: PolicyProtocol = field(init=False)
+ actor_target: PolicyProtocol = field(init=False)
+ optim_actor: t.optim.Optimizer = field(init=False)
+ critic: nn.Module = field(init=False)
+ critic_target: nn.Module = field(init=False)
+ optim_critic: t.optim.Optimizer = field(init=False)
update_step: int = 0
+ _created: bool = False
def create(self) -> "TD3":
self.actor = DeterministicPolicy(
@@ -38,6 +45,8 @@ def create(self) -> "TD3":
action_dim=self.action_dim,
hidden_units=(256, 256),
hidden_activation=nn.ReLU(inplace=True),
+ expl_noise=self.expl_noise,
+ device=self.device,
).to(self.device)
self.actor_target = deepcopy(self.actor).to(self.device).eval()
disable_gradient(self.actor_target)
@@ -54,34 +63,25 @@ def create(self) -> "TD3":
disable_gradient(self.critic_target)
self.optim_critic = Adam(self.critic.parameters(), lr=self.lr_critic)
- return self
-
- def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike:
- state = t.tensor(state, device=self.device).unsqueeze_(0)
- with t.no_grad():
- action = self.actor(state)
- return action.cpu().numpy().flatten()
- def explore(self, state: npt.ArrayLike) -> npt.ArrayLike:
- state = t.tensor(state, device=self.device).unsqueeze_(0)
- noise = (t.randn(self.action_dim) * self.max_action * self.expl_noise).to(
- self.device
- )
-
- with t.no_grad():
- action = self.actor(state) + noise
+ self._created = True
+ return self
- a = action.cpu().numpy()[0]
- return np.clip(a, -self.max_action, self.max_action)
- def update(self, state: t.Tensor, action, reward, done, next_state) -> None:
+ def update(
+ self,
+ state: t.Tensor,
+ action: t.Tensor,
+ reward: t.Tensor,
+ done: t.Tensor,
+ next_state: t.Tensor,
+ ) -> None:
self._update_critic(state, action, reward, done, next_state)
if self.update_step % self.policy_freq == 0:
self._update_actor(state)
soft_update(self.critic_target, self.critic, self.tau)
soft_update(self.actor_target, self.actor, self.tau)
-
self.update_step += 1
def _update_critic(
@@ -105,7 +105,7 @@ def _update_critic(
q1_next, q2_next = self.critic_target(next_state, next_actions)
q_next = t.min(q1_next, q2_next)
- q_target = reward + (1.0 - done) * self.discount * q_next
+ q_target = reward + (1.0 - done) * self.gamma * q_next
td_error1 = (q1 - q_target).pow(2).mean()
td_error2 = (q2 - q_target).pow(2).mean()
diff --git a/src/oprl/algos/tqc.py b/src/oprl/algos/tqc.py
index 8ac5e40..025c184 100644
--- a/src/oprl/algos/tqc.py
+++ b/src/oprl/algos/tqc.py
@@ -1,26 +1,22 @@
import copy
-from dataclasses import dataclass
+from dataclasses import dataclass, field
import numpy as np
-import numpy.typing as npt
import torch as t
import torch.nn as nn
-from oprl.algos import OffPolicyAlgorithm
-from oprl.algos.nn import MLP, GaussianActor
-from oprl.utils.logger import Logger, StdLogger
+from oprl.algos.protocols import PolicyProtocol
+from oprl.algos.base_algorithm import OffPolicyAlgorithm
+from oprl.algos.nn_models import MLP, GaussianActor
+from oprl.logging import LoggerProtocol
def quantile_huber_loss_f(
quantiles: t.Tensor, samples: t.Tensor, device: str
) -> t.Tensor:
"""
- Args:
- quantiles: [batch, n_nets, n_quantiles].
- samples: [batch, n_nets * n_quantiles - top_quantiles_to_drop].
-
- Returns:
- loss as a torch value.
+ quantiles: [batch, n_nets, n_quantiles].
+ samples: [batch, n_nets * n_quantiles - top_quantiles_to_drop].
"""
pairwise_delta = (
samples[:, None, None, :] - quantiles[:, :, :, None]
@@ -41,7 +37,7 @@ def quantile_huber_loss_f(
class QuantileQritic(nn.Module):
- def __init__(self, state_dim: int, action_dim: int, n_quantiles: int, n_nets: int):
+ def __init__(self, state_dim: int, action_dim: int, n_quantiles: int, n_nets: int) -> None:
super().__init__()
self.nets = []
self.n_quantiles = n_quantiles
@@ -64,17 +60,31 @@ def forward(self, state: t.Tensor, action: t.Tensor) -> t.Tensor:
@dataclass
class TQC(OffPolicyAlgorithm):
+ logger: LoggerProtocol
state_dim: int
action_dim: int
- discount: float = 0.99
+ gamma: float = 0.99
+ lr_actor = 3e-4
+ lr_critic = 3e-4
+ lr_alpha = 3e-4
tau: float = 0.005
top_quantiles_to_drop: int = 2
n_quantiles: int = 25
n_nets: int = 5
log_every: int = 5000
device: str = "cpu"
- logger: Logger = StdLogger()
- update_step = 0
+
+ actor: PolicyProtocol = field(init=False)
+ actor_target: PolicyProtocol = field(init=False)
+ actor_optimizer: t.optim.Optimizer = field(init=False)
+ critic: QuantileQritic = field(init=False)
+ critic_target: QuantileQritic = field(init=False)
+ critic_optimizer: t.optim.Optimizer = field(init=False)
+ target_entropy: float = field(init=False)
+ alpha_optimizer: t.optim.Optimizer = field(init=False)
+ quantiles_total: int = field(init=False)
+ update_step: int = 0
+ _created: bool = False
def create(self) -> "TQC":
self.target_entropy = -np.prod(self.action_dim).item()
@@ -83,6 +93,7 @@ def create(self) -> "TQC":
self.action_dim,
hidden_units=(256, 256),
hidden_activation=nn.ReLU(),
+ device=self.device,
).to(self.device)
self.critic = QuantileQritic(
self.state_dim,
@@ -92,13 +103,14 @@ def create(self) -> "TQC":
).to(self.device)
self.critic_target = copy.deepcopy(self.critic)
self.log_alpha = t.tensor(np.log(0.2), requires_grad=True, device=self.device)
- self._quantiles_total = self.critic.n_quantiles * self.critic.n_nets
+ self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets
# TODO: check hyperparams
- self.actor_optimizer = t.optim.Adam(self.actor.parameters(), lr=3e-4)
- self.critic_optimizer = t.optim.Adam(self.critic.parameters(), lr=3e-4)
- self.alpha_optimizer = t.optim.Adam([self.log_alpha], lr=3e-4)
+ self.actor_optimizer = t.optim.Adam(self.actor.parameters(), lr=self.lr_actor)
+ self.critic_optimizer = t.optim.Adam(self.critic.parameters(), lr=self.lr_critic)
+ self.alpha_optimizer = t.optim.Adam([self.log_alpha], lr=self.lr_alpha)
+ self._created = True
return self
def update(
@@ -124,11 +136,11 @@ def update(
) # batch x nets x quantiles
sorted_z, _ = t.sort(next_z.reshape(batch_size, -1))
sorted_z_part = sorted_z[
- :, : self._quantiles_total - self.top_quantiles_to_drop
+ :, : self.quantiles_total - self.top_quantiles_to_drop
]
# compute target
- target = reward + (1 - done) * self.discount * (
+ target = reward + (1 - done) * self.gamma * (
sorted_z_part - alpha * next_log_pi
)
@@ -175,15 +187,3 @@ def update(
)
self.update_step += 1
-
- def explore(self, state: npt.ArrayLike) -> npt.ArrayLike:
- state = t.tensor(state, device=self.device).unsqueeze_(0)
- with t.no_grad():
- action, _ = self.actor(state)
- return action.cpu().numpy()[0]
-
- def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike:
- self.actor.eval()
- action = self.explore(state)
- self.actor.train()
- return action
diff --git a/src/oprl/algos/utils.py b/src/oprl/algos/utils.py
deleted file mode 100644
index 2ad4fd0..0000000
--- a/src/oprl/algos/utils.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class Clamp(nn.Module):
- def forward(self, log_stds):
- return log_stds.clamp_(-20, 2)
-
-
-def initialize_weight(m, gain=nn.init.calculate_gain("relu")):
- # Initialize linear layers with the orthogonal initialization.
- if isinstance(m, nn.Linear):
- nn.init.orthogonal_(m.weight.data, gain)
- m.bias.data.fill_(0.0)
-
- # Initialize conv layers with the delta-orthogonal initialization.
- elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
- assert m.weight.size(2) == m.weight.size(3)
- m.weight.data.fill_(0.0)
- m.bias.data.fill_(0.0)
- mid = m.weight.size(2) // 2
- nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
-
-
-def soft_update(target, source, tau):
- """Update target network using Polyak-Ruppert Averaging."""
- with torch.no_grad():
- for tgt, src in zip(target.parameters(), source.parameters()):
- tgt.data.mul_(1.0 - tau)
- tgt.data.add_(tau * src.data)
-
-
-def disable_gradient(network):
- """Disable gradient calculations of the network."""
- for param in network.parameters():
- param.requires_grad = False
diff --git a/src/oprl/buffers/__init__.py b/src/oprl/buffers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/oprl/buffers/episodic_buffer.py b/src/oprl/buffers/episodic_buffer.py
new file mode 100644
index 0000000..2c7bc69
--- /dev/null
+++ b/src/oprl/buffers/episodic_buffer.py
@@ -0,0 +1,140 @@
+from dataclasses import dataclass, field
+
+import numpy as np
+import numpy.typing as npt
+import torch as t
+
+from oprl.buffers.protocols import ReplayBufferProtocol
+
+
+Transition = tuple[npt.NDArray, npt.NDArray, float, bool, npt.NDArray]
+
+
+@dataclass
+class EpisodicReplayBuffer(ReplayBufferProtocol):
+ buffer_size_transitions: int
+ state_dim: int
+ action_dim: int
+ gamma: float = 0.99
+ max_episode_lenth: int = 1000
+ episodes_counter: int = 1
+ device: str = "cpu"
+
+ _tensors: dict[str, t.Tensor] = field(init=False)
+ _max_episodes: int = field(init=False)
+ _ep_pointer: int = 0
+ _number_transitions = 0
+ _created: bool = False
+
+ def create(self) -> "EpisodicReplayBuffer":
+ self._max_episodes = self.buffer_size_transitions // self.max_episode_lenth
+ self._tensors = {
+ "actions": t.empty(
+ (self._max_episodes, self.max_episode_lenth, self.action_dim),
+ dtype=t.float32,
+ device=self.device,
+ ),
+ "rewards": t.empty(
+ (self._max_episodes, self.max_episode_lenth, 1),
+ dtype=t.float32,
+ device=self.device
+ ),
+ "dones": t.empty(
+ (self._max_episodes, self.max_episode_lenth, 1),
+ dtype=t.float32,
+ device=self.device
+ ),
+ "states": t.empty(
+ (self._max_episodes, self.max_episode_lenth + 1, self.state_dim),
+ dtype=t.float32,
+ device=self.device,
+ ),
+ }
+ self.ep_lens = [0] * self._max_episodes
+ self._created = True
+ return self
+
+ def check_created(self) -> None:
+ if not self._created:
+ raise RuntimeError("Replay buffer has to be created with `.create()`.")
+
+ @property
+ def states(self) -> t.Tensor:
+ self.check_created()
+ return self._tensors["states"]
+
+ @property
+ def actions(self) -> t.Tensor:
+ self.check_created()
+ return self._tensors["actions"]
+
+ @property
+ def rewards(self) -> t.Tensor:
+ self.check_created()
+ return self._tensors["rewards"]
+
+ @property
+ def dones(self) -> t.Tensor:
+ self.check_created()
+ return self._tensors["dones"]
+
+ def add_transition(
+ self,
+ state: npt.NDArray,
+ action: npt.NDArray,
+ reward: float,
+ done: bool,
+ episode_done: bool | None = None
+ ) -> None:
+ self.states[self._ep_pointer, self.ep_lens[self._ep_pointer]].copy_(
+ t.from_numpy(state)
+ )
+ self.actions[self._ep_pointer, self.ep_lens[self._ep_pointer]].copy_(
+ t.from_numpy(action)
+ )
+ self.rewards[self._ep_pointer, self.ep_lens[self._ep_pointer]] = reward
+ self.dones[self._ep_pointer, self.ep_lens[self._ep_pointer]] = float(done)
+ self.ep_lens[self._ep_pointer] += 1
+ self._number_transitions = min(self._number_transitions + 1, self.buffer_size_transitions)
+ # TODO: Switch to the episodic append and remove condition below
+ if episode_done:
+ self._inc_episode()
+
+ def _inc_episode(self):
+ self._ep_pointer = (self._ep_pointer + 1) % self._max_episodes
+ self.episodes_counter = min(self.episodes_counter + 1, self._max_episodes)
+ self._number_transitions -= self.ep_lens[self._ep_pointer]
+ self.ep_lens[self._ep_pointer] = 0
+
+ def add_episode(self, episode: list[Transition]) -> None:
+ for s, a, r, d, _ in episode:
+ self.add_transition(s, a, r, d, episode_done=d)
+ self._inc_episode()
+
+ def _inds_to_episodic(self, inds: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]:
+ start_inds = np.cumsum([0] + self.ep_lens[: self.episodes_counter - 1])
+ end_inds = start_inds + np.array(self.ep_lens[: self.episodes_counter])
+ ep_inds = np.argmin(
+ inds.reshape(-1, 1) >= np.tile(end_inds, (len(inds), 1)), axis=1
+ )
+ step_inds = inds - start_inds[ep_inds]
+ return ep_inds, step_inds
+
+ def sample(self, batch_size: int) -> tuple[t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor]:
+ inds = np.random.randint(low=0, high=self._number_transitions, size=batch_size)
+ ep_inds, step_inds = self._inds_to_episodic(inds)
+
+ return (
+ self.states[ep_inds, step_inds],
+ self.actions[ep_inds, step_inds],
+ self.rewards[ep_inds, step_inds],
+ self.dones[ep_inds, step_inds],
+ self.states[ep_inds, step_inds + 1],
+ )
+
+ @property
+ def last_episode_length(self) -> int:
+ return self.ep_lens[self._ep_pointer]
+
+ def __len__(self) -> int:
+ return self._number_transitions
diff --git a/src/oprl/buffers/protocols.py b/src/oprl/buffers/protocols.py
new file mode 100644
index 0000000..beac0c0
--- /dev/null
+++ b/src/oprl/buffers/protocols.py
@@ -0,0 +1,27 @@
+from typing import Protocol, runtime_checkable
+
+import torch as t
+
+
+@runtime_checkable
+class ReplayBufferProtocol(Protocol):
+ episodes_counter: int
+ _created: bool
+
+ def create(self) -> "ReplayBufferProtocol": ...
+
+ def check_created(self) -> None: ...
+
+ def add_transition(self, state, action, reward, done, episode_done=None): ...
+
+ def add_episode(self, episode): ...
+
+ def sample(self, batch_size) -> tuple[
+ t.Tensor, t.Tensor, t.Tensor, t.Tensor, t.Tensor
+ ]: ...
+
+ def __len__(self) -> int: ...
+
+ @property
+ def last_episode_length(self) -> int: ...
+
diff --git a/src/oprl/configs/d3pg.py b/src/oprl/configs/d3pg.py
deleted file mode 100644
index 93b60cc..0000000
--- a/src/oprl/configs/d3pg.py
+++ /dev/null
@@ -1,113 +0,0 @@
-import argparse
-import logging
-from multiprocessing import Process
-
-import torch.nn as nn
-
-from oprl.algos.ddpg import DDPG, DeterministicPolicy
-from oprl.configs.utils import create_logdir
-from oprl.distrib.distrib_runner import env_worker, policy_update_worker
-from oprl.utils.utils import set_logging
-
-set_logging(logging.INFO)
-from oprl.env import make_env as _make_env
-from oprl.trainers.buffers.episodic_buffer import EpisodicReplayBuffer
-from oprl.utils.logger import FileLogger, Logger
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Run training")
- parser.add_argument("--config", type=str, help="Path to the config file.")
- parser.add_argument(
- "--env", type=str, default="cartpole-balance", help="Name of the environment."
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Device to perform training on."
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- return parser.parse_args()
-
-
-# -------- Distrib params -----------
-
-ENV_WORKERS = 4
-N_EPISODES = 50 # 500 # Number of episodes each env worker would perform
-
-# -----------------------------------
-
-args = parse_args()
-
-
-def make_env(seed: int):
- return _make_env(args.env, seed=seed)
-
-
-env = make_env(seed=0)
-STATE_DIM = env.observation_space.shape[0]
-ACTION_DIM = env.action_space.shape[0]
-logging.info(f"Env state {STATE_DIM}\tEnv action {ACTION_DIM}")
-
-
-log_dir = create_logdir(logdir="logs", algo="D3PG", env=args.env, seed=args.seed)
-logging.info(f"LOG_DIR: {log_dir}")
-
-
-def make_logger(seed: int) -> Logger:
- log_dir = create_logdir(logdir="logs", algo="D3PG", env=args.env, seed=seed)
- # TODO: add here actual config
- return FileLogger(log_dir, {})
-
-
-def make_policy():
- return DeterministicPolicy(
- state_dim=STATE_DIM,
- action_dim=ACTION_DIM,
- hidden_units=(256, 256),
- hidden_activation=nn.ReLU(inplace=True),
- device=args.device,
- )
-
-
-def make_buffer():
- return EpisodicReplayBuffer(
- buffer_size=int(1_000_000),
- state_dim=STATE_DIM,
- action_dim=ACTION_DIM,
- device=args.device,
- gamma=0.99,
- )
-
-
-def make_algo():
- logger = make_logger(args.seed)
-
- algo = DDPG(
- state_dim=STATE_DIM,
- action_dim=ACTION_DIM,
- device=args.device,
- logger=logger,
- )
- return algo
-
-
-if __name__ == "__main__":
- processes = []
-
- for i_env in range(ENV_WORKERS):
- processes.append(
- Process(target=env_worker, args=(make_env, make_policy, N_EPISODES, i_env))
- )
- processes.append(
- Process(
- target=policy_update_worker,
- args=(make_algo, make_env, make_buffer, ENV_WORKERS),
- )
- )
-
- for p in processes:
- p.start()
-
- for p in processes:
- p.join()
-
- logging.info("Training OK.")
diff --git a/src/oprl/configs/ddpg.py b/src/oprl/configs/ddpg.py
deleted file mode 100644
index 22437f0..0000000
--- a/src/oprl/configs/ddpg.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import logging
-
-from oprl.algos.ddpg import DDPG
-from oprl.configs.utils import create_logdir, parse_args
-from oprl.utils.utils import set_logging
-
-set_logging(logging.INFO)
-from oprl.env import make_env as _make_env
-from oprl.utils.logger import FileLogger, Logger
-from oprl.utils.run_training import run_training
-
-args = parse_args()
-
-
-def make_env(seed: int):
- return _make_env(args.env, seed=seed)
-
-
-env = make_env(seed=0)
-STATE_DIM: int = env.observation_space.shape[0]
-ACTION_DIM: int = env.action_space.shape[0]
-
-
-# -------- Config params -----------
-
-config = {
- "state_dim": STATE_DIM,
- "action_dim": ACTION_DIM,
- "num_steps": int(100_000),
- "eval_every": 2500,
- "device": args.device,
- "save_buffer": False,
- "visualise_every": 50000,
- "estimate_q_every": 5000,
- "log_every": 2500,
-}
-
-# -----------------------------------
-
-
-def make_algo(logger):
- return DDPG(
- state_dim=STATE_DIM,
- action_dim=ACTION_DIM,
- device=args.device,
- logger=logger,
- )
-
-
-def make_logger(seed: int) -> Logger:
- log_dir = create_logdir(logdir="logs", algo="DDPG", env=args.env, seed=seed)
- return FileLogger(log_dir, config)
-
-
-if __name__ == "__main__":
- args = parse_args()
- run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed)
diff --git a/src/oprl/configs/sac.py b/src/oprl/configs/sac.py
deleted file mode 100644
index 475fd0e..0000000
--- a/src/oprl/configs/sac.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import logging
-
-from oprl.algos.sac import SAC
-from oprl.configs.utils import create_logdir, parse_args
-from oprl.utils.utils import set_logging
-
-set_logging(logging.INFO)
-from oprl.env import make_env as _make_env
-from oprl.utils.logger import FileLogger, Logger
-from oprl.utils.run_training import run_training
-
-logging.basicConfig(level=logging.INFO)
-
-args = parse_args()
-
-
-def make_env(seed: int):
- return _make_env(args.env, seed=seed)
-
-
-env = make_env(seed=0)
-STATE_DIM: int = env.observation_space.shape[0]
-ACTION_DIM: int = env.action_space.shape[0]
-
-
-# -------- Config params -----------
-
-config = {
- "state_dim": STATE_DIM,
- "action_dim": ACTION_DIM,
- "num_steps": int(1_000_000),
- "eval_every": 2500,
- "device": args.device,
- "save_buffer": False,
- "visualise_every": 0,
- "estimate_q_every": 5000,
- "log_every": 1000,
-}
-
-# -----------------------------------
-
-
-def make_algo(logger):
- return SAC(
- state_dim=STATE_DIM,
- action_dim=ACTION_DIM,
- device=args.device,
- logger=logger,
- )
-
-
-def make_logger(seed: int) -> Logger:
- log_dir = create_logdir(logdir="logs", algo="SAC", env=args.env, seed=seed)
- return FileLogger(log_dir, config)
-
-
-if __name__ == "__main__":
- args = parse_args()
- run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed)
diff --git a/src/oprl/configs/td3.py b/src/oprl/configs/td3.py
deleted file mode 100644
index c8dac65..0000000
--- a/src/oprl/configs/td3.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import logging
-
-from oprl.algos.td3 import TD3
-from oprl.configs.utils import create_logdir, parse_args
-from oprl.utils.utils import set_logging
-
-set_logging(logging.INFO)
-from oprl.env import make_env as _make_env
-from oprl.utils.logger import FileLogger, Logger
-from oprl.utils.run_training import run_training
-
-args = parse_args()
-
-
-def make_env(seed: int):
- return _make_env(args.env, seed=seed)
-
-
-env = make_env(seed=0)
-STATE_DIM: int = env.observation_space.shape[0]
-ACTION_DIM: int = env.action_space.shape[0]
-
-
-# -------- Config params -----------
-
-config = {
- "state_dim": STATE_DIM,
- "action_dim": ACTION_DIM,
- "num_steps": int(1_000_000),
- "eval_every": 2500,
- "device": args.device,
- "save_buffer": False,
- "visualise_every": 0,
- "estimate_q_every": 5000,
- "log_every": 2500,
-}
-
-# -----------------------------------
-
-
-def make_algo(logger):
- return TD3(
- state_dim=STATE_DIM,
- action_dim=ACTION_DIM,
- device=args.device,
- logger=logger,
- )
-
-
-def make_logger(seed: int) -> Logger:
- log_dir = create_logdir(logdir="logs", algo="TD3", env=args.env, seed=seed)
- return FileLogger(log_dir, config)
-
-
-if __name__ == "__main__":
- args = parse_args()
- run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed)
diff --git a/src/oprl/configs/tqc.py b/src/oprl/configs/tqc.py
deleted file mode 100644
index 640071e..0000000
--- a/src/oprl/configs/tqc.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import logging
-
-from oprl.algos.tqc import TQC
-from oprl.configs.utils import create_logdir, parse_args
-from oprl.utils.utils import set_logging
-
-set_logging(logging.INFO)
-from oprl.env import make_env as _make_env
-from oprl.utils.logger import FileLogger, Logger
-from oprl.utils.run_training import run_training
-
-args = parse_args()
-
-
-def make_env(seed: int):
- return _make_env(args.env, seed=seed)
-
-
-env = make_env(seed=0)
-STATE_DIM: int = env.observation_space.shape[0]
-ACTION_DIM: int = env.action_space.shape[0]
-
-
-# -------- Config params -----------
-
-config = {
- "state_dim": STATE_DIM,
- "action_dim": ACTION_DIM,
- "num_steps": int(1_000_000),
- "eval_every": 2500,
- "device": args.device,
- "save_buffer": False,
- "visualise_every": 0,
- "estimate_q_every": 0, # TODO: Here is the unsupported logic
- "log_every": 2500,
-}
-
-# -----------------------------------
-
-
-def make_algo(logger: Logger):
- return TQC(
- state_dim=STATE_DIM,
- action_dim=ACTION_DIM,
- device=args.device,
- logger=logger,
- )
-
-
-def make_logger(seed: int) -> Logger:
- log_dir = create_logdir(logdir="logs", algo="TQC", env=args.env, seed=seed)
- return FileLogger(log_dir, config)
-
-
-if __name__ == "__main__":
- args = parse_args()
- run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed)
diff --git a/src/oprl/distrib/__init__.py b/src/oprl/distrib/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/oprl/distrib/distrib_runner.py b/src/oprl/distrib/distrib_runner.py
deleted file mode 100644
index 58698be..0000000
--- a/src/oprl/distrib/distrib_runner.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import logging
-import pickle
-import time
-from itertools import count
-from multiprocessing import Process
-
-import numpy as np
-import pika
-import torch
-
-
-class Queue:
- def __init__(self, name: str, host: str = "localhost"):
- self._name = name
-
- connection = pika.BlockingConnection(pika.ConnectionParameters(host=host))
- self.channel = connection.channel()
- self.channel.queue_declare(queue=name)
-
- def push(self, data) -> None:
- self.channel.basic_publish(exchange="", routing_key=self._name, body=data)
-
- def pop(self) -> bytes | None:
- method_frame, header_frame, body = self.channel.basic_get(queue=self._name)
- if method_frame:
- self.channel.basic_ack(method_frame.delivery_tag)
- return body
- return None
-
-
-def env_worker(make_env, make_policy, n_episodes, id_worker):
- env = make_env(seed=0)
- logging.info("Env created.")
-
- policy = make_policy()
- logging.info("Policy created.")
-
- q_env = Queue(f"env_{id_worker}")
- q_policy = Queue(f"policy_{id_worker}")
- logging.info("Queue created.")
-
- episodes = []
-
- total_env_step = 0
- # TODO: Move parameter to config
- start_steps = 1000
- for i_ep in range(n_episodes):
- if i_ep % 10 == 0:
- logging.info(f"AGENT {id_worker} EPISODE {i_ep}")
-
- episode = []
- state, _ = env.reset()
- # TODO: Move parameter to config
- for env_step in range(1000):
- if total_env_step <= start_steps:
- action = env.sample_action()
- else:
- action = policy.explore(state)
-
- next_state, reward, terminated, truncated, _ = env.step(action)
- episode.append([state, action, reward, terminated, next_state])
-
- if terminated or truncated:
- break
- state = next_state
- total_env_step += 1
-
- q_env.push(pickle.dumps(episode))
-
- while True:
- data = q_policy.pop()
- if data is None:
- logging.info("Waiting for the policy..")
- time.sleep(2.0)
- continue
-
- policy.load_state_dict(pickle.loads(data))
- break
-
- logging.info("Episode by env worker is done.")
-
-
-def policy_update_worker(make_algo, make_env_test, make_buffer, n_workers):
- algo = make_algo()
- logging.info("Algo created.")
- buffer = make_buffer()
- logging.info("Buffer created.")
-
- q_envs = []
- q_policies = []
- for i_env in range(n_workers):
- q_envs.append(Queue(f"env_{i_env}"))
- q_policies.append(Queue(f"policy_{i_env}"))
- logging.info("Learner queue created.")
-
- batch_size = 128
-
- logging.info("Warming up the learner...")
- time.sleep(2.0)
-
- for i_epoch in count(0):
- logging.info(f"Epoch: {i_epoch}")
- n_waits = 0
- for i_env in range(n_workers):
- while True:
- data = q_envs[i_env].pop()
- if data:
- episode = pickle.loads(data)
- buffer.add_episode(episode)
- break
- else:
- logging.info("Waiting for the env data...")
- # TODO: not optimal wait for each queue
- time.sleep(1)
- n_waits += 1
- if n_waits == 10:
- logging.info("Learner tired to wait, exiting...")
- return
- continue
-
- # TODO: Remove hardcoded value
- if i_epoch > 16:
- for i in range(1000 * 4):
- batch = buffer.sample(batch_size)
- algo.update(*batch)
- if i % int(1000) == 0:
- logging.info(f"\tUpdating {i}")
-
- policy_state_dict = algo.get_policy_state_dict()
-
- policy_serialized = pickle.dumps(policy_state_dict)
- for i_env in range(n_workers):
- q_policies[i_env].push(policy_serialized)
-
- if True:
- mean_reward = evaluate(algo, make_env_test)
- algo.logger.log_scalar("trainer/ep_reward", mean_reward, i_epoch)
-
- logging.info("Update by policy update worker done.")
-
-
-def evaluate(algo, make_env_test, num_eval_episodes: int = 5, seed: int = 0):
- returns = []
- for i_ep in range(num_eval_episodes):
- env_test = make_env_test(seed * 100 + i_ep)
- state, _ = env_test.reset()
-
- episode_return = 0.0
- terminated, truncated = False, False
-
- while not (terminated or truncated):
- action = algo.exploit(state)
- state, reward, terminated, truncated, _ = env_test.step(action)
- episode_return += reward
-
- returns.append(episode_return)
-
- mean_return = np.mean(returns)
- return mean_return
diff --git a/src/oprl/distrib/env_worker.py b/src/oprl/distrib/env_worker.py
new file mode 100644
index 0000000..67bcdeb
--- /dev/null
+++ b/src/oprl/distrib/env_worker.py
@@ -0,0 +1,65 @@
+import pickle
+import time
+from typing import Callable
+
+from oprl.algos.protocols import PolicyProtocol
+from oprl.environment.protocols import EnvProtocol
+from oprl.logging import create_stdout_logger
+from oprl.runners.config import DistribConfig
+from oprl.distrib.queue import Queue
+
+
+logger = create_stdout_logger()
+
+
+def run_env_worker(
+ make_env: Callable[[int], EnvProtocol],
+ make_policy: Callable[[], PolicyProtocol],
+ config: DistribConfig,
+ id_worker: int,
+) -> None:
+ env = make_env(seed=0)
+ logger.info("Env created.")
+
+ policy = make_policy()
+ logger.info("Policy created.")
+
+ q_env = Queue(f"env_{id_worker}")
+ q_policy = Queue(f"policy_{id_worker}")
+ logger.info("Queue created.")
+
+ total_env_step = 0
+ for i_ep in range(config.episodes_per_worker):
+ print("Running episode: ", i_ep)
+ if i_ep % 10 == 0:
+ logger.info(f"AGENT {id_worker} EPISODE {i_ep}")
+
+ episode = []
+ state, _ = env.reset()
+ for _ in range(config.episode_length):
+ if total_env_step <= config.warmup_env_steps:
+ action = env.sample_action()
+ else:
+ action = policy.explore(state)
+
+ next_state, reward, terminated, truncated, _ = env.step(action)
+ episode.append([state, action, reward, terminated, next_state])
+
+ if terminated or truncated:
+ break
+ state = next_state
+ total_env_step += 1
+
+ q_env.push(pickle.dumps(episode))
+
+ while True:
+ data = q_policy.pop()
+ if data is None:
+ logger.info("Waiting for the policy..")
+ time.sleep(2.0)
+ continue
+ policy.load_state_dict(pickle.loads(data))
+ break
+
+ logger.info("Episode by env worker is done.")
+
diff --git a/src/oprl/distrib/policy_update_worker.py b/src/oprl/distrib/policy_update_worker.py
new file mode 100644
index 0000000..4dea035
--- /dev/null
+++ b/src/oprl/distrib/policy_update_worker.py
@@ -0,0 +1,119 @@
+import pickle
+import time
+from itertools import count
+from pathlib import Path
+from typing import Callable
+
+import numpy as np
+import torch as t
+import torch.nn as nn
+
+from oprl.algos.protocols import AlgorithmProtocol
+from oprl.environment.protocols import EnvProtocol
+from oprl.logging import create_stdout_logger, LoggerProtocol
+from oprl.runners.config import DistribConfig
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.distrib.queue import Queue
+
+
+logger = create_stdout_logger()
+
+
+def run_policy_update_worker(
+ make_algo: Callable[[LoggerProtocol], AlgorithmProtocol],
+ make_env_test: Callable[[int], EnvProtocol],
+ make_buffer: Callable[[], ReplayBufferProtocol],
+ make_logger: Callable[[], LoggerProtocol],
+ config: DistribConfig,
+) -> None:
+ scalar_logger = make_logger()
+ algo = make_algo(scalar_logger)
+ logger.info("Algo created.")
+ buffer = make_buffer()
+ logger.info("Buffer created.")
+
+ q_envs = []
+ q_policies = []
+ for i_env in range(config.num_env_workers):
+ q_envs.append(Queue(f"env_{i_env}"))
+ q_policies.append(Queue(f"policy_{i_env}"))
+ logger.info("Learner queue created.")
+
+ logger.info("Warming up the learner...")
+ time.sleep(2.0)
+
+ for i_epoch in count(0):
+ logger.info(f"Epoch: {i_epoch}")
+ n_waits = 0
+ for i_env in range(config.num_env_workers):
+ while True:
+ data = q_envs[i_env].pop()
+ if data:
+ episode = pickle.loads(data)
+ buffer.add_episode(episode)
+ break
+ else:
+ logger.info("Waiting for the env data...")
+ # TODO: not optimal wait for each queue
+ time.sleep(1)
+ n_waits += 1
+ if n_waits == config.learner_num_waits:
+ logger.info("Learner is not receiving data, exiting...")
+ return
+ continue
+
+ if i_epoch > config.warmup_epochs:
+ for i in range(config.episode_length * config.num_env_workers):
+ batch = buffer.sample(config.batch_size)
+ algo.update(*batch)
+ if i % int(1000) == 0:
+ logger.info(f"\tUpdating {i}")
+
+ policy_state_dict = algo.get_policy_state_dict()
+
+ policy_serialized = pickle.dumps(policy_state_dict)
+ for i_env in range(config.num_env_workers):
+ q_policies[i_env].push(policy_serialized)
+
+
+ if i_epoch > 0 and i_epoch % 10 == 0:
+ mean_reward = evaluate(algo, make_env_test)
+ logger.info(f"Evaluating policy [epoch {i_epoch}]: {mean_reward}")
+ algo.logger.log_scalar("trainer/ep_reward", mean_reward, i_epoch)
+ save_policy(
+ policy=algo.actor,
+ save_path=algo.logger.log_dir / "weights" / f"epoch_{i_epoch}.w"
+ )
+ logger.info("Weights saved.")
+
+ logger.info("Update by policy update worker done.")
+
+
+def save_policy(policy: nn.Module, save_path: Path) -> None:
+ save_path.parents[0].mkdir(exist_ok=True)
+ t.save(
+ policy,
+ save_path
+ )
+
+
+def evaluate(
+ algo: AlgorithmProtocol,
+ make_env_test: Callable[[int], EnvProtocol],
+ num_eval_episodes: int = 5,
+ seed: int = 0
+) -> float:
+ returns = []
+ for i_ep in range(num_eval_episodes):
+ env_test = make_env_test(seed * 100 + i_ep)
+ state, _ = env_test.reset()
+
+ episode_return = 0.0
+ terminated, truncated = False, False
+ while not (terminated or truncated):
+ action = algo.actor.exploit(state)
+ state, reward, terminated, truncated, _ = env_test.step(action)
+ episode_return += reward
+ returns.append(episode_return)
+
+ return np.mean(returns)
diff --git a/src/oprl/distrib/queue.py b/src/oprl/distrib/queue.py
new file mode 100644
index 0000000..cfa9e91
--- /dev/null
+++ b/src/oprl/distrib/queue.py
@@ -0,0 +1,20 @@
+import pika
+
+
+class Queue:
+ def __init__(self, name: str, host: str = "localhost") -> None:
+ self._name = name
+ connection = pika.BlockingConnection(pika.ConnectionParameters(host=host))
+ self.channel = connection.channel()
+ self.channel.queue_declare(queue=name)
+
+ def push(self, data) -> None:
+ self.channel.basic_publish(exchange="", routing_key=self._name, body=data)
+
+ def pop(self) -> bytes | None:
+ method_frame, _, body = self.channel.basic_get(queue=self._name)
+ if method_frame:
+ self.channel.basic_ack(method_frame.delivery_tag)
+ return body
+ return None
+
diff --git a/src/oprl/distrib_train.py b/src/oprl/distrib_train.py
deleted file mode 100644
index 4810768..0000000
--- a/src/oprl/distrib_train.py
+++ /dev/null
@@ -1,108 +0,0 @@
-import argparse
-import os
-import time
-from datetime import datetime
-from multiprocessing import Process
-
-import torch
-import torch.nn as nn
-from algos.ddpg import DDPG, DeterministicPolicy
-from distrib.distrib_runner import env_worker, policy_update_worker
-from env import DMControlEnv, make_env
-from trainers.buffers.episodic_buffer import EpisodicReplayBuffer
-from utils.logger import Logger
-
-print("Imports ok.")
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Run training")
- parser.add_argument("--config", type=str, help="Path to the config file.")
- parser.add_argument(
- "--env", type=str, default="cartpole-balance", help="Name of the environment."
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Device to perform training on."
- )
- return parser.parse_args()
-
-
-args = parse_args()
-
-
-def make_env(seed: int):
- """
- Args:
- name: Environment name.
- """
- return DMControlEnv(args.env, seed=seed)
-
-
-env = make_env(seed=0)
-
-STATE_SHAPE = env.observation_space.shape
-ACTION_SHAPE = env.action_space.shape
-print("STATE ACTION SHAPE: ", STATE_SHAPE, ACTION_SHAPE)
-
-
-def make_policy():
- return DeterministicPolicy(
- state_dim=STATE_SHAPE,
- action_dim=ACTION_SHAPE,
- hidden_units=[256, 256],
- hidden_activation=nn.ReLU(inplace=True),
- )
-
-
-def make_buffer():
- buffer = EpisodicReplayBuffer(
- buffer_size=int(100_000),
- state_shape=STATE_SHAPE,
- action_shape=ACTION_SHAPE,
- device="cpu",
- gamma=0.99,
- )
- return buffer
-
-
-def make_algo():
- time = datetime.now().strftime("%Y-%m-%d_%H_%M")
- log_dir = os.path.join("logs_debug", "DDPG", f"DDPG-env_ENV-seedSEED-{time}")
- print("LOGDIR: ", log_dir)
- logger = Logger(log_dir, {})
-
- algo = DDPG(
- state_dim=STATE_SHAPE,
- action_dim=ACTION_SHAPE,
- device="cpu",
- seed=0,
- logger=logger,
- )
- return algo
-
-
-if __name__ == "__main__":
- ENV_WORKERS = 2
-
- seed = 0
-
- processes = []
-
- for i_env in range(ENV_WORKERS):
- processes.append(
- Process(target=env_worker, args=(make_env, make_policy, i_env))
- )
- processes.append(
- Process(
- target=policy_update_worker,
- args=(make_algo, make_env, make_buffer, ENV_WORKERS),
- )
- )
-
- for p in processes:
- p.start()
-
- for p in processes:
- p.join()
-
- print("OK.")
diff --git a/src/oprl/env.py b/src/oprl/env.py
deleted file mode 100644
index 6017ab2..0000000
--- a/src/oprl/env.py
+++ /dev/null
@@ -1,230 +0,0 @@
-from abc import ABC, abstractmethod
-from collections import OrderedDict
-from typing import Any
-
-import numpy as np
-import numpy.typing as npt
-from dm_control import suite
-
-
-class BaseEnv(ABC):
- @abstractmethod
- def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]:
- pass
-
- @abstractmethod
- def step(
- self, action: npt.ArrayLike
- ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]:
- pass
-
- @abstractmethod
- def sample_action(self) -> npt.ArrayLike:
- pass
-
- @property
- def env_family(self) -> str:
- return ""
-
-
-class DummyEnv(BaseEnv):
- def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]:
- return np.array([]), {}
-
- def step(
- self, action: npt.ArrayLike
- ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]:
- return np.array([]), np.array([]), False, False, {}
-
- def sample_action(self) -> npt.ArrayLike:
- return np.array([])
-
- @property
- def env_family(self) -> str:
- return ""
-
-
-class SafetyGym(BaseEnv):
- def __init__(self, env_name: str, seed: int):
- import safety_gymnasium as gym
-
- self._env = gym.make(env_name)
- self._seed = seed
-
- def step(
- self, action: npt.ArrayLike
- ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]:
- obs, reward, cost, terminated, truncated, info = self._env.step(action)
- info["cost"] = cost
- return obs.astype("float32"), reward, terminated, truncated, info
-
- def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]:
- obs, info = self._env.reset(seed=self._seed)
- self._env.step(self._env.action_space.sample())
- return obs.astype("float32"), info
-
- def sample_action(self):
- return self._env.action_space.sample()
-
- @property
- def observation_space(self):
- return self._env.observation_space
-
- @property
- def action_space(self):
- return self._env.action_space
-
- @property
- def env_family(self) -> str:
- return "safety_gymnasium"
-
-
-class DMControlEnv(BaseEnv):
- def __init__(self, env: str, seed: int):
- domain, task = env.split("-")
- self.random_state = np.random.RandomState(seed)
- self.env = suite.load(domain, task, task_kwargs={"random": self.random_state})
-
- self._render_width = 200
- self._render_height = 200
- self._camera_id = 0
-
- def reset(self, *args, **kwargs) -> tuple[npt.ArrayLike, dict[str, Any]]:
- obs = self._flat_obs(self.env.reset().observation)
- return obs, {}
-
- def step(
- self, action: npt.ArrayLike
- ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]:
- time_step = self.env.step(action)
- obs = self._flat_obs(time_step.observation)
-
- terminated = False
- truncated = self.env._step_count >= self.env._step_limit
-
- return obs, time_step.reward, terminated, truncated, {}
-
- def sample_action(self) -> npt.ArrayLike:
- spec = self.env.action_spec()
- action = self.random_state.uniform(spec.minimum, spec.maximum, spec.shape)
- return action
-
- @property
- def observation_space(self) -> npt.ArrayLike:
- return np.zeros(
- sum(int(np.prod(v.shape)) for v in self.env.observation_spec().values())
- )
-
- @property
- def action_space(self) -> npt.ArrayLike:
- return np.zeros(self.env.action_spec().shape[0])
-
- def render(self) -> npt.ArrayLike:
- """
- returned shape: [1, W, H, C]
- """
- img = self.env.physics.render(
- camera_id=self._camera_id,
- height=self._render_width,
- width=self._render_width,
- )
- img = img.astype(np.uint8)
- return img
-
- def _flat_obs(self, obs: OrderedDict) -> npt.ArrayLike:
- obs_flatten = []
- for _, o in obs.items():
- if len(o.shape) == 0:
- obs_flatten.append(np.array([o]))
- elif len(o.shape) == 2 and o.shape[1] > 1:
- obs_flatten.append(o.flatten())
- else:
- obs_flatten.append(o)
- return np.concatenate(obs_flatten, dtype="float32")
-
- @property
- def env_family(self) -> str:
- return "dm_control"
-
-
-ENV_MAPPER = {
- "dm_control": set(
- [
- "acrobot-swingup",
- "ball_in_cup-catch",
- "cartpole-balance",
- "cartpole-swingup",
- "cheetah-run",
- "finger-spin",
- "finger-turn_easy",
- "finger-turn_hard",
- "fish-upright",
- "fish-swim",
- "hopper-stand",
- "hopper-hop",
- "humanoid-stand",
- "humanoid-walk",
- "humanoid-run",
- "pendulum-swingup",
- "point_mass-easy",
- "reacher-easy",
- "reacher-hard",
- "swimmer-swimmer6",
- "swimmer-swimmer15",
- "walker-stand",
- "walker-walk",
- "walker-run",
- ]
- ),
- "safety_gymnasium": set(
- [
- "SafetyPointGoal1-v0",
- "SafetyPointGoal2-v0",
- "SafetyPointButton1-v0",
- "SafetyPointButton2-v0",
- "SafetyPointPush1-v0",
- "SafetyPointPush2-v0",
- "SafetyPointCircle1-v0",
- "SafetyPointCircle2-v0",
- "SafetyCarGoal1-v0",
- "SafetyCarGoal2-v0",
- "SafetyCarButton1-v0",
- "SafetyCarButton2-v0",
- "SafetyCarPush1-v0",
- "SafetyCarPush2-v0",
- "SafetyCarCircle1-v0",
- "SafetyCarCircle2-v0",
- "SafetyAntGoal1-v0",
- "SafetyAntGoal2-v0",
- "SafetyAntButton1-v0",
- "SafetyAntButton2-v0",
- "SafetyAntPush1-v0",
- "SafetyAntPush2-v0",
- "SafetyAntCircle1-v0",
- "SafetyAntCircle2-v0",
- "SafetyDoggoGoal1-v0",
- "SafetyDoggoGoal2-v0",
- "SafetyDoggoButton1-v0",
- "SafetyDoggoButton2-v0",
- "SafetyDoggoPush1-v0",
- "SafetyDoggoPush2-v0",
- "SafetyDoggoCircle1-v0",
- "SafetyDoggoCircle2-v0",
- ]
- ),
-}
-
-
-def make_env(name: str, seed: int):
- """
- Args:
- name: Environment name.
- """
- for env_type, env_set in ENV_MAPPER.items():
- if name in env_set:
- if env_type == "dm_control":
- return DMControlEnv(name, seed=seed)
- elif env_type == "safety_gymnasium":
- return SafetyGym(name, seed=seed)
- else:
- raise ValueError(f"Unsupported environment: {name}")
diff --git a/src/oprl/environment/__init__.py b/src/oprl/environment/__init__.py
new file mode 100644
index 0000000..3e0fa36
--- /dev/null
+++ b/src/oprl/environment/__init__.py
@@ -0,0 +1,8 @@
+from oprl.environment.protocols import EnvProtocol
+from oprl.environment.dm_control import DMControlEnv
+from oprl.environment.safety_gymnasium import SafetyGym
+from oprl.environment.make_env import make_env
+
+___all__ = [DMControlEnv, SafetyGym, make_env, EnvProtocol]
+
+
diff --git a/src/oprl/environment/dm_control.py b/src/oprl/environment/dm_control.py
new file mode 100644
index 0000000..124b6b8
--- /dev/null
+++ b/src/oprl/environment/dm_control.py
@@ -0,0 +1,73 @@
+from collections import OrderedDict
+from typing import Any
+
+import numpy as np
+import numpy.typing as npt
+from dm_control import suite
+
+from oprl.environment.protocols import EnvProtocol
+
+
+class DMControlEnv(EnvProtocol):
+ def __init__(self, env: str, seed: int) -> None:
+ domain, task = env.split("-")
+ self.random_state = np.random.RandomState(seed)
+ self.env = suite.load(domain, task, task_kwargs={"random": self.random_state})
+
+ self._render_width = 200
+ self._render_height = 200
+ self._camera_id = 0
+
+ def reset(self) -> tuple[npt.NDArray, dict[str, Any]]:
+ obs = self._flat_obs(self.env.reset().observation)
+ return obs, {}
+
+ def step(
+ self, action: npt.NDArray
+ ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]:
+ time_step = self.env.step(action)
+ obs = self._flat_obs(time_step.observation)
+
+ terminated = False
+ truncated = self.env._step_count >= self.env._step_limit
+
+ return obs, time_step.reward, terminated, truncated, {}
+
+ def sample_action(self) -> npt.NDArray:
+ spec = self.env.action_spec()
+ action = self.random_state.uniform(spec.minimum, spec.maximum, spec.shape)
+ return action
+
+ @property
+ def observation_space(self) -> npt.NDArray:
+ return np.zeros(
+ sum(int(np.prod(v.shape)) for v in self.env.observation_spec().values())
+ )
+
+ @property
+ def action_space(self) -> npt.NDArray:
+ return np.zeros(self.env.action_spec().shape[0])
+
+ def render(self) -> npt.NDArray: # [1, W, H, C]
+ img = self.env.physics.render(
+ camera_id=self._camera_id,
+ height=self._render_width,
+ width=self._render_width,
+ )
+ img = img.astype(np.uint8)
+ return img
+
+ def _flat_obs(self, obs: OrderedDict) -> npt.NDArray:
+ obs_flatten = []
+ for _, o in obs.items():
+ if len(o.shape) == 0:
+ obs_flatten.append(np.array([o]))
+ elif len(o.shape) == 2 and o.shape[1] > 1:
+ obs_flatten.append(o.flatten())
+ else:
+ obs_flatten.append(o)
+ return np.concatenate(obs_flatten, dtype="float32")
+
+ @property
+ def env_family(self) -> str:
+ return "dm_control"
diff --git a/src/oprl/environment/gymnasium.py b/src/oprl/environment/gymnasium.py
new file mode 100644
index 0000000..60f802a
--- /dev/null
+++ b/src/oprl/environment/gymnasium.py
@@ -0,0 +1,42 @@
+import numpy.typing as npt
+from typing import Any
+
+import gymnasium as gym
+
+from oprl.environment.protocols import EnvProtocol
+
+
+class Gymnasium(EnvProtocol):
+ def __init__(self, env_name: str, seed: int) -> None:
+ self._env = gym.make(env_name)
+ self._seed = seed
+
+ def step(
+ self,
+ action: npt.NDArray,
+ ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]:
+ obs, reward, terminated, truncated, info = self._env.step(action)
+ return obs.astype("float32"), float(reward), terminated, bool(truncated), info
+
+ def reset(self) -> tuple[npt.NDArray, dict[str, Any]]:
+ obs, info = self._env.reset(seed=self._seed)
+ self._env.step(self._env.action_space.sample())
+ return obs.astype("float32"), info
+
+ def sample_action(self) -> npt.NDArray:
+ return self._env.action_space.sample()
+
+ def render(self) -> npt.NDArray:
+ return self._env.render()
+
+ @property
+ def observation_space(self) -> npt.NDArray:
+ return self._env.observation_space
+
+ @property
+ def action_space(self) -> npt.NDArray:
+ return self._env.action_space
+
+ @property
+ def env_family(self) -> str:
+ return "gymnasium"
diff --git a/src/oprl/environment/make_env.py b/src/oprl/environment/make_env.py
new file mode 100644
index 0000000..6f85544
--- /dev/null
+++ b/src/oprl/environment/make_env.py
@@ -0,0 +1,103 @@
+from collections import defaultdict
+
+from oprl.environment import EnvProtocol, DMControlEnv, SafetyGym
+from oprl.environment.gymnasium import Gymnasium
+
+
+
+ENV_MAPPER: defaultdict[str, set[str]] = defaultdict(set)
+ENV_MAPPER["dm_control"] = set(
+ [
+ "acrobot-swingup",
+ "ball_in_cup-catch",
+ "cartpole-balance",
+ "cartpole-swingup",
+ "cheetah-run",
+ "finger-spin",
+ "finger-turn_easy",
+ "finger-turn_hard",
+ "fish-upright",
+ "fish-swim",
+ "hopper-stand",
+ "hopper-hop",
+ "humanoid-stand",
+ "humanoid-walk",
+ "humanoid-run",
+ "pendulum-swingup",
+ "point_mass-easy",
+ "reacher-easy",
+ "reacher-hard",
+ "swimmer-swimmer6",
+ "swimmer-swimmer15",
+ "walker-stand",
+ "walker-walk",
+ "walker-run",
+ ]
+)
+
+
+ENV_MAPPER["gymnasium"] = set(
+ [
+ "Ant-v4",
+ "Hopper-v4",
+ "HalfCheetah-v4",
+ "HumanoidStandup-v4",
+ "Humanoid-v4",
+ "InvertedDoublePendulum-v4",
+ "InvertedPendulum-v4",
+ "Pusher-v4",
+ "Reacher-v4",
+ "Swimmer-v4",
+ "Walker2d-v4",
+ ]
+)
+
+
+ENV_MAPPER["safety_gymnasium"] = set(
+ [
+ "SafetyPointGoal1-v0",
+ "SafetyPointGoal2-v0",
+ "SafetyPointButton1-v0",
+ "SafetyPointButton2-v0",
+ "SafetyPointPush1-v0",
+ "SafetyPointPush2-v0",
+ "SafetyPointCircle1-v0",
+ "SafetyPointCircle2-v0",
+ "SafetyCarGoal1-v0",
+ "SafetyCarGoal2-v0",
+ "SafetyCarButton1-v0",
+ "SafetyCarButton2-v0",
+ "SafetyCarPush1-v0",
+ "SafetyCarPush2-v0",
+ "SafetyCarCircle1-v0",
+ "SafetyCarCircle2-v0",
+ "SafetyAntGoal1-v0",
+ "SafetyAntGoal2-v0",
+ "SafetyAntButton1-v0",
+ "SafetyAntButton2-v0",
+ "SafetyAntPush1-v0",
+ "SafetyAntPush2-v0",
+ "SafetyAntCircle1-v0",
+ "SafetyAntCircle2-v0",
+ "SafetyDoggoGoal1-v0",
+ "SafetyDoggoGoal2-v0",
+ "SafetyDoggoButton1-v0",
+ "SafetyDoggoButton2-v0",
+ "SafetyDoggoPush1-v0",
+ "SafetyDoggoPush2-v0",
+ "SafetyDoggoCircle1-v0",
+ "SafetyDoggoCircle2-v0",
+ ]
+)
+
+
+def make_env(name: str, seed: int) -> EnvProtocol:
+ for env_type, env_set in ENV_MAPPER.items():
+ if name in env_set:
+ if env_type == "dm_control":
+ return DMControlEnv(name, seed=seed)
+ elif env_type == "safety_gymnasium":
+ return SafetyGym(name, seed=seed)
+ elif env_type == "gymnasium":
+ return Gymnasium(name, seed=seed)
+ raise ValueError(f"Unsupported environment: {name}")
diff --git a/src/oprl/environment/protocols.py b/src/oprl/environment/protocols.py
new file mode 100644
index 0000000..d086c87
--- /dev/null
+++ b/src/oprl/environment/protocols.py
@@ -0,0 +1,35 @@
+from typing import Protocol, Any
+
+import numpy.typing as npt
+
+
+class EnvProtocol(Protocol):
+ def __init__(self, env_name: str, seed: int) -> None:
+ ...
+
+ def step(
+ self, action: npt.NDArray
+ ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]:
+ ...
+
+ def reset(self) -> tuple[npt.NDArray, dict[str, Any]]:
+ ...
+
+ def sample_action(self) -> npt.NDArray:
+ ...
+
+ def render(self) -> npt.NDArray:
+ ...
+
+ @property
+ def observation_space(self) -> npt.NDArray:
+ ...
+
+ @property
+ def action_space(self) -> npt.NDArray:
+ ...
+
+ @property
+ def env_family(self) -> str:
+ ...
+
diff --git a/src/oprl/environment/safety_gymnasium.py b/src/oprl/environment/safety_gymnasium.py
new file mode 100644
index 0000000..910ecab
--- /dev/null
+++ b/src/oprl/environment/safety_gymnasium.py
@@ -0,0 +1,42 @@
+import numpy.typing as npt
+from typing import Any
+
+from oprl.environment.protocols import EnvProtocol
+
+
+class SafetyGym(EnvProtocol):
+ def __init__(self, env_name: str, seed: int) -> None:
+ import safety_gymnasium as gym
+ self._env = gym.make(env_name, render_mode='rgb_array', camera_name="fixednear")
+ self._seed = seed
+
+ def step(
+ self, action: npt.NDArray
+ ) -> tuple[npt.NDArray, float, bool, bool, dict[str, Any]]:
+ obs, reward, cost, terminated, truncated, info = self._env.step(action)
+ info["cost"] = cost
+ return obs.astype("float32"), float(reward), terminated, bool(truncated), info
+
+ def reset(self) -> tuple[npt.NDArray, dict[str, Any]]:
+ obs, info = self._env.reset(seed=self._seed)
+ self._env.step(self._env.action_space.sample())
+ return obs.astype("float32"), info
+
+ def sample_action(self) -> npt.NDArray:
+ return self._env.action_space.sample()
+
+ def render(self) -> npt.NDArray:
+ return self._env.render()
+
+ @property
+ def observation_space(self) -> npt.NDArray:
+ return self._env.observation_space
+
+ @property
+ def action_space(self) -> npt.NDArray:
+ return self._env.action_space
+
+ @property
+ def env_family(self) -> str:
+ return "safety_gymnasium"
+
diff --git a/src/oprl/logging.py b/src/oprl/logging.py
new file mode 100644
index 0000000..d34d6e9
--- /dev/null
+++ b/src/oprl/logging.py
@@ -0,0 +1,102 @@
+import os
+from pathlib import Path
+import sys
+import logging
+from datetime import datetime
+import shutil
+from abc import ABC, abstractmethod
+from typing import Protocol, Callable, runtime_checkable
+
+from torch.utils.tensorboard.writer import SummaryWriter
+
+
+@runtime_checkable
+class LoggerProtocol(Protocol):
+ log_dir: Path
+
+ def log_scalar(self, tag: str, value: float, step: int) -> None: ...
+
+ def log_scalars(self, values: dict[str, float], step: int) -> None: ...
+
+
+
+def get_logs_path(logdir: str, algo: str, env: str, seed: int) -> Path:
+ dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss")
+ log_dir = Path(logdir) / algo / f"{algo}-env_{env}-seed_{seed}-{dt}"
+ logging.info(f"LOGDIR: {log_dir}")
+ return log_dir
+
+
+def create_stdout_logger(name: str = None):
+ if name is None:
+ import inspect
+ frame = inspect.currentframe().f_back
+ filename = os.path.basename(frame.f_code.co_filename)
+ name = os.path.splitext(filename)[0]
+
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ return logger
+
+
+def copy_exp_dir(log_dir: Path) -> None:
+ cur_dir = Path(__file__).parents[0]
+ dest_dir = log_dir / "src"
+ shutil.copytree(cur_dir, dest_dir)
+ logging.info(f"Source copied into {dest_dir}")
+
+
+def make_text_logger_func(algo: str, env: str) -> Callable[[int], LoggerProtocol]:
+ def make_logger(seed: int) -> LoggerProtocol:
+ logs_root = os.environ.get("OPRL_LOGS", "logs")
+ log_dir = get_logs_path(logdir=logs_root, algo=algo, env=env, seed=seed)
+ logger = FileTxtLogger(log_dir)
+ logger.copy_source_code()
+ return logger
+ return make_logger
+
+
+class BaseLogger(ABC):
+ @abstractmethod
+ def log_scalar(self, tag: str, value: float, step: int) -> None:
+ ...
+
+ def log_scalars(self, values: dict[str, float], step: int) -> None:
+ """
+ Args:
+ values: Dict with tag -> value to log.
+ step: Iter step.
+ """
+ (self.log_scalar(k, v, step) for k, v in values.items())
+
+
+logger = create_stdout_logger(__name__)
+
+
+class FileTxtLogger(BaseLogger):
+ def __init__(self, logdir: Path | str) -> None:
+ self.writer = SummaryWriter(logdir)
+ self.log_dir = Path(logdir)
+
+ def copy_source_code(self) -> None:
+ copy_exp_dir(self.log_dir)
+ logger.info(f"Source code is copied to {self.log_dir}")
+ self._copy_config_file()
+
+ def _copy_config_file(self) -> None:
+ main_module = sys.modules.get('__main__')
+ if main_module and hasattr(main_module, '__file__'):
+ shutil.copyfile(main_module.__file__, self.log_dir / Path(main_module.__file__).name)
+ else:
+ logger.warning("Failed to copy config file.")
+
+ def log_scalar(self, tag: str, value: float, step: int) -> None:
+ self.writer.add_scalar(tag, value, step)
+ self._log_scalar_to_file(tag, value, step)
+
+ def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None:
+ log_path = self.log_dir / f"{tag}.log"
+ log_path.parents[0].mkdir(exist_ok=True)
+ with open(log_path, "a") as f:
+ f.write(f"{step} {val}\n")
+
diff --git a/src/oprl/configs/utils.py b/src/oprl/parse_args.py
similarity index 61%
rename from src/oprl/configs/utils.py
rename to src/oprl/parse_args.py
index 50c16e7..66bb580 100644
--- a/src/oprl/configs/utils.py
+++ b/src/oprl/parse_args.py
@@ -1,7 +1,4 @@
import argparse
-import logging
-import os
-from datetime import datetime
def parse_args() -> argparse.Namespace:
@@ -28,8 +25,14 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
-def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str:
- dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm%Ss")
- log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}")
- logging.info(f"LOGDIR: {log_dir}")
- return log_dir
+def parse_args_distrib():
+ parser = argparse.ArgumentParser(description="Run distrib training")
+ parser.add_argument("--config", type=str, help="Path to the config file.")
+ parser.add_argument(
+ "--env", type=str, default="cartpole-balance", help="Name of the environment."
+ )
+ parser.add_argument(
+ "--device", type=str, default="cpu", help="Device to perform training on."
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ return parser.parse_args()
diff --git a/src/oprl/runners/__init__.py b/src/oprl/runners/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/oprl/runners/config.py b/src/oprl/runners/config.py
new file mode 100644
index 0000000..b84b6d5
--- /dev/null
+++ b/src/oprl/runners/config.py
@@ -0,0 +1,22 @@
+from pydantic_settings import BaseSettings
+
+
+class CommonParameters(BaseSettings):
+ state_dim: int
+ action_dim: int
+ num_steps: int
+ eval_every: int = 2500
+ estimate_q_every: int = 5000
+ log_every: int = 2500
+ device: str = "cpu"
+
+
+class DistribConfig(BaseSettings):
+ batch_size: int = 128
+ num_env_workers: int = 4
+ episodes_per_worker: int = 100
+ warmup_epochs: int = 16
+ episode_length: int = 1000
+ learner_num_waits: int = 10
+ warmup_env_steps: int = 1000
+
diff --git a/src/oprl/runners/train.py b/src/oprl/runners/train.py
new file mode 100644
index 0000000..c128890
--- /dev/null
+++ b/src/oprl/runners/train.py
@@ -0,0 +1,86 @@
+from typing import Callable
+import logging
+import random
+from multiprocessing import Process
+
+import numpy as np
+import torch as t
+
+from oprl.algos.protocols import AlgorithmProtocol
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.environment.protocols import EnvProtocol
+from oprl.trainers.base_trainer import BaseTrainer
+from oprl.trainers.safe_trainer import SafeTrainer
+from oprl.logging import LoggerProtocol
+from oprl.runners.config import CommonParameters
+
+
+def set_seed(seed: int) -> None:
+ random.seed(seed)
+ np.random.seed(seed)
+ t.manual_seed(seed)
+
+
+def run_training(
+ make_algo: Callable[[LoggerProtocol], AlgorithmProtocol],
+ make_env: Callable[[int], EnvProtocol],
+ make_replay_buffer: Callable[[], ReplayBufferProtocol],
+ make_logger: Callable[[int], LoggerProtocol],
+ config: CommonParameters,
+ seeds: int = 1,
+ start_seed: int = 0
+) -> None:
+ if seeds == 1:
+ _run_training_func(make_algo, make_env, make_replay_buffer, make_logger, config, 0)
+ else:
+ processes = []
+ for seed in range(start_seed, start_seed + seeds):
+ processes.append(
+ Process(
+ target=_run_training_func,
+ args=(make_algo, make_env, make_replay_buffer, make_logger, config, seed),
+ )
+ )
+
+ for i, p in enumerate(processes):
+ p.start()
+ logging.info(f"Starting process {i}...")
+ for p in processes:
+ p.join()
+ logging.info("Training finished.")
+
+
+def _run_training_func(
+ make_algo: Callable[[LoggerProtocol], AlgorithmProtocol],
+ make_env: Callable[[int], EnvProtocol],
+ make_replay_buffer: Callable[[], ReplayBufferProtocol],
+ make_logger: Callable[[int], LoggerProtocol],
+ config: CommonParameters,
+ seed: int,
+) -> None:
+ set_seed(seed)
+ env = make_env(seed)
+ replay_buffer = make_replay_buffer()
+ logger = make_logger(seed)
+ algo = make_algo(logger)
+
+ base_trainer = BaseTrainer(
+ env=env,
+ make_env_test=make_env,
+ algo=algo,
+ replay_buffer=replay_buffer,
+ num_steps=config.num_steps,
+ eval_interval=config.eval_every,
+ device=config.device,
+ estimate_q_every=config.estimate_q_every,
+ stdout_log_every=config.log_every,
+ seed=seed,
+ logger=logger,
+ )
+ if env.env_family in ["dm_control", "gymnasium"]:
+ trainer = base_trainer
+ elif env.env_family == "safety_gymnasium":
+ trainer = SafeTrainer(trainer=base_trainer)
+ else:
+ raise ValueError(f"Unsupported env family: {env.env_family}")
+ trainer.train()
diff --git a/src/oprl/runners/train_distrib.py b/src/oprl/runners/train_distrib.py
new file mode 100644
index 0000000..cf639cd
--- /dev/null
+++ b/src/oprl/runners/train_distrib.py
@@ -0,0 +1,40 @@
+from typing import Callable
+from multiprocessing import Process
+
+from oprl.algos.protocols import AlgorithmProtocol, PolicyProtocol
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.environment.protocols import EnvProtocol
+from oprl.logging import LoggerProtocol, create_stdout_logger
+from oprl.runners.config import DistribConfig
+
+
+logger = create_stdout_logger()
+
+
+def run_distrib_training(
+ run_env_worker: Callable,
+ run_policy_update_worker: Callable,
+ make_env: Callable[[int], EnvProtocol],
+ make_algo: Callable[[LoggerProtocol], AlgorithmProtocol],
+ make_policy: Callable[[], PolicyProtocol],
+ make_replay_buffer: Callable[[], ReplayBufferProtocol],
+ make_logger: Callable[[], LoggerProtocol],
+ config: DistribConfig
+) -> None:
+ processes = []
+ for i_env in range(config.num_env_workers):
+ processes.append(
+ Process(target=run_env_worker, args=(make_env, make_policy, config, i_env))
+ )
+ processes.append(
+ Process(
+ target=run_policy_update_worker,
+ args=(make_algo, make_env, make_replay_buffer, make_logger, config),
+ )
+ )
+
+ for p in processes:
+ p.start()
+ for p in processes:
+ p.join()
+ logger.info("Training Finished.")
diff --git a/src/oprl/trainers/__init__.py b/src/oprl/trainers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/oprl/trainers/base_trainer.py b/src/oprl/trainers/base_trainer.py
index c9fccb2..7c15529 100644
--- a/src/oprl/trainers/base_trainer.py
+++ b/src/oprl/trainers/base_trainer.py
@@ -1,243 +1,172 @@
-import os
-from typing import Any, Callable
+from dataclasses import dataclass
+from typing import Callable
import numpy as np
-import torch
+import torch as t
+
+from oprl.algos.protocols import AlgorithmProtocol
+from oprl.environment import EnvProtocol
+from oprl.buffers.protocols import ReplayBufferProtocol
+from oprl.logging import LoggerProtocol, create_stdout_logger
+
+from oprl.trainers.protocols import TrainerProtocol
+
+
+logger = create_stdout_logger()
+
+
+@dataclass
+class BaseTrainer(TrainerProtocol):
+ logger: LoggerProtocol
+ env: EnvProtocol
+ make_env_test: Callable[[int], EnvProtocol]
+ replay_buffer: ReplayBufferProtocol
+ algo: AlgorithmProtocol
+ gamma: float = 0.99
+ num_steps: int = int(1e6)
+ start_steps: int = int(10e3)
+ batch_size: int = 128
+ eval_interval: int = int(2e3)
+ num_eval_episodes: int = 10
+ save_buffer_every: int = 0
+ save_policy_every: int = int(100_000)
+ estimate_q_every: int = 0
+ stdout_log_every: int = int(1e5)
+ device: str = "cpu"
+ seed: int = 0
+
+ def train(self) -> None:
+ self.algo.check_created()
+ self.replay_buffer.check_created()
-from oprl.env import BaseEnv
-from oprl.trainers.buffers.episodic_buffer import EpisodicReplayBuffer
-from oprl.utils.logger import Logger, StdLogger
-
-
-class BaseTrainer:
- def __init__(
- self,
- state_dim: int,
- action_dim: int,
- env: BaseEnv,
- make_env_test: Callable[[int], BaseEnv],
- algo: Any | None = None,
- buffer_size: int = int(1e6),
- gamma: float = 0.99,
- num_steps: int = int(1e6),
- start_steps: int = int(10e3),
- batch_size: int = 128,
- eval_interval: int = int(2e3),
- num_eval_episodes: int = 10,
- save_buffer_every: int = 0,
- save_policy_every: int = int(50_000),
- visualise_every: int = 0,
- estimate_q_every: int = 0,
- stdout_log_every: int = int(1e5),
- device: str = "cpu",
- seed: int = 0,
- logger: Logger = StdLogger(),
- ):
- """
- Args:
- state_dim: Dimension of the observation.
- action_dim: Dimension of the action.
- env: Enviornment object.
- make_env_test: Environment object for evaluation.
- algo: Codename for the algo (SAC).
- buffer_size: Buffer size in transitions.
- gamma: Discount factor.
- num_step: Number of env steps to train.
- start_steps: Number of environment steps not to perform training at the beginning.
- batch_size: Batch-size.
- eval_interval: Number of env step after which perform evaluation.
- save_buffer_every: Number of env steps after which save replay buffer.
- visualise_every: Number of env steps after which perform vizualisation.
- device: Name of the device.
- stdout_log_every: Number of evn steps after which log info to stdout.
- seed: Random seed.
- logger: Logger instance.
- """
- self._env = env
- self._make_env_test = make_env_test
- self._algo = algo
- self._gamma = gamma
- self._device = device
- self._save_buffer_every = save_buffer_every
- self._visualize_every = visualise_every
- self._estimate_q_every = estimate_q_every
- self._stdout_log_every = stdout_log_every
- self._save_policy_every= save_policy_every
- self._logger = logger
- self.seed = seed
-
- self.buffer = EpisodicReplayBuffer(
- buffer_size=buffer_size,
- state_dim=state_dim,
- action_dim=action_dim,
- device=device,
- gamma=gamma,
- )
-
- self.batch_size = batch_size
- self.num_steps = num_steps
- self.start_steps = start_steps
- self.eval_interval = eval_interval
- self.num_eval_episodes = num_eval_episodes
-
- def train(self):
ep_step = 0
- state, _ = self._env.reset()
-
+ state, _ = self.env.reset()
for env_step in range(self.num_steps + 1):
ep_step += 1
if env_step <= self.start_steps:
- action = self._env.sample_action()
+ action = self.env.sample_action()
else:
- action = self._algo.explore(state)
- next_state, reward, terminated, truncated, _ = self._env.step(action)
+ action = self.algo.actor.explore(state)
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
- self.buffer.append(
+ self.replay_buffer.add_transition(
state, action, reward, terminated, episode_done=terminated or truncated
)
if terminated or truncated:
- next_state, _ = self._env.reset()
+ next_state, _ = self.env.reset()
ep_step = 0
state = next_state
- if len(self.buffer) < self.batch_size:
+ if len(self.replay_buffer) < self.batch_size:
continue
- batch = self.buffer.sample(self.batch_size)
- self._algo.update(*batch)
+ (
+ states,
+ actions,
+ rewards,
+ dones,
+ next_states
+ ) = self.replay_buffer.sample(self.batch_size)
+ self.algo.update(states, actions, rewards, dones, next_states)
- self._eval_routine(env_step, batch)
- self._visualize(env_step)
- self._save_buffer(env_step)
+ self._log_evaluation(env_step, rewards)
self._save_policy(env_step)
- self._log_stdout(env_step, batch)
+ self._log_stdout(env_step, rewards)
- def _eval_routine(self, env_step: int, batch):
+ def _log_evaluation(self, env_step: int, rewards: t.Tensor) -> None:
if env_step % self.eval_interval == 0:
- self._log_evaluation(env_step)
-
- self._logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step)
- self._logger.log_scalar(
- "trainer/buffer_transitions", len(self.buffer), env_step
+ eval_metrics = self.evaluate()
+ self.logger.log_scalar("trainer/ep_reward", eval_metrics["return"], env_step)
+ self.logger.log_scalar("trainer/avg_reward", rewards.mean().item(), env_step)
+ self.logger.log_scalar(
+ "trainer/buffer_transitions", len(self.replay_buffer), env_step
)
- self._logger.log_scalar(
- "trainer/buffer_episodes", self.buffer.num_episodes, env_step
+ self.logger.log_scalar(
+ "trainer/buffer_episodes", self.replay_buffer.episodes_counter, env_step
)
- self._logger.log_scalar(
+ self.logger.log_scalar(
"trainer/buffer_last_ep_len",
- self.buffer.get_last_ep_len(),
+ self.replay_buffer.last_episode_length,
env_step,
)
- def _log_evaluation(self, env_step: int):
+ def evaluate(self) -> dict[str, float]:
returns = []
for i_ep in range(self.num_eval_episodes):
- env_test = self._make_env_test(seed=self.seed + i_ep)
+ env_test = self.make_env_test(self.seed + i_ep)
state, _ = env_test.reset()
episode_return = 0.0
terminated, truncated = False, False
while not (terminated or truncated):
- action = self._algo.exploit(state)
+ action = self.algo.actor.exploit(state)
state, reward, terminated, truncated, _ = env_test.step(action)
episode_return += reward
returns.append(episode_return)
- mean_return = np.mean(returns)
- self._logger.log_scalar("trainer/ep_reward", mean_return, env_step)
-
- def _visualize(self, env_step: int):
- if self._visualize_every > 0 and env_step % self._visualize_every == 0:
- imgs = self.visualise_policy() # [T, W, H, C]
- if imgs is not None:
- self._logger.log_video("eval_policy", imgs, env_step)
-
- def _save_buffer(self, env_step: int):
- # TODO: doesn't work
- if self._save_buffer_every > 0 and env_step % self._save_buffer_every == 0:
- self.buffer.save(f"{self.log_dir}/buffers/buffer_step_{env_step}.pickle")
-
- def _save_policy(self, env_step: int):
- if self._save_policy_every > 0 and env_step % self._save_policy_every == 0:
- self._logger.save_weights(self._algo.actor, env_step)
+ return {
+ "return": float(np.mean(returns))
+ }
+
+ def _save_policy(self, env_step: int) -> None:
+ if self.save_policy_every > 0 and env_step % self.save_policy_every == 0:
+ weights_path = self.logger.log_dir / "weights" / f"{env_step}.w"
+ weights_path.parents[0].mkdir(exist_ok=True)
+ t.save(
+ self.algo.actor,
+ weights_path
+ )
- def _estimate_q(self, env_step: int):
- if self._estimate_q_every > 0 and env_step % self._estimate_q_every == 0:
+ def _estimate_q(self, env_step: int) -> None:
+ if self.estimate_q_every > 0 and env_step % self.estimate_q_every == 0:
q_true = self.estimate_true_q()
q_critic = self.estimate_critic_q()
if q_true is not None:
- self._logger.log_scalar("trainer/Q-estimate", q_true, env_step)
- self._logger.log_scalar("trainer/Q-critic", q_critic, env_step)
- self._logger.log_scalar(
+ self.logger.log_scalar("trainer/Q-estimate", q_true, env_step)
+ self.logger.log_scalar("trainer/Q-critic", q_critic, env_step)
+ self.logger.log_scalar(
"trainer/Q_asb_diff", q_critic - q_true, env_step
)
- def _log_stdout(self, env_step: int, batch):
- if env_step % self._stdout_log_every == 0:
+ def _log_stdout(self, env_step: int, rewards: t.Tensor) -> None:
+ if env_step % self.stdout_log_every == 0:
perc = int(env_step / self.num_steps * 100)
- print(
- f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {batch[2].mean():10.3f}"
+ logger.info(
+ f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {rewards.mean():10.3f}"
)
- def visualise_policy(self):
- """
- returned shape: [N, C, W, H]
- """
- env = self._make_env_test(seed=self.seed)
- try:
- imgs = []
+ def estimate_true_q(self, eval_episodes: int = 10) -> float:
+ qs = []
+ for i_eval in range(eval_episodes):
+ env = self.make_env_test(self.seed * 100 + i_eval)
state, _ = env.reset()
- done = False
- while not done:
- img = env.render()
- imgs.append(img)
- action = self._algo.exploit(state)
- state, _, terminated, truncated, _ = env.step(action)
- done = terminated or truncated
- return np.concatenate(imgs, dtype="uint8")
- except Exception as e:
- print(f"Failed to visualise a policy: {e}")
- return None
-
- def estimate_true_q(self, eval_episodes: int = 10) -> float | None:
- try:
- qs = []
- for i_eval in range(eval_episodes):
- env = self._make_env_test(seed=self.seed * 100 + i_eval)
- print("Before reset etimate q")
- state, _ = env.reset()
-
- q = 0
- s_i = 1
- while True:
- action = self._algo.exploit(state)
- state, r, terminated, truncated, _ = env.step(action)
- q += r * self._gamma**s_i
- s_i += 1
- if terminated or truncated:
- break
- qs.append(q)
+ q = 0
+ s_i = 1
+ while True:
+ action = self.algo.actor.exploit(state)
+ state, r, terminated, truncated, _ = env.step(action)
+ q += r * self.gamma ** s_i
+ s_i += 1
+ if terminated or truncated:
+ break
+ qs.append(q)
- return np.mean(qs, dtype=float)
- except Exception as e:
- print(f"Failed to estimate Q-value: {e}")
- return None
+ return np.mean(qs, dtype=float)
def estimate_critic_q(self, num_episodes: int = 10) -> float:
qs = []
for i_eval in range(num_episodes):
- env = self._make_env_test(seed=self.seed * 100 + i_eval)
-
+ env = self.make_env_test(self.seed * 100 + i_eval)
state, _ = env.reset()
- action = self._algo.exploit(state)
+ action = self.algo.actor.exploit(state)
- state = torch.tensor(state).unsqueeze(0).float().to(self._device)
- action = torch.tensor(action).unsqueeze(0).float().to(self._device)
+ state = t.tensor(state).unsqueeze(0).float().to(self.device)
+ action = t.tensor(action).unsqueeze(0).float().to(self.device)
- q = self._algo.critic(state, action)
+ q = self.algo.critic(state, action)
# TODO: TQC is not supported by this logic, need to update
if isinstance(q, tuple):
q = q[0]
@@ -245,27 +174,3 @@ def estimate_critic_q(self, num_episodes: int = 10) -> float:
qs.append(q)
return np.mean(qs, dtype=float)
-
-
-def run_training(make_algo, make_env, make_logger, config: dict[str, Any], seed: int):
- env = make_env(seed=seed)
- logger = make_logger(seed)
-
- trainer = BaseTrainer(
- state_dim=config["state_shape"],
- action_dim=config["action_shape"],
- env=env,
- make_env_test=make_env,
- algo=make_algo(logger, seed),
- num_steps=config["num_steps"],
- eval_interval=config["eval_every"],
- device=config["device"],
- save_buffer_every=config["save_buffer"],
- visualise_every=config["visualise_every"],
- estimate_q_every=config["estimate_q_every"],
- stdout_log_every=config["log_every"],
- seed=seed,
- logger=logger,
- )
-
- trainer.train()
diff --git a/src/oprl/trainers/buffers/episodic_buffer.py b/src/oprl/trainers/buffers/episodic_buffer.py
deleted file mode 100644
index 8f7b7cf..0000000
--- a/src/oprl/trainers/buffers/episodic_buffer.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import os
-import pickle
-
-import numpy as np
-import torch
-
-
-class EpisodicReplayBuffer:
- def __init__(
- self,
- buffer_size: int,
- state_dim: int,
- action_dim: int,
- device: str,
- gamma: float,
- max_episode_len: int = 1000,
- dtype=torch.float,
- ):
- """
- Args:
- buffer_size: Max number of transitions in buffer.
- state_dim: Dimension of the state.
- action_dim: Dimension of the action.
- device: Device to place buffer.
- gamma: Discount factor for N-step.
- max_episode_len: Max length of the episode to store.
- dtype: Data type.
- """
- self.buffer_size = buffer_size
- self.max_episodes = buffer_size // max_episode_len
- self.max_episode_len = max_episode_len
- self.state_dim = state_dim
- self.action_dim = action_dim
- self.device = device
- self.gamma = gamma
-
- self.ep_pointer = 0
- self.cur_episodes = 1
- self.cur_size = 0
-
- self.actions = torch.empty(
- (self.max_episodes, max_episode_len, action_dim),
- dtype=dtype,
- device=device,
- )
- self.rewards = torch.empty(
- (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device
- )
- self.dones = torch.empty(
- (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device
- )
- self.states = torch.empty(
- (self.max_episodes, max_episode_len + 1, state_dim),
- dtype=dtype,
- device=device,
- )
- self.ep_lens = [0] * self.max_episodes
-
- self.actions_for_std = torch.empty(
- (100, action_dim), dtype=dtype, device=device
- )
- self.actions_for_std_cnt = 0
-
- # TODO: rename to add
- def append(self, state, action, reward, done, episode_done=None):
- """
- Args:
- state: state.
- action: action.
- reward: reward.
- done: done only if episode ends naturally.
- episode_done: done that can be set to True if time limit is reached.
- """
- self.states[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_(
- torch.from_numpy(state)
- )
- self.actions[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_(
- torch.from_numpy(action)
- )
- self.rewards[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(reward)
- self.dones[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(done)
-
- self.actions_for_std[self.actions_for_std_cnt % 100].copy_(
- torch.from_numpy(action)
- )
- self.actions_for_std_cnt += 1
-
- self.ep_lens[self.ep_pointer] += 1
- self.cur_size = min(self.cur_size + 1, self.buffer_size)
- if episode_done:
- self._inc_episode()
-
- def _inc_episode(self):
- self.ep_pointer = (self.ep_pointer + 1) % self.max_episodes
- self.cur_episodes = min(self.cur_episodes + 1, self.max_episodes)
- self.cur_size -= self.ep_lens[self.ep_pointer]
- self.ep_lens[self.ep_pointer] = 0
-
- def add_episode(self, episode):
- for s, a, r, d, s_ in episode:
- self.append(s, a, r, d, episode_done=d)
- if d:
- break
- else:
- self._inc_episode()
-
- def _inds_to_episodic(self, inds):
- start_inds = np.cumsum([0] + self.ep_lens[: self.cur_episodes - 1])
- end_inds = start_inds + np.array(self.ep_lens[: self.cur_episodes])
- ep_inds = np.argmin(
- inds.reshape(-1, 1) >= np.tile(end_inds, (len(inds), 1)), axis=1
- )
- step_inds = inds - start_inds[ep_inds]
-
- return ep_inds, step_inds
-
- def sample(self, batch_size):
- inds = np.random.randint(low=0, high=self.cur_size, size=batch_size)
- ep_inds, step_inds = self._inds_to_episodic(inds)
-
- return (
- self.states[ep_inds, step_inds],
- self.actions[ep_inds, step_inds],
- self.rewards[ep_inds, step_inds],
- self.dones[ep_inds, step_inds],
- self.states[ep_inds, step_inds + 1],
- )
-
- def save(self, path: str):
- """
- Args:
- path: Path to pickle file.
- """
- dirname = os.path.dirname(path)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
- data = {
- "states": self.states.cpu(),
- "actions": self.actions.cpu(),
- "rewards": self.rewards.cpu(),
- "dones": self.dones.cpu(),
- "ep_lens": self.ep_lens,
- }
- try:
- with open(path, "wb") as f:
- pickle.dump(data, f)
- print(f"Replay buffer saved to {path}")
- except Exception as e:
- print(f"Failed to save replay buffer: {e}")
-
- def __len__(self):
- return self.cur_size
-
- @property
- def num_episodes(self):
- return self.cur_episodes
-
- def get_last_ep_len(self):
- return self.ep_lens[self.ep_pointer]
diff --git a/src/oprl/trainers/protocols.py b/src/oprl/trainers/protocols.py
new file mode 100644
index 0000000..6b3e986
--- /dev/null
+++ b/src/oprl/trainers/protocols.py
@@ -0,0 +1,7 @@
+from typing import Protocol
+
+class TrainerProtocol(Protocol):
+ def train(self) -> None: ...
+
+ def evaluate(self) -> dict[str, float]: ...
+
diff --git a/src/oprl/trainers/safe_trainer.py b/src/oprl/trainers/safe_trainer.py
index a5edf9f..dc4d4f2 100644
--- a/src/oprl/trainers/safe_trainer.py
+++ b/src/oprl/trainers/safe_trainer.py
@@ -1,119 +1,84 @@
-from typing import Any, Callable
+from dataclasses import dataclass
+import torch as t
import numpy as np
-from oprl.env import BaseEnv
-from oprl.trainers.base_trainer import BaseTrainer
-from oprl.utils.logger import Logger, StdLogger
-
-
-class SafeTrainer(BaseTrainer):
- def __init__(
- self,
- state_dim: int,
- action_dim: int,
- env: BaseEnv,
- make_env_test: Callable[[int], BaseEnv],
- algo: Any | None = None,
- buffer_size: int = int(1e6),
- gamma: float = 0.99,
- num_steps=int(1e6),
- start_steps: int = int(10e3),
- batch_size: int = 128,
- eval_interval: int = int(2e3),
- num_eval_episodes: int = 10,
- save_buffer_every: int = 0,
- save_policy_every: int = int(50_000),
- visualise_every: int = 0,
- estimate_q_every: int = 0,
- stdout_log_every: int = int(1e5),
- device: str = "cpu",
- seed: int = 0,
- logger: Logger = StdLogger(),
- ):
- """
- Args:
- state_dim: Dimension of the observation.
- action_dim: Dimension of the action.
- env: Enviornment object.
- make_env_test: Environment object for evaluation.
- algo: Codename for the algo (SAC).
- buffer_size: Buffer size in transitions.
- gamma: Discount factor.
- num_step: Number of env steps to train.
- start_steps: Number of environment steps not to perform training at the beginning.
- batch_size: Batch-size.
- eval_interval: Number of env step after which perform evaluation.
- save_buffer_every: Number of env steps after which save replay buffer.
- visualise_every: Number of env steps after which perform vizualisation.
- stdout_log_every: Number of evn steps after which log info to stdout.
- device: Name of the device.
- seed: Random seed.
- logger: Logger instance.
- """
- super().__init__(
- state_dim=state_dim,
- action_dim=action_dim,
- env=env,
- make_env_test=make_env_test,
- algo=algo,
- buffer_size=buffer_size,
- gamma=gamma,
- device=device,
- num_steps=num_steps,
- start_steps=start_steps,
- batch_size=batch_size,
- eval_interval=eval_interval,
- num_eval_episodes=num_eval_episodes,
- save_buffer_every=save_buffer_every,
- save_policy_every=save_policy_every,
- visualise_every=visualise_every,
- estimate_q_every=estimate_q_every,
- stdout_log_every=stdout_log_every,
- seed=seed,
- logger=logger,
- )
+from oprl.trainers.base_trainer import BaseTrainer, TrainerProtocol
+
+
+@dataclass
+class SafeTrainer(TrainerProtocol):
+ trainer: BaseTrainer
def train(self):
+ self.trainer.algo.check_created()
+ self.trainer.replay_buffer.check_created()
+
ep_step = 0
- state, _ = self._env.reset()
+ state, _ = self.trainer.env.reset()
total_cost = 0
-
- for env_step in range(self.num_steps + 1):
+ for env_step in range(self.trainer.num_steps + 1):
ep_step += 1
- if env_step <= self.start_steps:
- action = self._env.sample_action()
+ if env_step <= self.trainer.start_steps:
+ action = self.trainer.env.sample_action()
else:
- action = self._algo.explore(state)
- next_state, reward, terminated, truncated, info = self._env.step(action)
+ action = self.trainer.algo.actor.explore(state)
+ next_state, reward, terminated, truncated, info = self.trainer.env.step(action)
total_cost += info["cost"]
- self.buffer.append(
+ self.trainer.replay_buffer.add_transition(
state, action, reward, terminated, episode_done=terminated or truncated
)
if terminated or truncated:
- next_state, _ = self._env.reset()
+ next_state, _ = self.trainer.env.reset()
ep_step = 0
state = next_state
- if len(self.buffer) < self.batch_size:
+ if len(self.trainer.replay_buffer) < self.trainer.batch_size:
continue
- batch = self.buffer.sample(self.batch_size)
- self._algo.update(*batch)
+ (
+ states,
+ actions,
+ rewards,
+ dones,
+ next_states
+ ) = self.trainer.replay_buffer.sample(self.trainer.batch_size)
+ self.trainer.algo.update(states, actions, rewards, dones, next_states)
+
+ self._log_evaluation(env_step, rewards)
+ self.trainer._save_policy(env_step)
+ self.trainer._log_stdout(env_step, rewards)
+
+ self.trainer.logger.log_scalar("trainer/total_cost", total_cost, self.trainer.num_steps)
- self._eval_routine(env_step, batch)
- self._visualize(env_step)
- self._save_policy(env_step)
- self._save_buffer(env_step)
- self._log_stdout(env_step, batch)
+ def _log_evaluation(self, env_step: int, rewards: t.Tensor) -> None:
+ if env_step % self.trainer.eval_interval == 0:
+ eval_metrics = self.evaluate()
+ self.trainer.logger.log_scalar(
+ "trainer/ep_reward", eval_metrics["return"], env_step
+ )
+ self.trainer.logger.log_scalar(
+ "trainer/ep_cost", eval_metrics["cost"], env_step
+ )
- self._logger.log_scalar("trainer/total_cost", total_cost, self.num_steps)
+ self.trainer.logger.log_scalar("trainer/avg_reward", rewards.mean().item(), env_step)
+ self.trainer.logger.log_scalar(
+ "trainer/buffer_transitions", len(self.trainer.replay_buffer), env_step
+ )
+ self.trainer.logger.log_scalar(
+ "trainer/buffer_episodes", self.trainer.replay_buffer.episodes_counter, env_step
+ )
+ self.trainer.logger.log_scalar(
+ "trainer/buffer_last_ep_len",
+ self.trainer.replay_buffer.last_episode_length,
+ env_step,
+ )
- def _log_evaluation(self, env_step: int):
+ def evaluate(self) -> dict[str, float]:
returns = []
costs = []
- for i_ep in range(self.num_eval_episodes):
- env_test = self._make_env_test(seed=self.seed + i_ep)
+ for i_ep in range(self.trainer.num_eval_episodes):
+ env_test = self.trainer.make_env_test(seed=self.trainer.seed + i_ep)
state, _ = env_test.reset()
episode_return = 0
@@ -121,7 +86,7 @@ def _log_evaluation(self, env_step: int):
terminated, truncated = False, False
while not (terminated or truncated):
- action = self._algo.exploit(state)
+ action = self.trainer.algo.actor.exploit(state)
state, reward, terminated, truncated, info = env_test.step(action)
episode_return += reward
episode_cost += info["cost"]
@@ -129,9 +94,8 @@ def _log_evaluation(self, env_step: int):
returns.append(episode_return)
costs.append(episode_cost)
- self._logger.log_scalar(
- "trainer/ep_reward", np.mean(returns, dtype=float), env_step
- )
- self._logger.log_scalar(
- "trainer/ep_cost", np.mean(costs, dtype=float), env_step
- )
+ return {
+ "return": float(np.mean(returns)),
+ "cost": float(np.mean(costs)),
+ }
+
diff --git a/src/oprl/utils/config.py b/src/oprl/utils/config.py
deleted file mode 100644
index 95bc38e..0000000
--- a/src/oprl/utils/config.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import importlib.util
-import sys
-
-
-def load_config(path: str):
- spec = importlib.util.spec_from_file_location("config", path)
- config = importlib.util.module_from_spec(spec)
- sys.modules["config"] = config
- spec.loader.exec_module(config)
- return config
diff --git a/src/oprl/utils/logger.py b/src/oprl/utils/logger.py
deleted file mode 100644
index c78175a..0000000
--- a/src/oprl/utils/logger.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import json
-import logging
-import os
-import shutil
-from abc import ABC, abstractmethod
-from sys import path
-from typing import Any
-
-import numpy as np
-import torch
-import torch.nn as nn
-from torch.utils.tensorboard.writer import SummaryWriter
-
-
-def copy_exp_dir(log_dir: str) -> None:
- cur_dir = os.path.join(os.getcwd(), "src")
- dest_dir = os.path.join(log_dir, "src")
- shutil.copytree(cur_dir, dest_dir)
- logging.info(f"Source copied into {dest_dir}")
-
-
-def save_json_config(config: dict[str, Any], path: str):
- with open(path, "w") as f:
- json.dump(config, f)
-
-
-class Logger(ABC):
-
- def log_scalars(self, values: dict[str, float], step: int):
- """
- Args:
- values: Dict with tag -> value to log.
- step: Iter step.
- """
- (self.log_scalar(k, v, step) for k, v in values.items())
-
- @abstractmethod
- def log_scalar(self, tag: str, value: float, step: int):
- logging.info(f"{tag}\t{value}\tat step {step}")
-
- @abstractmethod
- def log_video(self, tag: str, imgs, step: int):
- logging.warning("Skipping logging video in STDOUT logger")
-
-
-class StdLogger(Logger):
- def __init__(self, *args, **kwargs):
- pass
-
- def log_scalar(self, tag: str, value: float, step: int):
- logging.info(f"{tag}\t{value}\tat step {step}")
-
- def log_video(self, *args, **kwargs):
- logging.warning("Skipping logging video in STDOUT logger")
-
-
-class FileLogger(Logger):
- def __init__(self, logdir: str, config: dict[str, Any]):
- self.writer = SummaryWriter(logdir)
-
- self._log_dir = logdir
-
- logging.info(f"Source code is copied to {logdir}")
- copy_exp_dir(logdir)
- save_json_config(config, os.path.join(logdir, "config.json"))
-
- def log_scalar(self, tag: str, value: float, step: int) -> None:
- self.writer.add_scalar(tag, value, step)
- self._log_scalar_to_file(tag, value, step)
-
- def log_video(self, tag: str, imgs, step: int) -> None:
- os.makedirs(os.path.join(self._log_dir, "images"), exist_ok=True)
- fn = os.path.join(self._log_dir, "images", f"{tag}_step_{step}.npz")
- with open(fn, "wb") as f:
- np.save(f, imgs)
-
- def save_weights(self, weights: nn.Module, step: int) -> None:
- os.makedirs(os.path.join(self._log_dir, "weights"), exist_ok=True)
- fn = os.path.join(self._log_dir, "weights", f"step_{step}.w")
- torch.save(
- weights,
- fn
- )
-
- def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None:
- fn = os.path.join(self._log_dir, f"{tag}.log")
- os.makedirs(os.path.dirname(fn), exist_ok=True)
- with open(fn, "a") as f:
- f.write(f"{step} {val}\n")
diff --git a/src/oprl/utils/run_training.py b/src/oprl/utils/run_training.py
deleted file mode 100644
index d329a47..0000000
--- a/src/oprl/utils/run_training.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import logging
-from multiprocessing import Process
-
-from oprl.trainers.base_trainer import BaseTrainer
-from oprl.trainers.safe_trainer import SafeTrainer
-from oprl.utils.utils import set_seed
-
-
-def run_training(
- make_algo, make_env, make_logger, config, seeds: int = 1, start_seed: int = 0
-):
- if seeds == 1:
- _run_training_func(make_algo, make_env, make_logger, config, 0)
- else:
- processes = []
- for seed in range(start_seed, start_seed + seeds):
- processes.append(
- Process(
- target=_run_training_func,
- args=(make_algo, make_env, make_logger, config, seed),
- )
- )
-
- for i, p in enumerate(processes):
- p.start()
- logging.info(f"Starting process {i}...")
-
- for p in processes:
- p.join()
-
- logging.info("Training OK.")
-
-
-def _run_training_func(make_algo, make_env, make_logger, config, seed: int):
- set_seed(seed)
- env = make_env(seed=seed)
- logger = make_logger(seed)
-
- if env.env_family == "dm_control":
- trainer_class = BaseTrainer
- elif env.env_family == "safety_gymnasium":
- trainer_class = SafeTrainer
- else:
- raise ValueError(f"Unsupported env family: {env.env_family}")
-
- trainer = trainer_class(
- state_dim=config["state_dim"],
- action_dim=config["action_dim"],
- env=env,
- make_env_test=make_env,
- algo=make_algo(logger),
- num_steps=config["num_steps"],
- eval_interval=config["eval_every"],
- device=config["device"],
- save_buffer_every=config["save_buffer"],
- visualise_every=config["visualise_every"],
- estimate_q_every=config["estimate_q_every"],
- stdout_log_every=config["log_every"],
- seed=seed,
- logger=logger,
- )
-
- trainer.train()
diff --git a/src/oprl/utils/utils.py b/src/oprl/utils/utils.py
deleted file mode 100644
index 29736cc..0000000
--- a/src/oprl/utils/utils.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import logging
-import os
-import random
-import shutil
-import sys
-from glob import glob
-
-import imageio
-import numpy as np
-import torch as t
-
-
-class OUNoise(object):
- def __init__(
- self,
- dim,
- low,
- high,
- mu=0.0,
- theta=0.15,
- max_sigma=0.3,
- min_sigma=0.3,
- decay_period=10_000,
- ):
- self.mu = mu
- self.theta = theta
- self.sigma = max_sigma
- self.max_sigma = max_sigma
- self.min_sigma = min_sigma
- self.decay_period = decay_period
- self.action_dim = dim
- self.low = low
- self.high = high
-
- def reset(self):
- self.state = np.ones(self.action_dim) * self.mu
-
- def evolve_state(self):
- x = self.state
- dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
- self.state = x + dx
- return self.state
-
- def get_action(self, action, t=0):
- ou_state = self.evolve_state()
- self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(
- 1.0, t / self.decay_period
- )
- action = action.cpu().detach().numpy()
- return np.clip(action + ou_state, self.low, self.high)
-
-
-def make_gif(source_dir, output):
- """
- Make gif file from set of .jpeg images.
- Args:
- source_dir (str): path with .jpeg images
- output (str): path to the output .gif file
- Returns: None
- """
- batch_sort = lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])
- image_paths = sorted(glob(os.path.join(source_dir, "*.png")), key=batch_sort)
-
- images = []
- for filename in image_paths:
- images.append(imageio.imread(filename))
- imageio.mimsave(output, images)
-
-
-def empty_torch_queue(q):
- while True:
- try:
- o = q.get_nowait()
- del o
- except:
- break
- q.close()
-
-
-def copy_exp_dir(log_dir: str) -> None:
- cur_dir = os.path.join(os.getcwd(), "src")
- dest_dir = os.path.join(log_dir, "src")
- shutil.copy(cur_dir, dest_dir)
- logging.info(f"Source copied into {dest_dir}")
-
-
-def set_seed(seed: int) -> None:
- random.seed(seed)
- np.random.seed(seed)
- t.manual_seed(seed)
-
-
-def set_logging(level: int):
- logging.basicConfig(
- level=level,
- format="%(asctime)s | %(filename)s:%(lineno)d\t %(levelname)s - %(message)s",
- stream=sys.stdout,
- )
diff --git a/tests/functional/test_env.py b/tests/functional/test_env.py
index 1facf87..918da39 100644
--- a/tests/functional/test_env.py
+++ b/tests/functional/test_env.py
@@ -1,6 +1,6 @@
import pytest
-from oprl.env import make_env
+from oprl.environment import make_env
dm_control_envs: list[str] = [
@@ -31,6 +31,27 @@
]
+gymnasium_envs: list[str] = [
+ "Ant-v4",
+ "Hopper-v4",
+ "HalfCheetah-v4",
+ "HumanoidStandup-v4",
+ "Humanoid-v4",
+ "InvertedDoublePendulum-v4",
+ "InvertedPendulum-v4",
+ "Pusher-v4",
+ "Reacher-v4",
+ "Swimmer-v4",
+ "Walker2d-v4",
+]
+
+
+# TODO:
+# gymansium_robotics_envs: list[str] = [
+# "FetchPickAndPlace-v3",
+# ]
+
+
safety_envs: list[str] = [
"SafetyPointGoal1-v0",
"SafetyPointButton1-v0",
@@ -39,7 +60,7 @@
]
-env_names: list[str] = dm_control_envs + safety_envs
+env_names: list[str] = dm_control_envs + safety_envs + gymnasium_envs
@pytest.mark.parametrize("env_name", env_names)
diff --git a/tests/functional/test_logging.py b/tests/functional/test_logging.py
new file mode 100644
index 0000000..f09d8d8
--- /dev/null
+++ b/tests/functional/test_logging.py
@@ -0,0 +1,19 @@
+from logging import Logger
+
+from oprl.logging import (
+ LoggerProtocol,
+ create_stdout_logger,
+ make_text_logger_func,
+)
+
+
+def test_create_stdout_logger() -> None:
+ logger = create_stdout_logger()
+ assert isinstance(logger, Logger)
+
+
+def test_create_text_logger_func() -> None:
+ func = make_text_logger_func("test_algo", "test_env")
+ logger = func(0)
+ assert isinstance(logger, LoggerProtocol)
+
diff --git a/tests/functional/test_replay_buffer.py b/tests/functional/test_replay_buffer.py
new file mode 100644
index 0000000..dde0244
--- /dev/null
+++ b/tests/functional/test_replay_buffer.py
@@ -0,0 +1,21 @@
+from oprl.buffers.episodic_buffer import EpisodicReplayBuffer
+from oprl.buffers.protocols import ReplayBufferProtocol
+
+
+def test_replay_buffer() -> None:
+ state_dim = 7
+ max_episode_length = 10
+ num_transitions = 100
+ buffer = EpisodicReplayBuffer(
+ buffer_size_transitions=num_transitions,
+ state_dim=state_dim,
+ action_dim=3,
+ max_episode_lenth=max_episode_length,
+ ).create()
+ assert isinstance(buffer, ReplayBufferProtocol)
+
+ states = buffer.states
+ assert len(states.shape) == 3
+ assert states.shape[0] == num_transitions // max_episode_length
+ assert states.shape[1] == max_episode_length + 1
+ assert states.shape[2] == state_dim
diff --git a/tests/functional/test_rl_algos.py b/tests/functional/test_rl_algos.py
index 876f505..3f32ce0 100644
--- a/tests/functional/test_rl_algos.py
+++ b/tests/functional/test_rl_algos.py
@@ -1,28 +1,24 @@
import pytest
import torch
+import numpy.typing as npt
+from oprl.algos.protocols import AlgorithmProtocol
from oprl.algos.ddpg import DDPG
from oprl.algos.sac import SAC
from oprl.algos.td3 import TD3
from oprl.algos.tqc import TQC
-from oprl.env import DMControlEnv
+from oprl.environment import DMControlEnv
+from oprl.environment.protocols import EnvProtocol
+from oprl.logging import FileTxtLogger
-rl_algo_classes = [DDPG, SAC, TD3, TQC]
+rl_algo_classes: list[type[AlgorithmProtocol]] = [DDPG, SAC, TD3, TQC]
-@pytest.mark.parametrize("algo_class", rl_algo_classes)
-def test_rl_algo_run(algo_class):
- env = DMControlEnv("walker-walk", seed=0)
- obs, _ = env.reset(env.sample_action())
-
- algo = algo_class(
- state_dim=env.observation_space.shape[0],
- action_dim=env.action_space.shape[0],
- ).create()
- action = algo.exploit(obs)
+def _run_common_test(algo: AlgorithmProtocol, env: EnvProtocol, obs: npt.NDArray) -> None:
+ action = algo.actor.exploit(obs)
assert action.ndim == 1
- action = algo.explore(obs)
+ action = algo.actor.explore(obs)
assert action.ndim == 1
_batch_size = 8
@@ -33,3 +29,32 @@ def test_rl_algo_run(algo_class):
batch_rewards = torch.randn(_batch_size, 1)
batch_dones = torch.randint(2, (_batch_size, 1))
algo.update(batch_obs, batch_actions, batch_rewards, batch_dones, batch_obs)
+
+
+@pytest.mark.parametrize("algo_class", rl_algo_classes)
+def test_ddpg_td3_tqc(algo_class: type[AlgorithmProtocol]) -> None:
+ env = DMControlEnv("walker-walk", seed=0)
+ # TODO: Change to mocked logger
+ logger = FileTxtLogger(".")
+ obs, _ = env.reset()
+ algo = algo_class(
+ logger=logger,
+ state_dim=env.observation_space.shape[0],
+ action_dim=env.action_space.shape[0],
+ ).create()
+ _run_common_test(algo, env, obs)
+
+
+@pytest.mark.parametrize("tune_alpha", [True, False])
+def test_sac(tune_alpha: bool) -> None:
+ env = DMControlEnv("walker-walk", seed=0)
+ # TODO: Change to mocked logger
+ logger = FileTxtLogger(".")
+ obs, _ = env.reset()
+ algo = SAC(
+ logger=logger,
+ tune_alpha=tune_alpha,
+ state_dim=env.observation_space.shape[0],
+ action_dim=env.action_space.shape[0],
+ ).create()
+ _run_common_test(algo, env, obs)
diff --git a/uv.lock b/uv.lock
new file mode 100644
index 0000000..c0ea2d1
--- /dev/null
+++ b/uv.lock
@@ -0,0 +1,1141 @@
+version = 1
+revision = 2
+requires-python = "==3.10.8"
+
+[[package]]
+name = "absl-py"
+version = "2.3.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" },
+]
+
+[[package]]
+name = "annotated-types"
+version = "0.7.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" },
+]
+
+[[package]]
+name = "attrs"
+version = "25.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/5a/b0/1367933a8532ee6ff8d63537de4f1177af4bff9f3e829baf7331f595bb24/attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b", size = 812032, upload-time = "2025-03-13T11:10:22.779Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" },
+]
+
+[[package]]
+name = "black"
+version = "25.1.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "click" },
+ { name = "mypy-extensions" },
+ { name = "packaging" },
+ { name = "pathspec" },
+ { name = "platformdirs" },
+ { name = "tomli" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449, upload-time = "2025-01-29T04:15:40.373Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419, upload-time = "2025-01-29T05:37:06.642Z" },
+ { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080, upload-time = "2025-01-29T05:37:09.321Z" },
+ { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886, upload-time = "2025-01-29T04:18:24.432Z" },
+ { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404, upload-time = "2025-01-29T04:19:04.296Z" },
+ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" },
+]
+
+[[package]]
+name = "cachetools"
+version = "5.5.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380, upload-time = "2025-02-20T21:01:19.524Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080, upload-time = "2025-02-20T21:01:16.647Z" },
+]
+
+[[package]]
+name = "certifi"
+version = "2025.7.14"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/b3/76/52c535bcebe74590f296d6c77c86dabf761c41980e1347a2422e4aa2ae41/certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995", size = 163981, upload-time = "2025-07-14T03:29:28.449Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4f/52/34c6cf5bb9285074dc3531c437b3919e825d976fde097a7a73f79e726d03/certifi-2025.7.14-py3-none-any.whl", hash = "sha256:6b31f564a415d79ee77df69d757bb49a5bb53bd9f756cbbe24394ffd6fc1f4b2", size = 162722, upload-time = "2025-07-14T03:29:26.863Z" },
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.4.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e4/33/89c2ced2b67d1c2a61c19c6751aa8902d46ce3dacb23600a283619f5a12d/charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63", size = 126367, upload-time = "2025-05-02T08:34:42.01Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/95/28/9901804da60055b406e1a1c5ba7aac1276fb77f1dde635aabfc7fd84b8ab/charset_normalizer-3.4.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941", size = 201818, upload-time = "2025-05-02T08:31:46.725Z" },
+ { url = "https://files.pythonhosted.org/packages/d9/9b/892a8c8af9110935e5adcbb06d9c6fe741b6bb02608c6513983048ba1a18/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd", size = 144649, upload-time = "2025-05-02T08:31:48.889Z" },
+ { url = "https://files.pythonhosted.org/packages/7b/a5/4179abd063ff6414223575e008593861d62abfc22455b5d1a44995b7c101/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cbfacf36cb0ec2897ce0ebc5d08ca44213af24265bd56eca54bee7923c48fd6", size = 155045, upload-time = "2025-05-02T08:31:50.757Z" },
+ { url = "https://files.pythonhosted.org/packages/3b/95/bc08c7dfeddd26b4be8c8287b9bb055716f31077c8b0ea1cd09553794665/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18dd2e350387c87dabe711b86f83c9c78af772c748904d372ade190b5c7c9d4d", size = 147356, upload-time = "2025-05-02T08:31:52.634Z" },
+ { url = "https://files.pythonhosted.org/packages/a8/2d/7a5b635aa65284bf3eab7653e8b4151ab420ecbae918d3e359d1947b4d61/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8075c35cd58273fee266c58c0c9b670947c19df5fb98e7b66710e04ad4e9ff86", size = 149471, upload-time = "2025-05-02T08:31:56.207Z" },
+ { url = "https://files.pythonhosted.org/packages/ae/38/51fc6ac74251fd331a8cfdb7ec57beba8c23fd5493f1050f71c87ef77ed0/charset_normalizer-3.4.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5bf4545e3b962767e5c06fe1738f951f77d27967cb2caa64c28be7c4563e162c", size = 151317, upload-time = "2025-05-02T08:31:57.613Z" },
+ { url = "https://files.pythonhosted.org/packages/b7/17/edee1e32215ee6e9e46c3e482645b46575a44a2d72c7dfd49e49f60ce6bf/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a6ab32f7210554a96cd9e33abe3ddd86732beeafc7a28e9955cdf22ffadbab0", size = 146368, upload-time = "2025-05-02T08:31:59.468Z" },
+ { url = "https://files.pythonhosted.org/packages/26/2c/ea3e66f2b5f21fd00b2825c94cafb8c326ea6240cd80a91eb09e4a285830/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b33de11b92e9f75a2b545d6e9b6f37e398d86c3e9e9653c4864eb7e89c5773ef", size = 154491, upload-time = "2025-05-02T08:32:01.219Z" },
+ { url = "https://files.pythonhosted.org/packages/52/47/7be7fa972422ad062e909fd62460d45c3ef4c141805b7078dbab15904ff7/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8755483f3c00d6c9a77f490c17e6ab0c8729e39e6390328e42521ef175380ae6", size = 157695, upload-time = "2025-05-02T08:32:03.045Z" },
+ { url = "https://files.pythonhosted.org/packages/2f/42/9f02c194da282b2b340f28e5fb60762de1151387a36842a92b533685c61e/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:68a328e5f55ec37c57f19ebb1fdc56a248db2e3e9ad769919a58672958e8f366", size = 154849, upload-time = "2025-05-02T08:32:04.651Z" },
+ { url = "https://files.pythonhosted.org/packages/67/44/89cacd6628f31fb0b63201a618049be4be2a7435a31b55b5eb1c3674547a/charset_normalizer-3.4.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:21b2899062867b0e1fde9b724f8aecb1af14f2778d69aacd1a5a1853a597a5db", size = 150091, upload-time = "2025-05-02T08:32:06.719Z" },
+ { url = "https://files.pythonhosted.org/packages/1f/79/4b8da9f712bc079c0f16b6d67b099b0b8d808c2292c937f267d816ec5ecc/charset_normalizer-3.4.2-cp310-cp310-win32.whl", hash = "sha256:e8082b26888e2f8b36a042a58307d5b917ef2b1cacab921ad3323ef91901c71a", size = 98445, upload-time = "2025-05-02T08:32:08.66Z" },
+ { url = "https://files.pythonhosted.org/packages/7d/d7/96970afb4fb66497a40761cdf7bd4f6fca0fc7bafde3a84f836c1f57a926/charset_normalizer-3.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:f69a27e45c43520f5487f27627059b64aaf160415589230992cec34c5e18a509", size = 105782, upload-time = "2025-05-02T08:32:10.46Z" },
+ { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" },
+]
+
+[[package]]
+name = "click"
+version = "8.2.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" },
+]
+
+[[package]]
+name = "cloudpickle"
+version = "3.1.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/52/39/069100b84d7418bc358d81669d5748efb14b9cceacd2f9c75f550424132f/cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64", size = 22113, upload-time = "2025-01-14T17:02:05.085Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e", size = 20992, upload-time = "2025-01-14T17:02:02.417Z" },
+]
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
+]
+
+[[package]]
+name = "dm-control"
+version = "1.0.11"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "dm-env" },
+ { name = "dm-tree" },
+ { name = "glfw" },
+ { name = "labmaze" },
+ { name = "lxml" },
+ { name = "mujoco" },
+ { name = "numpy" },
+ { name = "protobuf" },
+ { name = "pyopengl" },
+ { name = "pyparsing" },
+ { name = "requests" },
+ { name = "scipy" },
+ { name = "setuptools" },
+ { name = "tqdm" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/e7/42/fd5aecc74c747c16e98097a07053438fb2ca6d13300b1a9eb27bddaad62c/dm_control-1.0.11.tar.gz", hash = "sha256:ac222c91a34be9d9d7573a168bdce791c8a6693cb84bd3de988090a96e8df010", size = 38991406, upload-time = "2023-03-22T16:45:07.783Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c7/b9/11639f413e2f407e71b4fcdcc83bd5914c785180667e5eed93ce6f3c28db/dm_control-1.0.11-py3-none-any.whl", hash = "sha256:2b46def2cfc5a547f61b496fee00287fd2af52c9cd5ba7e1e7a59a6973adaad9", size = 39291059, upload-time = "2023-03-22T16:44:54.358Z" },
+]
+
+[[package]]
+name = "dm-env"
+version = "1.6"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "dm-tree" },
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/62/c9/93e8d6239d5806508a2ee4b370e67c6069943ca149f59f533923737a99b7/dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de", size = 20187, upload-time = "2022-12-21T00:25:29.306Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/08/7e/36d548040e61337bf9182637a589c44da407a47a923ee88aec7f0e89867c/dm_env-1.6-py3-none-any.whl", hash = "sha256:0eabb6759dd453b625e041032f7ae0c1e87d4eb61b6a96b9ca586483837abf29", size = 26339, upload-time = "2022-12-21T00:25:37.128Z" },
+]
+
+[[package]]
+name = "dm-tree"
+version = "0.1.9"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "attrs" },
+ { name = "numpy" },
+ { name = "wrapt" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a6/83/ce29720ccf934c6cfa9b9c95ebbe96558386e66886626066632b5e44afed/dm_tree-0.1.9.tar.gz", hash = "sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b", size = 35623, upload-time = "2025-01-30T20:45:37.13Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/77/d2/88f685534d87072a5174fe229e77aab6b7da50092d5151ebc172f6270b5c/dm_tree-0.1.9-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5d5b28ee2e461b6af65330c143806a6d0945dcabbb8d22d2ba863e6dabd9254e", size = 173568, upload-time = "2025-03-31T08:35:38.425Z" },
+ { url = "https://files.pythonhosted.org/packages/d1/6a/64924e102f559c1380263a28a751f20a1bdd18e85ea599e216feead84adf/dm_tree-0.1.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54d5616015412311df154908069fcf2c2d8786f6088a2ae3554d186cdf2b1e15", size = 146935, upload-time = "2025-01-30T20:45:16.505Z" },
+ { url = "https://files.pythonhosted.org/packages/7c/79/ba0f7274164eb6bd06a36c2f8cb21b0debc32fd9ba8e73a7c9e50c90041b/dm_tree-0.1.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831699d2c60a1b38776a193b7143ae0acad0a687d87654e6d3342584166816bc", size = 152892, upload-time = "2025-01-30T20:45:18.021Z" },
+ { url = "https://files.pythonhosted.org/packages/bf/20/8b96a34a15c5c4d1d6af44795963fa44381716975aabac83beab4fe80974/dm_tree-0.1.9-cp310-cp310-win_amd64.whl", hash = "sha256:1ae3cbff592bb3f2e197f5a8030de4a94e292e6cdd85adeea0b971d07a1b85f2", size = 101469, upload-time = "2025-01-30T20:45:19.197Z" },
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.3.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" },
+]
+
+[[package]]
+name = "farama-notifications"
+version = "0.0.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/2e/2c/8384832b7a6b1fd6ba95bbdcae26e7137bb3eedc955c42fd5cdcc086cfbf/Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18", size = 2131, upload-time = "2023-02-27T18:28:41.047Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/05/2c/ffc08c54c05cdce6fbed2aeebc46348dbe180c6d2c541c7af7ba0aa5f5f8/Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae", size = 2511, upload-time = "2023-02-27T18:28:39.447Z" },
+]
+
+[[package]]
+name = "filelock"
+version = "3.18.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/0a/10/c23352565a6544bdc5353e0b15fc1c563352101f30e24bf500207a54df9a/filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2", size = 18075, upload-time = "2025-03-14T07:11:40.47Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215, upload-time = "2025-03-14T07:11:39.145Z" },
+]
+
+[[package]]
+name = "flake8"
+version = "7.3.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "mccabe" },
+ { name = "pycodestyle" },
+ { name = "pyflakes" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9b/af/fbfe3c4b5a657d79e5c47a2827a362f9e1b763336a52f926126aa6dc7123/flake8-7.3.0.tar.gz", hash = "sha256:fe044858146b9fc69b551a4b490d69cf960fcb78ad1edcb84e7fbb1b4a8e3872", size = 48326, upload-time = "2025-06-20T19:31:35.838Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9f/56/13ab06b4f93ca7cac71078fbe37fcea175d3216f31f85c3168a6bbd0bb9a/flake8-7.3.0-py2.py3-none-any.whl", hash = "sha256:b9696257b9ce8beb888cdbe31cf885c90d31928fe202be0889a7cdafad32f01e", size = 57922, upload-time = "2025-06-20T19:31:34.425Z" },
+]
+
+[[package]]
+name = "fsspec"
+version = "2025.7.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/8b/02/0835e6ab9cfc03916fe3f78c0956cfcdb6ff2669ffa6651065d5ebf7fc98/fsspec-2025.7.0.tar.gz", hash = "sha256:786120687ffa54b8283d942929540d8bc5ccfa820deb555a2b5d0ed2b737bf58", size = 304432, upload-time = "2025-07-15T16:05:21.19Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2f/e0/014d5d9d7a4564cf1c40b5039bc882db69fd881111e03ab3657ac0b218e2/fsspec-2025.7.0-py3-none-any.whl", hash = "sha256:8b012e39f63c7d5f10474de957f3ab793b47b45ae7d39f2fb735f8bbe25c0e21", size = 199597, upload-time = "2025-07-15T16:05:19.529Z" },
+]
+
+[[package]]
+name = "glfw"
+version = "2.9.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/38/97/a2d667c98b8474f6b8294042488c1bd488681fb3cb4c3b9cdac1a9114287/glfw-2.9.0.tar.gz", hash = "sha256:077111a150ff09bc302c5e4ae265a5eb6aeaff0c8b01f727f7fb34e3764bb8e2", size = 31453, upload-time = "2025-04-15T15:39:54.142Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/21/71/13dd8a8d547809543d21de9438a3a76a8728fc7966d01ad9fb54599aebf5/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-macosx_10_6_intel.whl", hash = "sha256:183da99152f63469e9263146db2eb1b6cc4ee0c4082b280743e57bd1b0a3bd70", size = 105297, upload-time = "2025-04-15T15:39:39.677Z" },
+ { url = "https://files.pythonhosted.org/packages/f8/a2/45e6dceec1e0a0ffa8dd3c0ecf1e11d74639a55186243129160c6434d456/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-macosx_11_0_arm64.whl", hash = "sha256:aef5b555673b9555216e4cd7bc0bdbbb9983f66c620a85ba7310cfcfda5cd38c", size = 102146, upload-time = "2025-04-15T15:39:42.354Z" },
+ { url = "https://files.pythonhosted.org/packages/d2/72/b6261ed918e3747c6070fe80636c63a3c8f1c42ce122670315eeeada156f/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux2014_aarch64.whl", hash = "sha256:fcc430cb21984afba74945b7df38a5e1a02b36c0b4a2a2bab42b4a26d7cc51d6", size = 230002, upload-time = "2025-04-15T15:39:43.933Z" },
+ { url = "https://files.pythonhosted.org/packages/45/d6/7f95786332e8b798569b8e60db2ee081874cec2a62572b8ec55c309d85b7/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux2014_x86_64.whl", hash = "sha256:7f85b58546880466ac445fc564c5c831ca93c8a99795ab8eaf0a2d521af293d7", size = 241949, upload-time = "2025-04-15T15:39:45.28Z" },
+ { url = "https://files.pythonhosted.org/packages/a1/e6/093ab7874a74bba351e754f6e7748c031bd7276702135da6cbcd00e1f3e2/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_aarch64.whl", hash = "sha256:2123716c8086b80b797e849a534fc6f21aebca300519e57c80618a65ca8135dc", size = 231016, upload-time = "2025-04-15T15:39:46.669Z" },
+ { url = "https://files.pythonhosted.org/packages/7f/ba/de3630757c7d7fc2086aaf3994926d6b869d31586e4d0c14f1666af31b93/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl", hash = "sha256:4e11271e49eb9bc53431ade022e284d5a59abeace81fe3b178db1bf3ccc0c449", size = 243489, upload-time = "2025-04-15T15:39:48.321Z" },
+ { url = "https://files.pythonhosted.org/packages/32/36/c3bada8503681806231d1705ea1802bac8febf69e4186b9f0f0b9e2e4f7e/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-win32.whl", hash = "sha256:8e4fbff88e4e953bb969b6813195d5de4641f886530cc8083897e56b00bf2c8e", size = 552655, upload-time = "2025-04-15T15:39:50.029Z" },
+ { url = "https://files.pythonhosted.org/packages/cb/70/7f2f052ca20c3b69892818f2ee1fea53b037ea9145ff75b944ed1dc4ff82/glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-win_amd64.whl", hash = "sha256:9aa3ae51601601c53838315bd2a03efb1e6bebecd072b2f64ddbd0b2556d511a", size = 559441, upload-time = "2025-04-15T15:39:52.531Z" },
+]
+
+[[package]]
+name = "google-auth"
+version = "2.40.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "cachetools" },
+ { name = "pyasn1-modules" },
+ { name = "rsa" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029, upload-time = "2025-06-04T18:04:57.577Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" },
+]
+
+[[package]]
+name = "google-auth-oauthlib"
+version = "1.2.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "google-auth" },
+ { name = "requests-oauthlib" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/fb/87/e10bf24f7bcffc1421b84d6f9c3377c30ec305d082cd737ddaa6d8f77f7c/google_auth_oauthlib-1.2.2.tar.gz", hash = "sha256:11046fb8d3348b296302dd939ace8af0a724042e8029c1b872d87fabc9f41684", size = 20955, upload-time = "2025-04-22T16:40:29.172Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ac/84/40ee070be95771acd2f4418981edb834979424565c3eec3cd88b6aa09d24/google_auth_oauthlib-1.2.2-py3-none-any.whl", hash = "sha256:fd619506f4b3908b5df17b65f39ca8d66ea56986e5472eb5978fd8f3786f00a2", size = 19072, upload-time = "2025-04-22T16:40:28.174Z" },
+]
+
+[[package]]
+name = "grpcio"
+version = "1.73.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/79/e8/b43b851537da2e2f03fa8be1aef207e5cbfb1a2e014fbb6b40d24c177cd3/grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87", size = 12730355, upload-time = "2025-06-26T01:53:24.622Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8f/51/a5748ab2773d893d099b92653039672f7e26dd35741020972b84d604066f/grpcio-1.73.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:2d70f4ddd0a823436c2624640570ed6097e40935c9194482475fe8e3d9754d55", size = 5365087, upload-time = "2025-06-26T01:51:44.541Z" },
+ { url = "https://files.pythonhosted.org/packages/ae/12/c5ee1a5dfe93dbc2eaa42a219e2bf887250b52e2e2ee5c036c4695f2769c/grpcio-1.73.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:3841a8a5a66830261ab6a3c2a3dc539ed84e4ab019165f77b3eeb9f0ba621f26", size = 10608921, upload-time = "2025-06-26T01:51:48.111Z" },
+ { url = "https://files.pythonhosted.org/packages/c4/6d/b0c6a8120f02b7d15c5accda6bfc43bc92be70ada3af3ba6d8e077c00374/grpcio-1.73.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:628c30f8e77e0258ab788750ec92059fc3d6628590fb4b7cea8c102503623ed7", size = 5803221, upload-time = "2025-06-26T01:51:50.486Z" },
+ { url = "https://files.pythonhosted.org/packages/a6/7a/3c886d9f1c1e416ae81f7f9c7d1995ae72cd64712d29dab74a6bafacb2d2/grpcio-1.73.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:67a0468256c9db6d5ecb1fde4bf409d016f42cef649323f0a08a72f352d1358b", size = 6444603, upload-time = "2025-06-26T01:51:52.203Z" },
+ { url = "https://files.pythonhosted.org/packages/42/07/f143a2ff534982c9caa1febcad1c1073cdec732f6ac7545d85555a900a7e/grpcio-1.73.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68b84d65bbdebd5926eb5c53b0b9ec3b3f83408a30e4c20c373c5337b4219ec5", size = 6040969, upload-time = "2025-06-26T01:51:55.028Z" },
+ { url = "https://files.pythonhosted.org/packages/fb/0f/523131b7c9196d0718e7b2dac0310eb307b4117bdbfef62382e760f7e8bb/grpcio-1.73.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c54796ca22b8349cc594d18b01099e39f2b7ffb586ad83217655781a350ce4da", size = 6132201, upload-time = "2025-06-26T01:51:56.867Z" },
+ { url = "https://files.pythonhosted.org/packages/ad/18/010a055410eef1d3a7a1e477ec9d93b091ac664ad93e9c5f56d6cc04bdee/grpcio-1.73.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:75fc8e543962ece2f7ecd32ada2d44c0c8570ae73ec92869f9af8b944863116d", size = 6774718, upload-time = "2025-06-26T01:51:58.338Z" },
+ { url = "https://files.pythonhosted.org/packages/16/11/452bfc1ab39d8ee748837ab8ee56beeae0290861052948785c2c445fb44b/grpcio-1.73.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6a6037891cd2b1dd1406b388660522e1565ed340b1fea2955b0234bdd941a862", size = 6304362, upload-time = "2025-06-26T01:51:59.802Z" },
+ { url = "https://files.pythonhosted.org/packages/1e/1c/c75ceee626465721e5cb040cf4b271eff817aa97388948660884cb7adffa/grpcio-1.73.1-cp310-cp310-win32.whl", hash = "sha256:cce7265b9617168c2d08ae570fcc2af4eaf72e84f8c710ca657cc546115263af", size = 3679036, upload-time = "2025-06-26T01:52:01.817Z" },
+ { url = "https://files.pythonhosted.org/packages/62/2e/42cb31b6cbd671a7b3dbd97ef33f59088cf60e3cf2141368282e26fafe79/grpcio-1.73.1-cp310-cp310-win_amd64.whl", hash = "sha256:6a2b372e65fad38842050943f42ce8fee00c6f2e8ea4f7754ba7478d26a356ee", size = 4340208, upload-time = "2025-06-26T01:52:03.674Z" },
+]
+
+[[package]]
+name = "gymnasium"
+version = "0.28.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "cloudpickle" },
+ { name = "farama-notifications" },
+ { name = "jax-jumpy" },
+ { name = "numpy" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/58/6a/c304954dc009648a21db245a8f56f63c8da8a025d446dd0fd67319726003/gymnasium-0.28.1.tar.gz", hash = "sha256:4c2c745808792c8f45c6e88ad0a5504774394e0c126f6e3db555e720d3da6f24", size = 796462, upload-time = "2023-03-25T12:02:00.613Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/60/82/3762ef4555791a729ae554e13c011efe5e8347d7eba9ea5ed245a8d1b234/gymnasium-0.28.1-py3-none-any.whl", hash = "sha256:7bc9a5bce1022f997d1dbc152fc91d1ac977bad9cc7794cdc25437010867cabf", size = 925534, upload-time = "2023-03-25T12:01:58.35Z" },
+]
+
+[[package]]
+name = "idna"
+version = "3.10"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
+]
+
+[[package]]
+name = "iniconfig"
+version = "2.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
+]
+
+[[package]]
+name = "jax-jumpy"
+version = "1.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/52/6a/b6affff68f172a4c8316d9ab9b7d952e865df15b854f158690991864e0fe/jax-jumpy-1.0.0.tar.gz", hash = "sha256:195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad", size = 19417, upload-time = "2023-03-17T16:52:56.598Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/73/23/338caee543d80584916da20f018aeb017764509d964fd347b97f41f97baa/jax_jumpy-1.0.0-py3-none-any.whl", hash = "sha256:ab7e01454bba462de3c4d098e3e585c302a8f06bc36d9182ab4e7e4aa7067c5e", size = 20368, upload-time = "2023-03-17T16:52:55.437Z" },
+]
+
+[[package]]
+name = "jinja2"
+version = "3.1.6"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "markupsafe" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
+]
+
+[[package]]
+name = "labmaze"
+version = "1.0.6"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "numpy" },
+ { name = "setuptools" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/93/0a/139c4ae896b9413bd4ca69c62b08ee98dcfc78a9cbfdb7cadd0dce2ad31d/labmaze-1.0.6.tar.gz", hash = "sha256:2e8de7094042a77d6972f1965cf5c9e8f971f1b34d225752f343190a825ebe73", size = 4670455, upload-time = "2022-12-05T18:42:43.566Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a0/0c/6a3941f48644c0b9305c7a22bd51974be1fed8e9233b16c893d728805143/labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a", size = 4815423, upload-time = "2022-12-05T18:41:47.351Z" },
+ { url = "https://files.pythonhosted.org/packages/d0/fe/b038c6a15732eb064767dc92ca39a38b2f5df183576384f0cfb6a4840f69/labmaze-1.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:157efaa93228c8ccce5cae337902dd652093e0fba9d3a0f6506e4bee272bb66f", size = 4806825, upload-time = "2022-12-05T18:41:49.922Z" },
+ { url = "https://files.pythonhosted.org/packages/59/ec/2762281d4f26845b20bb7529742a6914fcb07c8e7c522175b879df0127cf/labmaze-1.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3ce98b9541c5fe6a306e411e7d018121dd646f2c9978d763fad86f9f30c5f57", size = 4871532, upload-time = "2022-12-05T18:41:52.784Z" },
+ { url = "https://files.pythonhosted.org/packages/4d/93/abac7877e1d7de984a2f0f5be561ff0dc795ae7e22595cf2f7c7032cd27e/labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e6433bd49bc541791de8191040526fddfebb77151620eb04203453f43ee486a", size = 4875892, upload-time = "2022-12-05T18:41:55.603Z" },
+ { url = "https://files.pythonhosted.org/packages/c4/10/5262db11b3c1db8e4fbc3feed9baed4f95db6047b8d9dcaf4f9fb8da9ba3/labmaze-1.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:6a507fc35961f1b1479708e2716f65e0d0611cefb55f31a77be29ce2339b6fef", size = 4812953, upload-time = "2022-12-05T18:41:58.098Z" },
+]
+
+[[package]]
+name = "lxml"
+version = "6.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/c5/ed/60eb6fa2923602fba988d9ca7c5cdbd7cf25faa795162ed538b527a35411/lxml-6.0.0.tar.gz", hash = "sha256:032e65120339d44cdc3efc326c9f660f5f7205f3a535c1fdbf898b29ea01fb72", size = 4096938, upload-time = "2025-06-26T16:28:19.373Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4b/e9/9c3ca02fbbb7585116c2e274b354a2d92b5c70561687dd733ec7b2018490/lxml-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35bc626eec405f745199200ccb5c6b36f202675d204aa29bb52e27ba2b71dea8", size = 8399057, upload-time = "2025-06-26T16:25:02.169Z" },
+ { url = "https://files.pythonhosted.org/packages/86/25/10a6e9001191854bf283515020f3633b1b1f96fd1b39aa30bf8fff7aa666/lxml-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:246b40f8a4aec341cbbf52617cad8ab7c888d944bfe12a6abd2b1f6cfb6f6082", size = 4569676, upload-time = "2025-06-26T16:25:05.431Z" },
+ { url = "https://files.pythonhosted.org/packages/f5/a5/378033415ff61d9175c81de23e7ad20a3ffb614df4ffc2ffc86bc6746ffd/lxml-6.0.0-cp310-cp310-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:2793a627e95d119e9f1e19720730472f5543a6d84c50ea33313ce328d870f2dd", size = 5291361, upload-time = "2025-06-26T16:25:07.901Z" },
+ { url = "https://files.pythonhosted.org/packages/5a/a6/19c87c4f3b9362b08dc5452a3c3bce528130ac9105fc8fff97ce895ce62e/lxml-6.0.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:46b9ed911f36bfeb6338e0b482e7fe7c27d362c52fde29f221fddbc9ee2227e7", size = 5008290, upload-time = "2025-06-28T18:47:13.196Z" },
+ { url = "https://files.pythonhosted.org/packages/09/d1/e9b7ad4b4164d359c4d87ed8c49cb69b443225cb495777e75be0478da5d5/lxml-6.0.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b4790b558bee331a933e08883c423f65bbcd07e278f91b2272489e31ab1e2b4", size = 5163192, upload-time = "2025-06-28T18:47:17.279Z" },
+ { url = "https://files.pythonhosted.org/packages/56/d6/b3eba234dc1584744b0b374a7f6c26ceee5dc2147369a7e7526e25a72332/lxml-6.0.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2030956cf4886b10be9a0285c6802e078ec2391e1dd7ff3eb509c2c95a69b76", size = 5076973, upload-time = "2025-06-26T16:25:10.936Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/47/897142dd9385dcc1925acec0c4afe14cc16d310ce02c41fcd9010ac5d15d/lxml-6.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d23854ecf381ab1facc8f353dcd9adeddef3652268ee75297c1164c987c11dc", size = 5297795, upload-time = "2025-06-26T16:25:14.282Z" },
+ { url = "https://files.pythonhosted.org/packages/fb/db/551ad84515c6f415cea70193a0ff11d70210174dc0563219f4ce711655c6/lxml-6.0.0-cp310-cp310-manylinux_2_31_armv7l.whl", hash = "sha256:43fe5af2d590bf4691531b1d9a2495d7aab2090547eaacd224a3afec95706d76", size = 4776547, upload-time = "2025-06-26T16:25:17.123Z" },
+ { url = "https://files.pythonhosted.org/packages/e0/14/c4a77ab4f89aaf35037a03c472f1ccc54147191888626079bd05babd6808/lxml-6.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:74e748012f8c19b47f7d6321ac929a9a94ee92ef12bc4298c47e8b7219b26541", size = 5124904, upload-time = "2025-06-26T16:25:19.485Z" },
+ { url = "https://files.pythonhosted.org/packages/70/b4/12ae6a51b8da106adec6a2e9c60f532350a24ce954622367f39269e509b1/lxml-6.0.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:43cfbb7db02b30ad3926e8fceaef260ba2fb7df787e38fa2df890c1ca7966c3b", size = 4805804, upload-time = "2025-06-26T16:25:21.949Z" },
+ { url = "https://files.pythonhosted.org/packages/a9/b6/2e82d34d49f6219cdcb6e3e03837ca5fb8b7f86c2f35106fb8610ac7f5b8/lxml-6.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:34190a1ec4f1e84af256495436b2d196529c3f2094f0af80202947567fdbf2e7", size = 5323477, upload-time = "2025-06-26T16:25:24.475Z" },
+ { url = "https://files.pythonhosted.org/packages/a1/e6/b83ddc903b05cd08a5723fefd528eee84b0edd07bdf87f6c53a1fda841fd/lxml-6.0.0-cp310-cp310-win32.whl", hash = "sha256:5967fe415b1920a3877a4195e9a2b779249630ee49ece22021c690320ff07452", size = 3613840, upload-time = "2025-06-26T16:25:27.345Z" },
+ { url = "https://files.pythonhosted.org/packages/40/af/874fb368dd0c663c030acb92612341005e52e281a102b72a4c96f42942e1/lxml-6.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:f3389924581d9a770c6caa4df4e74b606180869043b9073e2cec324bad6e306e", size = 3993584, upload-time = "2025-06-26T16:25:29.391Z" },
+ { url = "https://files.pythonhosted.org/packages/4a/f4/d296bc22c17d5607653008f6dd7b46afdfda12efd31021705b507df652bb/lxml-6.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:522fe7abb41309e9543b0d9b8b434f2b630c5fdaf6482bee642b34c8c70079c8", size = 3681400, upload-time = "2025-06-26T16:25:31.421Z" },
+ { url = "https://files.pythonhosted.org/packages/66/e1/2c22a3cff9e16e1d717014a1e6ec2bf671bf56ea8716bb64466fcf820247/lxml-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:dbdd7679a6f4f08152818043dbb39491d1af3332128b3752c3ec5cebc0011a72", size = 3898804, upload-time = "2025-06-26T16:27:59.751Z" },
+ { url = "https://files.pythonhosted.org/packages/2b/3a/d68cbcb4393a2a0a867528741fafb7ce92dac5c9f4a1680df98e5e53e8f5/lxml-6.0.0-pp310-pypy310_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40442e2a4456e9910875ac12951476d36c0870dcb38a68719f8c4686609897c4", size = 4216406, upload-time = "2025-06-28T18:47:45.518Z" },
+ { url = "https://files.pythonhosted.org/packages/15/8f/d9bfb13dff715ee3b2a1ec2f4a021347ea3caf9aba93dea0cfe54c01969b/lxml-6.0.0-pp310-pypy310_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:db0efd6bae1c4730b9c863fc4f5f3c0fa3e8f05cae2c44ae141cb9dfc7d091dc", size = 4326455, upload-time = "2025-06-28T18:47:48.411Z" },
+ { url = "https://files.pythonhosted.org/packages/01/8b/fde194529ee8a27e6f5966d7eef05fa16f0567e4a8e8abc3b855ef6b3400/lxml-6.0.0-pp310-pypy310_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ab542c91f5a47aaa58abdd8ea84b498e8e49fe4b883d67800017757a3eb78e8", size = 4268788, upload-time = "2025-06-26T16:28:02.776Z" },
+ { url = "https://files.pythonhosted.org/packages/99/a8/3b8e2581b4f8370fc9e8dc343af4abdfadd9b9229970fc71e67bd31c7df1/lxml-6.0.0-pp310-pypy310_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:013090383863b72c62a702d07678b658fa2567aa58d373d963cca245b017e065", size = 4411394, upload-time = "2025-06-26T16:28:05.179Z" },
+ { url = "https://files.pythonhosted.org/packages/e7/a5/899a4719e02ff4383f3f96e5d1878f882f734377f10dfb69e73b5f223e44/lxml-6.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c86df1c9af35d903d2b52d22ea3e66db8058d21dc0f59842ca5deb0595921141", size = 3517946, upload-time = "2025-06-26T16:28:07.665Z" },
+]
+
+[[package]]
+name = "markdown"
+version = "3.8.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d7/c2/4ab49206c17f75cb08d6311171f2d65798988db4360c4d1485bd0eedd67c/markdown-3.8.2.tar.gz", hash = "sha256:247b9a70dd12e27f67431ce62523e675b866d254f900c4fe75ce3dda62237c45", size = 362071, upload-time = "2025-06-19T17:12:44.483Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/96/2b/34cc11786bc00d0f04d0f5fdc3a2b1ae0b6239eef72d3d345805f9ad92a1/markdown-3.8.2-py3-none-any.whl", hash = "sha256:5c83764dbd4e00bdd94d85a19b8d55ccca20fe35b2e678a1422b380324dd5f24", size = 106827, upload-time = "2025-06-19T17:12:42.994Z" },
+]
+
+[[package]]
+name = "markupsafe"
+version = "3.0.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537, upload-time = "2024-10-18T15:21:54.129Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/04/90/d08277ce111dd22f77149fd1a5d4653eeb3b3eaacbdfcbae5afb2600eebd/MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8", size = 14357, upload-time = "2024-10-18T15:20:51.44Z" },
+ { url = "https://files.pythonhosted.org/packages/04/e1/6e2194baeae0bca1fae6629dc0cbbb968d4d941469cbab11a3872edff374/MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158", size = 12393, upload-time = "2024-10-18T15:20:52.426Z" },
+ { url = "https://files.pythonhosted.org/packages/1d/69/35fa85a8ece0a437493dc61ce0bb6d459dcba482c34197e3efc829aa357f/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579", size = 21732, upload-time = "2024-10-18T15:20:53.578Z" },
+ { url = "https://files.pythonhosted.org/packages/22/35/137da042dfb4720b638d2937c38a9c2df83fe32d20e8c8f3185dbfef05f7/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d", size = 20866, upload-time = "2024-10-18T15:20:55.06Z" },
+ { url = "https://files.pythonhosted.org/packages/29/28/6d029a903727a1b62edb51863232152fd335d602def598dade38996887f0/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb", size = 20964, upload-time = "2024-10-18T15:20:55.906Z" },
+ { url = "https://files.pythonhosted.org/packages/cc/cd/07438f95f83e8bc028279909d9c9bd39e24149b0d60053a97b2bc4f8aa51/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b", size = 21977, upload-time = "2024-10-18T15:20:57.189Z" },
+ { url = "https://files.pythonhosted.org/packages/29/01/84b57395b4cc062f9c4c55ce0df7d3108ca32397299d9df00fedd9117d3d/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c", size = 21366, upload-time = "2024-10-18T15:20:58.235Z" },
+ { url = "https://files.pythonhosted.org/packages/bd/6e/61ebf08d8940553afff20d1fb1ba7294b6f8d279df9fd0c0db911b4bbcfd/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171", size = 21091, upload-time = "2024-10-18T15:20:59.235Z" },
+ { url = "https://files.pythonhosted.org/packages/11/23/ffbf53694e8c94ebd1e7e491de185124277964344733c45481f32ede2499/MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50", size = 15065, upload-time = "2024-10-18T15:21:00.307Z" },
+ { url = "https://files.pythonhosted.org/packages/44/06/e7175d06dd6e9172d4a69a72592cb3f7a996a9c396eee29082826449bbc3/MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a", size = 15514, upload-time = "2024-10-18T15:21:01.122Z" },
+]
+
+[[package]]
+name = "mccabe"
+version = "0.7.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e7/ff/0ffefdcac38932a54d2b5eed4e0ba8a408f215002cd178ad1df0f2806ff8/mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325", size = 9658, upload-time = "2022-01-24T01:14:51.113Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350, upload-time = "2022-01-24T01:14:49.62Z" },
+]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" },
+]
+
+[[package]]
+name = "mujoco"
+version = "2.3.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "glfw" },
+ { name = "numpy" },
+ { name = "pyopengl" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/7c/03/7ce6078745085febd22fc4586a63016a71125031b97883d456ce1d64e5ed/mujoco-2.3.3.zip", hash = "sha256:8bd074d3c5d9d25416cf2a5b82b337a7431a6e20edbd0da7fbc05ee5255c1aaa", size = 633278, upload-time = "2023-03-20T18:23:59.89Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/21/2c/f59255b4fbd159a374c4077721bc5baac11afcc15b8a28f8c7658ac89df5/mujoco-2.3.3-2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:7b95a0b7ae8bb9e36d04ba475a950025791c43087845235bb92bd2dd1787589a", size = 4308671, upload-time = "2023-03-20T18:09:30.955Z" },
+ { url = "https://files.pythonhosted.org/packages/31/20/afc0ef5d5b9d96f3853458329f39518e638f41c0439c94ae2b97a9ab9f3a/mujoco-2.3.3-2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8b9f8e9e6d47fe60f96dc54be780a66a38d5ec2ff94d091ad54c6f87468b4b6a", size = 4177357, upload-time = "2023-03-20T18:09:34.016Z" },
+ { url = "https://files.pythonhosted.org/packages/6d/da/27b0ef31aa23f64c21e7129ddf378c2548be45d87433f89e2bea4cafc811/mujoco-2.3.3-2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9cbd3b60ac30f0b6661de050ca3e1de906b9c28ed9d084bdd708c141953247d", size = 3995762, upload-time = "2023-03-20T18:09:36.335Z" },
+ { url = "https://files.pythonhosted.org/packages/6d/27/90cc9b4f88c5b797417e1fbeacb7590cd85f7e464a8ab79f60c885708e39/mujoco-2.3.3-2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c7fd195bca86102788d86dbc2773c59c1f0ee3933b35eb0f6a86f0f2aeb5065", size = 4270849, upload-time = "2023-03-20T18:09:38.593Z" },
+ { url = "https://files.pythonhosted.org/packages/cd/b3/e9119ebbbe9ea830e6c8ab7eafc0de7c82b38d6b71d2b2e38ba20e43a1b7/mujoco-2.3.3-2-cp310-cp310-win_amd64.whl", hash = "sha256:f3595e992770eff3f842cb80f7eb2b7b1b3e78995b6ecc247f98036da17ef74f", size = 3190540, upload-time = "2023-03-20T18:09:41.101Z" },
+]
+
+[[package]]
+name = "mypy-extensions"
+version = "1.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
+]
+
+[[package]]
+name = "networkx"
+version = "3.4.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" },
+]
+
+[[package]]
+name = "numpy"
+version = "1.26.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129, upload-time = "2024-02-06T00:26:44.495Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a7/94/ace0fdea5241a27d13543ee117cbc65868e82213fb31a8eb7fe9ff23f313/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", size = 20631468, upload-time = "2024-02-05T23:48:01.194Z" },
+ { url = "https://files.pythonhosted.org/packages/20/f7/b24208eba89f9d1b58c1668bc6c8c4fd472b20c45573cb767f59d49fb0f6/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a", size = 13966411, upload-time = "2024-02-05T23:48:29.038Z" },
+ { url = "https://files.pythonhosted.org/packages/fc/a5/4beee6488160798683eed5bdb7eead455892c3b4e1f78d79d8d3f3b084ac/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4", size = 14219016, upload-time = "2024-02-05T23:48:54.098Z" },
+ { url = "https://files.pythonhosted.org/packages/4b/d7/ecf66c1cd12dc28b4040b15ab4d17b773b87fa9d29ca16125de01adb36cd/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f", size = 18240889, upload-time = "2024-02-05T23:49:25.361Z" },
+ { url = "https://files.pythonhosted.org/packages/24/03/6f229fe3187546435c4f6f89f6d26c129d4f5bed40552899fcf1f0bf9e50/numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a", size = 13876746, upload-time = "2024-02-05T23:49:51.983Z" },
+ { url = "https://files.pythonhosted.org/packages/39/fe/39ada9b094f01f5a35486577c848fe274e374bbf8d8f472e1423a0bbd26d/numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2", size = 18078620, upload-time = "2024-02-05T23:50:22.515Z" },
+ { url = "https://files.pythonhosted.org/packages/d5/ef/6ad11d51197aad206a9ad2286dc1aac6a378059e06e8cf22cd08ed4f20dc/numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07", size = 5972659, upload-time = "2024-02-05T23:50:35.834Z" },
+ { url = "https://files.pythonhosted.org/packages/19/77/538f202862b9183f54108557bfda67e17603fc560c384559e769321c9d92/numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5", size = 15808905, upload-time = "2024-02-05T23:51:03.701Z" },
+]
+
+[[package]]
+name = "nvidia-cublas-cu12"
+version = "12.1.3.1"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774, upload-time = "2023-04-19T15:50:03.519Z" },
+]
+
+[[package]]
+name = "nvidia-cuda-cupti-cu12"
+version = "12.1.105"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015, upload-time = "2023-04-19T15:47:32.502Z" },
+]
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu12"
+version = "12.1.105"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734, upload-time = "2023-04-19T15:48:32.42Z" },
+]
+
+[[package]]
+name = "nvidia-cuda-runtime-cu12"
+version = "12.1.105"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596, upload-time = "2023-04-19T15:47:22.471Z" },
+]
+
+[[package]]
+name = "nvidia-cudnn-cu12"
+version = "8.9.2.26"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-cublas-cu12" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9", size = 731725872, upload-time = "2023-06-01T19:24:57.328Z" },
+]
+
+[[package]]
+name = "nvidia-cufft-cu12"
+version = "11.0.2.54"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161, upload-time = "2023-04-19T15:50:46Z" },
+]
+
+[[package]]
+name = "nvidia-curand-cu12"
+version = "10.3.2.106"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784, upload-time = "2023-04-19T15:51:04.804Z" },
+]
+
+[[package]]
+name = "nvidia-cusolver-cu12"
+version = "11.4.5.107"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-cublas-cu12" },
+ { name = "nvidia-cusparse-cu12" },
+ { name = "nvidia-nvjitlink-cu12" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928, upload-time = "2023-04-19T15:51:25.781Z" },
+]
+
+[[package]]
+name = "nvidia-cusparse-cu12"
+version = "12.1.0.106"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nvidia-nvjitlink-cu12" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278, upload-time = "2023-04-19T15:51:49.939Z" },
+]
+
+[[package]]
+name = "nvidia-nccl-cu12"
+version = "2.19.3"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/38/00/d0d4e48aef772ad5aebcf70b73028f88db6e5640b36c38e90445b7a57c45/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d", size = 165987969, upload-time = "2023-10-24T16:16:24.789Z" },
+]
+
+[[package]]
+name = "nvidia-nvjitlink-cu12"
+version = "12.9.86"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9", size = 39748338, upload-time = "2025-06-05T20:10:25.613Z" },
+]
+
+[[package]]
+name = "nvidia-nvtx-cu12"
+version = "12.1.105"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138, upload-time = "2023-04-19T15:48:43.556Z" },
+]
+
+[[package]]
+name = "oauthlib"
+version = "3.3.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/0b/5f/19930f824ffeb0ad4372da4812c50edbd1434f678c90c2733e1188edfc63/oauthlib-3.3.1.tar.gz", hash = "sha256:0f0f8aa759826a193cf66c12ea1af1637f87b9b4622d46e866952bb022e538c9", size = 185918, upload-time = "2025-06-19T22:48:08.269Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" },
+]
+
+[[package]]
+name = "oprl"
+version = "0.1.0"
+source = { editable = "." }
+dependencies = [
+ { name = "dm-control" },
+ { name = "gymnasium" },
+ { name = "mujoco" },
+ { name = "numpy" },
+ { name = "packaging" },
+ { name = "pika" },
+ { name = "pydantic-settings" },
+ { name = "ruff" },
+ { name = "tensorboard" },
+ { name = "torch" },
+]
+
+[package.optional-dependencies]
+dev = [
+ { name = "black" },
+ { name = "flake8" },
+ { name = "pytest" },
+]
+
+[package.metadata]
+requires-dist = [
+ { name = "black", marker = "extra == 'dev'" },
+ { name = "dm-control", specifier = "==1.0.11" },
+ { name = "flake8", marker = "extra == 'dev'" },
+ { name = "gymnasium", specifier = "==0.28.1" },
+ { name = "mujoco", specifier = "==2.3.3" },
+ { name = "numpy", specifier = "==1.26.4" },
+ { name = "packaging", specifier = "==23.2" },
+ { name = "pika", specifier = "==1.3.2" },
+ { name = "pydantic-settings", specifier = "==2.10.1" },
+ { name = "pytest", marker = "extra == 'dev'", specifier = ">=6.0" },
+ { name = "ruff", specifier = ">=0.12.3" },
+ { name = "tensorboard", specifier = "==2.15.1" },
+ { name = "torch", specifier = "==2.2.2" },
+]
+provides-extras = ["dev"]
+
+[[package]]
+name = "packaging"
+version = "23.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" },
+]
+
+[[package]]
+name = "pathspec"
+version = "0.12.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" },
+]
+
+[[package]]
+name = "pika"
+version = "1.3.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/db/db/d4102f356af18f316c67f2cead8ece307f731dd63140e2c71f170ddacf9b/pika-1.3.2.tar.gz", hash = "sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f", size = 145029, upload-time = "2023-05-05T14:25:43.368Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f9/f3/f412836ec714d36f0f4ab581b84c491e3f42c6b5b97a6c6ed1817f3c16d0/pika-1.3.2-py3-none-any.whl", hash = "sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f", size = 155415, upload-time = "2023-05-05T14:25:41.484Z" },
+]
+
+[[package]]
+name = "platformdirs"
+version = "4.3.8"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/fe/8b/3c73abc9c759ecd3f1f7ceff6685840859e8070c4d947c93fae71f6a0bf2/platformdirs-4.3.8.tar.gz", hash = "sha256:3d512d96e16bcb959a814c9f348431070822a6496326a4be0911c40b5a74c2bc", size = 21362, upload-time = "2025-05-07T22:47:42.121Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fe/39/979e8e21520d4e47a0bbe349e2713c0aac6f3d853d0e5b34d76206c439aa/platformdirs-4.3.8-py3-none-any.whl", hash = "sha256:ff7059bb7eb1179e2685604f4aaf157cfd9535242bd23742eadc3c13542139b4", size = 18567, upload-time = "2025-05-07T22:47:40.376Z" },
+]
+
+[[package]]
+name = "pluggy"
+version = "1.6.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
+]
+
+[[package]]
+name = "protobuf"
+version = "4.23.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/d3/1c/de86d82a5fc780feca36ef52c1231823bb3140266af8a04ed6286957aa6e/protobuf-4.23.4.tar.gz", hash = "sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9", size = 400173, upload-time = "2023-07-06T23:28:22.071Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6f/a7/872807299eb114956c665fb1717ce106a8874db08a724651ac4f78c1198c/protobuf-4.23.4-cp310-abi3-win32.whl", hash = "sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b", size = 402981, upload-time = "2023-07-06T23:27:57.401Z" },
+ { url = "https://files.pythonhosted.org/packages/80/70/dc63d340d27b8ff22022d7dd14b8d6d68b479a003eacdc4507150a286d9a/protobuf-4.23.4-cp310-abi3-win_amd64.whl", hash = "sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12", size = 422467, upload-time = "2023-07-06T23:28:00.387Z" },
+ { url = "https://files.pythonhosted.org/packages/cb/d3/a164038605494d49acc4f9cda1c0bc200b96382c53edd561387263bb181d/protobuf-4.23.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd", size = 400308, upload-time = "2023-07-06T23:28:02.356Z" },
+ { url = "https://files.pythonhosted.org/packages/71/42/3a7fc57f360f728f38eca6656e8d00edaf22bc0ffc35dd2936f23e5fbb3e/protobuf-4.23.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a", size = 303455, upload-time = "2023-07-06T23:28:04.292Z" },
+ { url = "https://files.pythonhosted.org/packages/01/cb/445b3e465abdb8042a41957dc8f60c54620dc7540dbcf9b458a921531ca2/protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597", size = 304498, upload-time = "2023-07-06T23:28:06.277Z" },
+ { url = "https://files.pythonhosted.org/packages/b0/07/fb712cce15ba456f7c24b82b97c8a7db2233f07037ffe61c9011660c592a/protobuf-4.23.4-py3-none-any.whl", hash = "sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff", size = 173332, upload-time = "2023-07-06T23:28:20.053Z" },
+]
+
+[[package]]
+name = "pyasn1"
+version = "0.6.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" },
+]
+
+[[package]]
+name = "pyasn1-modules"
+version = "0.4.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pyasn1" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" },
+]
+
+[[package]]
+name = "pycodestyle"
+version = "2.14.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/11/e0/abfd2a0d2efe47670df87f3e3a0e2edda42f055053c85361f19c0e2c1ca8/pycodestyle-2.14.0.tar.gz", hash = "sha256:c4b5b517d278089ff9d0abdec919cd97262a3367449ea1c8b49b91529167b783", size = 39472, upload-time = "2025-06-20T18:49:48.75Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d7/27/a58ddaf8c588a3ef080db9d0b7e0b97215cee3a45df74f3a94dbbf5c893a/pycodestyle-2.14.0-py2.py3-none-any.whl", hash = "sha256:dd6bf7cb4ee77f8e016f9c8e74a35ddd9f67e1d5fd4184d86c3b98e07099f42d", size = 31594, upload-time = "2025-06-20T18:49:47.491Z" },
+]
+
+[[package]]
+name = "pydantic"
+version = "2.11.7"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "annotated-types" },
+ { name = "pydantic-core" },
+ { name = "typing-extensions" },
+ { name = "typing-inspection" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/00/dd/4325abf92c39ba8623b5af936ddb36ffcfe0beae70405d456ab1fb2f5b8c/pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db", size = 788350, upload-time = "2025-06-14T08:33:17.137Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" },
+]
+
+[[package]]
+name = "pydantic-core"
+version = "2.33.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e5/92/b31726561b5dae176c2d2c2dc43a9c5bfba5d32f96f8b4c0a600dd492447/pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8", size = 2028817, upload-time = "2025-04-23T18:30:43.919Z" },
+ { url = "https://files.pythonhosted.org/packages/a3/44/3f0b95fafdaca04a483c4e685fe437c6891001bf3ce8b2fded82b9ea3aa1/pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d", size = 1861357, upload-time = "2025-04-23T18:30:46.372Z" },
+ { url = "https://files.pythonhosted.org/packages/30/97/e8f13b55766234caae05372826e8e4b3b96e7b248be3157f53237682e43c/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d", size = 1898011, upload-time = "2025-04-23T18:30:47.591Z" },
+ { url = "https://files.pythonhosted.org/packages/9b/a3/99c48cf7bafc991cc3ee66fd544c0aae8dc907b752f1dad2d79b1b5a471f/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572", size = 1982730, upload-time = "2025-04-23T18:30:49.328Z" },
+ { url = "https://files.pythonhosted.org/packages/de/8e/a5b882ec4307010a840fb8b58bd9bf65d1840c92eae7534c7441709bf54b/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02", size = 2136178, upload-time = "2025-04-23T18:30:50.907Z" },
+ { url = "https://files.pythonhosted.org/packages/e4/bb/71e35fc3ed05af6834e890edb75968e2802fe98778971ab5cba20a162315/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b", size = 2736462, upload-time = "2025-04-23T18:30:52.083Z" },
+ { url = "https://files.pythonhosted.org/packages/31/0d/c8f7593e6bc7066289bbc366f2235701dcbebcd1ff0ef8e64f6f239fb47d/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2", size = 2005652, upload-time = "2025-04-23T18:30:53.389Z" },
+ { url = "https://files.pythonhosted.org/packages/d2/7a/996d8bd75f3eda405e3dd219ff5ff0a283cd8e34add39d8ef9157e722867/pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a", size = 2113306, upload-time = "2025-04-23T18:30:54.661Z" },
+ { url = "https://files.pythonhosted.org/packages/ff/84/daf2a6fb2db40ffda6578a7e8c5a6e9c8affb251a05c233ae37098118788/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac", size = 2073720, upload-time = "2025-04-23T18:30:56.11Z" },
+ { url = "https://files.pythonhosted.org/packages/77/fb/2258da019f4825128445ae79456a5499c032b55849dbd5bed78c95ccf163/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a", size = 2244915, upload-time = "2025-04-23T18:30:57.501Z" },
+ { url = "https://files.pythonhosted.org/packages/d8/7a/925ff73756031289468326e355b6fa8316960d0d65f8b5d6b3a3e7866de7/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b", size = 2241884, upload-time = "2025-04-23T18:30:58.867Z" },
+ { url = "https://files.pythonhosted.org/packages/0b/b0/249ee6d2646f1cdadcb813805fe76265745c4010cf20a8eba7b0e639d9b2/pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22", size = 1910496, upload-time = "2025-04-23T18:31:00.078Z" },
+ { url = "https://files.pythonhosted.org/packages/66/ff/172ba8f12a42d4b552917aa65d1f2328990d3ccfc01d5b7c943ec084299f/pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640", size = 1955019, upload-time = "2025-04-23T18:31:01.335Z" },
+ { url = "https://files.pythonhosted.org/packages/30/68/373d55e58b7e83ce371691f6eaa7175e3a24b956c44628eb25d7da007917/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa", size = 2023982, upload-time = "2025-04-23T18:32:53.14Z" },
+ { url = "https://files.pythonhosted.org/packages/a4/16/145f54ac08c96a63d8ed6442f9dec17b2773d19920b627b18d4f10a061ea/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29", size = 1858412, upload-time = "2025-04-23T18:32:55.52Z" },
+ { url = "https://files.pythonhosted.org/packages/41/b1/c6dc6c3e2de4516c0bb2c46f6a373b91b5660312342a0cf5826e38ad82fa/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d", size = 1892749, upload-time = "2025-04-23T18:32:57.546Z" },
+ { url = "https://files.pythonhosted.org/packages/12/73/8cd57e20afba760b21b742106f9dbdfa6697f1570b189c7457a1af4cd8a0/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e", size = 2067527, upload-time = "2025-04-23T18:32:59.771Z" },
+ { url = "https://files.pythonhosted.org/packages/e3/d5/0bb5d988cc019b3cba4a78f2d4b3854427fc47ee8ec8e9eaabf787da239c/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c", size = 2108225, upload-time = "2025-04-23T18:33:04.51Z" },
+ { url = "https://files.pythonhosted.org/packages/f1/c5/00c02d1571913d496aabf146106ad8239dc132485ee22efe08085084ff7c/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec", size = 2069490, upload-time = "2025-04-23T18:33:06.391Z" },
+ { url = "https://files.pythonhosted.org/packages/22/a8/dccc38768274d3ed3a59b5d06f59ccb845778687652daa71df0cab4040d7/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052", size = 2237525, upload-time = "2025-04-23T18:33:08.44Z" },
+ { url = "https://files.pythonhosted.org/packages/d4/e7/4f98c0b125dda7cf7ccd14ba936218397b44f50a56dd8c16a3091df116c3/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c", size = 2238446, upload-time = "2025-04-23T18:33:10.313Z" },
+ { url = "https://files.pythonhosted.org/packages/ce/91/2ec36480fdb0b783cd9ef6795753c1dea13882f2e68e73bce76ae8c21e6a/pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808", size = 2066678, upload-time = "2025-04-23T18:33:12.224Z" },
+]
+
+[[package]]
+name = "pydantic-settings"
+version = "2.10.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pydantic" },
+ { name = "python-dotenv" },
+ { name = "typing-inspection" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/68/85/1ea668bbab3c50071ca613c6ab30047fb36ab0da1b92fa8f17bbc38fd36c/pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee", size = 172583, upload-time = "2025-06-24T13:26:46.841Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/58/f0/427018098906416f580e3cf1366d3b1abfb408a0652e9f31600c24a1903c/pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796", size = 45235, upload-time = "2025-06-24T13:26:45.485Z" },
+]
+
+[[package]]
+name = "pyflakes"
+version = "3.4.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/45/dc/fd034dc20b4b264b3d015808458391acbf9df40b1e54750ef175d39180b1/pyflakes-3.4.0.tar.gz", hash = "sha256:b24f96fafb7d2ab0ec5075b7350b3d2d2218eab42003821c06344973d3ea2f58", size = 64669, upload-time = "2025-06-20T18:45:27.834Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c2/2f/81d580a0fb83baeb066698975cb14a618bdbed7720678566f1b046a95fe8/pyflakes-3.4.0-py2.py3-none-any.whl", hash = "sha256:f742a7dbd0d9cb9ea41e9a24a918996e8170c799fa528688d40dd582c8265f4f", size = 63551, upload-time = "2025-06-20T18:45:26.937Z" },
+]
+
+[[package]]
+name = "pygments"
+version = "2.19.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
+]
+
+[[package]]
+name = "pyopengl"
+version = "3.1.9"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/c0/42/71080db298df3ddb7e3090bfea8fd7c300894d8b10954c22f8719bd434eb/pyopengl-3.1.9.tar.gz", hash = "sha256:28ebd82c5f4491a418aeca9672dffb3adbe7d33b39eada4548a5b4e8c03f60c8", size = 1913642, upload-time = "2025-01-20T02:17:53.263Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/92/44/8634af40b0db528b5b37e901c0dc67321354880d251bf8965901d57693a5/PyOpenGL-3.1.9-py3-none-any.whl", hash = "sha256:15995fd3b0deb991376805da36137a4ae5aba6ddbb5e29ac1f35462d130a3f77", size = 3190341, upload-time = "2025-01-20T02:17:50.913Z" },
+]
+
+[[package]]
+name = "pyparsing"
+version = "3.2.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/bb/22/f1129e69d94ffff626bdb5c835506b3a5b4f3d070f17ea295e12c2c6f60f/pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be", size = 1088608, upload-time = "2025-03-25T05:01:28.114Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120, upload-time = "2025-03-25T05:01:24.908Z" },
+]
+
+[[package]]
+name = "pytest"
+version = "8.4.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "exceptiongroup" },
+ { name = "iniconfig" },
+ { name = "packaging" },
+ { name = "pluggy" },
+ { name = "pygments" },
+ { name = "tomli" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
+]
+
+[[package]]
+name = "python-dotenv"
+version = "1.1.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f6/b0/4bc07ccd3572a2f9df7e6782f52b0c6c90dcbb803ac4a167702d7d0dfe1e/python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab", size = 41978, upload-time = "2025-06-24T04:21:07.341Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc", size = 20556, upload-time = "2025-06-24T04:21:06.073Z" },
+]
+
+[[package]]
+name = "requests"
+version = "2.32.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "certifi" },
+ { name = "charset-normalizer" },
+ { name = "idna" },
+ { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/e1/0a/929373653770d8a0d7ea76c37de6e41f11eb07559b103b1c02cafb3f7cf8/requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422", size = 135258, upload-time = "2025-06-09T16:43:07.34Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" },
+]
+
+[[package]]
+name = "requests-oauthlib"
+version = "2.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "oauthlib" },
+ { name = "requests" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650, upload-time = "2024-03-22T20:32:29.939Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179, upload-time = "2024-03-22T20:32:28.055Z" },
+]
+
+[[package]]
+name = "rsa"
+version = "4.9.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pyasn1" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" },
+]
+
+[[package]]
+name = "ruff"
+version = "0.12.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" },
+ { url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" },
+ { url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" },
+ { url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" },
+ { url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" },
+ { url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" },
+ { url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" },
+ { url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" },
+ { url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" },
+ { url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" },
+ { url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" },
+ { url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" },
+ { url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" },
+ { url = "https://files.pythonhosted.org/packages/fe/84/7cc7bd73924ee6be4724be0db5414a4a2ed82d06b30827342315a1be9e9c/ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d", size = 10589263, upload-time = "2025-07-11T13:21:07.148Z" },
+ { url = "https://files.pythonhosted.org/packages/07/87/c070f5f027bd81f3efee7d14cb4d84067ecf67a3a8efb43aadfc72aa79a6/ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7", size = 11695072, upload-time = "2025-07-11T13:21:11.004Z" },
+ { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" },
+]
+
+[[package]]
+name = "scipy"
+version = "1.15.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/0f/37/6964b830433e654ec7485e45a00fc9a27cf868d622838f6b6d9c5ec0d532/scipy-1.15.3.tar.gz", hash = "sha256:eae3cf522bc7df64b42cad3925c876e1b0b6c35c1337c93e12c0f366f55b0eaf", size = 59419214, upload-time = "2025-05-08T16:13:05.955Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/78/2f/4966032c5f8cc7e6a60f1b2e0ad686293b9474b65246b0c642e3ef3badd0/scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c", size = 38702770, upload-time = "2025-05-08T16:04:20.849Z" },
+ { url = "https://files.pythonhosted.org/packages/a0/6e/0c3bf90fae0e910c274db43304ebe25a6b391327f3f10b5dcc638c090795/scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253", size = 30094511, upload-time = "2025-05-08T16:04:27.103Z" },
+ { url = "https://files.pythonhosted.org/packages/ea/b1/4deb37252311c1acff7f101f6453f0440794f51b6eacb1aad4459a134081/scipy-1.15.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:aef683a9ae6eb00728a542b796f52a5477b78252edede72b8327a886ab63293f", size = 22368151, upload-time = "2025-05-08T16:04:31.731Z" },
+ { url = "https://files.pythonhosted.org/packages/38/7d/f457626e3cd3c29b3a49ca115a304cebb8cc6f31b04678f03b216899d3c6/scipy-1.15.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:1c832e1bd78dea67d5c16f786681b28dd695a8cb1fb90af2e27580d3d0967e92", size = 25121732, upload-time = "2025-05-08T16:04:36.596Z" },
+ { url = "https://files.pythonhosted.org/packages/db/0a/92b1de4a7adc7a15dcf5bddc6e191f6f29ee663b30511ce20467ef9b82e4/scipy-1.15.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:263961f658ce2165bbd7b99fa5135195c3a12d9bef045345016b8b50c315cb82", size = 35547617, upload-time = "2025-05-08T16:04:43.546Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/6d/41991e503e51fc1134502694c5fa7a1671501a17ffa12716a4a9151af3df/scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2abc762b0811e09a0d3258abee2d98e0c703eee49464ce0069590846f31d40", size = 37662964, upload-time = "2025-05-08T16:04:49.431Z" },
+ { url = "https://files.pythonhosted.org/packages/25/e1/3df8f83cb15f3500478c889be8fb18700813b95e9e087328230b98d547ff/scipy-1.15.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed7284b21a7a0c8f1b6e5977ac05396c0d008b89e05498c8b7e8f4a1423bba0e", size = 37238749, upload-time = "2025-05-08T16:04:55.215Z" },
+ { url = "https://files.pythonhosted.org/packages/93/3e/b3257cf446f2a3533ed7809757039016b74cd6f38271de91682aa844cfc5/scipy-1.15.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5380741e53df2c566f4d234b100a484b420af85deb39ea35a1cc1be84ff53a5c", size = 40022383, upload-time = "2025-05-08T16:05:01.914Z" },
+ { url = "https://files.pythonhosted.org/packages/d1/84/55bc4881973d3f79b479a5a2e2df61c8c9a04fcb986a213ac9c02cfb659b/scipy-1.15.3-cp310-cp310-win_amd64.whl", hash = "sha256:9d61e97b186a57350f6d6fd72640f9e99d5a4a2b8fbf4b9ee9a841eab327dc13", size = 41259201, upload-time = "2025-05-08T16:05:08.166Z" },
+]
+
+[[package]]
+name = "setuptools"
+version = "80.9.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" },
+]
+
+[[package]]
+name = "six"
+version = "1.17.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
+]
+
+[[package]]
+name = "sympy"
+version = "1.14.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "mpmath" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" },
+]
+
+[[package]]
+name = "tensorboard"
+version = "2.15.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "absl-py" },
+ { name = "google-auth" },
+ { name = "google-auth-oauthlib" },
+ { name = "grpcio" },
+ { name = "markdown" },
+ { name = "numpy" },
+ { name = "protobuf" },
+ { name = "requests" },
+ { name = "setuptools" },
+ { name = "six" },
+ { name = "tensorboard-data-server" },
+ { name = "werkzeug" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6e/0c/1059a6682cf2cc1fcc0d5327837b5672fe4f5574255fa5430d0a8ceb75e9/tensorboard-2.15.1-py3-none-any.whl", hash = "sha256:c46c1d1cf13a458c429868a78b2531d8ff5f682058d69ec0840b0bc7a38f1c0f", size = 5539710, upload-time = "2023-11-02T20:49:50.813Z" },
+]
+
+[[package]]
+name = "tensorboard-data-server"
+version = "0.7.2"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356, upload-time = "2023-10-23T21:23:32.16Z" },
+ { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598, upload-time = "2023-10-23T21:23:33.714Z" },
+ { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" },
+]
+
+[[package]]
+name = "tomli"
+version = "2.2.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175, upload-time = "2024-11-27T22:38:36.873Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" },
+]
+
+[[package]]
+name = "torch"
+version = "2.2.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "filelock" },
+ { name = "fsspec" },
+ { name = "jinja2" },
+ { name = "networkx" },
+ { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "sympy" },
+ { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
+ { name = "typing-extensions" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/33/b3/1fcc3bccfddadfd6845dcbfe26eb4b099f1dfea5aa0e5cfb92b3c98dba5b/torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585", size = 755526581, upload-time = "2024-03-27T21:06:46.5Z" },
+ { url = "https://files.pythonhosted.org/packages/c3/7c/aeb0c5789a3f10cf909640530cd75b314959b9d9914a4996ed2c7bf8779d/torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030", size = 86623646, upload-time = "2024-03-27T21:10:22.719Z" },
+ { url = "https://files.pythonhosted.org/packages/3a/81/684d99e536b20e869a7c1222cf1dd233311fb05d3628e9570992bfb65760/torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5", size = 198579616, upload-time = "2024-03-27T21:10:15.41Z" },
+ { url = "https://files.pythonhosted.org/packages/3b/55/7192974ab13e5e5577f45d14ce70d42f5a9a686b4f57bbe8c9ab45c4a61a/torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e", size = 150788930, upload-time = "2024-03-27T21:08:09.98Z" },
+ { url = "https://files.pythonhosted.org/packages/33/6b/21496316c9b8242749ee2a9064406271efdf979e91d440e8a3806b5e84bf/torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2", size = 59707286, upload-time = "2024-03-27T21:10:28.154Z" },
+]
+
+[[package]]
+name = "tqdm"
+version = "4.67.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" },
+]
+
+[[package]]
+name = "triton"
+version = "2.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "filelock" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/95/05/ed974ce87fe8c8843855daa2136b3409ee1c126707ab54a8b72815c08b49/triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5", size = 167900779, upload-time = "2024-01-10T03:11:56.576Z" },
+]
+
+[[package]]
+name = "typing-extensions"
+version = "4.14.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" },
+]
+
+[[package]]
+name = "typing-inspection"
+version = "0.4.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f8/b1/0c11f5058406b3af7609f121aaa6b609744687f1d158b3c3a5bf4cc94238/typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28", size = 75726, upload-time = "2025-05-21T18:55:23.885Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" },
+]
+
+[[package]]
+name = "urllib3"
+version = "2.5.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
+]
+
+[[package]]
+name = "werkzeug"
+version = "3.1.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "markupsafe" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925, upload-time = "2024-11-08T15:52:18.093Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/52/24/ab44c871b0f07f491e5d2ad12c9bd7358e527510618cb1b803a88e986db1/werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e", size = 224498, upload-time = "2024-11-08T15:52:16.132Z" },
+]
+
+[[package]]
+name = "wrapt"
+version = "1.17.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/c3/fc/e91cc220803d7bc4db93fb02facd8461c37364151b8494762cc88b0fbcef/wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3", size = 55531, upload-time = "2025-01-14T10:35:45.465Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5a/d1/1daec934997e8b160040c78d7b31789f19b122110a75eca3d4e8da0049e1/wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984", size = 53307, upload-time = "2025-01-14T10:33:13.616Z" },
+ { url = "https://files.pythonhosted.org/packages/1b/7b/13369d42651b809389c1a7153baa01d9700430576c81a2f5c5e460df0ed9/wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22", size = 38486, upload-time = "2025-01-14T10:33:15.947Z" },
+ { url = "https://files.pythonhosted.org/packages/62/bf/e0105016f907c30b4bd9e377867c48c34dc9c6c0c104556c9c9126bd89ed/wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7", size = 38777, upload-time = "2025-01-14T10:33:17.462Z" },
+ { url = "https://files.pythonhosted.org/packages/27/70/0f6e0679845cbf8b165e027d43402a55494779295c4b08414097b258ac87/wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c", size = 83314, upload-time = "2025-01-14T10:33:21.282Z" },
+ { url = "https://files.pythonhosted.org/packages/0f/77/0576d841bf84af8579124a93d216f55d6f74374e4445264cb378a6ed33eb/wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72", size = 74947, upload-time = "2025-01-14T10:33:24.414Z" },
+ { url = "https://files.pythonhosted.org/packages/90/ec/00759565518f268ed707dcc40f7eeec38637d46b098a1f5143bff488fe97/wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061", size = 82778, upload-time = "2025-01-14T10:33:26.152Z" },
+ { url = "https://files.pythonhosted.org/packages/f8/5a/7cffd26b1c607b0b0c8a9ca9d75757ad7620c9c0a9b4a25d3f8a1480fafc/wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2", size = 81716, upload-time = "2025-01-14T10:33:27.372Z" },
+ { url = "https://files.pythonhosted.org/packages/7e/09/dccf68fa98e862df7e6a60a61d43d644b7d095a5fc36dbb591bbd4a1c7b2/wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c", size = 74548, upload-time = "2025-01-14T10:33:28.52Z" },
+ { url = "https://files.pythonhosted.org/packages/b7/8e/067021fa3c8814952c5e228d916963c1115b983e21393289de15128e867e/wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62", size = 81334, upload-time = "2025-01-14T10:33:29.643Z" },
+ { url = "https://files.pythonhosted.org/packages/4b/0d/9d4b5219ae4393f718699ca1c05f5ebc0c40d076f7e65fd48f5f693294fb/wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563", size = 36427, upload-time = "2025-01-14T10:33:30.832Z" },
+ { url = "https://files.pythonhosted.org/packages/72/6a/c5a83e8f61aec1e1aeef939807602fb880e5872371e95df2137142f5c58e/wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f", size = 38774, upload-time = "2025-01-14T10:33:32.897Z" },
+ { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594, upload-time = "2025-01-14T10:35:44.018Z" },
+]