From cb9acd4073e6772c77465261db965cd02dd58244 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 24 Sep 2024 14:28:54 -0600 Subject: [PATCH] Add a comment --- array_api_compat/torch/_info.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index a85e684e..264caa9e 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -147,6 +147,10 @@ def default_dtypes(self, *, device=None): 'indexing': torch.int64} """ + # Note: if the default is set to float64, the devices like MPS that + # don't support float64 will error. We still return the default_dtype + # value here because this error doesn't represent a different default + # per-device. default_floating = torch.get_default_dtype() default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 default_integral = torch.int64