From 866e9b84d46bc83ae184ee627893f3bb55839260 Mon Sep 17 00:00:00 2001 From: Tarun Karuturi <58826100+tarun292@users.noreply.github.com> Date: Thu, 5 Sep 2024 09:35:39 -0700 Subject: [PATCH] Switch XNNPack tests to use export_for_training Differential Revision: D61684468 Pull Request resolved: https://github.com/pytorch/executorch/pull/4867 --- backends/xnnpack/test/test_xnnpack_utils.py | 6 ++++-- backends/xnnpack/test/tester/tester.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 3f5359a3f4..ea9217e04a 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -72,6 +72,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import export_for_training from torch.testing import FileCheck @@ -315,10 +316,11 @@ def quantize_and_test_model_with_quantizer( ): module.eval() # program capture - m = torch._export.capture_pre_autograd_graph( + + m = export_for_training( module, example_inputs, - ) + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config() diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index eb25a14cfe..7586c4f231 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -14,7 +14,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch -import torch.export._trace as export_trace from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.backends.xnnpack.passes import XNNPACKPassManager from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config @@ -31,6 +30,7 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program +from torch.export import export_for_training logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -157,10 +157,10 @@ def __init__( def run( self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] ) -> None: - captured_graph = export_trace._export( - artifact, inputs, pre_dispatch=True - ).module() + assert inputs is not None + captured_graph = export_for_training(artifact, inputs).module() + assert isinstance(captured_graph, torch.fx.GraphModule) prepared = prepare_pt2e(captured_graph, self.quantizer) if self.calibrate: