diff --git a/main.py b/main.py index 3a52b1c..51f8c29 100644 --- a/main.py +++ b/main.py @@ -99,6 +99,11 @@ parser.add_argument('--subsample', default='false', type=str) parser.add_argument('--num_ss_patches', default=0, type=int) +# warmup learning rate +parser.add_argument('--warmup', default='false', type=str) +parser.add_argument('--warmup-epoch', default=5, type=int) +parser.add_argument('--warmup-init-lr', default=0.01, type=float) + args = parser.parse_args() # init model @@ -792,6 +797,80 @@ def test(epoch): test_loss = test_loss/(batch_idx + 1) return acc, test_loss +def warmup(epoch): + torch.set_printoptions(precision=16) + print('=== Warmup ===') + print('\nEpoch: %d' % epoch) + net.train() + train_loss = 0 + correct = 0 + total = 0 + step_st_time = time.time() + epoch_time = 0 + print('\nKFAC/KBFGS damping: %f' % damping) + print('\nNGD damping: %f' % (damping)) + + warmup_lr = (args.learning_rate - args.warmup_init_lr) / (args.warmup_epoch - 1) * epoch + args.warmup_init_lr + desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % + (tag, warmup_lr, 0, 0, correct, total)) + + writer.add_scalar('train/lr', warmup_lr, epoch) + + prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) + for batch_idx, (inputs, targets) in prog_bar: + ### new optimizer test + if optim_name in ['kngd'] : + inputs, targets = inputs.to(args.device), targets.to(args.device) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, targets) + if optimizer.steps % optimizer.freq == 0: + # compute true fisher + optimizer.acc_stats = True + with torch.no_grad(): + sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) + loss_sample = criterion(outputs, sampled_y) + loss_sample.backward(retain_graph=True) + optimizer.acc_stats = False + optimizer.zero_grad() # clear the gradient for computing true-fisher. + if args.partial_backprop == 'true': + idx = (sampled_y == targets) == False + loss = criterion(outputs[idx,:], targets[idx]) + # print('extra:', idx.sum().item()) + loss.backward() + optimizer.step() + + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % + (tag, warmup_lr, train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) + prog_bar.set_description(desc, refresh=True) + if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1): + step_saved_time = time.time() - step_st_time + epoch_time += step_saved_time + test_acc, test_loss = test(epoch) + TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total))) + TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc))) + TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1)))) + TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss))) + TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time))) + if args.debug_mem == 'true': + TRAIN_INFO['memory'].append(torch.cuda.memory_reserved()) + step_st_time = time.time() + net.train() + + writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch) + writer.add_scalar('train/acc', 100. * correct / total, epoch) + acc = 100. * correct / total + train_loss = train_loss/(batch_idx + 1) + if args.step_info == 'true': + TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time))) + + return acc, train_loss + def optimal_JJT(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false'): jac_list = 0 vjp = 0 @@ -830,6 +909,10 @@ def main(): if args.debug_mem == 'true': TRAIN_INFO['memory'].append(torch.cuda.memory_reserved()) st_time = time.time() + + if args.warmup: + for epoch in range(args.warmup_epoch): + train_acc, train_loss = warmup(epoch) for epoch in range(start_epoch, args.epoch): ep_st_time = time.time() train_acc, train_loss = train(epoch) @@ -937,7 +1020,6 @@ def get_accuracy(data): def memory_cleanup(module): """Remove I/O stored by backpack during the forward pass. - Deletes the attributes created by `hook_store_io` and `hook_store_shapes`. """ # if self.mem_clean_up: @@ -955,6 +1037,4 @@ def memory_cleanup(module): i += 1 if __name__ == '__main__': - main() - - + main() \ No newline at end of file diff --git a/models/cifar/resnet.py b/models/cifar/resnet.py index 1c68511..1407d84 100644 --- a/models/cifar/resnet.py +++ b/models/cifar/resnet.py @@ -1,6 +1,6 @@ from __future__ import absolute_import -__all__ = ['ResNet34'] +__all__ = ['ResNet32'] '''ResNet in PyTorch. Reference: @@ -13,6 +13,15 @@ import torch.nn.functional as F +class LambdaLayer(nn.Module): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + class BasicBlock(nn.Module): expansion = 1 @@ -27,11 +36,13 @@ def __init__(self, in_planes, planes, stride=1): self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) + # self.shortcut = nn.Sequential( + # nn.Conv2d(in_planes, self.expansion*planes, + # kernel_size=1, stride=stride, bias=False), + # nn.BatchNorm2d(self.expansion*planes) + # ) + self.shortcut = LambdaLayer(lambda x: + F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) @@ -105,5 +116,5 @@ def forward(self, x): return out -def ResNet34(**kwargs): - return ResNet(BasicBlock, [5, 5, 5]) +def ResNet32(**kwargs): + return ResNet(BasicBlock, [5, 5, 5]) \ No newline at end of file diff --git a/optimizers/ngd.py b/optimizers/ngd.py index 6577151..987e6c6 100644 --- a/optimizers/ngd.py +++ b/optimizers/ngd.py @@ -341,21 +341,21 @@ def _step(self, closure): if p.grad is None: continue d_p = p.grad.data - - # if momentum != 0: - # param_state = self.state[p] - # if 'momentum_buffer' not in param_state: - # buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) - # buf.mul_(momentum).add_(d_p) - # else: - # buf = param_state['momentum_buffer'] - # buf.mul_(momentum).add_(1, d_p) - # d_p.copy_(buf) # if weight_decay != 0 and self.steps >= 10 * self.freq: if weight_decay != 0: d_p.add_(weight_decay, p.data) + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) + buf.mul_(momentum).add_(d_p) + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(1, d_p) + d_p.copy_(buf) + p.data.add_(-group['lr'], d_p) # print('d_p:', d_p.shape) # print(d_p) diff --git a/utils/network_utils.py b/utils/network_utils.py index d33804f..a3d2e99 100644 --- a/utils/network_utils.py +++ b/utils/network_utils.py @@ -1,4 +1,4 @@ -from models.cifar import (alexnet, densenet, ResNet34, +from models.cifar import (alexnet, densenet, ResNet32, vgg16_bn, vgg19_bn, vgg16, vgg13, vgg11_bn, wrn, inception, googlenet, xception, nasnet, resnext, mobilenetv2) from models.mnist import (fc, convnet, bn, toy, autoencoder) @@ -12,7 +12,7 @@ def get_network(network, **kwargs): 'convnet': convnet, 'alexnet': alexnet, 'densenet': densenet, - 'resnet34': ResNet34, + 'resnet32': ResNet32, 'vgg16_bn': vgg16_bn, 'vgg19_bn': vgg19_bn, 'vgg16': vgg16, @@ -29,5 +29,4 @@ def get_network(network, **kwargs): } - return networks[network](**kwargs) - + return networks[network](**kwargs) \ No newline at end of file