Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions blur.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import math
import numbers
import torch
from torch import nn
from torch.nn import functional as F


class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""

def __init__(self, channels, kernel_size, sigma, dim=2):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim

# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)

# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)

# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))

self.register_buffer('weight', kernel)
self.groups = channels

if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)

def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight, groups=self.groups)
110 changes: 110 additions & 0 deletions datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm import tqdm
import time
import pickle
import random
from blur import GaussianSmoothing


def mnist_loader(train=False):
return torch.utils.data.DataLoader(
datasets.MNIST('../data', train=train, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=64 if train else 1000, shuffle=True)


def save(filename, data):
with open(filename, 'wb') as f:
pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)


def load(filename):
with open(filename, 'rb') as f:
return pickle.load(f)


def generate():
train_loader = mnist_loader(train=True)
test_loader = mnist_loader()
test_loader_vertical_cut = mnist_loader()
test_loader_horizontal_cut = mnist_loader()
test_loader_diagonal_cut = mnist_loader()
test_loader_quarter_cut = mnist_loader()
test_loader_triple_cut = [mnist_loader(), mnist_loader(), mnist_loader()] # 5x5, 7x7 and 9x9
test_loader_triple_cut_noise = [mnist_loader(), mnist_loader(), mnist_loader()]
test_loader_triple_cut_replaced1 = [mnist_loader(), mnist_loader(), mnist_loader()]
test_loader_triple_cut_replaced3 = [mnist_loader(), mnist_loader(), mnist_loader()]
test_loader_triple_cut_blur = [mnist_loader(), mnist_loader(), mnist_loader()]

print('Generating new test sets...')

def get_random_pairs(test_loader, num):
label = test_loader.dataset.test_labels[num]
while True:
pairs = []
while len(set(pairs)) != 3:
pairs = [random.randint(0, 10000 - 1) for i in range(3)]

got_duplicate_label = False
for pair in pairs:
if test_loader.dataset.test_labels[pair] == label:
got_duplicate_label = True

if got_duplicate_label:
continue
else:
return pairs

smoothing = GaussianSmoothing(1, 5, 1)

for num in tqdm(range(0, 10000)):
random_pairs = get_random_pairs(test_loader, num)

sample = test_loader.dataset.test_data[num].type('torch.FloatTensor')
sample = F.pad(sample.reshape(1, 1, 28, 28), (2, 2, 2, 2), mode='reflect')
blur = smoothing(sample).reshape(28, 28)

for x in range(28):
for y in range(28):
if y < 14:
test_loader_vertical_cut.dataset.test_data[num, x, y] = 0
if x < 14:
test_loader_horizontal_cut.dataset.test_data[num, x, y] = 0
if (x < 14 and y > 14) or (x > 14 and y < 14):
test_loader_diagonal_cut.dataset.test_data[num, x, y] = 0
if x < 14 and y < 14:
test_loader_quarter_cut.dataset.test_data[num, x, y] = 0
for i in range(3):
half = i + 2 # squares will have side 2*half + 1
if (10 - half <= x <= 10 + half and 10 - half <= y <= 10 + half) or (22 - half <= x <= 22 + half and 22 - half <= y <= 22 + half) or (12 - half <= x <= 12 + half and 21 - half <= y <= 21 + half):
test_loader_triple_cut[i].dataset.test_data[num, x, y] = 0
test_loader_triple_cut_noise[i].dataset.test_data[num, x, y] = random.randint(0, 255)
test_loader_triple_cut_replaced1[i].dataset.test_data[num, x, y] = test_loader.dataset.test_data[random_pairs[0], x, y]
test_loader_triple_cut_blur[i].dataset.test_data[num, x, y] = blur[x, y]
if 10 - half <= x <= 10 + half and 10 - half <= y <= 10 + half:
test_loader_triple_cut_replaced3[i].dataset.test_data[num, x, y] = test_loader.dataset.test_data[random_pairs[0], x, y]
elif 22 - half <= x <= 22 + half and 22 - half <= y <= 22 + half:
test_loader_triple_cut_replaced3[i].dataset.test_data[num, x, y] = test_loader.dataset.test_data[random_pairs[1], x, y]
elif 12 - half <= x <= 12 + half and 21 - half <= y <= 21 + half:
test_loader_triple_cut_replaced3[i].dataset.test_data[num, x, y] = test_loader.dataset.test_data[random_pairs[2], x, y]

