Skip to content

Commit

Permalink
Switch XNNPack tests to use export_for_training
Browse files Browse the repository at this point in the history
Differential Revision: D61684468

Pull Request resolved: #4867
  • Loading branch information
tarun292 authored Sep 5, 2024
1 parent e854967 commit 866e9b8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions backends/xnnpack/test/test_xnnpack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export_for_training

from torch.testing import FileCheck

Expand Down Expand Up @@ -315,10 +316,11 @@ def quantize_and_test_model_with_quantizer(
):
module.eval()
# program capture
m = torch._export.capture_pre_autograd_graph(

m = export_for_training(
module,
example_inputs,
)
).module()

quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config()
Expand Down
8 changes: 4 additions & 4 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.export._trace as export_trace
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.passes import XNNPACKPassManager
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
Expand All @@ -31,6 +30,7 @@
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.exir.print_program import pretty_print, print_program
from torch.export import export_for_training

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -157,10 +157,10 @@ def __init__(
def run(
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
) -> None:
captured_graph = export_trace._export(
artifact, inputs, pre_dispatch=True
).module()
assert inputs is not None
captured_graph = export_for_training(artifact, inputs).module()

assert isinstance(captured_graph, torch.fx.GraphModule)
prepared = prepare_pt2e(captured_graph, self.quantizer)

if self.calibrate:
Expand Down

0 comments on commit 866e9b8

Please sign in to comment.