From 493cf0c2ef36716533698fabf85e3f4140024b58 Mon Sep 17 00:00:00 2001 From: Ben Kirk <benkirk@ucar.edu> Date: Thu, 12 Sep 2024 11:20:09 -0600 Subject: [PATCH 1/3] support MPI launching through MPICH & variants --- modulus/distributed/manager.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/modulus/distributed/manager.py b/modulus/distributed/manager.py index 61bc2687e9..93123ce468 100644 --- a/modulus/distributed/manager.py +++ b/modulus/distributed/manager.py @@ -287,6 +287,28 @@ def initialize_open_mpi(addr, port): method="openmpi", ) + @staticmethod + def initialize_mpich(addr, port): + """Setup method using MPICH initialization""" + rank = int(os.environ.get("PMI_RANK")) + world_size = int(os.environ.get("PMI_SIZE")) + try: + # cray-mpich + local_rank = int(os.environ.get("PMI_LOCAL_RANK")) + except: + # mpich-4.2.1 / hydra + local_rank = int(os.environ.get("MPI_LOCALRANKID")) + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="mpich", + ) + @staticmethod def initialize_slurm(port): """Setup method using SLURM initialization""" @@ -319,6 +341,9 @@ def initialize(): `OPENMPI`: Initialization for OpenMPI launchers. Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and `OMPI_COMM_WORLD_LOCAL_RANK` environment variables. + `MPICH`: Initialization for MPICH-based MPI launchers. + Uses `PMI_RANK`, `PMI_SIZE` and + either `PMI_LOCAL_RANK` or `MPI_LOCALRANKID` environment variables. Initialization by default is done using the first valid method in the order listed above. Initialization method can also be explicitly controlled using the @@ -342,9 +367,11 @@ def initialize(): DistributedManager.initialize_slurm(port) elif "OMPI_COMM_WORLD_RANK" in os.environ: DistributedManager.initialize_open_mpi(addr, port) + elif "PMI_RANK" in os.environ: + DistributedManager.initialize_mpich(addr, port) else: warn( - "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" + "Could not initialize using ENV, SLURM, OPENMPI or MPICH methods. Assuming this is a single process job" ) DistributedManager._shared_state["_is_initialized"] = True elif initialization_method == "ENV": @@ -353,13 +380,15 @@ def initialize(): DistributedManager.initialize_slurm(port) elif initialization_method == "OPENMPI": DistributedManager.initialize_open_mpi(addr, port) + elif initialization_method == "MPICH": + DistributedManager.initialize_mpich(addr, port) else: raise RuntimeError( "Unknown initialization method " f"{initialization_method}. " "Supported values for " "MODULUS_DISTRIBUTED_INITIALIZATION_METHOD are " - "ENV, SLURM and OPENMPI" + "ENV, SLURM, OPENMPI, and MPICH" ) # Set per rank numpy random seed for data sampling From c64e96c57b9fbfca4ee4ba628c6410672cf2e901 Mon Sep 17 00:00:00 2001 From: Ben Kirk <benkirk@ucar.edu> Date: Thu, 12 Sep 2024 11:34:53 -0600 Subject: [PATCH 2/3] Support for MPICH-based MPI launching. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39af7586ff..d3f5bded48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 spectrum. - Added gradient clipping to StaticCapture utilities. - Bistride Multiscale MeshGraphNet example. +- Support for MPICH-based MPI launching. ### Changed From 1cd7e67efe28c04f3b9ab03f4e6fc20a1413bfcc Mon Sep 17 00:00:00 2001 From: Ben Kirk <benkirk@ucar.edu> Date: Thu, 12 Sep 2024 13:27:15 -0600 Subject: [PATCH 3/3] infer MASTER_ADDR from rank-0 when mpi4py is available --- modulus/distributed/manager.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/modulus/distributed/manager.py b/modulus/distributed/manager.py index 93123ce468..3af4380922 100644 --- a/modulus/distributed/manager.py +++ b/modulus/distributed/manager.py @@ -292,13 +292,24 @@ def initialize_mpich(addr, port): """Setup method using MPICH initialization""" rank = int(os.environ.get("PMI_RANK")) world_size = int(os.environ.get("PMI_SIZE")) - try: - # cray-mpich + + # cray-mpich + if "PMI_LOCAL_RANK" in os.environ: local_rank = int(os.environ.get("PMI_LOCAL_RANK")) - except: - # mpich-4.2.1 / hydra + # mpich-4.2.1 / hydra + else: local_rank = int(os.environ.get("MPI_LOCALRANKID")) + # for multi-node MPI jobs, determine "addr" as the + # address of global rank 0. + if "localhost" == addr: + try: + import socket + from mpi4py import MPI + comm = MPI.COMM_WORLD + addr = comm.bcast(socket.gethostbyname(socket.gethostname()), root=0) + except ImportError: pass + DistributedManager.setup( rank=rank, world_size=world_size,