diff --git a/docs/about/features_index/workflowinterface.rst b/docs/about/features_index/workflowinterface.rst index 4f567b4342..574d0e9e2d 100644 --- a/docs/about/features_index/workflowinterface.rst +++ b/docs/about/features_index/workflowinterface.rst @@ -18,14 +18,25 @@ A new OpenFL interface that gives significantly more flexility to researchers in There are several modifications we make in our reimagined version of this interface that are necessary for federated learning: 1. *Placement*: Metaflow's :code:`@step` decorator is replaced by placement decorators that specify where a task will run. In horizontal federated learning, there are server (or aggregator) and client (or collaborator) nodes. Tasks decorated by :code:`@aggregator` will run on the aggregator node, and :code:`@collaborator` will run on the collaborator node. These placement decorators are interpreted by *Runtime* implementations: these do the heavy lifting of figuring out how to get the state of the current task to another process or node. -2. *Runtime*: Each flow has a :code:`.runtime` attribute. The runtime encapsulates the details of the infrastucture where the flow will run. We support the LocalRuntime for simulating experiments on local node and FederatedRuntime to launch experiments on distributed infrastructure. +2. *Runtime*: The runtime encapsulates the details of the infrastucture where the flow will run. We support the LocalRuntime for simulating experiments on local node and FederatedRuntime to launch experiments on distributed infrastructure. 3. *Conditional branches*: Perform different tasks if a criteria is met 4. *Loops*: Internal loops are within a flow; this is necessary to support rounds of training where the same sequence of tasks is performed repeatedly. How to use it? ============== -Let's start with the basics. A flow is intended to define the entirety of federated learning experiment. Every flow begins with the :code:`start` task and concludes with the :code:`end` task. At each step in the flow, attributes can be defined, modified, or deleted. Attributes get passed forward to the next step in the flow, which is defined by the name of the task passed to the :code:`next` function. In the line before each task, there is a **placement decorator**. The placement decorator defines where that task will be run. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with start and concludes with the end task. In the following example, the aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the start task, where the list of collaborators is extracted from the runtime (:code:`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in self.next, aggregated_model_validation. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the start function on the aggregator to the aggregated_model_validation task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (:code:`@aggregator` or :code:`@collaborator`). Once each of the collaborators (defined in the runtime) complete the aggregated_model_validation task, they pass their current state onto the train task, from train to local_model_validation, and then finally to join at the aggregator. It is in join that an average is taken of the model weights, and the next round can begin. +Let's start with the basics. A flow is intended to define the entirety of federated learning experiment. Every flow begins with the :code:`start` task and concludes with the +:code:`end` task. At each step in the flow, attributes can be defined, modified, or deleted. Attributes get passed forward to the next step in the flow, which is defined by +the name of the task passed to the :code:`next` function. +In the line before each task, there is a **placement decorator**. The placement decorator defines where that task will be run (:code:`@aggregator` or :code:`@collaborator`). +The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with start andconcludes with the end task. In the following example, the +aggregator begins the flow with :code:`start` task and optionally passed in model and optimizer. The list of collaborators in the federation, :code:`self.collaborators`, +is automatically populated by the Runtime infrastructure. It serves as the participant list for executing tasks listed in :code:`self.next` and :code:`aggregated_model_validation`. +The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the start function on the aggregator to the +aggregated_model_validation task on the collaborator. +Once each of the collaborators (defined in the runtime) complete the :code:`aggregated_model_validation` task, they +pass their current state onto the :code:`train` task, from :code:`train` to :code:`local_model_validation`, and then finally to :code:`join` at the aggregator. +It is in :code:`join` that an average is taken of the model weights, and the next round can begin. .. code-block:: python @@ -45,9 +56,9 @@ Let's start with the basics. A flow is intended to define the entirety of federa @aggregator def start(self): print(f'Performing initialization for model') - self.collaborators = self.runtime.collaborators self.private = 10 self.current_round = 0 + print(f'Collaborators participating in federation: {self.collaborators}') self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private']) @collaborator @@ -237,20 +248,19 @@ Some important points to remember while creating callback function and private a - In above example multiple collaborators have the same callback function or private attributes. Depending on the Federated Learning requirements, user can specify unique callback function or private attributes for each Participant - *Private attributes* needs to be set after instantiating the participant. -Now let's see how the runtime for a flow is assigned, and the flow gets run: +To run the flow, simply pass the instance of the flow to the :code:`run()` method of runtime: .. code-block:: python flow = FederatedFlow() - flow.runtime = local_runtime - flow.run() + local_runtime.run(flow) And that's it! This will run an instance of the :code:`FederatedFlow` on a single node in a single process. LocalRuntime Backends --------------------- -The Runtime defines where code will run, but the Runtime has a :code:`Backend` - which defines the underlying implementation of *how* the flow will be executed. :code:`single_process` is the default in the :code:`LocalRuntime`: it executes all code sequentially within a single python process, and is well suited to run both on high spec and low spec hardware +The Runtime defines where code will run, but the Runtime has a :code:`backend` - which defines the underlying implementation of *how* the flow will be executed. :code:`single_process` is the default in the :code:`LocalRuntime`: it executes all code sequentially within a single python process, and is well suited to run both on high spec and low spec hardware For users with large servers or multiple GPUs they wish to take advantage of, we also provide a :code:`ray` `` backend. The Ray backend enables parallel task execution for collaborators, and optionally allows users to request dedicated CPU / GPUs for Participants by using the :code:`num_cpus` and :code:`num_gpus` arguments while instantiating the Participant in following manner: @@ -428,13 +438,12 @@ Below is an example of how to set up and instantiate a :code:`FederatedRuntime`: tls=False ) -To distribute the experiment on the Federation, we now need to assign the federated_runtime to the flow and execute it. +To distribute the experiment on the Federation, we simply pass the instance of flow to :code:`run()` method of :code:`FederatedRuntime` .. code-block:: python flow = FederatedFlow() - flow.runtime = federated_runtime - flow.run() + federated_runtime.run(flow) This will export the Jupyter notebook to an workspace and deploy it to the federation. The Director receives the experiment, distributes it to the Envoys, and initiates the execution of the experiment. diff --git a/openfl-tutorials/experimental/workflow/101_MNIST.ipynb b/openfl-tutorials/experimental/workflow/101_MNIST.ipynb index 7393d16a46..239d2c53be 100644 --- a/openfl-tutorials/experimental/workflow/101_MNIST.ipynb +++ b/openfl-tutorials/experimental/workflow/101_MNIST.ipynb @@ -224,7 +224,12 @@ "scrolled": true }, "source": [ - "Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin.\n", + "Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. Task placement (i.e. where the tasks run) is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`)\n", + "\n", + "The aggregator begins the flow with `start` task and optionally passed in model and optimizer. The list of collaborators in federation (`self.collaborators`) is automatically populated by LocalRuntime infrastructure and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator\n", + "\n", + "Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin.\n", + "\n", "\n", "![image.png](attachment:image.png)" ] @@ -252,9 +257,9 @@ " @aggregator\n", " def start(self):\n", " print(f'Performing initialization for model')\n", - " self.collaborators = self.runtime.collaborators\n", " self.private = 10\n", " self.current_round = 0\n", + " print(f'Collaborators participating in federation: {self.collaborators}')\n", " self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n", "\n", " @collaborator\n", @@ -382,8 +387,7 @@ "best_model = None\n", "optimizer = None\n", "flflow = FederatedFlow(model, optimizer, rounds=2, checkpoint=True)\n", - "flflow.runtime = local_runtime\n", - "flflow.run()" + "local_runtime.run(flflow)" ] }, { @@ -425,8 +429,7 @@ "outputs": [], "source": [ "flflow2 = FederatedFlow(model=flflow.model, optimizer=flflow.optimizer, rounds=2, checkpoint=True)\n", - "flflow2.runtime = local_runtime\n", - "flflow2.run()" + "local_runtime.run(flflow2)" ] }, { diff --git a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb index 00daa8095e..8e1a01f130 100644 --- a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb +++ b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb @@ -339,9 +339,8 @@ " \"\"\"\n", " print(f\"Initializing Workflow .... \")\n", "\n", - " self.collaborators = self.runtime.collaborators\n", " self.current_round = 0\n", - "\n", + " print(f'Collaborators participating in federation: {self.collaborators}')\n", " self.next(self.aggregated_model_validation, foreach=\"collaborators\")\n", "\n", " @collaborator\n", @@ -521,8 +520,7 @@ "model = None\n", "optimizer = None\n", "flflow = FederatedFlow_TorchMNIST(model, optimizer, learning_rate, momentum, rounds=2, checkpoint=True)\n", - "flflow.runtime = local_runtime\n", - "flflow.run()" + "local_runtime.run(flflow)" ] }, { @@ -635,7 +633,7 @@ "id": "87c487cb", "metadata": {}, "source": [ - "Now that we have our distributed infrastructure ready, let us modify the flow runtime to `FederatedRuntime` instance and deploy the experiment. \n", + "Now that we have our distributed infrastructure ready, the experiment is deployed onto the federation by providing the same `flflow` instance to `FederatedRuntime`.\n", "\n", "Progress of the flow is available on \n", "1. Jupyter notebook: if `checkpoint` attribute of the flow object is set to `True`\n", @@ -650,8 +648,7 @@ "outputs": [], "source": [ "flflow.results = [] # clear results from previous run\n", - "flflow.runtime = federated_runtime\n", - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { diff --git a/openfl-tutorials/experimental/workflow/FederatedRuntime/301_MNIST_Watermarking/workspace/MNIST_Watermarking.ipynb b/openfl-tutorials/experimental/workflow/FederatedRuntime/301_MNIST_Watermarking/workspace/MNIST_Watermarking.ipynb index eb71fa92af..cc2317c807 100644 --- a/openfl-tutorials/experimental/workflow/FederatedRuntime/301_MNIST_Watermarking/workspace/MNIST_Watermarking.ipynb +++ b/openfl-tutorials/experimental/workflow/FederatedRuntime/301_MNIST_Watermarking/workspace/MNIST_Watermarking.ipynb @@ -279,8 +279,7 @@ " This is the start of the Flow.\n", " \"\"\"\n", " print(\": Start of flow ... \")\n", - " self.collaborators = self.runtime.collaborators\n", - "\n", + " print(f'Collaborators participating in federation: {self.collaborators}')\n", " self.next(self.watermark_pretrain)\n", "\n", " @aggregator\n", @@ -558,8 +557,7 @@ " watermark_retrain_optimizer,\n", " checkpoint=True,\n", ")\n", - "flflow.runtime = federated_runtime\n", - "flflow.run()" + "federated_runtime.run(flflow)" ] } ], diff --git a/openfl/experimental/workflow/component/aggregator/aggregator.py b/openfl/experimental/workflow/component/aggregator/aggregator.py index 717436f8a0..c2964b5634 100644 --- a/openfl/experimental/workflow/component/aggregator/aggregator.py +++ b/openfl/experimental/workflow/component/aggregator/aggregator.py @@ -15,9 +15,7 @@ import dill from openfl.experimental.workflow.interface import FLSpec -from openfl.experimental.workflow.runtime import FederatedRuntime from openfl.experimental.workflow.utilities import aggregator_to_collaborator, checkpoint -from openfl.experimental.workflow.utilities.metaflow_utils import MetaflowInterface logger = getLogger(__name__) @@ -125,13 +123,7 @@ def __init__( self.flow = flow self.checkpoint = checkpoint - self.flow._foreach_methods = [] - logger.info("MetaflowInterface creation.") - self.flow._metaflow_interface = MetaflowInterface(self.flow.__class__, "single_process") - self.flow._run_id = self.flow._metaflow_interface.create_run() - self.flow.runtime = FederatedRuntime() self.name = "aggregator" - self.flow.runtime.collaborators = self.authorized_cols self.__private_attrs_callable = private_attributes_callable self.__private_attrs = private_attributes @@ -200,10 +192,8 @@ async def run_flow(self) -> FLSpec: """ # Start function will be the first step if any flow f_name = "start" - # Creating a clones from the flow object - FLSpec._reset_clones() - FLSpec._create_clones(self.flow, self.flow.runtime.collaborators) - + # Initialize the flow state + self.flow.initialize_flow_state(self.authorized_cols) logger.info(f"Starting round {self.current_round}...") while True: next_step = self.do_task(f_name) diff --git a/openfl/experimental/workflow/interface/fl_spec.py b/openfl/experimental/workflow/interface/fl_spec.py index 3e8365458b..d77935bd70 100644 --- a/openfl/experimental/workflow/interface/fl_spec.py +++ b/openfl/experimental/workflow/interface/fl_spec.py @@ -8,16 +8,11 @@ import inspect from copy import deepcopy -from typing import TYPE_CHECKING, Callable, List, Type, Union - -if TYPE_CHECKING: - from openfl.experimental.workflow.runtime import FederatedRuntime, LocalRuntime, Runtime +from typing import Callable, List, Type from openfl.experimental.workflow.utilities import ( MetaflowInterface, - SerializationError, aggregator_to_collaborator, - checkpoint, collaborator_to_aggregator, filter_attributes, generate_artifacts, @@ -36,7 +31,9 @@ class FLSpec: _initial_state (FLSpec or None): The saved initial state of the FLSpec instance. _foreach_methods (list): A list of methods to be applied iteratively. _checkpoint (bool): A flag indicating whether checkpointing is enabled. - _runtime (RuntimeType): The runtime of the flow. + _collaborators (list): A list of collaborators associated with the runtime. + _metaflow_interface (MetaflowInterface): The interface to the Metaflow runtime. + _run_id (str): The ID of the current run. """ _clones = [] @@ -53,20 +50,15 @@ def __init__(self, checkpoint: bool = False) -> None: self._checkpoint = checkpoint @classmethod - def _create_clones(cls, instance: Type[FLSpec], names: List[str]) -> None: - """Creates clones for instance for each collaborator in names. + def _reset_and_create_clones(cls, instance: Type[FLSpec], names: List[str]) -> None: + """Resets and creates clones for instance for each collaborator in names. Args: instance (Type[FLSpec]): The instance to be cloned. names (List[str]): The list of names for the clones. """ - cls._clones = {name: deepcopy(instance) for name in names} - - @classmethod - def _reset_clones(cls) -> None: - """Resets the clones of the class.""" - cls._clones = [] + cls._clones = {name: deepcopy(instance) for name in names} @classmethod def save_initial_state(cls, instance: Type[FLSpec]) -> None: @@ -102,102 +94,22 @@ def checkpoint(self, value: bool) -> None: self._checkpoint = value @property - def runtime(self) -> Type[Union[LocalRuntime, FederatedRuntime]]: - """Returns flow runtime. + def collaborators(self) -> List: + """Get the list of collaborators. Returns: - Type[Runtime]: The runtime of the flow. + _collaborators: A list of collaborators """ - return self._runtime + return self._collaborators - @runtime.setter - def runtime(self, runtime: Type[Runtime]) -> None: - """Sets flow runtime. + @collaborators.setter + def collaborators(self, collaborators: List) -> None: + """Set the list of collaborators. Args: - runtime (Type[Runtime]): The runtime to be set. - - Raises: - TypeError: If the provided runtime is not a valid OpenFL Runtime. + collaborators (List): A list of collaborators to be assigned. """ - if str(runtime) not in ["LocalRuntime", "FederatedRuntime"]: - raise TypeError(f"{runtime} is not a valid OpenFL Runtime") - self._runtime = runtime - - def run(self) -> None: - """Starts the execution of the flow.""" - # Submit flow to Runtime - if str(self._runtime) == "LocalRuntime": - self._run_local() - elif str(self._runtime) == "FederatedRuntime": - self._run_federated() - else: - raise Exception("Runtime not implemented") - - def _run_local(self) -> None: - """Executes the flow using LocalRuntime.""" - self._setup_initial_state() - try: - # Execute all Participant (Aggregator & Collaborator) tasks and - # retrieve the final attributes - # start step is the first task & invoked on aggregator through - # runtime.execute_task - final_attributes = self.runtime.execute_task( - self, - self.start, - ) - except Exception as e: - if "cannot pickle" in str(e) or "Failed to unpickle" in str(e): - msg = ( - "\nA serialization error was encountered that could not" - "\nbe handled by the ray backend." - "\nTry rerunning the flow without ray as follows:\n" - "\nLocalRuntime(...,backend='single_process')\n" - "\n or for more information about the original error," - "\nPlease see the official Ray documentation" - "\nhttps://docs.ray.io/en/releases-2.2.0/ray-core/\ - objects/serialization.html" - ) - raise SerializationError(str(e) + msg) - else: - raise e - for name, attr in final_attributes: - setattr(self, name, attr) - - def _setup_initial_state(self) -> None: - """ - Sets up the flow's initial state, initializing private attributes for - collaborators and aggregators. - """ - self._metaflow_interface = MetaflowInterface(self.__class__, self.runtime.backend) - self._run_id = self._metaflow_interface.create_run() - # Initialize aggregator private attributes - self.runtime.initialize_aggregator() - self._foreach_methods = [] - FLSpec._reset_clones() - FLSpec._create_clones(self, self.runtime.collaborators) - # Initialize collaborator private attributes - self.runtime.initialize_collaborators() - if self._checkpoint: - print(f"Created flow {self.__class__.__name__}") - - def _run_federated(self) -> None: - """Executes the flow using FederatedRuntime.""" - try: - # Prepare workspace and submit it for the FederatedRuntime - archive_path, exp_name = self.runtime.prepare_workspace_archive() - self.runtime.submit_experiment(archive_path, exp_name) - # Stream the experiment's stdout if the checkpoint is enabled - if self._checkpoint: - self.runtime.stream_experiment_stdout(exp_name) - # Retrieve the flspec object to update the experiment state - flspec_obj = self._get_flow_state() - # Update state of self - self._update_from_flspec_obj(flspec_obj) - except Exception as e: - raise Exception( - f"FederatedRuntime: Experiment {exp_name} failed to run due to error: {e}" - ) + self._collaborators = collaborators def _update_from_flspec_obj(self, flspec_obj: FLSpec) -> None: """Update self with attributes from the updated flspec instance. @@ -211,22 +123,6 @@ def _update_from_flspec_obj(self, flspec_obj: FLSpec) -> None: self._foreach_methods = flspec_obj._foreach_methods - def _get_flow_state(self) -> Union[FLSpec, None]: - """ - Gets the updated flow state. - - Returns: - flspec_obj (Union[FLSpec, None]): An updated FLSpec instance if the experiment - runs successfully. None if the experiment could not run. - """ - status, flspec_obj = self.runtime.get_flow_state() - if status: - print("Experiment ran successfully") - return flspec_obj - else: - print("Experiment could not run") - return None - def _capture_instance_snapshot(self, kwargs) -> List: """Takes backup of self before exclude or include filtering. @@ -279,6 +175,23 @@ def _display_transition_logs(self, f: Callable, parent_func: Callable) -> None: elif collaborator_to_aggregator(f, parent_func): print("Sending state from collaborator to aggregator") + def initialize_flow_state(self, collaborators: List, backend: str = "single_process") -> None: + """ + Sets up the flow's initial state + + Args: + collaborators (list): A list of collaborators + backend (str): The runtime backend + """ + self.collaborators = collaborators + print("MetaflowInterface creation.") + self._metaflow_interface = MetaflowInterface(self.__class__, backend) + self._run_id = self._metaflow_interface.create_run() + self._foreach_methods = [] + FLSpec._reset_and_create_clones(self, self.collaborators) + if self._checkpoint: + print(f"Created flow {self.__class__.__name__}") + def filter_exclude_include(self, f, **kwargs) -> None: """Filters exclude/include attributes for a given task within the flow. @@ -334,10 +247,6 @@ def next(self, f, **kwargs) -> None: parent = inspect.stack()[1][3] parent_func = getattr(self, parent) - if str(self._runtime) == "LocalRuntime": - # Checkpoint current attributes (if checkpoint==True) - checkpoint(self, parent_func) - # Take back-up of current state of self agg_to_collab_ss = None if aggregator_to_collaborator(f, parent_func): @@ -346,19 +255,15 @@ def next(self, f, **kwargs) -> None: # Remove included / excluded attributes from next task filter_attributes(self, f, **kwargs) - if str(self._runtime) == "FederatedRuntime": - if f.collaborator_step and not f.aggregator_step: - self._foreach_methods.append(f.__name__) - - self.execute_task_args = ( - self, - f, - parent_func, - FLSpec._clones, - agg_to_collab_ss, - kwargs, - ) - - elif str(self._runtime) == "LocalRuntime": - # update parameters required to execute execute_task function - self.execute_task_args = [f, parent_func, agg_to_collab_ss, kwargs] + if f.collaborator_step and not f.aggregator_step: + self._foreach_methods.append(f.__name__) + + # update parameters required to execute next steps + self.execute_task_args = ( + self, + f, + parent_func, + FLSpec._clones, + agg_to_collab_ss, + kwargs, + ) diff --git a/openfl/experimental/workflow/runtime/federated_runtime.py b/openfl/experimental/workflow/runtime/federated_runtime.py index 6420b6d8dc..5cb85ad8af 100644 --- a/openfl/experimental/workflow/runtime/federated_runtime.py +++ b/openfl/experimental/workflow/runtime/federated_runtime.py @@ -11,11 +11,14 @@ import sys from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import dill from tabulate import tabulate +if TYPE_CHECKING: + from openfl.experimental.workflow.interface.fl_spec import FLSpec + from openfl.experimental.workflow.runtime.runtime import Runtime from openfl.experimental.workflow.transport.grpc.director_client import RuntimeDirectorClient from openfl.experimental.workflow.workspace_export import WorkspaceExport @@ -54,7 +57,7 @@ def __init__( tls (bool): Whether to use TLS for the connection. """ super().__init__() - self.__collaborators = collaborators + self.collaborators = collaborators self.tls = tls if director: @@ -138,6 +141,32 @@ def _create_runtime_dir_client(self) -> RuntimeDirectorClient: certificate=self.certificate, ) + def run(self, flspec: Type[FLSpec]) -> None: + """Executes the flow using FederatedRuntime. + + Args: + flspec (Type[FLSpec]): Reference to the FLSpec (flow) object. + """ + exp_name = None + try: + # Prepare workspace and submit it for the FederatedRuntime + archive_path, exp_name = self.prepare_workspace_archive() + self.submit_experiment(archive_path, exp_name) + # Stream the experiment's stdout if the checkpoint is enabled + if flspec._checkpoint: + self.stream_experiment_stdout(exp_name) + # Retrieve the flspec object to update the experiment state + updated_flspec = self.get_flow_state() + # Update state of self + flspec._update_from_flspec_obj(updated_flspec) + except Exception as e: + error_msg = ( + "FederatedRuntime: Failed to prepare workspace archive" + if exp_name is None + else f"FederatedRuntime: Experiment {exp_name} failed" + ) + raise Exception(f"{error_msg} due to error: {e}") + def prepare_workspace_archive(self) -> Tuple[Path, str]: """ Prepare workspace archive using WorkspaceExport. @@ -164,7 +193,7 @@ def submit_experiment(self, archive_path, exp_name) -> None: """ try: response = self._runtime_dir_client.set_new_experiment( - archive_path=archive_path, experiment_name=exp_name, col_names=self.__collaborators + archive_path=archive_path, experiment_name=exp_name, col_names=self.collaborators ) self.experiment_submitted = response.status @@ -178,22 +207,24 @@ def submit_experiment(self, archive_path, exp_name) -> None: finally: self.remove_workspace_archive(archive_path) - def get_flow_state(self) -> Tuple[bool, Any]: + def get_flow_state(self) -> Optional[Any]: """ Retrieve the updated flow status and deserialized flow object. Returns: - status (bool): The flow status. - flow_object: The deserialized flow object. + Optional[Any]: The deserialized flow object if successful, otherwise None """ status, flspec_obj = self._runtime_dir_client.get_flow_state() - - # Append generated workspace path to sys.path - # to allow unpickling of flspec_obj - sys.path.append(str(self.generated_workspace_path)) - flow_object = dill.loads(flspec_obj) - - return status, flow_object + if status: + print("Experiment ran successfully") + # Append generated workspace path to sys.path + # to allow unpickling of flspec_obj + sys.path.append(str(self.generated_workspace_path)) + flow_object = dill.loads(flspec_obj) + return flow_object + else: + print("Experiment could not run") + return None def get_envoys(self) -> List[str]: """ diff --git a/openfl/experimental/workflow/runtime/local_runtime.py b/openfl/experimental/workflow/runtime/local_runtime.py index 7a7aa3f7a2..dae8b49582 100644 --- a/openfl/experimental/workflow/runtime/local_runtime.py +++ b/openfl/experimental/workflow/runtime/local_runtime.py @@ -21,6 +21,7 @@ from openfl.experimental.workflow.runtime.runtime import Runtime from openfl.experimental.workflow.utilities import ( ResourcesNotAvailableError, + SerializationError, aggregator_to_collaborator, check_resource_allocation, checkpoint, @@ -536,6 +537,23 @@ def get_collab_name(collab): get_collab_name(collaborator): collaborator for collaborator in collaborators } + def __handle_execution_exception(self, e: Exception): + """Handles exceptions encountered during flow execution.""" + if "cannot pickle" in str(e) or "Failed to unpickle" in str(e): + msg = ( + "\nA serialization error was encountered that could not" + "\nbe handled by the ray backend." + "\nTry rerunning the flow without ray as follows:\n" + "\nLocalRuntime(...,backend='single_process')\n" + "\n or for more information about the original error," + "\nPlease see the official Ray documentation" + "\nhttps://docs.ray.io/en/releases-2.2.0/ray-core/" + "objects/serialization.html" + ) + raise SerializationError(str(e) + msg) + else: + raise e + def get_collaborator_kwargs(self, collaborator_name: str): """Returns kwargs of collaborator. @@ -556,14 +574,14 @@ def get_collaborator_kwargs(self, collaborator_name: str): return kwargs - def initialize_aggregator(self): + def _initialize_aggregator(self): """Initialize aggregator private attributes.""" if self.backend == "single_process": self._aggregator.initialize_private_attributes() else: ray.get(self._aggregator.initialize_private_attributes.remote()) - def initialize_collaborators(self): + def _initialize_collaborators(self): """Initialize collaborator private attributes.""" if self.backend == "single_process": @@ -578,7 +596,12 @@ def init_private_attrs(collab): for collaborator in self.__collaborators.values(): init_private_attrs(collaborator) - def restore_instance_snapshot(self, ctx: Type[FLSpec], instance_snapshot: List[Type[FLSpec]]): + def _initialize_private_attributes(self): + """Initializes private attributes for aggregator and collaborators.""" + self._initialize_aggregator() + self._initialize_collaborators() + + def _restore_instance_snapshot(self, ctx: Type[FLSpec], instance_snapshot: List[Type[FLSpec]]): """Restores attributes from backup (in instance snapshot) to context (ctx). @@ -602,20 +625,20 @@ def execute_agg_steps(self, ctx: Any, f_name: str, clones: Optional[Any] = None) f_name (str): The name of the function to be executed. clones (Optional[Any], optional): Clones if any. Defaults to None. """ + f = getattr(ctx, f_name) + # Join step if clones is not None: - f = getattr(ctx, f_name) f(clones) - else: - not_at_transition_point = True - while not_at_transition_point: - f = getattr(ctx, f_name) - f() - - f, parent_func = ctx.execute_task_args[:2] - if aggregator_to_collaborator(f, parent_func) or f.__name__ == "end": - not_at_transition_point = False - - f_name = f.__name__ + checkpoint(ctx, f) + return + not_at_transition_point = True + while not_at_transition_point: + f() + checkpoint(ctx, f) + f, parent_func = ctx.execute_task_args[1:3] + if aggregator_to_collaborator(f, parent_func) or f.__name__ == "end": + not_at_transition_point = False + f_name = f.__name__ def execute_collab_steps(self, ctx: Any, f_name: str): """Execute collaborator steps until at transition point. @@ -624,18 +647,17 @@ def execute_collab_steps(self, ctx: Any, f_name: str): ctx (Any): The context in which the function is executed. f_name (str): The name of the function to be executed. """ + f = getattr(ctx, f_name) not_at_transition_point = True while not_at_transition_point: - f = getattr(ctx, f_name) f() - - f, parent_func = ctx.execute_task_args[:2] + checkpoint(ctx, f) + f, parent_func = ctx.execute_task_args[1:3] if ctx._is_at_transition_point(f, parent_func): not_at_transition_point = False - f_name = f.__name__ - def execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs): + def _execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs): """Defines which function to be executed based on name and kwargs. Updates the arguments and executes until end is not reached. @@ -654,21 +676,19 @@ def execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs): while f.__name__ != "end": if "foreach" in kwargs: - flspec_obj = self.execute_collab_task( + flspec_obj = self._execute_collab_task( flspec_obj, f, parent_func, instance_snapshot, **kwargs ) else: - flspec_obj = self.execute_agg_task(flspec_obj, f) - f, parent_func, instance_snapshot, kwargs = flspec_obj.execute_task_args + flspec_obj = self._execute_agg_task(flspec_obj, f) + _, f, parent_func, _, instance_snapshot, kwargs = flspec_obj.execute_task_args else: - flspec_obj = self.execute_agg_task(flspec_obj, f) - f = flspec_obj.execute_task_args[0] + flspec_obj = self._execute_agg_task(flspec_obj, f) - checkpoint(flspec_obj, f) artifacts_iter, _ = generate_artifacts(ctx=flspec_obj) return artifacts_iter() - def execute_agg_task(self, flspec_obj, f): + def _execute_agg_task(self, flspec_obj, f): """Performs execution of aggregator task. Args: @@ -703,7 +723,7 @@ def execute_agg_task(self, flspec_obj, f): gc.collect() return flspec_obj - def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **kwargs): + def _execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **kwargs): """ Performs 1. Filter include/exclude @@ -721,13 +741,11 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k Returns: flspec_obj: updated FLSpec (flow) object """ - - flspec_obj._foreach_methods.append(f.__name__) selected_collaborators = getattr(flspec_obj, kwargs["foreach"]) self.selected_collaborators = selected_collaborators # filter exclude/include attributes for clone - self.filter_exclude_include(flspec_obj, f, selected_collaborators, **kwargs) + self._filter_exclude_include(flspec_obj, f, selected_collaborators, **kwargs) if self.backend == "ray": ray_executor = RayExecutor() @@ -763,7 +781,7 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k flspec_obj.execute_task_args = clone.execute_task_args # Restore the flspec_obj state if back-up is taken - self.restore_instance_snapshot(flspec_obj, instance_snapshot) + self._restore_instance_snapshot(flspec_obj, instance_snapshot) del instance_snapshot gc.collect() @@ -771,7 +789,7 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k self.join_step = True return flspec_obj - def filter_exclude_include(self, flspec_obj, f, selected_collaborators, **kwargs): + def _filter_exclude_include(self, flspec_obj, f, selected_collaborators, **kwargs): """ This function filters exclude/include attributes Args: @@ -792,6 +810,39 @@ def filter_exclude_include(self, flspec_obj, f, selected_collaborators, **kwargs setattr(clone, name, deepcopy(attr)) clone._foreach_methods = flspec_obj._foreach_methods + def _execute_flow(self, flspec_obj: Type[FLSpec]): + """Executes and updates the flow with the final attributes. + + Args: + flspec_obj: Reference to the FLSpec (flow) object. + """ + try: + # Execute all Participant (Aggregator & Collaborator) tasks and + # retrieve the final attributes. + # start step is the first task & invoked on aggregator through + # self._execute_task + final_attributes = self._execute_task( + flspec_obj, + flspec_obj.start, + ) + except Exception as e: + self.__handle_execution_exception(e) + + # Updating the flow state with the final attributes + for name, attr in final_attributes: + setattr(flspec_obj, name, attr) + + def run(self, flspec_obj: Type[FLSpec]): + """Runs the flow using the LocalRuntime. + + Args: + flspec_obj: Reference to the FLSpec (flow) object. Contains + information about task sequence, flow attributes. + """ + self._initialize_private_attributes() + flspec_obj.initialize_flow_state(self.collaborators, self.backend) + self._execute_flow(flspec_obj) + def __repr__(self): """Returns the string representation of the LocalRuntime object. diff --git a/openfl/experimental/workflow/utilities/runtime_utils.py b/openfl/experimental/workflow/utilities/runtime_utils.py index ec0121da80..61df095396 100644 --- a/openfl/experimental/workflow/utilities/runtime_utils.py +++ b/openfl/experimental/workflow/utilities/runtime_utils.py @@ -45,14 +45,14 @@ def parse_attrs(ctx, exclude=[], reserved_words=["next", "runtime", "input"]): return cls_attrs, valid_artifacts -def generate_artifacts(ctx, reserved_words=["next", "runtime", "input", "checkpoint"]): +def generate_artifacts(ctx, reserved_words=["next", "input", "checkpoint", "collaborators"]): """Generates artifacts from the given context, excluding specified reserved words. Args: ctx (any): The context to generate artifacts from. reserved_words (list, optional): A list of reserved words to exclude. - Defaults to ["next", "runtime", "input", "checkpoint"]. + Defaults to ["next", "input", "checkpoint", "collaborators"]. Returns: tuple: A tuple containing a generator of artifacts and a list of diff --git a/openfl/experimental/workflow/workspace_export/export.py b/openfl/experimental/workflow/workspace_export/export.py index 14750de2ff..1a65649b59 100644 --- a/openfl/experimental/workflow/workspace_export/export.py +++ b/openfl/experimental/workflow/workspace_export/export.py @@ -13,7 +13,7 @@ from logging import getLogger from pathlib import Path from shutil import copytree -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import nbformat import yaml @@ -127,15 +127,45 @@ def __convert_to_python(self, notebook_path: Path, output_path: Path, export_fil return Path(output_path).joinpath(export_filename).resolve() + def __extract_runtime_instance_names(self) -> List[str]: + """ + Identifies instances of given classes in the script and returns + their variable names. + + Returns: + List[str]: A list of runtime instances variable names. + """ + instance_names = [] + class_names = ["LocalRuntime", "FederatedRuntime"] + + # Open the script file and read its contents while filtering out shell-style commands (!, %) + with open(self.script_path, "r") as file: + code = "".join(line for line in file if not line.lstrip().startswith(("!", "%"))) + tree = ast.parse(code) + for node in ast.walk(tree): + # Check if the node represents a variable assignment where a class is being instantiated + if isinstance(node, ast.Assign) and isinstance(node.value, ast.Call): + # Ensure the function being called is a direct class instantiation by name + if isinstance(node.value.func, ast.Name) and node.value.func.id in class_names: + # Extract the variable name(s) the instance is assigned to + for target in node.targets: + if isinstance(target, ast.Name): + instance_names.append(target.id) + return instance_names + def __comment_flow_execution(self) -> None: - """In the python script search for ".run()" and comment it.""" + """Search and comment runtime_instance.run(...) in python script. + runtime_instance could be an instance of either LocalRuntime or FederatedRuntime. + """ + runtime_instance_names = self.__extract_runtime_instance_names() with open(self.script_path, "r") as f: data = f.readlines() - for idx, line in enumerate(data): - if ".run()" in line: - data[idx] = f"# {line}" - with open(self.script_path, "w") as f: - f.writelines(data) + for runtime_instance_name in runtime_instance_names: + for idx, line in enumerate(data): + if f"{runtime_instance_name}.run(" in line: + data[idx] = f"# {line}" + with open(self.script_path, "w") as f: + f.writelines(data) def __change_runtime(self) -> None: """Change the LocalRuntime backend from ray to single_process.""" diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_datastore_cli/workspace/testflow_datastore_cli.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_datastore_cli/workspace/testflow_datastore_cli.ipynb index 9e518febcb..0b69d5c69f 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_datastore_cli/workspace/testflow_datastore_cli.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_datastore_cli/workspace/testflow_datastore_cli.ipynb @@ -145,7 +145,6 @@ " print(\n", " \"Testing FederatedFlow - Starting Test for Dataflow and CLI Functionality\"\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", " self.private = 10\n", " self.next(\n", " self.aggregated_model_validation,\n", @@ -389,8 +388,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowDatastoreAndCli(checkpoint=True)\n", - "flflow.runtime = federated_runtime" + "flflow = TestFlowDatastoreAndCli(checkpoint=True)" ] }, { @@ -400,7 +398,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -416,7 +414,7 @@ ], "metadata": { "kernelspec": { - "display_name": "fed_run", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -430,7 +428,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_include_exclude/workspace/testflow_include_exclude.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_include_exclude/workspace/testflow_include_exclude.ipynb index 77b876ff93..e9bafce19e 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_include_exclude/workspace/testflow_include_exclude.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_include_exclude/workspace/testflow_include_exclude.ipynb @@ -108,7 +108,6 @@ " f\"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for Include and Exclude \"\n", " + f\"Attributes {bcolors.ENDC}\"\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", "\n", " self.exclude_agg_to_agg = 10\n", " self.include_agg_to_agg = 100\n", @@ -141,7 +140,7 @@ " self.next(\n", " self.test_include_exclude_agg_to_collab,\n", " foreach=\"collaborators\",\n", - " include=[\"include_agg_to_collab\", \"collaborators\"],\n", + " include=[\"include_agg_to_collab\"],\n", " )\n", "\n", " @collaborator\n", @@ -315,8 +314,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowIncludeExclude(checkpoint=True)\n", - "flflow.runtime = federated_runtime\n" + "flflow = TestFlowIncludeExclude(checkpoint=True)" ] }, { @@ -326,7 +324,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -342,7 +340,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dir-wip", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -356,7 +354,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.19" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_internalloop/workspace/testflow_internal_loop.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_internalloop/workspace/testflow_internal_loop.ipynb index 190a6eac0e..7e89122f32 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_internalloop/workspace/testflow_internal_loop.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_internalloop/workspace/testflow_internal_loop.ipynb @@ -108,7 +108,6 @@ " + f\" of Training Rounds: {self.training_rounds}{bcolors.ENDC}\"\n", " )\n", " self.model = np.zeros((10, 10, 10)) # Test model\n", - " self.collaborators = self.runtime.collaborators\n", " self.next(self.agg_model_mean, foreach=\"collaborators\")\n", "\n", " @collaborator\n", @@ -336,8 +335,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowInternalLoop(checkpoint=True)\n", - "flflow.runtime = federated_runtime\n" + "flflow = TestFlowInternalLoop(checkpoint=True)" ] }, { @@ -347,7 +345,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -363,7 +361,7 @@ ], "metadata": { "kernelspec": { - "display_name": "fed_run", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -377,7 +375,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes/workspace/testflow_privateattributes.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes/workspace/testflow_privateattributes.ipynb index bbf5cf2f34..897a1c0387 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes/workspace/testflow_privateattributes.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes/workspace/testflow_privateattributes.ipynb @@ -107,7 +107,6 @@ " f\"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for accessibility of private \"\n", " + f\"attributes {bcolors.ENDC}\"\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", "\n", " validate_agg_private_attr(self, \"start\", aggr=[\"test_loader_agg\"], collabs=[\"train_loader\", \"test_loader\"])\n", "\n", @@ -251,11 +250,7 @@ " )\n", " for idx, collab in enumerate(self.collaborators):\n", " # Collaborator attributes should not be accessible in aggregator step\n", - " if (\n", - " type(self.collaborators[idx]) is not str\n", - " or hasattr(self.runtime, \"_collaborators\")\n", - " or hasattr(self.runtime, \"__collaborators\")\n", - " ):\n", + " if type(self.collaborators[idx]) is not str:\n", " # Error - we are able to access collaborator attributes\n", " TestFlowPrivateAttributes.ERROR_LIST.append(\n", " step_name + \"_collaborator_attributes_found\"\n", @@ -301,19 +296,7 @@ " print(\n", " f\"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator\"\n", " + f\" private attributes accessible: {','.join(breached_agg_attr)} {bcolors.ENDC}\"\n", - " )\n", - "\n", - " # Aggregator attributes should not be accessible in collaborator step\n", - " if hasattr(self.runtime, \"_aggregator\") and isinstance(self.runtime._aggregator, Aggregator):\n", - " # Error - we are able to access aggregator attributes\n", - " TestFlowPrivateAttributes.ERROR_LIST.append(\n", - " step_name + \"_aggregator_attributes_found\"\n", - " )\n", - " print(\n", - " f\"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator\"\n", - " + f\" private attributes accessible {bcolors.ENDC}\"\n", - " )\n", - "\n" + " )\n" ] }, { @@ -362,8 +345,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowPrivateAttributes(checkpoint=True)\n", - "flflow.runtime = federated_runtime\n" + "flflow = TestFlowPrivateAttributes(checkpoint=True)\n" ] }, { @@ -373,7 +355,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -389,7 +371,7 @@ ], "metadata": { "kernelspec": { - "display_name": "fed_run", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -403,7 +385,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_with_both_options/workspace/testflow_privateattributes.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_with_both_options/workspace/testflow_privateattributes.ipynb index 87a6ac01e9..943d9ae82a 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_with_both_options/workspace/testflow_privateattributes.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_with_both_options/workspace/testflow_privateattributes.ipynb @@ -107,7 +107,6 @@ " f\"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for accessibility of private \"\n", " + f\"attributes {bcolors.ENDC}\"\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", "\n", " validate_agg_private_attr(self, \"start\", aggr=[\"test_loader_agg_via_callable\"], collabs=[\"train_loader_via_callable\", \"test_loader_via_callable\"])\n", "\n", @@ -251,11 +250,7 @@ " )\n", " for idx, collab in enumerate(self.collaborators):\n", " # Collaborator attributes should not be accessible in aggregator step\n", - " if (\n", - " type(self.collaborators[idx]) is not str\n", - " or hasattr(self.runtime, \"_collaborators\")\n", - " or hasattr(self.runtime, \"__collaborators\")\n", - " ):\n", + " if type(self.collaborators[idx]) is not str:\n", " # Error - we are able to access collaborator attributes\n", " TestFlowPrivateAttributes.ERROR_LIST.append(\n", " step_name + \"_collaborator_attributes_found\"\n", @@ -301,19 +296,7 @@ " print(\n", " f\"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator\"\n", " + f\" private attributes accessible: {','.join(breached_agg_attr)} {bcolors.ENDC}\"\n", - " )\n", - "\n", - " # Aggregator attributes should not be accessible in collaborator step\n", - " if hasattr(self.runtime, \"_aggregator\") and isinstance(self.runtime._aggregator, Aggregator):\n", - " # Error - we are able to access aggregator attributes\n", - " TestFlowPrivateAttributes.ERROR_LIST.append(\n", - " step_name + \"_aggregator_attributes_found\"\n", - " )\n", - " print(\n", - " f\"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator\"\n", - " + f\" private attributes accessible {bcolors.ENDC}\"\n", - " )\n", - "\n" + " )\n" ] }, { @@ -362,8 +345,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowPrivateAttributes(checkpoint=True)\n", - "flflow.runtime = federated_runtime\n" + "flflow = TestFlowPrivateAttributes(checkpoint=True)" ] }, { @@ -373,7 +355,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -389,7 +371,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dir_shift", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -403,7 +385,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.20" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_without_callable/workspace/testflow_private_attributes.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_without_callable/workspace/testflow_private_attributes.ipynb index bdbf091bee..58e6f9f2dd 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_without_callable/workspace/testflow_private_attributes.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_private_attributes_initialization_without_callable/workspace/testflow_private_attributes.ipynb @@ -98,7 +98,6 @@ " f\"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for accessibility of private \"\n", " + f\"attributes {bcolors.ENDC}\"\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", "\n", " validate_agg_private_attr(self, \"start\", aggr=[\"test_loader_agg\"], collabs=[\"train_loader\", \"test_loader\"])\n", "\n", @@ -242,11 +241,7 @@ " )\n", " for idx, collab in enumerate(self.collaborators):\n", " # Collaborator attributes should not be accessible in aggregator step\n", - " if (\n", - " type(self.collaborators[idx]) is not str\n", - " or hasattr(self.runtime, \"_collaborators\")\n", - " or hasattr(self.runtime, \"__collaborators\")\n", - " ):\n", + " if type(self.collaborators[idx]) is not str:\n", " # Error - we are able to access collaborator attributes\n", " TestFlowPrivateAttributes.ERROR_LIST.append(\n", " step_name + \"_collaborator_attributes_found\"\n", @@ -291,17 +286,6 @@ " print(\n", " f\"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator\"\n", " + f\" private attributes accessible: {','.join(breached_agg_attr)} {bcolors.ENDC}\"\n", - " )\n", - "\n", - " # Aggregator attributes should not be accessible in collaborator step\n", - " if hasattr(self.runtime, \"_aggregator\") and isinstance(self.runtime._aggregator, Aggregator):\n", - " # Error - we are able to access aggregator attributes\n", - " TestFlowPrivateAttributes.ERROR_LIST.append(\n", - " step_name + \"_aggregator_attributes_found\"\n", - " )\n", - " print(\n", - " f\"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator\"\n", - " + f\" private attributes accessible {bcolors.ENDC}\"\n", " )\n" ] }, @@ -358,8 +342,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowPrivateAttributes(checkpoint=True)\n", - "flflow.runtime = federated_runtime\n" + "flflow = TestFlowPrivateAttributes(checkpoint=True)" ] }, { @@ -369,7 +352,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -385,7 +368,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dir-wip", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -399,7 +382,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.19" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_reference/workspace/testflow_reference.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_reference/workspace/testflow_reference.ipynb index 311a443f09..d06b525c0e 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_reference/workspace/testflow_reference.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_reference/workspace/testflow_reference.ipynb @@ -142,7 +142,6 @@ " self.agg_attr_optimizer = optim.SGD(\n", " self.agg_attr_model.parameters(), lr=1e-3, momentum=1e-2\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", "\n", " # get aggregator attributes\n", " agg_attr_list = filter_attrs(inspect.getmembers(self))\n", @@ -438,8 +437,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowReference(checkpoint=True)\n", - "flflow.runtime = federated_runtime\n" + "flflow = TestFlowReference(checkpoint=True)" ] }, { @@ -449,7 +447,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -465,7 +463,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dir_shift", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -479,7 +477,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.20" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_reference_with_include_exclude/workspace/testflow_reference_with_include_exclude.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_reference_with_include_exclude/workspace/testflow_reference_with_include_exclude.ipynb index b584487489..4cb48cc05b 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_reference_with_include_exclude/workspace/testflow_reference_with_include_exclude.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_reference_with_include_exclude/workspace/testflow_reference_with_include_exclude.ipynb @@ -135,11 +135,10 @@ " self.agg_attr_optimizer = optim.SGD(\n", " self.agg_attr_model.parameters(), lr=1e-3, momentum=1e-2\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", " self.next(\n", " self.test_create_collab_attr,\n", " foreach=\"collaborators\",\n", - " include=[\"collaborators\", \"agg_attr_list\"],\n", + " include=[\"agg_attr_list\"],\n", " )\n", "\n", " @collaborator\n", @@ -348,8 +347,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowReferenceWithIncludeExclude(checkpoint=True)\n", - "flflow.runtime = federated_runtime\n" + "flflow = TestFlowReferenceWithIncludeExclude(checkpoint=True)" ] }, { @@ -359,7 +357,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -375,7 +373,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dir-wip", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -389,7 +387,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.19" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/FederatedRuntime/testcase_subset_of_collaborators/workspace/testflow_subset_of_collaborators.ipynb b/tests/github/experimental/workflow/FederatedRuntime/testcase_subset_of_collaborators/workspace/testflow_subset_of_collaborators.ipynb index 3620e8e4a1..2ba82fc0e0 100644 --- a/tests/github/experimental/workflow/FederatedRuntime/testcase_subset_of_collaborators/workspace/testflow_subset_of_collaborators.ipynb +++ b/tests/github/experimental/workflow/FederatedRuntime/testcase_subset_of_collaborators/workspace/testflow_subset_of_collaborators.ipynb @@ -100,7 +100,6 @@ " f\"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for \"\n", " + f\"validating Subset of collaborators {bcolors.ENDC}\"\n", " )\n", - " self.collaborators = self.runtime.collaborators\n", "\n", " # select subset of collaborators\n", " self.subset_collabrators = self.collaborators[:2]\n", @@ -226,7 +225,7 @@ "\n", "federated_runtime = FederatedRuntime(\n", " collaborators= ['envoy_one', 'envoy_two', 'envoy_three', 'envoy_four'], \n", - " director=director_info, \n", + " director=director_info,\n", " notebook_path='./testflow_subset_of_collaborators.ipynb'\n", ")" ] @@ -250,8 +249,7 @@ "source": [ "#| export\n", "\n", - "flflow = TestFlowSubsetCollaborators(checkpoint=True)\n", - "flflow.runtime = federated_runtime" + "flflow = TestFlowSubsetCollaborators(checkpoint=True)" ] }, { @@ -261,7 +259,7 @@ "metadata": {}, "outputs": [], "source": [ - "flflow.run()" + "federated_runtime.run(flflow)" ] }, { @@ -277,7 +275,7 @@ ], "metadata": { "kernelspec": { - "display_name": "dir-wip", + "display_name": "refactor", "language": "python", "name": "python3" }, @@ -291,7 +289,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.19" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_datastore_cli.py b/tests/github/experimental/workflow/LocalRuntime/testflow_datastore_cli.py index 1dc792a413..6be491ac30 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_datastore_cli.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_datastore_cli.py @@ -109,7 +109,6 @@ def start(self): print( "Testing FederatedFlow - Starting Test for Dataflow and CLI Functionality" ) - self.collaborators = self.runtime.collaborators self.private = 10 self.next( self.aggregated_model_validation, @@ -272,8 +271,8 @@ def validate_datastore_cli(flow_obj, expected_flow_steps, num_rounds): No issues found and below are the tests that ran successfully 1. Datastore steps and expected steps are matching 2. Task stdout and task stderr verified through metaflow cli is as expected - 3. Number of tasks are aligned with number of rounds and number\ - of collaborators {Bcolors.ENDC}""") + 3. Number of tasks are aligned with number of rounds and number \ + of collaborators {Bcolors.ENDC}""") def display_validate_errors(validate_flow_error): @@ -331,8 +330,7 @@ def callable_to_initialize_collaborator_private_attributes( model = None optimizer = None flflow = TestFlowDatastoreAndCli(model, optimizer, num_rounds, checkpoint=True) - flflow.runtime = local_runtime - flflow.run() + local_runtime.run(flflow) expected_flow_steps = [ "start", diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py b/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py index 306960d1ca..805d91edeb 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py @@ -35,7 +35,6 @@ def start(self): f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for Exclude Attributes " + f"{bcolors.ENDC}" ) - self.collaborators = self.runtime.collaborators self.exclude_agg_to_agg = 10 self.include_agg_to_agg = 100 @@ -205,9 +204,8 @@ def end(self): print(f"Local runtime collaborators = {local_runtime.collaborators}") flflow = TestFlowExclude(checkpoint=True) - flflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - flflow.run() + local_runtime.run(flflow) print(f"{bcolors.OKBLUE}End of Testing FederatedFlow {bcolors.ENDC}") diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_include.py b/tests/github/experimental/workflow/LocalRuntime/testflow_include.py index 8b5ce29c6f..aa6052a6c5 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_include.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_include.py @@ -35,13 +35,12 @@ def start(self): f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for Include Attributes " + f"{bcolors.ENDC}" ) - self.collaborators = self.runtime.collaborators self.exclude_agg_to_agg = 10 self.include_agg_to_agg = 100 self.next( self.test_include_agg_to_agg, - include=["include_agg_to_agg", "collaborators"], + include=["include_agg_to_agg"], ) @aggregator @@ -68,7 +67,7 @@ def test_include_agg_to_agg(self): self.next( self.test_include_agg_to_collab, foreach="collaborators", - include=["include_agg_to_collab", "collaborators"], + include=["include_agg_to_collab"], ) @collaborator @@ -208,9 +207,8 @@ def end(self): print(f"Local runtime collaborators = {local_runtime.collaborators}") flflow = TestFlowInclude(checkpoint=True) - flflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - flflow.run() + local_runtime.run(flflow) print(f"{bcolors.OKBLUE}End of Testing FederatedFlow {bcolors.ENDC}") diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_include_exclude.py b/tests/github/experimental/workflow/LocalRuntime/testflow_include_exclude.py index 4754a48426..44ead301e0 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_include_exclude.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_include_exclude.py @@ -35,7 +35,6 @@ def start(self): f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for Include and Exclude " + f"Attributes {bcolors.ENDC}" ) - self.collaborators = self.runtime.collaborators self.exclude_agg_to_agg = 10 self.include_agg_to_agg = 100 @@ -68,7 +67,7 @@ def test_include_exclude_agg_to_agg(self): self.next( self.test_include_exclude_agg_to_collab, foreach="collaborators", - include=["include_agg_to_collab", "collaborators"], + include=["include_agg_to_collab"], ) @collaborator @@ -224,9 +223,8 @@ def end(self): print(f"Local runtime collaborators = {local_runtime.collaborators}") flflow = TestFlowIncludeExclude(checkpoint=True) - flflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - flflow.run() + local_runtime.run(flflow) print(f"{bcolors.OKBLUE}End of Testing FederatedFlow {bcolors.ENDC}") diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_internalloop.py b/tests/github/experimental/workflow/LocalRuntime/testflow_internalloop.py index a27af8c3ad..6ca28942f8 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_internalloop.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_internalloop.py @@ -40,7 +40,6 @@ def start(self): + f" of Training Rounds: {self.training_rounds}{bcolors.ENDC}" ) self.model = np.zeros((10, 10, 10)) # Test model - self.collaborators = self.runtime.collaborators self.next(self.agg_model_mean, foreach="collaborators") @collaborator @@ -242,8 +241,7 @@ def display_validate_errors(validate_flow_error): top_model_accuracy = 0 flflow = TestFlowInternalLoop(model, optimizer, 5, checkpoint=True) - flflow.runtime = local_runtime - flflow.run() + local_runtime.run(flflow) # Flow Test Begins expected_flow_steps = [ diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes.py b/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes.py index 2b4be06f15..985ba1ddf4 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes.py @@ -37,7 +37,6 @@ def start(self): f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for accessibility of private " + f"attributes {bcolors.ENDC}" ) - self.collaborators = self.runtime.collaborators validate_collab_private_attr(self, "test_loader", "start") @@ -157,11 +156,7 @@ def validate_collab_private_attr(self, private_attr, step_name): for idx, collab in enumerate(self.collaborators): # Collaborator private attributes should not be accessible - if ( - type(self.collaborators[idx]) is not str - or hasattr(self.runtime, "_collaborators") - or hasattr(self.runtime, "__collaborators") - ): + if type(self.collaborators[idx]) is not str: # Error - we are able to access collaborator attributes TestFlowPrivateAttributes.error_list.append( step_name + "_collaborator_attributes_found" @@ -183,16 +178,6 @@ def validate_agg_private_attrs(self, private_attr_1, private_attr_2, step_name): + f"private attributes not accessible {bcolors.ENDC}" ) - if hasattr(self.runtime, "_aggregator"): - # Error - we are able to access aggregator attributes - TestFlowPrivateAttributes.error_list.append( - step_name + "_aggregator_attributes_found" - ) - print( - f"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator" - + f" private attributes accessible {bcolors.ENDC}" - ) - if __name__ == "__main__": # Setup Aggregator with private attributes via callable function @@ -221,8 +206,8 @@ def callable_to_initialize_aggregator_private_attributes(): def callable_to_initialize_collaborator_private_attributes(index): return { - "train_loader": np.random.rand(idx * 50, 28, 28), - "test_loader": np.random.rand(idx * 10, 28, 28), + "train_loader": np.random.rand(index * 50, 28, 28), + "test_loader": np.random.rand(index * 10, 28, 28), } collaborators = [] @@ -245,9 +230,8 @@ def callable_to_initialize_collaborator_private_attributes(index): print(f"Local runtime collaborators = {local_runtime.collaborators}") flflow = TestFlowPrivateAttributes(checkpoint=True) - flflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - flflow.run() + local_runtime.run(flflow) print(f"{bcolors.OKBLUE}End of Testing FederatedFlow {bcolors.ENDC}") diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_with_both_options.py b/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_with_both_options.py index 168a4953c2..a177c7e283 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_with_both_options.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_with_both_options.py @@ -37,7 +37,6 @@ def start(self): f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for accessibility of private " + f"attributes {bcolors.ENDC}" ) - self.collaborators = self.runtime.collaborators validate_collab_private_attr(self, "test_loader_via_callable", "start") @@ -157,11 +156,7 @@ def validate_collab_private_attr(self, private_attr, step_name): for idx, collab in enumerate(self.collaborators): # Collaborator private attributes should not be accessible - if ( - type(self.collaborators[idx]) is not str - or hasattr(self.runtime, "_collaborators") - or hasattr(self.runtime, "__collaborators") - ): + if type(self.collaborators[idx]) is not str: # Error - we are able to access collaborator attributes TestFlowPrivateAttributes.error_list.append( step_name + "_collaborator_attributes_found" @@ -183,16 +178,6 @@ def validate_agg_private_attrs(self, private_attr_1, private_attr_2, step_name): + f"private attributes not accessible {bcolors.ENDC}" ) - if hasattr(self.runtime, "_aggregator"): - # Error - we are able to access aggregator attributes - TestFlowPrivateAttributes.error_list.append( - step_name + "_aggregator_attributes_found" - ) - print( - f"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator" - + f" private attributes accessible {bcolors.ENDC}" - ) - if __name__ == "__main__": # Setup Aggregator with private attributes via callable function @@ -222,8 +207,8 @@ def callable_to_initialize_aggregator_private_attributes(): def callable_to_initialize_collaborator_private_attributes(index): return { - "train_loader_via_callable": np.random.rand(idx * 50, 28, 28), - "test_loader_via_callable": np.random.rand(idx * 10, 28, 28), + "train_loader_via_callable": np.random.rand(index * 50, 28, 28), + "test_loader_via_callable": np.random.rand(index * 10, 28, 28), } collaborators = [] @@ -250,9 +235,8 @@ def callable_to_initialize_collaborator_private_attributes(index): print(f"Local runtime collaborators = {local_runtime.collaborators}") flflow = TestFlowPrivateAttributes(checkpoint=True) - flflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - flflow.run() + local_runtime.run(flflow) print(f"{bcolors.OKBLUE}End of Testing FederatedFlow {bcolors.ENDC}") diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_without_callable.py b/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_without_callable.py index 3f0da6060e..21b8020cd9 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_without_callable.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_privateattributes_initialization_without_callable.py @@ -37,7 +37,6 @@ def start(self): f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for accessibility of private " + f"attributes {bcolors.ENDC}" ) - self.collaborators = self.runtime.collaborators validate_collab_private_attr(self, "test_loader", "start") @@ -157,11 +156,7 @@ def validate_collab_private_attr(self, private_attr, step_name): for idx, collab in enumerate(self.collaborators): # Collaborator private attributes should not be accessible - if ( - type(self.collaborators[idx]) is not str - or hasattr(self.runtime, "_collaborators") - or hasattr(self.runtime, "__collaborators") - ): + if type(self.collaborators[idx]) is not str: # Error - we are able to access collaborator attributes TestFlowPrivateAttributes.error_list.append( step_name + "_collaborator_attributes_found" @@ -183,16 +178,6 @@ def validate_agg_private_attrs(self, private_attr_1, private_attr_2, step_name): + f"private attributes not accessible {bcolors.ENDC}" ) - if hasattr(self.runtime, "_aggregator"): - # Error - we are able to access aggregator attributes - TestFlowPrivateAttributes.error_list.append( - step_name + "_aggregator_attributes_found" - ) - print( - f"{bcolors.FAIL} ... Attribute test failed in {step_name} - Aggregator" - + f" private attributes accessible {bcolors.ENDC}" - ) - if __name__ == "__main__": # Setup aggregator private attributes @@ -231,9 +216,8 @@ def validate_agg_private_attrs(self, private_attr_1, private_attr_2, step_name): print(f"Local runtime collaborators = {local_runtime.collaborators}") flflow = TestFlowPrivateAttributes(checkpoint=True) - flflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - flflow.run() + local_runtime.run(flflow) print(f"{bcolors.OKBLUE}End of Testing FederatedFlow {bcolors.ENDC}") diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_reference.py b/tests/github/experimental/workflow/LocalRuntime/testflow_reference.py index de9b100d50..163cac1385 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_reference.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_reference.py @@ -81,7 +81,6 @@ def test_create_agg_attr(self): self.agg_attr_optimizer = optim.SGD( self.agg_attr_model.parameters(), lr=1e-3, momentum=1e-2 ) - self.collaborators = self.runtime.collaborators # get aggregator attributes agg_attr_list = filter_attrs(inspect.getmembers(self)) @@ -367,8 +366,7 @@ def callable_to_initialize_collaborator_private_attributes(index): print(f"Local runtime collaborators = {local_runtime.collaborators}") testflow = TestFlowReference(checkpoint=True) - testflow.runtime = local_runtime for i in range(2): print(f"Starting round {i}...") - testflow.run() + local_runtime.run(testflow) diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_exclude.py b/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_exclude.py index ce6aeb1701..29709652db 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_exclude.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_exclude.py @@ -75,7 +75,6 @@ def test_create_agg_attr(self): self.agg_attr_optimizer = optim.SGD( self.agg_attr_model.parameters(), lr=1e-3, momentum=1e-2 ) - self.collaborators = self.runtime.collaborators self.next( self.test_create_collab_attr, @@ -269,8 +268,7 @@ def validate_references(matched_ref_dict): print(f"Local runtime collaborators = {local_runtime.collaborators}") testflow = TestFlowReferenceWithExclude(checkpoint=True) - testflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - testflow.run() + local_runtime.run(testflow) diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_include.py b/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_include.py index c3757d9f96..f937668314 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_include.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_reference_with_include.py @@ -75,11 +75,10 @@ def test_create_agg_attr(self): self.agg_attr_optimizer = optim.SGD( self.agg_attr_model.parameters(), lr=1e-3, momentum=1e-2 ) - self.collaborators = self.runtime.collaborators self.next( self.test_create_collab_attr, foreach="collaborators", - include=["collaborators", "agg_attr_list"], + include=["agg_attr_list"], ) @collaborator @@ -266,8 +265,7 @@ def validate_references(matched_ref_dict): print(f"Local runtime collaborators = {local_runtime.collaborators}") testflow = TestFlowReferenceWithInclude(checkpoint=True) - testflow.runtime = local_runtime for i in range(5): print(f"Starting round {i}...") - testflow.run() + local_runtime.run(testflow) diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_subset_of_collaborators.py b/tests/github/experimental/workflow/LocalRuntime/testflow_subset_of_collaborators.py index cb031900c1..a39625cbec 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_subset_of_collaborators.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_subset_of_collaborators.py @@ -44,7 +44,6 @@ def start(self): f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test for " + f"validating Subset of collaborators {bcolors.ENDC}" ) - self.collaborators = self.runtime.collaborators # select subset of collaborators self.subset_collabrators = self.collaborators[: random.choice(self.random_ints)] @@ -135,8 +134,7 @@ def callable_to_initialize_collaborator_private_attributes(collab_name): testflow_subset_collaborators = TestFlowSubsetCollaborators( checkpoint=True, random_ints=random_ints ) - testflow_subset_collaborators.runtime = local_runtime - testflow_subset_collaborators.run() + local_runtime.run(testflow_subset_collaborators) subset_collaborators = testflow_subset_collaborators.subset_collabrators collaborators_ran = testflow_subset_collaborators.collaborators_ran