Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

benchdnn: graph: enhance input displace and shape rewrite for linked attribute and shapes #2354

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

wzt1997
Copy link
Contributor

@wzt1997 wzt1997 commented Jan 8, 2025

Description

The PR enhances input displacement and rewrite functionality in benchdnn graph for the following aspects:

  1. Support mb rewrite on SRC1 of MatMul and scale and zp of DynamicDequantize to support SDPA patterns rewriting.
  2. Support shape rewrite for linked attributes and shapes, such as group_shape and scale/zp input of DynamicDequantize. If user provides shapes for one of the attributes or input shapes, benchdnn graph will update the other accordingly after performing some checks.
  3. Fix the data type setting for input displacement in case the primitive cannot be created, which solved two specific cases: f8_e4m3 cases as primitive creating might fail as f8_e4m3:f8_e4m3:f8_e5m2 is not supported, and bf16:int4:bf16 matmul cases where f32:int4:bf16 matmul is not supported.

For example:

# Rewrite input shape only
0:PASSED __REPRO: --graph --in-shapes=7:1x16x128x1+8:1x16x128x1 --case=/home/wangzhitao/oneDNN-src/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-compressed-v-int8-gs32.json
# Rewrite group-shape only
1:PASSED __REPRO: --graph --op-attrs=34107656704:group_shape:1x1x128x1 --case=/home/wangzhitao/oneDNN-src/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
# Rewrite mb size
2:PASSED __REPRO: --graph --mb=10 --case=/home/wangzhitao/oneDNN-src/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-compressed-v-int8-gs32.json

@wzt1997 wzt1997 self-assigned this Jan 8, 2025
@github-actions github-actions bot added component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch labels Jan 8, 2025
@wzt1997 wzt1997 force-pushed the zhitao/enhance-shape-rewrite branch from 4af44a1 to d65e846 Compare January 8, 2025 06:09
@wzt1997 wzt1997 changed the title [WIP]benchdnn: graph: enhance mb and shape rewrite [WIP]benchdnn: graph: enhance input displace and shape rewrite for linked attribute and shapes Jan 8, 2025
@wzt1997 wzt1997 force-pushed the zhitao/enhance-shape-rewrite branch from 50714b2 to c82a899 Compare January 8, 2025 08:02
@wzt1997 wzt1997 force-pushed the zhitao/enhance-shape-rewrite branch from c82a899 to b6a63b5 Compare January 10, 2025 05:56
@wzt1997 wzt1997 changed the title [WIP]benchdnn: graph: enhance input displace and shape rewrite for linked attribute and shapes benchdnn: graph: enhance input displace and shape rewrite for linked attribute and shapes Jan 10, 2025
@wzt1997 wzt1997 marked this pull request as ready for review January 10, 2025 06:50
@wzt1997 wzt1997 requested review from a team as code owners January 10, 2025 06:50
@wzt1997
Copy link
Contributor Author

wzt1997 commented Jan 10, 2025

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph


bool ret = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes sense to set this to false in the first place to avoid multiple re-initialization of this value across the logic.

@@ -32,7 +32,7 @@
# Re-written graphs
--reset --dt=f32,bf16,f16 --in-shapes=4:4x16x32x256+5:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=3:4x32x32x128+4:4x32x128x33+0:4x32x33x128+1:4x1x32x33 --case=complex_fusion/mha/MHA-LLaMa-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=3:20x16x384x64+4:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --mb=10,20 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=3:10x16x384x64+4:10x1x64x384+0:10x1x384x64+1:10x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this me removed in favor of mb=10 right above?

if (attr.find("qtype") == attr.end()
|| attr["qtype"].str_value_ != "per_group")
continue;
if (attr.find("group_shape") == attr.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to dump group_size attribute in graph dump under -v7 so that it's visible to the user.

zp_lt.shape_, dgraph.lt_2_mtag_[zp_lt.id_]);
}
} else if (input_shape_rewrite && !group_shape_rewrite) {
// if user only rewrite input shapes, upadte the group-shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// if user only rewrite input shapes, upadte the group-shape
// if user only rewrites input shapes, update the group-shape

});
bool group_shape_rewrite = op_attrs_.count(aop.id_)
&& parse_attrs(op_attrs_.at(aop.id_)).count("group_shape");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see some checks are duplicated, I guess reorganizing code a little bit can help eliminate those duplicates (pseudo-code):

if (!input_shape_rewrite && !group_shape_rewrite) continue;

if (input_shape_rewrite) {
    checks_for_src_and_scale
}
if (group_shape_rewrite) {
    checks_for_src_and_group
}

And then bodies of action will just do what they need to do.

"Error: the ndims of scale tensor should align "
"with the ndims of zero-point tensor for op "
"with "
"id=\'%zu\'\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clang-format (the one that we are using) is bad at formatting strings. The best way to update the string to make a single line and then let clang-format break it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants