Skip to content

Commit

Permalink
Add unit test for frozen weights handling in 'Experiment'.
Browse files Browse the repository at this point in the history
  • Loading branch information
pandrey-fr committed May 30, 2023
1 parent 957bf56 commit b21b0af
Showing 1 changed file with 55 additions and 12 deletions.
67 changes: 55 additions & 12 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,11 +1274,10 @@ def test_experiment_28_agg_optimizer_updates(self):
torch_aggregate,
numpy_aggregate
)

global_model_collections = (
torch_global_model,
numpy_global_model
)
)
for aggregates, global_model in zip(aggregates_collection, global_model_collections):
# set up the Optimizer on Researcher side
lr = .12345
Expand All @@ -1304,14 +1303,58 @@ def test_experiment_28_agg_optimizer_updates(self):
for k, v in agg_updates.items():
self.assertTrue(np.isclose(agg_updates[k], aggregates[k]).all())

def test_experiment_29_agg_optimizer_updates_with_frozen_layers(self):
"""Test that the researcher-side optimize properly handles frozen weights."""
# Set up placeholder model weights, and a weights-getter function.
global_model = {
"layer_frozen_kernel": torch.randn((8,4)),
"layer_frozen_bias": torch.randn(4),
"layer_trainable_kernel": torch.randn((4,1)),
"layer_trainable_bias": torch.randn(1),
}
aggregates = {
"layer_frozen_kernel": global_model["layer_frozen_kernel"].clone(),
"layer_frozen_bias": global_model["layer_frozen_bias"].clone(),
"layer_trainable_kernel": torch.randn((4,1)),
"layer_trainable_bias": torch.randn(1),
}
def get_weights(only_trainable=False):
"""Access the model's weights, opt. restricted to trainable ones."""
if only_trainable:
return {
key: val.clone()
for key, val in global_model.items()
if "trainable" in key
}
return copy.deepcopy(global_model)
# Attach a mock Job with proper weight-access to the tested Experiment.
mock_job = create_autospec(Job, instance=True)
mock_job.training_plan.get_model_params.side_effect = get_weights
self.test_exp._job = mock_job
self.test_exp._global_model = global_model
# Set up an Optimizer, and expected results.
lrate = 0.8
self.test_exp.set_agg_optimizer(Optimizer(lr=lrate))
expected = {
key: (
(val - lrate * (val - aggregates[key]))
if "trainable" in key
else val
)
for key, val in global_model.items()
}
# Perform the optimization step and compare to expected results.
results = self.test_exp._run_agg_optimizer(aggregates)
for key, val in results.items():
self.assertTrue(np.isclose(val, expected[key]).all())

@patch('fedbiomed.researcher.aggregators.fedavg.FedAverage.aggregate')
@patch('fedbiomed.researcher.job.Job.training_plan', new_callable=PropertyMock)
@patch('fedbiomed.researcher.job.Job.training_replies', new_callable=PropertyMock)
@patch('fedbiomed.researcher.job.Job.start_nodes_training_round')
@patch('fedbiomed.researcher.job.Job.update_parameters')
@patch('fedbiomed.researcher.job.Job.__init__')
def test_experiment_29_strategy(self,
def test_experiment_30_strategy(self,
mock_job_init,
mock_job_updates_params,
mock_job_training,
Expand Down Expand Up @@ -1404,7 +1447,7 @@ def test_experiment_29_strategy(self,
self.test_exp.run_once()

@patch('fedbiomed.researcher.experiment.Experiment.run_once')
def test_experiment_30_run(self, mock_exp_run_once):
def test_experiment_31_run(self, mock_exp_run_once):
""" Testing run method of Experiment class """

def run_once_side_effect(increase, test_after=False):
Expand Down Expand Up @@ -1492,7 +1535,7 @@ def run_once_side_effect(increase, test_after=False):

@patch('builtins.open')
@patch('fedbiomed.researcher.job.Job.training_plan_file', new_callable=PropertyMock)
def test_experiment_31_training_plan_file(self,
def test_experiment_32_training_plan_file(self,
mock_training_plan_file,
mock_open):
""" Testing getter training_plan_file of the experiment class """
Expand Down Expand Up @@ -1533,7 +1576,7 @@ def test_experiment_31_training_plan_file(self,

@patch('fedbiomed.researcher.job.Job.__init__', return_value=None)
@patch('fedbiomed.researcher.job.Job.check_training_plan_is_approved_by_nodes')
def test_experiment_32_check_training_plan_status(self,
def test_experiment_33_check_training_plan_status(self,
mock_job_model_is_approved,
mock_job):
"""Testing method that checks model status """
Expand All @@ -1551,7 +1594,7 @@ def test_experiment_32_check_training_plan_status(self,
self.assertDictEqual(result, expected_approved_result,
'check_training_plan_status did not return expected value')

def test_experiment_33_breakpoint_raises(self):
def test_experiment_34_breakpoint_raises(self):
""" Testing the scenarios where the method breakpoint() raises error """

# Test if self._round_current is less than 1
Expand Down Expand Up @@ -1584,7 +1627,7 @@ def test_experiment_33_breakpoint_raises(self):
@patch('fedbiomed.researcher.experiment.choose_bkpt_file')
# testing _save_breakpoint + _save_aggregated_params
# (not exactly a unit test, but probably more interesting)
def test_experiment_34_save_breakpoint(
def test_experiment_35_save_breakpoint(
self,
patch_choose_bkpt_file,
patch_create_ul,
Expand Down Expand Up @@ -1746,7 +1789,7 @@ def training_plan(self):
# test load_breakpoint + _load_aggregated_params
# cannot test Experiment constructor, need to fake it
# (not exactly a unit test, but probably more interesting)
def test_experiment_35_static_load_breakpoint(self,
def test_experiment_36_static_load_breakpoint(self,
patch_find_breakpoint_path,
patch_training_plan
):
Expand Down Expand Up @@ -1958,7 +2001,7 @@ def load(self, aggreg, update_model):
self.assertTrue(loaded_exp.secagg.active)

@patch('fedbiomed.researcher.experiment.create_unique_file_link')
def test_experiment_36_static_save_aggregated_params(self,
def test_experiment_37_static_save_aggregated_params(self,
mock_create_unique_file_link):
"""Testing static private method of experiment for saving aggregated params"""

Expand Down Expand Up @@ -1988,7 +2031,7 @@ def test_experiment_36_static_save_aggregated_params(self,
agg_p = Experiment._save_aggregated_params(aggregated_params_init=agg_params, breakpoint_path='/')
self.assertDictEqual(agg_p, expected_agg_params, '_save_aggregated_params result is not as expected')

def test_experiment_37_static_load_aggregated_params(self):
def test_experiment_38_static_load_aggregated_params(self):
"""Testing static method for loading aggregated params of Experiment"""

# Test invalid type of aggregated params (should be dict)
Expand All @@ -2014,7 +2057,7 @@ def test_experiment_37_static_load_aggregated_params(self):
result = Experiment._load_aggregated_params(agg_params)
self.assertDictEqual(result, expected, '_load_aggregated_params did not return as expected')

def test_experiment_38_private_create_object(self):
def test_experiment_39_private_create_object(self):
"""tests `_create_object_ method :
Importing class, creating and initializing multiple objects from
breakpoint state for object and file containing class code
Expand Down

0 comments on commit b21b0af

Please sign in to comment.