Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions machine/models/Receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torch.nn as nn

from .baseRNN import BaseRNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Receiver(BaseRNN):
"""
Applies a rnn to a message produced by a Sender.
Args:
vocab_size (int): size of the vocabulary
embedding_size (int): the size of the embedding of input variables
hidden_size (int): the size of the hidden dimension of the rnn
rnn_cell (str, optional): type of RNN cell (default: gru)
Inputs:
m (torch.tensor): The message produced by the Sender. Shape [batch_size, max_seq_len]
Outputs:
output (torch.tensor): The batch of the appended hidden states at each timestep.
state (torch.tensor): (h,c) of the last timestep if LSTM, h if GRU
"""

def __init__(self, vocab_size, embedding_size, hidden_size, rnn_cell='gru'):
super().__init__(vocab_size, -1, hidden_size,
input_dropout_p=0, dropout_p=0,
n_layers=1, rnn_cell=rnn_cell)

self.rnn = self.rnn_cell(embedding_size, hidden_size, num_layers=1, batch_first=True)
self.embedding = nn.Parameter(torch.empty((vocab_size, embedding_size), dtype=torch.float32))

self.reset_parameters()

def reset_parameters(self):
nn.init.normal_(self.embedding, 0.0, 0.1)

nn.init.xavier_uniform_(self.rnn.weight_ih_l0)
nn.init.orthogonal_(self.rnn.weight_hh_l0)
nn.init.constant_(self.rnn.bias_ih_l0, val=0)
# cuDNN bias order: https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
# add some positive bias for the forget gates [b_i, b_f, b_o, b_g] = [0, 1, 0, 0]
nn.init.constant_(self.rnn.bias_hh_l0, val=0)
nn.init.constant_(self.rnn.bias_hh_l0[self.hidden_size:2 * self.hidden_size], val=1)

def forward(self, messages):
emb = torch.matmul(messages, self.embedding) if self.training else self.embedding[messages]
return self.rnn(emb)
73 changes: 73 additions & 0 deletions test/test_receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
import mock

import torch

from machine.models.Receiver import Receiver

class TestReceiver(unittest.TestCase):

@classmethod
def setUpClass(self):
self.vocab_size = 4
self.embedding_size = 8
self.hidden_size = 16
self.seq_len = 5

def test_lstm_train(self):
receiver = Receiver(self.vocab_size, self.embedding_size,
self.hidden_size, rnn_cell='lstm')

batch_size = 2
m = torch.rand([batch_size, self.seq_len, self.vocab_size])

receiver.train()

outputs, (h,c) = receiver(m)

self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size))
self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size))
self.assertEqual(c.squeeze().shape, (batch_size, self.hidden_size))

def test_gru_train(self):
receiver = Receiver(self.vocab_size, self.embedding_size,
self.hidden_size, rnn_cell='gru')

batch_size = 2
m = torch.rand([batch_size, self.seq_len, self.vocab_size])

receiver.train()

outputs, h = receiver(m)

self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size))
self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size))

def test_lstm_eval(self):
receiver = Receiver(self.vocab_size, self.embedding_size,
self.hidden_size, rnn_cell='lstm')

batch_size = 2
m = torch.randint(high=self.vocab_size, size=(batch_size, self.seq_len))

receiver.eval()

outputs, (h,c) = receiver(m)

self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size))
self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size))
self.assertEqual(c.squeeze().shape, (batch_size, self.hidden_size))

def test_gru_eval(self):
receiver = Receiver(self.vocab_size, self.embedding_size,
self.hidden_size, rnn_cell='gru')

batch_size = 2
m = torch.randint(high=self.vocab_size, size=(batch_size, self.seq_len))

receiver.eval()

outputs, h = receiver(m)

self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size))
self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size))