From 431fa2bca61e7a837e37cdd87d2431dc1bee8bae Mon Sep 17 00:00:00 2001 From: Ellen Xu Date: Thu, 28 Nov 2024 14:53:28 -0800 Subject: [PATCH 1/2] add norm layer --- networks/agent.py | 75 ++++++++++++++++++----------------------- networks/alphazero.py | 2 +- networks/utils.py | 31 +++++++++++++++++ run_a0c.py | 76 ++++++++++++++++++++---------------------- utils/running_stats.py | 45 ------------------------- 5 files changed, 102 insertions(+), 127 deletions(-) create mode 100644 networks/utils.py delete mode 100644 utils/running_stats.py diff --git a/networks/agent.py b/networks/agent.py index 7fb763f..367197d 100644 --- a/networks/agent.py +++ b/networks/agent.py @@ -3,6 +3,7 @@ import numpy as np import torch.nn.functional as F from torch import nn +from networks.utils import Normalizer def mlp(hidden_sizes: list[int], activation: nn.Module = nn.Tanh, output_activation: nn.Module = nn.Identity): layers = [] @@ -11,33 +12,6 @@ def mlp(hidden_sizes: list[int], activation: nn.Module = nn.Tanh, output_activat layers += [nn.Linear(hidden_sizes[j], hidden_sizes[j+1]), act()] return nn.Sequential(*layers) -class MLPCategorical(nn.Module): - def __init__(self, obs_dim: int, hidden_sizes: list[int], act_dim: int, activation: nn.Module = nn.Tanh) -> None: - super(MLPCategorical, self).__init__() - self.mlp = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.mlp(x) - - def get_policy(self, obs: torch.Tensor) -> torch.Tensor: - logits = self.forward(obs) - probs = F.softmax(logits, dim=-1) - return probs - - def get_action(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor: - probs = self.get_policy(obs) - action = torch.argmax(probs, dim=-1) if deterministic else torch.multinomial(probs, num_samples=1) - return action.detach().cpu().numpy().squeeze() - - def get_logprob(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor: - logits = self.forward(obs) - logprob = F.log_softmax(logits, dim=-1) - logprob = logprob.gather(1, act.unsqueeze(-1)).squeeze(-1) - prob = logprob.exp() - entropy = -(logprob*prob) - assert logprob.shape == act.shape - return logprob, entropy.sum(dim=-1) - class MLPGaussian(nn.Module): def __init__(self, obs_dim: int, hidden_sizes: list[int], act_dim: int, activation: nn.Module = nn.Tanh, log_std: float = 3.) -> None: # TODO: 0 super(MLPGaussian, self).__init__() @@ -45,6 +19,8 @@ def __init__(self, obs_dim: int, hidden_sizes: list[int], act_dim: int, activati self.log_std = torch.nn.Parameter(torch.full((act_dim,), log_std, dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.obs_normalizer: + x = self.obs_normalizer.norm(x) return self.mlp(x) def get_policy(self, obs: torch.Tensor) -> torch.Tensor: @@ -74,6 +50,8 @@ def __init__(self, obs_dim, hidden_sizes, act_dim, activation=nn.Tanh, bias=True self.act_bound = act_bound def forward(self, x: torch.Tensor): + if self.obs_normalizer: + x = self.obs_normalizer.norm(x) return self.mlp(x) def get_policy(self, obs: torch.Tensor): @@ -101,31 +79,38 @@ def get_logprob(self, obs: torch.Tensor, act: torch.Tensor): class MLPCritic(nn.Module): def __init__(self, obs_dim: int, hidden_sizes: list[int], activation: nn.Module = nn.Tanh) -> None: - super(MLPCritic, self).__init__() + super().__init__() self.mlp = mlp([obs_dim] + list(hidden_sizes) + [1], activation) - self.register_buffer('mean', torch.tensor(0.0)) - self.register_buffer('std', torch.tensor(1.0)) + self.return_normalizer = Normalizer(1) # value is scalar + + def forward(self, x: torch.Tensor, out_norm: bool = False) -> torch.Tensor: + if self.obs_normalizer: + x = self.obs_normalizer.norm(x) + value = self.mlp(x) + return self.return_normalizer.denorm(value) if not out_norm else value # default is unnormalized - def forward(self, x: torch.Tensor, normalize: bool = False) -> torch.Tensor: - if normalize: - return self.mlp(x) # mlp predicts normalized value - return self.mlp(x) * self.std + self.mean # -> unnormalize in mcts - class ActorCritic(nn.Module): - def __init__(self, obs_dim: int, hidden_sizes: dict[str, list[int]], act_dim: int, discrete: bool = False, shared_layers: bool = True, act_bound: tuple[float, float] = None) -> None: - super(ActorCritic, self).__init__() + def __init__(self, obs_dim: int, hidden_sizes: dict[str, list[int]], act_dim: int, + discrete: bool = False, shared_layers: bool = True, act_bound: tuple[float, float] = None) -> None: + super().__init__() + self.obs_normalizer = Normalizer(obs_dim) model_class = MLPCategorical if discrete else (MLPGaussian if not act_bound else MLPBeta) print('using model', model_class, 'with action bound', act_bound) - if act_bound: # if bounded then use MLPBeta + + if act_bound: self.actor = model_class(obs_dim, hidden_sizes["pi"], act_dim, act_bound=act_bound) - act_dim *= 2 # MLPBeta outputs two parameters alpha, beta + act_dim *= 2 else: self.actor = model_class(obs_dim, hidden_sizes["pi"], act_dim) self.critic = MLPCritic(obs_dim, hidden_sizes["vf"]) + # share normalizer with actor and critic + self.actor.obs_normalizer = self.obs_normalizer + self.critic.obs_normalizer = self.obs_normalizer + if shared_layers and len(hidden_sizes["pi"]) > 1: self.shared = mlp([obs_dim] + hidden_sizes["pi"][:-1], nn.Tanh) - self.actor.mlp = nn.Sequential( # override + self.actor.mlp = nn.Sequential( self.shared, mlp([hidden_sizes["pi"][-2], hidden_sizes["pi"][-1], act_dim], nn.Tanh) ) @@ -134,7 +119,13 @@ def __init__(self, obs_dim: int, hidden_sizes: dict[str, list[int]], act_dim: in mlp([hidden_sizes["vf"][-2], hidden_sizes["vf"][-1], 1], nn.Tanh) ) + def update_normalizers(self, obs, returns=None): + """Update running statistics for normalization""" + self.obs_normalizer.update(obs) + if returns is not None: + self.critic.return_normalizer.update(returns.reshape(-1, 1)) + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - actor_out = self.actor(x) + actor_out = self.actor(x) # normalization handled in forward pass critic_out = self.critic(x) - return actor_out, critic_out + return actor_out, critic_out \ No newline at end of file diff --git a/networks/alphazero.py b/networks/alphazero.py index d6c9345..1bb7369 100644 --- a/networks/alphazero.py +++ b/networks/alphazero.py @@ -3,7 +3,7 @@ import numpy as np from networks.mcts import MCTS, State -class A0C(MCTS): +class A0CModel(MCTS): """AlphaZero Continuous. Uses NN to guide MCTS search.""" def __init__(self, model, exploration_weight=1e-3, gamma=0.99, k=1, alpha=0.5, device='cpu'): super().__init__(exploration_weight, gamma, k, alpha) diff --git a/networks/utils.py b/networks/utils.py new file mode 100644 index 0000000..8cb3cba --- /dev/null +++ b/networks/utils.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + +class Normalizer(nn.Module): + """Normalization layer that tracks running statistics""" + def __init__(self, shape): + super().__init__() + self.register_buffer('mean', torch.zeros(shape)) + self.register_buffer('std', torch.ones(shape)) + self.register_buffer('count', torch.zeros(1)) + self.shape = shape + + def update(self, x): + """Welford's online algorithm for updating mean and std""" + batch_mean = x.mean(dim=0) + batch_var = x.var(dim=0, unbiased=False) + batch_count = x.shape[0] + + delta = batch_mean - self.mean + self.mean += delta * batch_count / (self.count + batch_count) + m_a = self.std ** 2 * self.count + m_b = batch_var * batch_count + M2 = m_a + m_b + delta ** 2 * self.count * batch_count / (self.count + batch_count) + self.std = torch.sqrt(M2 / (self.count + batch_count)) + self.count += batch_count + + def norm(self, x): + return (x - self.mean) / (self.std + 1e-8) + + def denorm(self, x): + return x * self.std + self.mean \ No newline at end of file diff --git a/run_a0c.py b/run_a0c.py index 35d5a00..64b11b5 100644 --- a/run_a0c.py +++ b/run_a0c.py @@ -10,10 +10,10 @@ import matplotlib.pyplot as plt from torchrl.data import ReplayBuffer, LazyTensorStorage from tensordict import TensorDict -from networks.alphazero import A0C as A0CModel +from networks.alphazero import A0CModel from networks.agent import ActorCritic +from networks.mcts import MCTS from run_mcts import CartState -from utils.running_stats import RunningStats import warnings warnings.filterwarnings("ignore", category=UserWarning) @@ -27,12 +27,10 @@ def __init__(self, env, model, lr=1e-1, epochs=10, ent_coeff=0.01, env_bs=1, dev self.ent_coeff = ent_coeff self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.replay_buffer = ReplayBuffer(storage=LazyTensorStorage(max_size=100000, device=device)) - self.start = time.time() self.device = device self.debug = debug self.mcts = A0CModel(model, device=device) # self.mcts = MCTS() - self.running_stats = RunningStats() self.hist = {'iter': [], 'reward': [], 'value_loss': [], 'policy_loss': [], 'total_loss': []} def _compute_return(self, states): @@ -45,26 +43,18 @@ def _compute_return(self, states): returns.append(v_target) return np.array(returns) - def _normalize_return(self, returns): - for r in returns: - self.running_stats.push(r) - - self.mu = self.running_stats.mean() - self.sigma = self.running_stats.standard_deviation() - self.model.critic.mean = torch.tensor(self.mu, device=self.device) # update value net - self.model.critic.std = torch.tensor(self.sigma, device=self.device) - - normalized_returns = (returns - self.mu) / (self.sigma + 1e-8) - return normalized_returns - def value_loss(self, states, returns): - # value loss is MSE/MAE with normalized Q values - V = self.model.critic(states, normalize=True).squeeze() # normalize = True, unnorm only in mcts search - normalized_returns = self._normalize_return(returns) - value_loss = nn.L1Loss()(V, normalized_returns) + V = self.model.critic(states, out_norm=True).squeeze() # normalize output for loss, unnorm in mcts + # normalize returns for loss + returns = self.model.critic.return_normalizer.norm(returns) + value_loss = nn.L1Loss()(V, returns) if self.debug: - print(f"value {V[0]} return {normalized_returns[0]}") - print(f"value range: {V.min():.3f} {V.max():.3f}, returns range: {normalized_returns.min():.3f} {normalized_returns.max():.3f}") + print(f"value {V[0]} return {returns[0]}") + print(f"normalized value range: {V.min().item():.3f} {V.max().item():.3f}") + print(f"normalized returns range: {returns.min().item():.3f} {returns.max().item():.3f}") + normalized_states = self.model.critic.obs_normalizer.norm(states) + print(f"state {states[0]}, normalized {normalized_states[0]}") + print(f"normalized obs range: {normalized_states.min().item():.3f} {normalized_states.max().item():.3f}") return value_loss def policy_loss(self, mcts_states, mcts_actions, mcts_counts): @@ -89,13 +79,14 @@ def mcts_rollout(self, max_steps=1000, deterministic=False): state, _ = self.env.reset() for _ in range(max_steps): - s = CartState.from_array(state) + s = CartState.from_array(state) # TODO: get rid of cartstate best_action, _ = self.mcts.get_action(s, d=10, n=100, deterministic=True) # get single action and prob next_state, reward, terminated, truncated, info = self.env.step(np.array([[best_action]])) states.append(state) rewards.append(reward) done = terminated or truncated + # mcts returns actions, counts actions, norm_counts = self.mcts.get_policy(s) mcts_counts.append(norm_counts) mcts_actions.append(actions) @@ -108,18 +99,25 @@ def mcts_rollout(self, max_steps=1000, deterministic=False): return states, rewards, mcts_states, mcts_actions, mcts_counts def train(self, max_iters=1000, n_episodes=10, n_steps=30): + start = time.time() for i in range(max_iters): - # collect data using mcts + # collect data for _ in range(n_episodes): states, rewards, mcts_states, mcts_actions, mcts_counts = self.mcts_rollout(n_steps) returns = self._compute_return(states) - episode_dict = TensorDict( { - "states": states, - "returns": returns, - "mcts_states": np.array(mcts_states), - "mcts_actions": np.array(mcts_actions), - "mcts_counts": np.array(mcts_counts), + # update normalizers with new data + self.model.update_normalizers( + torch.FloatTensor(states).to(self.device), + torch.FloatTensor(returns).to(self.device) + ) + + episode_dict = TensorDict({ + "states": states, + "returns": returns, + "mcts_states": np.array(mcts_states), + "mcts_actions": np.array(mcts_actions), + "mcts_counts": np.array(mcts_counts), }, batch_size=len(states)) self.replay_buffer.extend(episode_dict) @@ -149,9 +147,9 @@ def train(self, max_iters=1000, n_episodes=10, n_steps=30): self.hist['total_loss'].append(loss.item()) print(f"actor loss {policy_loss.item():.3f} value loss {value_loss.item():.3f} l2 loss {l2_loss.item():.3f}") - print(f"iter {i}, reward {avg_reward:.3f}, t {time.time()-self.start:.2f}") + print(f"iter {i}, reward {avg_reward:.3f}, t {time.time()-start:.2f}") - print(f"Total time: {time.time() - self.start}") + print(f"Total time: {time.time() - start}") return self.model, self.hist def rollout(model, env, max_steps=1000, deterministic=False): @@ -195,13 +193,13 @@ def plot_losses(hist, save_path=None): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--max_iters", type=int, default=30) + parser.add_argument("--max_iters", type=int, default=10) parser.add_argument("--n_eps", type=int, default=10) parser.add_argument("--n_steps", type=int, default=30) parser.add_argument("--env_bs", type=int, default=1) # TODO: batch parser.add_argument("--save", default=True) parser.add_argument("--noise_mode", default=None) - parser.add_argument("--debug", default=False) + parser.add_argument("--debug", default=True) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() @@ -212,16 +210,16 @@ def plot_losses(hist, save_path=None): best_model, hist = a0c.train(args.max_iters, args.n_eps, args.n_steps) # run value net online planner - # from run_mcts import run_mcts - # env = gym.make("CartLatAccel-v1", render_mode="human", noise_mode=args.noise_mode) - # reward = run_mcts(a0c.mcts, env, max_steps=200, search_depth=10, n_sims=100, seed=args.seed) - # print(f"reward {reward}") + from run_mcts import run_mcts + env = gym.make("CartLatAccel-v1", render_mode="human", noise_mode=args.noise_mode) + reward = run_mcts(a0c.mcts, env, max_steps=200, search_depth=10, n_sims=100, seed=args.seed) + print(f"mcts reward {reward}") # run actor net model print("rollout out best actor") env = gym.make("CartLatAccel-v1", noise_mode=args.noise_mode, env_bs=1, render_mode="human") rewards = rollout(best_model.actor, env, max_steps=200, deterministic=True) - print(f"reward {sum(rewards)}") + print(f"actor reward {sum(rewards)}") if args.save: os.makedirs('out', exist_ok=True) diff --git a/utils/running_stats.py b/utils/running_stats.py deleted file mode 100644 index ccab4b0..0000000 --- a/utils/running_stats.py +++ /dev/null @@ -1,45 +0,0 @@ -import math -import numpy as np - -# https://stackoverflow.com/a/17637351 -class RunningStats: - - def __init__(self): - self.n = 0 - self.old_m = 0 - self.new_m = 0 - self.old_s = 0 - self.new_s = 0 - - self.min = np.inf - self.max = -np.inf - - def clear(self): - self.n = 0 - - def push(self, x): - self.n += 1 - - if self.n == 1: - self.old_m = self.new_m = x - self.old_s = 0 - else: - self.new_m = self.old_m + (x - self.old_m) / self.n - self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m) - - self.old_m = self.new_m - self.old_s = self.new_s - - if x < self.min: - self.min = x - if x > self.max: - self.max = x - - def mean(self): - return self.new_m if self.n else 0.0 - - def variance(self): - return self.new_s / (self.n - 1) if self.n > 1 else 0.0 - - def standard_deviation(self): - return math.sqrt(self.variance()) From dfa3900ddd43766b6599bc1e34016d5dc1b42b2b Mon Sep 17 00:00:00 2001 From: Ellen Xu Date: Thu, 28 Nov 2024 15:00:33 -0800 Subject: [PATCH 2/2] seed eval --- run_a0c.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/run_a0c.py b/run_a0c.py index 64b11b5..70c0209 100644 --- a/run_a0c.py +++ b/run_a0c.py @@ -132,12 +132,7 @@ def train(self, max_iters=1000, n_episodes=10, n_steps=30): loss.backward() self.optimizer.step() - # evaluate current actor net - eps_rewards = [] - for _ in range(5): - rewards = rollout(self.model.actor, self.env, max_steps=n_steps, deterministic=True) - eps_rewards.append(sum(rewards)) - avg_reward = np.mean(eps_rewards) + avg_reward = np.mean(evaluate(self.model.actor, self.env)) # log metrics self.hist['iter'].append(i) @@ -151,9 +146,16 @@ def train(self, max_iters=1000, n_episodes=10, n_steps=30): print(f"Total time: {time.time() - start}") return self.model, self.hist - -def rollout(model, env, max_steps=1000, deterministic=False): - state, _ = env.reset() + +def evaluate(model, env, n_episodes=10, n_steps=200, seed=42): + eps_rewards = [] + for i in range(n_episodes): + rewards = rollout(model, env, max_steps=n_steps, seed=seed+i, deterministic=True) + eps_rewards.append(sum(rewards)) + return eps_rewards + +def rollout(model, env, max_steps=1000, seed=42, deterministic=False): + state, _ = env.reset(seed=seed) rewards = [] for _ in range(max_steps): state_tensor = torch.FloatTensor(state).unsqueeze(0)