Skip to content

Commit

Permalink
Add fuse_single_source_parallel_gemm (#977)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #977

Add a fusion for single source parallel gemm.
This is an implementation of fuse_parallel_linear in DPER pass.

Reviewed By: khabinov, chenyang78

Differential Revision: D52087558

fbshipit-source-id: e1a0f0c0b2597adf96c02cad2f229cea78c10268
  • Loading branch information
muchulee8 authored and facebook-github-bot committed Dec 14, 2023
1 parent bd5ba64 commit b29432a
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 2 deletions.
111 changes: 111 additions & 0 deletions python/aitemplate/compiler/transform/fuse_parallel_gemms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from aitemplate.compiler.ops.gemm_universal.gemm_common import default_align_ab
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.compiler.transform import transform_utils
from aitemplate.compiler.transform.fuse_mm_elementwise_patterns import (
get_gemm_rcr_bias_patterns,
)
from aitemplate.compiler.transform.fuse_utils import transform_simple_fusion_patterns
from aitemplate.compiler.transform.toposort import toposort
from aitemplate.compiler.transform.transform_strided_ops import _is_supported_op

Expand Down Expand Up @@ -465,3 +469,110 @@ def fuse_parallel_gemms(
for func in funcs:
sorted_graph = func(sorted_graph)
return sorted_graph


def _fuse_single_source_parallel_gemms(
sorted_graph: List[Tensor],
) -> Tuple[bool, List[Tensor]]:
_fusing_ops = {"gemm_rcr", "gemm_rcr_bias"}

for tensor in sorted_graph:
fusion_groups = {}
for dst in tensor.dst_ops():
op_type = dst._attrs["op"]
if op_type in _fusing_ops:
if dst._attrs["outputs"][0]._attrs["is_output"]:
# Skip for outputs.
continue

if (
tensor == dst._attrs["inputs"][1]
or dst._attrs["inputs"][1].src_ops()
or dst._attrs["inputs"][1]._attrs["is_input"]
):
# Skip if weight or non-const
continue
elif len(dst._attrs["inputs"]) > 2 and (
tensor == dst._attrs["inputs"][2]
or dst._attrs["inputs"][2].src_ops()
or dst._attrs["inputs"][2]._attrs["is_input"]
):
# Skip if bias or non-const
continue

if op_type in fusion_groups:
fusion_groups[op_type].append(dst)
else:
fusion_groups[op_type] = [dst]

for op_type, fusion_group in fusion_groups.items():
if len(fusion_group) < 2:
continue

bias = "bias" in op_type
W = [] # This stores all weights
B = [] # This stores all biases
N = [] # This stores all n from (m x k) x (n x k) of gemm_rcr
for gemm_op in fusion_group:
w = gemm_op._attrs["inputs"][1]
W.append(w)
if bias:
B.append(gemm_op._attrs["inputs"][2])
N.append(w.shape()[0].value())
W_concat = ops.concatenate()(W, dim=0)
if bias:
B_concat = ops.concatenate()(B)
fused_gemm = ops.gemm_rcr_bias()(tensor, W_concat, B_concat)
else:
fused_gemm = ops.gemm_rcr()(tensor, W_concat)

split_result = ops.split()(fused_gemm, N, dim=-1)
for old_op, new_tensor in zip(fusion_group, split_result):
transform_utils.replace_tensor(old_op._attrs["outputs"][0], new_tensor)

sorted_graph = toposort(sorted_graph)
return True, transform_utils.sanitize_sorted_graph(sorted_graph)

return False, sorted_graph


def fuse_single_source_parallel_gemms(
sorted_graph: List[Tensor], workdir: str = None
) -> List[Tensor]:
"""This pass fuses patterns like
# x: [m, k], w_i: [n_i, k], b_i: [n_i]
y1 = gemm_rcr_bias()(x, w1, b1)
y2 = gemm_rcr_bias()(x, w2, b2)
...
into:
# x: [m, k], w: [sum(n_i), k], b: [sum(n_i)]
w = concatenate()([w1, w2], dim=0)
b = concatenate()([b1, b2], dim=0)
y = gemm_rcr_bias()(x, w, b)
y1, y2 = split()(y, n)
For w and b, we rely on constant folding to preprocess them.
y1 and y2 would be written directly from y's op.
It is required that all the gemm ops have the same layouts.
On graph pass ordering, we need to make sure this pass runs before
any other pass that modifies gemm and concat input/output TensorAccessors.
Args:
sorted_graph (List[Tensor]): a sorted list of tensors
Returns:
List[Tensor]: the transformed graph with all ops sorted
"""

# Extract gemm_rcr_bias pattern first.
sorted_graph = transform_simple_fusion_patterns(
sorted_graph, get_gemm_rcr_bias_patterns()
)

applied = True
while applied:
applied, sorted_graph = _fuse_single_source_parallel_gemms(sorted_graph)

return sorted_graph
6 changes: 5 additions & 1 deletion python/aitemplate/compiler/transform/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
fuse_ops,
process_singleton_elementwise,
)
from aitemplate.compiler.transform.fuse_parallel_gemms import fuse_parallel_gemms
from aitemplate.compiler.transform.fuse_parallel_gemms import (
fuse_parallel_gemms,
fuse_single_source_parallel_gemms,
)
from aitemplate.compiler.transform.fuse_permute_bmm_and_gemm import (
fuse_permute_bmm_and_gemm,
)
Expand Down Expand Up @@ -104,6 +107,7 @@ def optimize_graph(
fuse_expand_bmm,
transform_odd_alignment,
fuse_conv_elementwise,
fuse_single_source_parallel_gemms,
fuse_mm_elementwise,
fuse_mm_reshape_permute,
# make sure we run move_view_op_before_concat before transform_memory_ops
Expand Down
Loading

0 comments on commit b29432a

Please sign in to comment.