Skip to content

Commit

Permalink
Merge pull request #72 from Xiao-Chenguang/subprocess-exception-handling
Browse files Browse the repository at this point in the history
Subprocess exception handling for train and test
  • Loading branch information
Xiao-Chenguang authored Dec 8, 2024
2 parents 5150f80 + a8a8334 commit 3cf8c29
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions fedmind/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,30 @@ def fit(self, pool: int, num_clients: int, num_rounds: int):
for cid in clients:
updates.append(self._result_queue.get())

# train subprocess exception handling
if self.config.NUM_PROCESS > 0:
sp_error = False
for u in updates:
if "error" in u:
sp_error = True
self.logger.error(f"Subprocess Error: {u['error']}")
break
if sp_error:
del updates
break

# 3. Aggregate updates to new model
train_metrics = self._aggregate_updates(updates)
del updates # Fix shared cuda tensors issue (release tensor generated by child process)

# 4. Evaluate the new model
test_metrics = self._evaluate()

# test subprocess exception handling
if self.config.TEST_SUBPROCESS and "error" in test_metrics:
self.logger.error(f"Subprocess Error: {test_metrics['error']}")
break

# 5. Log metrics
self._wb_run.log(train_metrics | test_metrics)

Expand Down Expand Up @@ -392,10 +409,16 @@ def _create_worker_process(
if task == "STOP":
break
elif task == "TRAIN":
result = train_func(**(fix_args | dyn_args))
result_queue.put(result)
try:
result = train_func(**(fix_args | dyn_args))
result_queue.put(result)
except Exception as e:
result_queue.put({"error": e})
elif task == "TEST":
result = test_func(**(fix_args | dyn_args))
test_queue.put(result)
try:
result = test_func(**(fix_args | dyn_args))
test_queue.put(result)
except Exception as e:
test_queue.put({"error": e})
else:
raise ValueError(f"Unknown task type {task}")

0 comments on commit 3cf8c29

Please sign in to comment.