Skip to content

Commit

Permalink
[Compiler/Runtime/External Libs] Add flash attention external library…
Browse files Browse the repository at this point in the history
… & pass through mechanism (#99)

This PR moves runtime flash attention kernels to `external_libs/`,
supports flash attention kvcache, and uses it in a pass-through fashion
by declaring `Byre::CustomOp`. `Byre::CustomOp` has three main
components:
1.  `StrAttr:$lib_path` , which specifies the path of the library file.
2. `StrAttr:$api_name`, which specifies the symbol name of the library
for this custom op.
3. `ArrayAttr:$extra_args`, which specifies the additional arguments
that needs to be passed to the api call.

The following changes are made to support pass-through on flash
attention op:

- Remove runtime flash-attention kernels, update and move it to
`external_libs/` (adds flash attention kvcache support).
- Add runtime `Byre::CustomOp` support.
`runtime/lib/backends/cuda/providers/default/custom/custom.cc`
- Add compiler `HloToByreCustomPass` conversion pass.
`compiler/lib/Conversion/HloToByreTensor/HloToByreCustom.cpp`
  • Loading branch information
zhekunz2 authored Jan 25, 2024
1 parent 0783a59 commit b60d7a6
Show file tree
Hide file tree
Showing 142 changed files with 4,952 additions and 3,133 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===- HloToByreCustom.h ---------------------------------------*--- C++-*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef BYTEIR_CONVERSION_HLOTOBYRETENSOR_HLOTOBYRECUSTOM_H
#define BYTEIR_CONVERSION_HLOTOBYRETENSOR_HLOTOBYRECUSTOM_H

#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>

namespace mlir {
// forward decl
namespace func {
class FuncOp;
} // namespace func
class Operation;

struct ByreCustomConfig {
std::function<llvm::StringRef(llvm::StringRef)> getCustomLibPath;
std::function<llvm::StringRef(llvm::StringRef)> getApiName;
std::function<ArrayAttr(mhlo::CustomCallOp)> getExtraArgs;
};

ByreCustomConfig getCudaByreCustomConfig();

// use ByreCustomConvertRuleBase to decide how to convert to byre custom op
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertHloToByreCustomPass(const ByreCustomConfig &);

} // namespace mlir

#endif // BYTEIR_CONVERSION_HLOTOBYRETENSOR_HLOTOBYRECUSTOM_H
47 changes: 47 additions & 0 deletions compiler/include/byteir/Dialect/Byre/ByreOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,51 @@ def Byre_AliasOp
let hasVerifier = 1;
}

def Byre_CustomOp : Byre_Op<"custom",
[HasParent<"func::FuncOp">,
DeclareOpInterfaceMethods<ByreInterface, ["getCalleeName"]>]> {
let summary = "compute custom operation passed by library path and api name. ";
let description = [{
Example:
```mlir
%2 = byre.custom(%0, %1) { lib_path = "xxx.so", api_name = "add", extra_args = [0 : i64, 1 : i64, 2.0 : f32] } : (f32, f32) -> f32
```
During execution, "xxx.so" will be loaded, and "add" function will be called.
}];

let arguments = (ins
StrAttr:$lib_path,
StrAttr:$api_name,
Variadic<AnyType>:$operands,
ArrayAttr:$extra_args,
OptionalAttr<ArrayAttr>:$memory_effects
);

let results = (outs
Variadic<AnyType>:$results
);

let builders = [
OpBuilder<(ins "StringRef":$lib_path,
"StringRef":$api_name,
"ValueRange":$inputs,
"ValueRange":$outputs,
"ArrayAttr":$extra_args)>
];

let extraClassDeclaration = [{
FunctionType getType();

/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}

operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
}];

let hasVerifier = 1;
}

