diff --git a/examples/vision/image-classification/pets/quantize_vit_model.py b/examples/vision/image-classification/pets/quantize_vit_model.py index 42ad3b70..ad0195e4 100644 --- a/examples/vision/image-classification/pets/quantize_vit_model.py +++ b/examples/vision/image-classification/pets/quantize_vit_model.py @@ -37,7 +37,6 @@ def test(model, device, test_loader): data, target = batch["pixel_values"], batch["labels"] data, target = data.to(device), target.to(device) output = model(data).logits - # print("*****I am after output", output) if isinstance(output, QTensor): output = output.dequantize() test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss @@ -66,8 +65,6 @@ def main(): parser.add_argument("--activations", type=str, default="int8", choices=["none", "int8", "float8"]) args = parser.parse_args() - # torch.manual_seed(args.seed) - dataset_kwargs = {} if args.device is None: @@ -75,8 +72,8 @@ def main(): device = torch.device("cuda") cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} dataset_kwargs.update(cuda_kwargs) - elif torch.backends.mps.is_available(): - device = torch.device("cpu") + elif all([torch.backends.mps.is_available(), args.weights != "float8", args.activations != "float8"]): + device = torch.device("mps") else: device = torch.device("cpu") else: