Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training #296

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
95b068e
[compiler] add fuse-nested-forall pass
XG-zheng May 25, 2024
406e5ad
[compiler] add vector-warp-distribute
XG-zheng May 25, 2024
f63fbdd
[compiler] canonicalize for tensor/vector
XG-zheng May 25, 2024
20f16df
[compiler] add block dims hint for mapping-forall
XG-zheng May 25, 2024
9ea79d5
[compiler] optimize reduction codegen
XG-zheng May 25, 2024
d2f1aee
[compiler] move VectorToScf after warp distribute
XG-zheng May 25, 2024
1a6a26a
[compiler] fix bug in RemoveCopy
XG-zheng May 27, 2024
57712a5
[compiler] fix bug in reduction
XG-zheng May 28, 2024
bb8a20b
[compiler] fuse insert slice with elementwise op
XG-zheng May 29, 2024
88deadb
[compiler] convert insert slice to linalg
XG-zheng May 29, 2024
abb3627
[compiler] fuse mhlo sliceOp
XG-zheng May 29, 2024
d4fb8ba
[compiler] element-wise fusion add transpose candidate
YellowHCH May 29, 2024
4ea0ba0
[compiler] chore
YellowHCH May 30, 2024
bb677d9
[dynamo] add debug backend for comparsion
YellowHCH May 30, 2024
cd8aae2
[dynamo] switch to torch2.12
YellowHCH Jun 3, 2024
bd30523
change Unknown to name with hint
qingyunqu Jun 4, 2024
47e4c59
[dynamo] fix fx cache hash key with GraphModule
YellowHCH Jun 5, 2024
66de9ca
[runtime] add BatchTranspose cuda kernel
XG-zheng Jun 5, 2024
f4ae62c
[compiler] fix transpose fusion
XG-zheng Jun 5, 2024
4783844
fix transpose
qingyunqu Jun 5, 2024
fec3cbb
[runtime] add transpose_with_small_inner_dim kernel
XG-zheng Jun 6, 2024
01de85b
[compiler] adjust blockSize/GridSize
XG-zheng Jun 6, 2024
18c90b4
add bacthed bind args
XG-zheng Jun 7, 2024
e503867
[frontend] reduce brt overhead
XG-zheng Jun 7, 2024
f874171
refine contigous inputs && host side overhead
YellowHCH Jun 11, 2024
8dfff7a
[compiler] reorder loop of concat
XG-zheng Jun 12, 2024
5923c01
[dynamo] reduce host overhead
YellowHCH Jun 12, 2024
4675de1
[compiler] convert static concat to index_switch
XG-zheng Jun 12, 2024
0f4b143
[compiler] remove redundant copyOp after ConvertFuncAndCallToByre
XG-zheng Jun 12, 2024
bbb7c00
[dynamo] add outputs info in compile time
YellowHCH Jun 15, 2024
2b47797
[dynamo] update readme
YellowHCH Jun 15, 2024
36ebc1f
[dynamo] refine analysis aliased output info
YellowHCH Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/include/byteir/Dialect/SCF/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef BYTEIR_DIALECT_SCF_PASSES_H
#define BYTEIR_DIALECT_SCF_PASSES_H

#include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h"
#include "byteir/Dialect/SCF/Transforms/InsertTrivialSCFLoop.h"

namespace mlir {
Expand Down
17 changes: 17 additions & 0 deletions compiler/include/byteir/Dialect/SCF/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,21 @@ def InsertTrivialSCFLoop : Pass<"insert-trivial-scf-loop", "mlir::func::FuncOp">
];
}

//===----------------------------------------------------------------------===//
// FuseNestedForall
//===----------------------------------------------------------------------===//

def FuseNestedForall : Pass<"fuse-nested-forall", "mlir::func::FuncOp"> {
let summary = "Fuse nested forall if possible";
let constructor = "mlir::createFuseNestedForallPass()";
let dependentDialects = [
"scf::SCFDialect"
];
let options = [
Option<"anchorTag", "anchor-tag", "std::string",
/*default=*/"",
"Optional unitAttr anchored tag to apply this pass">
];
}

#endif // BYTEIR_DIALECT_SCF_PASSES
34 changes: 34 additions & 0 deletions compiler/include/byteir/Dialect/SCF/Transforms/FuseNestedForall.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- FuseNestedForall.h ------------------------------------- C++ --===//
//
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_FUSENESTEDFORALL_H
#define BYTEIR_DIALECT_SCF_TRANSFORMS_FUSENESTEDFORALL_H

#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func

std::unique_ptr<OperationPass<func::FuncOp>>
createFuseNestedForallPass(llvm::StringRef anchorTag = "");

} // namespace mlir

#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_FUSENESTEDFORALL_H
3 changes: 3 additions & 0 deletions compiler/include/byteir/Dialect/Transform/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def DetensorizeTransformInsertion : Pass<"insert-detensorize-transform", "Module
let summary = "Insert detensorize transform IR to functions.";
let constructor = "mlir::createDetensorizeTransformInsertionPass()";
let options = [
Option<"usingVectorizeOp", "using-vectorize-op", "bool",
/*default=*/"false",
"using vectorizeOp to detensorize linalg op.">,
Option<"funcAnchorAttr", "func-anchor", "std::string",
/*default=*/"",
"An optional Unit attribute anchoring on target functions.">,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ createGenericTransformInsertionPass(const TransformInsertionConfig &config);

std::unique_ptr<OperationPass<ModuleOp>>
createDetensorizeTransformInsertionPass(
const std::string &funcAnchor = "",
const bool usingVectorizeOp = false, const std::string &funcAnchor = "",
const std::string &matchPrefix = "__byteir_detensorize");

std::unique_ptr<OperationPass<ModuleOp>> createFuseExtTransformInsertionPass(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- MoveForallRegionIntoWarpOp.h ---------------------------*--- C++ -*-===//
//
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_MOVEFORALLREGIONINTOWARPOP_H
#define BYTEIR_DIALECT_SCF_TRANSFORMS_MOVEFORALLREGIONINTOWARPOP_H

#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func

constexpr StringRef getMoveForallRegionIntoWarpOpAttrName() {
return "__byteir_move_forall_region_into_warp_execute_on_lane0";
}

std::unique_ptr<OperationPass<func::FuncOp>>
createMoveForallRegionIntoWarpOpPass(int64_t warpSize = 32,
llvm::StringRef anchorTag = "");

} // namespace mlir

#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_MOVEFORALLREGIONINTOWARPOP_H
7 changes: 7 additions & 0 deletions compiler/include/byteir/Dialect/Vector/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@
#ifndef BYTEIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H
#define BYTEIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H

#include "byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h"
#include "byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func

/// Generate the code for registering transforms passes.
#define GEN_PASS_DECL_VECTORTRANSPOSELOWERINGPASS
#define GEN_PASS_DECL_MOVEFORALLREGIONINTOWARPOPPASS
#define GEN_PASS_DECL_SCALARVECTORLOWERINGPASS
#define GEN_PASS_REGISTRATION
#include "byteir/Dialect/Vector/Transforms/Passes.h.inc"

Expand Down
67 changes: 67 additions & 0 deletions compiler/include/byteir/Dialect/Vector/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,72 @@ def VectorTransposeLoweringPass : Pass<"vector-transpose-lowering", "func::FuncO
];
}

//===----------------------------------------------------------------------===//
// Move Forall Region Into WarpOp
//===----------------------------------------------------------------------===//

def MoveForallRegionIntoWarpOpPass : Pass<"move-forall-region-into-warp-op", "mlir::func::FuncOp"> {
let summary = "move region of forall into warp_execute_on_lane_0 op";
let constructor = "mlir::createMoveForallRegionIntoWarpOpPass()";
let dependentDialects = [
"memref::MemRefDialect",
"vector::VectorDialect",
"gpu::GPUDialect",
];
let options = [
Option<"warpSize", "warp-size", "int64_t", "32", "warp size">,
Option<"anchorTag", "anchor-tag", "std::string",
/*default=*/"",
"Optional unitAttr anchored tag to apply this pass">
];
}

//===----------------------------------------------------------------------===//
// Vector Warp Distribute
//===----------------------------------------------------------------------===//

