-
Notifications
You must be signed in to change notification settings - Fork 178
/
Copy pathpytorch_ignite_simple.py
134 lines (103 loc) · 4.55 KB
/
pytorch_ignite_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Optuna example that optimizes convolutional neural networks using PyTorch Ignite.
In this example, we optimize the validation accuracy of fashion product recognition using
PyTorch Ignite and FashionMNIST. We optimize the neural network architecture as well as the
regularization. As it is too time consuming to use the whole FashionMNIST dataset,
we here use a small subset of it.
You can run this example as follows, pruning can be turned on and off with the `--pruning`
argument.
$ python pytorch_ignite_simple.py [--pruning]
"""
import argparse
from ignite.engine import create_supervised_evaluator
from ignite.engine import create_supervised_trainer
from ignite.engine import Events
from ignite.metrics import Accuracy
import optuna
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision.datasets.mnist import FashionMNIST
from torchvision.transforms import Compose
from torchvision.transforms import Normalize
from torchvision.transforms import ToTensor
EPOCHS = 10
TRAIN_BATCH_SIZE = 64
VAL_BATCH_SIZE = 1000
N_TRAIN_EXAMPLES = 3000
N_VALID_EXAMPLES = 1000
class Net(nn.Module):
def __init__(self, trial):
# We optimize dropout rate in a convolutional neural network.
super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
dropout_rate = trial.suggest_float("dropout_rate", 0, 1)
self.conv2_drop = nn.Dropout2d(p=dropout_rate)
fc2_input_dim = trial.suggest_int("fc2_input_dim", 40, 80)
self.fc1 = nn.Linear(320, fc2_input_dim)
self.fc2 = nn.Linear(fc2_input_dim, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_data = FashionMNIST(download=True, root=".", transform=data_transform, train=True)
val_data = FashionMNIST(download=False, root=".", transform=data_transform, train=False)
train_loader = DataLoader(
Subset(train_data, range(N_TRAIN_EXAMPLES)), batch_size=train_batch_size, shuffle=True
)
val_loader = DataLoader(
Subset(val_data, range(N_VALID_EXAMPLES)), batch_size=val_batch_size, shuffle=False
)
return train_loader, val_loader
def objective(trial):
# Create a convolutional neural network.
model = Net(trial)
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
model.cuda(device)
optimizer = Adam(model.parameters())
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()}, device=device)
# Register a pruning handler to the evaluator.
pruning_handler = optuna.integration.PyTorchIgnitePruningHandler(trial, "accuracy", trainer)
evaluator.add_event_handler(Events.COMPLETED, pruning_handler)
# Load FashionMNIST dataset.
train_loader, val_loader = get_data_loaders(TRAIN_BATCH_SIZE, VAL_BATCH_SIZE)
@trainer.on(Events.EPOCH_COMPLETED)
def log_results(engine):
evaluator.run(val_loader)
validation_acc = evaluator.state.metrics["accuracy"]
print("Epoch: {} Validation accuracy: {:.2f}".format(engine.state.epoch, validation_acc))
trainer.run(train_loader, max_epochs=EPOCHS)
return evaluator.state.metrics["accuracy"]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch Ignite example.")
parser.add_argument(
"--pruning",
"-p",
action="store_true",
help="Activate the pruning feature. `MedianPruner` stops unpromising "
"trials at the early stages of training.",
)
args = parser.parse_args()
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))