diff --git a/tome/patch/mae.py b/tome/patch/mae.py index 8a9e842d..83cce7fd 100644 --- a/tome/patch/mae.py +++ b/tome/patch/mae.py @@ -15,7 +15,7 @@ from tome.utils import parse_r -from .timm import ToMeBlock, ToMeAttention +from .timm import ToMeBlock, FlashAttnToMeAttention def make_tome_class(transformer_class): @@ -100,4 +100,4 @@ def apply_patch( module.__class__ = ToMeBlock module._tome_info = model._tome_info elif isinstance(module, Attention): - module.__class__ = ToMeAttention + module.__class__ = FlashAttnToMeAttention diff --git a/tome/patch/timm.py b/tome/patch/timm.py index ae2b8fc8..5577db72 100644 --- a/tome/patch/timm.py +++ b/tome/patch/timm.py @@ -17,6 +17,8 @@ from tome.merge import bipartite_soft_matching, merge_source, merge_wavg from tome.utils import parse_r +from flash_attn import flash_attn_qkvpacked_func + class ToMeBlock(Block): """ @@ -96,6 +98,42 @@ def forward( return x, k.mean(1) +class FlashAttnToMeAttention(Attention): + """ + Modifications: + - apply Flash-attn + - Do not Apply proportional attention for MAE models + - Return the mean of k over heads from attention + """ + + def forward( + self, x: torch.Tensor, size: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Note: this is copied from timm.models.vision_transformer.Attention with modifications. + B, N, C = x.shape + try: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), + self.v_bias)) + except: + qkv_bias = None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1) + k = qkv.permute(2, 0, 3, 1, 4)[1] + + x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, + softmax_scale=self.scale, + causal=False) + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + # Return k as well here + return x, k.mean(1) + + def make_tome_class(transformer_class): class ToMeVisionTransformer(transformer_class): """