diff --git a/modelrunner/model/base.py b/modelrunner/model/base.py index 97a3614..3adf7e1 100644 --- a/modelrunner/model/base.py +++ b/modelrunner/model/base.py @@ -16,6 +16,15 @@ from ..storage import ModeType, StorageGroup, open_storage from .parameters import DeprecatedParameter, HideParameter, Parameterized +# read state of the current MPI node +try: + from mpi4py import MPI +except ImportError: + mpi_rank: int = 0 # no MPI -> current process is main process +else: + mpi_rank = MPI.COMM_WORLD.rank + + if TYPE_CHECKING: from ..run.results import Result # @UnusedImport @@ -223,13 +232,13 @@ def run_from_command_line( # run the model result = mdl.get_result() - if mdl.output: - # write the results to a file + + # write the results + if mdl.output and mpi_rank == 0: + # Write the results to a file if `output` is specified and if we are on the + # root node of an MPI run (or a serial program). The second check is a + # safe-guard against writing data on sub-nodes during an MPI program. mdl.write_result(result=result) - # else: - # # display the results on stdout - # storage = MemoryStorage() - # result.to_file(storage) # close the output file mdl.close()