diff --git a/fedmind/server.py b/fedmind/server.py index 46b17c9..4123859 100644 --- a/fedmind/server.py +++ b/fedmind/server.py @@ -260,6 +260,18 @@ def fit(self, pool: int, num_clients: int, num_rounds: int): for cid in clients: updates.append(self._result_queue.get()) + # train subprocess exception handling + if self.config.NUM_PROCESS > 0: + sp_error = False + for u in updates: + if "error" in u: + sp_error = True + self.logger.error(f"Subprocess Error: {u['error']}") + break + if sp_error: + del updates + break + # 3. Aggregate updates to new model train_metrics = self._aggregate_updates(updates) del updates # Fix shared cuda tensors issue (release tensor generated by child process) @@ -267,6 +279,11 @@ def fit(self, pool: int, num_clients: int, num_rounds: int): # 4. Evaluate the new model test_metrics = self._evaluate() + # test subprocess exception handling + if self.config.TEST_SUBPROCESS and "error" in test_metrics: + self.logger.error(f"Subprocess Error: {test_metrics['error']}") + break + # 5. Log metrics self._wb_run.log(train_metrics | test_metrics) @@ -392,10 +409,16 @@ def _create_worker_process( if task == "STOP": break elif task == "TRAIN": - result = train_func(**(fix_args | dyn_args)) - result_queue.put(result) + try: + result = train_func(**(fix_args | dyn_args)) + result_queue.put(result) + except Exception as e: + result_queue.put({"error": e}) elif task == "TEST": - result = test_func(**(fix_args | dyn_args)) - test_queue.put(result) + try: + result = test_func(**(fix_args | dyn_args)) + test_queue.put(result) + except Exception as e: + test_queue.put({"error": e}) else: raise ValueError(f"Unknown task type {task}")