-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTFMixer.py
More file actions
736 lines (595 loc) · 26.5 KB
/
TFMixer.py
File metadata and controls
736 lines (595 loc) · 26.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
# models/TFMixer.py
# -*- coding: utf-8 -*-
"""
Bridging Time and Frequency: A Joint Modeling Framework for Irregular Multivariate Time Series Forecasting.
This module implements the TFMixer model, which integrates global frequency analysis
(via Non-Uniform DFT) and local patch-based representation learning (Query Attention and MLP-Mixer)
to capture both global and local temporal dependencies.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
class MaskedRevIN(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=False):
"""
Args:
num_features: 变量数量 (N)
eps: 防止除零的微小值
affine: 是否使用可学习的仿射变换 (Learnable Affine)
"""
super(MaskedRevIN, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self._init_params()
def _init_params(self):
# 仿射变换参数:缩放 (weight) 和 平移 (bias)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x, mask):
"""
只基于有效观测值计算 Mean 和 Stdev
x: [Batch, Seq_Len, N_vars]
mask: [Batch, Seq_Len, N_vars] (1=Observed, 0=Missing/Padding)
"""
# 1. 计算有效点的数量 [Batch, 1, N_vars]
# keepdim=True 保持维度以便广播
valid_count = mask.sum(dim=1, keepdim=True)
# 防止除以 0:如果某变量全缺失,分母设为 1 (此时分子也是0,mean=0)
valid_count = torch.clamp(valid_count, min=1.0)
# 2. 计算 Masked Mean
# sum(x * mask) / count
self.mean = (x * mask).sum(dim=1, keepdim=True) / valid_count
# 3. 计算 Masked Variance -> Stdev
# sum((x - mean)^2 * mask) / count
# 注意:这里 x 在缺失处虽然是0,但 (0 - mean)^2 * 0 = 0,
# 所以我们需要确保只累加 mask=1 的部分的方差
var = ((x - self.mean) ** 2 * mask).sum(dim=1, keepdim=True) / valid_count
self.stdev = torch.sqrt(var + self.eps)
# 停止梯度反传到统计量,保持它们作为静态特征
self.mean = self.mean.detach()
self.stdev = self.stdev.detach()
def forward(self, x, mask, mode: str):
"""
x: [Batch, Seq_Len, N]
mask: [Batch, Seq_Len, N]
"""
if mode == "norm":
self._get_statistics(x, mask)
# 归一化:(x - mean) / std
# 即使是缺失值(0)也会被减去 mean 变成非0值,但这没关系,
# 因为后续模型层(如 TTCN/Attention)会再次用到 mask 把它们屏蔽掉。
x = (x - self.mean) / self.stdev
if self.affine:
x = x * self.affine_weight + self.affine_bias
elif mode == "denorm":
# 反归一化:(x - bias) / weight * std + mean
if self.affine:
x = (x - self.affine_bias) / (self.affine_weight + 1e-10)
x = x * self.stdev + self.mean
return x
class PositionalEncoding(nn.Module):
"""
Standard Sinusoidal Positional Encoding.
Injects information about the relative or absolute position of the tokens in the sequence.
"""
def __init__(self, d_model: int, max_len: int = 5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# Compute the geometric progression of the wavelengths
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# Apply sine to even indices and cosine to odd indices
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (Batch, Seq_Len, d_model)
Returns:
Tensor with positional encoding added.
"""
x = x + self.pe[:, : x.size(1), :]
return x
class GlobalRawNUDFT(nn.Module):
"""
Global Raw Non-Uniform Discrete Fourier Transform (NUDFT) Layer.
This module captures global frequency components from irregular or masked time series data.
It includes a mechanism for mixing information across variables (Channel Mixing)
in the frequency domain.
"""
def __init__(self, d_model: int, num_freqs: int = 64, n_vars: int = None):
"""
Args:
d_model (int): Hidden dimension size.
num_freqs (int): Number of frequency components to learn.
n_vars (int, optional): Number of variables (channels). Required for the channel mixer.
"""
super(GlobalRawNUDFT, self).__init__()
self.num_freqs = num_freqs
self.d_model = d_model
self.n_vars = n_vars
# Learnable frequencies initialized sequentially
self.freqs = nn.Parameter(
torch.arange(1, num_freqs + 1).float(), requires_grad=True
)
# --- Channel Mixer ---
# Mixes information across the N variables while keeping frequency components independent initially.
if n_vars is not None:
self.var_mixer = nn.Sequential(
nn.Linear(n_vars, n_vars),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(n_vars, n_vars),
)
else:
self.var_mixer = None
# Encoder for the raw spectrum (Real + Imaginary parts)
self.spectrum_encoder = nn.Sequential(
nn.Linear(num_freqs * 2, num_freqs * 2),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(num_freqs * 2, num_freqs * 2),
)
# Project spectrum to hidden dimension
self.freq_mlp = nn.Sequential(
nn.Linear(num_freqs * 2, d_model * 2),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(d_model * 2, d_model),
)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, t: torch.Tensor, mask: torch.Tensor):
"""
Computes the NUDFT spectrum and processes it.
Args:
x: Input values (Batch * N_vars, Length, 1).
t: Time stamps (Batch * N_vars, Length, 1).
mask: Validity mask (Batch * N_vars, Length, 1).
Returns:
processed_embedding: (Batch * N_vars, d_model)
raw_spectrum: (Batch * N_vars, num_freqs * 2) - Used for reconstruction/forecasting.
"""
# 1. NUDFT Calculation
# Calculate arguments for sin/cos basis: 2 * pi * t * freq
args = 2.0 * math.pi * t * self.freqs.view(1, 1, -1)
cos_basis = torch.cos(args)
sin_basis = torch.sin(args)
x_masked = x * mask
# Normalize by the number of valid points (clamped to avoid division by zero)
norm_factor = torch.clamp(mask.sum(dim=1), min=5.0)
# Project input onto basis functions
real_spectrum = torch.sum(x_masked * cos_basis, dim=1) / norm_factor
imag_spectrum = torch.sum(x_masked * -sin_basis, dim=1) / norm_factor
# Concatenate Real and Imaginary parts -> Shape: (B * N, num_freqs * 2)
spectrum = torch.cat([real_spectrum, imag_spectrum], dim=-1)
# Encode spectrum features
spectrum = self.spectrum_encoder(spectrum)
# 2. Variable Mixing (Inter-Channel Interaction)
if self.var_mixer is not None:
BN, D_freq = spectrum.shape
N = self.n_vars
B = BN // N
# Reshape to separate Batch and Variables: [B, N, D_freq]
spectrum_reshaped = spectrum.view(B, N, D_freq)
# Transpose to [B, D_freq, N] so Linear layer applies across N (variables)
spectrum_transposed = spectrum_reshaped.permute(0, 2, 1)
# Apply Mixer: Interaction across variables for each frequency component
mixed_spectrum = self.var_mixer(spectrum_transposed)
# Residual connection + Restore shape: [B, D_freq, N] -> [B, N, D_freq] -> [B*N, D_freq]
spectrum = (
(spectrum_transposed + mixed_spectrum)
.permute(0, 2, 1)
.reshape(BN, D_freq)
)
raw_spectrum = spectrum
# Project to model dimension and normalize
return self.norm(self.freq_mlp(spectrum)), raw_spectrum
def forecast(self, spectrum: torch.Tensor, t_future: torch.Tensor) -> torch.Tensor:
"""
Performs Harmonic Extrapolation using the learned spectrum.
Formula: y(t) = Sum(Real * cos(wt) + Imag * -sin(wt))
Args:
spectrum: Learned spectrum (B * N, num_freqs * 2).
t_future: Future time steps (B * N, L_pred, 1).
Returns:
pred: Extrapolated values (B * N, L_pred, 1).
"""
# 1. Split spectrum into real and imaginary coefficients
real_spectrum = spectrum[:, : self.num_freqs].unsqueeze(1)
imag_spectrum = spectrum[:, self.num_freqs :].unsqueeze(1)
# 2. Generate future basis functions
args = 2.0 * math.pi * t_future * self.freqs.view(1, 1, -1)
cos_basis = torch.cos(args)
sin_basis = torch.sin(args)
# 3. Inverse Transform (Synthesis)
pred = torch.sum(real_spectrum * cos_basis, dim=-1) + torch.sum(
imag_spectrum * -sin_basis, dim=-1
)
return pred.unsqueeze(-1)
class TFMixer_SubModel(nn.Module):
"""
Core architecture of the TFMixer model.
Combines Temporal Convolution (TTCN), Patch Attention, and dual-mixing (Time/Variable) layers.
"""
def __init__(self, configs):
super(TFMixer_SubModel, self).__init__()
self.configs = configs
# --- Hyperparameters ---
self.hid_dim = configs.d_model
self.N = configs.enc_in # Number of variables
self.dropout = configs.dropout
self.num_freqs = configs.tfmixer_num_freqs
# Patching and Architecture params
self.n_patch = configs.tfmixer_npatch
self.te_dim = configs.tfmixer_te_dim # Time Encoding dimension
self.n_layer = configs.tfmixer_nlayer
self.tf_layer = configs.tfmixer_tf_layer # Expansion factor for FeedForward
self.K = configs.tfmixer_K # Internal patch representation size
self.outlayer = "Linear" # Output aggregation type
# --- Modules Initialization ---
# 1. Global Frequency Branch
self.global_nudft = GlobalRawNUDFT(
d_model=self.hid_dim, num_freqs=self.num_freqs, n_vars=None
)
# Learnable weights for fusing frequency and time domain features
self.fusion_weight = nn.Parameter(torch.tensor(0.5))
self.seasonal_weight = nn.Parameter(torch.tensor(0.1))
# 2. Time Encoding (TE) Generators
self.te_scale = nn.Linear(1, 1)
self.te_periodic = nn.Linear(1, self.te_dim - 1)
# Node Embeddings (Identity for each variable)
self.node_emb = nn.Parameter(torch.randn(1, 1, self.N, self.hid_dim))
# 3. TTCN (Temporal Trend Convolution Network)
# Generates dynamic filters based on input intensity
input_dim = 1 + self.te_dim
ttcn_dim = self.hid_dim - 1
self.ttcn_dim = ttcn_dim
self.Filter_Generators = nn.Sequential(
nn.Linear(input_dim, ttcn_dim, bias=True),
nn.ReLU(inplace=True),
nn.Linear(ttcn_dim, ttcn_dim, bias=True),
nn.ReLU(inplace=True),
nn.Linear(ttcn_dim, input_dim * ttcn_dim, bias=True),
)
self.T_bias = nn.Parameter(torch.randn(1, ttcn_dim))
# 4. Patch Attention Components
self.query_patches = nn.Parameter(torch.randn(1, 1, self.K, ttcn_dim + 1))
self.ADD_PE = PositionalEncoding(self.hid_dim)
# 5. Mixer Layers (Dual-Mixing)
# Patch Mixer: Mixes information across the 'K' dimension (internal patch time)
self.patch_mixer_layers = nn.ModuleList()
for _ in range(self.n_layer):
self.patch_mixer_layers.append(
nn.Sequential(
nn.Linear(self.K, self.K * self.tf_layer),
nn.GELU(),
nn.Dropout(self.dropout),
nn.Linear(self.K * self.tf_layer, self.K),
nn.Dropout(self.dropout),
)
)
# Variable Mixer: Mixes information across the 'N' dimension (variables)
self.var_mixer_layers = nn.ModuleList()
for _ in range(self.n_layer):
self.var_mixer_layers.append(
nn.Sequential(
nn.Linear(self.N, self.N * self.tf_layer),
nn.GELU(),
nn.Dropout(self.dropout),
nn.Linear(self.N * self.tf_layer, self.N),
nn.Dropout(self.dropout),
)
)
self.norm_patch = nn.ModuleList(
[nn.LayerNorm(self.hid_dim) for _ in range(self.n_layer)]
)
self.norm_var = nn.ModuleList(
[nn.LayerNorm(self.hid_dim) for _ in range(self.n_layer)]
)
# 6. Output Layers
if self.outlayer == "Linear":
self.temporal_agg = nn.Sequential(
nn.Linear(self.hid_dim * self.K, self.hid_dim)
)
elif self.outlayer == "CNN":
self.temporal_agg = nn.Sequential(
nn.Conv1d(self.hid_dim, self.hid_dim, kernel_size=self.K)
)
self.decoder = nn.Sequential(
nn.Linear(self.hid_dim + self.te_dim, self.hid_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hid_dim, self.hid_dim),
nn.ReLU(inplace=True),
nn.Linear(self.hid_dim, 1),
)
def LearnableTE(self, tt: torch.Tensor) -> torch.Tensor:
"""
Generates learnable time encodings consisting of a scaling term and periodic terms.
"""
out1 = self.te_scale(tt)
out2 = torch.sin(self.te_periodic(tt))
return torch.cat([out1, out2], -1)
def TTCN(self, X_int: torch.Tensor, mask_X: torch.Tensor) -> torch.Tensor:
"""
Temporal Trend Convolution Logic.
Uses a hyper-network style approach where filters are generated from the input data.
Args:
X_int: Input tensor (N_dim, Lx, Input_Dim)
mask_X: Mask tensor
"""
N_dim, Lx, _ = mask_X.shape
# Generate filters
Filter = self.Filter_Generators(X_int)
# Mask filters for invalid data points
Filter_mask = Filter * mask_X + (1 - mask_X) * (-1e8)
# Normalize filters over the sequence length
Filter_seqnorm = F.softmax(Filter_mask, dim=-2)
Filter_seqnorm = Filter_seqnorm.view(N_dim, Lx, self.ttcn_dim, -1)
# Broadcast input and apply convolution
X_int_broad = X_int.unsqueeze(dim=-2).repeat(1, 1, self.ttcn_dim, 1)
ttcn_out = torch.sum(torch.sum(X_int_broad * Filter_seqnorm, dim=-3), dim=-1)
h_t = torch.relu(ttcn_out + self.T_bias)
return h_t
def IMTS_Model_Logic(self, x, mask_X, B):
"""
Main logic for patch processing, attention, and mixing.
"""
# 1. Patch Representation via TTCN
mask_patch = mask_X.sum(dim=1) > 0
x_patch = self.TTCN(x, mask_X)
# Concatenate convolution output with patch mask validity
x_patch = torch.cat([x_patch, mask_patch], dim=-1)
# Reshape to [Batch, Variables, Num_Patches, Dim]
x_patch = x_patch.view(B, self.N, self.n_patch, -1)
B_curr, N, M, D = x_patch.shape
# 2. Add Node Embeddings (Variable Identity)
node_identities = self.node_emb.permute(0, 2, 1, 3)
x_patch = x_patch + node_identities
# 3. Add Positional Encoding
x_patch = x_patch.view(B_curr * N, M, D)
x_patch = self.ADD_PE(x_patch).view(B_curr, N, M, D)
# 4. Attention Mechanism (Queries interacting with Patch Keys/Values)
Q = self.query_patches.expand(B_curr, N, -1, -1)
K_mat = x_patch
V_mat = x_patch
scores = torch.matmul(Q, K_mat.transpose(-1, -2)) / math.sqrt(D)
attn_weights = torch.softmax(scores, dim=-1)
x_patch = torch.matmul(attn_weights, V_mat) # [Batch, N, K, D]
x_curr = x_patch
# 5. Dual Mixing Blocks (Time-Mixing & Variable-Mixing)
for layer in range(self.n_layer):
# --- A. Patch (Time) Mixing ---
# Permute to apply Linear layer on dimension K
# [B, N, K, D] -> [B, N, D, K]
x_in = x_curr.permute(0, 1, 3, 2)
# MLP applied over K dimension (intra-patch time)
x_out = self.patch_mixer_layers[layer](x_in)
# Restore: [B, N, D, K] -> [B, N, K, D]
x_out = x_out.permute(0, 1, 3, 2)
# Residual Connection + Normalization
x_curr = self.norm_patch[layer](x_curr + x_out)
# --- B. Variable Mixing ---
# Permute to apply Linear layer on dimension N
# [B, N, K, D] -> [B, K, D, N]
x_in_var = x_curr.permute(0, 2, 3, 1)
# MLP applied over N dimension (inter-variable interaction)
x_out_var = self.var_mixer_layers[layer](x_in_var)
# Restore: [B, K, D, N] -> [B, N, K, D]
x_out_var = x_out_var.permute(0, 3, 1, 2)
# Residual Connection + Normalization
x_curr = self.norm_var[layer](x_curr + x_out_var)
# 6. Output Aggregation
if self.outlayer == "CNN":
x_curr = x_curr.reshape(B_curr * self.N, self.K, -1).permute(0, 2, 1)
x_curr = self.temporal_agg(x_curr)
x_curr = x_curr.view(B_curr, self.N, -1)
elif self.outlayer == "Linear":
# Flatten patch internal dimension before aggregation
x_curr = x_curr.reshape(B_curr, self.N, -1) # (B, N, M*F or equivalent)
x_curr = self.temporal_agg(x_curr) # (B, N, hid_dim)
return x_curr
def _prepare_global_inputs(
self, x: torch.Tensor, x_mark: torch.Tensor, x_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepares flattened inputs for global frequency processing.
"""
B, L_obs, N_vars = x.shape
# Extract time feature (assuming index 0 is time-of-day)
time_features = x_mark[:, :, [0]]
time_features_expanded = time_features.repeat(1, 1, N_vars)
# Flatten: [Batch, L, N] -> [Batch * N, L, 1]
x_flat = x.permute(0, 2, 1).reshape(B * N_vars, L_obs, 1)
t_flat = time_features_expanded.permute(0, 2, 1).reshape(B * N_vars, L_obs, 1)
mask_flat = x_mask.permute(0, 2, 1).reshape(B * N_vars, L_obs, 1)
return x_flat, t_flat, mask_flat, time_features
def _process_local_branch(
self, x: torch.Tensor, time_features: torch.Tensor, x_mask: torch.Tensor
) -> torch.Tensor:
"""
Handles local patch creation, padding, time encoding, and the attention mechanism.
"""
B, L_obs, N_vars = x.shape
# 1. Padding to ensure divisibility by n_patch
remainder = L_obs % self.n_patch
if remainder != 0:
padding_len = self.n_patch - remainder
pad_x = torch.zeros(
(B, padding_len, N_vars), device=x.device, dtype=x.dtype
)
pad_mask = torch.zeros(
(B, padding_len, N_vars), device=x.device, dtype=x_mask.dtype
)
pad_t = torch.zeros(
(B, padding_len, 1), device=x.device, dtype=time_features.dtype
)
x_padded = torch.cat([x, pad_x], dim=1)
mask_padded = torch.cat([x_mask, pad_mask], dim=1)
t_padded = torch.cat([time_features, pad_t], dim=1)
else:
x_padded, mask_padded, t_padded = x, x_mask, time_features
# 2. Reshape into patches
L_padded = x_padded.shape[1]
patch_len = L_padded // self.n_patch
X_patches = x_padded.reshape(B, self.n_patch, patch_len, N_vars)
Mask_patches = mask_padded.reshape(B, self.n_patch, patch_len, N_vars)
T_patches = t_padded.reshape(B, self.n_patch, patch_len, 1).repeat(
1, 1, 1, N_vars
)
# 3. Permute for processing: [Total_Patches, Patch_Len, 1]
X_patches = X_patches.permute(0, 3, 1, 2).reshape(-1, patch_len, 1)
Mask_patches = Mask_patches.permute(0, 3, 1, 2).reshape(-1, patch_len, 1)
T_patches = T_patches.permute(0, 3, 1, 2).reshape(-1, patch_len, 1)
# 4. Apply Time Encoding and Model Logic
te_his = self.LearnableTE(T_patches)
X_with_te = torch.cat([X_patches, te_his], dim=-1)
return self.IMTS_Model_Logic(X_with_te, Mask_patches, B)
def _decode_prediction(self, h: torch.Tensor, y_mark: torch.Tensor) -> torch.Tensor:
"""
Expands the latent state and projects it to the output horizon using the decoder.
"""
B, N_vars, _ = h.shape
time_steps_to_predict = y_mark[:, :, [0]]
L_pred = time_steps_to_predict.shape[1]
# Expand hidden state: [B, N, D] -> [B, N, L_pred, D]
h_expanded = h.unsqueeze(dim=-2).repeat(1, 1, L_pred, 1)
# Prepare future time encodings
time_steps_exp = time_steps_to_predict.view(B, 1, L_pred, 1).repeat(
1, N_vars, 1, 1
)
te_pred = self.LearnableTE(time_steps_exp)
# Concatenate and decode
decoder_input = torch.cat([h_expanded, te_pred], dim=-1)
outputs_raw = self.decoder(decoder_input)
# Result shape: [B, L_pred, N]
return outputs_raw.squeeze(-1).permute(0, 2, 1)
def _add_seasonal_extrapolation(
self,
base_prediction: torch.Tensor,
raw_spectrum: torch.Tensor,
y_mark: torch.Tensor,
) -> torch.Tensor:
"""
Computes the harmonic seasonal forecast and adds it to the base prediction.
"""
B, L_pred, N_vars = base_prediction.shape
time_steps_to_predict = y_mark[:, :, [0]]
# Prepare future timestamps: [B*N, L_pred, 1]
t_pred_flat = (
time_steps_to_predict.repeat(1, 1, N_vars)
.permute(0, 2, 1)
.reshape(B * N_vars, -1, 1)
)
# Harmonic Forecasting
seasonal_forecast_flat = self.global_nudft.forecast(raw_spectrum, t_pred_flat)
# Reshape to [B, L_pred, N_vars]
seasonal_forecast = seasonal_forecast_flat.view(B, N_vars, -1).permute(0, 2, 1)
return base_prediction + self.seasonal_weight * seasonal_forecast
def _compute_reconstruction_loss(
self,
raw_spectrum: torch.Tensor,
x_flat: torch.Tensor,
t_flat: torch.Tensor,
mask_flat: torch.Tensor,
) -> torch.Tensor:
"""
Calculates the self-supervised reconstruction loss in the time domain.
"""
# Reconstruct history from spectrum
recon_history_flat = self.global_nudft.forecast(raw_spectrum, t_flat)
# Calculate masked MAE
recon_err = torch.abs(recon_history_flat - x_flat)
recon_err_masked = recon_err * mask_flat
valid_points = mask_flat.sum()
return recon_err_masked.sum() / (valid_points + 1e-5)
def forward(
self,
x: torch.Tensor,
x_mark: torch.Tensor,
x_mask: torch.Tensor,
y_mark: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the model.
Pipeline:
1. Data Preparation
2. Global Frequency Branch (NUDFT)
3. Local Time Branch (Patch Attention)
4. Feature Fusion
5. Decoding
6. Seasonal Extrapolation
7. Loss Calculation
"""
B, L_obs, N_vars = x.shape
# 1. Prepare Inputs for Global Branch
x_flat, t_flat, mask_flat, time_features = self._prepare_global_inputs(
x, x_mark, x_mask
)
# 2. Global Frequency Branch
h_freq_flat, raw_spectrum = self.global_nudft(x_flat, t_flat, mask_flat)
h_freq = h_freq_flat.view(B, N_vars, -1)
# 3. Local Time Branch (Patch Processing)
h_time = self._process_local_branch(x, time_features, x_mask)
# 4. Feature Fusion
h = h_time + self.fusion_weight * h_freq
# 5. Decoding (Base Prediction)
outputs = self._decode_prediction(h, y_mark)
# 6. Seasonal Extrapolation (Harmonic Bias)
outputs = self._add_seasonal_extrapolation(outputs, raw_spectrum, y_mark)
# 7. Reconstruction Loss (Self-Supervision)
recon_loss = self._compute_reconstruction_loss(
raw_spectrum, x_flat, t_flat, mask_flat
)
return outputs, recon_loss
class Model(nn.Module):
"""
High-level wrapper for the TFMixer forecasting model.
Handles dictionary-based input/output for compatibility with experiment runners.
"""
def __init__(self, configs):
super(Model, self).__init__()
self.configs = configs
self.task_name = configs.task_name
self.revin = MaskedRevIN(configs.enc_in)
self.model = TFMixer_SubModel(configs)
def forward(
self, x: torch.Tensor, x_mark: torch.Tensor, x_mask: torch.Tensor, **kwargs
) -> dict:
"""
Args:
x: Input sequence.
x_mark: Input time features.
x_mask: Input validity mask.
**kwargs: Contains 'y', 'y_mark', 'y_mask'.
Returns:
dict: {
"pred": Predictions,
"true": Ground Truth (if available),
"mask": Ground Truth mask,
"recon_loss": Reconstruction loss
}
"""
if self.configs.mask_revin == True:
x = self.revin(x, x_mask, "norm")
y_mark = kwargs.get("y_mark")
predictions, recon_loss = self.model(x, x_mark, x_mask, y_mark)
if self.configs.mask_revin == True:
predictions = self.revin(predictions, None, "denorm")
y = kwargs.get("y")
y_mask = kwargs.get("y_mask")
# Determine feature slicing based on forecasting type (Multivariate/Univariate)
f_dim = -1 if self.configs.features == "MS" else 0
return {
"pred": predictions[:, :, f_dim:],
"true": y[:, :, f_dim:] if y is not None else None,
"mask": y_mask[:, :, f_dim:] if y_mask is not None else None,
"recon_loss": recon_loss,
}