-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
75 lines (56 loc) · 1.97 KB
/
plot.py
File metadata and controls
75 lines (56 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import matplotlib.pyplot as plt
import json, pandas as pd, glob, tqdm
plt.rcParams['svg.fonttype'] = 'none'
tables = {}
for f in tqdm.tqdm(glob.glob('saved/*/log.jsons')):
args, hparams = {}, {}
with open(f) as fh:
for l in fh:
j = json.loads(l)
t = j.pop('table')
if t == 'args': args = j
elif t == 'hyperparams': hparams = j
else:
j.update(args)
j.update(hparams)
tables.setdefault(t, []).append(j)
dfs = {t: pd.DataFrame(v) for t, v in tables.items()}
epoch = dfs['epoch']
epoch['plotdim'] = epoch.apply(lambda x:
6 if x.netspec == 'inf' else x['ndim'], 1)
final = epoch[(epoch.i == 29) & (epoch.dt == 0.5)]
tgtfreq = {f'{f}': sub.t1p.values for f, sub in final.groupby('tgtfreq')}
plt.boxplot(tgtfreq.values(), labels=tgtfreq.keys())
plt.xlabel('Ftarget (Hz)')
plt.ylabel('TOP1 (%)')
plt.savefig('1.svg')
plt.show()
final = final[final.tgtfreq == 10]
for net, sub in final.groupby('netspec'):
sub = sub.sort_values('nhidden') # type: ignore
sub = sub[['nhidden', 't1p']].groupby('nhidden').median()
plt.plot(sub.index, sub.t1p, label=net)
plt.legend(title='Network Dimensions', ncol=2)
plt.xlabel('RSNN Neurons (#)')
plt.ylabel('Accuracy')
plt.savefig('2.svg')
plt.show()
for nh, sub in final.groupby('nhidden'):
sub = sub.sort_values('plotdim') # type: ignore
sub = sub[['plotdim', 't1p']].groupby('plotdim').median()
plt.plot(sub.index, sub.t1p, label=nh)
plt.xticks([0,1,2,3,4,5,6], '0 1 2 3 4 5 inf'.split())
plt.legend(title='Nhidden', ncol=2)
plt.xlabel('Network Dimension')
plt.ylabel('Accuracy')
plt.savefig('3.svg')
plt.show()
df['eps'] = df['netspec'].str.split('e').str[1].astype(float)
for n, sub in finaleps.groupby('nhidden'):
if n not in [10, 50]:
continue
plt.plot(sub.eps * 100, sub.t1p, label=n)
plt.xlabel('Epsilon (%)')
plt.ylabel('Test accuracy (%)')
plt.ylim(50, 100)
plt.legend()