Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@
parser.add_argument('--subsample', default='false', type=str)
parser.add_argument('--num_ss_patches', default=0, type=int)

# warmup learning rate
parser.add_argument('--warmup', default='false', type=str)
parser.add_argument('--warmup-epoch', default=5, type=int)
parser.add_argument('--warmup-init-lr', default=0.01, type=float)

args = parser.parse_args()

# init model
Expand Down Expand Up @@ -792,6 +797,80 @@ def test(epoch):
test_loss = test_loss/(batch_idx + 1)
return acc, test_loss

def warmup(epoch):
torch.set_printoptions(precision=16)
print('=== Warmup ===')
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
step_st_time = time.time()
epoch_time = 0
print('\nKFAC/KBFGS damping: %f' % damping)
print('\nNGD damping: %f' % (damping))

warmup_lr = (args.learning_rate - args.warmup_init_lr) / (args.warmup_epoch - 1) * epoch + args.warmup_init_lr
desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
(tag, warmup_lr, 0, 0, correct, total))

writer.add_scalar('train/lr', warmup_lr, epoch)

prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
for batch_idx, (inputs, targets) in prog_bar:
### new optimizer test
if optim_name in ['kngd'] :
inputs, targets = inputs.to(args.device), targets.to(args.device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
if optimizer.steps % optimizer.freq == 0:
# compute true fisher
optimizer.acc_stats = True
with torch.no_grad():
sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
loss_sample = criterion(outputs, sampled_y)
loss_sample.backward(retain_graph=True)
optimizer.acc_stats = False
optimizer.zero_grad() # clear the gradient for computing true-fisher.
if args.partial_backprop == 'true':
idx = (sampled_y == targets) == False
loss = criterion(outputs[idx,:], targets[idx])
# print('extra:', idx.sum().item())
loss.backward()
optimizer.step()


train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
(tag, warmup_lr, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
prog_bar.set_description(desc, refresh=True)
if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1):
step_saved_time = time.time() - step_st_time
epoch_time += step_saved_time
test_acc, test_loss = test(epoch)
TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total)))
TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc)))
TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1))))
TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss)))
TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time)))
if args.debug_mem == 'true':
TRAIN_INFO['memory'].append(torch.cuda.memory_reserved())
step_st_time = time.time()
net.train()

writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch)
writer.add_scalar('train/acc', 100. * correct / total, epoch)
acc = 100. * correct / total
train_loss = train_loss/(batch_idx + 1)
if args.step_info == 'true':
TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time)))

return acc, train_loss

def optimal_JJT(outputs, targets, batch_size, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false'):
jac_list = 0
vjp = 0
Expand Down Expand Up @@ -830,6 +909,10 @@ def main():
if args.debug_mem == 'true':
TRAIN_INFO['memory'].append(torch.cuda.memory_reserved())
st_time = time.time()

if args.warmup:
for epoch in range(args.warmup_epoch):
train_acc, train_loss = warmup(epoch)
for epoch in range(start_epoch, args.epoch):
ep_st_time = time.time()
train_acc, train_loss = train(epoch)
Expand Down Expand Up @@ -937,7 +1020,6 @@ def get_accuracy(data):

def memory_cleanup(module):
"""Remove I/O stored by backpack during the forward pass.

Deletes the attributes created by `hook_store_io` and `hook_store_shapes`.
"""
# if self.mem_clean_up:
Expand All @@ -955,6 +1037,4 @@ def memory_cleanup(module):
i += 1

if __name__ == '__main__':
main()


main()
27 changes: 19 additions & 8 deletions models/cifar/resnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import absolute_import

__all__ = ['ResNet34']
__all__ = ['ResNet32']

'''ResNet in PyTorch.
Reference:
Expand All @@ -13,6 +13,15 @@
import torch.nn.functional as F


class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd

def forward(self, x):
return self.lambd(x)


class BasicBlock(nn.Module):
expansion = 1

Expand All @@ -27,11 +36,13 @@ def __init__(self, in_planes, planes, stride=1):

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
# self.shortcut = nn.Sequential(
# nn.Conv2d(in_planes, self.expansion*planes,
# kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(self.expansion*planes)
# )
self.shortcut = LambdaLayer(lambda x:
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
Expand Down Expand Up @@ -105,5 +116,5 @@ def forward(self, x):
return out


def ResNet34(**kwargs):
return ResNet(BasicBlock, [5, 5, 5])
def ResNet32(**kwargs):
return ResNet(BasicBlock, [5, 5, 5])
20 changes: 10 additions & 10 deletions optimizers/ngd.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,21 +341,21 @@ def _step(self, closure):
if p.grad is None:
continue
d_p = p.grad.data

# if momentum != 0:
# param_state = self.state[p]
# if 'momentum_buffer' not in param_state:
# buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
# buf.mul_(momentum).add_(d_p)
# else:
# buf = param_state['momentum_buffer']
# buf.mul_(momentum).add_(1, d_p)
# d_p.copy_(buf)

# if weight_decay != 0 and self.steps >= 10 * self.freq:
if weight_decay != 0:
d_p.add_(weight_decay, p.data)

if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1, d_p)
d_p.copy_(buf)

p.data.add_(-group['lr'], d_p)
# print('d_p:', d_p.shape)
# print(d_p)
Expand Down
7 changes: 3 additions & 4 deletions utils/network_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from models.cifar import (alexnet, densenet, ResNet34,
from models.cifar import (alexnet, densenet, ResNet32,
vgg16_bn, vgg19_bn, vgg16, vgg13, vgg11_bn,
wrn, inception, googlenet, xception, nasnet, resnext, mobilenetv2)
from models.mnist import (fc, convnet, bn, toy, autoencoder)
Expand All @@ -12,7 +12,7 @@ def get_network(network, **kwargs):
'convnet': convnet,
'alexnet': alexnet,
'densenet': densenet,
'resnet34': ResNet34,
'resnet32': ResNet32,
'vgg16_bn': vgg16_bn,
'vgg19_bn': vgg19_bn,
'vgg16': vgg16,
Expand All @@ -29,5 +29,4 @@ def get_network(network, **kwargs):

}

return networks[network](**kwargs)

return networks[network](**kwargs)