Skip to content

Commit

Permalink
add a example
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiao-Chenguang committed Oct 22, 2024
1 parent 3b50a2d commit ba4f945
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
52 changes: 52 additions & 0 deletions examples/fedavg_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from functools import reduce

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from fedmind.algs.fedavg import FedAvg
from fedmind.config import get_config
from fedmind.data import ClientDataset


def test_fedavg():
# 0. Prepare necessary arguments
args = get_config("config.yaml")

# 1. Prepare Federated Learning DataSets
org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor())
test_ds = MNIST("dataset", train=False, download=True, transform=ToTensor())

effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT # type: ignore
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) # type: ignore
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()]

fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss]
test_loader = DataLoader(test_ds, batch_size=32)

# 2. Prepare Model and Criterion
classes = 10
features = reduce(lambda x, y: x * y, org_ds[0][0].shape)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(features, 32),
nn.ReLU(),
nn.Linear(32, classes),
)

criterion = nn.CrossEntropyLoss()

# 3. Run Federated Learning Simulation
FedAvg(
model=model,
fed_loader=fed_loader,
test_loader=test_loader,
criterion=criterion,
args=args,
).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) # type: ignore


if __name__ == "__main__":
test_fedavg()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fedmind"
version = "0.1.3a1"
version = "0.1.1"
description = "Federated Learning research framework in your mind"
readme = "README.md"
requires-python = ">=3.12"
Expand Down

0 comments on commit ba4f945

Please sign in to comment.