#endif // BYTEIR_DIALECT_BYRE_BYRE_OPS
1 change: 1 addition & 0 deletions compiler/lib/Conversion/HloToByreTensor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_byteir_conversion_library(ByteIRHloToByreTensor
HloToByreCustom.cpp
HloToByreTensor.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
236 changes: 236 additions & 0 deletions compiler/lib/Conversion/HloToByreTensor/HloToByreCustom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//===- HloToByreCustom.cpp ------------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "byteir/Conversion/HloToByreTensor/HloToByreCustom.h"
#include "byteir/Dialect/Byre/ByreDialect.h"
#include "byteir/Dialect/Byre/Common.h"
#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h"
#include "byteir/Utils/Utils.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "../PassDetail.h"

using namespace mlir;
using namespace llvm;

namespace {
constexpr StringRef getFlashAttnLibPath() {
return "external_libs/libs/libflash_attn.so";
}
constexpr StringRef getFlashAttnFwdAPI() { return "run_flash_attn_fwd"; }
constexpr StringRef getFlashAttnBwdAPI() { return "run_flash_attn_bwd"; }
constexpr StringRef getFlashAttnKVCacheAPI() {
return "run_flash_attn_kvcache";
}
} // namespace

ByreCustomConfig mlir::getCudaByreCustomConfig() {
ByreCustomConfig config;
config.getCustomLibPath = [=](StringRef callee) {
if (callee == getFlashAttnFwdName() || callee == getFlashAttnBwdName()) {
return getFlashAttnLibPath();
}
return StringRef("");
};
config.getApiName = [=](StringRef callee) {
if (callee == getFlashAttnFwdName()) {
return getFlashAttnFwdAPI();
} else if (callee == getFlashAttnBwdName()) {
return getFlashAttnBwdAPI();
}
return StringRef("");
};
config.getExtraArgs = [=](mhlo::CustomCallOp op) {
SmallVector<Attribute> extraArgs;
auto callee = op.getCallTargetName();
if (callee == getFlashAttnFwdName() || callee == getFlashAttnBwdName()) {
OpBuilder rewriter(op);
ShapedType qShapeTy;
ShapedType kShapeTy;
ShapedType vShapeTy;
ShapedType oShapeTy;
if (callee == getFlashAttnFwdName()) {
qShapeTy = op.getOperand(0).getType().dyn_cast<ShapedType>();
kShapeTy = op.getOperand(1).getType().dyn_cast<ShapedType>();
vShapeTy = op.getOperand(2).getType().dyn_cast<ShapedType>();
oShapeTy = op.getResult(0).getType().dyn_cast<ShapedType>();
} else {
qShapeTy = op.getOperand(1).getType().dyn_cast<ShapedType>();
kShapeTy = op.getOperand(2).getType().dyn_cast<ShapedType>();
vShapeTy = op.getOperand(3).getType().dyn_cast<ShapedType>();
oShapeTy = op.getOperand(4).getType().dyn_cast<ShapedType>();
}
if (!qShapeTy || !qShapeTy.hasStaticShape() || !kShapeTy ||
!kShapeTy.hasStaticShape() || !vShapeTy ||
!vShapeTy.hasStaticShape() || !oShapeTy || !oShapeTy.hasStaticShape())
assert(false && "unexpected flash attention shape!");

auto qShape = qShapeTy.getShape();
auto kShape = kShapeTy.getShape();
auto vShape = vShapeTy.getShape();
auto oShape = oShapeTy.getShape();
int64_t batchSizeQ = qShape[0];
int64_t seqlenQ = qShape[1];
int64_t numHeadsQ = qShape[2];
int64_t headSizeQ = qShape[3];
int64_t batchSizeK = kShape[0];
int64_t seqlenK = kShape[1];
int64_t numHeadsK = kShape[2];
int64_t headSizeK = kShape[3];
assert(headSizeQ == headSizeK && batchSizeQ == batchSizeK);
assert(headSizeQ % 8 == 0);

auto roundMultiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int headSize = roundMultiple(headSizeQ, 8);
const int headSizeRounded = roundMultiple(headSize, 32);
const int seqlenQRounded = roundMultiple(seqlenQ, 128);
const int seqlenKRounded = roundMultiple(seqlenK, 128);

uint32_t qBatchStride = qShape[1] * qShape[2] * qShape[3];
uint32_t kBatchStride = kShape[1] * kShape[2] * kShape[3];
uint32_t vBatchStride = vShape[1] * vShape[2] * vShape[3];
uint32_t oBatchStride = oShape[1] * oShape[2] * oShape[3];
uint32_t qRowStride = qShape[2] * qShape[3];
uint32_t kRowStride = kShape[2] * kShape[3];
uint32_t vRowStride = vShape[2] * vShape[3];
uint32_t oRowStride = oShape[2] * oShape[3];
uint32_t qHeadStride = qShape[3];
uint32_t kHeadStride = kShape[3];
uint32_t vHeadStride = vShape[3];
uint32_t oHeadStride = oShape[3];

DictionaryAttr byteirAttrs =
op->getAttr(getCustomCallAttrName()).cast<DictionaryAttr>();
if (!byteirAttrs)
assert(false && "byteir attribute not found!");
bool causal = byteirAttrs.get("causal").cast<BoolAttr>().getValue();
float softmaxScale = byteirAttrs.get("softmax_scale")
.cast<FloatAttr>()
.getValue()
.convertToDouble();
float dropoutP = byteirAttrs.get("dropout_p")
.cast<FloatAttr>()
.getValue()
.convertToDouble();
int windowSizeLeft = -1;
int windowSizeRight = -1;
// causal=true is the same as causal=false in this case
if (seqlenQ == 1)
causal = false;
if (causal)
windowSizeRight = 0;

// extra args should match kernel api call
extraArgs.push_back(rewriter.getI64IntegerAttr(qBatchStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(kBatchStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(vBatchStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(oBatchStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(qRowStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(kRowStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(vRowStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(oRowStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(qHeadStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(kHeadStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(vHeadStride));
extraArgs.push_back(rewriter.getI64IntegerAttr(oHeadStride));

extraArgs.push_back(rewriter.getI64IntegerAttr(batchSizeQ));
extraArgs.push_back(rewriter.getI64IntegerAttr(numHeadsQ));
extraArgs.push_back(rewriter.getI64IntegerAttr(numHeadsK));
extraArgs.push_back(rewriter.getI64IntegerAttr(headSize));
extraArgs.push_back(rewriter.getI64IntegerAttr(headSizeRounded));
extraArgs.push_back(rewriter.getF32FloatAttr(softmaxScale));
extraArgs.push_back(rewriter.getI64IntegerAttr(seqlenQ));
extraArgs.push_back(rewriter.getI64IntegerAttr(seqlenK));
extraArgs.push_back(rewriter.getI64IntegerAttr(seqlenQRounded));
extraArgs.push_back(rewriter.getI64IntegerAttr(seqlenKRounded));
extraArgs.push_back(rewriter.getF32FloatAttr(dropoutP));
extraArgs.push_back(rewriter.getI64IntegerAttr(windowSizeLeft));
extraArgs.push_back(rewriter.getI64IntegerAttr(windowSizeRight));
return ArrayAttr::get(rewriter.getContext(), extraArgs);
}
return ArrayAttr({});
};
return config;
}

struct ConvertCustomCallOpToByreCustom : public RewritePattern {
ConvertCustomCallOpToByreCustom(MLIRContext *context,
const ByreCustomConfig &converter)
: RewritePattern(MatchAnyOpTypeTag(), 1, context), converter(converter) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!isa<mhlo::CustomCallOp>(op))
return failure();
auto customCallOp = cast<mhlo::CustomCallOp>(op);
auto callee = customCallOp.getCallTargetName();
auto libPath = converter.getCustomLibPath(callee);
if (libPath == "")
return failure();
auto apiName = converter.getApiName(callee);
auto extraArgs = converter.getExtraArgs(customCallOp);

auto newOp = rewriter.create<byre::CustomOp>(
customCallOp.getLoc(), customCallOp.getResultTypes(), libPath, apiName,
customCallOp.getOperands(), extraArgs, /*memEffects*/ ArrayAttr{});
rewriter.replaceOp(op, newOp.getResults());
return success();
}

protected:
ByreCustomConfig converter;
};

class ConvertHloToByreCustomPass
: public PassWrapper<ConvertHloToByreCustomPass,
OperationPass<func::FuncOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertHloToByreCustomPass)

ConvertHloToByreCustomPass(const ByreCustomConfig &rule) : converter(rule) {}

/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect>();
registry.insert<func::FuncDialect>();
registry.insert<byre::ByreDialect>();
}

void runOnOperation() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
auto funcOp = getOperation();

patterns.add<ConvertCustomCallOpToByreCustom>(patterns.getContext(),
converter);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns))) {
signalPassFailure();
}
}

protected:
ByreCustomConfig converter;
};

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createConvertHloToByreCustomPass(const ByreCustomConfig &converter) {
return std::make_unique<ConvertHloToByreCustomPass>(converter);
}
24 changes: 24 additions & 0 deletions compiler/lib/Dialect/Byre/IR/ByreDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,30 @@ std::string AliasOp::getCalleeName() { return "AliasOp"; }

Value AliasOp::getViewSource() { return getSource(); }

//===----------------------------------------------------------------------===//
// CustomOp
//===----------------------------------------------------------------------===/

void CustomOp::build(OpBuilder &builder, OperationState &result,
StringRef lib_path, StringRef api_name, ValueRange inputs,
ValueRange outputs, ArrayAttr extra_args) {
SmallVector<Attribute> memoryEffectAttrs;
memoryEffectAttrs.append(
inputs.size(), builder.getAttr<MemoryEffectAttr>(MemoryEffect::Read));
memoryEffectAttrs.append(
outputs.size(), builder.getAttr<MemoryEffectAttr>(MemoryEffect::Write));
build(builder, result, TypeRange{}, lib_path, api_name,
llvm::to_vector(llvm::concat<Value>(llvm::to_vector(inputs),
llvm::to_vector(outputs))),
extra_args, builder.getArrayAttr(memoryEffectAttrs));
}

std::string CustomOp::getCalleeName() { return "custom"; }

LogicalResult CustomOp::verify() {
return verifyOpInEntryPointFunc(this->getOperation());
}

// LWC: ignore Async for now
//
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit b60d7a6

Please sign in to comment.