Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fuse_single_source_parallel_gemm #977

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions python/aitemplate/compiler/transform/fuse_parallel_gemms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from aitemplate.compiler.ops.gemm_universal.gemm_common import default_align_ab
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.compiler.transform import transform_utils
from aitemplate.compiler.transform.fuse_mm_elementwise_patterns import (
get_gemm_rcr_bias_patterns,
)
from aitemplate.compiler.transform.fuse_utils import transform_simple_fusion_patterns
from aitemplate.compiler.transform.toposort import toposort
from aitemplate.compiler.transform.transform_strided_ops import _is_supported_op

Expand Down Expand Up @@ -465,3 +469,110 @@ def fuse_parallel_gemms(
for func in funcs:
sorted_graph = func(sorted_graph)
return sorted_graph


def _fuse_single_source_parallel_gemms(
sorted_graph: List[Tensor],
) -> Tuple[bool, List[Tensor]]:
_fusing_ops = {"gemm_rcr", "gemm_rcr_bias"}

for tensor in sorted_graph:
fusion_groups = {}
for dst in tensor.dst_ops():
op_type = dst._attrs["op"]
if op_type in _fusing_ops:
if dst._attrs["outputs"][0]._attrs["is_output"]:
# Skip for outputs.
continue

if (
tensor == dst._attrs["inputs"][1]
or dst._attrs["inputs"][1].src_ops()
or dst._attrs["inputs"][1]._attrs["is_input"]
):
# Skip if weight or non-const
continue
elif len(dst._attrs["inputs"]) > 2 and (
tensor == dst._attrs["inputs"][2]
or dst._attrs["inputs"][2].src_ops()
or dst._attrs["inputs"][2]._attrs["is_input"]
):
# Skip if bias or non-const
continue

if op_type in fusion_groups:
fusion_groups[op_type].append(dst)
else:
fusion_groups[op_type] = [dst]

for op_type, fusion_group in fusion_groups.items():
if len(fusion_group) < 2:
continue

bias = "bias" in op_type
W = [] # This stores all weights
B = [] # This stores all biases
N = [] # This stores all n from (m x k) x (n x k) of gemm_rcr
for gemm_op in fusion_group:
w = gemm_op._attrs["inputs"][1]
W.append(w)
if bias:
B.append(gemm_op._attrs["inputs"][2])
N.append(w.shape()[0].value())
W_concat = ops.concatenate()(W, dim=0)
if bias:
B_concat = ops.concatenate()(B)
fused_gemm = ops.gemm_rcr_bias()(tensor, W_concat, B_concat)
else:
fused_gemm = ops.gemm_rcr()(tensor, W_concat)

split_result = ops.split()(fused_gemm, N, dim=-1)
for old_op, new_tensor in zip(fusion_group, split_result):
transform_utils.replace_tensor(old_op._attrs["outputs"][0], new_tensor)

sorted_graph = toposort(sorted_graph)
return True, transform_utils.sanitize_sorted_graph(sorted_graph)

return False, sorted_graph


def fuse_single_source_parallel_gemms(
sorted_graph: List[Tensor], workdir: str = None
) -> List[Tensor]:
"""This pass fuses patterns like
# x: [m, k], w_i: [n_i, k], b_i: [n_i]
y1 = gemm_rcr_bias()(x, w1, b1)
y2 = gemm_rcr_bias()(x, w2, b2)
...

into:
# x: [m, k], w: [sum(n_i), k], b: [sum(n_i)]
w = concatenate()([w1, w2], dim=0)
b = concatenate()([b1, b2], dim=0)
y = gemm_rcr_bias()(x, w, b)
y1, y2 = split()(y, n)

For w and b, we rely on constant folding to preprocess them.
y1 and y2 would be written directly from y's op.
It is required that all the gemm ops have the same layouts.

On graph pass ordering, we need to make sure this pass runs before
any other pass that modifies gemm and concat input/output TensorAccessors.

Args:
sorted_graph (List[Tensor]): a sorted list of tensors

Returns:
List[Tensor]: the transformed graph with all ops sorted
"""

# Extract gemm_rcr_bias pattern first.
sorted_graph = transform_simple_fusion_patterns(
sorted_graph, get_gemm_rcr_bias_patterns()
)

applied = True
while applied:
applied, sorted_graph = _fuse_single_source_parallel_gemms(sorted_graph)

return sorted_graph
6 changes: 5 additions & 1 deletion python/aitemplate/compiler/transform/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
fuse_ops,
process_singleton_elementwise,
)
from aitemplate.compiler.transform.fuse_parallel_gemms import fuse_parallel_gemms
from aitemplate.compiler.transform.fuse_parallel_gemms import (
fuse_parallel_gemms,
fuse_single_source_parallel_gemms,
)
from aitemplate.compiler.transform.fuse_permute_bmm_and_gemm import (
fuse_permute_bmm_and_gemm,
)
Expand Down Expand Up @@ -104,6 +107,7 @@ def optimize_graph(
fuse_expand_bmm,
transform_odd_alignment,
fuse_conv_elementwise,
fuse_single_source_parallel_gemms,
fuse_mm_elementwise,
fuse_mm_reshape_permute,
# make sure we run move_view_op_before_concat before transform_memory_ops
Expand Down
Loading
Loading