forked from fedbiomed/fedbiomed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_base_training_plan.py
280 lines (221 loc) · 11.7 KB
/
test_base_training_plan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import unittest
from unittest.mock import patch, MagicMock
from typing import Any, Dict, Optional
import logging
import torch
import numpy as np
from fedbiomed.common.exceptions import FedbiomedError, FedbiomedTrainingPlanError
from fedbiomed.common.constants import ProcessTypes
from fedbiomed.common.training_plans._base_training_plan import BaseTrainingPlan # noqa
# Import again the full module: we need it to test saving code without dependencies. Do not delete the line below.
import fedbiomed.common.training_plans._base_training_plan # noqa
class SimpleTrainingPlan(BaseTrainingPlan):
def training_routine(
self,
history_monitor: Optional['HistoryMonitor'] = None,
node_args: Optional[Dict[str, Any]] = None
) -> None:
pass
def post_init(
self,
model_args: Dict[str, Any],
training_args: Dict[str, Any]
) -> None:
pass
def model(self):
pass
def predict(
self,
data: Any,
) -> np.ndarray:
pass
def init_optimizer(self):
pass
class TestBaseTrainingPlan(unittest.TestCase):
""" Test Class for Base Training Plan """
def setUp(self):
self.tp = SimpleTrainingPlan()
logging.disable('CRITICAL')
pass
def tearDown(self) -> None:
logging.disable(logging.NOTSET)
def test_base_training_plan_01_add_dependency(self):
""" Test adding dependencies """
expected = ['from torch import nn']
self.tp.add_dependency(expected)
self.assertListEqual(expected, self.tp._dependencies, 'Can not set dependency properly')
def test_base_training_plan_02_set_dataset_path(self):
""" Test setting dataset path """
expected = '/path/to/my/data.csv'
self.tp.set_dataset_path(expected)
self.assertEqual(expected, self.tp.dataset_path, 'Can not set `dataset_path` properly')
def test_base_training_plan_03_save_code(self):
""" Testing the method save_code of BaseTrainingPlan """
expected_filepath = 'path/to/model.py'
# Test without dependencies
with patch.object(fedbiomed.common.training_plans._base_training_plan, 'open', MagicMock()) as mock_open:
self.tp.save_code(expected_filepath)
mock_open.assert_called_once_with(expected_filepath, "w", encoding="utf-8")
# Test with adding dependencies
with patch.object(fedbiomed.common.training_plans._base_training_plan, 'open', MagicMock()) as mock_open:
self.tp.add_dependency(['from fedbiomed.common.training_plans import TorchTrainingPlan'])
self.tp.save_code(expected_filepath)
mock_open.assert_called_once_with(expected_filepath, "w", encoding="utf-8")
# Test if get_class_source raises error
with patch('fedbiomed.common.training_plans._base_training_plan.get_class_source') \
as mock_get_class_source:
mock_get_class_source.side_effect = FedbiomedError
with self.assertRaises(FedbiomedTrainingPlanError):
self.tp.save_code(expected_filepath)
# Test if open function raises errors
with patch.object(fedbiomed.common.training_plans._base_training_plan, 'open', MagicMock()) as mock_open:
mock_open.side_effect = OSError
with self.assertRaises(FedbiomedTrainingPlanError):
self.tp.save_code(expected_filepath)
mock_open.side_effect = PermissionError
with self.assertRaises(FedbiomedTrainingPlanError):
self.tp.save_code(expected_filepath)
mock_open.side_effect = MemoryError
with self.assertRaises(FedbiomedTrainingPlanError):
self.tp.save_code(expected_filepath)
def test_base_training_plan_04_add_preprocess(self):
def method(args):
pass
# Test raising error due to worn process type
with self.assertRaises(FedbiomedTrainingPlanError):
self.tp.add_preprocess(method, 'WorngType')
# Test raising error due to wrong type of method argument
with self.assertRaises(FedbiomedTrainingPlanError):
self.tp.add_preprocess('not-callable', ProcessTypes.DATA_LOADER)
# Test proper scenario
self.tp.add_preprocess(method, ProcessTypes.DATA_LOADER)
self.assertTrue('method' in self.tp.pre_processes, 'add_preprocess could not add process properly')
def test_base_training_plan_05_set_data_loaders(self):
test_data_loader = [1, 2, 3]
train_data_loader = [1, 2, 3]
self.tp.set_data_loaders(train_data_loader, test_data_loader)
self.assertListEqual(self.tp.training_data_loader, train_data_loader)
self.assertListEqual(self.tp.testing_data_loader, test_data_loader)
def test_base_training_plan_06__create_metric_result(self):
"""
Testing private method create metric result dict
This test function also tests the method _check_metric_types_is_int_or_float
as implicitly
"""
metric = 14
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom': 14})
with self.assertRaises(FedbiomedTrainingPlanError):
metric = True
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
with self.assertRaises(FedbiomedTrainingPlanError):
metric = 'True'
BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
metric = [14, 14, 14.5]
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom_1': 14, 'Custom_2': 14, 'Custom_3': 14.5})
with self.assertRaises(FedbiomedTrainingPlanError):
metric = ['14', '14', '14']
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
metric = {'my_metric': 12, 'other_metric': 14.15}
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, metric)
with self.assertRaises(FedbiomedTrainingPlanError):
metric = {'my_metric': 'True', 'other_metric': 14.15}
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
# Testing torch.tensor
metric = torch.tensor(14)
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom': 14})
metric = [torch.tensor(14), torch.tensor(14), torch.tensor(14)]
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom_1': 14, 'Custom_2': 14, 'Custom_3': 14})
metric = {"m1": torch.tensor(14), "m2": torch.tensor(14)}
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'m1': 14, 'm2': 14})
metric = {"m1": torch.tensor(14.5), "m2": torch.tensor(14.5)}
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'m1': 14.5, 'm2': 14.5})
with self.assertRaises(FedbiomedTrainingPlanError):
metric = {"m1": torch.tensor([14.5, 14.5]), "m2": torch.tensor([14.5,14.5])}
BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
# Testing numpy arrays
metric = np.array([14, 14, 14])
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom_1': 14, 'Custom_2': 14, 'Custom_3': 14})
metric = np.array([14.5, 14.5, 14.5])
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom_1': 14.5, 'Custom_2': 14.5, 'Custom_3': 14.5})
metric = np.array([14.5, 14.5, 14.5], dtype=np.floating)
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom_1': 14.5, 'Custom_2': 14.5, 'Custom_3': 14.5})
with self.assertRaises(FedbiomedTrainingPlanError):
metric = {"m1": np.array([14.5, 14.5]), "m2": np.array([14.5, 14.5])}
BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
metric = np.float64(4.5)
result = BaseTrainingPlan._create_metric_result_dict(metric=metric, metric_name='Custom')
self.assertDictEqual(result, {'Custom': 4.5})
self.assertIsInstance(result["Custom"], float)
def test_base_training_plan_07_training_data(self):
""" Test training_data method whether raises error """
# The method training data should be defined by user, that's why
# training_data in BaseTrainingPLan has been configured for raising error
with self.assertRaises(FedbiomedTrainingPlanError):
self.tp.training_data()
def test_base_training_plan_08_get_learning_rate(self):
pass
def test_base_training_plan_09_infer_batch_size(self):
"""Test that the utility to infer batch size works correctly.
Supported data types are:
- torch tensor
- numpy array
- dict mapping {modality: tensor/array}
- tuple or list containing the above
"""
batch_size = 4
tp = SimpleTrainingPlan()
# Test simple case: data is a tensor
data = MagicMock(spec=torch.Tensor)
data.__len__.return_value = batch_size
self.assertEqual(tp._infer_batch_size(data), batch_size)
# Test simple case: data is an array
data = MagicMock(spec=np.ndarray)
data.__len__.return_value = batch_size
self.assertEqual(tp._infer_batch_size(data), batch_size)
# Text complex case: data is a dict of tensors
data = {
'T1': MagicMock(spec=torch.Tensor)
}
data['T1'].__len__.return_value = batch_size
self.assertEqual(tp._infer_batch_size(data), batch_size)
# Test complex case: data is a list of dicts of arrays
data = [{
'T1': MagicMock(spec=np.ndarray)
},
{
'T1': MagicMock(spec=np.ndarray)
}, ]
data[0]['T1'].__len__.return_value = batch_size
self.assertEqual(tp._infer_batch_size(data), batch_size)
# Test random arbitrarily-nested data
collection_types = ('list', 'tuple', 'dict')
leaf_types = (torch.Tensor, np.ndarray)
data_leaf_type = leaf_types[np.random.randint(low=0, high=len(leaf_types)-1, size=1)[0]]
data = MagicMock(spec=data_leaf_type)
data.__len__.return_value = batch_size
num_nesting_levels = np.random.randint(low=1, high=5, size=1)[0]
nesting_types = list() # for record-keeping purposes
for _ in range(num_nesting_levels):
collection_type = collection_types[np.random.randint(low=0, high=len(collection_types)-1, size=1)[0]]
nesting_types.append(collection_type)
if collection_type == 'list':
data = [data, data]
elif collection_type == 'tuple':
data = (data, data)
else:
data = {'T1': data, 'T2': data}
self.assertEqual(tp._infer_batch_size(data), batch_size,
f'Inferring batch size failed on arbitrary nested collection of {nesting_types[::-1]} '
f'and leaf type {data_leaf_type.__name__}')
if __name__ == '__main__': # pragma: no cover
unittest.main()