Skip to content

Commit

Permalink
feat(examples): don't use mps for float8 dtype.
Browse files Browse the repository at this point in the history
  • Loading branch information
shovan777 authored and dacorvo committed Sep 17, 2024
1 parent ffae30c commit ec1f85e
Showing 1 changed file with 2 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test(model, device, test_loader):
data, target = batch["pixel_values"], batch["labels"]
data, target = data.to(device), target.to(device)
output = model(data).logits
# print("*****I am after output", output)
if isinstance(output, QTensor):
output = output.dequantize()
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
Expand Down Expand Up @@ -66,17 +65,15 @@ def main():
parser.add_argument("--activations", type=str, default="int8", choices=["none", "int8", "float8"])
args = parser.parse_args()

# torch.manual_seed(args.seed)

dataset_kwargs = {}

if args.device is None:
if torch.cuda.is_available():
device = torch.device("cuda")
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
dataset_kwargs.update(cuda_kwargs)
elif torch.backends.mps.is_available():
device = torch.device("cpu")
elif all([torch.backends.mps.is_available(), args.weights != "float8", args.activations != "float8"]):
device = torch.device("mps")
else:
device = torch.device("cpu")
else:
Expand Down

0 comments on commit ec1f85e

Please sign in to comment.