Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 33 additions & 42 deletions networks/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch.nn.functional as F
from torch import nn
from networks.utils import Normalizer

def mlp(hidden_sizes: list[int], activation: nn.Module = nn.Tanh, output_activation: nn.Module = nn.Identity):
layers = []
Expand All @@ -11,40 +12,15 @@ def mlp(hidden_sizes: list[int], activation: nn.Module = nn.Tanh, output_activat
layers += [nn.Linear(hidden_sizes[j], hidden_sizes[j+1]), act()]
return nn.Sequential(*layers)

class MLPCategorical(nn.Module):
def __init__(self, obs_dim: int, hidden_sizes: list[int], act_dim: int, activation: nn.Module = nn.Tanh) -> None:
super(MLPCategorical, self).__init__()
self.mlp = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)

def get_policy(self, obs: torch.Tensor) -> torch.Tensor:
logits = self.forward(obs)
probs = F.softmax(logits, dim=-1)
return probs

def get_action(self, obs: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
probs = self.get_policy(obs)
action = torch.argmax(probs, dim=-1) if deterministic else torch.multinomial(probs, num_samples=1)
return action.detach().cpu().numpy().squeeze()

def get_logprob(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
logits = self.forward(obs)
logprob = F.log_softmax(logits, dim=-1)
logprob = logprob.gather(1, act.unsqueeze(-1)).squeeze(-1)
prob = logprob.exp()
entropy = -(logprob*prob)
assert logprob.shape == act.shape
return logprob, entropy.sum(dim=-1)

class MLPGaussian(nn.Module):
def __init__(self, obs_dim: int, hidden_sizes: list[int], act_dim: int, activation: nn.Module = nn.Tanh, log_std: float = 3.) -> None: # TODO: 0
super(MLPGaussian, self).__init__()
self.mlp = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
self.log_std = torch.nn.Parameter(torch.full((act_dim,), log_std, dtype=torch.float32))

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.obs_normalizer:
x = self.obs_normalizer.norm(x)
return self.mlp(x)

def get_policy(self, obs: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -74,6 +50,8 @@ def __init__(self, obs_dim, hidden_sizes, act_dim, activation=nn.Tanh, bias=True
self.act_bound = act_bound

def forward(self, x: torch.Tensor):
if self.obs_normalizer:
x = self.obs_normalizer.norm(x)
return self.mlp(x)

def get_policy(self, obs: torch.Tensor):
Expand Down Expand Up @@ -101,31 +79,38 @@ def get_logprob(self, obs: torch.Tensor, act: torch.Tensor):

class MLPCritic(nn.Module):
def __init__(self, obs_dim: int, hidden_sizes: list[int], activation: nn.Module = nn.Tanh) -> None:
super(MLPCritic, self).__init__()
super().__init__()
self.mlp = mlp([obs_dim] + list(hidden_sizes) + [1], activation)
self.register_buffer('mean', torch.tensor(0.0))
self.register_buffer('std', torch.tensor(1.0))
self.return_normalizer = Normalizer(1) # value is scalar

def forward(self, x: torch.Tensor, out_norm: bool = False) -> torch.Tensor:
if self.obs_normalizer:
x = self.obs_normalizer.norm(x)
value = self.mlp(x)
return self.return_normalizer.denorm(value) if not out_norm else value # default is unnormalized

def forward(self, x: torch.Tensor, normalize: bool = False) -> torch.Tensor:
if normalize:
return self.mlp(x) # mlp predicts normalized value
return self.mlp(x) * self.std + self.mean # -> unnormalize in mcts

class ActorCritic(nn.Module):
def __init__(self, obs_dim: int, hidden_sizes: dict[str, list[int]], act_dim: int, discrete: bool = False, shared_layers: bool = True, act_bound: tuple[float, float] = None) -> None:
super(ActorCritic, self).__init__()
def __init__(self, obs_dim: int, hidden_sizes: dict[str, list[int]], act_dim: int,
discrete: bool = False, shared_layers: bool = True, act_bound: tuple[float, float] = None) -> None:
super().__init__()
self.obs_normalizer = Normalizer(obs_dim)
model_class = MLPCategorical if discrete else (MLPGaussian if not act_bound else MLPBeta)
print('using model', model_class, 'with action bound', act_bound)
if act_bound: # if bounded then use MLPBeta

if act_bound:
self.actor = model_class(obs_dim, hidden_sizes["pi"], act_dim, act_bound=act_bound)
act_dim *= 2 # MLPBeta outputs two parameters alpha, beta
act_dim *= 2
else:
self.actor = model_class(obs_dim, hidden_sizes["pi"], act_dim)
self.critic = MLPCritic(obs_dim, hidden_sizes["vf"])

# share normalizer with actor and critic
self.actor.obs_normalizer = self.obs_normalizer
self.critic.obs_normalizer = self.obs_normalizer

if shared_layers and len(hidden_sizes["pi"]) > 1:
self.shared = mlp([obs_dim] + hidden_sizes["pi"][:-1], nn.Tanh)
self.actor.mlp = nn.Sequential( # override
self.actor.mlp = nn.Sequential(
self.shared,
mlp([hidden_sizes["pi"][-2], hidden_sizes["pi"][-1], act_dim], nn.Tanh)
)
Expand All @@ -134,7 +119,13 @@ def __init__(self, obs_dim: int, hidden_sizes: dict[str, list[int]], act_dim: in
mlp([hidden_sizes["vf"][-2], hidden_sizes["vf"][-1], 1], nn.Tanh)
)

def update_normalizers(self, obs, returns=None):
"""Update running statistics for normalization"""
self.obs_normalizer.update(obs)
if returns is not None:
self.critic.return_normalizer.update(returns.reshape(-1, 1))

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actor_out = self.actor(x)
actor_out = self.actor(x) # normalization handled in forward pass
critic_out = self.critic(x)
return actor_out, critic_out
return actor_out, critic_out
2 changes: 1 addition & 1 deletion networks/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from networks.mcts import MCTS, State

class A0C(MCTS):
class A0CModel(MCTS):
"""AlphaZero Continuous. Uses NN to guide MCTS search."""
def __init__(self, model, exploration_weight=1e-3, gamma=0.99, k=1, alpha=0.5, device='cpu'):
super().__init__(exploration_weight, gamma, k, alpha)
Expand Down
31 changes: 31 additions & 0 deletions networks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import torch.nn as nn

class Normalizer(nn.Module):
"""Normalization layer that tracks running statistics"""
def __init__(self, shape):
super().__init__()
self.register_buffer('mean', torch.zeros(shape))
self.register_buffer('std', torch.ones(shape))
self.register_buffer('count', torch.zeros(1))
self.shape = shape

def update(self, x):
"""Welford's online algorithm for updating mean and std"""
batch_mean = x.mean(dim=0)
batch_var = x.var(dim=0, unbiased=False)
batch_count = x.shape[0]

delta = batch_mean - self.mean
self.mean += delta * batch_count / (self.count + batch_count)
m_a = self.std ** 2 * self.count
m_b = batch_var * batch_count
M2 = m_a + m_b + delta ** 2 * self.count * batch_count / (self.count + batch_count)
self.std = torch.sqrt(M2 / (self.count + batch_count))
self.count += batch_count

def norm(self, x):
return (x - self.mean) / (self.std + 1e-8)

def denorm(self, x):
return x * self.std + self.mean
96 changes: 48 additions & 48 deletions run_a0c.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import matplotlib.pyplot as plt
from torchrl.data import ReplayBuffer, LazyTensorStorage
from tensordict import TensorDict
from networks.alphazero import A0C as A0CModel
from networks.alphazero import A0CModel
from networks.agent import ActorCritic
from networks.mcts import MCTS
from run_mcts import CartState
from utils.running_stats import RunningStats

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
Expand All @@ -27,12 +27,10 @@ def __init__(self, env, model, lr=1e-1, epochs=10, ent_coeff=0.01, env_bs=1, dev
self.ent_coeff = ent_coeff
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.replay_buffer = ReplayBuffer(storage=LazyTensorStorage(max_size=100000, device=device))
self.start = time.time()
self.device = device
self.debug = debug
self.mcts = A0CModel(model, device=device)
# self.mcts = MCTS()
self.running_stats = RunningStats()
self.hist = {'iter': [], 'reward': [], 'value_loss': [], 'policy_loss': [], 'total_loss': []}

def _compute_return(self, states):
Expand All @@ -45,26 +43,18 @@ def _compute_return(self, states):
returns.append(v_target)
return np.array(returns)

def _normalize_return(self, returns):
for r in returns:
self.running_stats.push(r)

self.mu = self.running_stats.mean()
self.sigma = self.running_stats.standard_deviation()
self.model.critic.mean = torch.tensor(self.mu, device=self.device) # update value net
self.model.critic.std = torch.tensor(self.sigma, device=self.device)

normalized_returns = (returns - self.mu) / (self.sigma + 1e-8)
return normalized_returns

def value_loss(self, states, returns):
# value loss is MSE/MAE with normalized Q values
V = self.model.critic(states, normalize=True).squeeze() # normalize = True, unnorm only in mcts search
normalized_returns = self._normalize_return(returns)
value_loss = nn.L1Loss()(V, normalized_returns)
V = self.model.critic(states, out_norm=True).squeeze() # normalize output for loss, unnorm in mcts
# normalize returns for loss
returns = self.model.critic.return_normalizer.norm(returns)
value_loss = nn.L1Loss()(V, returns)
if self.debug:
print(f"value {V[0]} return {normalized_returns[0]}")
print(f"value range: {V.min():.3f} {V.max():.3f}, returns range: {normalized_returns.min():.3f} {normalized_returns.max():.3f}")
print(f"value {V[0]} return {returns[0]}")
print(f"normalized value range: {V.min().item():.3f} {V.max().item():.3f}")
print(f"normalized returns range: {returns.min().item():.3f} {returns.max().item():.3f}")
normalized_states = self.model.critic.obs_normalizer.norm(states)
print(f"state {states[0]}, normalized {normalized_states[0]}")
print(f"normalized obs range: {normalized_states.min().item():.3f} {normalized_states.max().item():.3f}")
return value_loss

def policy_loss(self, mcts_states, mcts_actions, mcts_counts):
Expand All @@ -89,13 +79,14 @@ def mcts_rollout(self, max_steps=1000, deterministic=False):
state, _ = self.env.reset()

for _ in range(max_steps):
s = CartState.from_array(state)
s = CartState.from_array(state) # TODO: get rid of cartstate
best_action, _ = self.mcts.get_action(s, d=10, n=100, deterministic=True) # get single action and prob
next_state, reward, terminated, truncated, info = self.env.step(np.array([[best_action]]))
states.append(state)
rewards.append(reward)
done = terminated or truncated

# mcts returns actions, counts
actions, norm_counts = self.mcts.get_policy(s)
mcts_counts.append(norm_counts)
mcts_actions.append(actions)
Expand All @@ -108,18 +99,25 @@ def mcts_rollout(self, max_steps=1000, deterministic=False):
return states, rewards, mcts_states, mcts_actions, mcts_counts

def train(self, max_iters=1000, n_episodes=10, n_steps=30):
start = time.time()
for i in range(max_iters):
# collect data using mcts
# collect data
for _ in range(n_episodes):
states, rewards, mcts_states, mcts_actions, mcts_counts = self.mcts_rollout(n_steps)
returns = self._compute_return(states)

episode_dict = TensorDict( {
"states": states,
"returns": returns,
"mcts_states": np.array(mcts_states),
"mcts_actions": np.array(mcts_actions),
"mcts_counts": np.array(mcts_counts),
# update normalizers with new data
self.model.update_normalizers(
torch.FloatTensor(states).to(self.device),
torch.FloatTensor(returns).to(self.device)
)

episode_dict = TensorDict({
"states": states,
"returns": returns,
"mcts_states": np.array(mcts_states),
"mcts_actions": np.array(mcts_actions),
"mcts_counts": np.array(mcts_counts),
}, batch_size=len(states))
self.replay_buffer.extend(episode_dict)

Expand All @@ -134,12 +132,7 @@ def train(self, max_iters=1000, n_episodes=10, n_steps=30):
loss.backward()
self.optimizer.step()

# evaluate current actor net
eps_rewards = []
for _ in range(5):
rewards = rollout(self.model.actor, self.env, max_steps=n_steps, deterministic=True)
eps_rewards.append(sum(rewards))
avg_reward = np.mean(eps_rewards)
avg_reward = np.mean(evaluate(self.model.actor, self.env))

# log metrics
self.hist['iter'].append(i)
Expand All @@ -149,13 +142,20 @@ def train(self, max_iters=1000, n_episodes=10, n_steps=30):
self.hist['total_loss'].append(loss.item())

print(f"actor loss {policy_loss.item():.3f} value loss {value_loss.item():.3f} l2 loss {l2_loss.item():.3f}")
print(f"iter {i}, reward {avg_reward:.3f}, t {time.time()-self.start:.2f}")
print(f"iter {i}, reward {avg_reward:.3f}, t {time.time()-start:.2f}")

print(f"Total time: {time.time() - self.start}")
print(f"Total time: {time.time() - start}")
return self.model, self.hist

def rollout(model, env, max_steps=1000, deterministic=False):
state, _ = env.reset()

def evaluate(model, env, n_episodes=10, n_steps=200, seed=42):
eps_rewards = []
for i in range(n_episodes):
rewards = rollout(model, env, max_steps=n_steps, seed=seed+i, deterministic=True)
eps_rewards.append(sum(rewards))
return eps_rewards

def rollout(model, env, max_steps=1000, seed=42, deterministic=False):
state, _ = env.reset(seed=seed)
rewards = []
for _ in range(max_steps):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
Expand Down Expand Up @@ -195,13 +195,13 @@ def plot_losses(hist, save_path=None):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max_iters", type=int, default=30)
parser.add_argument("--max_iters", type=int, default=10)
parser.add_argument("--n_eps", type=int, default=10)
parser.add_argument("--n_steps", type=int, default=30)
parser.add_argument("--env_bs", type=int, default=1) # TODO: batch
parser.add_argument("--save", default=True)
parser.add_argument("--noise_mode", default=None)
parser.add_argument("--debug", default=False)
parser.add_argument("--debug", default=True)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

Expand All @@ -212,16 +212,16 @@ def plot_losses(hist, save_path=None):
best_model, hist = a0c.train(args.max_iters, args.n_eps, args.n_steps)

# run value net online planner
# from run_mcts import run_mcts
# env = gym.make("CartLatAccel-v1", render_mode="human", noise_mode=args.noise_mode)
# reward = run_mcts(a0c.mcts, env, max_steps=200, search_depth=10, n_sims=100, seed=args.seed)
# print(f"reward {reward}")
from run_mcts import run_mcts
env = gym.make("CartLatAccel-v1", render_mode="human", noise_mode=args.noise_mode)
reward = run_mcts(a0c.mcts, env, max_steps=200, search_depth=10, n_sims=100, seed=args.seed)
print(f"mcts reward {reward}")

# run actor net model
print("rollout out best actor")
env = gym.make("CartLatAccel-v1", noise_mode=args.noise_mode, env_bs=1, render_mode="human")
rewards = rollout(best_model.actor, env, max_steps=200, deterministic=True)
print(f"reward {sum(rewards)}")
print(f"actor reward {sum(rewards)}")

if args.save:
os.makedirs('out', exist_ok=True)
Expand Down
Loading