diff --git a/openfl-workspace/torch_llm/src/pt_model.py b/openfl-workspace/torch_llm/src/pt_model.py index 38df8e4613..19eb16c198 100644 --- a/openfl-workspace/torch_llm/src/pt_model.py +++ b/openfl-workspace/torch_llm/src/pt_model.py @@ -169,7 +169,7 @@ def save_modelstate(self, col_name, round_num, func_name, kwargs): ) return state_path, out_path, data_path - def launch_horovod(self, data_path, state_path, out_path, horovod_kwags): + def launch_horovod(self, data_path, state_path, out_path, function_name, horovod_kwags): result = subprocess.run( [ "horovodrun", @@ -190,7 +190,7 @@ def launch_horovod(self, data_path, state_path, out_path, horovod_kwags): "--kwargs", json.dumps(horovod_kwags), "--func", - "validate", + function_name, "--out_path", out_path, ], @@ -225,7 +225,7 @@ def validate( "input_tensor_dict": None, "use_tqdm": use_tqdm, } - result = self.launch_horovod(data_path, state_path, out_path, horovod_kwags) + result = self.launch_horovod(data_path, state_path, out_path, 'validate', horovod_kwags, ) if result.returncode != 0: raise RuntimeError(result.stderr) @@ -276,7 +276,7 @@ def train_batches( "input_tensor_dict": None, "use_tqdm": use_tqdm, } - result = self.launch_horovod(data_path, state_path, out_path, horovod_kwags) + result = self.launch_horovod(data_path, state_path, out_path, 'train_batches', horovod_kwags) if result.returncode != 0: raise RuntimeError(result.stderr)