Skip to content

Commit

Permalink
[AutoBump] Merge with fixes of c6876b4 (Sep 26)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jan 9, 2025
2 parents f78fcd7 + c6876b4 commit b28576b
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 20 deletions.
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
}];

let arguments = (ins
Tosa_Tensor: $input,
Tosa_Tensor: $input1,
Tosa_Tensor1D: $table
);

Expand All @@ -896,7 +896,7 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
);

let assemblyFormat = [{
$input `,` $table attr-dict `:` `(` type($input) `,` type($table) `)` `->` type($output)
$input1 `,` $table attr-dict `:` `(` type($input1) `,` type($table) `)` `->` type($output)
}];

let hasVerifier = 1;
Expand Down Expand Up @@ -1664,7 +1664,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
}];

let arguments = (ins
Tosa_Tensor:$input,
Tosa_Tensor:$input1,
I32Attr:$axis
);

Expand All @@ -1691,7 +1691,7 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
}];

let arguments = (ins
Tosa_Tensor:$input,
Tosa_Tensor:$input1,
DenseI64ArrayAttr:$start,
DenseI64ArrayAttr:$size
);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,7 +1957,7 @@ class ReverseConverter : public OpConversionPattern<tosa::ReverseOp> {
matchAndRewrite(tosa::ReverseOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op.getLoc();
Value input = operands.getInput();
Value input = operands.getInput1();
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy =
cast_or_null<ShapedType>(getTypeConverter()->convertType(op.getType()));
Expand Down Expand Up @@ -2300,7 +2300,7 @@ class TableConverter : public OpConversionPattern<tosa::TableOp> {
matchAndRewrite(tosa::TableOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op.getLoc();
Value input = operands.getInput();
Value input = operands.getInput1();
Value table = operands.getTable();
auto inputTy = cast<ShapedType>(input.getType());
auto tableTy = cast<ShapedType>(table.getType());
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
for (; currRhsDim < rhsShape.size(); currRhsDim++) {
assert(rhsShape[currRhsDim] == 1);
}

return lhsType.clone(intermediateShape);
}

Expand Down Expand Up @@ -264,7 +264,7 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = sliceOp.getLoc();
Value input = adaptor.getInput();
Value input = adaptor.getInput1();
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
if (llvm::isa<UnrankedTensorType>(resultType))
return failure();
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {

LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const override {
Value sliceInput = sliceOp.getInput();
Value sliceInput = sliceOp.getInput1();
auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
if (!concatOp)
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -1212,11 +1212,11 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput();
auto operand = getInput1();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
auto operandAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
if (operandAttr)
return operandAttr;

Expand All @@ -1229,24 +1229,24 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());

if (!inputTy || !outputTy)
return {};

if (inputTy == outputTy && inputTy.hasStaticShape())
return getInput();
return getInput1();

if (!adaptor.getInput())
if (!adaptor.getInput1())
return {};

// Cannot create an ElementsAttr from non-int/float/index types
if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
!outputTy.getElementType().isIntOrIndexOrFloat())
return {};

auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
}

LogicalResult tosa::SliceOp::verify() {
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto outputType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !outputType)
return success();
Expand Down Expand Up @@ -931,7 +931,7 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TableOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput().getType());
ShapeAdaptor inputShape(adaptor.getInput1().getType());

if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
Expand All @@ -944,7 +944,7 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
}

LogicalResult tosa::TableOp::verify() {
TensorType inputType = getInput().getType();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();

if (inputType.hasRank() && outputType.hasRank() &&
Expand Down Expand Up @@ -2052,7 +2052,7 @@ void IfOp::print(OpAsmPrinter &p) {
}

LogicalResult ReverseOp::verify() {
TensorType inputType = getInput().getType();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
int32_t reverseAxis = getAxis();

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1798,7 +1798,7 @@ struct TosaFoldConstantSlice : public TosaFoldConstantBase<tosa::SliceOp> {
return failure();

auto start = op.getStart();
auto input = op.getInput();
auto input = op.getInput1();
ElementsAttr inputValues;
if (!matchPattern(input, m_Constant(&inputValues)))
return failure();
Expand Down

0 comments on commit b28576b

Please sign in to comment.