From 22829dba0c1fc5c38ccc66e8c0bb8ab1eec5043b Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 15 Sep 2021 03:20:29 -0400 Subject: [PATCH 1/4] add Nystrom approximation of Fisher. --- main.py | 19 +++++++++++++++++-- optimizers/__init__.py | 3 +++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 6c12a77..f3b096c 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ '''Train CIFAR10/CIFAR100 with PyTorch.''' import argparse import os -from optimizers import (KFACOptimizer, SKFACOptimizer, EKFACOptimizer, KBFGSOptimizer, KBFGSLOptimizer, KBFGSL2LOOPOptimizer, KBFGSLMEOptimizer, NGDOptimizer, NGDStreamOptimizer) +from optimizers import (KFACOptimizer, SKFACOptimizer, EKFACOptimizer, KBFGSOptimizer, KBFGSLOptimizer, KBFGSL2LOOPOptimizer, KBFGSLMEOptimizer, NGDOptimizer, NGDStreamOptimizer, NystromOptimizer) import torch import torch.nn as nn import torch.optim as optim @@ -233,6 +233,21 @@ reduce_sum=args.reduce_sum, diag=args.diag) +elif optim_name == 'nystrom': + print('Nystrom optimizer selected') + optimizer = NystromOptimizer(net, + lr=args.learning_rate, + momentum=args.momentum, + damping=args.damping, + kl_clip=args.kl_clip, + weight_decay=args.weight_decay, + freq=args.freq, + gamma=args.gamma, + low_rank=args.low_rank, + super_opt=args.super_opt, + reduce_sum=args.reduce_sum, + diag=args.diag) + elif optim_name == 'ngd_stream': # SAEED: TODO fix batchnorm or remove it totally print('NGD Stream Optimizer selected') @@ -470,7 +485,7 @@ def closure(): optimizer.step() ### new optimizer test - elif optim_name in ['kngd', 'ngd_stream'] : + elif optim_name in ['kngd', 'ngd_stream', 'nystrom'] : inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) diff --git a/optimizers/__init__.py b/optimizers/__init__.py index 39e73a7..85af4d7 100644 --- a/optimizers/__init__.py +++ b/optimizers/__init__.py @@ -7,6 +7,7 @@ from .kbfgsl_mem_eff import KBFGSLMEOptimizer from .ngd import NGDOptimizer from .ngd_stream import NGDStreamOptimizer +from .nystrom import NystromOptimizer def get_optimizer(name): @@ -28,5 +29,7 @@ def get_optimizer(name): return NGDOptimizer elif name == 'ngd_stream': return NGDStreamOptimizer + elif name == 'nystrom': + return NystromOptimizer else: raise NotImplementedError \ No newline at end of file From 6be385527c278119969ad2101719dd974187d859 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 15 Sep 2021 03:21:51 -0400 Subject: [PATCH 2/4] Nystrom (rank 1) approx of Fisher + SMW for Linear layer. --- optimizers/nystrom.py | 417 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 optimizers/nystrom.py diff --git a/optimizers/nystrom.py b/optimizers/nystrom.py new file mode 100644 index 0000000..31d9b33 --- /dev/null +++ b/optimizers/nystrom.py @@ -0,0 +1,417 @@ +import math + +import torch +import torch.optim as optim + +from utils.nystrom_utils import (ComputeI, ComputeG) +from torch import einsum, eye, matmul, cumsum +from torch.linalg import inv, svd + +class NystromOptimizer(optim.Optimizer): + def __init__(self, + model, + lr=0.01, + momentum=0.9, + damping=0.1, + kl_clip=0.01, + weight_decay=0.003, + freq=100, + gamma=0.9, + low_rank='true', + super_opt='false', + reduce_sum='false', + diag='false'): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, momentum=momentum, damping=damping, + weight_decay=weight_decay) + + super(NystromOptimizer, self).__init__(model.parameters(), defaults) + + self.known_modules = {'Linear', 'Conv2d'} + self.modules = [] + # self.grad_outputs = {} + self.IHandler = ComputeI() + self.GHandler = ComputeG() + self.model = model + self._prepare_model() + + self.steps = 0 + self.m_C = {} + self.m_I = {} + self.m_G = {} + self.m_UV = {} + self.m_NGD_Kernel = {} + self.m_bias_Kernel = {} + + self.kl_clip = kl_clip + self.freq = freq + self.gamma = gamma + self.low_rank = low_rank + self.super_opt = super_opt + self.reduce_sum = reduce_sum + self.diag = diag + self.damping = damping + + def _save_input(self, module, input): + # storing the optimized input in forward pass + if torch.is_grad_enabled() and self.steps % self.freq == 0: + II, I = self.IHandler(input[0].data, module, self.super_opt, self.reduce_sum, self.diag) + self.m_I[module] = II, I + + def _save_grad_output(self, module, grad_input, grad_output): + # storing the optimized gradients in backward pass + if self.acc_stats and self.steps % self.freq == 0: + GG, G = self.GHandler(grad_output[0].data, module, self.super_opt, self.reduce_sum, self.diag) + self.m_G[module] = GG, G + + def _prepare_model(self): + count = 0 + print(self.model) + print('NGD keeps the following modules:') + for module in self.model.modules(): + classname = module.__class__.__name__ + if classname in self.known_modules: + self.modules.append(module) + module.register_forward_pre_hook(self._save_input) + module.register_backward_hook(self._save_grad_output) + print('(%s): %s' % (count, module)) + count += 1 + + def _update_inv(self, m): + classname = m.__class__.__name__.lower() + # print('=== _update_inv ===') + if classname == 'linear': + assert(m.optimized == True) + I = self.m_I[m][1] + # print('I:', I.shape) + G = self.m_G[m][1] + # print('G:', G.shape) + n = I.shape[0] + J = einsum('ni,no->nio', (I, G)).reshape(n, -1) + # print('J:', J.shape) + p = einsum('np->p', J * J) + # print('p:', p.shape) + p_ = p / torch.sum(p) + i = torch.multinomial(p_, num_samples=1) + # print('i:', i) + Ji = J[:,i].reshape(-1) + # print('Ji:', Ji.shape) + C = einsum('n,np->np', Ji, J) + # print('c:', C.shape) + C = einsum('np->p', C) + # print('C:', C.shape) + w = p[i] + # print('w:', w) + s = math.sqrt(1 / (torch.dot(C, C) + self.damping * w)) + # print('s:', s) + self.m_C[m] = C * s + + # II = self.m_I[m][0] + # GG = self.m_G[m][0] + # n = II.shape[0] + + # ### bias kernel is GG (II = all ones) + # bias_kernel = GG / n + # bias_inv = inv(bias_kernel + self.damping * eye(n).to(GG.device)) + # self.m_bias_Kernel[m] = bias_inv + + # NGD_kernel = (II * GG) / n + # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) + # self.m_NGD_Kernel[m] = NGD_inv + + # self.m_I[m] = (None, self.m_I[m][1]) + # self.m_G[m] = (None, self.m_G[m][1]) + # torch.cuda.empty_cache() + elif classname == 'conv2d': + # SAEED: @TODO: we don't need II and GG after computations, clear the memory + if m.optimized == True: + # print('=== optimized ===') + II = self.m_I[m][0] + GG = self.m_G[m][0] + n = II.shape[0] + + NGD_kernel = None + if self.reduce_sum == 'true': + if self.diag == 'true': + NGD_kernel = (II * GG / n) + NGD_inv = torch.reciprocal(NGD_kernel + self.damping) + else: + NGD_kernel = II * GG / n + NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) + else: + NGD_kernel = (einsum('nqlp->nq', II * GG)) / n + NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) + + self.m_NGD_Kernel[m] = NGD_inv + + self.m_I[m] = (None, self.m_I[m][1]) + self.m_G[m] = (None, self.m_G[m][1]) + torch.cuda.empty_cache() + else: + # SAEED: @TODO memory cleanup + I = self.m_I[m][1] + G = self.m_G[m][1] + n = I.shape[0] + AX = einsum("nkl,nml->nkm", (I, G)) + + del I + del G + + AX_ = AX.reshape(n , -1) + out = matmul(AX_, AX_.t()) + + del AX + + NGD_kernel = out / n + ### low-rank approximation of Jacobian + if self.low_rank == 'true': + # print('=== low rank ===') + V, S, U = svd(AX_.T, full_matrices=False) + U = U.t() + V = V.t() + cs = cumsum(S, dim = 0) + sum_s = sum(S) + index = ((cs - self.gamma * sum_s) <= 0).sum() + U = U[:, 0:index] + S = S[0:index] + V = V[0:index, :] + self.m_UV[m] = U, S, V + + del AX_ + + NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(NGD_kernel.device)) + self.m_NGD_Kernel[m] = NGD_inv + + del NGD_inv + self.m_I[m] = None, self.m_I[m][1] + self.m_G[m] = None, self.m_G[m][1] + torch.cuda.empty_cache() + + def _get_natural_grad(self, m, damping): + grad = m.weight.grad.data + classname = m.__class__.__name__.lower() + + if classname == 'linear': + assert(m.optimized == True) + I = self.m_I[m][1] + G = self.m_G[m][1] + n = I.shape[0] + + # print('grad:', grad.shape) + g = grad.reshape(-1) + # print('g:', g.shape) + C = self.m_C[m] + # print('C:', C.shape) + v = torch.dot(C, g) * C + # print('v:', v.shape) + v = v.view_as(grad) + # print('v.view:', v.shape) + + updates = (grad - v) / damping, None + + # NGD_inv = self.m_NGD_Kernel[m] + # grad_prod = einsum("ni,oi->no", (I, grad)) + # grad_prod = einsum("no,no->n", (grad_prod, G)) + + # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() + + # gv = einsum("n,no->no", (v, G)) + # gv = einsum("no,ni->oi", (gv, I)) + # gv = gv / n + + # bias_update = None + # if m.bias is not None: + # grad_bias = m.bias.grad.data + # if self.steps % self.freq == 0: + # grad_prod_bias = einsum("o,no->n", (grad_bias, G)) + # v = matmul(self.m_bias_Kernel[m], grad_prod_bias.unsqueeze(1)).squeeze() + # gv_bias = einsum('n,no->o', (v, G)) + # gv_bias = gv_bias / n + # bias_update = (grad_bias - gv_bias) / damping + # else: + # bias_update = grad_bias + + # updates = (grad - gv)/damping, bias_update + + elif classname == 'conv2d': + raise NotImplementedError + # grad_reshape = grad.reshape(grad.shape[0], -1) + # if m.optimized == True: + # # print('=== optimized ===') + # I = self.m_I[m][1] + # G = self.m_G[m][1] + # n = I.shape[0] + # NGD_inv = self.m_NGD_Kernel[m] + + # if self.reduce_sum == 'true': + # x1 = einsum("nk,mk->nm", (I, grad_reshape)) + # grad_prod = einsum("nm,nm->n", (x1, G)) + + # if self.diag == 'true': + # v = NGD_inv * grad_prod + # else: + # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() + + # gv = einsum("n,nm->nm", (v, G)) + # gv = einsum("nm,nk->mk", (gv, I)) + # else: + # x1 = einsum("nkl,mk->nml", (I, grad_reshape)) + # grad_prod = einsum("nml,nml->n", (x1, G)) + # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() + # gv = einsum("n,nml->nml", (v, G)) + # gv = einsum("nml,nkl->mk", (gv, I)) + # gv = gv.view_as(grad) + # gv = gv / n + + # bias_update = None + # if m.bias is not None: + # bias_update = m.bias.grad.data + + # updates = (grad - gv)/damping, bias_update + + # else: + # # TODO(bmu): fix low rank + # if self.low_rank.lower() == 'true': + # # print("=== low rank ===") + + # ###### using low rank structure + # U, S, V = self.m_UV[m] + # NGD_inv = self.m_NGD_Kernel[m] + # G = self.m_G[m][1] + # n = NGD_inv.shape[0] + + # grad_prod = V @ grad_reshape.t().reshape(-1, 1) + # grad_prod = torch.diag(S) @ grad_prod + # grad_prod = U @ grad_prod + # grad_prod = grad_prod.squeeze() + + # bias_update = None + # if m.bias is not None: + # bias_update = m.bias.grad.data + + # v = matmul(NGD_inv, (grad_prod).unsqueeze(1)).squeeze() + + # gv = U.t() @ v.unsqueeze(1) + # gv = torch.diag(S) @ gv + # gv = V.t() @ gv + + # gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t() + # gv = gv.view_as(grad) + # gv = gv / n + + # updates = (grad - gv)/damping, bias_update + # else: + # I = self.m_I[m][1] + # G = self.m_G[m][1] + # AX = einsum('nkl,nml->nkm', (I, G)) + + # del I + # del G + + # n = AX.shape[0] + + # NGD_inv = self.m_NGD_Kernel[m] + + # grad_prod = einsum('nkm,mk->n', (AX, grad_reshape)) + # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() + # gv = einsum('nkm,n->mk', (AX, v)) + # gv = gv.view_as(grad) + # gv = gv / n + + # bias_update = None + # if m.bias is not None: + # bias_update = m.bias.grad.data + + # updates = (grad - gv) / damping, bias_update + + # del AX + # del NGD_inv + # torch.cuda.empty_cache() + + return updates + + + def _kl_clip_and_update_grad(self, updates, lr): + # do kl clip + + # vg_sum = 0 + + # for m in self.model.modules(): + # classname = m.__class__.__name__ + # if classname in self.known_modules: + # v = updates[m] + # vg_sum += (v[0] * m.weight.grad.data).sum().item() + # if m.bias is not None: + # vg_sum += (v[1] * m.bias.grad.data).sum().item() + # elif classname in ['BatchNorm1d', 'BatchNorm2d']: + # vg_sum += (m.weight.grad.data * m.weight.grad.data).sum().item() + # if m.bias is not None: + # vg_sum += (m.bias.grad.data * m.bias.grad.data).sum().item() + + # vg_sum = vg_sum * (lr ** 2) + + # nu = min(1.0, math.sqrt(self.kl_clip / vg_sum)) + + for m in self.model.modules(): + if m.__class__.__name__ in ['Linear', 'Conv2d']: + v = updates[m] + m.weight.grad.data.copy_(v[0]) + # m.weight.grad.data.mul_(nu) + if v[1] is not None: + m.bias.grad.data.copy_(v[1]) + # m.bias.grad.data.mul_(nu) + # elif m.__class__.__name__ in ['BatchNorm1d', 'BatchNorm2d']: + # m.weight.grad.data.mul_(nu) + # if m.bias is not None: + # m.bias.grad.data.mul_(nu) + + def _step(self, closure): + # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.) + # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p) + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + for p in group['params']: + # print('=== step ===') + if p.grad is None: + continue + d_p = p.grad.data + + # if momentum != 0: + # param_state = self.state[p] + # if 'momentum_buffer' not in param_state: + # buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) + # buf.mul_(momentum).add_(d_p) + # else: + # buf = param_state['momentum_buffer'] + # buf.mul_(momentum).add_(1, d_p) + # d_p.copy_(buf) + + # if weight_decay != 0 and self.steps >= 10 * self.freq: + if weight_decay != 0: + d_p.add_(weight_decay, p.data) + + p.data.add_(-group['lr'], d_p) + # print('d_p:', d_p.shape) + # print(d_p) + + def step(self, closure=None): + group = self.param_groups[0] + lr = group['lr'] + damping = group['damping'] + updates = {} + for m in self.modules: + classname = m.__class__.__name__ + if self.steps % self.freq == 0: + self._update_inv(m) + v = self._get_natural_grad(m, damping) + updates[m] = v + self._kl_clip_and_update_grad(updates, lr) + + self._step(closure) + self.steps += 1 From c1e4e16830d034dc974c57ae79761d2ed3fe097c Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 15 Sep 2021 03:34:05 -0400 Subject: [PATCH 3/4] Nystrom (rank 1) approx of Fisher + SMW + RS for Conv2d layer. --- optimizers/nystrom.py | 169 +++++++++++++++++++++++++---------------- utils/nystrom_utils.py | 145 +++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 65 deletions(-) create mode 100644 utils/nystrom_utils.py diff --git a/optimizers/nystrom.py b/optimizers/nystrom.py index 31d9b33..4051e0c 100644 --- a/optimizers/nystrom.py +++ b/optimizers/nystrom.py @@ -86,7 +86,7 @@ def _update_inv(self, m): classname = m.__class__.__name__.lower() # print('=== _update_inv ===') if classname == 'linear': - assert(m.optimized == True) + # assert(m.optimized == True) I = self.m_I[m][1] # print('I:', I.shape) G = self.m_G[m][1] @@ -128,76 +128,101 @@ def _update_inv(self, m): # self.m_G[m] = (None, self.m_G[m][1]) # torch.cuda.empty_cache() elif classname == 'conv2d': + # assert(m.optimized == True) + I = self.m_I[m][1] + # print('I:', I.shape) + G = self.m_G[m][1] + # print('G:', G.shape) + n = I.shape[0] + J = einsum('ni,no->nio', (I, G)).reshape(n, -1) + # print('J:', J.shape) + p = einsum('np->p', J * J) + # print('p:', p.shape) + p_ = p / torch.sum(p) + i = torch.multinomial(p_, num_samples=1) + # print('i:', i) + Ji = J[:,i].reshape(-1) + # print('Ji:', Ji.shape) + C = einsum('n,np->np', Ji, J) + # print('c:', C.shape) + C = einsum('np->p', C) + # print('C:', C.shape) + w = p[i] + # print('w:', w) + s = math.sqrt(1 / (torch.dot(C, C) + self.damping * w)) + # print('s:', s) + self.m_C[m] = C * s + # SAEED: @TODO: we don't need II and GG after computations, clear the memory - if m.optimized == True: - # print('=== optimized ===') - II = self.m_I[m][0] - GG = self.m_G[m][0] - n = II.shape[0] - - NGD_kernel = None - if self.reduce_sum == 'true': - if self.diag == 'true': - NGD_kernel = (II * GG / n) - NGD_inv = torch.reciprocal(NGD_kernel + self.damping) - else: - NGD_kernel = II * GG / n - NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) - else: - NGD_kernel = (einsum('nqlp->nq', II * GG)) / n - NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) - - self.m_NGD_Kernel[m] = NGD_inv - - self.m_I[m] = (None, self.m_I[m][1]) - self.m_G[m] = (None, self.m_G[m][1]) - torch.cuda.empty_cache() - else: - # SAEED: @TODO memory cleanup - I = self.m_I[m][1] - G = self.m_G[m][1] - n = I.shape[0] - AX = einsum("nkl,nml->nkm", (I, G)) - - del I - del G - - AX_ = AX.reshape(n , -1) - out = matmul(AX_, AX_.t()) - - del AX - - NGD_kernel = out / n - ### low-rank approximation of Jacobian - if self.low_rank == 'true': - # print('=== low rank ===') - V, S, U = svd(AX_.T, full_matrices=False) - U = U.t() - V = V.t() - cs = cumsum(S, dim = 0) - sum_s = sum(S) - index = ((cs - self.gamma * sum_s) <= 0).sum() - U = U[:, 0:index] - S = S[0:index] - V = V[0:index, :] - self.m_UV[m] = U, S, V - - del AX_ - - NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(NGD_kernel.device)) - self.m_NGD_Kernel[m] = NGD_inv - - del NGD_inv - self.m_I[m] = None, self.m_I[m][1] - self.m_G[m] = None, self.m_G[m][1] - torch.cuda.empty_cache() + # if m.optimized == True: + # # print('=== optimized ===') + # II = self.m_I[m][0] + # GG = self.m_G[m][0] + # n = II.shape[0] + + # NGD_kernel = None + # if self.reduce_sum == 'true': + # if self.diag == 'true': + # NGD_kernel = (II * GG / n) + # NGD_inv = torch.reciprocal(NGD_kernel + self.damping) + # else: + # NGD_kernel = II * GG / n + # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) + # else: + # NGD_kernel = (einsum('nqlp->nq', II * GG)) / n + # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) + + # self.m_NGD_Kernel[m] = NGD_inv + + # self.m_I[m] = (None, self.m_I[m][1]) + # self.m_G[m] = (None, self.m_G[m][1]) + # torch.cuda.empty_cache() + # else: + # # SAEED: @TODO memory cleanup + # I = self.m_I[m][1] + # G = self.m_G[m][1] + # n = I.shape[0] + # AX = einsum("nkl,nml->nkm", (I, G)) + + # del I + # del G + + # AX_ = AX.reshape(n , -1) + # out = matmul(AX_, AX_.t()) + + # del AX + + # NGD_kernel = out / n + # ### low-rank approximation of Jacobian + # if self.low_rank == 'true': + # # print('=== low rank ===') + # V, S, U = svd(AX_.T, full_matrices=False) + # U = U.t() + # V = V.t() + # cs = cumsum(S, dim = 0) + # sum_s = sum(S) + # index = ((cs - self.gamma * sum_s) <= 0).sum() + # U = U[:, 0:index] + # S = S[0:index] + # V = V[0:index, :] + # self.m_UV[m] = U, S, V + + # del AX_ + + # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(NGD_kernel.device)) + # self.m_NGD_Kernel[m] = NGD_inv + + # del NGD_inv + # self.m_I[m] = None, self.m_I[m][1] + # self.m_G[m] = None, self.m_G[m][1] + # torch.cuda.empty_cache() def _get_natural_grad(self, m, damping): grad = m.weight.grad.data classname = m.__class__.__name__.lower() if classname == 'linear': - assert(m.optimized == True) + # assert(m.optimized == True) I = self.m_I[m][1] G = self.m_G[m][1] n = I.shape[0] @@ -239,7 +264,21 @@ def _get_natural_grad(self, m, damping): # updates = (grad - gv)/damping, bias_update elif classname == 'conv2d': - raise NotImplementedError + I = self.m_I[m][1] + G = self.m_G[m][1] + n = I.shape[0] + + # print('grad:', grad.shape) + g = grad.reshape(-1) + # print('g:', g.shape) + C = self.m_C[m] + # print('C:', C.shape) + v = torch.dot(C, g) * C + # print('v:', v.shape) + v = v.view_as(grad) + # print('v.view:', v.shape) + + updates = (grad - v) / damping, None # grad_reshape = grad.reshape(grad.shape[0], -1) # if m.optimized == True: # # print('=== optimized ===') diff --git a/utils/nystrom_utils.py b/utils/nystrom_utils.py new file mode 100644 index 0000000..67502ed --- /dev/null +++ b/utils/nystrom_utils.py @@ -0,0 +1,145 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum, matmul +from torch.nn import Unfold + +class ComputeI: + + @classmethod + def compute_cov_a(cls, a, module, super_opt='false', reduce_sum='false', diag='false'): + return cls.__call__(a, module, super_opt, reduce_sum, diag) + + @classmethod + def __call__(cls, a, module, super_opt='false', reduce_sum='false', diag='false'): + if isinstance(module, nn.Linear): + II, I = cls.linear(a, module, super_opt, reduce_sum, diag) + return II, I + elif isinstance(module, nn.Conv2d): + II, I = cls.conv2d(a, module, super_opt, reduce_sum, diag) + return II, I + else: + # FIXME(CW): for extension to other layers. + # raise NotImplementedError + return None + + @staticmethod + def conv2d(input, module, super_opt='false', reduce_sum='false', diag='false'): + f = Unfold( + kernel_size=module.kernel_size, + dilation=module.dilation, + padding=module.padding, + stride=module.stride, + ) + I = f(input) + I = einsum('nil->ni', I) + # N = I.shape[0] + # K = I.shape[1] + # L = I.shape[2] + # M = module.out_channels + # module.param_shapes = [N, K, L, M] + + return None, I + + # if reduce_sum == 'true': + # I = einsum("nkl->nk", I) + # if diag == 'true': + # I /= L + # II = torch.sum(I * I, dim=1) + # else: + # II = einsum("nk,qk->nq", (I, I)) + # module.optimized = True + # return II, I + + # flag = False + # if super_opt == 'true': + # flag = N * (L * L) * (K + M) < K * M * L + N * K * M + # else: + # flag = (L * L) * (K + M) < K * M + + # if flag == True: + # II = einsum("nkl,qkp->nqlp", (I, I)) + # module.optimized = True + # return II, I + # else: + # module.optimized = False + # return None, I + + @staticmethod + def linear(input, module, super_opt='false', reduce_sum='false', diag='false'): + I = input + # II = einsum("ni,li->nl", (I, I)) + module.optimized = True + II = None + return II, I + +class ComputeG: + + @classmethod + def compute_cov_g(cls, g, module, super_opt='false', reduce_sum='false', diag='false'): + """ + :param g: gradient + :param module: the corresponding module + :return: + """ + return cls.__call__(g, module, super_opt, reduce_sum, diag) + + @classmethod + def __call__(cls, g, module, super_opt='false', reduce_sum='false', diag='false'): + if isinstance(module, nn.Conv2d): + GG, G = cls.conv2d(g, module, super_opt, reduce_sum, diag) + return GG, G + elif isinstance(module, nn.Linear): + GG, G = cls.linear(g, module, super_opt, reduce_sum, diag) + return GG, G + else: + return None + + + @staticmethod + def conv2d(g, module, super_opt='false', reduce_sum='false', diag='false'): + n = g.shape[0] + g_out_sc = n * g + grad_output_viewed = g_out_sc.reshape(g_out_sc.shape[0], g_out_sc.shape[1], -1) + G = grad_output_viewed + G = einsum('nol->no', G) + return None, G + + # N = module.param_shapes[0] + # K = module.param_shapes[1] + # L = module.param_shapes[2] + # M = module.param_shapes[3] + + # if reduce_sum == 'true': + # G = einsum("nkl->nk", G) + # if diag == 'true': + # G /= L + # GG = torch.sum(G * G, dim=1) + # else: + # GG = einsum("nk,qk->nq", (G, G)) + # module.optimized = True + # return GG, G + + # flag = False + # if super_opt == 'true': + # flag = N * (L * L) * (K + M) < K * M * L + N * K * M + # else: + # flag = (L * L) * (K + M) < K * M + + # if flag == True : + # GG = einsum("nml,qmp->nqlp", (G, G)) + # module.optimized = True + # return GG, G + # else: + # module.optimized = False + # return None, G + + @staticmethod + def linear(g, module, super_opt='false', reduce_sum='false', diag='false'): + n = g.shape[0] + g_out_sc = n * g + G = g_out_sc + # GG = einsum("no,lo->nl", (G, G)) + module.optimized = True + GG = None + return GG, G From be37bc6eded03ad018b42df0a4a32ec5654db8a6 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 15 Sep 2021 11:27:11 -0400 Subject: [PATCH 4/4] clean up. --- main.py | 16 +-- optimizers/nystrom.py | 287 ++--------------------------------------- utils/nystrom_utils.py | 95 +++----------- 3 files changed, 32 insertions(+), 366 deletions(-) diff --git a/main.py b/main.py index f3b096c..0e2d825 100644 --- a/main.py +++ b/main.py @@ -236,17 +236,11 @@ elif optim_name == 'nystrom': print('Nystrom optimizer selected') optimizer = NystromOptimizer(net, - lr=args.learning_rate, - momentum=args.momentum, - damping=args.damping, - kl_clip=args.kl_clip, - weight_decay=args.weight_decay, - freq=args.freq, - gamma=args.gamma, - low_rank=args.low_rank, - super_opt=args.super_opt, - reduce_sum=args.reduce_sum, - diag=args.diag) + lr=args.learning_rate, + momentum=args.momentum, + damping=args.damping, + weight_decay=args.weight_decay, + freq=args.freq) elif optim_name == 'ngd_stream': # SAEED: TODO fix batchnorm or remove it totally diff --git a/optimizers/nystrom.py b/optimizers/nystrom.py index 4051e0c..491b8c8 100644 --- a/optimizers/nystrom.py +++ b/optimizers/nystrom.py @@ -13,14 +13,8 @@ def __init__(self, lr=0.01, momentum=0.9, damping=0.1, - kl_clip=0.01, weight_decay=0.003, - freq=100, - gamma=0.9, - low_rank='true', - super_opt='false', - reduce_sum='false', - diag='false'): + freq=100): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -44,29 +38,20 @@ def __init__(self, self.m_C = {} self.m_I = {} self.m_G = {} - self.m_UV = {} - self.m_NGD_Kernel = {} - self.m_bias_Kernel = {} - self.kl_clip = kl_clip self.freq = freq - self.gamma = gamma - self.low_rank = low_rank - self.super_opt = super_opt - self.reduce_sum = reduce_sum - self.diag = diag self.damping = damping def _save_input(self, module, input): # storing the optimized input in forward pass if torch.is_grad_enabled() and self.steps % self.freq == 0: - II, I = self.IHandler(input[0].data, module, self.super_opt, self.reduce_sum, self.diag) + II, I = self.IHandler(input[0].data, module) self.m_I[module] = II, I def _save_grad_output(self, module, grad_input, grad_output): # storing the optimized gradients in backward pass if self.acc_stats and self.steps % self.freq == 0: - GG, G = self.GHandler(grad_output[0].data, module, self.super_opt, self.reduce_sum, self.diag) + GG, G = self.GHandler(grad_output[0].data, module) self.m_G[module] = GG, G def _prepare_model(self): @@ -84,293 +69,43 @@ def _prepare_model(self): def _update_inv(self, m): classname = m.__class__.__name__.lower() - # print('=== _update_inv ===') - if classname == 'linear': - # assert(m.optimized == True) + if classname in ['linear', 'conv2d']: I = self.m_I[m][1] - # print('I:', I.shape) G = self.m_G[m][1] - # print('G:', G.shape) n = I.shape[0] - J = einsum('ni,no->nio', (I, G)).reshape(n, -1) - # print('J:', J.shape) - p = einsum('np->p', J * J) - # print('p:', p.shape) - p_ = p / torch.sum(p) - i = torch.multinomial(p_, num_samples=1) - # print('i:', i) - Ji = J[:,i].reshape(-1) - # print('Ji:', Ji.shape) - C = einsum('n,np->np', Ji, J) - # print('c:', C.shape) - C = einsum('np->p', C) - # print('C:', C.shape) - w = p[i] - # print('w:', w) - s = math.sqrt(1 / (torch.dot(C, C) + self.damping * w)) - # print('s:', s) - self.m_C[m] = C * s - - # II = self.m_I[m][0] - # GG = self.m_G[m][0] - # n = II.shape[0] - - # ### bias kernel is GG (II = all ones) - # bias_kernel = GG / n - # bias_inv = inv(bias_kernel + self.damping * eye(n).to(GG.device)) - # self.m_bias_Kernel[m] = bias_inv - # NGD_kernel = (II * GG) / n - # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) - # self.m_NGD_Kernel[m] = NGD_inv - - # self.m_I[m] = (None, self.m_I[m][1]) - # self.m_G[m] = (None, self.m_G[m][1]) - # torch.cuda.empty_cache() - elif classname == 'conv2d': - # assert(m.optimized == True) - I = self.m_I[m][1] - # print('I:', I.shape) - G = self.m_G[m][1] - # print('G:', G.shape) - n = I.shape[0] J = einsum('ni,no->nio', (I, G)).reshape(n, -1) - # print('J:', J.shape) p = einsum('np->p', J * J) - # print('p:', p.shape) p_ = p / torch.sum(p) i = torch.multinomial(p_, num_samples=1) - # print('i:', i) Ji = J[:,i].reshape(-1) - # print('Ji:', Ji.shape) + C = einsum('n,np->np', Ji, J) - # print('c:', C.shape) C = einsum('np->p', C) - # print('C:', C.shape) w = p[i] - # print('w:', w) s = math.sqrt(1 / (torch.dot(C, C) + self.damping * w)) - # print('s:', s) - self.m_C[m] = C * s - - # SAEED: @TODO: we don't need II and GG after computations, clear the memory - # if m.optimized == True: - # # print('=== optimized ===') - # II = self.m_I[m][0] - # GG = self.m_G[m][0] - # n = II.shape[0] - - # NGD_kernel = None - # if self.reduce_sum == 'true': - # if self.diag == 'true': - # NGD_kernel = (II * GG / n) - # NGD_inv = torch.reciprocal(NGD_kernel + self.damping) - # else: - # NGD_kernel = II * GG / n - # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) - # else: - # NGD_kernel = (einsum('nqlp->nq', II * GG)) / n - # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) - - # self.m_NGD_Kernel[m] = NGD_inv - - # self.m_I[m] = (None, self.m_I[m][1]) - # self.m_G[m] = (None, self.m_G[m][1]) - # torch.cuda.empty_cache() - # else: - # # SAEED: @TODO memory cleanup - # I = self.m_I[m][1] - # G = self.m_G[m][1] - # n = I.shape[0] - # AX = einsum("nkl,nml->nkm", (I, G)) - - # del I - # del G - - # AX_ = AX.reshape(n , -1) - # out = matmul(AX_, AX_.t()) - # del AX - - # NGD_kernel = out / n - # ### low-rank approximation of Jacobian - # if self.low_rank == 'true': - # # print('=== low rank ===') - # V, S, U = svd(AX_.T, full_matrices=False) - # U = U.t() - # V = V.t() - # cs = cumsum(S, dim = 0) - # sum_s = sum(S) - # index = ((cs - self.gamma * sum_s) <= 0).sum() - # U = U[:, 0:index] - # S = S[0:index] - # V = V[0:index, :] - # self.m_UV[m] = U, S, V - - # del AX_ + self.m_C[m] = C * s - # NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(NGD_kernel.device)) - # self.m_NGD_Kernel[m] = NGD_inv + else: + raise NotImplementedError - # del NGD_inv - # self.m_I[m] = None, self.m_I[m][1] - # self.m_G[m] = None, self.m_G[m][1] - # torch.cuda.empty_cache() def _get_natural_grad(self, m, damping): grad = m.weight.grad.data classname = m.__class__.__name__.lower() - if classname == 'linear': - # assert(m.optimized == True) - I = self.m_I[m][1] - G = self.m_G[m][1] - n = I.shape[0] - - # print('grad:', grad.shape) - g = grad.reshape(-1) - # print('g:', g.shape) - C = self.m_C[m] - # print('C:', C.shape) - v = torch.dot(C, g) * C - # print('v:', v.shape) - v = v.view_as(grad) - # print('v.view:', v.shape) - - updates = (grad - v) / damping, None - - # NGD_inv = self.m_NGD_Kernel[m] - # grad_prod = einsum("ni,oi->no", (I, grad)) - # grad_prod = einsum("no,no->n", (grad_prod, G)) - - # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() - - # gv = einsum("n,no->no", (v, G)) - # gv = einsum("no,ni->oi", (gv, I)) - # gv = gv / n - - # bias_update = None - # if m.bias is not None: - # grad_bias = m.bias.grad.data - # if self.steps % self.freq == 0: - # grad_prod_bias = einsum("o,no->n", (grad_bias, G)) - # v = matmul(self.m_bias_Kernel[m], grad_prod_bias.unsqueeze(1)).squeeze() - # gv_bias = einsum('n,no->o', (v, G)) - # gv_bias = gv_bias / n - # bias_update = (grad_bias - gv_bias) / damping - # else: - # bias_update = grad_bias - - # updates = (grad - gv)/damping, bias_update - - elif classname == 'conv2d': + if classname in ['linear', 'conv2d']: I = self.m_I[m][1] G = self.m_G[m][1] n = I.shape[0] - # print('grad:', grad.shape) g = grad.reshape(-1) - # print('g:', g.shape) C = self.m_C[m] - # print('C:', C.shape) v = torch.dot(C, g) * C - # print('v:', v.shape) v = v.view_as(grad) - # print('v.view:', v.shape) updates = (grad - v) / damping, None - # grad_reshape = grad.reshape(grad.shape[0], -1) - # if m.optimized == True: - # # print('=== optimized ===') - # I = self.m_I[m][1] - # G = self.m_G[m][1] - # n = I.shape[0] - # NGD_inv = self.m_NGD_Kernel[m] - - # if self.reduce_sum == 'true': - # x1 = einsum("nk,mk->nm", (I, grad_reshape)) - # grad_prod = einsum("nm,nm->n", (x1, G)) - - # if self.diag == 'true': - # v = NGD_inv * grad_prod - # else: - # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() - - # gv = einsum("n,nm->nm", (v, G)) - # gv = einsum("nm,nk->mk", (gv, I)) - # else: - # x1 = einsum("nkl,mk->nml", (I, grad_reshape)) - # grad_prod = einsum("nml,nml->n", (x1, G)) - # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() - # gv = einsum("n,nml->nml", (v, G)) - # gv = einsum("nml,nkl->mk", (gv, I)) - # gv = gv.view_as(grad) - # gv = gv / n - - # bias_update = None - # if m.bias is not None: - # bias_update = m.bias.grad.data - - # updates = (grad - gv)/damping, bias_update - - # else: - # # TODO(bmu): fix low rank - # if self.low_rank.lower() == 'true': - # # print("=== low rank ===") - - # ###### using low rank structure - # U, S, V = self.m_UV[m] - # NGD_inv = self.m_NGD_Kernel[m] - # G = self.m_G[m][1] - # n = NGD_inv.shape[0] - - # grad_prod = V @ grad_reshape.t().reshape(-1, 1) - # grad_prod = torch.diag(S) @ grad_prod - # grad_prod = U @ grad_prod - # grad_prod = grad_prod.squeeze() - - # bias_update = None - # if m.bias is not None: - # bias_update = m.bias.grad.data - - # v = matmul(NGD_inv, (grad_prod).unsqueeze(1)).squeeze() - - # gv = U.t() @ v.unsqueeze(1) - # gv = torch.diag(S) @ gv - # gv = V.t() @ gv - - # gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t() - # gv = gv.view_as(grad) - # gv = gv / n - - # updates = (grad - gv)/damping, bias_update - # else: - # I = self.m_I[m][1] - # G = self.m_G[m][1] - # AX = einsum('nkl,nml->nkm', (I, G)) - - # del I - # del G - - # n = AX.shape[0] - - # NGD_inv = self.m_NGD_Kernel[m] - - # grad_prod = einsum('nkm,mk->n', (AX, grad_reshape)) - # v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() - # gv = einsum('nkm,n->mk', (AX, v)) - # gv = gv.view_as(grad) - # gv = gv / n - - # bias_update = None - # if m.bias is not None: - # bias_update = m.bias.grad.data - - # updates = (grad - gv) / damping, bias_update - - # del AX - # del NGD_inv - # torch.cuda.empty_cache() return updates @@ -401,8 +136,8 @@ def _kl_clip_and_update_grad(self, updates, lr): v = updates[m] m.weight.grad.data.copy_(v[0]) # m.weight.grad.data.mul_(nu) - if v[1] is not None: - m.bias.grad.data.copy_(v[1]) + # if v[1] is not None: + # m.bias.grad.data.copy_(v[1]) # m.bias.grad.data.mul_(nu) # elif m.__class__.__name__ in ['BatchNorm1d', 'BatchNorm2d']: # m.weight.grad.data.mul_(nu) diff --git a/utils/nystrom_utils.py b/utils/nystrom_utils.py index 67502ed..45f02a3 100644 --- a/utils/nystrom_utils.py +++ b/utils/nystrom_utils.py @@ -7,16 +7,16 @@ class ComputeI: @classmethod - def compute_cov_a(cls, a, module, super_opt='false', reduce_sum='false', diag='false'): - return cls.__call__(a, module, super_opt, reduce_sum, diag) + def compute_cov_a(cls, a, module): + return cls.__call__(a, module) @classmethod - def __call__(cls, a, module, super_opt='false', reduce_sum='false', diag='false'): + def __call__(cls, a, module): if isinstance(module, nn.Linear): - II, I = cls.linear(a, module, super_opt, reduce_sum, diag) + II, I = cls.linear(a, module) return II, I elif isinstance(module, nn.Conv2d): - II, I = cls.conv2d(a, module, super_opt, reduce_sum, diag) + II, I = cls.conv2d(a, module) return II, I else: # FIXME(CW): for extension to other layers. @@ -24,7 +24,7 @@ def __call__(cls, a, module, super_opt='false', reduce_sum='false', diag='false' return None @staticmethod - def conv2d(input, module, super_opt='false', reduce_sum='false', diag='false'): + def conv2d(input, module): f = Unfold( kernel_size=module.kernel_size, dilation=module.dilation, @@ -33,71 +33,38 @@ def conv2d(input, module, super_opt='false', reduce_sum='false', diag='false'): ) I = f(input) I = einsum('nil->ni', I) - # N = I.shape[0] - # K = I.shape[1] - # L = I.shape[2] - # M = module.out_channels - # module.param_shapes = [N, K, L, M] - return None, I - - # if reduce_sum == 'true': - # I = einsum("nkl->nk", I) - # if diag == 'true': - # I /= L - # II = torch.sum(I * I, dim=1) - # else: - # II = einsum("nk,qk->nq", (I, I)) - # module.optimized = True - # return II, I - - # flag = False - # if super_opt == 'true': - # flag = N * (L * L) * (K + M) < K * M * L + N * K * M - # else: - # flag = (L * L) * (K + M) < K * M - - # if flag == True: - # II = einsum("nkl,qkp->nqlp", (I, I)) - # module.optimized = True - # return II, I - # else: - # module.optimized = False - # return None, I @staticmethod - def linear(input, module, super_opt='false', reduce_sum='false', diag='false'): - I = input - # II = einsum("ni,li->nl", (I, I)) - module.optimized = True - II = None - return II, I + def linear(input, module): + I = input + return None, I class ComputeG: @classmethod - def compute_cov_g(cls, g, module, super_opt='false', reduce_sum='false', diag='false'): + def compute_cov_g(cls, g, module): """ :param g: gradient :param module: the corresponding module :return: """ - return cls.__call__(g, module, super_opt, reduce_sum, diag) + return cls.__call__(g, module) @classmethod - def __call__(cls, g, module, super_opt='false', reduce_sum='false', diag='false'): + def __call__(cls, g, module): if isinstance(module, nn.Conv2d): - GG, G = cls.conv2d(g, module, super_opt, reduce_sum, diag) + GG, G = cls.conv2d(g, module) return GG, G elif isinstance(module, nn.Linear): - GG, G = cls.linear(g, module, super_opt, reduce_sum, diag) + GG, G = cls.linear(g, module) return GG, G else: return None @staticmethod - def conv2d(g, module, super_opt='false', reduce_sum='false', diag='false'): + def conv2d(g, module): n = g.shape[0] g_out_sc = n * g grad_output_viewed = g_out_sc.reshape(g_out_sc.shape[0], g_out_sc.shape[1], -1) @@ -105,41 +72,11 @@ def conv2d(g, module, super_opt='false', reduce_sum='false', diag='false'): G = einsum('nol->no', G) return None, G - # N = module.param_shapes[0] - # K = module.param_shapes[1] - # L = module.param_shapes[2] - # M = module.param_shapes[3] - - # if reduce_sum == 'true': - # G = einsum("nkl->nk", G) - # if diag == 'true': - # G /= L - # GG = torch.sum(G * G, dim=1) - # else: - # GG = einsum("nk,qk->nq", (G, G)) - # module.optimized = True - # return GG, G - - # flag = False - # if super_opt == 'true': - # flag = N * (L * L) * (K + M) < K * M * L + N * K * M - # else: - # flag = (L * L) * (K + M) < K * M - - # if flag == True : - # GG = einsum("nml,qmp->nqlp", (G, G)) - # module.optimized = True - # return GG, G - # else: - # module.optimized = False - # return None, G - @staticmethod - def linear(g, module, super_opt='false', reduce_sum='false', diag='false'): + def linear(g, module): n = g.shape[0] g_out_sc = n * g G = g_out_sc - # GG = einsum("no,lo->nl", (G, G)) module.optimized = True GG = None return GG, G