diff --git a/python/aitemplate/compiler/transform/fuse_parallel_gemms.py b/python/aitemplate/compiler/transform/fuse_parallel_gemms.py index baf298e2f..b592350d4 100644 --- a/python/aitemplate/compiler/transform/fuse_parallel_gemms.py +++ b/python/aitemplate/compiler/transform/fuse_parallel_gemms.py @@ -23,6 +23,10 @@ from aitemplate.compiler.ops.gemm_universal.gemm_common import default_align_ab from aitemplate.compiler.tensor_accessor import TensorAccessor from aitemplate.compiler.transform import transform_utils +from aitemplate.compiler.transform.fuse_mm_elementwise_patterns import ( + get_gemm_rcr_bias_patterns, +) +from aitemplate.compiler.transform.fuse_utils import transform_simple_fusion_patterns from aitemplate.compiler.transform.toposort import toposort from aitemplate.compiler.transform.transform_strided_ops import _is_supported_op @@ -465,3 +469,110 @@ def fuse_parallel_gemms( for func in funcs: sorted_graph = func(sorted_graph) return sorted_graph + + +def _fuse_single_source_parallel_gemms( + sorted_graph: List[Tensor], +) -> Tuple[bool, List[Tensor]]: + _fusing_ops = {"gemm_rcr", "gemm_rcr_bias"} + + for tensor in sorted_graph: + fusion_groups = {} + for dst in tensor.dst_ops(): + op_type = dst._attrs["op"] + if op_type in _fusing_ops: + if dst._attrs["outputs"][0]._attrs["is_output"]: + # Skip for outputs. + continue + + if ( + tensor == dst._attrs["inputs"][1] + or dst._attrs["inputs"][1].src_ops() + or dst._attrs["inputs"][1]._attrs["is_input"] + ): + # Skip if weight or non-const + continue + elif len(dst._attrs["inputs"]) > 2 and ( + tensor == dst._attrs["inputs"][2] + or dst._attrs["inputs"][2].src_ops() + or dst._attrs["inputs"][2]._attrs["is_input"] + ): + # Skip if bias or non-const + continue + + if op_type in fusion_groups: + fusion_groups[op_type].append(dst) + else: + fusion_groups[op_type] = [dst] + + for op_type, fusion_group in fusion_groups.items(): + if len(fusion_group) < 2: + continue + + bias = "bias" in op_type + W = [] # This stores all weights + B = [] # This stores all biases + N = [] # This stores all n from (m x k) x (n x k) of gemm_rcr + for gemm_op in fusion_group: + w = gemm_op._attrs["inputs"][1] + W.append(w) + if bias: + B.append(gemm_op._attrs["inputs"][2]) + N.append(w.shape()[0].value()) + W_concat = ops.concatenate()(W, dim=0) + if bias: + B_concat = ops.concatenate()(B) + fused_gemm = ops.gemm_rcr_bias()(tensor, W_concat, B_concat) + else: + fused_gemm = ops.gemm_rcr()(tensor, W_concat) + + split_result = ops.split()(fused_gemm, N, dim=-1) + for old_op, new_tensor in zip(fusion_group, split_result): + transform_utils.replace_tensor(old_op._attrs["outputs"][0], new_tensor) + + sorted_graph = toposort(sorted_graph) + return True, transform_utils.sanitize_sorted_graph(sorted_graph) + + return False, sorted_graph + + +def fuse_single_source_parallel_gemms( + sorted_graph: List[Tensor], workdir: str = None +) -> List[Tensor]: + """This pass fuses patterns like + # x: [m, k], w_i: [n_i, k], b_i: [n_i] + y1 = gemm_rcr_bias()(x, w1, b1) + y2 = gemm_rcr_bias()(x, w2, b2) + ... + + into: + # x: [m, k], w: [sum(n_i), k], b: [sum(n_i)] + w = concatenate()([w1, w2], dim=0) + b = concatenate()([b1, b2], dim=0) + y = gemm_rcr_bias()(x, w, b) + y1, y2 = split()(y, n) + + For w and b, we rely on constant folding to preprocess them. + y1 and y2 would be written directly from y's op. + It is required that all the gemm ops have the same layouts. + + On graph pass ordering, we need to make sure this pass runs before + any other pass that modifies gemm and concat input/output TensorAccessors. + + Args: + sorted_graph (List[Tensor]): a sorted list of tensors + + Returns: + List[Tensor]: the transformed graph with all ops sorted + """ + + # Extract gemm_rcr_bias pattern first. + sorted_graph = transform_simple_fusion_patterns( + sorted_graph, get_gemm_rcr_bias_patterns() + ) + + applied = True + while applied: + applied, sorted_graph = _fuse_single_source_parallel_gemms(sorted_graph) + + return sorted_graph diff --git a/python/aitemplate/compiler/transform/optimize_graph.py b/python/aitemplate/compiler/transform/optimize_graph.py index 3eab2ad9d..4e73cc3d6 100644 --- a/python/aitemplate/compiler/transform/optimize_graph.py +++ b/python/aitemplate/compiler/transform/optimize_graph.py @@ -36,7 +36,10 @@ fuse_ops, process_singleton_elementwise, ) -from aitemplate.compiler.transform.fuse_parallel_gemms import fuse_parallel_gemms +from aitemplate.compiler.transform.fuse_parallel_gemms import ( + fuse_parallel_gemms, + fuse_single_source_parallel_gemms, +) from aitemplate.compiler.transform.fuse_permute_bmm_and_gemm import ( fuse_permute_bmm_and_gemm, ) @@ -104,6 +107,7 @@ def optimize_graph( fuse_expand_bmm, transform_odd_alignment, fuse_conv_elementwise, + fuse_single_source_parallel_gemms, fuse_mm_elementwise, fuse_mm_reshape_permute, # make sure we run move_view_op_before_concat before transform_memory_ops diff --git a/tests/unittest/compiler/test_parallel_gemm_fusions.py b/tests/unittest/compiler/test_parallel_gemm_fusions.py index 929d06150..2458cbcc4 100644 --- a/tests/unittest/compiler/test_parallel_gemm_fusions.py +++ b/tests/unittest/compiler/test_parallel_gemm_fusions.py @@ -21,7 +21,10 @@ from aitemplate.compiler import compile_model, ops from aitemplate.compiler.ops.common.epilogue import FuncEnum -from aitemplate.compiler.transform.fuse_parallel_gemms import fuse_parallel_gemms +from aitemplate.compiler.transform.fuse_parallel_gemms import ( + fuse_parallel_gemms, + fuse_single_source_parallel_gemms, +) from aitemplate.compiler.transform.toposort import toposort from aitemplate.frontend import IntImm, IntVar, Tensor from aitemplate.testing import detect_target @@ -732,7 +735,248 @@ def test_skip_parallel_gemm_cat_groups(self): ) +class SingleSourceParallelGemmFusionTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(SingleSourceParallelGemmFusionTestCase, self).__init__(*args, **kwargs) + self._test_id = 0 + + def test_simple_gemm_rcr(self, dtype: str = "float16"): + M = 1024 + N = [256, 32, 128] + K = 256 + + X = Tensor( + shape=[IntImm(M), IntImm(K)], + name="X", + dtype=dtype, + is_input=True, + ) + Ws = [] + for i in range(len(N)): + W = Tensor( + shape=[IntImm(N[i]), IntImm(K)], + dtype=dtype, + name=f"W{i}", + ) + Ws.append(W) + + Ys = [] + for i in range(len(N)): + Ys.append(ops.elementwise(FuncEnum.RELU)(ops.gemm_rcr()(X, Ws[i]))) + Ys[-1]._attrs["name"] = f"output{i}" + Ys[-1]._attrs["is_output"] = True + + constants = {} + for i in range(len(N)): + constants[f"W{i}"] = get_random_torch_tensor([N[i], K], dtype) + + # Test graph pass is correct. + sorted_graph = toposort(Ys) + new_sorted_graph = fuse_single_source_parallel_gemms(sorted_graph) + + sorted_ops = graph_utils.get_sorted_ops(new_sorted_graph) + self.assertEqual(count_ops(sorted_ops, "gemm_rcr"), 1) + + # Test e2e results. + x_pt = get_random_torch_tensor([M, K], dtype) + + Ys_pt = [] + for i in range(len(N)): + y_pt = torch.nn.functional.linear(x_pt, constants[f"W{i}"]) + Ys_pt.append(torch.nn.functional.relu(y_pt)) + + # Run AITemplate module. + target = detect_target() + with compile_model( + Ys, + target, + "./tmp", + f"fuse_single_source_parallel_gemm_simple_{dtype}", + dll_name=f"test_{self._test_id}.so", + constants=constants, + ) as module: + self._test_id += 1 + + outs = [] + for n in N: + outs.append(get_torch_empty_tensor([M, n], dtype)) + module.run_with_tensors([x_pt], outs) + + # Do comparisons. + for (out_ait, out_pt) in zip(outs, Ys_pt): + self.assertTrue(torch.allclose(out_ait, out_pt, atol=5e-2, rtol=5e-2)) + + def test_gemm_rcr_bias(self, dtype: str = "float16"): + M = 1024 + N = [36, 24, 256] + K = 256 + + X = Tensor( + shape=[IntImm(M), IntImm(K)], + name="X", + dtype=dtype, + is_input=True, + ) + Ws = [] + Bs = [] + for i in range(len(N)): + W = Tensor( + shape=[IntImm(N[i]), IntImm(K)], + dtype=dtype, + name=f"W{i}", + ) + B = Tensor( + shape=[IntImm(N[i])], + dtype=dtype, + name=f"B{i}", + ) + Ws.append(W) + Bs.append(B) + + Ys = [] + for i in range(len(N)): + Ys.append( + ops.elementwise(FuncEnum.RELU)(ops.gemm_rcr_bias()(X, Ws[i], Bs[i])) + ) + Ys[-1]._attrs["name"] = f"output{i}" + Ys[-1]._attrs["is_output"] = True + + constants = {} + for i in range(len(N)): + constants[f"W{i}"] = get_random_torch_tensor([N[i], K], dtype) + constants[f"B{i}"] = get_random_torch_tensor([N[i]], dtype) + + # Test graph pass is correct. + sorted_graph = toposort(Ys) + new_sorted_graph = fuse_single_source_parallel_gemms(sorted_graph) + + sorted_ops = graph_utils.get_sorted_ops(new_sorted_graph) + self.assertEqual(count_ops(sorted_ops, "gemm_rcr_bias"), 1) + + # Test e2e results. + x_pt = get_random_torch_tensor([M, K], dtype) + + Ys_pt = [] + for i in range(len(N)): + y_pt = torch.nn.functional.linear( + x_pt, constants[f"W{i}"], constants[f"B{i}"] + ) + Ys_pt.append(torch.nn.functional.relu(y_pt)) + + # Run AITemplate module. + target = detect_target() + with compile_model( + Ys, + target, + "./tmp", + f"fuse_single_source_parallel_gemm_rcr_bias_{dtype}", + dll_name=f"test_{self._test_id}.so", + constants=constants, + ) as module: + self._test_id += 1 + + outs = [] + for n in N: + outs.append(get_torch_empty_tensor([M, n], dtype)) + module.run_with_tensors([x_pt], outs) + + # Do comparisons. + for (out_ait, out_pt) in zip(outs, Ys_pt): + self.assertTrue(torch.allclose(out_ait, out_pt, atol=5e-2, rtol=5e-2)) + + def test_mix_gemm(self, dtype: str = "float16"): + M = 1024 + N1 = [512, 128, 32] + N2 = [32, 128, 256] + K = 256 + + X = Tensor( + shape=[IntImm(M), IntImm(K)], + name="X", + dtype=dtype, + is_input=True, + ) + Ws = [] + Bs = [] + for i, n in enumerate(N1 + N2): + W = Tensor( + shape=[IntImm(n), IntImm(K)], + dtype=dtype, + name=f"W{i}", + ) + B = Tensor( + shape=[IntImm(n)], + dtype=dtype, + name=f"B{i}", + ) + Ws.append(W) + Bs.append(B) + Bs = Bs[: len(N1)] + + Ys = [] + for i in range(len(N1)): + Ys.append( + ops.elementwise(FuncEnum.RELU)(ops.gemm_rcr_bias()(X, Ws[i], Bs[i])) + ) + Ys[-1]._attrs["name"] = f"output{i}" + Ys[-1]._attrs["is_output"] = True + for i in range(len(N1), len(N1 + N2)): + Ys.append(ops.elementwise(FuncEnum.RELU)(ops.gemm_rcr()(X, Ws[i]))) + Ys[-1]._attrs["name"] = f"output{i}" + Ys[-1]._attrs["is_output"] = True + + constants = {} + for i, n in enumerate(N1): + constants[f"W{i}"] = get_random_torch_tensor([n, K], dtype) + constants[f"B{i}"] = get_random_torch_tensor([n], dtype) + for i, n in enumerate(N2): + constants[f"W{i + len(N1)}"] = get_random_torch_tensor([n, K], dtype) + + # Test graph pass is correct. + sorted_graph = toposort(Ys) + new_sorted_graph = fuse_single_source_parallel_gemms(sorted_graph) + + sorted_ops = graph_utils.get_sorted_ops(new_sorted_graph) + self.assertEqual(count_ops(sorted_ops, "gemm_rcr"), 1) + self.assertEqual(count_ops(sorted_ops, "gemm_rcr_bias"), 1) + + # Test e2e results. + x_pt = get_random_torch_tensor([M, K], dtype) + + Ys_pt = [] + for i in range(len(N1)): + y_pt = torch.nn.functional.linear( + x_pt, constants[f"W{i}"], constants[f"B{i}"] + ) + Ys_pt.append(torch.nn.functional.relu(y_pt)) + for i in range(len(N2)): + y_pt = torch.nn.functional.linear(x_pt, constants[f"W{i+len(N1)}"]) + Ys_pt.append(torch.nn.functional.relu(y_pt)) + + # Run AITemplate module. + target = detect_target() + with compile_model( + Ys, + target, + "./tmp", + f"fuse_single_source_parallel_mix_gemm_{dtype}", + dll_name=f"test_{self._test_id}.so", + constants=constants, + ) as module: + self._test_id += 1 + + outs = [] + for n in N1 + N2: + outs.append(get_torch_empty_tensor([M, n], dtype)) + module.run_with_tensors([x_pt], outs) + + # Do comparisons. + for (out_ait, out_pt) in zip(outs, Ys_pt): + self.assertTrue(torch.allclose(out_ait, out_pt, atol=5e-2, rtol=5e-2)) + + filter_test_cases_by_test_env(ParallelGemmCatFusionTestCase) +filter_test_cases_by_test_env(SingleSourceParallelGemmFusionTestCase) if __name__ == "__main__": torch.manual_seed(0)