Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions env/wrappers/vec_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@


def _worker(remote, parent_remote, env_fn_wrapper):
"""
Worker process for handling environment interactions in a subprocess.

Reward computation:
- The reward at each step comes directly from env.step(action), which returns
(obs, reward, done, truncated, info). The truncated flag is ignored here.
- The per-step reward is determined by the underlying environment's reward function.
- Rewards are accumulated over the episode in a list. When done=True,
the episode_return (sum of all step rewards) is stored in info["episode_return"].
- The per-step reward is sent back to the main process via remote.send().

For ASID (Automatic System Identification), an alternative reward can be computed
using ASIDRewardWrapper which calculates reward = trace(J_transpose * J), where J is
the Jacobian of observations w.r.t. physics parameters estimated via finite differences.
"""
parent_remote.close()
env = _patch_env(env_fn_wrapper.var())
# env = env_fn_wrapper.var()
Expand All @@ -18,12 +33,15 @@ def _worker(remote, parent_remote, env_fn_wrapper):
cmd, data = remote.recv()
# gym interface
if cmd == "step":
# Reward is computed by the underlying environment's step function.
# For standard RL tasks, this is the task-specific reward.
# For ASID, use ASIDRewardWrapper to compute information-theoretic rewards.
obs, reward, done, _, info = env.step(data)
rewards.append(reward)
rewards.append(reward) # Accumulate per-step rewards
successes.append(info.get("success", 0))
if done:
info["terminal_obs"] = obs
info["episode_return"] = sum(rewards)
info["episode_return"] = sum(rewards) # Total episode reward
info["episode_success"] = float(sum(successes) > 0)
rewards, successes = [], []
obs, _ = env.reset()
Expand Down