diff --git a/main.py b/main.py index 6c12a77..0e2d825 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ '''Train CIFAR10/CIFAR100 with PyTorch.''' import argparse import os -from optimizers import (KFACOptimizer, SKFACOptimizer, EKFACOptimizer, KBFGSOptimizer, KBFGSLOptimizer, KBFGSL2LOOPOptimizer, KBFGSLMEOptimizer, NGDOptimizer, NGDStreamOptimizer) +from optimizers import (KFACOptimizer, SKFACOptimizer, EKFACOptimizer, KBFGSOptimizer, KBFGSLOptimizer, KBFGSL2LOOPOptimizer, KBFGSLMEOptimizer, NGDOptimizer, NGDStreamOptimizer, NystromOptimizer) import torch import torch.nn as nn import torch.optim as optim @@ -233,6 +233,15 @@ reduce_sum=args.reduce_sum, diag=args.diag) +elif optim_name == 'nystrom': + print('Nystrom optimizer selected') + optimizer = NystromOptimizer(net, + lr=args.learning_rate, + momentum=args.momentum, + damping=args.damping, + weight_decay=args.weight_decay, + freq=args.freq) + elif optim_name == 'ngd_stream': # SAEED: TODO fix batchnorm or remove it totally print('NGD Stream Optimizer selected') @@ -470,7 +479,7 @@ def closure(): optimizer.step() ### new optimizer test - elif optim_name in ['kngd', 'ngd_stream'] : + elif optim_name in ['kngd', 'ngd_stream', 'nystrom'] : inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) diff --git a/optimizers/__init__.py b/optimizers/__init__.py index 39e73a7..85af4d7 100644 --- a/optimizers/__init__.py +++ b/optimizers/__init__.py @@ -7,6 +7,7 @@ from .kbfgsl_mem_eff import KBFGSLMEOptimizer from .ngd import NGDOptimizer from .ngd_stream import NGDStreamOptimizer +from .nystrom import NystromOptimizer def get_optimizer(name): @@ -28,5 +29,7 @@ def get_optimizer(name): return NGDOptimizer elif name == 'ngd_stream': return NGDStreamOptimizer + elif name == 'nystrom': + return NystromOptimizer else: raise NotImplementedError \ No newline at end of file diff --git a/optimizers/nystrom.py b/optimizers/nystrom.py new file mode 100644 index 0000000..491b8c8 --- /dev/null +++ b/optimizers/nystrom.py @@ -0,0 +1,191 @@ +import math + +import torch +import torch.optim as optim + +from utils.nystrom_utils import (ComputeI, ComputeG) +from torch import einsum, eye, matmul, cumsum +from torch.linalg import inv, svd + +class NystromOptimizer(optim.Optimizer): + def __init__(self, + model, + lr=0.01, + momentum=0.9, + damping=0.1, + weight_decay=0.003, + freq=100): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, momentum=momentum, damping=damping, + weight_decay=weight_decay) + + super(NystromOptimizer, self).__init__(model.parameters(), defaults) + + self.known_modules = {'Linear', 'Conv2d'} + self.modules = [] + # self.grad_outputs = {} + self.IHandler = ComputeI() + self.GHandler = ComputeG() + self.model = model + self._prepare_model() + + self.steps = 0 + self.m_C = {} + self.m_I = {} + self.m_G = {} + + self.freq = freq + self.damping = damping + + def _save_input(self, module, input): + # storing the optimized input in forward pass + if torch.is_grad_enabled() and self.steps % self.freq == 0: + II, I = self.IHandler(input[0].data, module) + self.m_I[module] = II, I + + def _save_grad_output(self, module, grad_input, grad_output): + # storing the optimized gradients in backward pass + if self.acc_stats and self.steps % self.freq == 0: + GG, G = self.GHandler(grad_output[0].data, module) + self.m_G[module] = GG, G + + def _prepare_model(self): + count = 0 + print(self.model) + print('NGD keeps the following modules:') + for module in self.model.modules(): + classname = module.__class__.__name__ + if classname in self.known_modules: + self.modules.append(module) + module.register_forward_pre_hook(self._save_input) + module.register_backward_hook(self._save_grad_output) + print('(%s): %s' % (count, module)) + count += 1 + + def _update_inv(self, m): + classname = m.__class__.__name__.lower() + if classname in ['linear', 'conv2d']: + I = self.m_I[m][1] + G = self.m_G[m][1] + n = I.shape[0] + + J = einsum('ni,no->nio', (I, G)).reshape(n, -1) + p = einsum('np->p', J * J) + p_ = p / torch.sum(p) + i = torch.multinomial(p_, num_samples=1) + Ji = J[:,i].reshape(-1) + + C = einsum('n,np->np', Ji, J) + C = einsum('np->p', C) + w = p[i] + s = math.sqrt(1 / (torch.dot(C, C) + self.damping * w)) + + self.m_C[m] = C * s + + else: + raise NotImplementedError + + + def _get_natural_grad(self, m, damping): + grad = m.weight.grad.data + classname = m.__class__.__name__.lower() + + if classname in ['linear', 'conv2d']: + I = self.m_I[m][1] + G = self.m_G[m][1] + n = I.shape[0] + + g = grad.reshape(-1) + C = self.m_C[m] + v = torch.dot(C, g) * C + v = v.view_as(grad) + + updates = (grad - v) / damping, None + + return updates + + + def _kl_clip_and_update_grad(self, updates, lr): + # do kl clip + + # vg_sum = 0 + + # for m in self.model.modules(): + # classname = m.__class__.__name__ + # if classname in self.known_modules: + # v = updates[m] + # vg_sum += (v[0] * m.weight.grad.data).sum().item() + # if m.bias is not None: + # vg_sum += (v[1] * m.bias.grad.data).sum().item() + # elif classname in ['BatchNorm1d', 'BatchNorm2d']: + # vg_sum += (m.weight.grad.data * m.weight.grad.data).sum().item() + # if m.bias is not None: + # vg_sum += (m.bias.grad.data * m.bias.grad.data).sum().item() + + # vg_sum = vg_sum * (lr ** 2) + + # nu = min(1.0, math.sqrt(self.kl_clip / vg_sum)) + + for m in self.model.modules(): + if m.__class__.__name__ in ['Linear', 'Conv2d']: + v = updates[m] + m.weight.grad.data.copy_(v[0]) + # m.weight.grad.data.mul_(nu) + # if v[1] is not None: + # m.bias.grad.data.copy_(v[1]) + # m.bias.grad.data.mul_(nu) + # elif m.__class__.__name__ in ['BatchNorm1d', 'BatchNorm2d']: + # m.weight.grad.data.mul_(nu) + # if m.bias is not None: + # m.bias.grad.data.mul_(nu) + + def _step(self, closure): + # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.) + # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p) + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + for p in group['params']: + # print('=== step ===') + if p.grad is None: + continue + d_p = p.grad.data + + # if momentum != 0: + # param_state = self.state[p] + # if 'momentum_buffer' not in param_state: + # buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) + # buf.mul_(momentum).add_(d_p) + # else: + # buf = param_state['momentum_buffer'] + # buf.mul_(momentum).add_(1, d_p) + # d_p.copy_(buf) + + # if weight_decay != 0 and self.steps >= 10 * self.freq: + if weight_decay != 0: + d_p.add_(weight_decay, p.data) + + p.data.add_(-group['lr'], d_p) + # print('d_p:', d_p.shape) + # print(d_p) + + def step(self, closure=None): + group = self.param_groups[0] + lr = group['lr'] + damping = group['damping'] + updates = {} + for m in self.modules: + classname = m.__class__.__name__ + if self.steps % self.freq == 0: + self._update_inv(m) + v = self._get_natural_grad(m, damping) + updates[m] = v + self._kl_clip_and_update_grad(updates, lr) + + self._step(closure) + self.steps += 1 diff --git a/utils/nystrom_utils.py b/utils/nystrom_utils.py new file mode 100644 index 0000000..45f02a3 --- /dev/null +++ b/utils/nystrom_utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum, matmul +from torch.nn import Unfold + +class ComputeI: + + @classmethod + def compute_cov_a(cls, a, module): + return cls.__call__(a, module) + + @classmethod + def __call__(cls, a, module): + if isinstance(module, nn.Linear): + II, I = cls.linear(a, module) + return II, I + elif isinstance(module, nn.Conv2d): + II, I = cls.conv2d(a, module) + return II, I + else: + # FIXME(CW): for extension to other layers. + # raise NotImplementedError + return None + + @staticmethod + def conv2d(input, module): + f = Unfold( + kernel_size=module.kernel_size, + dilation=module.dilation, + padding=module.padding, + stride=module.stride, + ) + I = f(input) + I = einsum('nil->ni', I) + return None, I + + @staticmethod + def linear(input, module): + I = input + return None, I + +class ComputeG: + + @classmethod + def compute_cov_g(cls, g, module): + """ + :param g: gradient + :param module: the corresponding module + :return: + """ + return cls.__call__(g, module) + + @classmethod + def __call__(cls, g, module): + if isinstance(module, nn.Conv2d): + GG, G = cls.conv2d(g, module) + return GG, G + elif isinstance(module, nn.Linear): + GG, G = cls.linear(g, module) + return GG, G + else: + return None + + + @staticmethod + def conv2d(g, module): + n = g.shape[0] + g_out_sc = n * g + grad_output_viewed = g_out_sc.reshape(g_out_sc.shape[0], g_out_sc.shape[1], -1) + G = grad_output_viewed + G = einsum('nol->no', G) + return None, G + + @staticmethod + def linear(g, module): + n = g.shape[0] + g_out_sc = n * g + G = g_out_sc + module.optimized = True + GG = None + return GG, G