Skip to content

Commit

Permalink
[TorchFX] Conformance test init (#2841)
Browse files Browse the repository at this point in the history
### Changes

Conformance test for resnet18

### Reason for changes

To extend testing scope for the TorchFX backend

### Related tickets

#2766

### Tests

post_training_quantization/442 is successfull
  • Loading branch information
daniil-lyakhov authored Aug 7, 2024
1 parent b108455 commit eb91af2
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 80 deletions.
22 changes: 12 additions & 10 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ hf/hf-internal-testing/tiny-random-gpt2_backend_OV:
metric_value: null
hf/hf-internal-testing/tiny-random-gpt2_backend_TORCH:
metric_value: null
torchvision/resnet18_backend_FP32:
metric_value: 0.6978
torchvision/resnet18_backend_OV:
metric_value: 0.6948
torchvision/resnet18_backend_ONNX:
metric_value: 0.6948
torchvision/resnet18_backend_TORCH:
metric_value: 0.69152
torchvision/resnet18_backend_CUDA_TORCH:
metric_value: 0.69152
torchvision/resnet18_backend_FX_TORCH:
metric_value: 0.6946
timm/crossvit_9_240_backend_CUDA_TORCH:
metric_value: 0.689
timm/crossvit_9_240_backend_FP32:
Expand Down Expand Up @@ -180,16 +192,6 @@ timm/resnest14d_backend_OV:
metric_value: 0.75
timm/resnest14d_backend_TORCH:
metric_value: 0.7485
timm/resnet18_backend_CUDA_TORCH:
metric_value: 0.69748
timm/resnet18_backend_FP32:
metric_value: 0.71502
timm/resnet18_backend_ONNX:
metric_value: 0.71102
timm/resnet18_backend_OV:
metric_value: 0.71116
timm/resnet18_backend_TORCH:
metric_value: 0.70982
timm/swin_base_patch4_window7_224_backend_FP32:
metric_value: 0.85274
timm/swin_base_patch4_window7_224_backend_OV:
Expand Down
18 changes: 10 additions & 8 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tests.post_training.pipelines.causal_language_model import CausalLMHF
from tests.post_training.pipelines.gpt import GPT
from tests.post_training.pipelines.image_classification_timm import ImageClassificationTimm
from tests.post_training.pipelines.image_classification_torchvision import ImageClassificationTorchvision
from tests.post_training.pipelines.lm_weight_compression import LMWeightCompression
from tests.post_training.pipelines.masked_language_modeling import MaskedLanguageModelingHF

Expand Down Expand Up @@ -65,6 +66,15 @@
},
"backends": [BackendType.TORCH, BackendType.OV, BackendType.OPTIMUM],
},
# Torchvision models
{
"reported_name": "torchvision/resnet18",
"model_id": "resnet18",
"pipeline_cls": ImageClassificationTorchvision,
"compression_params": {},
"backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX],
"batch_size": 128,
},
# Timm models
{
"reported_name": "timm/crossvit_9_240",
Expand Down Expand Up @@ -246,14 +256,6 @@
"backends": ALL_PTQ_BACKENDS,
"batch_size": 128,
},
{
"reported_name": "timm/resnet18",
"model_id": "resnet18",
"pipeline_cls": ImageClassificationTimm,
"compression_params": {},
"backends": ALL_PTQ_BACKENDS,
"batch_size": 128,
},
{
"reported_name": "timm/swin_base_patch4_window7_224",
"model_id": "swin_base_patch4_window7_224",
Expand Down
6 changes: 6 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class BackendType(Enum):
FP32 = "FP32"
TORCH = "TORCH"
CUDA_TORCH = "CUDA_TORCH"
FX_TORCH = "FX_TORCH"
ONNX = "ONNX"
OV = "OV"
OPTIMUM = "OPTIMUM"
Expand Down Expand Up @@ -367,6 +368,11 @@ def save_compressed_model(self) -> None:
)
self.path_compressed_ir = self.output_model_dir / "model.xml"
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend == BackendType.FX_TORCH:
exported_model = torch.export.export(self.compressed_model, (self.dummy_tensor,))
ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor.cpu(), input=self.input_size)
self.path_compressed_ir = self.output_model_dir / "model.xml"
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend == BackendType.ONNX:
onnx_path = self.output_model_dir / "model.onnx"
onnx.save(self.compressed_model, str(onnx_path))
Expand Down
80 changes: 80 additions & 0 deletions tests/post_training/pipelines/image_classification_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os

