Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ita9naiwa committed Jan 27, 2025
1 parent 078a72b commit 8c25391
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,8 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,

// Compute output = exp2(output - input)
static Value computeSubAndExp(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap outputMap,
Value input, Value output,
bool useExp2) {
AffineMap inputMap, AffineMap outputMap,
Value input, Value output, bool useExp2) {
SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
inputMap = compressedMaps[0];
Expand Down Expand Up @@ -415,7 +414,8 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
if (mlir::BoolAttr useExp2Attr = mlir::dyn_cast_or_null<mlir::BoolAttr>(config.get(getUseExp2AttrStr()))) {
if (mlir::BoolAttr useExp2Attr = mlir::dyn_cast_or_null<mlir::BoolAttr>(
config.get(getUseExp2AttrStr()))) {
useExp2 = useExp2Attr.getValue();
}
}
Expand Down Expand Up @@ -538,7 +538,8 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
if (mlir::BoolAttr useExp2Attr = mlir::dyn_cast_or_null<mlir::BoolAttr>(config.get(getUseExp2AttrStr()))) {
if (mlir::BoolAttr useExp2Attr = mlir::dyn_cast_or_null<mlir::BoolAttr>(
config.get(getUseExp2AttrStr()))) {
useExp2 = useExp2Attr.getValue();
}
}
Expand Down Expand Up @@ -574,7 +575,8 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
// norm = exp2(oldMax - newMax)
// normMap = maxMap
AffineMap normMap = getMaxMap();
Value norm = computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2);
Value norm =
computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2);

// normSum = norm * oldSum
AffineMap sumMap = getSumMap();
Expand Down

0 comments on commit 8c25391

Please sign in to comment.