This repository is the official implementation of Personalized Federated Learning through Local Memorization
Federated learning allows clients to collaboratively learn statistical models while keeping their data local. Federated learning was originally used to train a unique global model to be served to all clients, but this approach might be sub-optimal when clients' local data distributions are heterogeneous. In order to tackle this limitation, recent personalized federated learning methods train a separate model for each client while still leveraging the knowledge available at other clients. In this work, we exploit the ability of deep neural networks to extract high quality vectorial representations (embeddings) from non-tabular data, e.g., images and text, to propose a personalization mechanism based on local memorization. Personalization is obtained interpolating a pre-trained global model with a k-nearest neighbors (kNN) model based on the shared representation provided by the global model. We provide generalization bounds for the proposed approach, and we show on a suite of federated datasets that this approach achieves significantly higher accuracy and fairness than state-of-the-art methods.
To install requirements:
pip install -r requirements.txt
Additionally, FAISS should be installed. Instructions for the installation of FAISS can be found here
We provide code to simulate federated training of machine learning.
The core objects are Aggregator
and Client
; different federated learning
algorithms can be implemented by implementing the local update method
Client.step()
and/or the aggregation protocol defined in
Aggregator.mix()
and Aggregator.update_client()
.
In addition to the trivial baseline consisting of training models locally without any collaboration, this repository supports the following federated learning algorithms:
- FedAvg (McMahan et al. 2017)
- FdProx (Li et al. 2018)
- q-FFL (Li et al. 2020)
- AFL (Mohri et al. 2019)
- Clustered FL (Sattler et al. 2019)
- L2SGD (Hanzely et al. 2020)
- Ditto (Li et al. 2021)
- FedRep (Collins et al. 2021)
- APFL (Deng et al. 2020)
- PerFedAvg (Fallah et al. 2020)
Different algorithms can be obtained through different combinations of Aggregator
(see aggregator.py
),
Client
(see client.py)
, and Optimizer
(see utils/optim.py
).
The following table summarizes the combinations needed for each algorithm
Algorithm | Aggregator | Client | Optimizer |
---|---|---|---|
Local | NoCommunicationAggregator |
Client |
sgd |
FedAvg | CentralizedAggregator |
Client |
sgd |
FedProx | CentralizedAggregator |
Client |
prox-sgd |
AFL | AgnosticAggregator |
AgnosticFLClient |
sgd |
q-FFL | FFLAggregator |
FFLClient |
sgd |
FedAvg+ | CentralizedAggregator |
Client |
sgd |
Clustered FL | ClusteredAggregator |
Client |
sgd |
L2SGD | LoopLessLocalSGDAggregator |
Client |
sgd |
Ditto | PersonalizedAggregator |
Client |
prox-sgd |
FedRep | CentralizedAggregator |
FedRepClient |
sgd |
APFL | APFLAggregator |
Client |
sgd |
PerFedAvg | APFLAggregator |
PerFedAvgClient |
sgd |
kNN-Per | CentralizedAggregator |
KNNPerClient |
sgd |
A detailed example for simulating a federated training using FedAvg is provided in examples/fed-avg.md
This repository implements kNN-Per described in
Personalized Federated Learning through Local Memorization.
The object KNNPerClient
represents a client with a local memory,
represented as a Datastore
object.
An example of kNN-Per with CIFAR-10 dataset is provided in examples/cifar-10.md
.
We provide four federated benchmark datasets spanning a wide range of machine learning tasks: image classification (CIFAR10 and CIFAR100), handwritten character recognition (FEMNIST), and language modelling (Shakespeare), in addition to a synthetic dataset
Shakespeare dataset (resp. FEMNIST) was naturally partitioned by assigning all lines from the same characters (resp. all images from the same writer) to the same client. We created federated versions of CIFAR10 by distributing samples with the same label across the clients according to a symmetric Dirichlet distribution with parameter 0.3. For CIFAR100, we exploited the availability of "coarse" and "fine" labels, using a two-stage Pachinko allocation method to assign 600 sample to each of the 100 clients.
The following table summarizes the datasets and models
Dataset | Task | Model |
---|---|---|
FEMNIST | Handwritten character recognition | MobileNet-v2 |
CIFAR10 | Image classification | MobileNet-v2 |
CIFAR100 | Image classification | MobileNet-v2 |
Shakespeare | Next character prediction | Stacked LSTM |
See the README.md
files of respective dataset, i.e., data/$DATASET
,
for instructions on generating data.
To train the base models used for Fed-kNN, run this command
python train.py
<dataset_name> \
--aggregator_type centralized \
--n_rounds 100 \
--bz 128 \
--lr 0.05 \
--lr_scheduler multi_step \
--log_freq 10 \
--device cuda \
--optimizer sgd \
--seed 1234 \
--logs_dir ./logs \
--chkpts_dir ./chkpts/cifar10_fedavg
--verbose 1
To evaluate the score (accuracy) of kNN-Per, run this command
python eval_knnper.py \
<dataset_name> \
random \
<chkpts_path> \
<n_neighbors>\
--capacities_grid_resolution 0.01 \
--weights_grid_resolution 0.01 \
--bz 256 \
--device cuda \
--verbose 1 \
--results_dir <results_dir> \
--seed 12345
This scripts will create an array (saved as an .npy
file) of shape
(101, 101), each entry corresponds to the score (accuracy) of kNN-Per
for a value of
You can download pretrained models here:
- Models trained using FedAvg on CIFAR-10 for different levels of heterogeneity can be found here.
- Models trained using FedAvg on CIFAR-100 for different levels of heterogeneity can be found here.
- Model trained using FedAvg on Shakespeare can be found here
To download all the pretrained models, run
mkdir chkpts
cd chkpts
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1g1qQsGWFPBb5yDWOXro9i8Gwd_XRohzj' -O 'cifar10-fedavg-alpha-1.0.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1HAaVBzolYPGmmvSJr5jgkugR5pjgC3L4' -O 'cifar10-fedavg-alpha-0.7.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1SPhDq9-SVmAHS_XQQIa4QFk2fnyjYA7M' -O 'cifar10-fedavg-alpha-0.5.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=14SsQ5uXNa7kvjR01g1eufOAR9t5sMfMy' -O 'cifar10-fedavg-alpha-0.3.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1yxSMq7m6e2-Pm8bKRJWW1AaF96LkDq_9' -O 'cifar10-fedavg-alpha-0.1.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1lo3t4suq7gF2jFQWq8hCbC5lHk2zODyY' -O 'cifar100-fedavg-alpha-1.0.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ms4mvRPuKVFIfl3D7Ka0apnekaW8wxxa' -O 'cifar100-fedavg-alpha-0.7.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1mzIfbYElguv0hIbYPP-0-LzEudsp66bI' -O 'cifar100-fedavg-alpha-0.5.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1disK-4yUPNpVVul8YR4bhpSRxwy9GClf' -O 'cifar100-fedavg-alpha-0.3.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1JVyhhnBacFvhJ6q93lBn3P12A9IQT85m' -O 'cifar100-fedavg-alpha-0.1.pth'
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1oKBBHmMKT7aHKDXAhYynjOXDlNtKYjUu' -O 'shakespeare-fedavg.pth'
cd ../
The performance of each personalized model (which coincides with the global one in the case of FedAvg) is evaluated on the local test dataset (unseen at training). The table below shows the average weighted accuracy with weights proportional to local dataset sizes. kNN-Per consistently achieves the highest accuracy across all datasets.
Dataset | Local | FedAvg | FedAvg+ | Clustered FL | Ditto | FedRep | APFL | kNN-Per (Ours) |
---|---|---|---|---|---|---|---|---|
FEMNIST | 71.0 | 83.4 | 84.3 | 83.7 | 84.3 | 85.3 | 84.1 | 88.2 |
CIFAR10 | 57.6 | 72.8 | 75.2 | 73.3 | 80.0 | 77.7 | 78.9 | 83.0 |
CIFAR100 | 31.5 | 47.4 | 51.4 | 47.2 | 52.0 | 53.2 | 51.7 | 55.0 |
Shakespeare | 32.0 | 48.1 | 47.0 | 46.7 | 47.9 | 47.2 | 45.9 | 51.4 |
To plot the effect of the datastore capacity on the accuracy obtained by kNN-Per, run
python make_plots.py capacity_effect <results_dir> --save_path <save_path>
To plot the effect of the mixing weight on the accuracy obtained by Fed-kNN, run
python make_plots.py weight_effect <results_dir> --save_path <save_path>
To plot the effect of the data heterogeneity on the score obtained by Fed-kNN, run this command
cd scripts/<dataset_name>
chmod +x heterogeneity_effect.sh
./heterogeneity_effect.sh
If you use our code or wish to refer to our results, please use the following BibTex entry:
@article{marfoq2021personalized,
title={Personalized Federated Learning through Local Memorization},
author={Marfoq, Othmane and Neglia, Giovanni and Kameni, Laetitia and Vidal, Richard},
journal={arXiv preprint arXiv:2111.09360},
year={2021}
}