-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Apple MPS acceleration #1129
base: main
Are you sure you want to change the base?
Conversation
Thank you for the PR! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1129 +/- ##
==========================================
- Coverage 92.26% 91.99% -0.28%
==========================================
Files 84 84
Lines 12445 12535 +90
==========================================
+ Hits 11482 11531 +49
- Misses 963 1004 +41
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
…:helmholtz-analytics/heat into features/1053-support-Apple-silicon-GPUs
Thanks, sadly it looks like an actual problem, conveniently without error message. I'll debug it, will probably get to it next week. |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
Thank you for the PR! |
if self.is_mps: | ||
dtypes = [ht.float32] | ||
else: | ||
dtypes = [ht.float32, ht.float64] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (and all subsequent tests that have to filter by system) would be a great target for parametrization (now that we talked about introducing hypothesis and parametrized tests).
A good example on how to skip certain possible parameters based on the os is here
@@ -339,7 +339,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: | |||
else: # A not split, b.split == -2 | |||
b_lshapes_cum = torch.hstack( | |||
[ | |||
torch.zeros(1, dtype=torch.int32, device=tdev), | |||
torch.zeros(1, dtype=torch.int64, device=tdev), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for change? Why not use the default dtype?
@@ -2154,19 +2168,20 @@ def test_triu(self): | |||
self.assertTrue(result.larray[0, -1] == 1) | |||
|
|||
def test_vdot(self): | |||
a = ht.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]], split=0) | |||
b = ht.array([[1 + 2j, 3 + 4j], [5 + 6j, 7 + 8j]], split=0) | |||
if not self.is_mps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test should be skipped using unittest.skipIf or pytest.mark.skipif
ht.allclose(q.transpose([0, 1, 3, 2]) @ q, batched_id, atol=1e-6, rtol=1e-6) | ||
) | ||
self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6)) | ||
# skip float64 tests on MPS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test should be skipped using unittest.skipIf or pytest.mark.skipif
] | ||
rtols = [1e-1, 1e-2, 1e-3] | ||
ranks = [5, 10, 15] | ||
# not testing on MPS for now as torch.norm() is unstable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test should be skipped using unittest.skipIf or pytest.mark.skipif
is_mps = x.larray.is_mps or y.larray.is_mps | ||
if is_mps and result_type is types.float64: | ||
result_type = types.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking every time after calling types.result_type
, the check could be done inside types.result_type()
. This would save a lot of extra if statements and less chance of possibly forgetting to add that.
if a.larray.is_mps and promoted_type == float64: | ||
# cannot cast to float64 on MPS | ||
promoted_type = float32 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with promote_types
.
LAST EDITED DEC 12 2025
[Note from human: the most important changes in this PR are:
device.py
. Bothht.array(..., device="gpu")
andht.array(..., device="mps")
are allowed.device
attribute introduced forht.random.permutation
Still to do:
update READMEactually, it's probably best to update the README just before the next releaseBelow a copilot summary]
This pull request includes several changes to improve compatibility with Apple's Metal Performance Shaders (MPS) and correct some minor issues. The most important changes include modifications to handle unsupported data types on MPS, updates to unit tests, and minor corrections in documentation.
MPS Compatibility Improvements:
heat/core/_operations.py
: Added checks to handle unsupportedfloat64
data type on MPS and cast tofloat32
with appropriate warnings. [1] [2] [3]heat/core/arithmetics.py
: Updatedhypot
andhypot_
functions to raise errors for unsupportedint64
data type on MPS. [1] [2]heat/core/dndarray.py
: Modified thesize
method to avoid usingfloat64
on MPS.heat/core/linalg/basics.py
: Added a check to raise aRuntimeError
if matrix inversion fails on MPS.Unit Test Updates:
heat/cluster/tests/test_batchparallelclustering.py
,heat/cluster/tests/test_kmeans.py
,heat/cluster/tests/test_kmedians.py
,heat/cluster/tests/test_kmedoids.py
: Updated tests to handle unsupportedfloat64
data type on MPS. [1] [2] [3] [4]heat/cluster/tests/test_spectral.py
: Added a condition to skip tests on MPS due to unsupportedComplexFloat
operations.heat/core/linalg/tests/test_basics.py
: Updated tests to avoid usingfloat64
on MPS. [1] [2] [3]Minor Corrections:
heat/core/devices.py
: Corrected documentation to use consistent naming for Heat. [1] [2]heat/core/linalg/solver.py
: Changed tensor creation to useint64
instead ofint32
for cumulative sum operations. [1] [2]Reference
Issue/s resolved: #1053
Changes proposed:
Type of change
Memory requirements
Performance
Due Diligence
All split configurations testeddoes not applyDoes this change modify the behaviour of other functions? If so, which?
no