diff --git a/cifar_test.py b/cifar_test.py new file mode 100644 index 0000000..a8fe0d7 --- /dev/null +++ b/cifar_test.py @@ -0,0 +1,275 @@ +'''Train CIFAR10 with PyTorch.''' +from __future__ import print_function +import os +import torch.optim as optim +import torch.backends.cudnn as cudnn +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import Subset +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +from scipy.optimize import dual_annealing +import numpy as np + + +class PcConvBp_DS(nn.Module): + def __init__(self, inchan, outchan, kernel_size=3, stride=1, padding=1, cls=0, bias=False, + solver='SGD', num_iterations=5, train_weight=False, noise_level=None): + super().__init__() + self.noise_level = noise_level + self.solver = solver + self.train_weight = train_weight + self.num_iterations = num_iterations + self.padding = padding + self.stride = stride + self.kernel_size = kernel_size + self.C_in = inchan + self.C_out = outchan + self.FFconv = nn.Conv2d(inchan, outchan, self.kernel_size, self.stride, self.padding, bias=bias) + self.FBconv = nn.ConvTranspose2d(outchan, inchan, self.kernel_size, self.stride, self.padding, bias=bias) + self.b0 = nn.ParameterList([nn.Parameter(torch.zeros(1, outchan, 1, 1))]) + self.relu = nn.ReLU(inplace=True) + self.cls = cls + self.bypass = nn.Conv2d(inchan, outchan, kernel_size=1, stride=1, bias=False) + + def forward(self, x, layer_idx): + y = self.relu(self.FFconv(x)) + y = self.find_optimal_r(x, y, layer_idx, solver=self.solver) + y = y + self.bypass(x) + return y + + def find_optimal_r(self, x, y, layer_idx, solver): + if self.train_weight: + expanded_weights = torch.load(f'./expanded_weights_train/expanded_weights_{layer_idx}.pt') + expanded_weights.clone().detach().requires_grad_(True) + else: + expanded_weights = torch.load(f'./expanded_weights/PCN_5/expanded_weights_{layer_idx}.pt') + if self.noise_level is not None: + noise = self.noise_level * torch.randn(expanded_weights.shape) * expanded_weights + noise = noise.to_sparse() + expanded_weights += noise + + flattened_x = torch.flatten(x, start_dim=1).clone().detach() + y = F.pad(y, (self.padding, self.padding, self.padding, self.padding)) + + if solver == 'SGD': + """ Implement with SGD """ + # Initialize flattened_y as a tensor with requires_grad=True + expanded_weights = expanded_weights.to(y.device) + flattened_y = torch.flatten(y, start_dim=1).clone().detach().requires_grad_(True) + energy = 0 + optimizer_y = torch.optim.SGD([flattened_y], lr=0.001) + optimizer_w = torch.optim.SGD([expanded_weights], lr=0.001) if self.train_weight else None + for _ in range(self.num_iterations): + optimizer_y.zero_grad() + # energy = self.Energy_Function(flattened_x, expanded_weights, flattened_y) + energy = torch.norm(flattened_x - flattened_y @ expanded_weights.T, p=2) + energy.backward() + optimizer_y.step() + + if self.train_weight: + for _ in range(5): + optimizer_w.zero_grad() + energy = torch.norm(flattened_x - flattened_y @ expanded_weights.T, p=2) + energy.backward() + optimizer_w.step() + torch.save(expanded_weights, f'./expanded_weights_train/expanded_weights_{layer_idx}.pt') + + elif solver == 'SA': + flattened_x_np = flattened_x.cpu().numpy() + expanded_weights_np = expanded_weights.to_dense().numpy() + flattened_y_np = torch.flatten(y, start_dim=1).cpu().detach() + flattened_y_np = flattened_y_np.numpy() + + def e_f(y, x, W): + energy = np.linalg.norm(x - y @ W.T, ord=2) + return energy.item() + + # Define bounds for each element in flattened_y_np + bounds = [(-2.5, 2.5) for _ in range(flattened_y_np.size)] + + result = dual_annealing(e_f, bounds, x0=np.squeeze(flattened_y_np), args=(flattened_x_np, expanded_weights_np), maxiter=self.num_iterations, maxfun=5) + flattened_y = torch.tensor(result.x, dtype=torch.float32) + + elif solver == 'LD': + expanded_weights = expanded_weights.to(y.device) + c = -2 * torch.sparse.mm(flattened_x, expanded_weights) + def LD(expanded_weights, c, r1, lr=0.001): + # Q is W.T @ W + # c is -2 * r0 @ W + x = r1.squeeze(0) + c = c.squeeze(0) + for i in range(self.num_iterations): + # Perform sparse matrix multiplication instead of forming Q explicitly + gradient = torch.sparse.mm(expanded_weights.T, torch.sparse.mm(expanded_weights, x.unsqueeze(1))).squeeze(1) + c + x = x - lr * gradient + del gradient + return x.view(1, -1) + flattened_y = LD(expanded_weights, c, torch.flatten(y, start_dim=1)) + del c + else: + raise ValueError(f'Solver {solver} not supported') + + # Reshape the flattened_y to the original shape + _, C_in, H_in, W_in = y.shape + H_out = (H_in - self.kernel_size + 2 * self.padding) // self.stride + 1 + W_out = (W_in - self.kernel_size + 2 * self.padding) // self.stride + 1 + optimal_y = flattened_y.view(-1, self.C_out, H_out, W_out) + del flattened_y, flattened_x, expanded_weights + # Cut off the padding area + optimal_y = optimal_y[:, :, self.padding:-self.padding, self.padding:-self.padding] + optimal_y = optimal_y.to(y.device) + + return optimal_y.detach() + + def Energy_Function(self, x, W, y): + energy = torch.sqrt(x @ x.T -2* x @ W @ y.T + y @ (W.T @ W) @ y.T) + return energy + + def expand_weights_to_matrix(self, input_shape, weight_tensor, stride=1, padding=0): + C_in, H_in, W_in = input_shape + C_out, _, K, _ = weight_tensor.shape + + # Compute output dimensions + H_out = (H_in + 2 * padding - K) // stride + 1 + W_out = (W_in + 2 * padding - K) // stride + 1 + + # List to store sparse indices and values + indices = [] + values = [] + + for c_out in range(C_out): + for h in range(H_out): + for w in range(W_out): + start_h = h * stride + start_w = w * stride + filter_idx = c_out * H_out * W_out + h * W_out + w + for c_in in range(C_in): + for i in range(K): + for j in range(K): + input_idx = (c_in * (H_in + 2 * padding) + (start_h + i)) * (W_in + 2 * padding) + (start_w + j) + value = weight_tensor[c_out, c_in, i, j].item() + if value != 0: + indices.append([filter_idx, input_idx]) + values.append(value) + + # Convert to sparse tensor + indices = torch.tensor(indices, dtype=torch.long).t() + values = torch.tensor(values, dtype=torch.float32) + size = (C_out * H_out * W_out, C_in * (H_in + 2 * padding) * (W_in + 2 * padding)) + expanded_weights = torch.sparse_coo_tensor(indices, values, size=size) + + return expanded_weights + + +''' Architecture PredNetBpD ''' +from prednet import PcConvBp +class PredNetBpD(nn.Module): + def __init__(self, num_classes=10, cls=0, Tied = False, + solver=None, layer_number=None, num_iterations=None, train_weight=False, + noise_level=None): + super().__init__() + self.ics = [ 3, 32, 64, 64, 128] # input chanels + self.ocs = [32, 64, 64, 128, 128] # output chanels + self.maxpool = [False, True, False, True, False] # downsample flag + # self.ics = [ 3, 32, 64] # input chanels + # self.ocs = [32, 64, 64] # output chanels + # self.maxpool = [False, True, False] # downsample flag + # self.ics = [3, 64, 64, 128, 128, 256, 256, 512] # input chanels + # self.ocs = [64, 64, 128, 128, 256, 256, 512, 512] # output chanels + # self.maxpool = [False, False, True, False, True, False, False, False] # downsample flag + self.cls = cls # num of time steps + self.nlays = len(self.ics) + + # construct PC layers + # Unlike PCN v1, we do not have a tied version here. We may or may not incorporate a tied version in the future. + if Tied == False: + if solver is None: + print('No solver in used, still using convolution in recurrent layer') + assert layer_number is None, 'layer_number must be None if solver is None' + self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + elif solver in ['SGD', 'SA', 'LD']: + print(f'Solver {solver} is in use') + assert layer_number is not None, 'layer_number must be provided if solver is not None' + assert layer_number <= self.nlays, f'layer_number must be less than or equal to the number of layers: {self.nlays}' + self.PcConvs = nn.ModuleList() + for i in range(self.nlays): + # if i <= (layer_number-1): + if i == (layer_number-1): + self.PcConvs.append(PcConvBp_DS(self.ics[i], self.ocs[i], cls=self.cls, + solver=solver, num_iterations=num_iterations, train_weight=train_weight, + noise_level=noise_level)) + else: + self.PcConvs.append(PcConvBp(self.ics[i], self.ocs[i], cls=self.cls)) + else: + print(f'Solver {solver} not supported') + else: + self.PcConvs = nn.ModuleList([PcConvBpTied(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + if noise_level is not None: + print(f'Adding noise to the solver {solver} with noise level {noise_level}') + self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(self.nlays)]) + # Linear layer + self.linear = nn.Linear(self.ocs[-1], num_classes) + self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU(inplace=True) + self.BNend = nn.BatchNorm2d(self.ocs[-1]) + + def forward(self, x): + for i in range(self.nlays): + x = self.BNs[i](x) + x = self.PcConvs[i](x, i) # ReLU + Conv + if self.maxpool[i]: + x = self.maxpool2d(x) + + # classifier + out = F.avg_pool2d(self.relu(self.BNend(x)), x.size(-1)) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +if __name__ == '__main__': + batchsize = 128 + test_ratio = 1 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f'Using device: {device}') + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) + testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test) + num_samples = len(testset) + subset_size = int(test_ratio * num_samples) + + # Create a subset of the test set + indices = list(range(num_samples)) + subset_indices = indices[:subset_size] + test_subset = Subset(testset, subset_indices) + + # Create a DataLoader for the subset + testloader = torch.utils.data.DataLoader(test_subset, batch_size=batchsize, shuffle=False, num_workers=6) + + # Create an instance of the PredNetBpD class + checkpoint_weight = torch.load('checkpoint/PredNetBpD_5_5CLS_FalseNes_0.001WD_FalseTIED_2REP_best_ckpt.t7', map_location=device) + prednet = PredNetBpD(num_classes=10, cls=5, Tied=False, + solver='SGD', layer_number=5, num_iterations=5, train_weight=False, + noise_level=None) + prednet = prednet.to(device) + prednet = nn.DataParallel(prednet) + prednet.load_state_dict(checkpoint_weight['net']) + # prednet.eval() + total = 0 + correct = 0 + for batch_idx, (inputs, targets) in tqdm(enumerate(testloader), total=len(testloader)): + inputs, targets = inputs.to(device), targets.to(device) + output_tensor = prednet(inputs) + # Get the predicted class + _, predicted = torch.max(output_tensor, 1) + total += targets.size(0) + correct += (predicted == targets).sum().item() + print(f' Temperary Accuracy: {100 * correct / total:.2f}%') + + # Calculate the accuracy + accuracy = 100 * correct / total + print(f'Test Accuracy: {accuracy:.2f}%') \ No newline at end of file diff --git a/main_cifar.py b/main_cifar.py index 2adfd23..b233f99 100644 --- a/main_cifar.py +++ b/main_cifar.py @@ -12,16 +12,16 @@ from utils import progress_bar from torch.autograd import Variable -def main_cifar(model='PredNetBpD', circles=5, gpunum=1, Tied=False, weightDecay=1e-3, nesterov=False): +def main_cifar(model='PredNetBpD_3', circles=5, gpunum=1, Tied=False, weightDecay=1e-3, nesterov=False): use_cuda = True # torch.cuda.is_available() best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch - batchsize = 128 + batchsize = 512 root = './' rep = 1 lr = 0.01 - models = {'PredNetBpD':PredNetBpD} + models = {'PredNetBpD_3':PredNetBpD_3} modelname = model+'_'+str(circles)+'CLS_'+str(nesterov)+'Nes_'+str(weightDecay)+'WD_'+str(Tied)+'TIED_'+str(rep)+'REP' # clearn folder @@ -45,20 +45,19 @@ def main_cifar(model='PredNetBpD', circles=5, gpunum=1, Tied=False, weightDecay= transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) - trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train) + trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=2) - testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform_test) - testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=False, num_workers=2) + testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test) + testloader = torch.utils.data.DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=2) # Model print('==> Building model..') - net = models[model](num_classes=100,cls=circles,Tied=Tied) - + net = models[model](num_classes=10,cls=circles,Tied=Tied) # Define objective function criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr, weight_decay=weightDecay, nesterov=nesterov) - + # Parallel computing if use_cuda: net.cuda() @@ -150,7 +149,6 @@ def decrease_learning_rate(): for param_group in optimizer.param_groups: param_group['lr'] /= 10 - for epoch in range(start_epoch, start_epoch+300): statfile = open(logpath+'training_stats_'+modelname+'.txt', 'a+') if epoch==150 or epoch==225 or epoch == 262: diff --git a/prednet.py b/prednet.py index f639c6a..56aa516 100644 --- a/prednet.py +++ b/prednet.py @@ -25,17 +25,195 @@ def __init__(self, inchan, outchan, kernel_size=3, stride=1, padding=1, cls=0, b self.FBconv = nn.ConvTranspose2d(outchan, inchan, kernel_size, stride, padding, bias=bias) self.b0 = nn.ParameterList([nn.Parameter(torch.zeros(1,outchan,1,1))]) self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() self.cls = cls self.bypass = nn.Conv2d(inchan, outchan, kernel_size=1, stride=1, bias=False) - def forward(self, x): + def forward(self, x, layer_idx): y = self.relu(self.FFconv(x)) - b0 = F.relu(self.b0[0]+1.0).expand_as(y) for _ in range(self.cls): - y = self.FFconv(self.relu(x - self.FBconv(y)))*b0 + y + y = self.FFconv(self.relu(x - self.FBconv(y))) + y y = y + self.bypass(x) return y + def la_sigmoid(self, x): + return 0.5+0.25*x-0.0212*x**3 + +''' Architecture PredNetBpD_5 ''' +class PredNetBpD_5(nn.Module): + def __init__(self, num_classes=10, cls=0, Tied = False): + super().__init__() + self.ics = [3, 32, 64, 64, 128] # input chanels + self.ocs = [32, 64, 64, 128, 128] # output chanels + self.maxpool = [False, True, False, True, False] # downsample flag + self.cls = cls # num of time steps + self.nlays = len(self.ics) + + # construct PC layers + # Unlike PCN v1, we do not have a tied version here. We may or may not incorporate a tied version in the future. + if Tied == False: + self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + else: + self.PcConvs = nn.ModuleList([PcConvBpTied(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(self.nlays)]) + # Linear layer + self.linear = nn.Linear(self.ocs[-1], num_classes) + self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU(inplace=True) + self.BNend = nn.BatchNorm2d(self.ocs[-1]) + + def forward(self, x): + for i in range(self.nlays): + x = self.BNs[i](x) + x = self.PcConvs[i](x, i) # ReLU + Conv + if self.maxpool[i]: + x = self.maxpool2d(x) + + # classifier + out = F.avg_pool2d(self.relu(self.BNend(x)), x.size(-1)) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +''' Architecture PredNetBpD ''' +class PredNetBpD_3(nn.Module): + def __init__(self, num_classes=10, cls=0, Tied = False): + super().__init__() + self.ics = [3, 32, 64] # input chanels + self.ocs = [32, 64, 64] # output chanels + self.maxpool = [False, True, False] # downsample flag + self.cls = cls # num of time steps + self.nlays = len(self.ics) + + # construct PC layers + # Unlike PCN v1, we do not have a tied version here. We may or may not incorporate a tied version in the future. + if Tied == False: + self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + else: + self.PcConvs = nn.ModuleList([PcConvBpTied(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(self.nlays)]) + # Linear layer + self.linear = nn.Linear(self.ocs[-1], num_classes) + self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU(inplace=True) + self.BNend = nn.BatchNorm2d(self.ocs[-1]) + + def forward(self, x): + for i in range(self.nlays): + x = self.BNs[i](x) + x = self.PcConvs[i](x, i) # ReLU + Conv + if self.maxpool[i]: + x = self.maxpool2d(x) + + # classifier + out = F.avg_pool2d(self.relu(self.BNend(x)), x.size(-1)) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +class PcConvBp_SGD(nn.Module): + def __init__(self, inchan, outchan, kernel_size=3, stride=1, padding=1, cls=0, bias=False): + super().__init__() + self.padding = padding + self.stride = stride + self.kernel_size = kernel_size + self.FFconv = nn.Conv2d(inchan, outchan, self.kernel_size, self.stride, self.padding, bias=bias) + self.FBconv = nn.ConvTranspose2d(outchan, inchan, self.kernel_size, self.stride, self.padding, bias=bias) + self.b0 = nn.ParameterList([nn.Parameter(torch.zeros(1, outchan, 1, 1))]) + self.relu = nn.ReLU(inplace=True) + self.cls = cls + self.bypass = nn.Conv2d(inchan, outchan, kernel_size=1, stride=1, bias=False) + + def forward(self, x): + y = self.relu(self.FFconv(x)) + y, energies = self.find_optimal_r(x, y) + y = y + self.bypass(x) + return y + + def find_optimal_r(self, x, y): + weight = self.FBconv.weight.data + expanded_weights = self.expand_weights_to_matrix(y.shape[1:], weight.permute(1, 0, 2, 3), stride=self.stride, padding=self.padding) + expanded_weights = expanded_weights.to(y.device) + flattened_x = x.view(1, -1).clone().detach() + + """ + Implement with SGD + """ + # Initialize flattened_y as a tensor with requires_grad=True + num_iterations = 1 + y = F.pad(y, (self.padding, self.padding, self.padding, self.padding)) + flattened_y = y.view(1, -1).clone().detach().requires_grad_(True) + energy_record = [] + optimizer = torch.optim.SGD([flattened_y], lr=0.01) + for _ in range(num_iterations): + optimizer.zero_grad() + energy = self.Energy_Function(flattened_x, expanded_weights, flattened_y) + energy.backward() + optimizer.step() + energy_record.append(energy.item()) + + # Reshape the flattened_y to the original shape + _, C_in, H_in, W_in = y.shape + C_out, _, K, _ = weight.shape + H_out = (H_in - K + 2 * self.padding) // self.stride + 1 + W_out = (W_in - K + 2 * self.padding) // self.stride + 1 + optimal_y = flattened_y.view(-1, C_out, H_out, W_out) + # Cut off the padding area + optimal_y = optimal_y[:, :, self.padding:-self.padding, self.padding:-self.padding] + optimal_y = optimal_y.to(y.device) + + return optimal_y.detach(), energy_record + + def Energy_Function(self, x, W, y): + energy = -2* x @ W @ y.T + y @ (W.T @ W) @ y.T + return energy + + def expand_weights_to_matrix(self, input_shape, weight_tensor, stride=1, padding=0): + """ + Expand the convolution weights to a matrix suitable for multiplying with a flattened input vector. + + Args: + - input_shape (tuple): Shape of the input (C_in, H_in, W_in) + - weight_tensor (torch.Tensor): Convolution weights of shape (C_out, C_in, K, K) + - stride (int): Stride of the convolution + - padding (int): Padding size + + Returns: + - expanded_weights (torch.Tensor): The expanded weight matrix for matrix multiplication + """ + C_in, H_in, W_in = input_shape + C_out, _, K, _ = weight_tensor.shape + + # Compute output dimensions + H_out = (H_in + 2 * padding - K) // stride + 1 + W_out = (W_in + 2 * padding - K) // stride + 1 + + # Initialize expanded weight matrix + expanded_weights = torch.zeros((C_out * H_out * W_out, C_in * (H_in + 2 * padding) * (W_in + 2 * padding))) + + # Fill the expanded weight matrix + for c_out in range(C_out): + for h in range(H_out): + for w in range(W_out): + # Calculate the starting index for each filter application + start_h = h * stride + start_w = w * stride + # Flattened receptive field index + filter_idx = c_out * H_out * W_out + h * W_out + w + # Fill the appropriate section of the expanded weight matrix + for c_in in range(C_in): + for i in range(K): + for j in range(K): + # Calculate the input index considering padding + input_idx = (c_in * (H_in + 2 * padding) + (start_h + i)) * (W_in + 2 * padding) + (start_w + j) + # Assign the weight to the correct position + expanded_weights[filter_idx, input_idx] = weight_tensor[c_out, c_in, i, j] + + return expanded_weights + + ''' Architecture PredNetBpE ''' class PredNetBpE(nn.Module): def __init__(self, num_classes=1000, cls=0, Tied = False): @@ -77,7 +255,7 @@ def forward(self, x): ''' Architecture PredNetBpD ''' class PredNetBpD(nn.Module): - def __init__(self, num_classes=10, cls=0, Tied = False): + def __init__(self, num_classes=10, cls=0, Tied = False, solver=None): super().__init__() self.ics = [3, 64, 64, 128, 128, 256, 256, 512] # input chanels self.ocs = [64, 64, 128, 128, 256, 256, 512, 512] # output chanels @@ -88,9 +266,18 @@ def __init__(self, num_classes=10, cls=0, Tied = False): # construct PC layers # Unlike PCN v1, we do not have a tied version here. We may or may not incorporate a tied version in the future. if Tied == False: - self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + if solver is None: + print('No solver in used, still using convolution in recurrent layer') + self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + elif solver == 'SGD': + print(f'Solver {solver} is in use') + self.PcConvs = nn.ModuleList([PcConvBp_SGD(3, 64, cls=self.cls)]) + self.PcConvs.extend([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(1, self.nlays)]) + else: + print(f'Solver {solver} not supported') else: self.PcConvs = nn.ModuleList([PcConvBpTied(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) + self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(self.nlays)]) # Linear layer self.linear = nn.Linear(self.ocs[-1], num_classes) @@ -111,7 +298,7 @@ def forward(self, x): out = self.linear(out) return out -''' Architecture PredNetBpD ''' +''' Architecture PredNetBpC ''' class PredNetBpC(nn.Module): def __init__(self, num_classes=10, cls=0, Tied = False): super().__init__() diff --git a/utils.py b/utils.py index 4c9b3f9..42b454d 100644 --- a/utils.py +++ b/utils.py @@ -7,7 +7,8 @@ import sys import time import math - +import torch +import numpy as np import torch.nn as nn import torch.nn.init as init @@ -122,3 +123,5 @@ def format_time(seconds): if f == '': f = '0ms' return f + +