Skip to content

Commit

Permalink
[Codegen][Tuner]: remove attrs inside decomposeConfig for attention op
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Feb 24, 2025
1 parent f85a780 commit bb8be9f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,19 @@ struct StripAttentionOpCompilationInfo final
eraseLoweringConfig(attentionOp);
}

if (attentionOp.getDecompositionConfigAttr()) {
attentionOp.removeDecompositionConfigAttr();
DictionaryAttr decompositionConfig =
attentionOp.getDecompositionConfigAttr();
if (decompositionConfig) {
decompositionConfig = DictionaryAttr::get(
decompositionConfig.getContext(),
llvm::to_vector(llvm::make_filter_range(
decompositionConfig, [&](NamedAttribute attr) {
return attr.getName() !=
IREE::LinalgExt::AttentionOp::getQKAttrStr() &&
attr.getName() !=
IREE::LinalgExt::AttentionOp::getPVAttrStr();
})));
attentionOp.setDecompositionConfigAttr(decompositionConfig);
}
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func.func @matmul_128x1024x256_1(%lhs : tensor<128x256xf32>, %rhs: tensor<256x10
#config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]>
func.func @attention(%arg0: tensor<2x10x6x4xf16>, %arg1 : tensor<2x10x4x4xf16>, %arg2 : tensor<2x10x4x4xf16>, %arg3 : f16) -> tensor<2x10x6x4xf16> attributes {translation_info = #iree_codegen.translation_info<pipeline = None subgroup_size = 32>} {
%init = tensor.empty() : tensor<2x10x6x4xf16>
%result = iree_linalg_ext.attention {decomposition_config = {x}, indexing_maps = [#map, #map1, #map2, #map3, #map4], lowering_config = #config} ins(%arg0, %arg1, %arg2, %arg3 : tensor<2x10x6x4xf16>, tensor<2x10x4x4xf16>, tensor<2x10x4x4xf16>, f16) outs(%init : tensor<2x10x6x4xf16>) {
%result = iree_linalg_ext.attention {decomposition_config = {pv_attrs = {x}, qk_attrs = {y}, z}, indexing_maps = [#map, #map1, #map2, #map3, #map4], lowering_config = #config} ins(%arg0, %arg1, %arg2, %arg3 : tensor<2x10x6x4xf16>, tensor<2x10x4x4xf16>, tensor<2x10x4x4xf16>, f16) outs(%init : tensor<2x10x6x4xf16>) {
^bb0(%arg: f32):
iree_linalg_ext.yield %arg : f32
} -> tensor<2x10x6x4xf16>
Expand All @@ -82,13 +82,13 @@ func.func @attention(%arg0: tensor<2x10x6x4xf16>, %arg1 : tensor<2x10x4x4xf16>,

// CHECK-LABEL: func.func @attention
// CHECK: iree_linalg_ext.attention
// CHECK-SAME: decomposition_config = {z}
// CHECK-NOT: iree_codegen.translation_info
// CHECK-NOT: iree_codegen.lowering_config
// CHECK-NOT: translation_info =
// CHECK-NOT: lowering_config =
// CHECK-NOT: decomposition_config =


// -----

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
Expand All @@ -103,7 +103,7 @@ func.func @attention(%arg0: tensor<2x10x6x4xf16>, %arg1 : tensor<2x10x4x4xf16>,
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
func.func @attention_1(%arg0: tensor<2x10x6x4xf16>, %arg1 : tensor<2x10x4x4xf16>, %arg2 : tensor<2x10x4x4xf16>, %arg3 : f16) -> tensor<2x10x6x4xf16> attributes {translation_info = #iree_codegen.translation_info<pipeline = None subgroup_size = 32>} {
%init = tensor.empty() : tensor<2x10x6x4xf16>
%result = iree_linalg_ext.attention {decomposition_config = {x}, indexing_maps = [#map, #map1, #map2, #map3, #map4], compilation_info = #compilation} ins(%arg0, %arg1, %arg2, %arg3 : tensor<2x10x6x4xf16>, tensor<2x10x4x4xf16>, tensor<2x10x4x4xf16>, f16) outs(%init : tensor<2x10x6x4xf16>) {
%result = iree_linalg_ext.attention {decomposition_config = {pv_attrs = {x}, qk_attrs = {y}}, indexing_maps = [#map, #map1, #map2, #map3, #map4], compilation_info = #compilation} ins(%arg0, %arg1, %arg2, %arg3 : tensor<2x10x6x4xf16>, tensor<2x10x4x4xf16>, tensor<2x10x4x4xf16>, f16) outs(%init : tensor<2x10x6x4xf16>) {
^bb0(%arg: f32):
iree_linalg_ext.yield %arg : f32
} -> tensor<2x10x6x4xf16>
Expand All @@ -112,6 +112,7 @@ func.func @attention_1(%arg0: tensor<2x10x6x4xf16>, %arg1 : tensor<2x10x4x4xf16>

// CHECK-LABEL: func.func @attention_1
// CHECK: iree_linalg_ext.attention
// CHECK-SAME: decomposition_config = {}
// CHECK-NOT: iree_codegen.compilation_info
// CHECK-NOT: iree_codegen.lowering_config
// CHECK-NOT: iree_codegen.translation_info
Expand Down

0 comments on commit bb8be9f

Please sign in to comment.