diff --git a/main.py b/main.py index 4b894d4..d6ea5b8 100644 --- a/main.py +++ b/main.py @@ -50,6 +50,7 @@ parser.add_argument('--load_path', default='', type=str) parser.add_argument('--log_dir', default='runs/pretrain', type=str) parser.add_argument('--save_inv', default='false', type=str) +parser.add_argument('--save_kernel', default='false', type=str) parser.add_argument('--optimizer', default='kfac', type=str) @@ -182,6 +183,8 @@ buf[name] = torch.zeros_like(param.data).to(args.device) if args.save_inv == 'true': os.mkdir('ngd') + if args.save_kernel == 'true': + os.mkdir('ngd_kernel') elif optim_name == 'exact_ngd': print('Exact NGD optimizer selected.') @@ -457,7 +460,7 @@ def closure(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) if args.trial == 'true': - update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt) + update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt, save_kernel=args.save_kernel) else: update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient) @@ -587,7 +590,49 @@ def closure(): gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) - + elif isinstance(m, nn.LayerNorm): + I, G = m.I, m.G + if len(I.shape) == 2: + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + else: + mean = I.mean((-2, -1), keepdims=True) + var = I.var((-2, -1), unbiased=False, keepdims=True) + x_hat = (I - mean) / (var + m.eps).sqrt() + + J = G * x_hat + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad.reshape(-1)) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + damp * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad.reshape(-1) - gv) / damp + update = update.reshape(m.weight.grad.shape) + m.weight.grad.copy_(update) + + grad = m.bias.grad.reshape(-1) + + J = G + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + damp * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad - gv) / damp + update = update.reshape(m.bias.grad.shape) + m.bias.grad.copy_(update) # last part of SMW formula @@ -654,6 +699,22 @@ def closure(): train_loss = train_loss/(batch_idx + 1) if args.step_info == 'true': TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time))) + + # save NGD kernels + if args.save_kernel == 'true' and optim_name == 'ngd': + if module_names == 'children': + all_modules = net.children() + elif module_names == 'features': + all_modules = net.features.children() + + count = 0 + for m in all_modules: + if m.__class__.__name__ in ['Linear', 'Conv2d', 'LayerNorm']: + if hasattr(m, "NGD_kernel"): + with open('ngd_kernel/' + str(epoch) + '_m_' + str(count) + '_kernel.npy', 'wb') as f: + np.save(f, m.NGD_kernel.cpu().numpy()) + count += 1 + # save diagonal blocks of exact Fisher inverse or its approximations if args.save_inv == 'true': if module_names == 'children': @@ -686,6 +747,24 @@ def closure(): np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 + elif m.__class__.__name__ == 'LayerNorm': + with torch.no_grad(): + I, G = m.I, m.G + if len(I.shape) == 2: + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + else: + mean = I.mean((-2, -1), keepdims=True) + var = I.var((-2, -1), unbiased=False, keepdims=True) + x_hat = (I - mean) / (var + m.eps).sqrt() + + J = G * x_hat + J = J.reshape(J.shape[0], -1) + JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size + + with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: + np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) + count += 1 elif optim_name == 'exact_ngd': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: @@ -775,11 +854,11 @@ def optimal_JJT(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank= update_list[name] = fisher_vals[2] return update_list, loss -def optimal_JJT_v2(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'): +def optimal_JJT_v2(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'): jac_list = 0 vjp = 0 update_list = {} - with backpack(FisherBlockEff(damping, alpha, low_rank, gamma, memory_efficient, super_opt)): + with backpack(FisherBlockEff(damping, alpha, low_rank, gamma, memory_efficient, super_opt, save_kernel)): loss = criterion(outputs, targets) loss.backward() for name, param in net.named_parameters(): diff --git a/models/mnist/convnet.py b/models/mnist/convnet.py index 3506f11..b667634 100644 --- a/models/mnist/convnet.py +++ b/models/mnist/convnet.py @@ -12,14 +12,21 @@ def __init__(self, num_classes=10, **kwargs): super(ConvNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1), + # nn.LayerNorm([128, 28, 28], elementwise_affine=False), + nn.LayerNorm([28, 28], elementwise_affine=False), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + # nn.LayerNorm([128, 28, 28], elementwise_affine=False), + nn.LayerNorm([28, 28], elementwise_affine=False), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + # nn.LayerNorm([128, 28, 28], elementwise_affine=False), + nn.LayerNorm([28, 28], elementwise_affine=False), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Flatten(), nn.Linear(9*9*128, 500), + nn.LayerNorm([500], elementwise_affine=False), nn.ReLU(), nn.Linear(500, 10), ) diff --git a/models/mnist/toy.py b/models/mnist/toy.py index b632b1e..dd2c795 100644 --- a/models/mnist/toy.py +++ b/models/mnist/toy.py @@ -11,12 +11,15 @@ def __init__(self, num_classes=10, **kwargs): super(ToyNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([16, 28, 28]), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([16, 9, 9]), nn.ReLU(), nn.MaxPool2d(kernel_size=3), nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1), + nn.LayerNorm([16, 3, 3]), nn.ReLU(), nn.Flatten(), nn.Linear(3*3*16, 10)