Skip to content

Commit

Permalink
[Pallas:MGPU] Add support for passing in WGMMA lhs from registers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688117316
  • Loading branch information
apaszke authored and Google-ML-Automation committed Oct 21, 2024
1 parent f08801b commit f833891
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 26 deletions.
66 changes: 40 additions & 26 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class _WGMMAPipelineEffect(effects.Effect):

def wgmma(
acc: gpu_core.WGMMAAbstractAccumulatorRef,
a: pallas_core.TransformedRef,
a,
b: pallas_core.TransformedRef,
) -> None:
"""Performs and asynchronous warp group matmul-accumulate on the given references.
Expand Down Expand Up @@ -395,12 +395,16 @@ def wgmma(
if a.dtype != b.dtype:
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")

a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms)
if isinstance(a, pallas_core.TransformedRef):
a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms)
a = a.ref
else:
a_transforms_leaves, a_transforms_tree = [], None
b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms)

wgmma_ref_p.bind(
acc,
a.ref,
a,
b.ref,
*a_transforms_leaves,
*b_transforms_leaves,
Expand All @@ -411,15 +415,15 @@ def wgmma(

@wgmma_ref_p.def_effectful_abstract_eval
def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params):
del a_aval, b_aval, params
del b_aval, params
if not isinstance(acc_aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc_aval}")
return (), {
_wgmma_pipeline_effect,
state.WriteEffect(0),
state.ReadEffect(0),
state.ReadEffect(1),
state.ReadEffect(2),
*([state.ReadEffect(1)] if isinstance(a_aval, state.AbstractRef) else [])
}


Expand All @@ -444,23 +448,31 @@ def _wgmma_lowering(
b_transforms_tree,
):
_, a_aval, *_ = ctx.avals_in
a_transforms_leaves, b_transforms_leaves = util.split_list(
transforms_leaves, [a_transforms_tree.num_leaves]
)
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
lhs_swizzle = None
if a_transforms_tree is not None:
a_transforms_leaves, b_transforms_leaves = util.split_list(
transforms_leaves, [a_transforms_tree.num_leaves]
)
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
a, a_transforms = lowering._handle_indexing(a, a_transforms)
match a_transforms:
case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)):
swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize
if tiling != (64, swizzle_elems):
raise NotImplementedError("WGMMA lhs tiling does not fit swizzle")
case _:
raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.")
else:
b_transforms_leaves = transforms_leaves # type: ignore
if not isinstance(a, mgpu.FragmentedArray):
raise ValueError(
"When WGMMA lhs is passed in as a ref, it must be transformed by"
" swizzling and tiling appropriately."
)

a, a_transforms = lowering._handle_indexing(a, a_transforms)
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
b, b_transforms = lowering._handle_indexing(b, b_transforms)

match a_transforms:
case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)):
swizzle_elems = swizzle // a_aval.dtype.itemsize
if tiling != (64, swizzle_elems):
raise NotImplementedError("WGMMA lhs tiling does not fit swizzle")
case _:
raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.")

match b_transforms:
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
rhs_transpose = False
Expand All @@ -474,16 +486,18 @@ def _wgmma_lowering(
case _:
raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.")

if rhs_swizzle != swizzle:
raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle")
if rhs_tiling != (swizzle_elems, swizzle_elems):
raise NotImplementedError("WGMMA rhs tiling does not fit swizzle")
if lhs_swizzle is not None:
swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize
if rhs_swizzle != lhs_swizzle:
raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle")
if rhs_tiling != (swizzle_elems, swizzle_elems):
raise NotImplementedError("WGMMA rhs tiling does not fit swizzle")

new_acc = mgpu.wgmma(
acc,
a,
b,
swizzle=swizzle,
swizzle=rhs_swizzle,
b_order=mgpu.WGMMALayout.COL_MAJOR
if rhs_transpose
else mgpu.WGMMALayout.ROW_MAJOR,
Expand All @@ -493,12 +507,12 @@ def _wgmma_lowering(


@wgmma_p.def_effectful_abstract_eval
def _wgmma_effectful_abstract_eval(acc, *args, **kwargs):
def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs):
del args, kwargs
return acc, {
_wgmma_pipeline_effect,
state.ReadEffect(1),
state.ReadEffect(2),
*([state.ReadEffect(1)] if isinstance(lhs_ref, state.AbstractRef) else [])
}

wgmma_wait_p = jax_core.Primitive("wgmma_wait")
Expand Down
23 changes: 23 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,29 @@ def scope(acc_ref):
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
)

def test_wgmma_registers(self):
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)

transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128))
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms),
plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)

def test_wgmma_sliced_ref(self):
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
Expand Down

0 comments on commit f833891

Please sign in to comment.