Skip to content

Create model.py2#7

Open
elonmasai7 wants to merge 1 commit into
johnma2006:masterfrom
elonmasai7:patch-1
Open

Create model.py2#7
elonmasai7 wants to merge 1 commit into
johnma2006:masterfrom
elonmasai7:patch-1

Conversation

@elonmasai7
Copy link
Copy Markdown

another view
import math
import torch
from torch import nn
from torch.nn.utils import weight_norm

class Mamba(nn.Module):
def init(self, d_model, d_state, n_layers, d_inner, dropout=0.1):
super().init()

    # We don't want to learn position embeddings.
    # We'll do a simple positional encoding.
    # Note that we divide by sqrt(d_model), which you'll find across other Transformer implementations,
    # and serve the same purpose as with standard attention.
    # `to(torch.float32)` is there only because this code is intended to be seamlessly used with mixed precision.
    self.pos_enc = torch.arange(0, 64, dtype=torch.float32).view(1, -1).to(torch.float32) / math.sqrt(d_model)

    layers = []
    for _ in range(n_layers):
        layers.append(MambaLayer(d_model, d_state, d_inner, dropout=dropout))
    self.layers = nn.ModuleList(layers)

    # Final dense layers.
    self.fc = nn.Linear(d_model, 50257)

def forward(self, x, state_init=None):
    # The input has `l` sequences of length `L` and `b` batch size.
    # `x` has shape: `(l, b, L, d_model)`.
    # We assume the first dimension is the `l` sequence one.
    l, b, L, d = x.shape

    if state_init is None:
        state_init = torch.zeros(l, b, 1, d // 2, dtype=x.dtype, device=x.device)

    x = x + self.pos_enc[:L, None]
    states, outs = [], []

    for layer in self.layers:
        x, state = layer(x, state_init)
        states.append(state)
        # `outs` will eventually have shape `(l, b, L, d)`.
        outs.append(x)

    return self.fc(torch.cat(outs, dim=-1)), torch.cat(states, dim=-2)

class MambaLayer(nn.Module):
def init(self, d_model, d_state, d_inner, dropout=0.1):
super().init()
d_model_half = d_model // 2

    self.lin_A = nn.Linear(d_model, d_model_half)
    self.lin_D = nn.Linear(d_model, d_model_half)

    self.lin_in = nn.Linear(d_model, d_inner)
    self.lin_B1 = nn.Linear(d_inner, d_model_half)
    self.lin_B2 = nn.Linear(d_state, d_model_half)
    self.lin_C = weight_norm(nn.Linear(d_model_half, d_model_half))

    self.dropout = nn.Dropout(dropout)

def forward(self, x, state_init):
    # We output both the state AND the transformed sequence (`x`).
    # The `x` shape is expected to be `(l, b, L, d)`.
    # The `state_init` shape is expected to be `(l, b, 1, n)`.

    l, b, L, d = x.shape
    d_model_half = d // 2

    # We learned to use tanh activation for A and D.
    A = torch.tanh(self.lin_A(x))
    D = torch.tanh(self.lin_D(x))

    a = self.dropout(self.lin_in(x))
    b1 = self.lin_B1(a)
    b2 = self.dropout(self.lin_B2(state_init))
    B = b1 + b2
    c = self.lin_C(self.dropout(A * B))
    state = D * state_init + c[:, :, :, None]

    # It looks like state_init might be off by one timestep from A, B, C, D, but this is
    # not the case because we will start the loop on the 2nd timestep. It is perfectly
    # consistent with the equations of Mamba (see [1] Algorithm 2).
    # Intuitively, we also need to use `state_init` at time `t - 1` rather than `t` to compute
    # `x_t`. Indeed, `state_t - 1` is a consequence of `x_t - 1` and `u_t - 1`.
    # If we were to use `state_t`, this would be equivalent to having `δ_t = 1` instead of
    # `δ_t = 0`, which is the case under the "zero-input" assumption made by the authors
    # (see Equation (7) in [1]).
    x = A * B + C

    # We obtain a new state and a new output sequence `x`.
    return x, state

another view
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant