Skip to content

Avoid deepspeed plugin converting the whole model #20543

Discussion options

You must be logged in to vote

I found you should implement the dtype conversion in your LightningModule, and avoid DeepSpeedStrategy from converting your module.

from lightning.pytorch.plugins import DeepSpeedPrecision
from lightning.pytorch.strategies import DeepSpeedStrategy
from typing_extensions import override

class DeepSpeedPrecisionWithoutModuleConversion(DeepSpeedPrecision):
    @override
    def convert_module(self, module):
        return module

and pass to the trainer as

trainer = Trainer(
    ...,
    DeepSpeedStrategy(stage=2, precision_plugin=DeepSpeedPrecisionWithoutModuleConversion('32-true'))
)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Boltzmachine
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
1 participant