Skip to content

Commit

Permalink
remove nolonger needed type ignores
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiao-Chenguang committed Oct 23, 2024
1 parent c2a5997 commit fcf1647
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 23 deletions.
6 changes: 3 additions & 3 deletions examples/fedavg_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_fedavg():
org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor())
test_ds = MNIST("dataset", train=False, download=True, transform=ToTensor())

effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT # type: ignore
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) # type: ignore
effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1)
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()]

fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss]
Expand All @@ -45,7 +45,7 @@ def test_fedavg():
test_loader=test_loader,
criterion=criterion,
args=args,
).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) # type: ignore
).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/fedprox_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def test_fedavg():
org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor())
test_ds = MNIST("dataset", train=False, download=True, transform=ToTensor())

effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT # type: ignore
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) # type: ignore
effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1)
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()]

fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss]
Expand All @@ -46,7 +46,7 @@ def test_fedavg():
test_loader=test_loader,
criterion=criterion,
args=args,
).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) # type: ignore
).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion fedmind/algs/fedprox.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _train_client(
Returns:
A dictionary containing the trained model parameters.
"""
mu = args.PROX_MU # type: ignore
mu = args.PROX_MU

# Train the model
model.load_state_dict(gm_params)
Expand Down
22 changes: 11 additions & 11 deletions fedmind/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.args = args

self.gm_params = self.model.state_dict(destination=StateDict())
optim: dict = self.args.OPTIM # type: ignore
optim: dict = self.args.OPTIM
if optim["NAME"] == "SGD":
self.optimizer = SGD(self.model.parameters(), lr=optim["LR"])
else:
Expand All @@ -57,13 +57,13 @@ def __init__(
)

logging.basicConfig(
level=args.LOG_LEVEL, # type: ignore
level=args.LOG_LEVEL,
format="%(asctime)s %(levelname)s [%(processName)s] %(message)s",
)
self.logger = logging.getLogger("Server")
self.logger.info(f"Get following configs:\n{yaml.dump(args.to_dict())}")

if self.args.NUM_PROCESS > 0: # type: ignore
if self.args.NUM_PROCESS > 0:
self.__init_mp__()

def __init_mp__(self):
Expand All @@ -78,17 +78,17 @@ def __init_mp__(self):

# Start client processes
self.processes = []
for worker_id in range(self.args.NUM_PROCESS): # type: ignore
for worker_id in range(self.args.NUM_PROCESS):
args = (
worker_id,
self.task_queue,
self.result_queue,
self._train_client,
self.model,
self.args.OPTIM, # type: ignore
self.args.OPTIM,
self.criterion,
self.args.CLIENT_EPOCHS, # type: ignore
self.args.LOG_LEVEL, # type: ignore
self.args.CLIENT_EPOCHS,
self.args.LOG_LEVEL,
self.args,
)
p = mp.Process(target=self._create_worker_process, args=args)
Expand All @@ -99,7 +99,7 @@ def __del_mp__(self):
"""Terminate multi-process environment."""

# Terminate all client processes
for _ in range(self.args.NUM_PROCESS): # type: ignore
for _ in range(self.args.NUM_PROCESS):
self.task_queue.put("STOP")

# Wait for all client processes to finish
Expand Down Expand Up @@ -176,7 +176,7 @@ def fit(self, pool: int, num_clients: int, num_rounds: int):

# 2. Synchornous clients training
updates = []
if self.args.NUM_PROCESS == 0: # type: ignore
if self.args.NUM_PROCESS == 0:
# Serial simulation instead of parallel
for cid in clients:
updates.append(
Expand All @@ -186,7 +186,7 @@ def fit(self, pool: int, num_clients: int, num_rounds: int):
self.fed_loader[cid],
self.optimizer,
self.criterion,
self.args.CLIENT_EPOCHS, # type: ignore
self.args.CLIENT_EPOCHS,
self.logger,
self.args,
)
Expand All @@ -208,7 +208,7 @@ def fit(self, pool: int, num_clients: int, num_rounds: int):
self.wb_run.log(train_metrics | test_metrics)

# Terminate multi-process environment
if self.args.NUM_PROCESS > 0: # type: ignore
if self.args.NUM_PROCESS > 0:
self.__del_mp__()

# Finish wandb run and sync
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fedmind"
version = "0.1.4"
version = "0.1.5"
description = "Federated Learning research framework in your mind"
readme = "README.md"
requires-python = ">=3.12"
Expand Down
6 changes: 3 additions & 3 deletions test/test_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_fedavg():
org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor())
test_ds = MNIST("dataset", train=False, download=True, transform=ToTensor())

effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT # type: ignore
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) # type: ignore
effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT
idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1)
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()]

fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss]
Expand All @@ -45,7 +45,7 @@ def test_fedavg():
test_loader=test_loader,
criterion=criterion,
args=args,
).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) # type: ignore
).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS)

assert True

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fcf1647

Please sign in to comment.