From 53e24311d9107e48bf063e57fbdbcab94f8c59c9 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 13 Jun 2021 22:49:37 -0400 Subject: [PATCH 1/9] Add LayerNorm in 3C1F for Fashion-MNIST. --- models/mnist/convnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/models/mnist/convnet.py b/models/mnist/convnet.py index 3506f11..577395e 100644 --- a/models/mnist/convnet.py +++ b/models/mnist/convnet.py @@ -12,14 +12,18 @@ def __init__(self, num_classes=10, **kwargs): super(ConvNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([128, 28, 28]), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([128, 28, 28]), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([128, 28, 28]), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Flatten(), nn.Linear(9*9*128, 500), + nn.LayerNorm([500]), nn.ReLU(), nn.Linear(500, 10), ) From 26738a2d03b2439400ea3ea4501e135f18daec42 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Mon, 14 Jun 2021 04:31:15 -0400 Subject: [PATCH 2/9] Try take mean and var over surface only (not volume) for 4D tensor. --- models/mnist/convnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/mnist/convnet.py b/models/mnist/convnet.py index 577395e..4a9f5c1 100644 --- a/models/mnist/convnet.py +++ b/models/mnist/convnet.py @@ -12,13 +12,13 @@ def __init__(self, num_classes=10, **kwargs): super(ConvNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([128, 28, 28]), + nn.LayerNorm([28, 28]), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([128, 28, 28]), + nn.LayerNorm([28, 28]), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([128, 28, 28]), + nn.LayerNorm([28, 28]), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Flatten(), From 781daaf6425bc1df7b7054b2c5e719ccf7550626 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Mon, 14 Jun 2021 04:32:41 -0400 Subject: [PATCH 3/9] Add LayerNorm to main loop. --- main.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 4b894d4..4af2f62 100644 --- a/main.py +++ b/main.py @@ -587,7 +587,30 @@ def closure(): gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) - + elif isinstance(m, nn.LayerNorm): + I, G = m.I, m.G + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + x_hat = (I - mean) / (var + m.eps).sqrt() + + if len(I.shape) == 2: + J = G * x_hat + else: + J = torch.einsum('ncf,ncf->nf', G, x_hat) + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad.reshape(-1)) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + damp * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad.reshape(-1) - gv) / damp + update = update.reshape(m.weight.grad.shape) + m.weight.grad.copy_(update) # last part of SMW formula From 8825596f7078407a190383f24ce6841706474db6 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 16 Jun 2021 13:31:15 -0400 Subject: [PATCH 4/9] Add bias. --- main.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 4af2f62..ca572c4 100644 --- a/main.py +++ b/main.py @@ -589,14 +589,15 @@ def closure(): m.weight.grad.copy_(update) elif isinstance(m, nn.LayerNorm): I, G = m.I, m.G - mean = I.mean(dim=-1).unsqueeze(-1) - var = I.var(dim=-1, unbiased=False).unsqueeze(-1) - x_hat = (I - mean) / (var + m.eps).sqrt() - if len(I.shape) == 2: - J = G * x_hat + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) else: - J = torch.einsum('ncf,ncf->nf', G, x_hat) + mean = I.mean((-2, -1), keepdims=True) + var = I.var((-2, -1), unbiased=False, keepdims=True) + x_hat = (I - mean) / (var + m.eps).sqrt() + + J = G * x_hat J = J.reshape(J.shape[0], -1) JJT = torch.matmul(J, J.t()) @@ -611,6 +612,24 @@ def closure(): update = (grad.reshape(-1) - gv) / damp update = update.reshape(m.weight.grad.shape) m.weight.grad.copy_(update) + + grad = m.bias.grad.reshape(-1) + + J = G + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + damp * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad - gv) / damp + update = update.reshape(m.bias.grad.shape) + m.bias.grad.copy_(update) # last part of SMW formula From 950a0854db958a139fe9c34d68ba838d521729bc Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 16 Jun 2021 23:02:00 -0400 Subject: [PATCH 5/9] Update model. --- models/mnist/convnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/mnist/convnet.py b/models/mnist/convnet.py index 4a9f5c1..577395e 100644 --- a/models/mnist/convnet.py +++ b/models/mnist/convnet.py @@ -12,13 +12,13 @@ def __init__(self, num_classes=10, **kwargs): super(ConvNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([28, 28]), + nn.LayerNorm([128, 28, 28]), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([28, 28]), + nn.LayerNorm([128, 28, 28]), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([28, 28]), + nn.LayerNorm([128, 28, 28]), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Flatten(), From f94a1b2638cbaf7da73a365b2f008a7254788e50 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 17 Jun 2021 08:13:22 -0400 Subject: [PATCH 6/9] Update model. --- models/mnist/toy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/models/mnist/toy.py b/models/mnist/toy.py index b632b1e..dd2c795 100644 --- a/models/mnist/toy.py +++ b/models/mnist/toy.py @@ -11,12 +11,15 @@ def __init__(self, num_classes=10, **kwargs): super(ToyNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([16, 28, 28]), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([16, 9, 9]), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([16, 3, 3]), nn.ReLU(), nn.Flatten(), nn.Linear(3*3*16, 10) From 9178beb9e5c9b8e75f5e856c33ceb6c4b1b3cedc Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 17 Jun 2021 08:14:21 -0400 Subject: [PATCH 7/9] Save Fisher inv block of LayerNorm. --- main.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/main.py b/main.py index ca572c4..c42ea21 100644 --- a/main.py +++ b/main.py @@ -728,6 +728,24 @@ def closure(): np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 + elif m.__class__.__name__ == 'LayerNorm': + with torch.no_grad(): + I, G = m.I, m.G + if len(I.shape) == 2: + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + else: + mean = I.mean((-2, -1), keepdims=True) + var = I.var((-2, -1), unbiased=False, keepdims=True) + x_hat = (I - mean) / (var + m.eps).sqrt() + + J = G * x_hat + J = J.reshape(J.shape[0], -1) + JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size + + with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: + np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) + count += 1 elif optim_name == 'exact_ngd': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: From b6a1ab82b4d6dc46d3254d6135c31a41469f3908 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Fri, 18 Jun 2021 23:31:35 -0400 Subject: [PATCH 8/9] Add save NGD_kernel option. --- main.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index c42ea21..d6ea5b8 100644 --- a/main.py +++ b/main.py @@ -50,6 +50,7 @@ parser.add_argument('--load_path', default='', type=str) parser.add_argument('--log_dir', default='runs/pretrain', type=str) parser.add_argument('--save_inv', default='false', type=str) +parser.add_argument('--save_kernel', default='false', type=str) parser.add_argument('--optimizer', default='kfac', type=str) @@ -182,6 +183,8 @@ buf[name] = torch.zeros_like(param.data).to(args.device) if args.save_inv == 'true': os.mkdir('ngd') + if args.save_kernel == 'true': + os.mkdir('ngd_kernel') elif optim_name == 'exact_ngd': print('Exact NGD optimizer selected.') @@ -457,7 +460,7 @@ def closure(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) if args.trial == 'true': - update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt) + update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt, save_kernel=args.save_kernel) else: update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient) @@ -696,6 +699,22 @@ def closure(): train_loss = train_loss/(batch_idx + 1) if args.step_info == 'true': TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time))) + + # save NGD kernels + if args.save_kernel == 'true' and optim_name == 'ngd': + if module_names == 'children': + all_modules = net.children() + elif module_names == 'features': + all_modules = net.features.children() + + count = 0 + for m in all_modules: + if m.__class__.__name__ in ['Linear', 'Conv2d', 'LayerNorm']: + if hasattr(m, "NGD_kernel"): + with open('ngd_kernel/' + str(epoch) + '_m_' + str(count) + '_kernel.npy', 'wb') as f: + np.save(f, m.NGD_kernel.cpu().numpy()) + count += 1 + # save diagonal blocks of exact Fisher inverse or its approximations if args.save_inv == 'true': if module_names == 'children': @@ -835,11 +854,11 @@ def optimal_JJT(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank= update_list[name] = fisher_vals[2] return update_list, loss -def optimal_JJT_v2(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'): +def optimal_JJT_v2(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'): jac_list = 0 vjp = 0 update_list = {} - with backpack(FisherBlockEff(damping, alpha, low_rank, gamma, memory_efficient, super_opt)): + with backpack(FisherBlockEff(damping, alpha, low_rank, gamma, memory_efficient, super_opt, save_kernel)): loss = criterion(outputs, targets) loss.backward() for name, param in net.named_parameters(): From 8103e989e0aa790c0b0079bce73373a6f1573add Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Fri, 18 Jun 2021 23:42:23 -0400 Subject: [PATCH 9/9] Try no learnable parameter for LayerNorm. --- models/mnist/convnet.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/models/mnist/convnet.py b/models/mnist/convnet.py index 577395e..b667634 100644 --- a/models/mnist/convnet.py +++ b/models/mnist/convnet.py @@ -12,18 +12,21 @@ def __init__(self, num_classes=10, **kwargs): super(ConvNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([128, 28, 28]), + # nn.LayerNorm([128, 28, 28], elementwise_affine=False), + nn.LayerNorm([28, 28], elementwise_affine=False), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([128, 28, 28]), + # nn.LayerNorm([128, 28, 28], elementwise_affine=False), + nn.LayerNorm([28, 28], elementwise_affine=False), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), - nn.LayerNorm([128, 28, 28]), + # nn.LayerNorm([128, 28, 28], elementwise_affine=False), + nn.LayerNorm([28, 28], elementwise_affine=False), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Flatten(), nn.Linear(9*9*128, 500), - nn.LayerNorm([500]), + nn.LayerNorm([500], elementwise_affine=False), nn.ReLU(), nn.Linear(500, 10), )