Skip to content

Commit

Permalink
fixing unit tests wrt changes introduced in last commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ybouilla committed Apr 18, 2023
1 parent f97c2b6 commit 5d29fdc
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 53 deletions.
22 changes: 11 additions & 11 deletions fedbiomed/common/optimizers/generic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,17 @@ def init_training(self):
"""
self._model.init_training()

def train_model(self,
inputs: Union[torch.Tensor, np.ndarray],
target: Union[torch.Tensor, np.ndarray],
**kwargs):
"""Performs a training of the model
Args:
inputs: inputs data
target: targeted data
"""
self._model.train(inputs, target, **kwargs)
# def train_model(self,
# inputs: Union[torch.Tensor, np.ndarray],
# target: Union[torch.Tensor, np.ndarray],
# **kwargs):
# """Performs a training of the model

# Args:
# inputs: inputs data
# target: targeted data
# """
# self._model.train(inputs, target, **kwargs)

@abstractmethod
def step(self):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_base_training_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def predict(
data: Any,
) -> np.ndarray:
pass

def init_optimizer(self):
pass


class TestBaseTrainingPlan(unittest.TestCase):
Expand Down
122 changes: 80 additions & 42 deletions tests/test_generic_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import random
from typing import Any, Dict, List
import unittest
from unittest.mock import MagicMock, patch, Mock
from fedbiomed.common.exceptions import FedbiomedOptimizerError
Expand Down Expand Up @@ -37,6 +38,15 @@ def setUp(self) -> None:

self._torch_model_wrappers = (TorchModel(self._torch_model),)
self._torch_zero_model_wrappers = (TorchModel(self._zero_model),)

self.modules = [
ScaffoldServerModule(),
GaussianNoiseModule(),
YogiMomentumModule(),
L2Clipping(),
AdaGradModule(),
YogiModule()]
self.regularizers = [FedProxRegularizer(), LassoRegularizer(), RidgeRegularizer()]

def tearDown(self) -> None:
return super().tearDown()
Expand Down Expand Up @@ -119,7 +129,7 @@ def test_declearnoptimizer_03_step_method_2_SklearnOptimizer(self):
copy.deepcopy(model.get_weights()) for model in self._sklearn_model_wrappers
]
fake_retrieved_grads = [NumpyVector(grads) + 1 for grads in fake_retrieved_grads]

# operation: do a SGD step with all gradients equal 1 and learning rate equals 1
for sklearn_optim_wrapper, zero_model, grads in zip(initialized_sklearn_optim,
self._sklearn_model_wrappers,
fake_retrieved_grads):
Expand All @@ -130,7 +140,7 @@ def test_declearnoptimizer_03_step_method_2_SklearnOptimizer(self):

for (l, val), (l_ref, val_ref) in zip(zero_model.get_weights().items(),
sklearn_optim_wrapper._model.get_weights().items()):
self.assertTrue(np.all(val - 1 == val_ref))
self.assertTrue(np.all(val - 1 == val_ref)) # NOTA: all val values are equal 0

def test_declearnoptimizer_04_get_learning_rate(self):
learning_rate = .12345
Expand All @@ -153,47 +163,55 @@ def test_declearnoptimizer_04_get_learning_rate(self):

def test_declearnoptimizer_05_aux_variables(self):
learning_rate = .12345
empty_aux = {}
optim = FedOptimizer(lr=learning_rate, modules = [ScaffoldServerModule()])
optim.set_aux(empty_aux)

aux = { 'scaffold':
{
'node-1': {'state': 1.},
'node-2': {'state': 2.},
'node-3': {'state': 3.}
}
}

optim.set_aux(aux)

