@@ -653,6 +653,18 @@ class MappingSpec:
653653 dropout : float
654654
655655
656+ def make_layer_factory (spec , mapping ):
657+ if isinstance (spec .self_attn , GlobalAttentionSpec ):
658+ return lambda _ : GlobalTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , dropout = spec .dropout )
659+ elif isinstance (spec .self_attn , NeighborhoodAttentionSpec ):
660+ return lambda _ : NeighborhoodTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .kernel_size , dropout = spec .dropout )
661+ elif isinstance (spec .self_attn , ShiftedWindowAttentionSpec ):
662+ return lambda i : ShiftedWindowTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .window_size , i , dropout = spec .dropout )
663+ elif isinstance (spec .self_attn , NoAttentionSpec ):
664+ return lambda _ : NoAttentionTransformerLayer (spec .width , spec .d_ff , mapping .width , dropout = spec .dropout )
665+ raise ValueError (f"unsupported self attention spec { spec .self_attn } " )
666+
667+
656668# Model class
657669
658670class ImageTransformerDenoiserModelV2 (nn .Module ):
@@ -672,16 +684,7 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c
672684
673685 self .down_levels , self .up_levels = nn .ModuleList (), nn .ModuleList ()
674686 for i , spec in enumerate (levels ):
675- if isinstance (spec .self_attn , GlobalAttentionSpec ):
676- layer_factory = lambda _ : GlobalTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , dropout = spec .dropout )
677- elif isinstance (spec .self_attn , NeighborhoodAttentionSpec ):
678- layer_factory = lambda _ : NeighborhoodTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .kernel_size , dropout = spec .dropout )
679- elif isinstance (spec .self_attn , ShiftedWindowAttentionSpec ):
680- layer_factory = lambda i : ShiftedWindowTransformerLayer (spec .width , spec .d_ff , spec .self_attn .d_head , mapping .width , spec .self_attn .window_size , i , dropout = spec .dropout )
681- elif isinstance (spec .self_attn , NoAttentionSpec ):
682- layer_factory = lambda _ : NoAttentionTransformerLayer (spec .width , spec .d_ff , mapping .width , dropout = spec .dropout )
683- else :
684- raise ValueError (f"unsupported self attention spec { spec .self_attn } " )
687+ layer_factory = self .make_layer_factory (spec , mapping )
685688
686689 if i < len (levels ) - 1 :
687690 self .down_levels .append (Level ([layer_factory (i ) for i in range (spec .depth )]))
0 commit comments