Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS support #789

Merged
merged 25 commits into from
Feb 11, 2025
Merged

MPS support #789

merged 25 commits into from
Feb 11, 2025

Conversation

aman-17
Copy link
Member

@aman-17 aman-17 commented Jan 22, 2025

Added MPS support.

@aman-17 aman-17 added the type/feature An issue or pull request that introduces a new feature label Jan 22, 2025
@aman-17 aman-17 requested a review from dirkgr January 22, 2025 21:41
@aman-17 aman-17 self-assigned this Jan 22, 2025
@aman-17 aman-17 removed the type/feature An issue or pull request that introduces a new feature label Jan 22, 2025
@aman-17 aman-17 added the type/feature An issue or pull request that introduces a new feature label Jan 22, 2025
@dirkgr
Copy link
Member

dirkgr commented Jan 29, 2025

Is this all good to go then?

@aman-17
Copy link
Member Author

aman-17 commented Jan 29, 2025

Yeah, just removed precision and batch_size from main function as you've mentioned last week in the meeting.

olmo/checkpoint.py Outdated Show resolved Hide resolved
olmo/config.py Outdated Show resolved Hide resolved
olmo/model.py Outdated Show resolved Hide resolved
olmo/model.py Outdated Show resolved Hide resolved
olmo/torch_util.py Outdated Show resolved Hide resolved
olmo/train.py Outdated Show resolved Hide resolved
olmo/train.py Outdated Show resolved Hide resolved
scripts/train.py Outdated Show resolved Hide resolved
scripts/train.py Outdated Show resolved Hide resolved
scripts/train.py Outdated Show resolved Hide resolved
@aman-17 aman-17 requested a review from dirkgr February 5, 2025 00:36
Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try this to make sure it still works on MPS, and also on GPS?

olmo/train.py Outdated Show resolved Hide resolved
scripts/train.py Show resolved Hide resolved
olmo/train.py Outdated Show resolved Hide resolved
@aman-17
Copy link
Member Author

aman-17 commented Feb 5, 2025

Did you try this to make sure it still works on MPS, and also on GPS?

It runs fine on both MPS and GPU.
Our .toml file installs Pytorch 2.6.0, but we need to downgrade to 2.5.1 for training.

@dirkgr
Copy link
Member

dirkgr commented Feb 5, 2025

Our .toml file installs Pytorch 2.6.0, but we need to downgrade to 2.5.1 for training.

Why is that?

@aman-17
Copy link
Member Author

aman-17 commented Feb 5, 2025

Our .toml file installs Pytorch 2.6.0, but we need to downgrade to 2.5.1 for training.

Why is that?

Pytorch released 2.6 and changed the default weights_only=True when loading the models which is restricting PosixPath's.

Our.toml file for your reference:

dependencies = [
    "torch>=2.1",
    ...
    ]

And error with torch==2.6.0:

(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message. WeightsUnpickler error: Unsupported global: GLOBAL pathlib.PosixPath was not an allowed global by default. Please use `torch.serialization.add_safe_globals([PosixPath])` or the `torch.serialization.safe_globals([PosixPath])` context manager to allowlist this global if you trust this class/function.

olmo/model.py Outdated Show resolved Hide resolved
@dirkgr
Copy link
Member

dirkgr commented Feb 5, 2025

we changed the default value of the weights_only

Can you just change it where we call the function? Pass in the old value?

@aman-17
Copy link
Member Author

aman-17 commented Feb 5, 2025

we changed the default value of the weights_only

Can you just change it where we call the function? Pass in the old value?

Passed weights_only=False explicitly. Works well with both the torch versions(2.5.x and 2.6.0)

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing

olmo/model.py Outdated Show resolved Hide resolved
@dirkgr dirkgr merged commit b394700 into main Feb 11, 2025
12 checks passed
@dirkgr dirkgr deleted the amanr/mps_support branch February 11, 2025 22:50
@dirkgr
Copy link
Member

dirkgr commented Feb 11, 2025

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type/feature An issue or pull request that introduces a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants