Skip to content

Commit

Permalink
Add 'BaseTrainingPlan.(export|import)_model' and remove old load/save…
Browse files Browse the repository at this point in the history
… methods.
  • Loading branch information
pandrey-fr committed Mar 20, 2023
1 parent db70025 commit 219b9cd
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 304 deletions.
51 changes: 50 additions & 1 deletion fedbiomed/common/training_plans/_base_training_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from fedbiomed.common import utils
from fedbiomed.common.constants import ErrorNumbers, ProcessTypes
from fedbiomed.common.data import NPDataLoader
from fedbiomed.common.exceptions import FedbiomedError, FedbiomedTrainingPlanError
from fedbiomed.common.exceptions import (
FedbiomedError, FedbiomedModelError, FedbiomedTrainingPlanError
)
from fedbiomed.common.logger import logger
from fedbiomed.common.metrics import Metrics, MetricTypes
from fedbiomed.common.models import Model
Expand Down Expand Up @@ -532,3 +534,50 @@ def after_training_params(self) -> Dict[str, Any]:
The trained parameters to aggregate.
"""
return self.get_model_params()

def export_model(self, filename: str) -> None:
"""Export the wrapped model to a dump file.
Args:
filename: path to the file where the model will be saved.
!!! info "Notes":
This method is designed to save the model to a local dump
file for easy re-use by the same user, possibly outside of
Fed-BioMed. It is not designed to produce trustworthy data
dumps and is not used to exchange models and their weights
as part of the federated learning process.
To save the model parameters for sharing as part of the FL process,
use the `after_training_params` method (or `get_model_params` one
outside of a training context) and export results using
[`Serializer`][fedbiomed.common.serializer.Serializer].
"""
self._model.export(filename)

def import_model(self, filename: str) -> None:
"""Import and replace the wrapped model from a dump file.
Args:
filename: path to the file where the model has been exported.
!!! info "Notes":
This method is designed to load the model from a local dump
file, that might not be in a trustworthy format. It should
therefore only be used to re-load data exported locally and
not received from someone else, including other FL peers.
To load model parameters shared as part of the FL process, use the
[`Serializer`][fedbiomed.common.serializer.Serializer] to read the
network-exchanged file, and the `set_model_params` method to assign
the loaded values into the wrapped model.
"""
try:
self._model.reload(filename)
except FedbiomedModelError as exc:
msg = (
f"{ErrorNumbers.FB304.value}: failed to import a model from "
f"a dump file: {exc}"
)
logger.critical(msg)
raise FedbiomedTrainingPlanError(msg) from exc
84 changes: 6 additions & 78 deletions fedbiomed/common/training_plans/_sklearn_training_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from fedbiomed.common.constants import ErrorNumbers, TrainingPlans
from fedbiomed.common.data import NPDataLoader
from fedbiomed.common.exceptions import FedbiomedModelError, FedbiomedTrainingPlanError
from fedbiomed.common.exceptions import FedbiomedTrainingPlanError
from fedbiomed.common.logger import logger
from fedbiomed.common.metrics import MetricTypes
from fedbiomed.common.models import SkLearnModel
Expand All @@ -42,6 +42,11 @@ class SKLearnTrainingPlan(BaseTrainingPlan, metaclass=ABCMeta):
training data at the beginning of the training routine.
training_data_loader: Data loader used in the training routine.
testing_data_loader: Data loader used in the validation routine.
!!! info "Notes"
The trained model may be exported via the `export_model` method,
resulting in a dump file that may be reloded using `joblib.load`
outside of Fed-BioMed.
"""

_model_cls: Type[BaseEstimator] # wrapped model class
Expand Down Expand Up @@ -266,83 +271,6 @@ def _classes_from_concatenated_train_test(self) -> np.ndarray:
"""
return np.unique([t for loader in (self.training_data_loader, self.testing_data_loader) for d, t in loader])

def save(
self,
filename: str,
params: Union[None, Dict[str, np.ndarray], Dict[str, Any]] = None
) -> None:
"""Save the wrapped model and its trainable parameters.
This method is designed for parameter communication. It
uses the joblib.dump function, which in turn uses pickle
to serialize the model. Note that unpickling objects can
lead to arbitrary code execution; hence use with care.
Args:
filename: Path to the output file.
params: Model parameters to enforce and save.
This may either be a {name: array} parameters dict, or a
nested dict that stores such a parameters dict under the
'model_params' key (in the context of the Round class).
Notes:
Save can be called from Job or Round.
* From [`Round`][fedbiomed.node.round.Round] it is called with params (as a complex dict).
* From [`Job`][fedbiomed.researcher.job.Job] it is called with no params in constructor, and
with params in update_parameters.
"""
# Optionally overwrite the wrapped model's weights.
if params:
if isinstance(params.get('model_params'), dict): # in a Round
params = params["model_params"]
# for key, val in params.items():
# setattr(self._model, key, val)
self._model.set_weights(params)
# Save the wrapped model (using joblib, hence pickle).
self._model.export(filename)

def load(
self,
filename: str,
to_params: bool = False
) -> Union[BaseEstimator, Dict[str, Dict[str, np.ndarray]]]:
"""Load a scikit-learn model dump, overwriting the wrapped model.
This method uses the joblib.load function, which in turn uses
pickle to deserialize the model. Note that unpickling objects
can lead to arbitrary code execution; hence use with care.
This function updates the `_model` private attribute with the
loaded instance, and returns either that same model or a dict
wrapping its trainable parameters.
Args:
filename: The path to the pickle file to load.
to_params: Whether to return the model's parameters
wrapped as a dict rather than the model instance.
Notes:
Load can be called from a Job or Round:
* From [`Round`][fedbiomed.node.round.Round] it is called to return the model.
* From [`Job`][fedbiomed.researcher.job.Job] it is called with to return its parameters dict.
Returns:
Dictionary with the loaded parameters.
"""
# Deserialize the dump, type-check the instance and assign it.
try:
self._model.reload(filename)
except FedbiomedModelError as exc:
msg = f"{ErrorNumbers.FB304.value}: failed to reload wrapped model: {exc}"
logger.critical(msg)
raise FedbiomedTrainingPlanError(msg) from exc

# Optionally return the model's pseudo state dict instead of it.
if to_params:
params = self._model.get_weights()
return {"model_params": params}
return self.model()

def type(self) -> TrainingPlans:
"""Getter for training plan type """
return self.__type
38 changes: 6 additions & 32 deletions fedbiomed/common/training_plans/_torchnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ class TorchTrainingPlan(BaseTrainingPlan, metaclass=ABCMeta):
An abstraction over pytorch module to run pytorch models and scripts on node side. Researcher model (resp. params)
will be:
1. saved on a '*.py' (resp. '*.pt') files,
1. saved on a '*.py' (resp. '*.mpk') files,
2. uploaded on a HTTP server (network layer),
3. then Downloaded from the HTTP server on node side,
4. finally, read and executed on node side.
Researcher must define/override:
- a `training_data()` function
- a `training_step()` function
Expand All @@ -52,6 +51,11 @@ class TorchTrainingPlan(BaseTrainingPlan, metaclass=ABCMeta):
correction_state: an OrderedDict of {'parameter name': torch.Tensor} where the keys correspond to the names of
the model parameters contained in self._model.named_parameters(), and the values correspond to the
correction to be applied to that parameter.
!!! info "Notes"
The trained model may be exported via the `export_model` method,
resulting in a dump file that may be reloded using `torch.save`
outside of Fed-BioMed.
"""

def __init__(self):
Expand Down Expand Up @@ -575,36 +579,6 @@ def testing_routine(
finally:
self.model().train() # restore training behaviors

# provided by fedbiomed
def save(self, filename: str, params: dict = None) -> None:
"""Save the torch training parameters from this training plan or from given `params` to a file
Args:
filename (str): Path to the destination file
params (dict): Parameters to save to a file, should be structured as a torch state_dict()
"""
if params is not None:
return torch.save(params, filename)
else:
return self._model.export(filename)

# provided by fedbiomed
def load(self, filename: str, to_params: bool = False) -> dict:
"""Load the torch training parameters to this training plan or to a data structure from a file
Args:
filename: path to the source file
to_params: if False, load params to this pytorch object; if True load params to a data structure
Returns:
Contains parameters
"""
params = torch.load(filename)
if to_params is False:
self._model.reload(filename)
return params

def set_aggregator_args(self, aggregator_args: Dict[str, Any]):
"""Handles and loads aggregators arguments sent through MQTT and
file exchanged system. If sent through file exchanged system, loads the arguments.
Expand Down
26 changes: 14 additions & 12 deletions fedbiomed/researcher/aggregators/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

import numpy as np
import torch

from fedbiomed.common.logger import logger
from fedbiomed.common.constants import TrainingPlans
from fedbiomed.common.exceptions import FedbiomedAggregatorError
from fedbiomed.common.serializer import Serializer
from fedbiomed.common.training_plans import BaseTrainingPlan

from fedbiomed.researcher.aggregators.aggregator import Aggregator
Expand Down Expand Up @@ -96,7 +98,7 @@ def __init__(self, server_lr: float = 1., fds: Optional[FederatedDataSet] = None
self.nodes_lr: Dict[str, List[float]] = {}
if fds is not None:
self.set_fds(fds)

self._aggregator_args = {} # we need `_aggregator_args` to be not None
#self.update_aggregator_params()FedbiomedAggregatorError:

Expand Down Expand Up @@ -153,7 +155,7 @@ def aggregate(self,

# Compute the new aggregated model parameters.
aggregated_parameters = self.scaling(model_params, global_model)

# At round 0, initialize zero-valued correction states.
if n_round == 0:
self.init_correction_states(global_model, node_ids)
Expand All @@ -175,7 +177,7 @@ def create_aggregator_args(self,
Returns:
Tuple[Dict, Dict]: first dictionary contains parameters that will be sent through MQTT message
service, second dictionary parameters that will be sent through file exchange message.
Aggregators args are dictionary mapping node_id to SCAFFOLD parameters specific to
Aggregators args are dictionary mapping node_id to SCAFFOLD parameters specific to
each `Nodes`.
"""

Expand Down Expand Up @@ -208,7 +210,7 @@ def check_values(self, n_updates: int, training_plan: BaseTrainingPlan) -> True:
Args:
n_updates (int): number of updates. Must be non-zero and an integer.
training_plan (BaseTrainingPlan): training plan. used for checking if optimizer is SGD, otherwise,
training_plan (BaseTrainingPlan): training plan. used for checking if optimizer is SGD, otherwise,
triggers warning.
Raises:
Expand All @@ -217,7 +219,7 @@ def check_values(self, n_updates: int, training_plan: BaseTrainingPlan) -> True:
FedbiomedAggregatorError: triggered if number of updates equals 0 or is not an integer
FedbiomedAggregatorError: triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset]
has not been set.
"""
if n_updates is None:
raise FedbiomedAggregatorError("Cannot perform Scaffold: missing 'num_updates' entry in the training_args")
Expand Down Expand Up @@ -353,8 +355,8 @@ def update_correction_states(self,
eta_l: nodes' learning rate
x: global model before updates
y_i: local model updates
Remark:
Remark:
c^{t=0} = 0
Args:
Expand All @@ -374,7 +376,7 @@ def update_correction_states(self,
total_nb_nodes = len(self._fds.node_ids())
# Compute the node-wise average of corrected gradients (ACG_i).
# i.e. (x^t - y_i^t}) / (K * eta_l)
local_state_updates: Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] = {}
local_state_updates: Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] = {}
for node_id, params in local_models.items():
local_state_updates[node_id] = {
key: (global_model[key] - val) / (self.nodes_lr[node_id][idx] * n_updates)
Expand Down Expand Up @@ -445,10 +447,10 @@ def save_state(
) -> Dict[str, Any]:
# adding aggregator parameters to the breakpoint that wont be sent to nodes
self._aggregator_args['server_lr'] = self.server_lr

# saving global state variable into a file
filename = os.path.join(breakpoint_path, 'global_state_' + str(uuid.uuid4()) + '.pt')
training_plan.save(filename, self.global_state)
filename = os.path.join(breakpoint_path, f"global_state_{uuid.uuid4()}.mpk")
Serializer.dump(self.global_state, filename)
self._aggregator_args['global_state_filename'] = filename
# adding aggregator parameters that will be sent to nodes afterwards
return super().save_state(
Expand All @@ -462,7 +464,7 @@ def load_state(self, state: Dict[str, Any] = None):

# loading global state
global_state_filename = self._aggregator_args['global_state_filename']
self.global_state = training_plan.load(global_state_filename, to_params=True)
self.global_state = Serializer.load(global_state_filename)

for node_id in self._aggregator_args['aggregator_correction'].keys():
arg_filename = self._aggregator_args['aggregator_correction'][node_id]
Expand Down
Loading

0 comments on commit 219b9cd

Please sign in to comment.