forked from fedbiomed/fedbiomed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_scaffold.py
381 lines (324 loc) · 17.4 KB
/
test_scaffold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import os
import unittest
from unittest.mock import MagicMock, patch
from fedbiomed.common.exceptions import FedbiomedAggregatorError
from fedbiomed.common.optimizers.generic_optimizers import NativeTorchOptimizer
from fedbiomed.researcher.aggregators.fedavg import FedAverage
from fedbiomed.researcher.aggregators.functional import federated_averaging
from fedbiomed.researcher.datasets import FederatedDataSet
from fedbiomed.researcher.responses import Responses
from testsupport.fake_uuid import FakeUuid
import torch
from torch.nn import Linear
from fedbiomed.researcher.aggregators.scaffold import Scaffold
import copy
import random
from testsupport.base_case import ResearcherTestCase
class TestScaffold(ResearcherTestCase):
'''
Tests the Scaffold class
'''
# before the tests
def setUp(self):
self.model = Linear(10, 3)
self.n_nodes = 4
self.node_ids = [f'node_{i}'for i in range(self.n_nodes)]
self.fds = FederatedDataSet({node: {} for node in self.node_ids})
self.models = {node_id: Linear(10, 3).state_dict() for i, node_id in enumerate(self.node_ids)}
self.zero_model = copy.deepcopy(self.model) # model where all parameters are equals to 0
self.responses = Responses([])
for node_id in self.node_ids:
self.responses.append({'node_id': node_id, 'optimizer_args': {'lr' : [.1]}})
self.responses = Responses([self.responses])
self.weights = [{node_id: random.random()} for (node_id, _) in zip(self.node_ids, self.models)]
# setting all coefficients of `zero_model` to 0
for p in self.zero_model.parameters():
p.data.fill_(0)
# after the tests
def tearDown(self):
pass
def assert_coherent_states(self, agg: Scaffold) -> None:
"""Raise if the states of a Scaffold aggregator are not coherent.
In this context, coherence means that:
- the global state is the average of node-wise ones
- the delta variables match their definition (c_i - c)
"""
# Check that the global state is the average of local ones.
for key, val in agg.global_state.items():
avg = sum(states[key] for states in agg.nodes_states.values()) / len(agg.nodes_states)
self.assertTrue((val == avg).all())
# Check that delta variables match their definition.
for node_id in self.node_ids:
self.assertTrue(all(
(agg.nodes_deltas[node_id][key] == agg.nodes_states[node_id][key] - val).all()
for key, val in agg.global_state.items()
))
def test_1_init_correction_states(self):
"""Test that 'init_correction_states' works properly."""
agg = Scaffold(server_lr=1., fds=self.fds)
global_model = self.model.state_dict()
agg.init_correction_states(global_model)
# Check that the global state has proper keys and zero values.
self.assertEqual(agg.global_state.keys(), global_model.keys())
self.assertTrue(all(
(agg.global_state[key] == 0).all() for key in agg.global_state
))
# Check that node dicts have proper keys.
self.assertEqual(list(agg.nodes_states), self.node_ids)
self.assertEqual(list(agg.nodes_deltas), self.node_ids)
# Check that node-wise states and deltas have proper keys and values.
for node in self.node_ids:
self.assertEqual(agg.nodes_states[node].keys(), global_model.keys())
self.assertEqual(agg.nodes_deltas[node].keys(), global_model.keys())
self.assertTrue(all(
(agg.nodes_states[node][key] == 0).all() for key in global_model
))
self.assertTrue(all(
(agg.nodes_deltas[node][key] == 0).all() for key in global_model
))
def test_2_update_correction_state_all_nodes(self):
"""Test that 'update_correction_states' works properly.
Case: all nodes were sampled, all with zero-valued updates.
"""
# Instantiate a Scaffold aggregator and initialize its states.
agg = Scaffold(server_lr=1., fds=self.fds)
agg.init_correction_states(self.model.state_dict())
agg.nodes_lr = {k: [1] * self.n_nodes for k in self.node_ids}
# Test with zero-valued client models, i.e. updates equal to model.
agg.update_correction_states(
{node_id: self.model.state_dict() for node_id in self.node_ids},
n_updates=1,
)
# Check that the local states were properly updated.
for node_id in self.node_ids:
self.assertTrue(all(
(agg.nodes_states[node_id][key] == val).all()
for key, val in self.model.state_dict().items()
))
self.assert_coherent_states(agg)
def test_3_update_correction_state_single_node(self):
"""Test that 'update_correction_states' works properly.
Case: a single node was sampled, with random-valued updates.
"""
# Instantiate a Scaffold aggregator and initialize its states.
agg = Scaffold(server_lr=1., fds=self.fds)
agg.init_correction_states(self.model.state_dict())
agg.nodes_lr = {k: [1] * self.n_nodes for k in self.node_ids}
# Test when a single client has non-zero-updates after 4 steps.
updates = {
key: torch.rand_like(val)
for key, val in self.model.state_dict().items()
}
agg.update_correction_states({self.node_ids[0]: updates}, n_updates=4)
# Check that this client's local state was properly updated.
# Note that the previous delta is zero.
self.assertTrue(all(
(agg.nodes_states[self.node_ids[0]][key] == (val / 4.0)).all()
for key, val in updates.items()
))
# Check that other clients' local state was left unaltered.
for node_id in self.node_ids[1:]:
self.assertTrue(all(
(agg.nodes_states[node_id][key] == 0.).all()
for key in self.model.state_dict()
))
# Check that the global state and deltas were properly updated.
self.assert_coherent_states(agg)
def test_4_aggregate(self):
"""Test that 'aggregate' works properly."""
training_plan = MagicMock()
training_plan.get_model_params = MagicMock(return_value = self.node_ids)
agg = Scaffold(server_lr=.2, fds=self.fds)
n_round = 0
weights = {node_id: 1./self.n_nodes for node_id in self.node_ids}
# assuming that global model has all its coefficients to 0
aggregated_model_params_scaffold = agg.aggregate(
model_params=copy.deepcopy(self.models),
weights=weights,
global_model=copy.deepcopy(self.zero_model.state_dict()),
training_plan=training_plan,
training_replies=self.responses,
n_round=n_round
)
aggregated_model_params_fedavg = FedAverage().aggregate(
copy.deepcopy(self.models), weights
)
# we check that fedavg and scaffold give proportional results provided:
# - all previous correction state model are set to 0 (round 0)
# - model proportions are the same
# then:
# fedavg: x_i <- x_i / n_nodes
# scaffold: x_i <- server_lr * x_i / n_nodes
for v_s, v_f in zip(
aggregated_model_params_scaffold.values(),
aggregated_model_params_fedavg.values()
):
self.assertTrue(torch.isclose(v_s, v_f * .2).all())
# check that at the end of aggregation, all correction states are non zeros (
for deltas in agg.nodes_deltas.values():
for layer in deltas.values():
self.assertFalse(torch.nonzero(layer).all())
def test_5_setting_scaffold_with_wrong_parameters(self):
"""test_5_setting_scaffold_with_wrong_parameters: tests that scaffold is
returning an error when set with incorrect parameters
"""
# test 1: `server_lr` should be different than 0
for x in (0, 0.):
with self.assertRaises(FedbiomedAggregatorError):
Scaffold(server_lr = x)
# test 2: calling `init_correction_states` without any federated dataset
with self.assertRaises(FedbiomedAggregatorError):
scaffold = Scaffold()
scaffold.init_correction_states(self.model.state_dict())
# test 3: `n_updates` should be a positive and non zero integer
training_plan = MagicMock()
for x in (-1, .2, 0, 0., -3.2):
with self.assertRaises(FedbiomedAggregatorError):
scaffold = Scaffold()
scaffold.check_values(n_updates=x, training_plan=training_plan)
# test 4: `FederatedDataset` has not been specified
with self.assertRaises(FedbiomedAggregatorError):
scaffold = Scaffold()
scaffold.check_values(n_updates=1, training_plan=training_plan)
with self.assertRaises(FedbiomedAggregatorError):
scaffold = Scaffold()
scaffold.check_values(n_updates=None, training_plan=training_plan)
def test_6_create_aggregator_args(self):
agg = Scaffold(fds=self.fds)
agg_thr_msg, agg_thr_file = agg.create_aggregator_args(self.model.state_dict(),
self.node_ids)
for node_id in self.node_ids:
for (k, v), (k0, v0) in zip(agg.nodes_deltas[node_id].items(),
self.zero_model.state_dict().items()):
self.assertTrue(torch.isclose(v, v0).all())
# check that each element returned by method contains key 'aggregator_name'
for node_id in self.node_ids:
self.assertTrue(agg_thr_msg[node_id].get('aggregator_name', False))
self.assertTrue(agg_thr_file[node_id].get('aggregator_name', False))
# check `agg_thr_file` contains node correction state
for node_id in self.node_ids:
self.assertDictEqual(agg_thr_file[node_id]['aggregator_correction'], agg.nodes_deltas[node_id])
# checking case where a node has been added to the training (repeating same tests above)
self.n_nodes += 1
self.node_ids.append(f'node_{self.n_nodes}')
self.fds.data()[f'node_{self.n_nodes}'] = {}
agg_thr_msg, agg_thr_file = agg.create_aggregator_args(self.model.state_dict(),
self.node_ids)
for node_id in self.node_ids:
self.assertTrue(agg_thr_msg[node_id].get('aggregator_name', False))
self.assertTrue(agg_thr_file[node_id].get('aggregator_name', False))
# check `agg_thr_file` contains node correction state
for node_id in self.node_ids:
self.assertDictEqual(agg_thr_file[node_id]['aggregator_correction'], agg.nodes_deltas[node_id])
@patch('uuid.uuid4')
def test_7_save_state(self, uuid_patch):
uuid_patch.return_value = FakeUuid()
server_lr = .5
fds = FederatedDataSet({node_id: {} for node_id in self.node_ids})
bkpt_path = '/path/to/my/breakpoint'
scaffold = Scaffold(server_lr, fds=fds)
scaffold.init_correction_states(self.model.state_dict())
with patch("fedbiomed.common.serializer.Serializer.dump") as save_patch:
state = scaffold.save_state(breakpoint_path=bkpt_path, global_model=self.model.state_dict())
self.assertEqual(save_patch.call_count, self.n_nodes + 1,
f"'Serializer.dump' should be called {self.n_nodes} times: once for each node + \
one more time for global_state")
for node_id in self.node_ids:
self.assertEqual(state['parameters']['aggregator_correction'][node_id],
os.path.join(bkpt_path, 'aggregator_correction_' + str(node_id) + '.mpk'))
self.assertEqual(state['parameters']['server_lr'], server_lr)
self.assertEqual(state['parameters']['global_state_filename'], os.path.join(bkpt_path,
'global_state_'
+ str(FakeUuid.VALUE) + '.mpk'))
self.assertEqual(state['class'], Scaffold.__name__)
self.assertEqual(state['module'], Scaffold.__module__)
def test_8_load_state(self):
"""Test that 'load_state' triggers the proper amount of calls."""
server_lr = .5
fds = FederatedDataSet({node_id: {} for node_id in self.node_ids})
bkpt_path = '/path/to/my/breakpoint'
scaffold = Scaffold(server_lr, fds=fds)
# create a state (not actually saving the associated contents)
with patch("fedbiomed.common.serializer.Serializer.dump"):
state = scaffold.save_state(
breakpoint_path=bkpt_path, global_model=self.model.state_dict()
)
# action
with patch("fedbiomed.common.serializer.Serializer.load") as load_patch:
scaffold.load_state(state)
self.assertEqual(load_patch.call_count, self.n_nodes + 1,
f"'Serializer.load' should be called {self.n_nodes} times: once for each node + \
one more time for global_state")
def test_9_load_state_2(self):
"""Test that 'load_state' properly assigns loaded values."""
server_lr = .5
fds = FederatedDataSet({node_id: {} for node_id in self.node_ids})
bkpt_path = '/path/to/my/breakpoint'
scaffold = Scaffold(server_lr, fds=fds)
# create a state (not actually saving the associated contents)
with patch("fedbiomed.common.serializer.Serializer.dump"):
state = scaffold.save_state(
breakpoint_path=bkpt_path, global_model=self.model.state_dict()
)
# action
with patch(
"fedbiomed.common.serializer.Serializer.load",
return_value=self.model.state_dict()
):
scaffold.load_state(state)
# tests
for node_id in self.node_ids:
for (k,v), (k_ref, v_ref) in zip(scaffold.nodes_deltas[node_id].items(),
self.model.state_dict().items()):
self.assertTrue(torch.isclose(v, v_ref).all())
for (k, v), (k_0, v_0) in zip(scaffold.global_state.items(),
self.model.state_dict().items()):
self.assertTrue(torch.isclose(v, v_0).all())
self.assertEqual(scaffold.server_lr, server_lr)
def test_10_set_nodes_learning_rate_after_training(self):
n_rounds = 3
# test case were learning rates change from one layer to another
lr = [.1,.2,.3]
n_model_layer = len(lr) # number of layers model contains
training_replies = {r:
Responses( [{'node_id': node_id, 'optimizer_args': {'lr': lr}}
for node_id in self.node_ids])
for r in range(n_rounds)}
#assert n_model_layer == len(lr), "error in test: n_model_layer must be equal to the length of list of learning rate"
training_plan = MagicMock()
get_model_params_mock = MagicMock()
get_model_params_mock.__len__ = MagicMock(return_value=n_model_layer)
training_plan.get_model_params.return_value = get_model_params_mock
fds = FederatedDataSet({node_id: {} for node_id in self.node_ids})
scaffold = Scaffold(fds=fds)
for n_round in range(n_rounds):
node_lr = scaffold.set_nodes_learning_rate_after_training(training_plan=training_plan,
training_replies=training_replies,
n_round=n_round)
test_node_lr = {node_id: lr for node_id in self.node_ids}
self.assertDictEqual(node_lr, test_node_lr)
# same test with a mix of nodes present in training_replies and non present
fds = FederatedDataSet({node_id: {} for node_id in self.node_ids + ['node_99']})
optim_w = MagicMock(spec=NativeTorchOptimizer)
optim_w.get_learning_rate = MagicMock(return_value=lr)
training_plan.optimizer = MagicMock(return_value=optim_w)
#training_plan.get_learning_rate = MagicMock(return_value=lr)
scaffold = Scaffold(fds=fds)
for n_round in range(n_rounds):
node_lr = scaffold.set_nodes_learning_rate_after_training(training_plan=training_plan,
training_replies=training_replies,
n_round=n_round)
# test case where len(lr) != n_model_layer
lr += [.333]
training_plan.get_learning_rate = MagicMock(return_value=lr)
for n_round in range(n_rounds):
with self.assertRaises(FedbiomedAggregatorError):
scaffold.set_nodes_learning_rate_after_training(training_plan=training_plan,
training_replies=training_replies,
n_round=n_round)
# TODO:
# ideas for further tests:
# test 1: check that with one client only, correction terms are zeros
# test 2: check that for 2 clients, correction terms have opposite values
if __name__ == '__main__': # pragma: no cover
unittest.main()