import numpy as np
import openvino as ov
import torch
from sklearn.metrics import accuracy_score
from torchvision import datasets

import nncf
from nncf.common.logging.track_progress import track
from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS
from tests.post_training.pipelines.base import PTQTestPipeline


class ImageClassificationBase(PTQTestPipeline):
"""Base pipeline for Image Classification models"""

def prepare_calibration_dataset(self):
dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=2, shuffle=False)

self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn())

def _validate(self):
val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False)

dataset_size = len(val_loader)

# Initialize result tensors for async inference support.
predictions = np.zeros((dataset_size))
references = -1 * np.ones((dataset_size))

core = ov.Core()

if os.environ.get("INFERENCE_NUM_THREADS"):
# Set CPU_THREADS_NUM for OpenVINO inference
inference_num_threads = os.environ.get("INFERENCE_NUM_THREADS")
core.set_property("CPU", properties={"INFERENCE_NUM_THREADS": str(inference_num_threads)})

ov_model = core.read_model(self.path_compressed_ir)
compiled_model = core.compile_model(ov_model, device_name="CPU")

jobs = int(os.environ.get("NUM_VAL_THREADS", DEFAULT_VAL_THREADS))
infer_queue = ov.AsyncInferQueue(compiled_model, jobs)

with track(total=dataset_size, description="Validation") as pbar:

def process_result(request, userdata):
output_data = request.get_output_tensor().data
predicted_label = np.argmax(output_data, axis=1)
predictions[userdata] = predicted_label
pbar.progress.update(pbar.task, advance=1)

infer_queue.set_callback(process_result)

for i, (images, target) in enumerate(val_loader):
# W/A for memory leaks when using torch DataLoader and OpenVINO
image_copies = copy.deepcopy(images.numpy())
infer_queue.start_async(image_copies, userdata=i)
references[i] = target

infer_queue.wait_all()

acc_top1 = accuracy_score(predictions, references)

self.run_info.metric_name = "Acc@1"
self.run_info.metric_value = acc_top1
64 changes: 2 additions & 62 deletions tests/post_training/pipelines/image_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os

import numpy as np
import onnx
import openvino as ov
import timm
import torch
from sklearn.metrics import accuracy_score
from timm.data.transforms_factory import transforms_imagenet_eval
from timm.layers.config import set_fused_attn
from torchvision import datasets

import nncf
from nncf.common.logging.track_progress import track
from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS
from tests.post_training.pipelines.base import OV_BACKENDS
from tests.post_training.pipelines.base import PT_BACKENDS
from tests.post_training.pipelines.base import BackendType
from tests.post_training.pipelines.base import PTQTestPipeline
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase

# Disable using aten::scaled_dot_product_attention
set_fused_attn(False, False)


class ImageClassificationTimm(PTQTestPipeline):
class ImageClassificationTimm(ImageClassificationBase):
"""Pipeline for Image Classification model from timm repository"""

def prepare_model(self) -> None:
Expand Down Expand Up @@ -113,55 +105,3 @@ def transform_fn(data_item):
return {self.input_name: np.array(images, dtype=np.float32)}

return transform_fn

def prepare_calibration_dataset(self):
dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=2, shuffle=False)

self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn())

def _validate(self):
val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False)

dataset_size = len(val_loader)

# Initialize result tensors for async inference support.
predictions = np.zeros((dataset_size))
references = -1 * np.ones((dataset_size))

core = ov.Core()

if os.environ.get("INFERENCE_NUM_THREADS"):
# Set CPU_THREADS_NUM for OpenVINO inference
inference_num_threads = os.environ.get("INFERENCE_NUM_THREADS")
core.set_property("CPU", properties={"INFERENCE_NUM_THREADS": str(inference_num_threads)})

ov_model = core.read_model(self.path_compressed_ir)
compiled_model = core.compile_model(ov_model, device_name="CPU")

jobs = int(os.environ.get("NUM_VAL_THREADS", DEFAULT_VAL_THREADS))
infer_queue = ov.AsyncInferQueue(compiled_model, jobs)

with track(total=dataset_size, description="Validation") as pbar:

