From 493cf0c2ef36716533698fabf85e3f4140024b58 Mon Sep 17 00:00:00 2001
From: Ben Kirk <benkirk@ucar.edu>
Date: Thu, 12 Sep 2024 11:20:09 -0600
Subject: [PATCH 1/3] support MPI launching through MPICH & variants

---
 modulus/distributed/manager.py | 33 +++++++++++++++++++++++++++++++--
 1 file changed, 31 insertions(+), 2 deletions(-)

diff --git a/modulus/distributed/manager.py b/modulus/distributed/manager.py
index 61bc2687e9..93123ce468 100644
--- a/modulus/distributed/manager.py
+++ b/modulus/distributed/manager.py
@@ -287,6 +287,28 @@ def initialize_open_mpi(addr, port):
             method="openmpi",
         )
 
+    @staticmethod
+    def initialize_mpich(addr, port):
+        """Setup method using MPICH initialization"""
+        rank = int(os.environ.get("PMI_RANK"))
+        world_size = int(os.environ.get("PMI_SIZE"))
+        try:
+            # cray-mpich
+            local_rank = int(os.environ.get("PMI_LOCAL_RANK"))
+        except:
+            # mpich-4.2.1 / hydra
+            local_rank = int(os.environ.get("MPI_LOCALRANKID"))
+
+        DistributedManager.setup(
+            rank=rank,
+            world_size=world_size,
+            local_rank=local_rank,
+            addr=addr,
+            port=port,
+            backend=DistributedManager.get_available_backend(),
+            method="mpich",
+        )
+
     @staticmethod
     def initialize_slurm(port):
         """Setup method using SLURM initialization"""
@@ -319,6 +341,9 @@ def initialize():
             `OPENMPI`: Initialization for OpenMPI launchers.
                      Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and
                      `OMPI_COMM_WORLD_LOCAL_RANK` environment variables.
+            `MPICH`: Initialization for MPICH-based MPI launchers.
+                     Uses `PMI_RANK`, `PMI_SIZE` and
+                     either `PMI_LOCAL_RANK` or `MPI_LOCALRANKID` environment variables.
 
         Initialization by default is done using the first valid method in the order
         listed above. Initialization method can also be explicitly controlled using the
@@ -342,9 +367,11 @@ def initialize():
                     DistributedManager.initialize_slurm(port)
                 elif "OMPI_COMM_WORLD_RANK" in os.environ:
                     DistributedManager.initialize_open_mpi(addr, port)
+                elif "PMI_RANK" in os.environ:
+                    DistributedManager.initialize_mpich(addr, port)
                 else:
                     warn(
-                        "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job"
+                        "Could not initialize using ENV, SLURM, OPENMPI or MPICH methods. Assuming this is a single process job"
                     )
                     DistributedManager._shared_state["_is_initialized"] = True
         elif initialization_method == "ENV":
@@ -353,13 +380,15 @@ def initialize():
             DistributedManager.initialize_slurm(port)
         elif initialization_method == "OPENMPI":
             DistributedManager.initialize_open_mpi(addr, port)
+        elif initialization_method == "MPICH":
+            DistributedManager.initialize_mpich(addr, port)
         else:
             raise RuntimeError(
                 "Unknown initialization method "
                 f"{initialization_method}. "
                 "Supported values for "
                 "MODULUS_DISTRIBUTED_INITIALIZATION_METHOD are "
-                "ENV, SLURM and OPENMPI"
+                "ENV, SLURM, OPENMPI, and MPICH"
             )
 
         # Set per rank numpy random seed for data sampling

From c64e96c57b9fbfca4ee4ba628c6410672cf2e901 Mon Sep 17 00:00:00 2001
From: Ben Kirk <benkirk@ucar.edu>
Date: Thu, 12 Sep 2024 11:34:53 -0600
Subject: [PATCH 2/3] Support for MPICH-based MPI launching.

---
 CHANGELOG.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 39af7586ff..d3f5bded48 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
   spectrum.
 - Added gradient clipping to StaticCapture utilities.
 - Bistride Multiscale MeshGraphNet example.
+- Support for MPICH-based MPI launching.
 
 ### Changed
 

From 1cd7e67efe28c04f3b9ab03f4e6fc20a1413bfcc Mon Sep 17 00:00:00 2001
From: Ben Kirk <benkirk@ucar.edu>
Date: Thu, 12 Sep 2024 13:27:15 -0600
Subject: [PATCH 3/3] infer MASTER_ADDR from rank-0 when mpi4py is available

---
 modulus/distributed/manager.py | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/modulus/distributed/manager.py b/modulus/distributed/manager.py
index 93123ce468..3af4380922 100644
--- a/modulus/distributed/manager.py
+++ b/modulus/distributed/manager.py
@@ -292,13 +292,24 @@ def initialize_mpich(addr, port):
         """Setup method using MPICH initialization"""
         rank = int(os.environ.get("PMI_RANK"))
         world_size = int(os.environ.get("PMI_SIZE"))
-        try:
-            # cray-mpich
+
+        # cray-mpich
+        if "PMI_LOCAL_RANK" in os.environ:
             local_rank = int(os.environ.get("PMI_LOCAL_RANK"))
-        except:
-            # mpich-4.2.1 / hydra
+        # mpich-4.2.1 / hydra
+        else:
             local_rank = int(os.environ.get("MPI_LOCALRANKID"))
 
+        # for multi-node MPI jobs, determine "addr" as the
+        # address of global rank 0.
+        if "localhost" == addr:
+            try:
+                import socket
+                from mpi4py import MPI
+                comm = MPI.COMM_WORLD
+                addr = comm.bcast(socket.gethostbyname(socket.gethostname()), root=0)
+            except ImportError: pass
+
         DistributedManager.setup(
             rank=rank,
             world_size=world_size,