Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
alecwangcq committed Mar 11, 2019
0 parents commit 3dec464
Show file tree
Hide file tree
Showing 25 changed files with 2,414 additions and 0 deletions.
140 changes: 140 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
led / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# data
data.cifar10/
data.cifar100/
*.gz

# C extensions
*.so

# Distribution / packaging
.Python
checkpoint/
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
#*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env
*.tar

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

tmp
runs
run

# PyCharm
.idea/

# macOS metadata
.DS_Store
._.DS_Store
._*

#
data/
log/
summary/
data/kernel_toy/*.pth
data/AS/gp-structure-search
#*.data
data/mnist_data
*.npz
*.txt
#*.png
#*.pdf
*.jpeg
*.jpg
#results/
*.pyc
*__pycache__

checkpoint/
runs/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# K-FAC_pytorch
243 changes: 243 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
'''Train CIFAR10/CIFAR100 with PyTorch.'''
import argparse
import os
from optimizers import (KFACOptimizer, EKFACOptimizer)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

from tqdm import tqdm
from tensorboardX import SummaryWriter
from utils.network_utils import get_network
from utils.data_utils import get_dataloader


# fetch args
parser = argparse.ArgumentParser()


parser.add_argument('--network', default='vgg16_bn', type=str)
parser.add_argument('--depth', default=19, type=int)
parser.add_argument('--dataset', default='cifar10', type=str)

# densenet
parser.add_argument('--growthRate', default=12, type=int)
parser.add_argument('--compressionRate', default=2, type=int)

# wrn, densenet
parser.add_argument('--widen_factor', default=1, type=int)
parser.add_argument('--dropRate', default=0.0, type=float)


parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--resume', '-r', action='store_true')
parser.add_argument('--load_path', default='', type=str)
parser.add_argument('--log_dir', default='runs/pretrain', type=str)


parser.add_argument('--optimizer', default='kfac', type=str)
parser.add_argument('--batch_size', default=128, type=float)
parser.add_argument('--epoch', default=100, type=int)
parser.add_argument('--milestone', default=None, type=str)
parser.add_argument('--learning_rate', default=0.01, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--stat_decay', default=0.95, type=float)
parser.add_argument('--damping', default=1e-3, type=float)
parser.add_argument('--kl_clip', default=1e-2, type=float)
parser.add_argument('--weight_decay', default=3e-3, type=float)
parser.add_argument('--TCov', default=10, type=int)
parser.add_argument('--TScal', default=10, type=int)
parser.add_argument('--TInv', default=100, type=int)


parser.add_argument('--prefix', default=None, type=str)
args = parser.parse_args()

# init model
nc = {
'cifar10': 10,
'cifar100': 100
}
num_classes = nc[args.dataset]
net = get_network(args.network,
depth=args.depth,
num_classes=num_classes,
growthRate=args.growthRate,
compressionRate=args.compressionRate,
widen_factor=args.widen_factor,
dropRate=args.dropRate)
net = net.to(args.device)

# init dataloader
trainloader, testloader = get_dataloader(dataset=args.dataset,
train_batch_size=args.batch_size,
test_batch_size=256)

# init optimizer and lr scheduler
optim_name = args.optimizer.lower()
tag = optim_name
if optim_name == 'sgd':
optimizer = optim.SGD(net.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)
elif optim_name == 'kfac':
optimizer = KFACOptimizer(net,
lr=args.learning_rate,
momentum=args.momentum,
stat_decay=args.stat_decay,
damping=args.damping,
kl_clip=args.kl_clip,
weight_decay=args.weight_decay,
TCov=args.TCov,
TInv=args.TInv)
elif optim_name == 'ekfac':
optimizer = EKFACOptimizer(net,
lr=args.learning_rate,
momentum=args.momentum,
stat_decay=args.stat_decay,
damping=args.damping,
kl_clip=args.kl_clip,
weight_decay=args.weight_decay,
TCov=args.TCov,
TScal=args.TScal,
TInv=args.TInv)
else:
raise NotImplementedError

if args.milestone is None:
lr_scheduler = MultiStepLR(optimizer, milestones=[int(args.epoch*0.5), int(args.epoch*0.75)], gamma=0.1)
else:
milestone = [int(_) for _ in args.milestone.split(',')]
lr_scheduler = MultiStepLR(optimizer, milestones=milestone, gamma=0.1)

# init criterion
criterion = nn.CrossEntropyLoss()

start_epoch = 0
best_acc = 0
if args.resume:
print('==> Resuming from checkpoint..')
assert os.path.isfile(args.load_path), 'Error: no checkpoint directory found!'
checkpoint = torch.load(args.load_path)
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
print('==> Loaded checkpoint at epoch: %d, acc: %.2f%%' % (start_epoch, best_acc))

# init summary writter

log_dir = os.path.join(args.log_dir, args.dataset, args.network, args.optimizer,
'lr%.3f_wd%.4f_damping%.4f' %
(args.learning_rate, args.weight_decay, args.damping))
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
writer = SummaryWriter(log_dir)


def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0

lr_scheduler.step()
desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
(tag, lr_scheduler.get_lr()[0], 0, 0, correct, total))

writer.add_scalar('train/lr', lr_scheduler.get_lr()[0], epoch)

prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
for batch_idx, (inputs, targets) in prog_bar:
inputs, targets = inputs.to(args.device), targets.to(args.device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
if optim_name in ['kfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0:
# compute true fisher
optimizer.acc_stats = True
with torch.no_grad():
sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),
1).squeeze().cuda()
loss_sample = criterion(outputs, sampled_y)
loss_sample.backward(retain_graph=True)
optimizer.acc_stats = False
optimizer.zero_grad() # clear the gradient for computing true-fisher.
loss.backward()
optimizer.step()

train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
(tag, lr_scheduler.get_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
prog_bar.set_description(desc, refresh=True)

writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch)
writer.add_scalar('train/acc', 100. * correct / total, epoch)


def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (tag,lr_scheduler.get_lr()[0], test_loss/(0+1), 0, correct, total))

prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True)
with torch.no_grad():
for batch_idx, (inputs, targets) in prog_bar:
inputs, targets = inputs.to(args.device), targets.to(args.device)
outputs = net(inputs)
loss = criterion(outputs, targets)

test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (tag, lr_scheduler.get_lr()[0], test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
prog_bar.set_description(desc, refresh=True)

# Save checkpoint.
acc = 100.*correct/total

writer.add_scalar('test/loss', test_loss / (batch_idx + 1), epoch)
writer.add_scalar('test/acc', 100. * correct / total, epoch)

if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
'loss': test_loss,
'args': args
}

torch.save(state, '%s/%s_%s_%s%s_best.t7' % (log_dir,
args.optimizer,
args.dataset,
args.network,
args.depth))
best_acc = acc


def main():
for epoch in range(start_epoch, args.epoch):
train(epoch)
test(epoch)
return best_acc


if __name__ == '__main__':
main()


Empty file added models/__init__.py
Empty file.
Loading

0 comments on commit 3dec464

Please sign in to comment.