-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrocket.py
More file actions
18 lines (14 loc) · 747 Bytes
/
rocket.py
File metadata and controls
18 lines (14 loc) · 747 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from tsai.models.MINIROCKET_Pytorch import MiniRocketFeatures, get_minirocket_features
class MiniRocketTransformer():
def __init__(self, c_in, seq_len, num_features=10000, max_dilations_per_kernel=32, random_state=None, chunksize=1024, **kwargs):
self.chunksize = chunksize
self.mrf = MiniRocketFeatures(c_in, seq_len, num_features=num_features, max_dilations_per_kernel=max_dilations_per_kernel, random_state=random_state)
def fit(self, X, y=None):
self.mrf.fit(X, self.chunksize)
return self
def transform(self, X):
features = get_minirocket_features(X, self.mrf, self.chunksize)
return features.squeeze(-1)
def to(self, device):
self.mrf.to(device)
return self