Skip to content

Commit

Permalink
enable manuseed from config for mainprocess and dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiao-Chenguang committed Oct 28, 2024
1 parent fb4127f commit ffcfd5a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ OPTIM:
LR: 0.1
MOMENTUM: 0.9
NAME: SGD
SEED: 0
SERVER_EPOCHS: 20
TEST_SUBPROCESS: false
WB_ENTITY: example_entity
Expand Down
12 changes: 11 additions & 1 deletion examples/cuda_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def test_fedavg():
args.NUM_PROCESS = 10
args.SERVER_EPOCHS = 20

if args.SEED >= 0:
torch.manual_seed(args.SEED)

assert torch.cuda.is_available(), "CUDA is not available"

# 1. Prepare Federated Learning DataSets
Expand All @@ -27,7 +30,14 @@ def test_fedavg():
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1)
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()]

fed_loader = [DataLoader(ds, args.BATCH_SIZE, shuffle=True) for ds in fed_dss]
genetors = [
torch.Generator().manual_seed(args.SEED + i) if args.SEED >= 0 else None
for i in range(args.NUM_CLIENT)
]
fed_loader = [
DataLoader(ds, args.BATCH_SIZE, shuffle=True, generator=gtr)
for ds, gtr in zip(fed_dss, genetors)
]
test_loader = DataLoader(test_ds, args.BATCH_SIZE * 4)

# 2. Prepare Model and Criterion
Expand Down
11 changes: 10 additions & 1 deletion examples/fedavg_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
def test_fedavg():
# 0. Prepare necessary arguments
args = get_config("config.yaml")
if args.SEED >= 0:
torch.manual_seed(args.SEED)

# 1. Prepare Federated Learning DataSets
org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor())
Expand All @@ -23,7 +25,14 @@ def test_fedavg():
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1)
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()]

fed_loader = [DataLoader(ds, args.BATCH_SIZE, shuffle=True) for ds in fed_dss]
genetors = [
torch.Generator().manual_seed(args.SEED + i) if args.SEED >= 0 else None
for i in range(args.NUM_CLIENT)
]
fed_loader = [
DataLoader(ds, args.BATCH_SIZE, shuffle=True, generator=gtr)
for ds, gtr in zip(fed_dss, genetors)
]
test_loader = DataLoader(test_ds, args.BATCH_SIZE * 4)

# 2. Prepare Model and Criterion
Expand Down
13 changes: 11 additions & 2 deletions examples/fedprox_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_fedavg():
# 0. Prepare necessary arguments
args = get_config("config.yaml")
args.PROX_MU = 0.1
if args.SEED >= 0:
torch.manual_seed(args.SEED)

# 1. Prepare Federated Learning DataSets
org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor())
Expand All @@ -24,8 +26,15 @@ def test_fedavg():
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1)
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()]

fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss]
test_loader = DataLoader(test_ds, batch_size=32)
genetors = [
torch.Generator().manual_seed(args.SEED + i) if args.SEED >= 0 else None
for i in range(args.NUM_CLIENT)
]
fed_loader = [
DataLoader(ds, args.BATCH_SIZE, shuffle=True, generator=gtr)
for ds, gtr in zip(fed_dss, genetors)
]
test_loader = DataLoader(test_ds, args.BATCH_SIZE * 4)

# 2. Prepare Model and Criterion
classes = 10
Expand Down

0 comments on commit ffcfd5a

Please sign in to comment.