diff --git a/modulus/distributed/manager.py b/modulus/distributed/manager.py index 61bc2687e9..a266fdd0c5 100644 --- a/modulus/distributed/manager.py +++ b/modulus/distributed/manager.py @@ -332,7 +332,11 @@ def initialize(): addr = os.getenv("MASTER_ADDR", "localhost") port = os.getenv("MASTER_PORT", "12355") # https://pytorch.org/docs/master/notes/cuda.html#id5 - os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + # was changed in version 2.2 + if torch.__version__ < (2, 2): + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + else: + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" initialization_method = os.getenv("MODULUS_DISTRIBUTED_INITIALIZATION_METHOD") if initialization_method is None: try: