diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 87a687cf3..3d35a49dc 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -380,6 +380,10 @@ def main(): val_loader = generate_dataloader_with_transform( args.validation_dir, args.batch_size_validation, args.workers, transform) + if args.validate_before_quantize is True: + print("Starting validation of unquantized model") + validate(val_loader, model, stable=dtype != torch.bfloat16) + # Preprocess the model for quantization if args.target_backend == 'flexml': # flexml requires static shapes, pass a representative input in @@ -480,10 +484,6 @@ def main(): print("Starting validation:") validate(val_loader, quant_model, stable=dtype != torch.bfloat16) - if args.validate_before_quantize == True: - print("Starting validation of unquantized model") - validate(val_loader, model, stable=dtype != torch.bfloat16) - if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process model_config = get_model_config(args.model_name)