diff --git a/docs/about/features.rst b/docs/about/features.rst index b632cbd1f3..da6a80c8c0 100644 --- a/docs/about/features.rst +++ b/docs/about/features.rst @@ -97,3 +97,16 @@ Quantitatively audit data privacy in statistical and machine learning algorithms features_index/privacy_meter +.. _secure_aggregation: + +--------------------- +Secure Aggregation +--------------------- + +In Federated Learning (FL), Secure Aggregation (SecAgg) is a technique that allows the participants to collaborate on the central model without revealing their individual contributions (local model updates). For more info see :doc:`features_index/secure_aggregation` + +.. toctree:: + :hidden: + + features_index/secure_aggregation + diff --git a/docs/about/features_index/secure_aggregation.rst b/docs/about/features_index/secure_aggregation.rst new file mode 100644 index 0000000000..b80465d011 --- /dev/null +++ b/docs/about/features_index/secure_aggregation.rst @@ -0,0 +1,120 @@ +.. # Copyright (C) 2020-2024 Intel Corporation +.. # SPDX-License-Identifier: Apache-2.0 + +Secure Aggregation +======================================= + +In Federated Learning (FL), Secure Aggregation (SecAgg) restricts the aggregator to only learn the summation or average of the updates from collaborators. + +OpenFL integrates `SecAgg `_ into TaskRunner API as well as the Workflow API. + +TaskRunner API +------------------------------------- + +OpenFL treats SecAgg as a core security feature and can be enabled for any experiment by simply modifying the plan. + +The following plan shows secure aggregation being enabled on `keras/mnist `_ workspace by simply modifying the plan. + +.. code-block:: yaml + :emphasize-lines: 10 + + aggregator: + settings: + best_state_path: save/best.pbuf + db_store_rounds: 2 + init_state_path: save/init.pbuf + last_state_path: save/last.pbuf + persist_checkpoint: true + persistent_db_path: local_state/tensor.db + rounds_to_train: 1 + secure_aggregation: true + template: openfl.component.Aggregator + assigner: + settings: + task_groups: + - name: learning + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + - name: evaluation + percentage: 0 + tasks: + - aggregated_model_validation + template: openfl.component.RandomGroupedAssigner + collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator + compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline + data_loader: + settings: + batch_size: 256 + collaborator_count: 2 + data_group_name: mnist + template: src.dataloader.KerasMNISTInMemory + network: + settings: + agg_addr: localhost + agg_port: 53788 + cert_folder: cert + client_reconnect_interval: 5 + hash_salt: auto + require_client_auth: true + use_tls: true + template: openfl.federation.Network + task_runner: + settings: {} + template: src.taskrunner.KerasCNN + tasks: + aggregated_model_validation: + function: validate_task + kwargs: + apply: global + batch_size: 32 + metrics: + - accuracy + locally_tuned_model_validation: + function: validate_task + kwargs: + apply: local + batch_size: 32 + metrics: + - accuracy + settings: {} + train: + function: train_task + kwargs: + batch_size: 32 + epochs: 1 + metrics: + - loss + +As can be seen in the above plan, by only enabling ``aggregator.settings.secure_aggregation`` in the workspace plan, one can enable SecAgg. + +After the flags have been set in plan.yml and the setup for the experiment is completed, one can verify that SecAgg was enabled by looking at the aggregator logs + +.. code-block:: bash + + [21:55:01] INFO SecAgg: recreated secrets successfully setup.py:281 + INFO SecAgg: setup completed, saved required tensors to db + +Similarly, in the collaborator logs + +.. code-block:: bash + + INFO Secure aggregation is enabled, starting setup... secure_aggregation.py:48 + [21:55:01] INFO SecAgg: setup completed, saved required tensors to db. + + +Workflow API +------------------------------------- + +OpenFL provides `utility functions `_ that can be utilised to perform SecAgg in Workflow API. + +An example notebook can be found `here `_ that showcases how the secure aggregation flow can be achieved in Workflow API using both, LocalRuntime and FederatedRuntime. \ No newline at end of file diff --git a/openfl/callbacks/__init__.py b/openfl/callbacks/__init__.py index 8cd3f911ba..b8a8cec0a3 100644 --- a/openfl/callbacks/__init__.py +++ b/openfl/callbacks/__init__.py @@ -5,3 +5,4 @@ from openfl.callbacks.lambda_callback import LambdaCallback from openfl.callbacks.memory_profiler import MemoryProfiler from openfl.callbacks.metric_writer import MetricWriter +from openfl.callbacks.secure_aggregation import SecAggBootstrapping diff --git a/openfl/callbacks/callback_list.py b/openfl/callbacks/callback_list.py index 29661fff6a..75dc48952a 100644 --- a/openfl/callbacks/callback_list.py +++ b/openfl/callbacks/callback_list.py @@ -29,7 +29,7 @@ def __init__( **params, ): super().__init__() - self.callbacks = _flatten(callbacks) if callbacks else [] + self.callbacks = list(_flatten(callbacks)) if callbacks else [] self._add_default_callbacks(add_memory_profiler, add_metric_writer) @@ -77,7 +77,7 @@ def on_round_end(self, round_num: int, logs=None): def on_experiment_begin(self, logs=None): for callback in self.callbacks: - callback.on_experiment_begin(logs) + callback.on_experiment_begin(logs=logs) def on_experiment_end(self, logs=None): for callback in self.callbacks: diff --git a/openfl/callbacks/secure_aggregation.py b/openfl/callbacks/secure_aggregation.py new file mode 100644 index 0000000000..7a74a8b899 --- /dev/null +++ b/openfl/callbacks/secure_aggregation.py @@ -0,0 +1,280 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains callback that help setup for secure aggregation for the +collaborator. +""" + +import json +import logging +import struct + +import numpy as np + +from openfl.callbacks.callback import Callback +from openfl.protocols import utils +from openfl.utilities import TensorKey +from openfl.utilities.secagg import ( + calculate_shared_mask, + create_ciphertext, + create_secret_shares, + decipher_ciphertext, + generate_agreed_key, + generate_key_pair, + pseudo_random_generator, +) + +logger = logging.getLogger(__name__) + + +class SecAggBootstrapping(Callback): + """ + This callback is used by the collaborator to perform secure aggregation + bootstrapping. + + Required params include: + - origin: Name of the collaborator using the callback. + - client: AggregatorGRPCClient to communicate with the aggregator server. + + It also requires the tensor-db client to be set. + """ + + def on_experiment_begin(self, logs=None): + """ + Used to perform secure aggregation setup before experiment begins. + """ + self.name = self.params["origin"] + self.client = self.params["client"] + logger.info("Secure aggregation is enabled, starting setup...") + # Generate private public key pair used for secure aggregation. + self._generate_keys() + # Fetch public keys for all collaborators from the aggregator. + collaborator_keys = self._fetch_public_keys() + # Generate ciphertexts for each collaborator and share them with the + # aggregator. + self._generate_ciphertexts(collaborator_keys) + # Decrypt the addressed ciphertexts and share them with the + # aggregator. + self._decrypt_ciphertexts(collaborator_keys) + # Save the tensors which are required for masking of gradients. + self._save_mask_tensors() + + def _generate_keys(self): + """ + Generates a pair of private and public keys, along with a private seed, + and updates the local and global results. + + This method performs the following steps: + 1. Generates two pairs of private and public keys. + 2. Creates a local result dictionary containing the private keys, + public keys, and a private seed. + 3. Creates a global result dictionary containing the public keys. + 4. Sends the global results to the participant via the aggregator's + secure aggregation mechanism. + 5. Updates the instance parameters with the local result. + """ + private_key1, public_key1 = generate_key_pair() + private_key2, public_key2 = generate_key_pair() + + local_result = { + "private_key": [private_key1, private_key2], + "public_key": [public_key1, public_key2], + "private_seed": np.random.random(), + } + global_results = { + "public_key": [public_key1, public_key2], + } + + self._send_to_aggregator(global_results, "generate_keys") + # Update callback params as the results for this step are reused at a + # later stage. + self.params.update(local_result) + logger.debug("SecAgg: Generate key-pair generation successful") + + def _fetch_public_keys(self): + """ + Fetches collaborators' public keys from the aggregator and identifies + the index of the current collaborator using it's public key. + + Returns: + dict: A dictionary containing the public keys of all collaborators, + where the keys are the collaborator indices and the values are + the public keys. + """ + public_keys = {} + public_keys_tensor = self._fetch_from_aggregator("public_keys") + for tensor in public_keys_tensor: + # Creating a dictionary of the received public keys. + public_keys[int(tensor[0])] = [tensor[1], tensor[2]] + # Finding the index of the current collaborator by matching the + # first public key. + if tensor[1] == self.params["public_key"][0]: + self.index = int(tensor[0]) + + return public_keys + + def _generate_ciphertexts(self, public_keys): + """ + Generate ciphertexts for secure aggregation. + + This method generates ciphertexts for each collaborator using their + public keys. It creates secret shares for the private seed and private + key, then uses these shares to generate agreed keys and ciphertexts + for secure communication between collaborators. + + Args: + public_keys (dict): A dictionary where keys are collaborator + indices and values are lists containing public keys of the + collaborators. + """ + logger.debug("SecAgg: Generating ciphertexts to be shared with other collaborators") + collaborator_count = len(public_keys) + + private_seed = self.params["private_seed"] + seed_shares = create_secret_shares( + # Converts the floating-point number private_seed into an 8-byte + # binary representation. + struct.pack("d", private_seed), + collaborator_count, + collaborator_count, + ) + + private_keys = self.params["private_key"] + # Create secret shares for the private key. + key_shares = create_secret_shares( + str.encode(private_keys[0]), + collaborator_count, + collaborator_count, + ) + + global_results = {"ciphertext": []} + local_result = {"ciphertext_verification": {}, "agreed_keys": []} + # Create cipher-texts for each collaborator. + for collab_index in public_keys: + agreed_key = generate_agreed_key(private_keys[0], public_keys[collab_index][0]) + ciphertext, mac, nonce = create_ciphertext( + agreed_key, # agreed key + self.index, # source collaborator index + collab_index, # destination collaborator index + seed_shares[collab_index], # seed share from source to dest + key_shares[collab_index], # key share from source to dest + ) + global_results["ciphertext"].append((self.index, collab_index, str(ciphertext))) + local_result["ciphertext_verification"][collab_index] = [ciphertext, mac, nonce] + local_result["agreed_keys"].append([self.index, collab_index, agreed_key]) + + self._send_to_aggregator(global_results, "generate_ciphertexts") + # Update callback params as the results for this step are reused at a + # later stage. + self.params.update(local_result) + + logger.debug("SecAgg: Ciphertexts shared with the aggregator successfully") + + def _decrypt_ciphertexts(self, public_keys): + """ + Decrypts the ciphertexts received from collaborators using the provided + public keys. + + This method fetches the ciphertexts from the aggregator, decrypts them + using the collaborator's private key and the provided public keys, and + then sends the decrypted seed shares and key shares back to the + aggregator. + + Args: + public_keys (dict): A dictionary containing the public keys of the + collaborators. + """ + logger.debug("SecAgg: fetching addressed ciphertexts from the aggregator") + + ciphertexts = self._fetch_from_aggregator("ciphertexts") + private_keys = self.params["private_key"] + ciphertext_verification = self.params["ciphertext_verification"] + + global_results = {"seed_share": [], "key_share": []} + + for cipher in ciphertexts: + source_index = int(cipher[0]) + if int(cipher[1]) == self.index: + _, _, seed_share, key_share = decipher_ciphertext( + generate_agreed_key(private_keys[0], public_keys[source_index][0]), + ciphertext_verification[source_index][0], + ciphertext_verification[source_index][1], + ciphertext_verification[source_index][2], + ) + global_results["seed_share"].append((source_index, self.index, str(seed_share))) + global_results["key_share"].append((source_index, self.index, str(key_share))) + + self._send_to_aggregator(global_results, "decrypt_ciphertexts") + + logger.debug("SecAgg: decrypted ciphertexts shared with the aggregator") + + def _generate_masks(self): + """ + Use the private seed and agreed keys to calculate the masks to be + added to the gradients. + """ + private_mask = pseudo_random_generator(self.params.get("private_seed")) + shared_mask = calculate_shared_mask(self.params.get("agreed_keys")) + + return private_mask, shared_mask + + def _save_mask_tensors(self): + """ + Generates private and shared masks, stores them in a local tensor + dictionary, and caches the dictionary in the tensor database. + + These tensors are then added to the gradient before sharing them + with the aggregator during training task. + """ + private_mask, shared_mask = self._generate_masks() + local_tensor_dict = { + TensorKey("private_mask", self.name, -1, False, ("secagg",)): [private_mask], + TensorKey("shared_mask", self.name, -1, False, ("secagg",)): [shared_mask], + } + self.tensor_db.cache_tensor(local_tensor_dict) + logger.info("SecAgg: setup completed, saved required tensors to db.") + + def _send_to_aggregator(self, tensor_dict: dict, stage: str): + """ + Sends the provided tensor dictionary to the aggregator after + compressing it. + + Args: + tensor_dict (dict): A dictionary where keys are tensor names and + values are numpy arrays. + stage (str): The current stage of the secure aggregation process. + """ + named_tensors = [] + # Covert python dict to tensor dict. + for key, nparray in tensor_dict.items(): + tensor_key = TensorKey( + key, + self.name, + -1, + False, + ( + self.name, + "secagg", + ), + ) + named_tensor = utils.construct_named_tensor( + tensor_key, str.encode(json.dumps(nparray)), {}, lossless=True + ) + named_tensors.append(named_tensor) + + self.client.send_local_task_results(self.name, -1, f"secagg_{stage}", -1, named_tensors) + + def _fetch_from_aggregator(self, key_name): + """ + Fetches the aggregated tensor data from a aggregator. + + Args: + key_name (str): The name of the key to fetch the tensor for. + + Returns: + bytes: The aggregated tensor data in bytes. + """ + tensor = self.client.get_aggregated_tensor( + self.name, key_name, -1, False, ("secagg",), True + ) + return json.loads(tensor.data_bytes) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 965fae11c5..c04149d035 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -3,6 +3,7 @@ """Aggregator module.""" +import json import logging import queue import time @@ -12,11 +13,12 @@ import openfl.callbacks as callbacks_module from openfl.component.aggregator.straggler_handling import CutoffTimePolicy, StragglerPolicy from openfl.databases import PersistentTensorDB, TensorDB -from openfl.interface.aggregation_functions import WeightedAverage +from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec from openfl.protocols import base_pb2, utils from openfl.protocols.base_pb2 import NamedTensor from openfl.utilities import TaskResultKey, TensorKey, change_tags +from openfl.utilities.secagg.setup import Setup as secagg_setup logger = logging.getLogger(__name__) @@ -81,9 +83,10 @@ def __init__( initial_tensor_dict=None, log_memory_usage=False, write_logs=False, - callbacks: Optional[List] = None, + callbacks: Optional[List] = [], persist_checkpoint=True, persistent_db_path=None, + secure_aggregation=False, ): """Initializes the Aggregator. @@ -203,6 +206,21 @@ def __init__( self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) self._load_initial_tensors() # keys are TensorKeys + self.collaborator_tensor_results = {} # {TensorKey: nparray}} + self._secure_aggregation_enabled = secure_aggregation + if self._secure_aggregation_enabled: + self.secagg = secagg_setup(self.uuid, self.authorized_cols, self.tensor_db) + + # Callbacks + self.callbacks = callbacks_module.CallbackList( + callbacks, + add_memory_profiler=log_memory_usage, + add_metric_writer=write_logs, + tensor_db=self.tensor_db, + origin="aggregator", + collaborators=self.authorized_cols, + aggregator_uuid=self.uuid, + ) if self.persistent_db and self._recover(): logger.info("Recovered state of aggregator") @@ -634,6 +652,24 @@ def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, compr """ tensor_name, origin, round_number, report, tags = tensor_key + # Secure aggregation setup tensor. + if "secagg" in tags: + import numpy as np + + class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + named_tensor = utils.construct_named_tensor( + tensor_key, + str.encode(json.dumps(nparray, cls=NumpyEncoder)), + {}, + lossless=True, + ) + + return named_tensor # if we have an aggregated tensor, we can make a delta if "aggregated" in tags and send_model_deltas: # Should get the pretrained model to create the delta. If training @@ -716,9 +752,17 @@ def send_local_task_results( Returns: None """ - # Save task and its metadata for recovery - serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors] + # Check if secure aggregation is enabled. + if self._secure_aggregation_enabled: + secagg_setup = self._secure_aggregation_setup(collaborator_name, named_tensors) + # Task results processing is not required if the tensors belong to + # secure aggregation setup stage. + if secagg_setup: + return + if self.persistent_db: + # Save task and its metadata for recovery + serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors] self.persistent_db.save_task_results( collaborator_name, round_number, task_name, data_size, serialized_tensors ) @@ -859,6 +903,14 @@ def _process_named_tensor(self, named_tensor, collaborator_name): tuple(named_tensor.tags), ) tensor_name, origin, round_number, report, tags = tensor_key + # Secure aggregation setup stage key + if "secagg" in tags: + nparray = json.loads(raw_bytes) + self.tensor_db.cache_tensor({tensor_key: nparray}) + logger.debug("Created TensorKey: %s", tensor_key) + + return tensor_key, nparray + assert "compressed" in tags or "lossy_compressed" in tags, ( f"Named tensor {tensor_key} is not compressed" ) @@ -1049,7 +1101,14 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict: # Strip the collaborator label, and lookup aggregated tensor new_tags = change_tags(tags, remove_field=collaborators_for_task[0]) agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags) - agg_function = WeightedAverage() if "metric" in tags else task_agg_function + # Check if secure aggregation is enabled, set aggregation function. + agg_function = ( + task_agg_function + if "metric" not in tags + else SecureWeightedAverage() + if self._secure_aggregation_enabled + else WeightedAverage() + ) agg_results = self.tensor_db.get_aggregated_tensor( agg_tensor_key, collaborator_weight_dict, @@ -1200,3 +1259,45 @@ def stop(self, failed_collaborator: str = None) -> None: collaborator_name, ) self.quit_job_sent_to.append(collaborator_name) + + def _secure_aggregation_setup(self, collaborator_name, named_tensors): + """ + Set up secure aggregation for the given collaborator and named tensors. + + This method processes named tensors that are part of the secure + aggregation setup stages. It saves the processed tensors to the local + tensor database and checks if all collaborators have sent their data + for the current key. If all collaborators have sent their data, it + proceeds with aggregation for the key. + + Args: + collaborator_name (str): The name of the collaborator sending the + tensors. + named_tensors (list): A list of named tensors to be processed. + + Returns: + bool: True if the setup is complete or if the tensor does not + belong to secure aggregation setup, otherwise waits for all + collaborators. + """ + secagg_setup = False + for named_tensor in named_tensors: + # Check if the tensor belongs to one from secure aggregation + # setup stages. + if "secagg" not in tuple(named_tensor.tags): + continue + else: + secagg_setup = True + # Process and save tensor to local tensor db. + self._process_named_tensor(named_tensor, collaborator_name) + tensor_name = named_tensor.name + # Check if all collaborators have sent their data for the + # current key. + all_collaborators_sent = self.secagg.check_tensors_received(tensor_name) + if not all_collaborators_sent: + continue + # If all collaborators have sent their data, proceed with + # aggregation for the key. + self.secagg.aggregate_tensor(tensor_name) + + return secagg_setup diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 4a5a78329a..d152ac7ca6 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -85,7 +85,8 @@ def __init__( db_store_rounds=1, log_memory_usage=False, write_logs=False, - callbacks: Optional[List] = None, + callbacks: Optional[List] = [], + secure_aggregation=False, ): """Initialize the Collaborator object. @@ -148,12 +149,26 @@ def __init__( self.task_runner.set_optimizer_treatment(self.opt_treatment.name) + self._secure_aggregation_enabled = secure_aggregation + if self._secure_aggregation_enabled: + self._private_mask = None + self._shared_mask = None + secure_aggregation_callback = callbacks_module.SecAggBootstrapping() + if isinstance(callbacks, callbacks_module.Callback): + callbacks = [callbacks, secure_aggregation_callback] + elif isinstance(callbacks, list): + callbacks.append(secure_aggregation_callback) + else: + callbacks = [secure_aggregation_callback] + # Callbacks self.callbacks = callbacks_module.CallbackList( callbacks, add_memory_profiler=log_memory_usage, add_metric_writer=write_logs, + tensor_db=self.tensor_db, origin=self.collaborator_name, + client=self.client, ) def set_available_devices(self, cuda: Tuple[str] = ()): @@ -327,6 +342,10 @@ def do_task(self, task, round_number) -> dict: input_tensor_dict=input_tensor_dict, **kwargs, ) + # If secure aggregation is enabled, add masks to the dict to be shared + # with the aggregator. + if self._secure_aggregation_enabled: + self._secure_aggregation_masking(global_output_tensor_dict) # Save global and local output_tensor_dicts to TensorDB self.tensor_db.cache_tensor(global_output_tensor_dict) @@ -629,3 +648,42 @@ def named_tensor_to_nparray(self, named_tensor): self.tensor_db.cache_tensor({decompressed_tensor_key: decompressed_nparray}) return decompressed_nparray + + def _secure_aggregation_masking(self, global_output_tensor_dict): + """ + Apply secure aggregation masking to the global output tensor + dictionary. + + This method modifies the provided global output tensor dictionary by + applying secure aggregation masking if secure aggregation is enabled. + It fetches the private and shared masks from the tensor database and + applies them to the tensors in the global output tensor dictionary + that have the "metric" tag. + + Args: + global_output_tensor_dict (dict): A dictionary where keys are + tensor keys and values are the corresponding tensors. + + Returns: + None: The method modifies the global_output_tensor_dict in place. + """ + import numpy as np + + # Storing the masks as class attributes to reduce the number of + # lookups in the database. + # Fetch private mask from tensor db if not already fetched. + if not self._private_mask: + self._private_mask = self.tensor_db.get_tensor_from_cache( + TensorKey("private_mask", self.collaborator_name, -1, False, ("secagg",)) + )[0] + # Fetch shared mask from tensor db if not alreday fetched. + if not self._shared_mask: + self._shared_mask = self.tensor_db.get_tensor_from_cache( + TensorKey("shared_mask", self.collaborator_name, -1, False, ("secagg",)) + )[0] + + for tensor_key in global_output_tensor_dict: + _, _, _, _, tags = tensor_key + if "metric" in tags: + masked_metric = np.add(self._private_mask, global_output_tensor_dict[tensor_key]) + global_output_tensor_dict[tensor_key] = np.add(masked_metric, self._shared_mask) diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 3b229e8dd5..24c721eeeb 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -513,6 +513,10 @@ def get_collaborator( defaults[SETTINGS]["compression_pipeline"] = self.get_tensor_pipe() defaults[SETTINGS]["task_config"] = self.config.get("tasks", {}) + # Check if secure aggregation is enabled. + defaults[SETTINGS]["secure_aggregation"] = ( + self.config.get("aggregator", {}).get(SETTINGS, {}).get("secure_aggregation", False) + ) if client is not None: defaults[SETTINGS]["client"] = client else: diff --git a/openfl/interface/aggregation_functions/__init__.py b/openfl/interface/aggregation_functions/__init__.py index 0ee32655c6..b13962fe0d 100644 --- a/openfl/interface/aggregation_functions/__init__.py +++ b/openfl/interface/aggregation_functions/__init__.py @@ -11,5 +11,6 @@ from openfl.interface.aggregation_functions.fedcurv_weighted_average import FedCurvWeightedAverage from openfl.interface.aggregation_functions.geometric_median import GeometricMedian from openfl.interface.aggregation_functions.median import Median +from openfl.interface.aggregation_functions.secure_weighted_average import SecureWeightedAverage from openfl.interface.aggregation_functions.weighted_average import WeightedAverage from openfl.interface.aggregation_functions.yogi_adaptive_aggregation import YogiAdaptiveAggregation diff --git a/openfl/interface/aggregation_functions/secure_weighted_average.py b/openfl/interface/aggregation_functions/secure_weighted_average.py new file mode 100644 index 0000000000..85fe806859 --- /dev/null +++ b/openfl/interface/aggregation_functions/secure_weighted_average.py @@ -0,0 +1,185 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +"""Federated averaging with secure aggregation module.""" + +import numpy as np + +from openfl.interface.aggregation_functions.weighted_average import WeightedAverage +from openfl.utilities import LocalTensor +from openfl.utilities.secagg import ( + calculate_shared_mask, + pseudo_random_generator, +) + + +class SecureWeightedAverage(WeightedAverage): + """Weighted average with secure aggregation.""" + + def __init__(self): + super().__init__() + self._private_masks, self._shared_masks = None, None + + def call(self, local_tensors, db_iterator, *_) -> np.ndarray: + """Aggregate tensors. + + Args: + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. + db_iterator: iterator over history of all tensors. Columns: + - 'tensor_name': name of the tensor. + Examples for `torch.nn.Module`s: 'conv1.weight','fc2.bias'. + - 'round': 0-based number of round corresponding to this + tensor. + - 'tags': tuple of tensor tags. Tags that can appear: + - 'model' indicates that the tensor is a model parameter. + - 'trained' indicates that tensor is a part of a training + result. + These tensors are passed to the aggregator node after + local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. + These tensors are sent to collaborators for the next + round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. + also one of the tags is a collaborator name + if it corresponds to a result of a local task. + + - 'nparray': value of the tensor. + tensor_name: name of the tensor + fl_round: round number + tags: tuple of tags for this tensor + Returns: + np.ndarray: aggregated tensor + """ + # Generate masks for the collaborators if not laready done. + self._generate_masks(db_iterator) + # Calaculate the weighted avreage of collaborator masks. + weighted_mask = self._calculcate_weighted_mask_average(self._private_masks, local_tensors) + # Get weighted average for shared tensors. + tensor_avg = super().call(local_tensors) + # Subtract weighted average of masks from the tensor average. + return np.subtract(np.subtract(tensor_avg, weighted_mask), self._shared_masks) + + def _generate_masks(self, db_iterator): + """ + Generate shared and private masks for secure aggregation. + + This method processes a database iterator to extract private seeds, + agreed keys, and column indices, which are then used to generate + shared and private masks. + + Args: + db_iterator (iterator): An iterator over the database items + containing tensors with tags, tensor names, and numpy arrays. + + Raises: + KeyError: If the required keys are not found in the database items. + + Notes: + - The shared masks are calculated using the agreed keys. + - The private masks are generated for each collaborator using + their private seeds. + - The private masks are stored in a dictionary with the + collaborator's name as the key. + """ + if self._shared_masks and self._private_masks: + return + + private_seeds = [] + agreed_keys = [] + col_indices = [] + # Get all required values from tensor db. + private_seeds, agreed_keys, col_indices = self._get_secagg_items_from_db(db_iterator) + if not self._shared_masks: + # Calculate shared mask + self._shared_masks = calculate_shared_mask(agreed_keys) + + if not self._private_masks: + # Create a dict with collaborator index and their name. + # This dict is used to map private masks to the collaborator name + # as they are stored with collaborator index in the db. + col_idx = {} + for col in col_indices: + col_idx[col[1]] = col[0] + + del col_indices + + # Generate private masks for each collaborator. + self._private_masks = {} + for seed in private_seeds: + # col_name: col_private_mask + self._private_masks[col_idx[seed[0]]] = pseudo_random_generator(seed[1]) + + del col_idx + del private_seeds + + def _calculcate_weighted_mask_average( + self, private_masks: dict, local_tensors: list[LocalTensor] + ): + """ + Calculate the weighted mask average for the given local tensors using + their private masks and weight for their respective tensors. + + Args: + private_masks (dict): A dictionary where keys are collaborator + names and values are tuples, with the second element being the + mask for that colaborator. + local_tensors (list): A list of tensors, where each tensor has + attributes 'col_name' and 'weight'. + + Returns: + numpy.ndarray: The average mask calculated as the weighted + average of the masks. + """ + weights = [] + masks = [] + # Create a list of private masks and weights for the collaborators + # whose tensors are being aggregated. + for tensor in local_tensors: + col_name = tensor.col_name + weights.append(tensor.weight) + masks.append(private_masks[col_name]) + + # Calculate weighted mask using the masks and weights where each index + # in the lists represents a single collaborator. + weighted_mask = np.average(masks, weights=weights, axis=0) + + del weights + del masks + + return weighted_mask + + def _get_secagg_items_from_db(self, db_iterator): + """ + Extracts secure aggregation items from a database iterator. + It retrieves the private seeds, agreed keys, and column indices from + the database items. + + Args: + db_iterator (iterable): An iterator that yields database items. + Each item is expected to be a dictionary with keys "tags", + "tensor_name", and "nparray". + + Returns: + tuple: A tuple containing three elements: + - private_seeds (numpy.ndarray): The private seeds array. + - agreed_keys (numpy.ndarray): The agreed keys array. + - col_indices (numpy.ndarray): The column indices array. + + Raises: + KeyError: If any of the required keys ("tags", "tensor_name", + "nparray") are missing in an item. + """ + for item in db_iterator: + if "tags" in item and item["tags"] == ("secagg",): + if item["tensor_name"] == "private_seeds": + private_seeds = item["nparray"] + elif item["tensor_name"] == "agreed_keys": + agreed_keys = item["nparray"] + elif item["tensor_name"] == "indices": + col_indices = item["nparray"] + + return private_seeds, agreed_keys, col_indices diff --git a/openfl/utilities/secagg/setup.py b/openfl/utilities/secagg/setup.py new file mode 100644 index 0000000000..009c4d0eb4 --- /dev/null +++ b/openfl/utilities/secagg/setup.py @@ -0,0 +1,270 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +""" +This file contains the Setup class used on the server side for secure +aggregation setup. +""" + +import logging + +from openfl.utilities import TensorKey +from openfl.utilities.secagg import ( + calculate_shared_mask, + generate_agreed_key, + reconstruct_secret, +) + +logger = logging.getLogger(__name__) + + +class Setup: + """ + Used by the aggregator for the setup stage of secure aggregation. + """ + + def __init__(self, aggregator_uuid, collaborator_list, tensor_db): + self._aggregator_uuid = aggregator_uuid + self._collaborator_list = collaborator_list + self._tensor_db = tensor_db + self._results = {} + + def check_tensors_received(self, tensor_name): + """ + Checks if the tensor with the given name has been received from all + collaborators. + + Args: + tensor_name (str): The name of the tensor to check. + + Returns: + bool: True if the tensor has been received from all collaborators, + False otherwise. + """ + logger.debug("Checking if received {} from all collaborators".format(tensor_name)) + all_received = True + for collaborator in self._collaborator_list: + nparray = self._tensor_db.get_tensor_from_cache( + TensorKey( + tensor_name, + self._aggregator_uuid, + -1, + False, + ( + collaborator, + "secagg", + ), + ) + ) + if nparray is None: + all_received = False + + return all_received + + def aggregate_tensor(self, tensor_name): + """ + Aggregates the specified tensor based on its name and performs + subsequent operations if necessary. + + Args: + tensor_name (str): The name of the tensor to aggregate. + It can be one of the following: "public_key", "ciphertext", + "seed_share", "key_share". + + Raises: + ValueError: If the tensor_name is not one of the expected values. + + Operations: + - Aggregates public keys if tensor_name is "public_key". + - Aggregates ciphertexts if tensor_name is "ciphertext". + - Aggregates seed shares if tensor_name is "seed_share". + - Aggregates key shares if tensor_name is "key_share". + - If both "seed_shares" and "key_shares" are present in the + results, it: + - Reconstructs secrets. + - Generates agreed keys between all pairs of collaborators. + - Saves the local tensors to the tensor database. + """ + if tensor_name == "public_key": + self._aggregate_public_keys() + elif tensor_name == "ciphertext": + self._aggregate_ciphertexts() + elif tensor_name in ["seed_share", "key_share"]: + self._aggregate_secret_shares(tensor_name) + + if "seed_shares" in self._results and "key_shares" in self._results: + self._reconstruct_secrets() + # Generate agreed keys between all pairs of collaborators. + self._generate_agreed_keys() + # Save the local tensors to the tensor database. + self._save_tensors() + + def _aggregate_public_keys(self): + """ + Sorts the public keys received from collaborators and updates the + results. + """ + aggregated_tensor = [] + self._results["public_keys"] = {} + self._results["index"] = {} + index = 1 + for collaborator in self._collaborator_list: + # Fetching public key for each collaborator from tensor db. + nparray = self._tensor_db.get_tensor_from_cache( + TensorKey( + "public_key", + self._aggregator_uuid, + -1, + False, + ( + collaborator, + "secagg", + ), + ) + ) + aggregated_tensor.append([index, nparray[0], nparray[1]]) + # Creating a map for local use. + self._results["public_keys"][index] = [nparray[0], nparray[1]] + self._results["index"][collaborator] = index + index += 1 + + # Storing the aggregated result in tensor db which is fetched by the + # collaborators in subsequent steps. + self._tensor_db.cache_tensor( + { + TensorKey( + "public_keys", self._aggregator_uuid, -1, False, ("secagg",) + ): aggregated_tensor + } + ) + + def _aggregate_ciphertexts(self): + """ + Sorts the ciphertexts received from collaborators and updates the + results. + """ + aggregated_tensor = [] + self._results["ciphertexts"] = [] + + for collaborator in self._collaborator_list: + # Fetching ciphertext for each collaborator from tensor db. + nparray = self._tensor_db.get_tensor_from_cache( + TensorKey( + "ciphertext", + self._aggregator_uuid, + -1, + False, + ( + collaborator, + "secagg", + ), + ) + ) + for ciphertext in nparray: + aggregated_tensor.append(ciphertext) + # Creating a map for local use. + self._results["ciphertexts"].append(ciphertext) + # Storing the aggregated result in tensor db which is fetched by the + # collaborators in subsequent steps. + self._tensor_db.cache_tensor( + { + TensorKey( + "ciphertexts", self._aggregator_uuid, -1, False, ("secagg",) + ): aggregated_tensor + } + ) + + def _aggregate_secret_shares(self, key_name): + """ + Aggregates secret shares for a given key name from the tensor database. + + This method fetches seed shares for each collaborator from the tensor + database and organizes them into a dictionary for local use. + + Args: + key_name (str): The name of the key for which secret shares are to + be aggregated. + """ + self._results[f"{key_name}s"] = {} + + for collaborator in self._collaborator_list: + # Fetching seed shares for each collaborator from tensor db. + nparray = self._tensor_db.get_tensor_from_cache( + TensorKey( + key_name, + self._aggregator_uuid, + -1, + False, + ( + collaborator, + "secagg", + ), + ) + ) + for share in nparray: + # Creating a map for local use. + if int(share[1]) not in self._results[f"{key_name}s"]: + self._results[f"{key_name}s"][int(share[1])] = {} + self._results[f"{key_name}s"][int(share[1])][int(share[0])] = share[2][2:-1] + + def _reconstruct_secrets(self): + """ + Reconstructs the private seeds and private keys from the secret shares. + """ + self._results["private_seeds"] = {} + self._results["private_keys"] = {} + + for source_id in self._results["seed_shares"]: + self._results["private_seeds"][source_id] = reconstruct_secret( + self._results["seed_shares"][source_id] + ) + self._results["private_keys"][source_id] = reconstruct_secret( + self._results["key_shares"][source_id] + ) + logger.info("SecAgg: recreated secrets successfully") + + def _generate_agreed_keys(self): + """ + Generates agreed keys between all pairs of collaborators using their + private keys and public keys. + """ + self._results["agreed_keys"] = [] + for source_index in self._results["index"].values(): + for dest_index in self._results["index"].values(): + if source_index == dest_index: + continue + self._results["agreed_keys"].append( + [ + source_index, + dest_index, + generate_agreed_key( + self._results["private_keys"][source_index], + self._results["public_keys"][dest_index][0], + ), + ] + ) + + def _save_tensors(self): + """ + Generate and save tensors required for secure aggregation. + + This method generates private and shared masks by calling the + `_generate_masks` method. It then creates a dictionary of tensors + to be saved, which includes the sum of private and shared masks. + The tensors are cached in the tensor database. + + These tensors are then added to the gradient before to get the + actual aggregate after removing the masks. + """ + shared_mask_sum = calculate_shared_mask(self._results["agreed_keys"]) + local_tensor_dict = { + TensorKey("indices", "agg", -1, False, ("secagg",)): [ + [collaborator, index] for collaborator, index in self._results["index"].items() + ], + TensorKey("private_seeds", "agg", -1, False, ("secagg",)): [ + [index, seed] for index, seed in self._results["private_seeds"].items() + ], + TensorKey("agreed_keys", "agg", -1, False, ("secagg",)): self._results["agreed_keys"], + TensorKey("shared_mask_sum", "agg", -1, False, ("secagg",)): [shared_mask_sum], + } + self._tensor_db.cache_tensor(local_tensor_dict) + logger.info("SecAgg: setup completed, saved required tensors to db.") diff --git a/setup.py b/setup.py index f1de927a72..abc09cf5bd 100644 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ def run(self): 'tensorboardX', 'protobuf>=4.22,<6.0.0', 'grpcio>=1.56.2,<1.66.0', + 'pycryptodome' ], python_requires='>=3.10, <3.13', project_urls={