save('data/train_loader.pickle', train_loader)
save('data/test_loader.pickle', test_loader)
save('data/test_loader_vcut.pickle', test_loader_vertical_cut)
save('data/test_loader_hcut.pickle', test_loader_horizontal_cut)
save('data/test_loader_dcut.pickle', test_loader_diagonal_cut)
save('data/test_loader_qcut.pickle', test_loader_quarter_cut)
save('data/test_loader_tcut.pickle', test_loader_triple_cut)
save('data/test_loader_noise.pickle', test_loader_triple_cut_noise)
save('data/test_loader_replaced1.pickle', test_loader_triple_cut_replaced1)
save('data/test_loader_replaced3.pickle', test_loader_triple_cut_replaced3)
save('data/test_loader_blur.pickle', test_loader_triple_cut_blur)

print('Datasets saved')
68 changes: 34 additions & 34 deletions mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
from blur import GaussianSmoothing
import random
import time
from datagen import generate, load
import sys

torch.manual_seed(1)
random.seed(0)

LR = 0.1
MOM = 0.5
Expand Down Expand Up @@ -74,40 +79,22 @@ def test(model, device, test_loader):
device = torch.device('cuda' if use_cuda else 'cpu')
kwargs = {'num_workers': 12, 'pin_memory': True} if use_cuda else {}

if '--generate' in sys.argv:
generate()

def mnist_loader(train=False):
return torch.utils.data.DataLoader(
datasets.MNIST('../data', train=train, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=64 if train else 1000, shuffle=True, **kwargs)
train_loader = load('data/train_loader.pickle')
test_loader = load('data/test_loader.pickle')
test_loader_vertical_cut = load('data/test_loader_vcut.pickle')
test_loader_horizontal_cut = load('data/test_loader_hcut.pickle')
test_loader_diagonal_cut = load('data/test_loader_dcut.pickle')
test_loader_quarter_cut = load('data/test_loader_qcut.pickle')
test_loader_triple_cut = load('data/test_loader_tcut.pickle')
test_loader_triple_cut_noise = load('data/test_loader_noise.pickle')
test_loader_triple_cut_replaced1 = load('data/test_loader_replaced1.pickle')
test_loader_triple_cut_replaced3 = load('data/test_loader_replaced3.pickle')
test_loader_triple_cut_blur = load('data/test_loader_blur.pickle')


train_loader = mnist_loader(train=True)
test_loader = mnist_loader()
test_loader_vertical_cut = mnist_loader()
test_loader_horizontal_cut = mnist_loader()
test_loader_diagonal_cut = mnist_loader()
test_loader_triple_cut = mnist_loader()




print('Generating new test sets...')

for num in tqdm(range(0, 10000)):
for x in range(28):
for y in range(28):
if y < 14:
test_loader_vertical_cut.dataset.test_data[num, x, y] = 0
if x < 14:
test_loader_horizontal_cut.dataset.test_data[num, x, y] = 0
if (x < 14 and y > 14) or (x > 14 and y < 14):
test_loader_diagonal_cut.dataset.test_data[num, x, y] = 0
if (5 < x < 15 and 5 < y < 15) or (17 < x < 27 and 10 < y < 20) or (7 < x < 17 and 16 < y < 26):
test_loader_triple_cut.dataset.test_data[num, x, y] = 0

# import matplotlib.pyplot as plt

# plt.imshow(test_loader.dataset.test_data[343], cmap='gray')
Expand Down Expand Up @@ -212,12 +199,25 @@ def mnist_loader(train=False):
# Testing:

models = [model_normal, model_negative_relu, model_hybrid, model_hybrid_nr, model_hybrid_alt]
model_names = ['Normal:', 'HCUT:', 'VCUT:', 'DCUT:', 'TCUT:']

datasets = [test_loader, test_loader_horizontal_cut, test_loader_vertical_cut, test_loader_diagonal_cut, test_loader_triple_cut]
datasets = [test_loader, test_loader_horizontal_cut, test_loader_vertical_cut, test_loader_diagonal_cut, test_loader_quarter_cut]
dataset_names = ['Normal:', 'HCUT:', 'VCUT:', 'DCUT:', 'QCUT:']

for i in range(3):
size = "{0}x{0}".format(5 + 2 * i)
datasets.append(test_loader_triple_cut[i])
dataset_names.append("TCUT {}:".format(size))
datasets.append(test_loader_triple_cut_blur[i])
dataset_names.append("Blur {}:".format(size))
datasets.append(test_loader_triple_cut_noise[i])
dataset_names.append("Noise {}:".format(size))
datasets.append(test_loader_triple_cut_replaced1[i])
dataset_names.append("Replaced1 {}:".format(size))
datasets.append(test_loader_triple_cut_replaced3[i])
dataset_names.append("Replaced3 {}:".format(size))

for i, dataset in enumerate(datasets):
print('Testing -- ' + model_names[i])
print('Testing -- ' + dataset_names[i])
for model in models:
test(model, device, dataset)

Expand Down