From 88682736e95a34b041a9ed78177c0d3cc24c2cdb Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 18 Jul 2024 20:07:20 +0200 Subject: [PATCH] updating JAX_FFT cuFFTMP to work with JAX 0.4.30 --- .../JAX_FFT/src/cufftmp_jax/CMakeLists.txt | 9 ++++- .../JAX_FFT/src/cufftmp_jax/cufftmp_jax.py | 34 ++++++++++--------- cuFFTMp/JAX_FFT/src/fft_common/utils.py | 2 +- cuFFTMp/JAX_FFT/src/xfft/xfft.py | 2 +- cuFFTMp/JAX_FFT/tests/fft_test.py | 18 ++++++---- 5 files changed, 40 insertions(+), 25 deletions(-) diff --git a/cuFFTMp/JAX_FFT/src/cufftmp_jax/CMakeLists.txt b/cuFFTMp/JAX_FFT/src/cufftmp_jax/CMakeLists.txt index bf620c9a..66e032b3 100644 --- a/cuFFTMp/JAX_FFT/src/cufftmp_jax/CMakeLists.txt +++ b/cuFFTMp/JAX_FFT/src/cufftmp_jax/CMakeLists.txt @@ -6,8 +6,15 @@ find_package(pybind11 CONFIG REQUIRED) include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) +set(NVSHMEM_HOME $ENV{NVHPC_ROOT}/comm_libs/12.2/nvshmem_cufftmp_compat) +set(CUFFTMP_HOME $ENV{NVHPC_ROOT}/math_libs) message(STATUS "Using ${NVSHMEM_HOME} for NVSHMEM_HOME and ${CUFFTMP_HOME} for CUFFTMP_HOME") -include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUFFTMP_HOME}/include ${NVSHMEM_HOME}/include) + + +include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ${CUFFTMP_HOME}/include ${NVSHMEM_HOME}/include + $ENV{CUFFT_INC} +) link_directories(${CUFFTMP_HOME}/lib ${NVSHMEM_HOME}/lib) pybind11_add_module(gpu_ops diff --git a/cuFFTMp/JAX_FFT/src/cufftmp_jax/cufftmp_jax.py b/cuFFTMp/JAX_FFT/src/cufftmp_jax/cufftmp_jax.py index 70bb1331..fcc4a4b0 100644 --- a/cuFFTMp/JAX_FFT/src/cufftmp_jax/cufftmp_jax.py +++ b/cuFFTMp/JAX_FFT/src/cufftmp_jax/cufftmp_jax.py @@ -9,8 +9,8 @@ from jax.lib import xla_client from jax import core, dtypes from jax.interpreters import xla, mlir -from jax.abstract_arrays import ShapedArray -from jax._src.sharding import NamedSharding +from jax.core import ShapedArray +from jax.sharding import NamedSharding from jax.experimental.custom_partitioning import custom_partitioning from jaxlib.hlo_helpers import custom_call @@ -30,7 +30,7 @@ def _cufftmp_bind(input, num_parts, dist, dir): # param=val means it's a static parameter - (output,) = _cufftmp_prim.bind(input, + output = _cufftmp_prim.bind(input, num_parts=num_parts, dist=dist, dir=dir) @@ -110,7 +110,7 @@ def cufftmp(x, dist, dir): @custom_partitioning def _cufftmp_(x): - return _cufftmp_bind(x, num_parts=1, dist=dist, dir=dir) + return _cufftmp_bind(x, num_parts=jax.device_count(), dist=dist, dir=dir) _cufftmp_.def_partition( infer_sharding_from_operands=partial( @@ -180,18 +180,20 @@ def _cufftmp_translation(ctx, input, num_parts, dist, dir): else: raise ValueError("Unsupported tensor rank; must be 2 or 3") - return [custom_call( - "gpu_cufftmp", - # Output types - out_types=[output_type], - # The inputs: - operands=[input,], - # Layout specification: - operand_layouts=[layout,], - result_layouts=[layout,], - # GPU specific additional data - backend_config=opaque - )] + out = custom_call( + "gpu_cufftmp", + # Output types + result_types=[output_type], + # The inputs: + operands=[input,], + # Layout specification: + operand_layouts=[layout,], + result_layouts=[layout,], + # GPU specific additional data + backend_config=opaque + ) + + return out.results # ********************************************* diff --git a/cuFFTMp/JAX_FFT/src/fft_common/utils.py b/cuFFTMp/JAX_FFT/src/fft_common/utils.py index d2d893e1..6b488d04 100644 --- a/cuFFTMp/JAX_FFT/src/fft_common/utils.py +++ b/cuFFTMp/JAX_FFT/src/fft_common/utils.py @@ -1,7 +1,7 @@ from enum import Enum import jax -from jax.experimental import PartitionSpec +from jax.sharding import PartitionSpec class Dist(Enum): diff --git a/cuFFTMp/JAX_FFT/src/xfft/xfft.py b/cuFFTMp/JAX_FFT/src/xfft/xfft.py index 63afa7c8..3e7a09e1 100644 --- a/cuFFTMp/JAX_FFT/src/xfft/xfft.py +++ b/cuFFTMp/JAX_FFT/src/xfft/xfft.py @@ -1,7 +1,7 @@ from functools import partial import jax -from jax._src.sharding import NamedSharding +from jax.sharding import NamedSharding from jax.experimental.custom_partitioning import custom_partitioning from fft_common import Dir diff --git a/cuFFTMp/JAX_FFT/tests/fft_test.py b/cuFFTMp/JAX_FFT/tests/fft_test.py index e72757ef..78cce2bd 100644 --- a/cuFFTMp/JAX_FFT/tests/fft_test.py +++ b/cuFFTMp/JAX_FFT/tests/fft_test.py @@ -13,8 +13,10 @@ from fft_common import Dist, Dir from cufftmp_jax import cufftmp from xfft import xfft - +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P import helpers +from jax.experimental import mesh_utils, multihost_utils def main(): @@ -45,10 +47,17 @@ def main(): raise ValueError(f"Wrong implementation: got {impl}, expected cufftmp or xfft") dist = Dist.create(opt['dist']) + if dist == Dist.SLABS_X: + pdims = [jax.device_count(), 1] + axis_names = ('gpus', None) + elif dist == Dist.SLABS_Y: + pdims = [1, jax.device_count()] + axis_names = (None, 'gpus') input_shape = dist.slab_shape(fft_dims) dtype = jnp.complex64 - mesh = maps.Mesh(np.asarray(jax.devices()), ('gpus',)) + devices = mesh_utils.create_device_mesh(pdims) + mesh = Mesh(devices, axis_names=axis_names) with jax.spmd_mode('allow_all'): @@ -60,10 +69,7 @@ def main(): with mesh: - fft = pjit(dist_fft, - in_axis_resources=None, - out_axis_resources=None, - static_argnums=[1, 2]) + fft = jax.jit(dist_fft,static_argnums=[1, 2]) output = fft(input, dist, Dir.FWD)