forked from gmum/few-shot-hypernets-public
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathio_params.py
189 lines (143 loc) · 5.29 KB
/
io_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from enum import StrEnum, auto
from typing import Optional
from tap import Tap
from pathlib import Path
class Arg:
class Method(StrEnum):
baseline = auto()
baselinepp = "baseline++"
DKT = "DKT"
protonet = auto()
matchingnet = auto()
relationnet = auto()
relationnet_softmax = auto()
maml = auto()
maml_approx = auto()
hyper_maml = auto()
bayes_hmaml = auto()
hyper_shot = auto()
hn_ppa = auto()
hn_poc = auto()
class Model(StrEnum):
Conv4 = "Conv4"
Conv4Pool = "Conv4Pool"
Conv4S = "Conv4S"
Conv6 = "Conv6"
ResNet10 = "ResNet10"
ResNet18 = "ResNet18"
ResNet34 = "ResNet34"
ResNet50 = "ResNet50"
ResNet101 = "ResNet101"
Conv4WithKernel = "Conv4WithKernel"
ResNetWithKernel = "ResNetWithKernel"
class Dataset(StrEnum):
CUB = "CUB"
miniImagenet = auto()
cross = auto()
omniglot = auto()
# emnist = auto()
cross_char = auto()
class Optim(StrEnum):
adam = auto()
sgd = auto()
class Scheduler(StrEnum):
none = auto()
multisteplr = auto()
cosine = auto()
reducelronplateau = auto()
class Split(StrEnum):
novel = auto()
base = auto()
val = auto()
@classmethod
def list(_self, cls) -> list[str]:
return list(map(lambda c: c.value, cls))
class ParamStruct(Tap):
seed: int = 0
"Seed for Numpy and pyTorch."
dataset: Optional[Arg.Dataset] = None
"The dataset used for training the model. Refer to Arg.Dataset for allowed values"
model = Arg.Model.Conv4
"The model used for prediction. Refer to Arg.Model for allowed values"
# 50 and 101 are not used in the paper
method: Optional[Arg.Method] = None
"The method utilized in conjunction with the model. Refer to Arg.Method for allowed values"
# relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
train_n_way = 5
"Class num to classify for training"
# baseline and baseline++ ignore this parameter
test_n_way = 5
"Class num to classify for testing (validation)"
# baseline and baseline++ ignore this parameter
n_shot = 5
"Number of labeled data in each class, same as n_support"
# baseline and baseline++ only use this parameter in finetuning
train_aug = False
"Whether to perform data augmentation during training"
checkpoint_suffix: str = ""
"Suffix for custom experiment differentiation"
# saved in save/checkpoints/[dataset]
lr: float = 1e-3
"Learning rate"
optim = Arg.Optim.adam
"Optimizer"
n_val_perms: int = 1
"Number of task permutations in evaluation."
lr_scheduler = Arg.Scheduler.none
"LR scheduler"
milestones: list[int] | None = None
"Milestones for multisteplr"
maml_save_feature_network = False
"Whether to save feature net used in MAML"
maml_adapt_classifier = False
"Adapt only the classifier during second gradient calculation"
evaluate_model = False
"Skip train phase and perform final test"
# region train
num_classes: int = 200
"Total number of classes in softmax, only used in baseline"
# make it larger than the maximum label value in base class
save_freq: int = 500
"Save frequency"
# TODO: pass to pl.Trainer without perhaps defining a custom callback
stop_epoch: Optional[int] = None
"Stopping epoch"
# for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py
resume: bool = False
"Continue from previous trained model with largest epoch"
warmup: bool = False
"Continue from baseline, neglected if resume is true"
# never used in the paper
es_epoch = 250
"Check if val accuracy threshold achieved at this epoch, stop if not."
es_threshold: float = 50.0
"Validation accuracy threshold for early stopping."
eval_freq = 1
"Evaluation frequency"
# endregion
split = Arg.Split.novel
"Split dataset into /base/val/novel/"
# default novel, but you can also test base/val class accuracy if you want
save_iter: Optional[int] = None
"save feature from the model trained in x epoch, use the best model if x is None"
adaptation = False
"Further adaptation in test time or not"
repeat = 5
"Repeat the test N times with different seeds and take the mean. The seeds range is [seed, seed+repeat]"
n_query: Optional[int] = None
"By default, this parameter is computed at runtime based on n_way"
args_file: Optional[Path] = None
"Path to a .json file specifying arguments of a previous run may be provided"
class ParamHolder(ParamStruct):
def __init__(self):
super().__init__()
self.history = set()
def __getattr__(self, item):
it = super().__getattribute__(item)
if item not in self.history:
print("Getting", item, "=", it)
self.history.add(item)
return it
# TODO: seems to be not working correctly after refactor
def get_ignored_args(self) -> list[str]:
return sorted([k for k in vars(self).keys() if k not in self.history])