def VectorWarpDistributePass : Pass<"vector-warp-distribute", "mlir::func::FuncOp"> {
let summary = "vector warp distribute transformation";
let constructor = "mlir::createVectorWarpDistributePass()";
let dependentDialects = [
"scf::SCFDialect",
"memref::MemRefDialect",
"vector::VectorDialect",
"gpu::GPUDialect",
"affine::AffineDialect",
];
let options = [
Option<"warpOpToSCF", "rewrite-warp-ops-to-scf-if", "bool",
/*default=*/"false",
"Lower vector.warp_execute_on_lane0 to scf.if op">,

Option<"distributeTransferWriteOps", "distribute-transfer-write", "bool",
/*default=*/"false",
"distribution of transfer write">,

Option<"hoistUniform", "hoist-uniform", "bool",
/*default=*/"false",
"hoist-uniform">,

Option<"propagateDistribution", "propagate-distribution", "bool",
/*default=*/"false",
"distribution propgation">,

Option<"maxTransferWriteElements", "max-transfer-write-elements", "int64_t",
/*default=*/"1",
"Maximum number of transfer write elements to distribute">,
];
}

//===----------------------------------------------------------------------===//
// Scalar Vector Lowering
//===----------------------------------------------------------------------===//

def ScalarVectorLoweringPass : Pass<"scalar-vector-lowering", "func::FuncOp"> {
let summary = "Pass to lower scalar vector";
let dependentDialects = [
"memref::MemRefDialect",
"vector::VectorDialect"
];
}
#endif // BYTEIR_DIALECT_VECTOR_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- VectorWarpDistribute.h ---------------------------*--- C++ -*-===//
//
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_VECTORWARPDISTRIBUTE_H
#define BYTEIR_DIALECT_SCF_TRANSFORMS_VECTORWARPDISTRIBUTE_H

#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func

#define GEN_PASS_DECL_VECTORWARPDISTRIBUTEPASS
#include "byteir/Dialect/Vector/Transforms/Passes.h.inc"

std::unique_ptr<OperationPass<func::FuncOp>>
createVectorWarpDistributePass(const VectorWarpDistributePassOptions &options =
VectorWarpDistributePassOptions());

} // namespace mlir

#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_VECTORWARPDISTRIBUTE_H
4 changes: 4 additions & 0 deletions compiler/include/byteir/Dialect/mhlo/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ inline void registerByteIRMhloPassesExt() {
return mlir::createConcatSliceFusionPass();
});

::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::createInsertSliceWithElemwiseFusionPass();
});

// register createCatFusionPass
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::createCatFusionPass();
Expand Down
3 changes: 3 additions & 0 deletions compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ createElementFusionPass(bool clusterSingleElemwiseOp = false,

std::unique_ptr<OperationPass<func::FuncOp>> createConcatSliceFusionPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createInsertSliceWithElemwiseFusionPass();

std::unique_ptr<OperationPass<func::FuncOp>> createMatmulEpilogueFusionPass();

std::unique_ptr<OperationPass<func::FuncOp>> createIOConvertFusionPass();
Expand Down
7 changes: 7 additions & 0 deletions compiler/include/byteir/Pipelines/GPU/MappingForall.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef BYTEIR_PIPELINES_GPU_MAPPING_FORALL_H
#define BYTEIR_PIPELINES_GPU_MAPPING_FORALL_H

#include "byteir/Utils/OptionUtils.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
Expand All @@ -34,6 +35,12 @@ struct GPUMappingForallOptions
*this, "annotate-prefix",
llvm::cl::desc("An optional annotate prefix attribute on target ops."),
llvm::cl::init("__byteir_gpu_split_grid_reduction")};
Option<int64_t> warpSize{*this, "warp-size", llvm::cl::desc("warp size."),
llvm::cl::init(32)};
Option<llvm::cl::KernelDims> blockDimsHint{
*this, "block-size-hint",
llvm::cl::desc("block dims hint for dynamic shape."),
llvm::cl::init(llvm::cl::KernelDims{1024, 1, 1})};
// TODO: option for grid/block dims hint
};

