66import time
77
88import pandas as pd
9- from aeon .classification .interval_based import TimeSeriesForestClassifier
9+ from aeon .classification .convolution_based import RocketClassifier
1010
11- from multiverse .datasets import list_datasets , load_dataset
12-
13-
14- def main ():
15- ap = argparse .ArgumentParser ()
16- ap .add_argument ("--datasets" , default = "all" , help = "Comma-separated list, or 'all'" )
17- ap .add_argument ("--out" , default = "benchmark_results.csv" )
18- args = ap .parse_args ()
19-
20- if args .datasets .strip ().lower () == "all" :
21- datasets = list_datasets ()
22- else :
23- datasets = [d .strip () for d in args .datasets .split ("," ) if d .strip ()]
11+ from aeon .datasets import load_classification
12+ from aeon .datasets .tsc_datasets import multivariate_equal_length
2413
14+ def experiment_example ():
15+ datasets = ["BasicMotions" ]
2516 rows = []
2617 for name in datasets :
27- X_train , y_train = load_dataset (name , "train" )
28- X_test , y_test = load_dataset (name , "test" )
18+ X_train , y_train = load_classification (name , "train" )
19+ X_test , y_test = load_classification (name , "test" )
20+
21+ clf = RocketClassifier (n_kernels = 500 , random_state = 0 )
2922
30- clf = TimeSeriesForestClassifier (n_estimators = 200 , random_state = 0 )
3123
3224 t0 = time .time ()
3325 clf .fit (X_train , y_train )
@@ -52,4 +44,4 @@ def main():
5244
5345
5446if __name__ == "__main__" :
55- main ()
47+ experiment_example ()
0 commit comments