Skip to content

Commit

Permalink
Skip fuse_mm_elementwise fusion with model output in the middle (#1000)
Browse files Browse the repository at this point in the history
Summary:

`fuse_mm_elementwise` transformation fuses linear patterns into single ops. As a result, all intermediate outputs in the pattern are eliminated. In a special case when one or more of those intermediate outputs are model outputs, this leads to those model outputs vanishing after the fusion. Here we add skipping the fusion when one of the intermediate outputs in the detected pattern is a model output.

Reviewed By: mengyingdu

Differential Revision: D56340320
  • Loading branch information
aakhundov authored and facebook-github-bot committed Apr 19, 2024
1 parent d6142ef commit 8583c7d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
6 changes: 6 additions & 0 deletions python/aitemplate/compiler/transform/fuse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def _find_fusion_root(tensor: Tensor, fusion_patterns: List[Any]) -> int:
fusion_idx = idx
break

if curr_tensor._attrs["is_output"]:
# if we don't break here, the curr_tensor will be
# eliminated as an intermediate tensor in the linear
# op pattern, but we can't eliminate a graph output
break

dst_op = extract_only_one_op(curr_tensor._attrs["dst_ops"])
if dst_op is None:
break
Expand Down
61 changes: 44 additions & 17 deletions tests/unittest/compiler/test_fuse_mm_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,14 @@ def _test_gemm_rcr_bias_activation(
self.assertTrue(torch.allclose(Y_pt, y, atol=1e-1, rtol=1e-1))

def _test_gemm_rcr_bias_sigmoid_mul(
self, Ms, N, K, decomposed, testname, dtype="float16"
self,
Ms,
N,
K,
decomposed,
testname,
dtype="float16",
output_in_the_middle=False,
):
m_dim = shape_utils.gen_int_var_min_max(Ms, name="M_size")
D_shape = [m_dim, N]
Expand All @@ -963,28 +970,34 @@ def _test_gemm_rcr_bias_sigmoid_mul(
output._attrs["name"] = "output_0"
output._attrs["is_output"] = True

outputs = [output]
if output_in_the_middle:
sigmoid_tensor._attrs["name"] = "output_1"
sigmoid_tensor._attrs["is_output"] = True
outputs.append(sigmoid_tensor)

# Check value correctness
target = detect_target()
module = compile_model(output, target, "./tmp", testname)
module = compile_model(outputs, target, "./tmp", testname)

check_tensor = None
for tensor in module.debug_sorted_graph:
if tensor._attrs["name"] == "final_tensor":
check_tensor = tensor
break
self.assertIsNotNone(check_tensor)
self.assertEqual(len(check_tensor.src_ops()), 1)
src_op = list(check_tensor.src_ops())[0]
self.assertEqual(src_op._attrs["op"], "gemm_rcr_bias_sigmoid_mul")
if not output_in_the_middle:
check_tensor = None
for tensor in module.debug_sorted_graph:
if tensor._attrs["name"] == "final_tensor":
check_tensor = tensor
break
self.assertIsNotNone(check_tensor)
self.assertEqual(len(check_tensor.src_ops()), 1)
src_op = list(check_tensor.src_ops())[0]
self.assertEqual(src_op._attrs["op"], "gemm_rcr_bias_sigmoid_mul")

for M in Ms:
X_pt = get_random_torch_tensor([M, K], dtype)
W_pt = get_random_torch_tensor([N, K], dtype)
B_pt = get_random_torch_tensor([N], dtype)
D_pt = get_random_torch_tensor([M, N], dtype)
Y_pt = torch.cos(
torch.sigmoid(torch.nn.functional.linear(X_pt, W_pt, B_pt)) * D_pt
)
sigmoid_pt = torch.sigmoid(torch.nn.functional.linear(X_pt, W_pt, B_pt))
Y_pt = [torch.cos(sigmoid_pt * D_pt)]

input_name_to_index = module.get_input_name_to_index_map()
inputs = [0, 0, 0, 0]
Expand All @@ -993,9 +1006,15 @@ def _test_gemm_rcr_bias_sigmoid_mul(
inputs[input_name_to_index["input_2"]] = B_pt
inputs[input_name_to_index["input_3"]] = D_pt

y = get_torch_empty_tensor([M, N], dtype)
module.run_with_tensors(inputs, [y])
self.assertTrue(torch.allclose(Y_pt, y, atol=1e-1, rtol=1e-1))
y = [get_torch_empty_tensor([M, N], dtype)]

if output_in_the_middle:
# add another tensor to capture sigmoid output from AIT
y.append(get_torch_empty_tensor([M, N], dtype))
Y_pt.append(sigmoid_pt)

module.run_with_tensors(inputs, y)
torch.testing.assert_close(Y_pt, y, atol=1e-1, rtol=1e-1)

def _test_gemm_rcr_bias_sigmoid_mul_tanh(
self, Ms, N, K, decomposed, testname, dtype="float16"
Expand Down Expand Up @@ -1135,6 +1154,14 @@ def test_gemm_rcr_bias_sigmoid_mul(self):
self._test_gemm_rcr_bias_sigmoid_mul(
[8], 16, 3, False, "gemm_rcr_bias_sigmoid_mul_need_align"
)
self._test_gemm_rcr_bias_sigmoid_mul(
[8],
16,
3,
False,
"gemm_rcr_bias_sigmoid_mul_output_in_the_middle",
output_in_the_middle=True,
)

def test_gemm_rcr_bias_sigmoid_mul_tanh(self):
self._test_gemm_rcr_bias_sigmoid_mul_tanh(
Expand Down

0 comments on commit 8583c7d

Please sign in to comment.