Skip to content

Commit 992c941

Browse files
wadekargclaude
andcommitted
Add DQN, Neural REINFORCE, and A2C algorithms to CartPole RL lab
Ports all three deep RL algorithms from Jordan Lei's cartpole repo to TypeScript: - nnUtils.ts: shared neural net primitives (linear, relu, softmax, He init, Adam optimizer, backprop helpers) - dqn.ts: DQN with 4→128→64→2 network, experience replay (5K buffer), target network (sync every 50 steps), per-step ε decay, gradient clipping - neuralReinforce.ts: Neural REINFORCE with 4→128→2 softmax policy, normalized returns, Adam - a2c.ts: A2C with separate Actor (4→128→2) and Critic (4→128→1) networks, advantage = G_t − V(s_t) UI updates: - ClassicCartPolePage: 6 algorithm tabs (random, Q-learning, REINFORCE, DQN, Neural REINFORCE, A2C) with per-algorithm hyperparameter sliders - ClassicStepBreakdownPanel: new views for DQN (Q-values, buffer status, Bellman eq), Neural REINFORCE (probs + return normalization), A2C (actor probs + critic value + advantage) - classicCartpoleExplainer: descriptions for all 3 new algorithms - classicCartpoleStepBreakdown: DQN/NeuralReinforce/A2C breakdown types and compute functions Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 178b967 commit 992c941

8 files changed

Lines changed: 1568 additions & 11 deletions

File tree

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/**
2+
* a2c.ts — Advantage Actor-Critic (A2C) for Classic CartPole.
3+
*
4+
* Ported from Jordan Lei's Actor-Critic implementation.
5+
*
6+
* Architecture:
7+
* Actor (policy): 4 → 128 → 2 (softmax — probability of left/right)
8+
* Critic (value): 4 → 128 → 1 (linear — estimates V(s))
9+
*
10+
* Update (per episode):
11+
* advantage_t = G_t − V(s_t)
12+
* actor_loss = mean(−log π(a_t|s_t) × advantage_t.detach)
13+
* critic_loss = 0.0005 × mean(advantage_t²)
14+
*
15+
* Both actor and critic use separate Adam optimizers.
16+
*/
17+
18+
import type { Agent } from '../types'
19+
import type { ClassicCartPoleState, ClassicCartPoleAction } from '../../environments/classicCartpole'
20+
import {
21+
linear,
22+
relu,
23+
softmax,
24+
heInit,
25+
zerosInit,
26+
adamInit,
27+
adamUpdate,
28+
linearBackward,
29+
accumWeightGrad,
30+
accumBiasGrad,
31+
sampleCategorical,
32+
AdamState,
33+
} from './nnUtils'
34+
35+
// ─── Trajectory record ────────────────────────────────────────────────────────
36+
37+
interface A2CStep {
38+
x: number[] // state [4]
39+
// Actor activations
40+
actH1: number[] // post-ReLU [128]
41+
actPre1: number[] // pre-ReLU [128]
42+
probs: number[] // π(·|s) [2]
43+
action: number
44+
// Critic activations
45+
criH1: number[] // post-ReLU [128]
46+
criPre1: number[] // pre-ReLU [128]
47+
value: number // V(s)
48+
reward: number
49+
}
50+
51+
// ─── A2CAgent ─────────────────────────────────────────────────────────────────
52+
53+
export class A2CAgent implements Agent<ClassicCartPoleState, ClassicCartPoleAction> {
54+
// Actor weights: 4 → 128 → 2
55+
private actW1: number[] // [128 * 4]
56+
private actB1: number[] // [128]
57+
private actW2: number[] // [2 * 128]
58+
private actB2: number[] // [2]
59+
60+
// Critic weights: 4 → 128 → 1
61+
private criW1: number[] // [128 * 4]
62+
private criB1: number[] // [128]
63+
private criW2: number[] // [1 * 128]
64+
private criB2: number[] // [1]
65+
66+
// Adam states
67+
private adamActW1: AdamState; private adamActB1: AdamState
68+
private adamActW2: AdamState; private adamActB2: AdamState
69+
private adamCriW1: AdamState; private adamCriB1: AdamState
70+
private adamCriW2: AdamState; private adamCriB2: AdamState
71+
72+
private readonly lr: number
73+
private readonly gamma: number
74+
private readonly criticScale: number // = 0.0005 from Jordan's impl
75+
76+
private trajectory: A2CStep[] = []
77+
private lastProbs = [0.5, 0.5]
78+
private lastValue = 0
79+
80+
constructor(lr = 0.005, gamma = 0.99, criticScale = 0.0005) {
81+
this.lr = lr
82+
this.gamma = gamma
83+
this.criticScale = criticScale
84+
this.actW1 = heInit(128, 4); this.actB1 = zerosInit(128)
85+
this.actW2 = heInit(2, 128); this.actB2 = zerosInit(2)
86+
this.criW1 = heInit(128, 4); this.criB1 = zerosInit(128)
87+
this.criW2 = heInit(1, 128); this.criB2 = zerosInit(1)
88+
this.adamActW1 = adamInit(128 * 4); this.adamActB1 = adamInit(128)
89+
this.adamActW2 = adamInit(2 * 128); this.adamActB2 = adamInit(2)
90+
this.adamCriW1 = adamInit(128 * 4); this.adamCriB1 = adamInit(128)
91+
this.adamCriW2 = adamInit(1 * 128); this.adamCriB2 = adamInit(1)
92+
}
93+
94+
private actorForward(x: number[]) {
95+
const actPre1 = linear(this.actW1, this.actB1, x, 128, 4)
96+
const actH1 = relu(actPre1)
97+
const logits = linear(this.actW2, this.actB2, actH1, 2, 128)
98+
const probs = softmax(logits)
99+
return { probs, actH1, actPre1 }
100+
}
101+
102+
private criticForward(x: number[]) {
103+
const criPre1 = linear(this.criW1, this.criB1, x, 128, 4)
104+
const criH1 = relu(criPre1)
105+
const valueArr = linear(this.criW2, this.criB2, criH1, 1, 128)
106+
return { value: valueArr[0], criH1, criPre1 }
107+
}
108+
109+
act(state: ClassicCartPoleState): ClassicCartPoleAction {
110+
const x = [state.x, state.xDot, state.theta, state.thetaDot]
111+
const { probs } = this.actorForward(x)
112+
const { value } = this.criticForward(x)
113+
this.lastProbs = probs
114+
this.lastValue = value
115+
return sampleCategorical(probs) as ClassicCartPoleAction
116+
}
117+
118+
learn(
119+
state: ClassicCartPoleState,
120+
action: ClassicCartPoleAction,
121+
reward: number,
122+
_nextState: ClassicCartPoleState,
123+
done: boolean,
124+
): void {
125+
const x = [state.x, state.xDot, state.theta, state.thetaDot]
126+
const { probs, actH1, actPre1 } = this.actorForward(x)
127+
const { value, criH1, criPre1 } = this.criticForward(x)
128+
129+
this.trajectory.push({ x, actH1, actPre1, probs, action, criH1, criPre1, value, reward })
130+
131+
if (!done) return
132+
133+
// ── Episode ended: compute returns and advantages ──────────────────────
134+
135+
const T = this.trajectory.length
136+
const G = new Array<number>(T)
137+
let g = 0
138+
for (let t = T - 1; t >= 0; t--) {
139+
g = this.trajectory[t].reward + this.gamma * g
140+
G[t] = g
141+
}
142+
143+
// advantage_t = G_t - V(s_t)
144+
const advantages = this.trajectory.map((step, t) => G[t] - step.value)
145+
146+
// ── Actor gradients ────────────────────────────────────────────────────
147+
const dActW1 = zerosInit(128 * 4); const dActB1 = zerosInit(128)
148+
const dActW2 = zerosInit(2 * 128); const dActB2 = zerosInit(2)
149+
150+
for (let t = 0; t < T; t++) {
151+
const { x: xt, actH1, actPre1, probs: pt, action: at } = this.trajectory[t]
152+
const adv = advantages[t] / T // average over episode (detached from critic)
153+
154+
// dL_actor/dlogits[j] = adv * (probs[j] - I(j==action))
155+
const dLogits = pt.map((p, j) => adv * (p - (j === at ? 1 : 0)))
156+
157+
const dActH1 = linearBackward(this.actW2, dLogits, 2, 128)
158+
accumWeightGrad(dActW2, dLogits, actH1, 2, 128)
159+
accumBiasGrad(dActB2, dLogits)
160+
161+
const dActPre1 = dActH1.map((v, i) => (actPre1[i] > 0 ? v : 0))
162+
accumWeightGrad(dActW1, dActPre1, xt, 128, 4)
163+
accumBiasGrad(dActB1, dActPre1)
164+
}
165+
166+
// ── Critic gradients ───────────────────────────────────────────────────
167+
const dCriW1 = zerosInit(128 * 4); const dCriB1 = zerosInit(128)
168+
const dCriW2 = zerosInit(1 * 128); const dCriB2 = zerosInit(1)
169+
170+
for (let t = 0; t < T; t++) {
171+
const { x: xt, criH1, criPre1 } = this.trajectory[t]
172+
const adv = advantages[t]
173+
174+
// L_critic = criticScale * (G_t - V)^2, so dL/dV = -2 * criticScale * advantage / T
175+
const dOut1 = [-2 * this.criticScale * adv / T]
176+
177+
const dCriH1 = linearBackward(this.criW2, dOut1, 1, 128)
178+
accumWeightGrad(dCriW2, dOut1, criH1, 1, 128)
179+
accumBiasGrad(dCriB2, dOut1)
180+
181+
const dCriPre1 = dCriH1.map((v, i) => (criPre1[i] > 0 ? v : 0))
182+
accumWeightGrad(dCriW1, dCriPre1, xt, 128, 4)
183+
accumBiasGrad(dCriB1, dCriPre1)
184+
}
185+
186+
// ── Adam updates ───────────────────────────────────────────────────────
187+
adamUpdate(this.actW1, dActW1, this.adamActW1, this.lr)
188+
adamUpdate(this.actB1, dActB1, this.adamActB1, this.lr)
189+
adamUpdate(this.actW2, dActW2, this.adamActW2, this.lr)
190+
adamUpdate(this.actB2, dActB2, this.adamActB2, this.lr)
191+
192+
adamUpdate(this.criW1, dCriW1, this.adamCriW1, this.lr)
193+
adamUpdate(this.criB1, dCriB1, this.adamCriB1, this.lr)
194+
adamUpdate(this.criW2, dCriW2, this.adamCriW2, this.lr)
195+
adamUpdate(this.criB2, dCriB2, this.adamCriB2, this.lr)
196+
197+
this.trajectory = []
198+
}
199+
200+
getValues(): Record<string, number[]> {
201+
return {
202+
probs: [...this.lastProbs],
203+
value: [this.lastValue],
204+
}
205+
}
206+
207+
reset(): void {
208+
this.actW1 = heInit(128, 4); this.actB1 = zerosInit(128)
209+
this.actW2 = heInit(2, 128); this.actB2 = zerosInit(2)
210+
this.criW1 = heInit(128, 4); this.criB1 = zerosInit(128)
211+
this.criW2 = heInit(1, 128); this.criB2 = zerosInit(1)
212+
this.adamActW1 = adamInit(128 * 4); this.adamActB1 = adamInit(128)
213+
this.adamActW2 = adamInit(2 * 128); this.adamActB2 = adamInit(2)
214+
this.adamCriW1 = adamInit(128 * 4); this.adamCriB1 = adamInit(128)
215+
this.adamCriW2 = adamInit(1 * 128); this.adamCriB2 = adamInit(1)
216+
this.trajectory = []
217+
this.lastProbs = [0.5, 0.5]
218+
this.lastValue = 0
219+
}
220+
}

0 commit comments

Comments
 (0)