Skip to content

Commit 353e343

Browse files
committed
Fix up loop-binding issues in ImageTransformerV2
1 parent 11ea20d commit 353e343

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

k_diffusion/models/image_transformer_v2.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,18 @@ class MappingSpec:
653653
dropout: float
654654

655655

656+
def make_layer_factory(spec, mapping):
657+
if isinstance(spec.self_attn, GlobalAttentionSpec):
658+
return lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
659+
elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
660+
return lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
661+
elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
662+
return lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
663+
elif isinstance(spec.self_attn, NoAttentionSpec):
664+
return lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
665+
raise ValueError(f"unsupported self attention spec {spec.self_attn}")
666+
667+
656668
# Model class
657669

658670
class ImageTransformerDenoiserModelV2(nn.Module):
@@ -672,16 +684,7 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c
672684

673685
self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList()
674686
for i, spec in enumerate(levels):
675-
if isinstance(spec.self_attn, GlobalAttentionSpec):
676-
layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
677-
elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
678-
layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
679-
elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
680-
layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
681-
elif isinstance(spec.self_attn, NoAttentionSpec):
682-
layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
683-
else:
684-
raise ValueError(f"unsupported self attention spec {spec.self_attn}")
687+
layer_factory = self.make_layer_factory(spec, mapping)
685688

686689
if i < len(levels) - 1:
687690
self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)]))

0 commit comments

Comments
 (0)