forked from shakes76/PatternAnalysis-2023
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict
More file actions
122 lines (91 loc) · 4.05 KB
/
predict
File metadata and controls
122 lines (91 loc) · 4.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Example Predictions"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
import torch.utils.data
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import modules
import dataset
# Replace with preferred device and local path(s)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Torch version ", torch.__version__)
path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/"
vqvae_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/"
pixelCNN_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/"
# Load the models
vqvae_model = torch.load(vqvae_save_path + "trained_vqvae.pth", map_location=device)
cnn_model = torch.load(pixelCNN_save_path + "PixelCNN model.pth", map_location=device)
encoder = vqvae_model.__getattr__("encoder")
quantiser = vqvae_model.__getattr__("quantiser")
decoder = vqvae_model.__getattr__("decoder")
# Update this parameter if using a newer model
embedding_dim = 32
# Load data
processed_data = dataset.DataPreparer(path, "keras_png_slices_train/", "keras_png_slices_validate/", "keras_png_slices_test/", batch_size=128)
# VQVAE outputs
def show_reconstructed_imgs(num_shown=2):
input_imgs = processed_data.test_dataset[0:num_shown]
input_imgs = input_imgs.to(device, dtype=torch.float32)
with torch.no_grad():
output_imgs, _, _, encoding_indices = vqvae_model(input_imgs)
fig, ax = plt.subplots(num_shown, 3)
plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0)
ax[0, 0].set_title("Input Image")
ax[0, 1].set_title("CodeBook Indices")
ax[0, 2].set_title("Reconstructed Image")
for i in range(num_shown):
for j in range(3):
ax[i, j].axis('off')
ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray')
ax[i, 1].imshow(encoding_indices[i].cpu().numpy(), cmap='gray')
ax[i, 2].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray')
plt.show()
# PixelCNN Indices Generation
def show_generated_indices(shown_imgs=2):
print(" > Showing Images")
# Inputs
test_batch = processed_data.test_dataset[0:shown_imgs]
encoder_output = encoder(test_batch.to(device))
_, _, indices = quantiser(encoder_output)
indices_shape = indices.cpu().numpy().shape
#print("Indices shape is: ", indices.cpu().numpy().shape)
indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2]))
#print("Indices shape is: ", indices.cpu().numpy().shape)
# Masked Inputs (only top half shown to model)
masked_indices = 1*indices
masked_indices[:,:,16:,:] = -1
gen_indices = cnn_model.sample((shown_imgs, 1, 32, 32), ind=masked_indices*1)
fig, ax = plt.subplots(shown_imgs, 3)
for a in ax.flatten():
a.axis('off')
ax[0, 0].set_title("Real")
ax[0, 1].set_title("Masked")
ax[0, 2].set_title("Generated")
for i in range(shown_imgs):
ax[i, 0].imshow(indices[i][0].long().cpu().numpy(), cmap='gray')
ax[i, 1].imshow(masked_indices[i][0].cpu().numpy(), cmap='gray')
ax[i, 2].imshow(gen_indices[i][0].cpu().numpy(), cmap='gray')
plt.show()
def show_generated_output():
# Inputs
test_batch = processed_data.test_dataset[0:1]
encoder_output = encoder(test_batch.to(device))
_, _, indices = quantiser(encoder_output)
indices_shape = indices.cpu().numpy().shape
indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2]))
# Masked Inputs (only top half shown to model)
masked_indices = 1*indices
masked_indices[:,:,:,:] = -1
gen_indices = cnn_model.sample((1, 1, 32, 32), ind=masked_indices*1)
# Change the last 32 using a new model with embedding_dim > 32
plt.imshow(vqvae_model.img_from_indices(gen_indices[0], (1, 32, 32, embedding_dim))[0][0].cpu().numpy(), cmap='gray')
plt.show()
show_reconstructed_imgs()
show_generated_indices()
show_generated_output()