-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimizers.h
More file actions
45 lines (38 loc) · 1.14 KB
/
optimizers.h
File metadata and controls
45 lines (38 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/* Optimizers. */
#ifndef OPTIMIZERS_H
#define OPTIMIZERS_H
#include "xtensor/xarray.hpp"
/* Optimization Base Class. */
class Optimizer {
protected:
double learning_rate;
public:
Optimizer(double learning_rate);
virtual void update(xt::xarray<double> &weights, const xt::xarray<double> &grad) = 0;
};
class SGDOptimizer: public Optimizer {
public:
using Optimizer::Optimizer;
void update(xt::xarray<double> &weights, const xt::xarray<double> &grad);
};
class MomentumOptimizer: public Optimizer {
private:
double momentum;
xt::xarray<double> velocity;
public:
MomentumOptimizer(double learning_rate, double momentum=0.9);
void update(xt::xarray<double> &weights, const xt::xarray<double> &grad);
};
/* Based on Kingma et al., 2014 */
class AdamOptimizer: public Optimizer {
private:
xt::xarray<double> first_moment;
xt::xarray<double> second_moment;
int t = 0;
double beta1, beta2;
double eps;
public:
AdamOptimizer(double learning_rate, double beta1=0.9, double beta2=0.999, double eps=1e-8);
void update(xt::xarray<double> &weights, const xt::xarray<double> &grad);
};
#endif