From 28e8868c43fddd969c516b5a65fdbba8b17a6854 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Wed, 1 May 2024 11:15:02 +0100 Subject: [PATCH] moved pre quantize validate step to earlier in the script --- .../imagenet_classification/ptq/ptq_evaluate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 87a687cf3..3d35a49dc 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -380,6 +380,10 @@ def main(): val_loader = generate_dataloader_with_transform( args.validation_dir, args.batch_size_validation, args.workers, transform) + if args.validate_before_quantize is True: + print("Starting validation of unquantized model") + validate(val_loader, model, stable=dtype != torch.bfloat16) + # Preprocess the model for quantization if args.target_backend == 'flexml': # flexml requires static shapes, pass a representative input in @@ -480,10 +484,6 @@ def main(): print("Starting validation:") validate(val_loader, quant_model, stable=dtype != torch.bfloat16) - if args.validate_before_quantize == True: - print("Starting validation of unquantized model") - validate(val_loader, model, stable=dtype != torch.bfloat16) - if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process model_config = get_model_config(args.model_name)