-
Notifications
You must be signed in to change notification settings - Fork 241
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### Changes * Torch SDPA pattern is updated * As the concat node has his input nodes in format `args=([inp_1, ..., inp_n], dim)`, thus it should be treated differently. Retrieving concat inputs by input port id was supported in each TorchFX transformation ### Reason for changes * To support quantization of ultralytics/yolo11n in TorchFX backend ### Related tickets #2766 157032 ### Tests * `tests/torch/fx/test_model_transformer.py` and `tests/torch/fx/test_compress_weights.py` are updated to check all cases with the concat node. All .`dot` / `.json` were checked manually. * `tests/torch/fx/test_models.py` is updated with `YOLO11N_SDPABlock` synthetic model to check the correctness of SDPA pattern matching
- Loading branch information
1 parent
0c22b38
commit 7ea17f2
Showing
34 changed files
with
1,356 additions
and
735 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
tests/torch/data/reference_graphs/fx/post_quantization_compressed/yolo11n_sdpa_block.dot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
strict digraph { | ||
"0 x" [id=0, type=input]; | ||
"1 x_0_0_nncf_smooth_quant_0" [id=1, type=call_module]; | ||
"2 quantize_per_tensor_default" [id=2, type=quantize_per_tensor]; | ||
"3 dequantize_per_tensor_default" [id=3, type=dequantize_per_tensor]; | ||
"4 scale_updated_constant0" [id=4, type=get_attr]; | ||
"5 compressed_weight_updated_constant0" [id=5, type=get_attr]; | ||
"6 mul_tensor" [id=6, type=mul]; | ||
"7 zero_point_updated_constant0" [id=7, type=get_attr]; | ||
"8 sub_tensor" [id=8, type=sub]; | ||
"9 linear" [id=9, type=linear]; | ||
"10 quantize_per_tensor_default_1" [id=10, type=quantize_per_tensor]; | ||
"11 dequantize_per_tensor_default_1" [id=11, type=dequantize_per_tensor]; | ||
"12 slice_1" [id=12, type=slice]; | ||
"13 slice_2" [id=13, type=slice]; | ||
"14 slice_3" [id=14, type=slice]; | ||
"15 quantize_per_tensor_default_2" [id=15, type=quantize_per_tensor]; | ||
"16 dequantize_per_tensor_default_2" [id=16, type=dequantize_per_tensor]; | ||
"17 slice_4" [id=17, type=slice]; | ||
"18 slice_5" [id=18, type=slice]; | ||
"19 slice_6" [id=19, type=slice]; | ||
"20 slice_7" [id=20, type=slice]; | ||
"21 slice_8" [id=21, type=slice]; | ||
"22 slice_9" [id=22, type=slice]; | ||
"23 transpose" [id=23, type=transpose]; | ||
"24 matmul" [id=24, type=matmul]; | ||
"25 div_" [id=25, type=div_]; | ||
"26 softmax" [id=26, type=softmax]; | ||
"27 transpose_1" [id=27, type=transpose]; | ||
"28 matmul_1" [id=28, type=matmul]; | ||
"29 output" [id=29, type=output]; | ||
"0 x" -> "1 x_0_0_nncf_smooth_quant_0" [label="(1, 2, 4)", style=solid]; | ||
"1 x_0_0_nncf_smooth_quant_0" -> "2 quantize_per_tensor_default" [label="(1, 2, 4)", style=solid]; | ||
"2 quantize_per_tensor_default" -> "3 dequantize_per_tensor_default" [label="(1, 2, 4)", style=solid]; | ||
"3 dequantize_per_tensor_default" -> "9 linear" [label="(1, 2, 4)", style=solid]; | ||
"4 scale_updated_constant0" -> "6 mul_tensor" [label="(12, 1)", style=solid]; | ||
"5 compressed_weight_updated_constant0" -> "6 mul_tensor" [label="(12, 4)", style=solid]; | ||
"6 mul_tensor" -> "8 sub_tensor" [label="(12, 4)", style=solid]; | ||
"7 zero_point_updated_constant0" -> "8 sub_tensor" [label="(12, 1)", style=solid]; | ||
"8 sub_tensor" -> "9 linear" [label="(12, 4)", style=solid]; | ||
"9 linear" -> "10 quantize_per_tensor_default_1" [label="(1, 2, 12)", style=solid]; | ||
"9 linear" -> "15 quantize_per_tensor_default_2" [label="(1, 2, 12)", style=solid]; | ||
"9 linear" -> "20 slice_7" [label="(1, 2, 12)", style=solid]; | ||
"10 quantize_per_tensor_default_1" -> "11 dequantize_per_tensor_default_1" [label="(1, 2, 12)", style=solid]; | ||
"11 dequantize_per_tensor_default_1" -> "12 slice_1" [label="(1, 2, 12)", style=solid]; | ||
"12 slice_1" -> "13 slice_2" [label="(1, 2, 12)", style=solid]; | ||
"13 slice_2" -> "14 slice_3" [label="(1, 2, 12)", style=solid]; | ||
"14 slice_3" -> "24 matmul" [label="(1, 2, 4)", style=solid]; | ||
"15 quantize_per_tensor_default_2" -> "16 dequantize_per_tensor_default_2" [label="(1, 2, 12)", style=solid]; | ||
"16 dequantize_per_tensor_default_2" -> "17 slice_4" [label="(1, 2, 12)", style=solid]; | ||
"17 slice_4" -> "18 slice_5" [label="(1, 2, 12)", style=solid]; | ||
"18 slice_5" -> "19 slice_6" [label="(1, 2, 12)", style=solid]; | ||
"19 slice_6" -> "23 transpose" [label="(1, 2, 4)", style=solid]; | ||
"20 slice_7" -> "21 slice_8" [label="(1, 2, 12)", style=solid]; | ||
"21 slice_8" -> "22 slice_9" [label="(1, 2, 12)", style=solid]; | ||
"22 slice_9" -> "28 matmul_1" [label="(1, 2, 4)", style=solid]; | ||
"23 transpose" -> "24 matmul" [label="(1, 4, 2)", style=solid]; | ||
"24 matmul" -> "25 div_" [label="(1, 2, 2)", style=solid]; | ||
"25 div_" -> "26 softmax" [label="(1, 2, 2)", style=solid]; | ||
"26 softmax" -> "27 transpose_1" [label="(1, 2, 2)", style=solid]; | ||
"27 transpose_1" -> "28 matmul_1" [label="(1, 2, 2)", style=solid]; | ||
"28 matmul_1" -> "29 output" [label="(1, 2, 4)", style=solid]; | ||
} |
65 changes: 65 additions & 0 deletions
65
tests/torch/data/reference_graphs/fx/quantized/yolo11n_sdpa_block.dot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
strict digraph { | ||
"0 x" [id=0, type=input]; | ||
"1 x_0_0_nncf_smooth_quant_0" [id=1, type=call_module]; | ||
"2 quantize_per_tensor_default" [id=2, type=quantize_per_tensor]; | ||
"3 dequantize_per_tensor_default" [id=3, type=dequantize_per_tensor]; | ||
"4 linear_scale_0" [id=4, type=get_attr]; | ||
"5 linear_zero_point_0" [id=5, type=get_attr]; | ||
"6 compressed_weight_updated_constant0" [id=6, type=get_attr]; | ||
"7 quantize_per_channel_default" [id=7, type=quantize_per_channel]; | ||
"8 dequantize_per_channel_default" [id=8, type=dequantize_per_channel]; | ||
"9 linear" [id=9, type=linear]; | ||
"10 quantize_per_tensor_default_1" [id=10, type=quantize_per_tensor]; | ||
"11 dequantize_per_tensor_default_1" [id=11, type=dequantize_per_tensor]; | ||
"12 slice_1" [id=12, type=slice]; | ||
"13 slice_2" [id=13, type=slice]; | ||
"14 slice_3" [id=14, type=slice]; | ||
"15 quantize_per_tensor_default_2" [id=15, type=quantize_per_tensor]; | ||
"16 dequantize_per_tensor_default_2" [id=16, type=dequantize_per_tensor]; | ||
"17 slice_4" [id=17, type=slice]; | ||
"18 slice_5" [id=18, type=slice]; | ||
"19 slice_6" [id=19, type=slice]; | ||
"20 slice_7" [id=20, type=slice]; | ||
"21 slice_8" [id=21, type=slice]; | ||
"22 slice_9" [id=22, type=slice]; | ||
"23 transpose" [id=23, type=transpose]; | ||
"24 matmul" [id=24, type=matmul]; | ||
"25 div_" [id=25, type=div_]; | ||
"26 softmax" [id=26, type=softmax]; | ||
"27 transpose_1" [id=27, type=transpose]; | ||
"28 matmul_1" [id=28, type=matmul]; | ||
"29 output" [id=29, type=output]; | ||
"0 x" -> "1 x_0_0_nncf_smooth_quant_0" [label="(1, 2, 4)", style=solid]; | ||
"1 x_0_0_nncf_smooth_quant_0" -> "2 quantize_per_tensor_default" [label="(1, 2, 4)", style=solid]; | ||
"2 quantize_per_tensor_default" -> "3 dequantize_per_tensor_default" [label="(1, 2, 4)", style=solid]; | ||
"3 dequantize_per_tensor_default" -> "9 linear" [label="(1, 2, 4)", style=solid]; | ||
"4 linear_scale_0" -> "7 quantize_per_channel_default" [label="(12,)", style=solid]; | ||
"4 linear_scale_0" -> "8 dequantize_per_channel_default" [label="(12,)", style=solid]; | ||
"5 linear_zero_point_0" -> "7 quantize_per_channel_default" [label="(12,)", style=solid]; | ||
"5 linear_zero_point_0" -> "8 dequantize_per_channel_default" [label="(12,)", style=solid]; | ||
"6 compressed_weight_updated_constant0" -> "7 quantize_per_channel_default" [label="(12, 4)", style=solid]; | ||
"7 quantize_per_channel_default" -> "8 dequantize_per_channel_default" [label="(12, 4)", style=solid]; | ||
"8 dequantize_per_channel_default" -> "9 linear" [label="(12, 4)", style=solid]; | ||
"9 linear" -> "10 quantize_per_tensor_default_1" [label="(1, 2, 12)", style=solid]; | ||
"9 linear" -> "15 quantize_per_tensor_default_2" [label="(1, 2, 12)", style=solid]; | ||
"9 linear" -> "20 slice_7" [label="(1, 2, 12)", style=solid]; | ||
"10 quantize_per_tensor_default_1" -> "11 dequantize_per_tensor_default_1" [label="(1, 2, 12)", style=solid]; | ||
"11 dequantize_per_tensor_default_1" -> "12 slice_1" [label="(1, 2, 12)", style=solid]; | ||
"12 slice_1" -> "13 slice_2" [label="(1, 2, 12)", style=solid]; | ||
"13 slice_2" -> "14 slice_3" [label="(1, 2, 12)", style=solid]; | ||
"14 slice_3" -> "24 matmul" [label="(1, 2, 4)", style=solid]; | ||
"15 quantize_per_tensor_default_2" -> "16 dequantize_per_tensor_default_2" [label="(1, 2, 12)", style=solid]; | ||
"16 dequantize_per_tensor_default_2" -> "17 slice_4" [label="(1, 2, 12)", style=solid]; | ||
"17 slice_4" -> "18 slice_5" [label="(1, 2, 12)", style=solid]; | ||
"18 slice_5" -> "19 slice_6" [label="(1, 2, 12)", style=solid]; | ||
"19 slice_6" -> "23 transpose" [label="(1, 2, 4)", style=solid]; | ||
"20 slice_7" -> "21 slice_8" [label="(1, 2, 12)", style=solid]; | ||
"21 slice_8" -> "22 slice_9" [label="(1, 2, 12)", style=solid]; | ||
"22 slice_9" -> "28 matmul_1" [label="(1, 2, 4)", style=solid]; | ||
"23 transpose" -> "24 matmul" [label="(1, 4, 2)", style=solid]; | ||
"24 matmul" -> "25 div_" [label="(1, 2, 2)", style=solid]; | ||
"25 div_" -> "26 softmax" [label="(1, 2, 2)", style=solid]; | ||
"26 softmax" -> "27 transpose_1" [label="(1, 2, 2)", style=solid]; | ||
"27 transpose_1" -> "28 matmul_1" [label="(1, 2, 2)", style=solid]; | ||
"28 matmul_1" -> "29 output" [label="(1, 2, 4)", style=solid]; | ||
} |
21 changes: 21 additions & 0 deletions
21
tests/torch/data/reference_graphs/fx/reference_metatypes/yolo11n_sdpa_block.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
{ | ||
"kqv_weight": "PTConstNoopMetatype", | ||
"x": "PTInputNoopMetatype", | ||
"linear": "PTLinearMetatype", | ||
"slice_1": "PTGatherMetatype", | ||
"slice_2": "PTGatherMetatype", | ||
"slice_3": "PTGatherMetatype", | ||
"slice_4": "PTGatherMetatype", | ||
"slice_5": "PTGatherMetatype", | ||
"slice_6": "PTGatherMetatype", | ||
"slice_7": "PTGatherMetatype", | ||
"slice_8": "PTGatherMetatype", | ||
"slice_9": "PTGatherMetatype", | ||
"transpose": "PTTransposeMetatype", | ||
"matmul": "PTMatMulMetatype", | ||
"div_": "PTDivMetatype", | ||
"softmax": "PTSoftmaxMetatype", | ||
"transpose_1": "PTTransposeMetatype", | ||
"matmul_1": "PTMatMulMetatype", | ||
"output_1": "PTOutputNoopMetatype" | ||
} |
38 changes: 38 additions & 0 deletions
38
tests/torch/data/reference_graphs/fx/transformed/cat_constant_update.dot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
strict digraph { | ||
"0 conv_a_weight" [id=0, type=get_attr]; | ||
"1 conv_a_bias" [id=1, type=get_attr]; | ||
"2 conv_b_weight" [id=2, type=get_attr]; | ||
"3 conv_b_bias" [id=3, type=get_attr]; | ||
"4 conv_c_weight" [id=4, type=get_attr]; | ||
"5 conv_c_bias" [id=5, type=get_attr]; | ||
"6 bias" [id=6, type=get_attr]; | ||
"7 x" [id=7, type=input]; | ||
"8 conv2d" [id=8, type=conv2d]; | ||
"9 conv2d_1" [id=9, type=conv2d]; | ||
"10 add_" [id=10, type=add_]; | ||
"11 add__1" [id=11, type=add_]; | ||
"12 const_updated_constant0" [id=12, type=get_attr]; | ||
"13 cat" [id=13, type=cat]; | ||
"14 conv2d_2" [id=14, type=conv2d]; | ||
"15 add" [id=15, type=add]; | ||
"16 output_1" [id=16, type=output]; | ||
"0 conv_a_weight" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid]; | ||
"1 conv_a_bias" -> "8 conv2d" [label="(3,)", style=solid]; | ||
"2 conv_b_weight" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid]; | ||
"3 conv_b_bias" -> "9 conv2d_1" [label="(3,)", style=solid]; | ||
"4 conv_c_weight" -> "14 conv2d_2" [label="(3, 9, 1, 1)", style=solid]; | ||
"5 conv_c_bias" -> "14 conv2d_2" [label="(3,)", style=solid]; | ||
"6 bias" -> "10 add_" [label="(1,)", style=solid]; | ||
"6 bias" -> "11 add__1" [label="(1,)", style=solid]; | ||
"6 bias" -> "15 add" [label="(1,)", style=solid]; | ||
"7 x" -> "8 conv2d" [label="(1, 3, 3, 3)", style=solid]; | ||
"8 conv2d" -> "9 conv2d_1" [label="(1, 3, 3, 3)", style=solid]; | ||
"8 conv2d" -> "10 add_" [label="(1, 3, 3, 3)", style=solid]; | ||
"9 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid]; | ||
"10 add_" -> "13 cat" [label="(1, 3, 3, 3)", style=solid]; | ||
"11 add__1" -> "13 cat" [label="(1, 3, 3, 3)", style=solid]; | ||
"12 const_updated_constant0" -> "13 cat" [label="(1,)", style=solid]; | ||
"13 cat" -> "14 conv2d_2" [label="(1, 9, 3, 3)", style=solid]; | ||
"14 conv2d_2" -> "15 add" [label="(1, 3, 3, 3)", style=solid]; | ||
"15 add" -> "16 output_1" [label="(1, 3, 3, 3)", style=solid]; | ||
} |
70 changes: 36 additions & 34 deletions
70
tests/torch/data/reference_graphs/fx/transformed/constant_update.dot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,38 @@ | ||
strict digraph { | ||
"0 conv_a_weight" [id=0, type=get_attr]; | ||
"1 conv_a_bias" [id=1, type=get_attr]; | ||
"2 conv_b_weight" [id=2, type=get_attr]; | ||
"3 conv_b_bias" [id=3, type=get_attr]; | ||
"4 conv_c_weight" [id=4, type=get_attr]; | ||
"5 conv_c_bias" [id=5, type=get_attr]; | ||
"6 x" [id=6, type=input]; | ||
"7 conv2d" [id=7, type=conv2d]; | ||
"8 conv2d_1" [id=8, type=conv2d]; | ||
"9 add__updated_constant0" [id=9, type=get_attr]; | ||
"10 add_" [id=10, type=add_]; | ||
"11 add__1" [id=11, type=add_]; | ||
"12 add" [id=12, type=add]; | ||
"13 conv2d_2" [id=13, type=conv2d]; | ||
"14 add_1" [id=14, type=add]; | ||
"15 output_1" [id=15, type=output]; | ||
"0 conv_a_weight" -> "7 conv2d" [label="(3, 3, 1, 1)", style=solid]; | ||
"1 conv_a_bias" -> "7 conv2d" [label="(3,)", style=solid]; | ||
"2 conv_b_weight" -> "8 conv2d_1" [label="(3, 3, 1, 1)", style=solid]; | ||
"3 conv_b_bias" -> "8 conv2d_1" [label="(3,)", style=solid]; | ||
"4 conv_c_weight" -> "13 conv2d_2" [label="(3, 3, 1, 1)", style=solid]; | ||
"5 conv_c_bias" -> "13 conv2d_2" [label="(3,)", style=solid]; | ||
"6 x" -> "7 conv2d" [label="(1, 3, 3, 3)", style=solid]; | ||
"7 conv2d" -> "8 conv2d_1" [label="(1, 3, 3, 3)", style=solid]; | ||
"7 conv2d" -> "10 add_" [label="(1, 3, 3, 3)", style=solid]; | ||
"8 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid]; | ||
"9 add__updated_constant0" -> "10 add_" [label="(1,)", style=solid]; | ||
"9 add__updated_constant0" -> "11 add__1" [label="(1,)", style=solid]; | ||
"9 add__updated_constant0" -> "14 add_1" [label="(1,)", style=solid]; | ||
"10 add_" -> "12 add" [label="(1, 3, 3, 3)", style=solid]; | ||
"11 add__1" -> "12 add" [label="(1, 3, 3, 3)", style=solid]; | ||
"12 add" -> "13 conv2d_2" [label="(1, 3, 3, 3)", style=solid]; | ||
"13 conv2d_2" -> "14 add_1" [label="(1, 3, 3, 3)", style=solid]; | ||
"14 add_1" -> "15 output_1" [label="(1, 3, 3, 3)", style=solid]; | ||
"0 const" [id=0, type=get_attr]; | ||
"1 conv_a_weight" [id=1, type=get_attr]; | ||
"2 conv_a_bias" [id=2, type=get_attr]; | ||
"3 conv_b_weight" [id=3, type=get_attr]; | ||
"4 conv_b_bias" [id=4, type=get_attr]; | ||
"5 conv_c_weight" [id=5, type=get_attr]; | ||
"6 conv_c_bias" [id=6, type=get_attr]; | ||
"7 x" [id=7, type=input]; | ||
"8 conv2d" [id=8, type=conv2d]; | ||
"9 conv2d_1" [id=9, type=conv2d]; | ||
"10 bias_updated_constant0" [id=10, type=get_attr]; | ||
"11 add_" [id=11, type=add_]; | ||
"12 add__1" [id=12, type=add_]; | ||
"13 cat" [id=13, type=cat]; | ||
"14 conv2d_2" [id=14, type=conv2d]; | ||
"15 add" [id=15, type=add]; | ||
"16 output_1" [id=16, type=output]; | ||
"0 const" -> "13 cat" [label="(1, 3, 3, 3)", style=solid]; | ||
"1 conv_a_weight" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid]; | ||
"2 conv_a_bias" -> "8 conv2d" [label="(3,)", style=solid]; | ||
"3 conv_b_weight" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid]; | ||
"4 conv_b_bias" -> "9 conv2d_1" [label="(3,)", style=solid]; | ||
"5 conv_c_weight" -> "14 conv2d_2" [label="(3, 9, 1, 1)", style=solid]; | ||
"6 conv_c_bias" -> "14 conv2d_2" [label="(3,)", style=solid]; | ||
"7 x" -> "8 conv2d" [label="(1, 3, 3, 3)", style=solid]; | ||
"8 conv2d" -> "9 conv2d_1" [label="(1, 3, 3, 3)", style=solid]; | ||
"8 conv2d" -> "11 add_" [label="(1, 3, 3, 3)", style=solid]; | ||
"9 conv2d_1" -> "12 add__1" [label="(1, 3, 3, 3)", style=solid]; | ||
"10 bias_updated_constant0" -> "11 add_" [label="(1,)", style=solid]; | ||
"10 bias_updated_constant0" -> "12 add__1" [label="(1,)", style=solid]; | ||
"10 bias_updated_constant0" -> "15 add" [label="(1,)", style=solid]; | ||
"11 add_" -> "13 cat" [label="(1, 3, 3, 3)", style=solid]; | ||
"12 add__1" -> "13 cat" [label="(1, 3, 3, 3)", style=solid]; | ||
"13 cat" -> "14 conv2d_2" [label="(1, 9, 3, 3)", style=solid]; | ||
"14 conv2d_2" -> "15 add" [label="(1, 3, 3, 3)", style=solid]; | ||
"15 add" -> "16 output_1" [label="(1, 3, 3, 3)", style=solid]; | ||
} |
Oops, something went wrong.