Skip to content

Commit 12b6b11

Browse files
authored
Merge pull request #5 from LinGinQiu/my-fix
fix: rename _convert_X in TimesNet to avoid BaseClassifier collision
2 parents 31c9c95 + f02b12c commit 12b6b11

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

multiverse/classification/_timesnet.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def _resolve_device(self) -> torch.device:
184184
return torch.device(self.device)
185185
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
186186

187-
def _convert_X(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
187+
@staticmethod
188+
def _preprocess_X(X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
188189
"""
189190
Convert aeon numpy3D input to TimesNet layout and create a mask.
190191
@@ -296,7 +297,7 @@ def _fit(self, X: np.ndarray, y):
296297
"X.shape[2] to match seq_len."
297298
)
298299

299-
x_t, mask = self._convert_X(X)
300+
x_t, mask = self._preprocess_X(X)
300301

301302
if self.standardise:
302303
self.scaler_ = _StandardisePerChannel().fit(x_t)
@@ -432,7 +433,7 @@ def _predict_proba(self, X: np.ndarray) -> np.ndarray:
432433
f"seq_len={self.seq_len_}."
433434
)
434435

435-
x_t, mask = self._convert_X(X)
436+
x_t, mask = self._preprocess_X(X)
436437

437438
if self.scaler_ is not None:
438439
x_t = self.scaler_.transform(x_t)

0 commit comments

Comments
 (0)