-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathmodeling_moss_audio_tokenizer.py
More file actions
1843 lines (1505 loc) · 69.2 KB
/
modeling_moss_audio_tokenizer.py
File metadata and controls
1843 lines (1505 loc) · 69.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MossAudioTokenizer model."""
from __future__ import annotations
import copy
import math
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedAudioTokenizerBase
from transformers.utils import ModelOutput, auto_docstring, logging
from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
logger = logging.get_logger(__name__)
# =============================================================================
# Output Classes
# =============================================================================
@dataclass
@auto_docstring
class MossAudioTokenizerEncoderOutput(ModelOutput):
r"""
audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
Discrete audio codes computed using the encoder and quantizer.
audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio codes.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_size, sequence_length)`, *optional*):
Hidden states from the encoder before quantization.
"""
audio_codes: torch.Tensor | None = None
audio_codes_lengths: torch.Tensor | None = None
encoder_hidden_states: torch.Tensor | None = None
@dataclass
@auto_docstring
class MossAudioTokenizerDecoderOutput(ModelOutput):
r"""
audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
Decoded audio waveform.
audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio.
"""
audio: torch.Tensor | None = None
audio_lengths: torch.Tensor | None = None
@dataclass
@auto_docstring
class MossAudioTokenizerOutput(ModelOutput):
r"""
audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
Decoded audio waveform.
audio_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio.
audio_codes (`torch.LongTensor` of shape `(num_quantizers, batch_size, sequence_length)`, *optional*):
Discrete audio codes computed using the encoder and quantizer.
audio_codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Valid lengths for each sample's audio codes.
"""
audio: torch.Tensor | None = None
audio_lengths: torch.Tensor | None = None
audio_codes: torch.Tensor | None = None
audio_codes_lengths: torch.Tensor | None = None
# =============================================================================
# Streaming Module Base Classes
# =============================================================================
@dataclass
class StreamingState:
"""Base state for streaming modules."""
batch_size: int
device: torch.device
def __post_init__(self):
self.exec_mask = torch.ones(self.batch_size, dtype=torch.bool, device=self.device)
def set_exec_mask(self, exec_mask: torch.Tensor):
self.exec_mask[:] = exec_mask
def reset(self, reset_mask: torch.Tensor) -> None:
self.exec_mask[:] = torch.where(reset_mask, torch.ones_like(self.exec_mask), self.exec_mask)
def __enter__(self):
# ExitStack expects a context manager; returning self is conventional and useful for debugging.
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
class StreamingModule(nn.Module):
"""Base class for streaming components."""
def __init__(self) -> None:
super().__init__()
self._streaming_state: StreamingState | None = None
self._streaming_detached: bool = False
self._cached_children: list[tuple[str, StreamingModule]] | None = None
@property
def is_streaming(self):
return self._streaming_state is not None
def _apply_named_streaming(self, fn):
def _handle_module(prefix: str, module: nn.Module):
if isinstance(module, StreamingModule):
if module._streaming_detached and prefix != "":
return
if self._cached_children is None:
raise RuntimeError("Internal error: _cached_children should be initialized before traversal.")
self._cached_children.append((prefix, module))
for name, child in module.named_children():
new_prefix = f"{prefix}.{name}" if prefix else name
_handle_module(new_prefix, child)
if self._cached_children is None:
self._cached_children = []
_handle_module("", self)
for name, child in self._cached_children:
fn(name, child)
def _start_streaming(self, batch_size: int, exit_stack: ExitStack):
def _start_streaming_fn(name: str, module: StreamingModule):
if module._streaming_state is not None:
raise RuntimeError(f"{name} is already streaming!")
state = module._init_streaming_state(batch_size)
exit_stack.enter_context(state)
module._streaming_state = state
self._apply_named_streaming(_start_streaming_fn)
def _stop_streaming(self) -> None:
def _stop_streaming_fn(name: str, module: StreamingModule):
module._streaming_state = None
self._apply_named_streaming(_stop_streaming_fn)
def _init_streaming_state(self, batch_size: int) -> StreamingState:
device = next(iter(self.parameters())).device
return StreamingState(batch_size, device)
def streaming(self, batch_size: int) -> ExitStack:
"""Context manager to enter streaming mode."""
exit_stack = ExitStack()
self._start_streaming(batch_size, exit_stack)
exit_stack.callback(self._stop_streaming)
return exit_stack
class StreamingContainer(StreamingModule):
"""Container for streaming modules."""
pass
# =============================================================================
# Normalization Layers
# =============================================================================
class MossAudioTokenizerRMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(
self,
dim: int,
eps: float = 1e-5,
dtype: torch.dtype | None = None,
device=None,
):
super().__init__()
self.eps = eps
self.dtype = dtype
self.alpha = nn.Parameter(torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
x_dtype = x.dtype
if self.dtype is not None:
x = x.to(self.dtype)
var = self.eps + torch.mean(x**2, dim=2, keepdim=True)
y = (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
return y
class MossAudioTokenizerLayerScale(nn.Module):
"""Layer scale from Touvron et al. 2021."""
def __init__(
self,
channels: int,
init: float = 1e-4,
channel_last: bool = True,
device=None,
dtype=None,
):
super().__init__()
self.channel_last = channel_last
self.scale = nn.Parameter(torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
if self.channel_last:
return self.scale * x
else:
return self.scale[:, None] * x
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
"""Create normalization module."""
if norm_type == "layer_norm":
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
elif norm_type in {"rms_norm"}:
return MossAudioTokenizerRMSNorm(dim, eps=1e-5, **kwargs)
elif norm_type in {"rms_norm_f32"}:
kwargs.pop("dtype", None)
return MossAudioTokenizerRMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
# =============================================================================
# Rotary Position Embedding
# =============================================================================
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
offset: torch.Tensor,
max_period: float = 10_000,
time_before_heads: bool = False,
):
"""Apply rotary position embedding."""
if time_before_heads:
B, T, H, D = q.shape
else:
B, H, T, D = q.shape
if k.shape != q.shape:
raise ValueError(f"Expected k.shape == q.shape, got k={tuple(k.shape)} q={tuple(q.shape)}")
if D <= 0 or (D % 2) != 0:
raise ValueError(f"RoPE requires an even last dimension, got D={D}")
ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32)
if time_before_heads:
ts = ts.view(B, -1, 1, 1)
else:
ts = ts.view(B, 1, -1, 1)
dims = q.shape[:-1]
q = q.view(*dims, D // 2, 2)
k = k.view(*dims, D // 2, 2)
qr, qi = q[..., 0].float(), q[..., 1].float()
kr, ki = k[..., 0].float(), k[..., 1].float()
rotr = torch.cos(freqs * ts)
roti = torch.sin(freqs * ts)
qor = qr * rotr - qi * roti
qoi = qr * roti + qi * rotr
kor = kr * rotr - ki * roti
koi = kr * roti + ki * rotr
dtype = q.dtype
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
return qo.view(*dims, D), ko.view(*dims, D)
class MossAudioTokenizerRotaryEmbedding(nn.Module):
"""Rotary positional embedding (RoPE)."""
def __init__(self, max_period: float = 10000.0):
super().__init__()
self.max_period = max_period
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
offset: torch.Tensor,
time_before_heads: bool = False,
):
return apply_rope(q, k, offset, self.max_period, time_before_heads)
# =============================================================================
# Gating Modules
# =============================================================================
class MossAudioTokenizerActivationGating(nn.Module):
"""Gating FFN layer with activation."""
def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs):
super().__init__()
if dim_feedforward == 4 * dim:
hidden = (21 * dim) // 8
else:
hidden = (2 * dim_feedforward) // 3
self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
self.activation = activation
def forward(self, x: torch.Tensor):
x = self.linear_in(x)
B, T, _ = x.shape
x = x.view(B, T, 2, -1)
x = self.activation(x[..., 0, :]) * x[..., 1, :]
x = self.linear_out(x)
return x
def _get_activation(name: str):
if name in ["sigmoid", "tanh", "relu"]:
return getattr(torch, name)
elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
return getattr(F, name)
elif name == "identity":
return nn.Identity()
else:
raise ValueError(f"Unknown activation {name}")
def make_gating(name: str, dim: int, dim_feedforward: int, **factory_kwargs) -> nn.Module:
return MossAudioTokenizerActivationGating(dim, dim_feedforward, _get_activation(name), **factory_kwargs)
# =============================================================================
# Positional Embeddings
# =============================================================================
def create_sin_embedding(
positions: torch.Tensor,
dim: int,
max_period: float = 10000,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Create sinusoidal positional embedding with shape [B, T, C]."""
if dim % 2 != 0:
raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}")
half_dim = dim // 2
if half_dim <= 1:
raise ValueError(f"Sinusoidal embedding requires dim >= 4, got dim={dim}")
positions = positions.to(dtype)
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
# =============================================================================
# KV Cache for Attention
# =============================================================================
class KVCacheResult:
"""Container for KV cache results that supports tuple unpacking."""
__slots__ = ("keys", "values", "positions")
def __init__(self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor):
self.keys = keys
self.values = values
self.positions = positions
def __iter__(self):
"""Allow unpacking as (keys, values, positions)."""
return iter((self.keys, self.values, self.positions))
@staticmethod
def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult:
B, H, T, D = keys.shape
positions = torch.arange(T, device=keys.device, dtype=torch.long)
return KVCacheResult(keys, values, positions.expand(B, -1))
class RingKVCache:
"""Efficient streaming KVCache compatible with CUDA Graph."""
def __init__(
self,
batch_size: int,
num_heads: int,
dim_per_head: int,
capacity: int,
respect_exec_mask: bool = True,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
):
self.capacity = capacity
self.cache = torch.zeros(
(2, batch_size, num_heads, capacity, dim_per_head),
device=device,
dtype=dtype,
)
self.respect_exec_mask = respect_exec_mask
if self.respect_exec_mask:
self.end_offset = torch.zeros(batch_size, device=device, dtype=torch.long)
else:
self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
def reset(self, reset_mask: torch.Tensor) -> None:
self.end_offset[:] = torch.where(reset_mask, torch.zeros_like(self.end_offset), self.end_offset)
def complete(self, k: torch.Tensor, v: torch.Tensor, exec_mask: torch.Tensor) -> KVCacheResult:
B, H, T, D = k.shape
if T <= 0:
raise ValueError(f"Expected T > 0, got T={T}")
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype)
indexes = indexes + self.end_offset.view(-1, 1)
indexes = indexes % self.capacity
if self.respect_exec_mask:
this_indexes = indexes.view(B, 1, T, 1).expand(-1, H, T, D)
self.cache[0].scatter_(2, this_indexes, k)
self.cache[1].scatter_(2, this_indexes, v)
else:
self.cache[0].index_copy_(2, indexes[0], k)
self.cache[1].index_copy_(2, indexes[0], v)
keys = self.cache[0]
values = self.cache[1]
indexes = torch.arange(self.capacity, device=self.end_offset.device, dtype=torch.long)
last_offset = self.end_offset.view(-1, 1) + T - 1
end_index = last_offset % self.capacity
delta = indexes - end_index
positions = torch.where(
delta <= 0,
last_offset + delta,
last_offset + delta - self.capacity,
)
if self.respect_exec_mask:
self.end_offset[:] = torch.where(exec_mask, self.end_offset + T, self.end_offset)
else:
self.end_offset.add_(T)
invalid = indexes >= self.end_offset.view(-1, 1)
positions = torch.where(invalid, torch.full_like(positions, -1), positions)
return KVCacheResult(keys, values, positions)
# =============================================================================
# Multi-Head Attention
# =============================================================================
@dataclass
class MHAState(StreamingState):
kv_cache: RingKVCache | None
offset: torch.Tensor
offset_cpu: int
def reset(self, reset_mask: torch.Tensor):
super().reset(reset_mask)
self.offset[:] = torch.where(reset_mask, torch.zeros_like(self.offset), self.offset)
if self.kv_cache is not None:
self.kv_cache.reset(reset_mask)
self.offset_cpu = 0
def apply_weights_per_step(
modules: nn.ModuleList,
schedule: list[int] | None,
x: torch.Tensor,
offset: int | None,
) -> torch.Tensor:
"""Apply different weights for each time step."""
if len(modules) == 1:
return modules[0](x)
if offset is None:
raise ValueError("offset must be provided when using per-step weights (len(modules) > 1).")
ys = []
B, T, C = x.shape
for t in range(T):
module_index = t + offset
if schedule is not None:
if module_index >= len(schedule) or module_index < 0:
raise ValueError(
f"weights_per_step_schedule is too short for module_index={module_index} (len={len(schedule)})."
)
module_index = schedule[module_index]
if module_index >= len(modules) or module_index < 0:
raise ValueError(f"module_index={module_index} out of range for len(modules)={len(modules)}.")
y = modules[module_index](x[:, t : t + 1])
ys.append(y)
return torch.cat(ys, 1)
class MossAudioTokenizerMultiheadAttention(StreamingModule):
"""Multi-head attention with streaming support."""
def __init__(
self,
embed_dim: int,
num_heads: int,
causal: bool = False,
context: int | None = None,
rope: MossAudioTokenizerRotaryEmbedding | None = None,
weights_per_step: int = 0,
weights_per_step_schedule: list[int] | None = None,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.embed_dim = embed_dim
self.causal = causal
self.context = context
self.rope = rope
self.num_heads = num_heads
self.weights_per_step = weights_per_step
self.weights_per_step_schedule = weights_per_step_schedule
out_dim = 3 * embed_dim
mult = 1
if weights_per_step:
mult = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step
self.mult = mult
self.out_projs = nn.ModuleList(
[nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs) for _ in range(mult)]
)
self.in_projs = nn.ModuleList(
[nn.Linear(embed_dim, out_dim, bias=False, **factory_kwargs) for _ in range(mult)]
)
self._register_load_state_dict_pre_hook(self._load_hook, with_module=True)
@staticmethod
def _load_hook(module, state_dict, prefix, *_):
mappings = {
"in_proj_weight": "in_projs.{i}.weight",
"in_proj.weight": "in_projs.{i}.weight",
"out_proj.weight": "out_projs.{i}.weight",
}
mult = module.mult
for suffix in ["", "_scb"]:
for source, target in mappings.items():
this_source = prefix + source + suffix
if this_source in state_dict:
weight = state_dict[this_source]
_, *OD = weight.shape
weight = weight.view(mult, -1, *OD)
for i in range(mult):
state_dict[prefix + target.format(i=i) + suffix] = weight[i]
state_dict.pop(this_source)
def _init_streaming_state(self, batch_size: int) -> MHAState:
in_proj = cast(nn.Linear, self.in_projs[0])
device = cast(torch.device, in_proj.weight.device)
dtype = cast(torch.dtype, in_proj.weight.dtype)
dim_per_head = self.embed_dim // self.num_heads
if self.context is None:
capacity = self.weights_per_step if self.weights_per_step else 1024
else:
capacity = self.context
kv_cache = RingKVCache(
batch_size,
self.num_heads,
dim_per_head,
capacity,
respect_exec_mask=not self.weights_per_step,
device=cast(torch.device, device),
dtype=cast(torch.dtype, dtype),
)
return MHAState(
batch_size,
cast(torch.device, device),
kv_cache,
offset=torch.zeros(batch_size, device=cast(torch.device, device), dtype=torch.long),
offset_cpu=0,
)
def _complete_kv(self, k, v) -> KVCacheResult:
state = cast(MHAState | None, self._streaming_state)
if state is None:
return KVCacheResult.from_kv(k, v)
if state.kv_cache is None:
return KVCacheResult.from_kv(k, v)
return state.kv_cache.complete(k, v, state.exec_mask)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
state = cast(MHAState | None, self._streaming_state)
B, T = query.shape[:2]
if state is None:
offset = torch.zeros(B, device=query.device, dtype=torch.long)
offset_cpu = 0
else:
offset = state.offset
offset_cpu = state.offset_cpu
projected = apply_weights_per_step(self.in_projs, self.weights_per_step_schedule, query, offset_cpu)
dim_per_head = self.embed_dim // self.num_heads
projected = projected.reshape(B, T, 3, self.num_heads, dim_per_head).permute(2, 0, 3, 1, 4)
q, k, v = projected[0], projected[1], projected[2]
if self.rope:
q, k = self.rope(q, k, offset, time_before_heads=False)
k, v, pos_k = self._complete_kv(k, v)
pos_k = pos_k[:, None]
if self.causal:
pos_q = offset.view(-1, 1, 1) + torch.arange(T, device=q.device, dtype=torch.long).view(-1, 1)
delta = pos_q - pos_k
attn_bias = (pos_k >= 0) & (delta >= 0)
if self.context is not None:
attn_bias = attn_bias & (delta < self.context)
attn_bias = attn_bias[:, None]
else:
attn_bias = None
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, T, self.embed_dim)
x = apply_weights_per_step(self.out_projs, self.weights_per_step_schedule, x, offset_cpu)
if state is not None:
state.offset[:] = torch.where(state.exec_mask, state.offset + T, state.offset)
state.offset_cpu += T
return x
# =============================================================================
# Transformer Layer
# =============================================================================
@dataclass
class LayerState(StreamingState):
offset_cpu: int = 0
def reset(self, reset_mask: torch.Tensor):
super().reset(reset_mask)
self.offset_cpu = 0
class MossAudioTokenizerTransformerLayer(StreamingModule):
"""Transformer layer with streaming support."""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int = 2048,
causal: bool = False,
context: int | None = None,
rope: MossAudioTokenizerRotaryEmbedding | None = None,
norm: str = "layer_norm",
layer_scale: float | None = None,
gating: str = "none",
weights_per_step: int = 0,
weights_per_step_schedule: list[int] | None = None,
activation=F.gelu,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.self_attn = MossAudioTokenizerMultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
causal=causal,
context=context,
rope=rope,
weights_per_step=weights_per_step,
weights_per_step_schedule=weights_per_step_schedule,
**factory_kwargs,
)
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)
self.weights_per_step = weights_per_step
self.weights_per_step_schedule = weights_per_step_schedule
self.gating: nn.Module | nn.ModuleList | None = None
self.linear1: nn.Module | None = None
self.linear2: nn.Module | None = None
self.activation = activation
num_weights = 1
if weights_per_step:
num_weights = max(weights_per_step_schedule) + 1 if weights_per_step_schedule else weights_per_step
if gating == "none":
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs)
else:
if weights_per_step:
dim_ff_list = [dim_feedforward] * num_weights if isinstance(dim_feedforward, int) else dim_feedforward
self.gating = nn.ModuleList(
[make_gating(gating, d_model, dim, **factory_kwargs) for dim in dim_ff_list]
)
else:
self.gating = make_gating(gating, d_model, dim_feedforward, **factory_kwargs)
if layer_scale is None:
self.layer_scale_1 = nn.Identity()
self.layer_scale_2 = nn.Identity()
else:
self.layer_scale_1 = MossAudioTokenizerLayerScale(
channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs)
)
self.layer_scale_2 = MossAudioTokenizerLayerScale(
channels=d_model, init=layer_scale, channel_last=True, **cast(dict[str, object], factory_kwargs)
)
def _init_streaming_state(self, batch_size: int) -> LayerState:
device = next(iter(self.parameters())).device
return LayerState(batch_size, device, offset_cpu=0)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
state = self._streaming_state
offset = state.offset_cpu if isinstance(state, LayerState) else 0
x_orig = x
x = self.norm2(x)
if self.gating is None:
assert self.linear1 is not None
assert self.linear2 is not None
update = self.linear2(self.activation(self.linear1(x)))
else:
if self.weights_per_step:
assert isinstance(self.gating, nn.ModuleList)
update = apply_weights_per_step(self.gating, self.weights_per_step_schedule, x, offset)
else:
update = self.gating(x)
return x_orig.to(update) + self.layer_scale_2(update)
def _sa_block(self, x: torch.Tensor):
x_orig = x
x = self.norm1(x)
update = self.self_attn(x, x, x)
return x_orig.to(update) + self.layer_scale_1(update)
def forward(self, x: torch.Tensor):
x = self._sa_block(x)
x = self._ff_block(x)
state = self._streaming_state
if state is not None:
assert isinstance(state, LayerState)
state.offset_cpu += x.shape[1]
return x
# =============================================================================
# Streaming Transformer
# =============================================================================
@dataclass
class TransformerState(StreamingState):
offsets: torch.Tensor
def reset(self, reset_mask: torch.Tensor):
super().reset(reset_mask)
self.offsets[:] = torch.where(reset_mask, torch.zeros_like(self.offsets), self.offsets)
class MossAudioTokenizerTransformer(StreamingModule):
"""Transformer with streaming/causal support."""
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
dim_feedforward: int = 2048,
causal: bool = False,
context: int | None = None,
positional_embedding: str = "sin",
max_period: float = 10_000,
positional_scale: float = 1.0,
device=None,
dtype=None,
**kwargs,
):
super().__init__()
if d_model % num_heads != 0:
raise ValueError(f"d_model must be divisible by num_heads, got d_model={d_model}, num_heads={num_heads}")
self.positional_embedding = positional_embedding
self.max_period = max_period
self.positional_scale = positional_scale
self.rope: MossAudioTokenizerRotaryEmbedding | None = None
if positional_embedding in {"rope", "sin_rope"}:
self.rope = MossAudioTokenizerRotaryEmbedding(max_period=max_period)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
MossAudioTokenizerTransformerLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
causal=causal,
context=context,
rope=self.rope,
device=device,
dtype=dtype,
**kwargs,
)
)
def _init_streaming_state(self, batch_size: int) -> TransformerState:
device = next(self.parameters()).device
return TransformerState(
batch_size,
device,
offsets=torch.zeros(batch_size, device=device, dtype=torch.long),
)
def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
state = self._streaming_state
offsets = (
torch.zeros(1, dtype=torch.long, device=x.device)
if state is None
else (
state.offsets
if isinstance(state, TransformerState)
else torch.zeros(1, dtype=torch.long, device=x.device)
)
)
if self.positional_embedding in {"sin", "sin_rope"}:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
positions = positions + offsets.view(-1, 1, 1)
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
x = x + self.positional_scale * pos_emb
for layer in self.layers:
x = layer(x, *args, **kwargs)
if state is not None:
assert isinstance(state, TransformerState)
state.offsets[:] = torch.where(state.exec_mask, state.offsets + T, state.offsets)
return x
class MossAudioTokenizerProjectedTransformer(StreamingContainer):
"""Transformer with input/output projections."""
def __init__(
self,
input_dimension: int,
output_dimension: int,
d_model: int,
*,
conv_layout: bool = False,
module_type: str,
**kwargs,
):
super().__init__()
self.module_type = module_type
self.downsample_ratio: int = 1
self.input_dimension = input_dimension
self.output_dimension = output_dimension
self.input_proj = (
nn.Linear(input_dimension, d_model, bias=False) if d_model != input_dimension else nn.Identity()
)
self.transformer = MossAudioTokenizerTransformer(d_model=d_model, **kwargs)
self.conv_layout = conv_layout
self.output_proj = (
nn.Linear(d_model, output_dimension, bias=False) if d_model != output_dimension else nn.Identity()
)
def forward(self, x, input_lengths, *args, **kwargs):
x = self.input_proj(x.transpose(1, 2)) # (B, D, T) -> (B, T, D)
x = self.transformer(x, *args, **kwargs)
x = self.output_proj(x).transpose(1, 2) # (B, T, D) -> (B, D, T)
return x, input_lengths
# =============================================================================
# Patched Pretransform Module
# =============================================================================
class MossAudioTokenizerPatchedPretransform(nn.Module):
"""Patching module for downsampling/upsampling."""
def __init__(self, patch_size: int, is_downsample: bool, module_type: str, **kwargs):
super().__init__()
self.patch_size = patch_size
self.downsample_ratio: int = patch_size
self.is_downsample = is_downsample
self.module_type = module_type
def encode(self, x, input_lengths):
b, d, _ = x.shape
h = self.patch_size
x = x.reshape(b, d, -1, h).permute(0, 1, 3, 2).reshape(b, d * h, -1)
# We pad the input waveform to a multiple of `downsample_rate` before applying the encoder.
# Use a ceil division to match that padding and avoid dropping the last (partially padded) frame.
output_lengths = input_lengths // self.patch_size
return x, output_lengths
def decode(self, x, input_lengths):
b, dh, l = x.shape
h = self.patch_size
d = dh // h
x = x.reshape(b, d, h, l).permute(0, 1, 3, 2).reshape(b, d, l * h)
output_lengths = input_lengths * self.patch_size
return x, output_lengths
def forward(self, x, input_lengths):
if self.is_downsample:
return self.encode(x, input_lengths)
else:
return self.decode(x, input_lengths)
# =============================================================================
# Vector Quantization
# =============================================================================
def WNConv1d(*args, **kwargs):
"""Weight-normalized Conv1d."""
return nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs))
class MossAudioTokenizerVectorQuantize(nn.Module):
"""Single codebook vector quantization (inference only)."""
def __init__(
self,
input_dim: int,
codebook_size: int,
codebook_dim: int,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
if input_dim != codebook_dim:
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
else:
self.in_proj = nn.Identity()
self.out_proj = nn.Identity()
self.codebook = nn.Embedding(codebook_size, codebook_dim)
@torch.no_grad()
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
z: Input tensor of shape (B, D, T)