-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathautoencoder.py
More file actions
52 lines (39 loc) · 1.26 KB
/
autoencoder.py
File metadata and controls
52 lines (39 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class AutoEncoder:
def __init__(self, encoder, decoder):
self.encoder = encoder
self.decoder = decoder
def forward(self, x):
z = self.encoder.forward(x)
out = self.decoder.forward(z)
return out
def backward(self, grad):
grad = self.decoder.backward(grad)
if self.latent:
grad = self.latent.backward(grad)
grad = self.encoder.backward(grad)
return grad
@property
def layers(self):
layers = []
layers += self.encoder.layers
layers += self.decoder.layers
return layers
class VariationalAutoEncoder:
def __init__(self, encoder, decoder, latent):
self.encoder = encoder
self.decoder = decoder
self.latent = latent
def forward(self, x):
mu, logvar = self.encoder.forward(x)
z = self.latent.forward(mu, logvar)
out = self.decoder.forward(z)
return out
def backward(self, grad):
grad = self.decoder.backward(grad)
dmu, dlogvar = self.latent.backward(grad)
self.encoder.backward(dmu, dlogvar)
def kl_loss(self):
return self.latent.kl_divergence()
@property
def layers(self):
return self.encoder.layers + self.decoder.layers