diff --git a/fedbiomed/common/training_plans/_base_training_plan.py b/fedbiomed/common/training_plans/_base_training_plan.py index 73beaca0f..4d055560a 100644 --- a/fedbiomed/common/training_plans/_base_training_plan.py +++ b/fedbiomed/common/training_plans/_base_training_plan.py @@ -14,7 +14,9 @@ from fedbiomed.common import utils from fedbiomed.common.constants import ErrorNumbers, ProcessTypes from fedbiomed.common.data import NPDataLoader -from fedbiomed.common.exceptions import FedbiomedError, FedbiomedTrainingPlanError +from fedbiomed.common.exceptions import ( + FedbiomedError, FedbiomedModelError, FedbiomedTrainingPlanError +) from fedbiomed.common.logger import logger from fedbiomed.common.metrics import Metrics, MetricTypes from fedbiomed.common.models import Model @@ -532,3 +534,50 @@ def after_training_params(self) -> Dict[str, Any]: The trained parameters to aggregate. """ return self.get_model_params() + + def export_model(self, filename: str) -> None: + """Export the wrapped model to a dump file. + + Args: + filename: path to the file where the model will be saved. + + !!! info "Notes": + This method is designed to save the model to a local dump + file for easy re-use by the same user, possibly outside of + Fed-BioMed. It is not designed to produce trustworthy data + dumps and is not used to exchange models and their weights + as part of the federated learning process. + + To save the model parameters for sharing as part of the FL process, + use the `after_training_params` method (or `get_model_params` one + outside of a training context) and export results using + [`Serializer`][fedbiomed.common.serializer.Serializer]. + """ + self._model.export(filename) + + def import_model(self, filename: str) -> None: + """Import and replace the wrapped model from a dump file. + + Args: + filename: path to the file where the model has been exported. + + !!! info "Notes": + This method is designed to load the model from a local dump + file, that might not be in a trustworthy format. It should + therefore only be used to re-load data exported locally and + not received from someone else, including other FL peers. + + To load model parameters shared as part of the FL process, use the + [`Serializer`][fedbiomed.common.serializer.Serializer] to read the + network-exchanged file, and the `set_model_params` method to assign + the loaded values into the wrapped model. + """ + try: + self._model.reload(filename) + except FedbiomedModelError as exc: + msg = ( + f"{ErrorNumbers.FB304.value}: failed to import a model from " + f"a dump file: {exc}" + ) + logger.critical(msg) + raise FedbiomedTrainingPlanError(msg) from exc diff --git a/fedbiomed/common/training_plans/_sklearn_training_plan.py b/fedbiomed/common/training_plans/_sklearn_training_plan.py index c9ac22ff6..0a8b6beab 100644 --- a/fedbiomed/common/training_plans/_sklearn_training_plan.py +++ b/fedbiomed/common/training_plans/_sklearn_training_plan.py @@ -16,7 +16,7 @@ from fedbiomed.common.constants import ErrorNumbers, TrainingPlans from fedbiomed.common.data import NPDataLoader -from fedbiomed.common.exceptions import FedbiomedModelError, FedbiomedTrainingPlanError +from fedbiomed.common.exceptions import FedbiomedTrainingPlanError from fedbiomed.common.logger import logger from fedbiomed.common.metrics import MetricTypes from fedbiomed.common.models import SkLearnModel @@ -42,6 +42,11 @@ class SKLearnTrainingPlan(BaseTrainingPlan, metaclass=ABCMeta): training data at the beginning of the training routine. training_data_loader: Data loader used in the training routine. testing_data_loader: Data loader used in the validation routine. + + !!! info "Notes" + The trained model may be exported via the `export_model` method, + resulting in a dump file that may be reloded using `joblib.load` + outside of Fed-BioMed. """ _model_cls: Type[BaseEstimator] # wrapped model class @@ -266,83 +271,6 @@ def _classes_from_concatenated_train_test(self) -> np.ndarray: """ return np.unique([t for loader in (self.training_data_loader, self.testing_data_loader) for d, t in loader]) - def save( - self, - filename: str, - params: Union[None, Dict[str, np.ndarray], Dict[str, Any]] = None - ) -> None: - """Save the wrapped model and its trainable parameters. - - This method is designed for parameter communication. It - uses the joblib.dump function, which in turn uses pickle - to serialize the model. Note that unpickling objects can - lead to arbitrary code execution; hence use with care. - - Args: - filename: Path to the output file. - params: Model parameters to enforce and save. - This may either be a {name: array} parameters dict, or a - nested dict that stores such a parameters dict under the - 'model_params' key (in the context of the Round class). - - Notes: - Save can be called from Job or Round. - * From [`Round`][fedbiomed.node.round.Round] it is called with params (as a complex dict). - * From [`Job`][fedbiomed.researcher.job.Job] it is called with no params in constructor, and - with params in update_parameters. - """ - # Optionally overwrite the wrapped model's weights. - if params: - if isinstance(params.get('model_params'), dict): # in a Round - params = params["model_params"] - # for key, val in params.items(): - # setattr(self._model, key, val) - self._model.set_weights(params) - # Save the wrapped model (using joblib, hence pickle). - self._model.export(filename) - - def load( - self, - filename: str, - to_params: bool = False - ) -> Union[BaseEstimator, Dict[str, Dict[str, np.ndarray]]]: - """Load a scikit-learn model dump, overwriting the wrapped model. - - This method uses the joblib.load function, which in turn uses - pickle to deserialize the model. Note that unpickling objects - can lead to arbitrary code execution; hence use with care. - - This function updates the `_model` private attribute with the - loaded instance, and returns either that same model or a dict - wrapping its trainable parameters. - - Args: - filename: The path to the pickle file to load. - to_params: Whether to return the model's parameters - wrapped as a dict rather than the model instance. - - Notes: - Load can be called from a Job or Round: - * From [`Round`][fedbiomed.node.round.Round] it is called to return the model. - * From [`Job`][fedbiomed.researcher.job.Job] it is called with to return its parameters dict. - - Returns: - Dictionary with the loaded parameters. - """ - # Deserialize the dump, type-check the instance and assign it. - try: - self._model.reload(filename) - except FedbiomedModelError as exc: - msg = f"{ErrorNumbers.FB304.value}: failed to reload wrapped model: {exc}" - logger.critical(msg) - raise FedbiomedTrainingPlanError(msg) from exc - - # Optionally return the model's pseudo state dict instead of it. - if to_params: - params = self._model.get_weights() - return {"model_params": params} - return self.model() - def type(self) -> TrainingPlans: """Getter for training plan type """ return self.__type diff --git a/fedbiomed/common/training_plans/_torchnn.py b/fedbiomed/common/training_plans/_torchnn.py index 8e8c7fec8..30621e757 100644 --- a/fedbiomed/common/training_plans/_torchnn.py +++ b/fedbiomed/common/training_plans/_torchnn.py @@ -31,12 +31,11 @@ class TorchTrainingPlan(BaseTrainingPlan, metaclass=ABCMeta): An abstraction over pytorch module to run pytorch models and scripts on node side. Researcher model (resp. params) will be: - 1. saved on a '*.py' (resp. '*.pt') files, + 1. saved on a '*.py' (resp. '*.mpk') files, 2. uploaded on a HTTP server (network layer), 3. then Downloaded from the HTTP server on node side, 4. finally, read and executed on node side. - Researcher must define/override: - a `training_data()` function - a `training_step()` function @@ -52,6 +51,11 @@ class TorchTrainingPlan(BaseTrainingPlan, metaclass=ABCMeta): correction_state: an OrderedDict of {'parameter name': torch.Tensor} where the keys correspond to the names of the model parameters contained in self._model.named_parameters(), and the values correspond to the correction to be applied to that parameter. + + !!! info "Notes" + The trained model may be exported via the `export_model` method, + resulting in a dump file that may be reloded using `torch.save` + outside of Fed-BioMed. """ def __init__(self): @@ -575,36 +579,6 @@ def testing_routine( finally: self.model().train() # restore training behaviors - # provided by fedbiomed - def save(self, filename: str, params: dict = None) -> None: - """Save the torch training parameters from this training plan or from given `params` to a file - - Args: - filename (str): Path to the destination file - params (dict): Parameters to save to a file, should be structured as a torch state_dict() - - """ - if params is not None: - return torch.save(params, filename) - else: - return self._model.export(filename) - - # provided by fedbiomed - def load(self, filename: str, to_params: bool = False) -> dict: - """Load the torch training parameters to this training plan or to a data structure from a file - - Args: - filename: path to the source file - to_params: if False, load params to this pytorch object; if True load params to a data structure - - Returns: - Contains parameters - """ - params = torch.load(filename) - if to_params is False: - self._model.reload(filename) - return params - def set_aggregator_args(self, aggregator_args: Dict[str, Any]): """Handles and loads aggregators arguments sent through MQTT and file exchanged system. If sent through file exchanged system, loads the arguments. diff --git a/fedbiomed/researcher/aggregators/scaffold.py b/fedbiomed/researcher/aggregators/scaffold.py index 990e4c48b..494765c7d 100644 --- a/fedbiomed/researcher/aggregators/scaffold.py +++ b/fedbiomed/researcher/aggregators/scaffold.py @@ -10,9 +10,11 @@ import numpy as np import torch + from fedbiomed.common.logger import logger from fedbiomed.common.constants import TrainingPlans from fedbiomed.common.exceptions import FedbiomedAggregatorError +from fedbiomed.common.serializer import Serializer from fedbiomed.common.training_plans import BaseTrainingPlan from fedbiomed.researcher.aggregators.aggregator import Aggregator @@ -96,7 +98,7 @@ def __init__(self, server_lr: float = 1., fds: Optional[FederatedDataSet] = None self.nodes_lr: Dict[str, List[float]] = {} if fds is not None: self.set_fds(fds) - + self._aggregator_args = {} # we need `_aggregator_args` to be not None #self.update_aggregator_params()FedbiomedAggregatorError: @@ -153,7 +155,7 @@ def aggregate(self, # Compute the new aggregated model parameters. aggregated_parameters = self.scaling(model_params, global_model) - + # At round 0, initialize zero-valued correction states. if n_round == 0: self.init_correction_states(global_model, node_ids) @@ -175,7 +177,7 @@ def create_aggregator_args(self, Returns: Tuple[Dict, Dict]: first dictionary contains parameters that will be sent through MQTT message service, second dictionary parameters that will be sent through file exchange message. - Aggregators args are dictionary mapping node_id to SCAFFOLD parameters specific to + Aggregators args are dictionary mapping node_id to SCAFFOLD parameters specific to each `Nodes`. """ @@ -208,7 +210,7 @@ def check_values(self, n_updates: int, training_plan: BaseTrainingPlan) -> True: Args: n_updates (int): number of updates. Must be non-zero and an integer. - training_plan (BaseTrainingPlan): training plan. used for checking if optimizer is SGD, otherwise, + training_plan (BaseTrainingPlan): training plan. used for checking if optimizer is SGD, otherwise, triggers warning. Raises: @@ -217,7 +219,7 @@ def check_values(self, n_updates: int, training_plan: BaseTrainingPlan) -> True: FedbiomedAggregatorError: triggered if number of updates equals 0 or is not an integer FedbiomedAggregatorError: triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset] has not been set. - + """ if n_updates is None: raise FedbiomedAggregatorError("Cannot perform Scaffold: missing 'num_updates' entry in the training_args") @@ -353,8 +355,8 @@ def update_correction_states(self, eta_l: nodes' learning rate x: global model before updates y_i: local model updates - - Remark: + + Remark: c^{t=0} = 0 Args: @@ -374,7 +376,7 @@ def update_correction_states(self, total_nb_nodes = len(self._fds.node_ids()) # Compute the node-wise average of corrected gradients (ACG_i). # i.e. (x^t - y_i^t}) / (K * eta_l) - local_state_updates: Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] = {} + local_state_updates: Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] = {} for node_id, params in local_models.items(): local_state_updates[node_id] = { key: (global_model[key] - val) / (self.nodes_lr[node_id][idx] * n_updates) @@ -445,10 +447,10 @@ def save_state( ) -> Dict[str, Any]: # adding aggregator parameters to the breakpoint that wont be sent to nodes self._aggregator_args['server_lr'] = self.server_lr - + # saving global state variable into a file - filename = os.path.join(breakpoint_path, 'global_state_' + str(uuid.uuid4()) + '.pt') - training_plan.save(filename, self.global_state) + filename = os.path.join(breakpoint_path, f"global_state_{uuid.uuid4()}.mpk") + Serializer.dump(self.global_state, filename) self._aggregator_args['global_state_filename'] = filename # adding aggregator parameters that will be sent to nodes afterwards return super().save_state( @@ -462,7 +464,7 @@ def load_state(self, state: Dict[str, Any] = None): # loading global state global_state_filename = self._aggregator_args['global_state_filename'] - self.global_state = training_plan.load(global_state_filename, to_params=True) + self.global_state = Serializer.load(global_state_filename) for node_id in self._aggregator_args['aggregator_correction'].keys(): arg_filename = self._aggregator_args['aggregator_correction'][node_id] diff --git a/fedbiomed/researcher/job.py b/fedbiomed/researcher/job.py index aaa1730d3..f70f95c43 100644 --- a/fedbiomed/researcher/job.py +++ b/fedbiomed/researcher/job.py @@ -534,8 +534,9 @@ def update_parameters( sys.exit(-1) def save_state(self, breakpoint_path: str) -> dict: - """Creates current state of the job to be included in a breakpoint. Includes creating links to files included - in the job state. + """Creates current state of the job to be included in a breakpoint. + + Includes creating links to files included in the job state. Args: breakpoint_path: path to the existing breakpoint directory @@ -554,15 +555,15 @@ def save_state(self, breakpoint_path: str) -> dict: } state['model_params_path'] = create_unique_link( - breakpoint_path, - 'aggregated_params_current', '.pt', + breakpoint_path, 'aggregated_params_current', '.mpk', os.path.join('..', os.path.basename(state["model_params_path"])) ) for round_replies in state['training_replies']: for response in round_replies: - node_params_path = create_unique_file_link(breakpoint_path, - response['params_path']) + node_params_path = create_unique_file_link( + breakpoint_path, response['params_path'] + ) response['params_path'] = node_params_path return state @@ -580,12 +581,11 @@ def load_state(self, saved_state: Dict[str, Any]) -> None: self.update_parameters(filename=saved_state.get("model_params_path")) # Reloadthe latest training replies. self._training_replies = self._load_training_replies( - saved_state.get('training_replies'), - self._training_plan.load + saved_state.get('training_replies') ) @staticmethod - def _save_training_replies(training_replies: Dict[int, Responses]) -> List[List[dict]]: + def _save_training_replies(training_replies: Dict[int, Responses]) -> List[List[Dict[str, Any]]]: """Extracts a copy of `training_replies` and prepares it for saving in breakpoint - strip unwanted fields @@ -609,13 +609,11 @@ def _save_training_replies(training_replies: Dict[int, Responses]) -> List[List[ return converted_training_replies @staticmethod - def _load_training_replies(bkpt_training_replies: List[List[dict]], - func_load_params: Callable) -> Dict[int, Responses]: + def _load_training_replies(bkpt_training_replies: List[List[dict]]) -> Dict[int, Responses]: """Reads training replies from a formatted breakpoint file, and build a job training replies data structure . Args: bkpt_training_replies: Extract from training replies saved in breakpoint - func_load_params: Function for loading parameters from file to training replies data structure Returns: Training replies of already executed rounds of the job @@ -626,9 +624,7 @@ def _load_training_replies(bkpt_training_replies: List[List[dict]], loaded_training_reply = Responses(bkpt_training_replies[round_]) # reload parameters from file params_path for node in loaded_training_reply: - node['params'] = func_load_params( - node['params_path'], to_params=True)['model_params'] - + node['params'] = Serializer.load(node['params_path'])['model_params'] training_replies[round_] = loaded_training_reply return training_replies @@ -776,38 +772,31 @@ def training_args(self, training_args: dict): def start_training(self): """Sends training task to nodes and waits for the responses""" - + # Run import statements (very unsafely). for i in self._training_plan._dependencies: exec(i, globals()) - is_failed = False - error_message = '' - - # Run the training routine - if not is_failed: - results = {} - try: - self._training_plan.set_dataset_path(self.dataset_path) - data_manager = self._training_plan.training_data() - tp_type = self._training_plan.type() - data_manager.load(tp_type=tp_type) - train_loader, test_loader = data_manager.split(test_ratio=0) - self._training_plan.training_data_loader = train_loader - self._training_plan.testing_data_loader = test_loader - self._training_plan.training_routine() - except Exception as e: - is_failed = True - error_message = "Cannot train model in job : " + str(e) + error = "" - if not is_failed: + # Run the training routine. + try: + self._training_plan.set_dataset_path(self.dataset_path) + data_manager = self._training_plan.training_data() + tp_type = self._training_plan.type() + data_manager.load(tp_type=tp_type) + train_loader, test_loader = data_manager.split(test_ratio=0) + self._training_plan.training_data_loader = train_loader + self._training_plan.testing_data_loader = test_loader + self._training_plan.training_routine() + except Exception as exc: + logger.error("Cannot train model in job: %s", repr(exc)) + # Save the current parameters. + else: try: - # TODO : should test status code but not yet returned - # by upload_file - filename = environ['TMP_DIR'] + '/local_params_' + str(uuid.uuid4()) + '.pt' - self._training_plan.save(filename, results) - except Exception as e: - is_failed = True - error_message = "Cannot write results: " + str(e) - - if error_message != '': - logger.error(error_message) + # TODO: should test status code but not yet returned by upload_file + path = os.path.join( + environ["TMP_DIR"], f"local_params_{uuid.uuid4()}.mpk" + ) + Serializer.dump(self._training_plan.get_model_params(), path) + except Exception as exc: + logger.error("Cannot write results: %s", repr(exc)) diff --git a/tests/test_fedbiosklearn.py b/tests/test_fedbiosklearn.py index 158a9c936..9cf66c70c 100644 --- a/tests/test_fedbiosklearn.py +++ b/tests/test_fedbiosklearn.py @@ -17,7 +17,7 @@ import logging import numpy as np from copy import deepcopy -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch from sklearn.linear_model import SGDClassifier @@ -112,29 +112,19 @@ def test_sklearntrainingplanbasicinheritance_02_training_testing_routine(self): [x for x in np.unique(X)] ) - def test_sklearntrainingplanbasicinheritance_03_save(self): + def test_sklearntrainingplanbasicinheritance_03_export_model(self): training_plan = SKLearnTrainingPlan() saved_params = [] def mocked_joblib_dump(obj, *args, **kwargs): saved_params.append(obj) - # Base case where params are not provided to save function with patch('fedbiomed.common.models._sklearn.BaseSkLearnModel.export', side_effect=mocked_joblib_dump): - training_plan.save('filename') + training_plan.export_model('filename') self.assertEqual(saved_params[-1], 'filename') - for param in ({'coef_': 0.42, 'intercept_': 0.42}, {'model_params': {'coef_': 0.42, 'intercept_': 0.42}}): - with ( - patch('fedbiomed.common.models._sklearn.BaseSkLearnModel.export', side_effect=mocked_joblib_dump), - patch('fedbiomed.common.models._sklearn.BaseSkLearnModel.set_weights') as patch_set_weights - ): - training_plan.save('filename', params=param) - self.assertEqual(saved_params[-1], 'filename') - patch_set_weights.assert_called_once_with({'coef_': 0.42, 'intercept_': 0.42}) - - def test_sklearntrainingplanbasicinheritance_04_load(self): + def test_sklearntrainingplanbasicinheritance_04_import_model(self): training_plan = SKLearnTrainingPlan() # Saved object is not the correct type @@ -143,16 +133,16 @@ def test_sklearntrainingplanbasicinheritance_04_load(self): return_value=MagicMock() ): with self.assertRaises(FedbiomedTrainingPlanError): - training_plan.load('filename') + training_plan.import_model('filename') # Option to retrieve model parameters instead of full model from load function - init_params = {'coef_': 0.42, 'intercept_': 0.42} - with (patch('fedbiomed.common.models._sklearn.BaseSkLearnModel.get_weights', return_value=init_params), - patch('fedbiomed.common.models._sklearn.BaseSkLearnModel.reload')): - params = training_plan.load('filename', to_params=True) - self.assertDictEqual(params, {'model_params': {'coef_': 0.42, 'intercept_': 0.42}}) - params = training_plan.after_training_params() - self.assertDictEqual(params, {'coef_': 0.42, 'intercept_': 0.42}) + model = create_autospec(SGDClassifier, instance=True) + with patch( + 'fedbiomed.common.models.BaseSkLearnModel._reload', + return_value=model + ): + training_plan.import_model('filename') + self.assertIs(training_plan._model.model, model) class TestSklearnTrainingPlanPartialFit(unittest.TestCase): @@ -309,20 +299,21 @@ def test_sklearntrainingplancommonfunctionalities_01_model_args(self): def test_sklearntrainingplancommonfunctionalities_02_save_and_load(self): for training_plan in self.training_plans: randomfile = tempfile.NamedTemporaryFile() - training_plan.save(randomfile.name) - orig_params = deepcopy(training_plan.model().get_params()) + training_plan.export_model(randomfile.name) + orig_params = deepcopy(training_plan.get_model_params()) # ensure file has been created and has size > 0 self.assertTrue(os.path.exists(randomfile.name) and os.path.getsize(randomfile.name) > 0) new_tp = self.subclass_types[training_plan.parent_type]() new_tp.post_init({'n_classes': 2, 'n_features': 1}, FakeTrainingArgs()) - - m = new_tp.load(randomfile.name) + new_tp.import_model(randomfile.name) # ensure output of load is the same as original parameters - self.assertDictEqual(m.get_params(), orig_params) - # ensure that the newly loaded model has the same params as the original model - self.assertDictEqual(training_plan.model().get_params(), new_tp.model().get_params()) + load_params = new_tp.get_model_params() + self.assertEqual(load_params.keys(), orig_params.keys()) + self.assertTrue(all( + np.all(load_params[k] == orig_params[k]) for k in load_params + )) @patch.multiple(SKLearnTrainingPlan, __abstractmethods__=set()) def test_sklearntrainingplancommonfunctionalities_03_getters(self): diff --git a/tests/test_job.py b/tests/test_job.py index 142639302..8a4302913 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -1,11 +1,10 @@ import copy -import inspect import os import shutil -from typing import Dict, Any import unittest -from unittest.mock import patch, MagicMock import uuid +from typing import Any, Dict +from unittest.mock import MagicMock, patch import numpy as np import torch @@ -21,11 +20,12 @@ from testsupport.fake_uuid import FakeUuid from fedbiomed.common.constants import ErrorNumbers +from fedbiomed.common.training_args import TrainingArgs from fedbiomed.researcher.environ import environ from fedbiomed.researcher.job import Job from fedbiomed.researcher.requests import Requests from fedbiomed.researcher.responses import Responses -from fedbiomed.common.training_args import TrainingArgs + class TestJob(ResearcherTestCase): @@ -616,8 +616,7 @@ def test_job_16_private_load_training_replies( """tests if `_load_training_replies` is loading file content from path file and is building a proper training replies structure from breakpoint info """ - - # first test with a model done with pytorch + # Declare mock model parameters, for torch and scikit-learn. pytorch_params = { # dont need other fields 'model_params': torch.Tensor([1, 3, 5, 7]) @@ -630,11 +629,6 @@ def test_job_16_private_load_training_replies( fds = MagicMock() fds.data = MagicMock(return_value={}) - # mock Pytorch model object - model_torch = MagicMock(return_value=None) - model_torch.save = MagicMock(return_value=None) - func_torch_loadparams = MagicMock(return_value=pytorch_params) - # mock Responses # # nota: works fine only with one instance of Response active at a time thus @@ -650,10 +644,12 @@ def side_responses_getitem(arg, *args): patch_responses_init.return_value = None patch_responses_getitem.side_effect = side_responses_getitem - # instantiate job - test_job_torch = Job(training_plan_class=model_torch, - training_args=TrainingArgs({"batch_size": 12}, only_required=False), - data=fds) + # instantiate job with a mock training plan + test_job_torch = Job( + training_plan_class=MagicMock(), + training_args=TrainingArgs({"batch_size": 12}, only_required=False), + data=fds + ) # second create a `training_replies` variable loaded_training_replies_torch = [ [ @@ -661,35 +657,44 @@ def side_responses_getitem(arg, *args): "msg": "", "dataset_id": "dataset_1234", "node_id": "node_1234", - "params_path": "/path/to/file/param.pt", + "params_path": "/path/to/file/param.mpk", "timing": {"time": 0} }, {"success": True, "msg": "", "dataset_id": "dataset_4567", "node_id": "node_4567", - "params_path": "/path/to/file/param2.pt", + "params_path": "/path/to/file/param2.mpk", "timing": {"time": 0} } ] ] # action - torch_training_replies = test_job_torch._load_training_replies( - loaded_training_replies_torch, - func_torch_loadparams + with patch( + "fedbiomed.common.serializer.Serializer.load", return_value=pytorch_params + ) as load_patch: + torch_training_replies = test_job_torch._load_training_replies( + loaded_training_replies_torch + ) + self.assertEqual(load_patch.call_count, 2) + load_patch.assert_called_with( + loaded_training_replies_torch[0][1]["params_path"], ) - - self.assertTrue(type(torch_training_replies) is dict) + self.assertIsInstance(torch_training_replies, dict) # heuristic check `training_replies` for existing field in input self.assertEqual( torch_training_replies[0][0]['node_id'], loaded_training_replies_torch[0][0]['node_id']) # check `training_replies` for pytorch models - self.assertTrue(torch.isclose(torch_training_replies[0][1]['params'], - pytorch_params['model_params']).all()) - self.assertTrue(torch_training_replies[0][1]['params_path'], - "/path/to/file/param2.pt") + self.assertTrue(torch.eq( + torch_training_replies[0][1]['params'], + pytorch_params['model_params'] + ).all()) + self.assertEqual( + torch_training_replies[0][1]['params_path'], + "/path/to/file/param2.mpk" + ) self.assertTrue(isinstance(torch_training_replies[0], Responses)) # #### REPRODUCE TESTS BUT FOR SKLEARN MODELS AND 2 ROUNDS @@ -698,7 +703,7 @@ def side_responses_getitem(arg, *args): [ { # dummy - "params_path": "/path/to/file/param_sklearn.pt" + "params_path": "/path/to/file/param_sklearn.mpk" } ], [ @@ -706,36 +711,43 @@ def side_responses_getitem(arg, *args): "msg": "", "dataset_id": "dataset_8888", "node_id": "node_8888", - "params_path": "/path/to/file/param2_sklearn.pt", + "params_path": "/path/to/file/param2_sklearn.mpk", "timing": {"time": 6} } ] ] - # mock sklearn model object - model_sklearn = MagicMock(return_value=None) - model_sklearn.save = MagicMock(return_value=None) - func_sklearn_loadparams = MagicMock(return_value=sklearn_params) # instantiate job - test_job_sklearn = Job(training_plan_class=model_sklearn, - training_args=TrainingArgs({"batch_size": 12}, only_required=False), - data=fds) + test_job_sklearn = Job( + training_plan_class=MagicMock(), + training_args=TrainingArgs({"batch_size": 12}, only_required=False), + data=fds + ) # action - sklearn_training_replies = test_job_sklearn._load_training_replies( - loaded_training_replies_sklearn, - func_sklearn_loadparams + with patch( + "fedbiomed.common.serializer.Serializer.load", return_value=sklearn_params + ) as load_patch: + sklearn_training_replies = test_job_sklearn._load_training_replies( + loaded_training_replies_sklearn + ) + self.assertEqual(load_patch.call_count, 2) + load_patch.assert_called_with( + loaded_training_replies_sklearn[1][0]["params_path"], ) - # heuristic check `training_replies` for existing field in input self.assertEqual( sklearn_training_replies[1][0]['node_id'], loaded_training_replies_sklearn[1][0]['node_id']) # check `training_replies` for sklearn models - self.assertTrue(np.allclose(sklearn_training_replies[1][0]['params'], - sklearn_params['model_params'])) - self.assertTrue(sklearn_training_replies[1][0]['params_path'], - "/path/to/file/param2_sklearn.pt") + self.assertTrue(np.allclose( + sklearn_training_replies[1][0]['params'], + sklearn_params['model_params'] + )) + self.assertEqual( + sklearn_training_replies[1][0]['params_path'], + "/path/to/file/param2_sklearn.mpk" + ) self.assertTrue(isinstance(sklearn_training_replies[0], Responses)) diff --git a/tests/test_local_job.py b/tests/test_local_job.py index c723e74a8..dcc66a05e 100644 --- a/tests/test_local_job.py +++ b/tests/test_local_job.py @@ -127,6 +127,7 @@ def test_local_job_06_start_training(self, mock_logger_error): self.model.training_routine.assert_called_once() # Test failure during training + mock_logger_error.reset_mock() self.model.training_routine.side_effect = Exception self.local_job.start_training() mock_logger_error.assert_called_once() diff --git a/tests/test_torchnn.py b/tests/test_torchnn.py index 8d0d643af..d765c52f9 100644 --- a/tests/test_torchnn.py +++ b/tests/test_torchnn.py @@ -1,25 +1,27 @@ import copy import itertools -import types -import unittest -import os import logging +import os import re -from fedbiomed.common.models import TorchModel +import types +import unittest +from unittest.mock import MagicMock, patch import torch import torch.nn as nn from torch.autograd import Variable - -from unittest.mock import patch, MagicMock -from torch.utils.data import DataLoader, Dataset -from torch.optim import Adam, SGD from torch.nn import Module +from torch.optim import Adam, SGD from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, Dataset + from testsupport.base_fake_training_plan import BaseFakeTrainingPlan + from fedbiomed.common.exceptions import FedbiomedTrainingPlanError from fedbiomed.common.training_plans import TorchTrainingPlan, BaseTrainingPlan from fedbiomed.common.metrics import MetricTypes +from fedbiomed.common.models import TorchModel + # define TP outside of test class to avoid indentation problems when exporting class to file @@ -295,45 +297,6 @@ def test_torch_training_plan_08_getters(self): self.assertDictEqual(r_ta, ta) self.assertDictEqual(r_ip, ip) - def test_torch_training_plan_09_save_and_load_params(self): - """ Test save and load parameters """ - tp1 = TorchTrainingPlan() - tp1._model = TorchModel(torch.nn.Module()) - paramfile = self.tmpdir + '/tmp_params.pt' - - if os.path.isfile(paramfile): - os.remove(paramfile) - - # save/load from/to variable - tp1.save(paramfile, self.params) - self.assertTrue(os.path.isfile(paramfile)) - params2 = tp1.load(paramfile, True) - - self.assertTrue(type(params2) is dict) - self.assertEqual(self.params, params2) - - # save/load from/to object params - tp1.save(paramfile) - tp2 = TorchTrainingPlan() - tp2._model = TorchModel(torch.nn.Module()) - tp2.load(paramfile) - self.assertTrue(type(params2) is dict) - - sd1 = tp1.model().state_dict() - sd2 = tp2.model().state_dict() - - # verify we have an equivalent state dict - for key in sd1: - self.assertTrue(key in sd2) - - for key in sd2: - self.assertTrue(key in sd1) - - for (key, value) in sd1.items(): - self.assertTrue(torch.all(torch.isclose(value, sd2[key]))) - - os.remove(paramfile) - @patch('torch.nn.Module.__call__') def test_torch_nn_03_testing_routine(self, patch_model_call): @@ -458,7 +421,7 @@ def test_torch_nn_04_logging_progress_computation(self): num_batches = 3 batch_size = 5 mock_dataset = MagicMock(spec=Dataset) - + tp.training_data_loader = MagicMock(spec=DataLoader(mock_dataset), batch_size=batch_size) tp._training_args = {'batch_size': batch_size, 'optimizer_args': {}, @@ -473,7 +436,7 @@ def test_torch_nn_04_logging_progress_computation(self): custom_dataset = self.CustomDataset() x_train = torch.Tensor(custom_dataset.X_train[:batch_size]) y_train = torch.Tensor(custom_dataset.Y_train[:batch_size]) - + dataset_size = num_batches * batch_size fake_data = {'modality1': x_train, 'modality2': x_train} fake_target = (y_train, y_train) @@ -514,7 +477,7 @@ def test_torchnn_05_num_updates(self): tp.training_step = MagicMock(return_value=Variable(torch.Tensor([0]), requires_grad=True)) tp._log_interval = 1000 # essentially disable logging tp._dry_run = False - + tp._dp_controller = FakeDPController() def setup_tp(tp, num_samples, batch_size, num_updates): @@ -524,7 +487,7 @@ def setup_tp(tp, num_samples, batch_size, num_updates): tp.training_data_loader = MagicMock(spec=DataLoader(MagicMock(spec=Dataset)), dataset=[1,2], batch_size=batch_size) - + tp.training_data_loader.__iter__.return_value = list(itertools.repeat( (MagicMock(spec=torch.Tensor), MagicMock(spec=torch.Tensor)), num_batches_per_epoch)) tp.training_data_loader.__len__.return_value = num_batches_per_epoch @@ -569,20 +532,20 @@ def setup_tp(tp, num_samples, batch_size, num_updates): tp.training_routine(None, None) self.assertEqual(tp._optimizer.step.call_count, 3) - + tp = setup_tp(tp, num_samples=10, batch_size=5, num_updates=6) tp._batch_maxnum = 3 tp.training_routine(None, None) self.assertEqual(tp._optimizer.step.call_count, 6) def test_torch_nn_06_compute_corrected_loss(self): - """test_torch_nn_06_compute_corrected_loss: + """test_torch_nn_06_compute_corrected_loss: checks: that fedavg and scaffold are equivalent if correction states are set to 0 """ def set_training_plan(model, aggregator_name:str, loss_value: float = .0): """Configure a TorchTrainingPlan with a given model. - + Args: model: a torch model aggregator_name: name of the aggregator method @@ -597,7 +560,7 @@ def set_training_plan(model, aggregator_name:str, loss_value: float = .0): tp.training_data_loader = MagicMock() tp._log_interval = 1000 # essentially disable logging tp._dry_run = False - + tp.aggregator_name = aggregator_name if aggregator_name == 'scaffold': for name, param in model.named_parameters(): @@ -622,7 +585,7 @@ def training_step(instance, data, target): tp.training_data_loader.dataset.__len__.return_value = dataset_size tp._num_updates = num_batches tp._training_args = {'batch_size': batch_size} - + tp._optimizer_args = {"lr" : 1e-3} tp._optimizer = torch.optim.Adam(model.parameters(), **tp._optimizer_args) tp._dp_controller = FakeDPController() @@ -634,22 +597,22 @@ def training_step(instance, data, target): 'dry_run': False, 'num_updates': None} return tp - + model = torch.nn.Linear(3, 1) tp_fedavg = set_training_plan(model, "fedavg", .1) tp_fedavg.training_routine(None, None) - + tp_scaffold = set_training_plan(model, "scaffold", .1) - + tp_scaffold.training_routine(None, None) - + # test that model trained with scaffold is equivalent to model trained with fedavg for (name, layer_fedavg), (name, layer_scaffold) in zip(tp_fedavg._model.model.state_dict().items(), tp_scaffold._model.model.state_dict().items()): self.assertTrue(torch.isclose(layer_fedavg, layer_scaffold).all()) def test_torch_nn_07_get_learning_rate(self): - """test_torch_nn_08_get_learning_rate: test we retrieve the appropriate + """test_torch_nn_08_get_learning_rate: test we retrieve the appropriate learning rate """ # first test wih basic optimizer (eg without learning rate scheduler) @@ -659,10 +622,10 @@ def test_torch_nn_07_get_learning_rate(self): dataset = torch.Tensor([[1, 2], [1, 1], [2, 2]]) target = torch.Tensor([1, 2, 2]) tp._optimizer = SGD(tp._model.parameters(), lr=lr) - + lr_extracted = tp.get_learning_rate() self.assertListEqual(lr_extracted, [lr]) - + # last test using a pytorch scheduler scheduler = LambdaLR(tp._optimizer, lambda e: 2*e) # this pytorch scheduler increase earning rate by twice its previous value @@ -675,7 +638,7 @@ def test_torch_nn_07_get_learning_rate(self): loss.backward() tp._optimizer.step() scheduler.step() - + # checks lr_extracted = tp.get_learning_rate() self.assertListEqual(lr_extracted, [lr * 2 * (e+1)])