Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Issue with TransformerBlock parallel plan and residual connections. #62

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sirluk
Copy link

@sirluk sirluk commented Dec 4, 2024

Hi, when testing tensor parallel I noticed that there seems to be an issue with the parallel plan used for TransformerBlocks. Specifically there seems to be an issue with RowwiseParallel not correctly converting local torch.Tensor instances to DTensors. This is needed for the additive residual connection in the attention module and the multiplicative gate in the FeedForward module.

With the current plan outputs of attention.wo as well as feed_forward.w2 are instances of torch.distributed._functional_collectives.AsyncCollectiveTensor which is a subclass of torch.Tensor. Therefore I get this error got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

At least for me the solution was to simply set use_local_output=False for those layers. I also added a minimal script to reproduce the issue. I only tested it with 2 processes, maybe that is an edge case.

torchrun --nproc-per-node=2 myscript.py

# myscript.py
import os
from copy import deepcopy
import torch
import torch.distributed as dist
from torch.distributed._tensor import Replicate, Shard, DTensor
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
    PrepareModuleInput,
    parallelize_module,
)
from torch.distributed.device_mesh import init_device_mesh

from lingua.transformer import TransformerBlock, BaseTransformerArgs, RotaryEmbedding

def layer_plan():
    layer_plan = {}

    layer_plan["attention"] = PrepareModuleInput(
        input_layouts=(Shard(1), None),
        desired_input_layouts=(Replicate(), None),
    )
    layer_plan["attention_norm"] = SequenceParallel()
    layer_plan["attention.wq"] = ColwiseParallel()
    layer_plan["attention.wk"] = ColwiseParallel()
    layer_plan["attention.wv"] = ColwiseParallel()
    layer_plan["attention.wo"] = RowwiseParallel(output_layouts=Shard(1))

    # Feedforward layers tp
    layer_plan["feed_forward"] = PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    )
    layer_plan["ffn_norm"] = SequenceParallel()
    layer_plan["feed_forward.w1"] = ColwiseParallel()
    layer_plan["feed_forward.w3"] = ColwiseParallel()
    layer_plan["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1))
    return layer_plan


def main(model, parallel_plan, x, transformer_layer_kwargs, tp_mesh):
    model = parallelize_module(model, tp_mesh, parallel_plan)
    return model(x, **transformer_layer_kwargs)


if __name__ == "__main__":

    DIM = 512
    N_HEADS = 2
    MAX_SEQLEN = 256

    # Initialize distributed environment
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    torch.cuda.set_device(local_rank)
    dist.init_process_group("nccl")
    world_size = dist.get_world_size()

    tp_mesh = init_device_mesh("cuda", (world_size,))

    transformer_layer = None
    base_transformer_args = BaseTransformerArgs(dim=DIM, n_heads=N_HEADS)
    model = TransformerBlock(base_transformer_args)
    
    model.attention.n_heads = model.attention.n_heads // world_size
    model.attention.n_kv_heads = model.attention.n_kv_heads // world_size

    layer_plan_old = layer_plan()

    layer_plan_new = deepcopy(layer_plan_old)
    layer_plan_new["attention.wo"] = RowwiseParallel(output_layouts=Shard(1), use_local_output=False)
    layer_plan_new["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1), use_local_output=False)

    rope_embeddings = RotaryEmbedding(
        theta=1000,
        head_dim=DIM // N_HEADS,
        max_seqlen=MAX_SEQLEN,
    )
    freq_cis = rope_embeddings(seqlen=MAX_SEQLEN, tok_idx=None)
    freq_cis = freq_cis.to(local_rank)

    transformer_layer_kwargs = {
        "freq_cis": freq_cis,
        "tok_idx": None,
        "mask": None,
        "attn_impl": "sdpa",
    }

    # input
    x = torch.randn(1, MAX_SEQLEN//world_size, DIM, device=local_rank)
    x = DTensor.from_local(x, device_mesh=tp_mesh, placements=[Shard(1)])


    print("forward pass with new parallel plan")
    o_new = main(model, layer_plan_new, x, transformer_layer_kwargs, tp_mesh)

    print("forward pass with old parallel plan, will throw an error")
    o_old = main(model, layer_plan_old, x, transformer_layer_kwargs, tp_mesh)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 4, 2024
@mathuvu
Copy link
Contributor

mathuvu commented Jan 10, 2025

Thank you for your PR. From your minimal example, there is something incorrect: x = DTensor.from_local(x, device_mesh=tp_mesh, placements=[Shard(1)]) this is something we don't do. If you remove this line the old_plan works as intended and the new plan doesn't work anymore. Our code is without the from_local if I'm not mistaken. So maybe your issue comes from something else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants