-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
benchdnn: graph: enhance input displace and shape rewrite for linked attribute and shapes #2354
base: main
Are you sure you want to change the base?
Conversation
4af44a1
to
d65e846
Compare
50714b2
to
c82a899
Compare
c82a899
to
b6a63b5
Compare
make test |
|
||
bool ret = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it makes sense to set this to false
in the first place to avoid multiple re-initialization of this value across the logic.
@@ -32,7 +32,7 @@ | |||
# Re-written graphs | |||
--reset --dt=f32,bf16,f16 --in-shapes=4:4x16x32x256+5:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json | |||
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=3:4x32x32x128+4:4x32x128x33+0:4x32x33x128+1:4x1x32x33 --case=complex_fusion/mha/MHA-LLaMa-inf-fp32-bs1.json | |||
--reset --dt=f32,bf16,f16 --in-shapes=3:20x16x384x64+4:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json | |||
--reset --dt=f32,bf16,f16 --mb=10,20 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json | |||
--reset --dt=f32,bf16,f16 --in-shapes=3:10x16x384x64+4:10x1x64x384+0:10x1x384x64+1:10x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this me removed in favor of mb=10 right above?
if (attr.find("qtype") == attr.end() | ||
|| attr["qtype"].str_value_ != "per_group") | ||
continue; | ||
if (attr.find("group_shape") == attr.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to dump group_size attribute in graph dump under -v7 so that it's visible to the user.
zp_lt.shape_, dgraph.lt_2_mtag_[zp_lt.id_]); | ||
} | ||
} else if (input_shape_rewrite && !group_shape_rewrite) { | ||
// if user only rewrite input shapes, upadte the group-shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// if user only rewrite input shapes, upadte the group-shape | |
// if user only rewrites input shapes, update the group-shape |
}); | ||
bool group_shape_rewrite = op_attrs_.count(aop.id_) | ||
&& parse_attrs(op_attrs_.at(aop.id_)).count("group_shape"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see some checks are duplicated, I guess reorganizing code a little bit can help eliminate those duplicates (pseudo-code):
if (!input_shape_rewrite && !group_shape_rewrite) continue;
if (input_shape_rewrite) {
checks_for_src_and_scale
}
if (group_shape_rewrite) {
checks_for_src_and_group
}
And then bodies of action will just do what they need to do.
"Error: the ndims of scale tensor should align " | ||
"with the ndims of zero-point tensor for op " | ||
"with " | ||
"id=\'%zu\'\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-format (the one that we are using) is bad at formatting strings. The best way to update the string to make a single line and then let clang-format break it.
Description
The PR enhances input displacement and rewrite functionality in benchdnn graph for the following aspects:
MatMul
and scale and zp ofDynamicDequantize
to support SDPA patterns rewriting.group_shape
and scale/zp input ofDynamicDequantize
. If user provides shapes for one of the attributes or input shapes, benchdnn graph will update the other accordingly after performing some checks.For example: