Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] How to improve the training perfomance in MLX compare to pytorch and keras ? #1542

Open
thegodone opened this issue Oct 30, 2024 · 11 comments

Comments

@thegodone
Copy link

thegodone commented Oct 30, 2024

Describe the bug
I have a major issue that I have seen in lot of the cases on other trial. The MLX training gives rarely a good performance while for torch and keras it is more stable and better. This is really a bottleneck to use MLX, as you need to train 10 to 20 time your model to get a good result while torch and keras are systematically in a good range (rmse : 0.50-0.55).
important: models (tf/keras, torch and mlx) have the same number of trainable parameters, and we use the same train, val and test split for the 3 methods).

To Reproduce

run several time the following code the best result is jumping out of the pytorch and tf/keras results
https://github.com/thegodone/apple_ai_model/blob/main/AttFP_mlx_faster.ipynb
https://github.com/thegodone/apple_ai_model/blob/main/AttFP_torch.ipynb
https://github.com/thegodone/apple_ai_model/blob/main/AttFP_tf.ipynb

Expected behavior
I don't know if it is weights initialization or optimizer that can cause this huge difference between the 3 packages.

Desktop (please complete the following information):
see #1531

@thegodone
Copy link
Author

thegodone commented Nov 8, 2024

Can it be related to the #1153 (comment) comment ?

import math
from typing import Union, List, Callable
import mlx.core as mx

class Adam(Optimizer):
    def __init__(self, learning_rate: Union[float, Callable], betas: List[float] = [0.9, 0.999], eps: float = 1e-8):
        super().__init__()
        self._maybe_schedule("learning_rate", learning_rate)
        self.betas = betas
        self.eps = eps

    def init_single(self, parameter: mx.array, state: dict):
        state["m"] = mx.zeros_like(parameter)  # Initialize momentum
        state["v"] = mx.zeros_like(parameter)  # Initialize velocity

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        lr = self.learning_rate.astype(gradient.dtype)
        beta1, beta2 = self.betas
        eps = self.eps

        # Update biased first moment estimate
        state["m"] = beta1 * state["m"] + (1 - beta1) * gradient
        # Update biased second moment estimate
        state["v"] = beta2 * state["v"] + (1 - beta2) * mx.square(gradient)

        # Bias-corrected estimates (optional depending on application)
        m_hat = state["m"] / (1 - beta1)
        v_hat = state["v"] / (1 - beta2)

        # Parameter update
        return parameter - lr * m_hat / (mx.sqrt(v_hat) + eps)

class AdamW(Adam):
    def __init__(self, learning_rate: Union[float, Callable], betas: List[float] = [0.9, 0.999], eps: float = 1e-8, weight_decay: float = 0.01):
        super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
        self.weight_decay = weight_decay

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        lr = self.learning_rate.astype(gradient.dtype)
        # Apply weight decay before applying Adam update
        parameter = parameter * (1 - lr * self.weight_decay)
        # Call the parent Adam's apply_single() for the core Adam update
        return super().apply_single(gradient, parameter, state)

@awni
Copy link
Member

awni commented Nov 9, 2024

If you can try either of those and report back, that would be useful to know.

Is it possible to add a parameter for Adam optimizer to be strictly identical to pytorch / tensorflow ?

It's possible..would be good to know if it fixes your issue first though.

@thegodone
Copy link
Author

thegodone commented Nov 9, 2024

here the differences between the 3 trainings I used now the bias corrected AdamW in mlx:

  • train_loss is also different I need to check that
  • LR are not identical that may explain the difference
    comparing

one remark, now we have similar speed between torch and mlx

@awni
Copy link
Member

awni commented Nov 9, 2024

The 0 training loss in MLX seems incorrect particularly given the training MSE seems reasonable. I would double check you are averaging the loss in MLX correctly.

Otherwise it mostly looks reasonable.. fine tuning learning rates, warmups, initializations etc could all help.

@thegodone
Copy link
Author

thegodone commented Nov 9, 2024

I will have a synchronize LR scheduler tomorrow to be sure this is not the part that affect the model deviation. Yes I will look at the loss that is strange.

@thegodone
Copy link
Author

thegodone commented Nov 10, 2024

ok I fix the loss and now I have this type of results interestingly pytorch is more efficient than mlx or tensorflow.
comparing (1)

@thegodone
Copy link
Author

I use this LR scheduler now in mlx, one potential issue is that in pytorch/tensorflow the LR scheduler is per epoch while in mlx it is per step, is it possible to have an epoch equivalent ?

def cosineannealingwarmrestartfactor_(initial_lr, restart, decay_steps, warmup_factor, Tmin):
    schedules = []
    boundaries = []  # Boundaries should be one less than schedules
    base_lr = initial_lr
    schedules.append(optim.cosine_decay(initial_lr, decay_steps, Tmin))
    for i in range(restart-1):        
        Tmin *= warmup_factor
        initial_lr*=warmup_factor

        schedules.append(optim.cosine_decay(initial_lr, decay_steps, Tmin))
        boundaries.append(decay_steps*(i+1))
        
    lr_schedule = optim.join_schedules(schedules, boundaries)
    return lr_schedule


@thegodone
Copy link
Author

this is the mlx version with "official Adam without bias correction": clearly not performing as the bias correction used in previous posts.
comparing (2)

@thegodone
Copy link
Author

Is there a way to fix a seed for mlx similar to the torch.manual_seed(int) ?

@awni
Copy link
Member

awni commented Nov 12, 2024

mx.random.seed

@thegodone
Copy link
Author

thegodone commented Nov 12, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants