diff --git a/examples/models/flamingo/export_preprocess_lib.py b/examples/models/flamingo/export_preprocess_lib.py index 736116de8b..082c306ea3 100644 --- a/examples/models/flamingo/export_preprocess_lib.py +++ b/examples/models/flamingo/export_preprocess_lib.py @@ -15,10 +15,6 @@ from torch.export import Dim, ExportedProgram from torchtune.models.clip.inference._transforms import _CLIPImageTransform -from .passes.replace_custom_ops_with_aten_ops_pass import ( - ReplaceCustomOpsWithAtenOpsPass, -) - def get_example_inputs() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image = torch.ones(3, 800, 600) @@ -59,7 +55,6 @@ def export_preprocess( ) # Replace non-exportable ops with custom ops. - image_transform_model.pad = torch.ops.preprocess.pad.default image_transform_model.tile_crop = torch.ops.preprocess.tile_crop.default # Export. @@ -80,8 +75,6 @@ def lower_to_executorch_preprocess( edge_program = to_edge( exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False) ) - # Replace custom ops with aten ops. - edge_program = edge_program.transform([ReplaceCustomOpsWithAtenOpsPass()]) et_program = edge_program.to_executorch(ExecutorchBackendConfig()) return et_program diff --git a/examples/models/flamingo/passes/__init__.py b/examples/models/flamingo/passes/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/models/flamingo/passes/replace_custom_ops_with_aten_ops_pass.py b/examples/models/flamingo/passes/replace_custom_ops_with_aten_ops_pass.py deleted file mode 100644 index 8c31cf512c..0000000000 --- a/examples/models/flamingo/passes/replace_custom_ops_with_aten_ops_pass.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import torch -from executorch.exir.pass_base import ExportPass -from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa - - -class ReplaceCustomOpsWithAtenOpsPass(ExportPass): - """ - Goes through all ops and replaces custom ops with aten ops. In some cases - aten ops cannot be exported due to dynamism, eg. pad in flamingo preprocess. - Use a custom op to pass export, and replace it with the aten op post-export, - which avoids re-writing the op in C++. - """ - - def __init__(self) -> None: - super().__init__() - - def call_operator(self, op, args, kwargs, meta): - if op._name == "preprocess::pad": - return super().call_operator( - torch.ops.aten.constant_pad_nd.default, args, kwargs, meta - ) - - return super().call_operator(op, args, kwargs, meta) diff --git a/examples/models/flamingo/passes/test_passes.py b/examples/models/flamingo/passes/test_passes.py deleted file mode 100644 index d0a90f2e34..0000000000 --- a/examples/models/flamingo/passes/test_passes.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import unittest - -from typing import List - -import torch -from executorch.exir import EdgeCompileConfig, to_edge - -from .replace_custom_ops_with_aten_ops_pass import ReplaceCustomOpsWithAtenOpsPass - - -class TestPasses(unittest.TestCase): - def test_replace_custom_ops_with_aten_ops_pass(self) -> None: - from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa - - class Pad(torch.nn.Module): - def forward(self, x: torch.Tensor, padding: List[int]) -> torch.Tensor: - return torch.ops.preprocess.pad.default(x, padding) - - pad = Pad() - - image_tensor = torch.ones([3, 4, 5]) - padding = [0, 2, 0, 1] - - edge_prog = to_edge( - torch.export.export(pad, (image_tensor, padding), strict=False), - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - - # Check that the custom op exists in the graph, and aten op does not. - edge_nodes = [node.name for node in edge_prog.exported_program().graph.nodes] - assert "constant_pad_nd" not in edge_nodes - assert "preprocess_pad_default" in edge_nodes - - edge_prog = edge_prog.transform([ReplaceCustomOpsWithAtenOpsPass()]) - - # After running replace_custom_ops_with_aten_ops pass, the custom op - # should be replaced with aten op. - post_transform_nodes = [ - node.name for node in edge_prog.exported_program().graph.nodes - ] - assert "constant_pad_nd" in post_transform_nodes - assert "preprocess_pad_default" not in post_transform_nodes diff --git a/extension/llm/custom_ops/preprocess_custom_ops.py b/extension/llm/custom_ops/preprocess_custom_ops.py index aea8c09b0e..e49721ffd3 100644 --- a/extension/llm/custom_ops/preprocess_custom_ops.py +++ b/extension/llm/custom_ops/preprocess_custom_ops.py @@ -7,61 +7,12 @@ # pyre-unsafe -from typing import List - import torch from torch.library import impl, Library preprocess_op_lib = Library("preprocess", "DEF") -# Register and define pad and out variant. -# Note: pad doesn't require an explicit meta kernel because -# CompositeExplicitAutograd automatically registers the implementation to meta, -# and meta kernels do not go through functionalization. The implementation -# does not export due to issues during functionalization. -# See: https://github.com/pytorch/pytorch/issues/120288 -preprocess_op_lib.define("pad(Tensor image, SymInt[] padding) -> Tensor") - - -@impl(preprocess_op_lib, "pad", dispatch_key="CompositeExplicitAutograd") -def pad_impl( - image: torch.Tensor, - padding: List[int], -) -> torch.Tensor: - output = torch.empty( - [image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]], - dtype=image.dtype, - device=image.device, - requires_grad=False, - ) - output = torch.fill(output, 0) - output.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image) - return output - - -preprocess_op_lib.define( - "pad.out(Tensor image, SymInt[] padding, *, Tensor(a!) out) -> Tensor(a!)" -) - - -@impl(preprocess_op_lib, "pad.out", dispatch_key="CompositeExplicitAutograd") -def pad_out_impl( - image: torch.Tensor, - padding: List[int], - out: torch.Tensor, -) -> torch.Tensor: - out = torch.empty( - [image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]], - dtype=image.dtype, - device=image.device, - requires_grad=False, - ) - out = torch.fill(out, 0) - out.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image) - return out - - # Register and define tile_crop and out variant. preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor")