diff --git a/machine/models/Receiver.py b/machine/models/Receiver.py new file mode 100644 index 00000000..0d851a3d --- /dev/null +++ b/machine/models/Receiver.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +from .baseRNN import BaseRNN + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class Receiver(BaseRNN): + """ + Applies a rnn to a message produced by a Sender. + Args: + vocab_size (int): size of the vocabulary + embedding_size (int): the size of the embedding of input variables + hidden_size (int): the size of the hidden dimension of the rnn + rnn_cell (str, optional): type of RNN cell (default: gru) + Inputs: + m (torch.tensor): The message produced by the Sender. Shape [batch_size, max_seq_len] + Outputs: + output (torch.tensor): The batch of the appended hidden states at each timestep. + state (torch.tensor): (h,c) of the last timestep if LSTM, h if GRU + """ + + def __init__(self, vocab_size, embedding_size, hidden_size, rnn_cell='gru'): + super().__init__(vocab_size, -1, hidden_size, + input_dropout_p=0, dropout_p=0, + n_layers=1, rnn_cell=rnn_cell) + + self.rnn = self.rnn_cell(embedding_size, hidden_size, num_layers=1, batch_first=True) + self.embedding = nn.Parameter(torch.empty((vocab_size, embedding_size), dtype=torch.float32)) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.embedding, 0.0, 0.1) + + nn.init.xavier_uniform_(self.rnn.weight_ih_l0) + nn.init.orthogonal_(self.rnn.weight_hh_l0) + nn.init.constant_(self.rnn.bias_ih_l0, val=0) + # cuDNN bias order: https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t + # add some positive bias for the forget gates [b_i, b_f, b_o, b_g] = [0, 1, 0, 0] + nn.init.constant_(self.rnn.bias_hh_l0, val=0) + nn.init.constant_(self.rnn.bias_hh_l0[self.hidden_size:2 * self.hidden_size], val=1) + + def forward(self, messages): + emb = torch.matmul(messages, self.embedding) if self.training else self.embedding[messages] + return self.rnn(emb) \ No newline at end of file diff --git a/test/test_receiver.py b/test/test_receiver.py new file mode 100644 index 00000000..4635f275 --- /dev/null +++ b/test/test_receiver.py @@ -0,0 +1,73 @@ +import unittest +import mock + +import torch + +from machine.models.Receiver import Receiver + +class TestReceiver(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.vocab_size = 4 + self.embedding_size = 8 + self.hidden_size = 16 + self.seq_len = 5 + + def test_lstm_train(self): + receiver = Receiver(self.vocab_size, self.embedding_size, + self.hidden_size, rnn_cell='lstm') + + batch_size = 2 + m = torch.rand([batch_size, self.seq_len, self.vocab_size]) + + receiver.train() + + outputs, (h,c) = receiver(m) + + self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size)) + self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size)) + self.assertEqual(c.squeeze().shape, (batch_size, self.hidden_size)) + + def test_gru_train(self): + receiver = Receiver(self.vocab_size, self.embedding_size, + self.hidden_size, rnn_cell='gru') + + batch_size = 2 + m = torch.rand([batch_size, self.seq_len, self.vocab_size]) + + receiver.train() + + outputs, h = receiver(m) + + self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size)) + self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size)) + + def test_lstm_eval(self): + receiver = Receiver(self.vocab_size, self.embedding_size, + self.hidden_size, rnn_cell='lstm') + + batch_size = 2 + m = torch.randint(high=self.vocab_size, size=(batch_size, self.seq_len)) + + receiver.eval() + + outputs, (h,c) = receiver(m) + + self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size)) + self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size)) + self.assertEqual(c.squeeze().shape, (batch_size, self.hidden_size)) + + def test_gru_eval(self): + receiver = Receiver(self.vocab_size, self.embedding_size, + self.hidden_size, rnn_cell='gru') + + batch_size = 2 + m = torch.randint(high=self.vocab_size, size=(batch_size, self.seq_len)) + + receiver.eval() + + outputs, h = receiver(m) + + self.assertEqual(outputs.shape, (batch_size, self.seq_len, self.hidden_size)) + self.assertEqual(h.squeeze().shape, (batch_size, self.hidden_size)) \ No newline at end of file