def process_result(request, userdata):
output_data = request.get_output_tensor().data
predicted_label = np.argmax(output_data, axis=1)
predictions[userdata] = predicted_label
pbar.progress.update(pbar.task, advance=1)

infer_queue.set_callback(process_result)

for i, (images, target) in enumerate(val_loader):
# W/A for memory leaks when using torch DataLoader and OpenVINO
image_copies = copy.deepcopy(images.numpy())
infer_queue.start_async(image_copies, userdata=i)
references[i] = target

infer_queue.wait_all()

acc_top1 = accuracy_score(predictions, references)

self.run_info.metric_name = "Acc@1"
self.run_info.metric_value = acc_top1
121 changes: 121 additions & 0 deletions tests/post_training/pipelines/image_classification_torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import onnx
import openvino as ov
import torch
from torch._export import capture_pre_autograd_graph
from torchvision import models

from nncf.torch import disable_patching
from tests.post_training.pipelines.base import PT_BACKENDS
from tests.post_training.pipelines.base import BackendType
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase


class ImageClassificationTorchvision(ImageClassificationBase):
"""Pipeline for Image Classification model from torchvision repository"""

models_vs_imagenet_weights = {
models.resnet18: models.ResNet18_Weights.DEFAULT,
models.mobilenet_v3_small: models.MobileNet_V3_Small_Weights.DEFAULT,
models.vit_b_16: models.ViT_B_16_Weights.DEFAULT,
models.swin_v2_s: models.Swin_V2_S_Weights.DEFAULT,
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_weights: models.WeightsEnum = None
self.input_name: str = None

def prepare_model(self) -> None:
model_cls = models.__dict__.get(self.model_id)
self.model_weights = self.models_vs_imagenet_weights[model_cls]
model = model_cls(weights=self.model_weights)
model.eval()

self.static_input_size = [self.batch_size, 3, 224, 224]
self.input_size = self.static_input_size.copy()
if self.batch_size > 1: # Dynamic batch_size shape export
self.input_size[0] = -1

self.dummy_tensor = torch.rand(self.static_input_size)

if self.backend == BackendType.FX_TORCH:
with torch.no_grad():
with disable_patching():
self.model = capture_pre_autograd_graph(model, (self.dummy_tensor,))

elif self.backend in PT_BACKENDS:
self.model = model

if self.backend == BackendType.ONNX:
onnx_path = self.fp32_model_dir / "model_fp32.onnx"
additional_kwargs = {}
if self.batch_size > 1:
additional_kwargs["input_names"] = ["image"]
additional_kwargs["dynamic_axes"] = {"image": {0: "batch"}}
torch.onnx.export(
model, self.dummy_tensor, onnx_path, export_params=True, opset_version=13, **additional_kwargs
)
self.model = onnx.load(onnx_path)
self.input_name = self.model.graph.input[0].name

elif self.backend in [BackendType.OV, BackendType.FP32]:
with torch.no_grad():
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

self._dump_model_fp32()

# Set device after dump fp32 model
if self.backend == BackendType.CUDA_TORCH:
self.model.cuda()
self.dummy_tensor = self.dummy_tensor.cuda()

def _dump_model_fp32(self) -> None:
"""Dump IRs of fp32 models, to help debugging."""
if self.backend in PT_BACKENDS:
with disable_patching():
ov_model = ov.convert_model(
torch.export.export(self.model, args=(self.dummy_tensor,)),
example_input=self.dummy_tensor,
input=self.input_size,
)
ov.serialize(ov_model, self.fp32_model_dir / "model_fp32.xml")

if self.backend == BackendType.FX_TORCH:
exported_model = torch.export.export(self.model, (self.dummy_tensor,))
ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size)
ov.serialize(ov_model, self.fp32_model_dir / "fx_model_fp32.xml")

if self.backend in [BackendType.FP32, BackendType.OV]:
ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml")

def prepare_preprocessor(self) -> None:
self.transform = self.model_weights.transforms()

def get_transform_calibration_fn(self):
if self.backend in [BackendType.FX_TORCH] + PT_BACKENDS:
device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu")

def transform_fn(data_item):
images, _ = data_item
return images.to(device)

else:

def transform_fn(data_item):
images, _ = data_item
return {self.input_name: np.array(images, dtype=np.float32)}

return transform_fn

0 comments on commit eb91af2

Please sign in to comment.