-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdqfd_learner.py
More file actions
148 lines (121 loc) · 5.91 KB
/
dqfd_learner.py
File metadata and controls
148 lines (121 loc) · 5.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# This code is based on code from OpenAI baselines. (https://github.com/openai/baselines)
import tensorflow as tf
@tf.function
def huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(
tf.abs(x) < delta,
tf.square(x) * 0.5,
delta * (tf.abs(x) - 0.5 * delta)
)
class DQfD(tf.Module):
def __init__(self,
q_func,
observation_shape,
num_actions,
lr,
grad_norm_clipping=None,
gamma=0.99,
n_step=10,
exp_margin=0.8,
lambda1=1.0,
lambda2=1.0,
lambda3=1e-5,
double_q=True,
param_noise=False,
param_noise_filter_func=None):
self.num_actions = num_actions
self.gamma = gamma
self.n_step = n_step
self.exp_margin = exp_margin
self.lambda1 = lambda1
self.lambda2 = lambda2
self.lambda3 = lambda3
self.double_q = double_q
self.param_noise = param_noise
self.param_noise_filter_func = param_noise_filter_func
self.grad_norm_clipping = grad_norm_clipping
self.optimizer = tf.keras.optimizers.Adam(lr)
with tf.name_scope('q_network'):
self.q_network = q_func(observation_shape, num_actions)
with tf.name_scope('target_q_network'):
self.target_q_network = q_func(observation_shape, num_actions)
self.eps = tf.Variable(0., name="eps")
@tf.function
def step(self, obs, stochastic=True, update_eps=-1):
if self.param_noise:
raise ValueError('not supporting noise yet')
else:
q_values = self.q_network(obs)
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(obs)[0]
random_actions = tf.random.uniform(tf.stack([batch_size]), minval=0, maxval=self.num_actions, dtype=tf.int64)
chose_random = tf.random.uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < self.eps
stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)
if stochastic:
output_actions = stochastic_actions
else:
output_actions = deterministic_actions
if update_eps >= 0:
self.eps.assign(update_eps)
return output_actions, self.eps, None, None
@tf.function()
def train(self, obs0, actions, rewards, obs1, dones, is_demos, importance_weights, obsn=None, rewards_n=None, dones_n=None):
with tf.GradientTape() as tape:
# ====================1-step loss===================
q_t = self.q_network(obs0)
one_hot_actions = tf.one_hot(actions, self.num_actions, dtype=tf.float32)
q_t_selected = tf.reduce_sum(q_t * one_hot_actions, 1)
q_tp1 = self.target_q_network(obs1)
if self.double_q:
q_tp1_using_online_net = self.q_network(obs1)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best = tf.reduce_sum(q_tp1 * tf.one_hot(q_tp1_best_using_online_net, self.num_actions, dtype=tf.float32), 1)
else:
q_tp1_best = tf.reduce_max(q_tp1, 1)
dones = tf.cast(dones, q_tp1_best.dtype)
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
q_t_selected_target = rewards + self.gamma * q_tp1_best_masked
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
loss_dq = huber_loss(td_error)
# ====================n-step loss===================
if obsn is not None:
q_tpn = self.target_q_network(obsn)
if self.double_q:
q_tpn_using_online_net = self.q_network(obsn)
q_tpn_best_using_online_net = tf.argmax(q_tpn_using_online_net, 1)
q_tpn_best = tf.reduce_sum(q_tpn * tf.one_hot(q_tpn_best_using_online_net, self.num_actions, dtype=tf.float32), 1)
else:
q_tpn_best = tf.reduce_max(q_tpn, 1)
dones_n = tf.cast(dones_n, q_tpn_best.dtype)
q_tpn_best_masked = (1.0 - dones_n) * q_tpn_best
q_tn_selected_target = rewards_n + (self.gamma ** self.n_step) * q_tpn_best_masked
n_td_error = (q_t_selected - tf.stop_gradient(q_tn_selected_target))* tf.cast(is_demos, q_tp1_best.dtype)
loss_n = self.lambda1 * huber_loss(n_td_error)
else:
loss_n = tf.constant(0.)
# ==========large margin classification loss=========
is_demo = tf.cast(is_demos, q_tp1_best.dtype)
margin_l = self.exp_margin * (tf.ones_like(one_hot_actions, dtype=tf.float32) - one_hot_actions)
margin_masked = tf.reduce_max(q_t + margin_l, 1)
loss_E = self.lambda2 * is_demo * (margin_masked - q_t_selected)
# ==========L2 loss=========
loss_l2 = self.lambda3 * tf.reduce_sum([tf.reduce_sum(tf.square(variables)) for variables in self.q_network.trainable_variables])
all_loss = loss_n + loss_dq + loss_E
weighted_error = tf.reduce_mean(importance_weights * all_loss) + loss_l2
grads = tape.gradient(weighted_error, self.q_network.trainable_variables)
if self.grad_norm_clipping:
clipped_grads = []
for grad in grads:
clipped_grads.append(tf.clip_by_norm(grad, self.grad_norm_clipping))
grads = clipped_grads
grads_and_vars = zip(grads, self.q_network.trainable_variables)
self.optimizer.apply_gradients(grads_and_vars)
return td_error, n_td_error, loss_dq, loss_n, loss_E, loss_l2, weighted_error
@tf.function(autograph=False)
def update_target(self):
q_vars = self.q_network.trainable_variables
target_q_vars = self.target_q_network.trainable_variables
for var, var_target in zip(q_vars, target_q_vars):
var_target.assign(var)
print("target network update")