-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprint_total_cf.py
More file actions
60 lines (44 loc) · 1.9 KB
/
print_total_cf.py
File metadata and controls
60 lines (44 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from sklearn.metrics import confusion_matrix
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
# import csv
subject = [1, 1, 1, 2, 2, 2]
sequences = [2, 3, 4, 2, 3, 4]
gt_mv, pred_mv = np.array([], dtype=int), np.array([], dtype=int)
gt_st, pred_st = np.array([], dtype=int), np.array([], dtype=int)
for sub, seq in zip(subject, sequences):
pred_file = f'DATA/EventFrames/TEST_FULL_SEQUENCE_V2/predictionsv2_subject{sub:02}_seq{seq:02}.csv'
# Read the CSV file
data = pd.read_csv(pred_file)
gt_mv = np.concatenate((gt_mv, data['y_gt_mv'].to_numpy()))
_gt_st= data['y_gt_st'].to_numpy()
mask_pose = _gt_st != -1
_gt_st = _gt_st[mask_pose] - 6
gt_st = np.concatenate((gt_st, _gt_st))
_pred_mv = data['y_pred_mv'].to_numpy()
pred_mv = np.concatenate((pred_mv, _pred_mv))
_pred_st = data['y_pred_st'].to_numpy()
_pred_st = _pred_st[mask_pose] - 6
pred_st = np.concatenate((pred_st, _pred_st))
labels_mv = ['Head', 'Hands', 'RollL', 'RollR', 'Legs', 'Arms']
labels_st = ['LieL', 'LieR', 'LieU', 'LieD']
cm_mv = confusion_matrix(gt_mv, pred_mv)
cm_st = confusion_matrix(gt_st, pred_st)
# Movement labels confusion matrix
disp_mv = ConfusionMatrixDisplay.from_predictions(y_true=gt_mv, y_pred=pred_mv,
display_labels=labels_mv, normalize='true')
disp_mv.ax_.set_title('Movement labels')
disp_mv.plot()
plt.savefig('cm_mv_full_sequence_base.png', dpi=300, bbox_inches='tight')
plt.close()
# Static labels confusion matrix
disp_st = ConfusionMatrixDisplay.from_predictions(y_true=gt_st, y_pred=pred_st,
display_labels=labels_st, normalize='true')
disp_st.ax_.set_title('Static labels')
disp_st.plot()
plt.savefig('cm_st_full_sequence_base.png', dpi=300, bbox_inches='tight')
plt.close()
print('Done')