Skip to content

Commit

Permalink
moved pre quantize validate step to earlier in the script
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev committed May 1, 2024
1 parent c0f596b commit 28e8868
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ def main():
val_loader = generate_dataloader_with_transform(
args.validation_dir, args.batch_size_validation, args.workers, transform)

if args.validate_before_quantize is True:
print("Starting validation of unquantized model")
validate(val_loader, model, stable=dtype != torch.bfloat16)

# Preprocess the model for quantization
if args.target_backend == 'flexml':
# flexml requires static shapes, pass a representative input in
Expand Down Expand Up @@ -480,10 +484,6 @@ def main():
print("Starting validation:")
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)

if args.validate_before_quantize == True:
print("Starting validation of unquantized model")
validate(val_loader, model, stable=dtype != torch.bfloat16)

if args.export_onnx_qcdq or args.export_torch_qcdq:
# Generate reference input tensor to drive the export process
model_config = get_model_config(args.model_name)
Expand Down

0 comments on commit 28e8868

Please sign in to comment.