Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-frontend] lowerint aten.upsample_bilinear to byteir.resize #489

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1150,13 +1150,6 @@ class ConvertAtenUpsampleNearest2dOp : public OpConversionPattern<OP> {
// TODO: if result have dynamic shape, should lowering to target_mode=scale
if (!resultType.hasStaticShape())
return failure();
if constexpr (std::is_same_v<OP, AtenUpsampleNearest2dOp>) {
if (!isa<Torch::NoneType>(adaptor.getScalesH().getType()) ||
!isa<Torch::NoneType>(adaptor.getScalesW().getType())) {
// FIXME: check shape inference when scales_h or scales_w is not None.
return failure();
}
}

std::vector<NamedAttribute> byteir_attrs;
byteir_attrs.emplace_back(rewriter.getStringAttr("target_mode"),
Expand All @@ -1183,17 +1176,24 @@ class ConvertAtenUpsampleNearest2dOp : public OpConversionPattern<OP> {
}
};

// aten.upsample_bilinear2d.vec
class ConvertAtenUpsampleBilinear2dVecOp
: public OpConversionPattern<AtenUpsampleBilinear2dVecOp> {
// aten.upsample_bilinear2d.vec && aten.upsample_bilinear2d
template <typename OP>
class ConvertAtenUpsampleBilinear2dOp : public OpConversionPattern<OP> {
public:
using OpConversionPattern<AtenUpsampleBilinear2dVecOp>::OpConversionPattern;
using OpConversionPattern<OP>::OpConversionPattern;
using OpAdaptor = typename OP::Adaptor;
LogicalResult
matchAndRewrite(AtenUpsampleBilinear2dVecOp op, OpAdaptor adaptor,
matchAndRewrite(OP op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getInput();
Value input;
if constexpr (std::is_same_v<OP, AtenUpsampleBilinear2dOp>) {
input = adaptor.getSelf();
} else {
input = adaptor.getInput();
}
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
OpConversionPattern<OP>::getTypeConverter()->convertType(
op.getResult().getType()));

// TODO: if result have dynamic shape, should lowering to target_mode=scale
if (!resultType.hasStaticShape())
Expand Down Expand Up @@ -1387,8 +1387,13 @@ class ConvertTorchToCustomCall
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dOp<AtenUpsampleNearest2dVecOp>>(
typeConverter, context);
target.addIllegalOp<AtenUpsampleBilinear2dOp>();
patterns.add<ConvertAtenUpsampleBilinear2dOp<AtenUpsampleBilinear2dOp>>(
typeConverter, context);
target.addIllegalOp<AtenUpsampleBilinear2dVecOp>();
patterns.add<ConvertAtenUpsampleBilinear2dVecOp>(typeConverter, context);
patterns
.add<ConvertAtenUpsampleBilinear2dOp<AtenUpsampleBilinear2dVecOp>>(
typeConverter, context);
}

populateMathToCustomCallPattern(target, typeConverter, patterns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_resize_nearest():
model = torch.jit.trace(UpsampleNearest2dModule1(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

class UpsampleBilinear2dModule(torch.nn.Module):
class UpsampleBilinear2dVecModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -298,10 +298,10 @@ def forward(self, x):
@pytest.mark.mhlo_tools
def test_resize_bilinear():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleBilinear2dModule(), inputs)
model = torch.jit.trace(UpsampleBilinear2dVecModule(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

class UpsampleBilinear2dModule1(torch.nn.Module):
class UpsampleBilinear2dVecModule1(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -312,7 +312,21 @@ def forward(self, x):
@pytest.mark.mhlo_tools
def test_resize_bilinear_half_pixel():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleBilinear2dModule1(), inputs)
model = torch.jit.trace(UpsampleBilinear2dVecModule1(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

class UpsampleBilinear2dModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
#FIXME: use torch.nn.interpolate to avoid torch.jit.trace
return torch.ops.aten.upsample_bilinear2d(x, (11, 25), True, None, None)

@pytest.mark.mhlo_tools
def test_resize_bilinear_1():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleBilinear2dModule(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

# ==============================================================================
Expand Down