-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathPredictionUI_MNIST_FlatNet.py
More file actions
77 lines (60 loc) · 2.23 KB
/
PredictionUI_MNIST_FlatNet.py
File metadata and controls
77 lines (60 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import gradio as gr
# DEFINE MODEL (required for loading weights)
class FlattenLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
batch_size = x.shape[0]
return x.reshape(batch_size, -1)
class FlatNet(nn.Module):
def __init__(self):
super().__init__()
self.flatten = FlattenLayer()
self.linear1 = nn.Linear(1*28*28, 512)
self.linear2 = nn.Linear(512, 10)
self.activation_fn = nn.ReLU()
def forward(self, x):
x_flat = self.flatten(x)
x_linear1 = self.linear1(x_flat)
x_linear1_act = self.activation_fn(x_linear1)
class_logits = self.linear2(x_linear1_act)
return class_logits
# LOAD PRE-TRAINED MODEL
model = FlatNet()
checkpoint = torch.load('flatnet_checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
# SWITCH MODEL TO PREDICTION ONLY MODE
model.eval()
# Same transforms that used in training
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Function for processing input image
# Since we're only interested in prediction, we disable the gradient computations
@torch.no_grad()
def recognize_digit(image):
#print(type(image))
#print(image.shape)
image_tensor = transform(image) # 1, 28, 28
image_tensor = image_tensor.unsqueeze(0) # add dummy batch dimension 1, 1, 28, 28
#print(image_tensor.shape)
logits = model(image_tensor)
preds = F.softmax(logits, dim=1) # convert to probabilities
preds_list = preds.tolist()[0] # take the first batch (there is only one)
#print(preds_list)
return {str(i): preds_list[i] for i in range(10)}
# UI for displaying output class probabilities
output_labels = gr.outputs.Label(num_top_classes=3)
# Main UI that contains everything
interface = gr.Interface(fn=recognize_digit,
inputs='sketchpad',
outputs=output_labels,
title='MNIST Drawing Application',
description='Draw a number 0 through 9 on the sketchpad, and click submit to see the model predictions.',
)
interface.launch()