collected_aux_vars = optim.get_aux()
expected_aux_vars = {
'scaffold':
{
'node-1': {'delta': -1.},
'node-2': {'delta': 0.},
'node-3': {'delta': 1.}

for model_wrappers in (self._torch_model_wrappers, self._sklearn_model_wrappers):
for model in model_wrappers:
optim = FedOptimizer(lr=learning_rate, modules = [ScaffoldServerModule()])
optim_wrapper = DeclearnOptimizer(model, optim)
empty_aux = {}
optim_wrapper.set_aux(empty_aux)

self.assertDictEqual(optim_wrapper.get_aux(), {})
aux = { 'scaffold':
{
'node-1': {'state': 1.},
'node-2': {'state': 2.},
'node-3': {'state': 3.}
}
}
}

self.assertDictEqual(collected_aux_vars, expected_aux_vars)
optim_wrapper.set_aux(aux)

collected_aux_vars = optim_wrapper.get_aux()
expected_aux_vars = {
'scaffold':
{
'node-1': {'delta': -1.},
'node-2': {'delta': 0.},
'node-3': {'delta': 1.}
}
}

optim = FedOptimizer(lr=learning_rate, modules = [ScaffoldClientModule()])
aux = {'scaffold': {'delta': 1.}}
optim.set_aux(aux)
expected_aux = {'scaffold': {'state': 0.}}
collected_aux = optim.get_aux()
self.assertDictEqual(expected_aux, collected_aux)
self.assertDictEqual(collected_aux_vars, expected_aux_vars)

optim = FedOptimizer(lr=learning_rate, modules = [ScaffoldClientModule()])
optim_wrapper = DeclearnOptimizer(model, optim)
aux = {'scaffold': {'delta': 1.}}
optim_wrapper.set_aux(aux)
expected_aux = {'scaffold': {'state': 0.}}
collected_aux = optim_wrapper.get_aux()
self.assertDictEqual(expected_aux, collected_aux)

def test_declearnoptimizer_06_states(self):
def check_state(state, learning_rate, w_decay, modules, regs, model):
def check_state(state: Dict[str, Any], learning_rate: float, w_decay: float, modules: List, regs: List, model):
self.assertEqual(state['config']['lrate'], learning_rate)
self.assertEqual(state['config']['w_decay'], w_decay)
self.assertListEqual(state['config']['regularizers'], [(reg.name, reg.get_config()) for reg in regs])
self.assertListEqual(state['config']['modules'], [(mod.name , mod.get_config()) for mod in modules])
new_optim_wrapper = DeclearnOptimizer.load_state(model, state)
self.assertDictEqual(new_optim_wrapper.save_state(), state)
self.assertIsInstance(new_optim_wrapper.optimizer, FedOptimizer)

learning_rate = .12345
w_decay = .54321
Expand All @@ -206,15 +224,6 @@ def check_state(state, learning_rate, w_decay, modules, regs, model):
state = optim_wrapper.save_state()

check_state(state, learning_rate, w_decay, [], [], model)

self.modules = [
ScaffoldServerModule(),
GaussianNoiseModule(),
YogiMomentumModule(),
L2Clipping(),
AdaGradModule(),
YogiModule()]
self.regularizers = [FedProxRegularizer(), LassoRegularizer(), RidgeRegularizer()]

nb_tests = 10 # number of time the following test will be executed

Expand All @@ -236,7 +245,32 @@ def check_state(state, learning_rate, w_decay, modules, regs, model):

def test_declearnoptimizer_05_declearn_optimizers(self):
# TODO: test here several declearn optimizers, regardless of the framework used
pass
nb_tests = 10 # number of time the following test will be executed
for _ in range(nb_tests):
for model_wrappers in (self._torch_model_wrappers, self._sklearn_model_wrappers):
for model in model_wrappers:

# test DeclearnOptimizer with random modules and regularizers
selected_modules = random.sample(self.modules, random.randint(0, len(self.modules)))
selected_reg = random.sample(self.regularizers, random.randint(0, len(self.regularizers)))

optim = FedOptimizer(lr=.01,
decay=.1,
modules=selected_modules,
regularizers=selected_reg)
optim_wrapper = DeclearnOptimizer(model, optim)

def test_declearnoptimizer_06_scaffold_1_sklearnModel(self):
# test with one server and one node on a SklearnModel
researcher_optim = FedOptimizer(lr=.01, modules=[ScaffoldServerModule()])

node_optim = FedOptimizer(lr=.01, modules=[ScaffoldClientModule()])

for model_wrappers in (self._torch_model_wrappers, self._sklearn_model_wrappers):
for model in model_wrappers:
researcher_optim_wrapper = DeclearnOptimizer(model, researcher_optim)
node_optim_wrapper = DeclearnOptimizer(model, node_optim)


class TestTorchBasedOptimizer(unittest.TestCase):
# make sure torch based optimizers does the same action on torch models - regardless of their nature
Expand Down Expand Up @@ -341,9 +375,13 @@ def test_sklearnbasedoptimizer_01_get_learning_rate(self):
def test_sklearnbasedoptimizer_02_step(self):
pass

def test_sklearnbasedoptimizer_03_processing(self):
def test_sklearnbasedoptimizer_03_optimizer_processing(self):
pass

def test_sklearnbasedoptimizer_04_invalid_method(self):
# test that zero_grad raises error if model is pytorch
torch_model = MagicMock(spec=TorchModel)
pass
# class TestDeclearnTorchOptimizer(unittest.TestCase):

# def test_declearntorchoptimizer_01_zero_grad_error(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/testsupport/fake_training_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ def load(self, path: str, to_params: bool):
to_params (bool): originally, whether to return parameter into
the model or into a dictionary. Unused in this dummy class.
"""
def init_optimizer(self, optimizer_args: Dict[str, Any]):
"""Fakes `init_optimizer` method, used to initialize an optimizer (either framework
specific like pytorch optimizer, or non-framework specific like declearn)
Args:
optimizer_args: optimizer parameters as a dictionary
"""
def save(self, filename: str, results: Dict[str, Any] = None):
"""
Fakes `save` method of TrainingPlan classes, originally used for
Expand Down

0 comments on commit 5d29fdc

Please sign in to comment.