diff --git a/multiverse/classification/_timesnet.py b/multiverse/classification/_timesnet.py index 21fa1ab..b3a311d 100644 --- a/multiverse/classification/_timesnet.py +++ b/multiverse/classification/_timesnet.py @@ -184,7 +184,8 @@ def _resolve_device(self) -> torch.device: return torch.device(self.device) return torch.device("cuda" if torch.cuda.is_available() else "cpu") - def _convert_X(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + @staticmethod + def _preprocess_X(X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Convert aeon numpy3D input to TimesNet layout and create a mask. @@ -296,7 +297,7 @@ def _fit(self, X: np.ndarray, y): "X.shape[2] to match seq_len." ) - x_t, mask = self._convert_X(X) + x_t, mask = self._preprocess_X(X) if self.standardise: self.scaler_ = _StandardisePerChannel().fit(x_t) @@ -432,7 +433,7 @@ def _predict_proba(self, X: np.ndarray) -> np.ndarray: f"seq_len={self.seq_len_}." ) - x_t, mask = self._convert_X(X) + x_t, mask = self._preprocess_X(X) if self.scaler_ is not None: x_t = self.scaler_.transform(x_t)