diff --git a/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py b/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py index 33ba6c15ab..a6fdb0524e 100644 --- a/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py +++ b/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py @@ -115,7 +115,11 @@ def FedAvg(models): # NOQA: N802 state_dict[key] = np.sum( [state[key] for state in state_dicts], axis=0 ) / len(models) - new_model.load_state_dict(state_dict) + # Convert numpy arrays within the state dictionary to PyTorch tensors + state_dict_tensors = {key: torch.tensor(value, dtype=torch.float32).cpu() for key, value in state_dict.items()} + + # Load the converted state dictionary into the model + new_model.load_state_dict(state_dict_tensors) return new_model