Expand Down
58 changes: 52 additions & 6 deletions compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ struct GPUTileGridReductionOptions
llvm::cl::init(32)};
Option<int64_t> blockSize{*this, "block-size", llvm::cl::desc("block size"),
llvm::cl::init(256)};
Option<bool> usingForall{*this, "using-forall",
llvm::cl::desc("using forall"),
llvm::cl::init(true)};
};

struct GPUSplitBlockReductionOptions
Expand Down Expand Up @@ -92,9 +89,44 @@ struct GPUTileBlockReductionOptions
llvm::cl::init(32)};
Option<int64_t> blockSize{*this, "block-size", llvm::cl::desc("block size"),
llvm::cl::init(256)};
Option<bool> usingForall{*this, "using-forall",
llvm::cl::desc("using forall"),
llvm::cl::init(true)};
};

struct GPUTileSplitWarpReductionOptions
: public PassPipelineOptions<GPUTileSplitWarpReductionOptions> {
Option<std::string> funcAnchor{
*this, "func-anchor",
llvm::cl::desc(
"An optional Unit attribute anchoring on target functions."),
llvm::cl::init("")};
Option<std::string> annotatePrefix{
*this, "annotate-prefix",
llvm::cl::desc("An optional annotate prefix attribute on target ops."),
llvm::cl::init("__byteir_gpu_split_warp_reduction")};
Option<int64_t> blockSize{*this, "block-size", llvm::cl::desc("block size"),
llvm::cl::init(256)};
Option<int64_t> warpSize{*this, "warp-size", llvm::cl::desc("warp size"),
llvm::cl::init(32)};
};

struct GPUTileWarpReductionOptions
: public PassPipelineOptions<GPUTileWarpReductionOptions> {
Option<std::string> funcAnchor{
*this, "func-anchor",
llvm::cl::desc(
"An optional Unit attribute anchoring on target functions."),
llvm::cl::init("")};
Option<std::string> annotatePrefix{
*this, "annotate-prefix",
llvm::cl::desc("An optional annotate prefix attribute on target ops."),
llvm::cl::init("__byteir_gpu_warp_reduction")};
Option<int64_t> splitFactor{*this, "split-factor",
llvm::cl::desc("split factor"),
llvm::cl::init(32)};
Option<int64_t> warpSize{*this, "warp-size", llvm::cl::desc("warp size"),
llvm::cl::init(32)};
Option<bool> usingGPUShuffle{*this, "using-gpu-shuffle",
llvm::cl::desc("using gpu shuffle"),
llvm::cl::init(true)};
};

struct GPUTileThreadReductionOptions
Expand All @@ -118,6 +150,10 @@ void createGPUSplitBlockReductionTransform(
OpPassManager &pm, const GPUSplitBlockReductionOptions &options);
void createGPUTileBlockReductionTransform(
OpPassManager &pm, const GPUTileBlockReductionOptions &options);
void createGPUTileSplitWarpReductionTransform(
OpPassManager &pm, const GPUTileSplitWarpReductionOptions &options);
void createGPUTileWarpReductionTransform(
OpPassManager &pm, const GPUTileWarpReductionOptions &options);
void createGPUTileThreadReductionTransform(
OpPassManager &pm, const GPUTileThreadReductionOptions &options);

Expand All @@ -142,6 +178,16 @@ inline void registerGPUReductionCodegenPipelines() {
"Insert transformation IR to tile linalg reduction op",
createGPUTileBlockReductionTransform);

PassPipelineRegistration<GPUTileSplitWarpReductionOptions>(
"insert-gpu-tile-split-warp-reduction-transform",
"Insert transformation IR to split block reduction to warp",
createGPUTileSplitWarpReductionTransform);

PassPipelineRegistration<GPUTileWarpReductionOptions>(
"insert-gpu-tile-warp-reduction-transform",
"Insert transformation IR to vectorize warp redution",
createGPUTileWarpReductionTransform);

PassPipelineRegistration<GPUTileThreadReductionOptions>(
"insert-gpu-tile-thread-reduction-transform",
"Insert transformation IR to tile linalg reduction op",
Expand Down
Loading