-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_visualization.py
More file actions
43 lines (36 loc) · 964 Bytes
/
run_visualization.py
File metadata and controls
43 lines (36 loc) · 964 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
Run inference and visualization for an input song
Calls analyse/analyser.py, load a model and performs inference
"""
import os
import argparse
import yaml
import numpy as np
import torch
from analyse.analyser import MotiveAnalyser
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="bert.yaml",
type=str
)
parser.add_argument(
"--input_path",
required=True,
type=str
)
parser.add_argument(
"--checkpoint_path",
type=str
)
args = parser.parse_args()
config = yaml.load(open(args.config), Loader=yaml.FullLoader)
if args.checkpoint_path is not None:
config["active_checkpoint"] = args.checkpoint_path # for inference
ma = MotiveAnalyser(config, args.input_path)
ma.load_model()
ma.encode()
ma.perform_clustering()
ma.plot_colored_pr()
ma.plot_hm()