-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
31 lines (26 loc) · 863 Bytes
/
models.py
File metadata and controls
31 lines (26 loc) · 863 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import torch
from torch import nn
from definitions import device
def load_model(model: nn.Module, file_name: str, verbose: bool = True):
model.load_state_dict(torch.load(file_name))
model.eval()
model.to(device)
if verbose:
print(f"Model {file_name} loaded.")
return model
def create_model(x_train: np.array, layer_dims: list = None, verbose: bool = True):
if layer_dims is None:
layer_dims = [100, 50]
n_features = x_train.shape[1]
layer_dims = [n_features] + layer_dims
layers = []
for i in range(1, len(layer_dims)):
layers.append(nn.Linear(layer_dims[i - 1], layer_dims[i]))
layers.append(nn.LeakyReLU())
layers.append(nn.Linear(layer_dims[-1], 1))
model = nn.Sequential(*layers)
if verbose:
print(model)
model.to(device)
return model