[BUGFIX] Issue with TransformerBlock parallel plan and residual connections. #62
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi, when testing tensor parallel I noticed that there seems to be an issue with the parallel plan used for TransformerBlocks. Specifically there seems to be an issue with RowwiseParallel not correctly converting local torch.Tensor instances to DTensors. This is needed for the additive residual connection in the attention module and the multiplicative gate in the FeedForward module.
With the current plan outputs of
attention.wo
as well asfeed_forward.w2
are instances oftorch.distributed._functional_collectives.AsyncCollectiveTensor
which is a subclass oftorch.Tensor
. Therefore I get this errorgot mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
At least for me the solution was to simply set
use_local_output=False
for those layers. I also added a minimal script to reproduce the issue. I only tested it with 2 processes, maybe that is an edge case.torchrun --nproc-per-node=2 myscript.py