From eb91af2df5dd404d913503309c3902611e681404 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Wed, 7 Aug 2024 15:15:41 +0200 Subject: [PATCH] [TorchFX] Conformance test init (#2841) ### Changes Conformance test for resnet18 ### Reason for changes To extend testing scope for the TorchFX backend ### Related tickets https://github.com/openvinotoolkit/nncf/issues/2766 ### Tests post_training_quantization/442 is successfull --- .../data/ptq_reference_data.yaml | 22 ++-- tests/post_training/model_scope.py | 18 +-- tests/post_training/pipelines/base.py | 6 + .../pipelines/image_classification_base.py | 80 ++++++++++++ .../pipelines/image_classification_timm.py | 64 +-------- .../image_classification_torchvision.py | 121 ++++++++++++++++++ 6 files changed, 231 insertions(+), 80 deletions(-) create mode 100644 tests/post_training/pipelines/image_classification_base.py create mode 100644 tests/post_training/pipelines/image_classification_torchvision.py diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index 746bd6027b6..490cb7e73da 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -22,6 +22,18 @@ hf/hf-internal-testing/tiny-random-gpt2_backend_OV: metric_value: null hf/hf-internal-testing/tiny-random-gpt2_backend_TORCH: metric_value: null +torchvision/resnet18_backend_FP32: + metric_value: 0.6978 +torchvision/resnet18_backend_OV: + metric_value: 0.6948 +torchvision/resnet18_backend_ONNX: + metric_value: 0.6948 +torchvision/resnet18_backend_TORCH: + metric_value: 0.69152 +torchvision/resnet18_backend_CUDA_TORCH: + metric_value: 0.69152 +torchvision/resnet18_backend_FX_TORCH: + metric_value: 0.6946 timm/crossvit_9_240_backend_CUDA_TORCH: metric_value: 0.689 timm/crossvit_9_240_backend_FP32: @@ -180,16 +192,6 @@ timm/resnest14d_backend_OV: metric_value: 0.75 timm/resnest14d_backend_TORCH: metric_value: 0.7485 -timm/resnet18_backend_CUDA_TORCH: - metric_value: 0.69748 -timm/resnet18_backend_FP32: - metric_value: 0.71502 -timm/resnet18_backend_ONNX: - metric_value: 0.71102 -timm/resnet18_backend_OV: - metric_value: 0.71116 -timm/resnet18_backend_TORCH: - metric_value: 0.70982 timm/swin_base_patch4_window7_224_backend_FP32: metric_value: 0.85274 timm/swin_base_patch4_window7_224_backend_OV: diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index e4d2694e7d3..7f78d70528c 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -27,6 +27,7 @@ from tests.post_training.pipelines.causal_language_model import CausalLMHF from tests.post_training.pipelines.gpt import GPT from tests.post_training.pipelines.image_classification_timm import ImageClassificationTimm +from tests.post_training.pipelines.image_classification_torchvision import ImageClassificationTorchvision from tests.post_training.pipelines.lm_weight_compression import LMWeightCompression from tests.post_training.pipelines.masked_language_modeling import MaskedLanguageModelingHF @@ -65,6 +66,15 @@ }, "backends": [BackendType.TORCH, BackendType.OV, BackendType.OPTIMUM], }, + # Torchvision models + { + "reported_name": "torchvision/resnet18", + "model_id": "resnet18", + "pipeline_cls": ImageClassificationTorchvision, + "compression_params": {}, + "backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX], + "batch_size": 128, + }, # Timm models { "reported_name": "timm/crossvit_9_240", @@ -246,14 +256,6 @@ "backends": ALL_PTQ_BACKENDS, "batch_size": 128, }, - { - "reported_name": "timm/resnet18", - "model_id": "resnet18", - "pipeline_cls": ImageClassificationTimm, - "compression_params": {}, - "backends": ALL_PTQ_BACKENDS, - "batch_size": 128, - }, { "reported_name": "timm/swin_base_patch4_window7_224", "model_id": "swin_base_patch4_window7_224", diff --git a/tests/post_training/pipelines/base.py b/tests/post_training/pipelines/base.py index b1f7b0a558b..c74569da3ee 100644 --- a/tests/post_training/pipelines/base.py +++ b/tests/post_training/pipelines/base.py @@ -38,6 +38,7 @@ class BackendType(Enum): FP32 = "FP32" TORCH = "TORCH" CUDA_TORCH = "CUDA_TORCH" + FX_TORCH = "FX_TORCH" ONNX = "ONNX" OV = "OV" OPTIMUM = "OPTIMUM" @@ -367,6 +368,11 @@ def save_compressed_model(self) -> None: ) self.path_compressed_ir = self.output_model_dir / "model.xml" ov.serialize(ov_model, self.path_compressed_ir) + elif self.backend == BackendType.FX_TORCH: + exported_model = torch.export.export(self.compressed_model, (self.dummy_tensor,)) + ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor.cpu(), input=self.input_size) + self.path_compressed_ir = self.output_model_dir / "model.xml" + ov.serialize(ov_model, self.path_compressed_ir) elif self.backend == BackendType.ONNX: onnx_path = self.output_model_dir / "model.onnx" onnx.save(self.compressed_model, str(onnx_path)) diff --git a/tests/post_training/pipelines/image_classification_base.py b/tests/post_training/pipelines/image_classification_base.py new file mode 100644 index 00000000000..22e60a5ae3b --- /dev/null +++ b/tests/post_training/pipelines/image_classification_base.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os + +import numpy as np +import openvino as ov +import torch +from sklearn.metrics import accuracy_score +from torchvision import datasets + +import nncf +from nncf.common.logging.track_progress import track +from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS +from tests.post_training.pipelines.base import PTQTestPipeline + + +class ImageClassificationBase(PTQTestPipeline): + """Base pipeline for Image Classification models""" + + def prepare_calibration_dataset(self): + dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform) + loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=2, shuffle=False) + + self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn()) + + def _validate(self): + val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) + + dataset_size = len(val_loader) + + # Initialize result tensors for async inference support. + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + core = ov.Core() + + if os.environ.get("INFERENCE_NUM_THREADS"): + # Set CPU_THREADS_NUM for OpenVINO inference + inference_num_threads = os.environ.get("INFERENCE_NUM_THREADS") + core.set_property("CPU", properties={"INFERENCE_NUM_THREADS": str(inference_num_threads)}) + + ov_model = core.read_model(self.path_compressed_ir) + compiled_model = core.compile_model(ov_model, device_name="CPU") + + jobs = int(os.environ.get("NUM_VAL_THREADS", DEFAULT_VAL_THREADS)) + infer_queue = ov.AsyncInferQueue(compiled_model, jobs) + + with track(total=dataset_size, description="Validation") as pbar: + + def process_result(request, userdata): + output_data = request.get_output_tensor().data + predicted_label = np.argmax(output_data, axis=1) + predictions[userdata] = predicted_label + pbar.progress.update(pbar.task, advance=1) + + infer_queue.set_callback(process_result) + + for i, (images, target) in enumerate(val_loader): + # W/A for memory leaks when using torch DataLoader and OpenVINO + image_copies = copy.deepcopy(images.numpy()) + infer_queue.start_async(image_copies, userdata=i) + references[i] = target + + infer_queue.wait_all() + + acc_top1 = accuracy_score(predictions, references) + + self.run_info.metric_name = "Acc@1" + self.run_info.metric_value = acc_top1 diff --git a/tests/post_training/pipelines/image_classification_timm.py b/tests/post_training/pipelines/image_classification_timm.py index 601ec01d28b..3f4e159331c 100644 --- a/tests/post_training/pipelines/image_classification_timm.py +++ b/tests/post_training/pipelines/image_classification_timm.py @@ -9,32 +9,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import os - import numpy as np import onnx import openvino as ov import timm import torch -from sklearn.metrics import accuracy_score from timm.data.transforms_factory import transforms_imagenet_eval from timm.layers.config import set_fused_attn -from torchvision import datasets -import nncf -from nncf.common.logging.track_progress import track -from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS from tests.post_training.pipelines.base import OV_BACKENDS from tests.post_training.pipelines.base import PT_BACKENDS from tests.post_training.pipelines.base import BackendType -from tests.post_training.pipelines.base import PTQTestPipeline +from tests.post_training.pipelines.image_classification_base import ImageClassificationBase # Disable using aten::scaled_dot_product_attention set_fused_attn(False, False) -class ImageClassificationTimm(PTQTestPipeline): +class ImageClassificationTimm(ImageClassificationBase): """Pipeline for Image Classification model from timm repository""" def prepare_model(self) -> None: @@ -113,55 +105,3 @@ def transform_fn(data_item): return {self.input_name: np.array(images, dtype=np.float32)} return transform_fn - - def prepare_calibration_dataset(self): - dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform) - loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=2, shuffle=False) - - self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn()) - - def _validate(self): - val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform) - val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) - - dataset_size = len(val_loader) - - # Initialize result tensors for async inference support. - predictions = np.zeros((dataset_size)) - references = -1 * np.ones((dataset_size)) - - core = ov.Core() - - if os.environ.get("INFERENCE_NUM_THREADS"): - # Set CPU_THREADS_NUM for OpenVINO inference - inference_num_threads = os.environ.get("INFERENCE_NUM_THREADS") - core.set_property("CPU", properties={"INFERENCE_NUM_THREADS": str(inference_num_threads)}) - - ov_model = core.read_model(self.path_compressed_ir) - compiled_model = core.compile_model(ov_model, device_name="CPU") - - jobs = int(os.environ.get("NUM_VAL_THREADS", DEFAULT_VAL_THREADS)) - infer_queue = ov.AsyncInferQueue(compiled_model, jobs) - - with track(total=dataset_size, description="Validation") as pbar: - - def process_result(request, userdata): - output_data = request.get_output_tensor().data - predicted_label = np.argmax(output_data, axis=1) - predictions[userdata] = predicted_label - pbar.progress.update(pbar.task, advance=1) - - infer_queue.set_callback(process_result) - - for i, (images, target) in enumerate(val_loader): - # W/A for memory leaks when using torch DataLoader and OpenVINO - image_copies = copy.deepcopy(images.numpy()) - infer_queue.start_async(image_copies, userdata=i) - references[i] = target - - infer_queue.wait_all() - - acc_top1 = accuracy_score(predictions, references) - - self.run_info.metric_name = "Acc@1" - self.run_info.metric_value = acc_top1 diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py new file mode 100644 index 00000000000..91e586605cb --- /dev/null +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import onnx +import openvino as ov +import torch +from torch._export import capture_pre_autograd_graph +from torchvision import models + +from nncf.torch import disable_patching +from tests.post_training.pipelines.base import PT_BACKENDS +from tests.post_training.pipelines.base import BackendType +from tests.post_training.pipelines.image_classification_base import ImageClassificationBase + + +class ImageClassificationTorchvision(ImageClassificationBase): + """Pipeline for Image Classification model from torchvision repository""" + + models_vs_imagenet_weights = { + models.resnet18: models.ResNet18_Weights.DEFAULT, + models.mobilenet_v3_small: models.MobileNet_V3_Small_Weights.DEFAULT, + models.vit_b_16: models.ViT_B_16_Weights.DEFAULT, + models.swin_v2_s: models.Swin_V2_S_Weights.DEFAULT, + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_weights: models.WeightsEnum = None + self.input_name: str = None + + def prepare_model(self) -> None: + model_cls = models.__dict__.get(self.model_id) + self.model_weights = self.models_vs_imagenet_weights[model_cls] + model = model_cls(weights=self.model_weights) + model.eval() + + self.static_input_size = [self.batch_size, 3, 224, 224] + self.input_size = self.static_input_size.copy() + if self.batch_size > 1: # Dynamic batch_size shape export + self.input_size[0] = -1 + + self.dummy_tensor = torch.rand(self.static_input_size) + + if self.backend == BackendType.FX_TORCH: + with torch.no_grad(): + with disable_patching(): + self.model = capture_pre_autograd_graph(model, (self.dummy_tensor,)) + + elif self.backend in PT_BACKENDS: + self.model = model + + if self.backend == BackendType.ONNX: + onnx_path = self.fp32_model_dir / "model_fp32.onnx" + additional_kwargs = {} + if self.batch_size > 1: + additional_kwargs["input_names"] = ["image"] + additional_kwargs["dynamic_axes"] = {"image": {0: "batch"}} + torch.onnx.export( + model, self.dummy_tensor, onnx_path, export_params=True, opset_version=13, **additional_kwargs + ) + self.model = onnx.load(onnx_path) + self.input_name = self.model.graph.input[0].name + + elif self.backend in [BackendType.OV, BackendType.FP32]: + with torch.no_grad(): + self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size) + self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0] + + self._dump_model_fp32() + + # Set device after dump fp32 model + if self.backend == BackendType.CUDA_TORCH: + self.model.cuda() + self.dummy_tensor = self.dummy_tensor.cuda() + + def _dump_model_fp32(self) -> None: + """Dump IRs of fp32 models, to help debugging.""" + if self.backend in PT_BACKENDS: + with disable_patching(): + ov_model = ov.convert_model( + torch.export.export(self.model, args=(self.dummy_tensor,)), + example_input=self.dummy_tensor, + input=self.input_size, + ) + ov.serialize(ov_model, self.fp32_model_dir / "model_fp32.xml") + + if self.backend == BackendType.FX_TORCH: + exported_model = torch.export.export(self.model, (self.dummy_tensor,)) + ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size) + ov.serialize(ov_model, self.fp32_model_dir / "fx_model_fp32.xml") + + if self.backend in [BackendType.FP32, BackendType.OV]: + ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml") + + def prepare_preprocessor(self) -> None: + self.transform = self.model_weights.transforms() + + def get_transform_calibration_fn(self): + if self.backend in [BackendType.FX_TORCH] + PT_BACKENDS: + device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu") + + def transform_fn(data_item): + images, _ = data_item + return images.to(device) + + else: + + def transform_fn(data_item): + images, _ = data_item + return {self.input_name: np.array(images, dtype=np.float32)} + + return transform_fn