-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
86 lines (77 loc) · 3.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torchvision import datasets, transforms
import torch
import logging
import os
from collections import OrderedDict
from torch.utils.data.sampler import SubsetRandomSampler
# from torch.utils.
def clamp(X, lower_limit, upper_limit):
return torch.max(torch.min(X, upper_limit), lower_limit)
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std = (0.229, 0.224, 0.225)
def normalize(args, X):
if args.dataset=="cifar":
mu = torch.tensor(cifar10_mean).view(3, 1, 1).cuda()
std = torch.tensor(cifar10_std).view(3, 1, 1).cuda()
elif args.dataset=="imagenette" or args.dataset=="imagenet":
mu = torch.tensor(imagenet_mean).view(3, 1, 1).cuda()
std = torch.tensor(imagenet_std).view(3, 1, 1).cuda()
return (X - mu) / std
def get_loaders(args):
if args.dataset=="cifar":
mean = cifar10_mean
std = cifar10_std
elif args.dataset=="imagenette" or args.dataset=="imagenet":
mean = imagenet_mean
std = imagenet_std
train_list = [
transforms.Resize([args.resize, args.resize]),
transforms.RandomCrop(args.crop, padding=4),
transforms.RandomHorizontalFlip(),
]
train_list.append(transforms.ToTensor())
train_list.append(transforms.Normalize(mean, std))
train_transform = transforms.Compose(train_list)
test_transform = transforms.Compose([
transforms.Resize([args.resize,args.resize]),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
num_workers = 16
if args.dataset=="cifar":
# dataset = datasets.ImageFolder(args.data_dir+"cifar-10-images/",train_transform)
# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# train_dataset = datasets.ImageFolder(args.data_dir+"train/", train_transform)
# test_dataset = datasets.ImageFolder(args.data_dir+"val/", test_transform)
# もともとはダウンロードする形
train_dataset = datasets.CIFAR10(
args.data_dir, train=True, transform=train_transform, download=False)
test_dataset = datasets.CIFAR10(
args.data_dir, train=False, transform=test_transform, download=False)
if args.dataset=="imagenette":
train_dataset = datasets.ImageFolder(args.data_dir+"train/",train_transform)
test_dataset = datasets.ImageFolder(args.data_dir+"val/",test_transform)
if args.dataset == "imagenet":
train_dataset = datasets.ImageFolder(args.data_dir+"train/",train_transform)
test_dataset = datasets.ImageFolder(args.data_dir+"val/",test_transform)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=num_workers,
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=args.batch_size // args.accum_steps*2,
shuffle=False,
pin_memory=True,
num_workers=num_workers,
)
return train_loader, test_loader
_logger = logging.getLogger(__name__)