Skip to content

Commit

Permalink
Added BF16 & FP16 models to PTQ tests (#2922)
Browse files Browse the repository at this point in the history
### Changes

- Added BF16/FP16 models to PTQ tests scope.

### Reason for changes

- Extend e2e with different precisions.

### Related tickets

- 147481

### Tests

- manual post_training_quantization/562/ - passed
  • Loading branch information
KodiaqQ authored Nov 28, 2024
1 parent 76e0ffc commit 9c5220a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 8 deletions.
8 changes: 8 additions & 0 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ hf/hf-internal-testing/tiny-random-gpt2_backend_OV:
metric_value: null
hf/hf-internal-testing/tiny-random-gpt2_backend_TORCH:
metric_value: null
hf/bert-base-uncased_fp16_backend_FP32:
metric_value: null
hf/bert-base-uncased_fp16_backend_OV:
metric_value: null
hf/bert-base-uncased_bf16_backend_FP32:
metric_value: null
hf/bert-base-uncased_bf16_backend_OV:
metric_value: null
torchvision/resnet18_backend_FP32:
metric_value: 0.6978
torchvision/resnet18_backend_OV:
Expand Down
26 changes: 26 additions & 0 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import copy
from typing import Dict, List

import torch

import nncf
from nncf import ModelType
from nncf import QuantizationPreset
Expand Down Expand Up @@ -81,6 +83,30 @@
},
"backends": [BackendType.TORCH, BackendType.OV, BackendType.OPTIMUM],
},
{
"reported_name": "hf/bert-base-uncased_fp16",
"model_id": "bert-base-uncased",
"pipeline_cls": MaskedLanguageModelingHF,
"compression_params": {
"preset": QuantizationPreset.MIXED,
"model_type": ModelType.TRANSFORMER,
"subset_size": 2,
},
"backends": [BackendType.OV],
"params": {"base_precision": torch.float16},
},
{
"reported_name": "hf/bert-base-uncased_bf16",
"model_id": "bert-base-uncased",
"pipeline_cls": MaskedLanguageModelingHF,
"compression_params": {
"preset": QuantizationPreset.MIXED,
"model_type": ModelType.TRANSFORMER,
"subset_size": 2,
},
"backends": [BackendType.OV],
"params": {"base_precision": torch.bfloat16},
},
# Torchvision models
{
"reported_name": "torchvision/resnet18",
Expand Down
25 changes: 17 additions & 8 deletions tests/post_training/pipelines/masked_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,26 @@ class MaskedLanguageModelingHF(PTQTestPipeline):
"""Pipeline for masked language models from Hugging Face repository"""

def prepare_model(self) -> None:
torch_dtype = self.params.get("base_precision", torch.float32)
if self.backend in PT_BACKENDS:
self.model_hf = transformers.AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.model_hf = transformers.AutoModelForSequenceClassification.from_pretrained(
self.model_id, torch_dtype=torch_dtype
)
self.model = self.model_hf
self.model.config.torchscript = True # Set to export by convert_model via torch.jit.trace
self.dummy_tensor = self.model_hf.dummy_inputs["input_ids"]
if self.backend in OV_BACKENDS + [BackendType.FP32]:
self.model_hf = OVModelForSequenceClassification.from_pretrained(self.model_id, export=True, compile=False)
self.model = self.model_hf.model
if torch_dtype != torch.float32:
# Since optimum-intel does not produce custom-type models, this workaround handles it.
self.model_hf = transformers.AutoModelForSequenceClassification.from_pretrained(
self.model_id, torch_dtype=torch_dtype
)
self.model = ov.convert_model(self.model_hf, example_input=self.model_hf.dummy_inputs)
else:
self.model_hf = OVModelForSequenceClassification.from_pretrained(
self.model_id, export=True, compile=False
)
self.model = self.model_hf.model

if self.backend == BackendType.ONNX:
self.model_hf = ORTModelForSequenceClassification.from_pretrained(self.model_id, export=True)
Expand Down Expand Up @@ -75,13 +87,10 @@ def transform_func(data):
return torch.tensor([data["input_ids"]]).type(dtype=torch.LongTensor).to(device)

else:
input_names = [p.get_friendly_name() for p in self.model.get_parameters()]

def transform_func(data):
return {
"input_ids": np.expand_dims(data["input_ids"], axis=0),
"token_type_ids": np.expand_dims(data["token_type_ids"], axis=0),
"attention_mask": np.expand_dims(data["attention_mask"], axis=0),
}
return {n: np.expand_dims(data[n], axis=0) for n in input_names}

return transform_func

Expand Down

0 comments on commit 9c5220a

Please sign in to comment.