-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
33 lines (27 loc) · 847 Bytes
/
Copy pathplot.py
File metadata and controls
33 lines (27 loc) · 847 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from matplotlib import pyplot as plt
def plot_learning_curve(training_res: list, validation_res: list, metric: str, title: str, filename: str):
'''
plots the learning curve
Parameters
----------
training_res : array_like (1d)
the training points to plot
validation_res : array_like (1d)
the validation points to plot
metric : str
the metric that is being plotted
title : str
the title of the plot
filename : str
the file to save the plot
'''
fig, ax = plt.subplots()
x = range(len(training_res))
ax.plot(x, training_res, label="Training " + metric)
ax.plot(x, validation_res, label="Validation " + metric)
ax.legend()
ax.set(xlabel="Epoch", ylabel=metric)
ax.set_title(title)
fig.savefig(filename)
plt.show()
plt.close()