Skip to content

Commit

Permalink
[torch-frontend] lowerint aten.upsample_bilinear to byteir.resize (#489)
Browse files Browse the repository at this point in the history
as title
  • Loading branch information
qingyunqu authored Dec 3, 2024
1 parent 6152f75 commit 4016d0e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1150,13 +1150,6 @@ class ConvertAtenUpsampleNearest2dOp : public OpConversionPattern<OP> {
// TODO: if result have dynamic shape, should lowering to target_mode=scale
if (!resultType.hasStaticShape())
return failure();
if constexpr (std::is_same_v<OP, AtenUpsampleNearest2dOp>) {
if (!isa<Torch::NoneType>(adaptor.getScalesH().getType()) ||
!isa<Torch::NoneType>(adaptor.getScalesW().getType())) {
// FIXME: check shape inference when scales_h or scales_w is not None.
return failure();
}
}

std::vector<NamedAttribute> byteir_attrs;
byteir_attrs.emplace_back(rewriter.getStringAttr("target_mode"),
Expand All @@ -1183,17 +1176,24 @@ class ConvertAtenUpsampleNearest2dOp : public OpConversionPattern<OP> {
}
};

// aten.upsample_bilinear2d.vec
class ConvertAtenUpsampleBilinear2dVecOp
: public OpConversionPattern<AtenUpsampleBilinear2dVecOp> {
// aten.upsample_bilinear2d.vec && aten.upsample_bilinear2d
template <typename OP>
class ConvertAtenUpsampleBilinear2dOp : public OpConversionPattern<OP> {
public:
using OpConversionPattern<AtenUpsampleBilinear2dVecOp>::OpConversionPattern;
using OpConversionPattern<OP>::OpConversionPattern;
using OpAdaptor = typename OP::Adaptor;
LogicalResult
matchAndRewrite(AtenUpsampleBilinear2dVecOp op, OpAdaptor adaptor,
matchAndRewrite(OP op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getInput();
Value input;
if constexpr (std::is_same_v<OP, AtenUpsampleBilinear2dOp>) {
input = adaptor.getSelf();
} else {
input = adaptor.getInput();
}
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
OpConversionPattern<OP>::getTypeConverter()->convertType(
op.getResult().getType()));

// TODO: if result have dynamic shape, should lowering to target_mode=scale
if (!resultType.hasStaticShape())
Expand Down Expand Up @@ -1387,8 +1387,13 @@ class ConvertTorchToCustomCall
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dOp<AtenUpsampleNearest2dVecOp>>(
typeConverter, context);
target.addIllegalOp<AtenUpsampleBilinear2dOp>();
patterns.add<ConvertAtenUpsampleBilinear2dOp<AtenUpsampleBilinear2dOp>>(
typeConverter, context);
target.addIllegalOp<AtenUpsampleBilinear2dVecOp>();
patterns.add<ConvertAtenUpsampleBilinear2dVecOp>(typeConverter, context);
patterns
.add<ConvertAtenUpsampleBilinear2dOp<AtenUpsampleBilinear2dVecOp>>(
typeConverter, context);
}

populateMathToCustomCallPattern(target, typeConverter, patterns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_resize_nearest():
model = torch.jit.trace(UpsampleNearest2dModule1(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

class UpsampleBilinear2dModule(torch.nn.Module):
class UpsampleBilinear2dVecModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -298,10 +298,10 @@ def forward(self, x):
@pytest.mark.mhlo_tools
def test_resize_bilinear():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleBilinear2dModule(), inputs)
model = torch.jit.trace(UpsampleBilinear2dVecModule(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

class UpsampleBilinear2dModule1(torch.nn.Module):
class UpsampleBilinear2dVecModule1(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -312,7 +312,21 @@ def forward(self, x):
@pytest.mark.mhlo_tools
def test_resize_bilinear_half_pixel():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleBilinear2dModule1(), inputs)
model = torch.jit.trace(UpsampleBilinear2dVecModule1(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

class UpsampleBilinear2dModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
#FIXME: use torch.nn.interpolate to avoid torch.jit.trace
return torch.ops.aten.upsample_bilinear2d(x, (11, 25), True, None, None)

@pytest.mark.mhlo_tools
def test_resize_bilinear_1():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleBilinear2dModule(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

# ==============================================================================
Expand Down

0 comments on commit 4016d0e

Please sign in to comment.