forked from saeedsoori/k-fac
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
28 lines (24 loc) · 1006 Bytes
/
plot.py
File metadata and controls
28 lines (24 loc) · 1006 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import argparse
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('--module_idx', default='1', type=str)
parser.add_argument('--scale', default='true', type=str)
args = parser.parse_args()
def plot():
for epoch in ['0', '9', '29', '59']:
for method in ['ngd', 'exact', 'kfac']:
try:
with open(method + '/' + epoch + '_m_' + args.module_idx + '_inv.npy', 'rb') as f:
inv = np.load(f)
if args.scale == 'true':
inv = np.log(np.abs(inv) + 1e-3)
fig, ax = plt.subplots(figsize=(18,18))
im = ax.imshow(inv, cmap='coolwarm', vmin=np.min(inv), vmax=np.max(inv))
fig.colorbar(im, orientation='horizontal')
plt.show()
fig.savefig(method + '/' + epoch + '_m_' + args.module_idx + '_inv.png')
except FileNotFoundError:
print('[Error] missing: epoch = ' + epoch + ' method = ' + method + ' for module ' + args.module_idx)
if __name__ == '__main__':
plot()