From aa5af1be6773abadea9f674f6120a7c43042d3db Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Fri, 24 May 2024 17:10:01 -0700 Subject: [PATCH 1/2] enable TritonFusedRMSNorm with local_map annotation [ghstack-poisoned] --- test/test_fused_rms_norm.py | 55 ++++++++++++++++++++ torchtitan/models/norms.py | 22 ++++++++ torchtitan/parallelisms/parallelize_llama.py | 5 -- train_configs/debug_model.toml | 2 +- 4 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 test/test_fused_rms_norm.py diff --git a/test/test_fused_rms_norm.py b/test/test_fused_rms_norm.py new file mode 100644 index 00000000..f0e4ebef --- /dev/null +++ b/test/test_fused_rms_norm.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.distributed._tensor import ( + distribute_tensor, + init_device_mesh, + Replicate, + Shard, +) +from torch.distributed._tensor.debug import CommDebugMode +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + +from torchtitan.models.norms import fused_rms_norm_fn + + +class TestFusedRMSNorm(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_fused_rms_norm(self): + mesh = init_device_mesh( + device_type=self.device_type, mesh_shape=(self.world_size,) + ) + x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1) + w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate + + dx = distribute_tensor(x, mesh, [Shard(1)]) + dw = distribute_tensor(w, mesh, [Replicate()]) + + comm_mode = CommDebugMode() + # fused rmsnorm + with comm_mode: + out = fused_rms_norm_fn(dx, dw) + + self.assertEqual(comm_mode.get_total_counts(), 0) + + with comm_mode: + grad_out = torch.ones_like(out) + out.backward(grad_out) + + self.assertEqual(comm_mode.get_total_counts(), 0) + + +if __name__ == "__main__": + run_tests() diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index e29338d9..b15a8519 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -6,12 +6,18 @@ import math +from functools import partial + import torch import torch.nn as nn import triton import triton.language as tl +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed._tensor.experimental import local_map +from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard + def create_norm(norm_type: str, dim: int, eps: float = 1e-6): """ @@ -214,8 +220,16 @@ def _rms_norm_bwd_kernel_sm( class TritonFusedRMSNorm(torch.autograd.Function): + @partial( + local_map, + out_placements=[Shard(1)], + in_placements=(None, [Shard(1)], [Replicate()], None), + ) @staticmethod def forward(ctx, x, weight, eps): + if isinstance(x, AsyncCollectiveTensor): + x = x.wait() + x_shape_start = x.shape # Flatten input @@ -256,8 +270,16 @@ def forward(ctx, x, weight, eps): y = y.reshape(x_shape_start) return y + @partial( + local_map, + out_placements=([Shard(1)], [_Partial()], None), + in_placements=(None, [Shard(1)]), + ) @staticmethod def backward(ctx, dy): + if isinstance(dy, AsyncCollectiveTensor): + dy = dy.wait() + x, weight, rstd = ctx.saved_tensors eps = ctx.eps x_shape_start = ctx.x_shape_start diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 3617eb23..179ef8ee 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -300,11 +300,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ if parallel_dims.tp_enabled: - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." - ) - tp_mesh = world_mesh["tp"] row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( job_config diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 009348b5..f0b04298 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -35,7 +35,7 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) From 71659de492ae262efcdaf2860d4d16db9ee3715a Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Tue, 4 Jun 2024 14:41:20 -0700 Subject: [PATCH 2/2] Update on "enable TritonFusedRMSNorm with local_map annotation" [ghstack-poisoned] --- torchtitan/models/norms.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index b15a8519..dc312f71 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -14,7 +14,6 @@ import triton import triton.language as tl -from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor.experimental import local_map from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard @@ -227,9 +226,6 @@ class TritonFusedRMSNorm(torch.autograd.Function): ) @staticmethod def forward(ctx, x, weight, eps): - if isinstance(x, AsyncCollectiveTensor): - x = x.wait() - x_shape_start = x.shape # Flatten input @@ -277,9 +273,6 @@ def forward(ctx, x, weight, eps): ) @staticmethod def backward(ctx, dy): - if isinstance(dy, AsyncCollectiveTensor): - dy = dy.wait() - x, weight, rstd = ctx.saved_tensors eps = ctx.eps x_shape_start = ctx.x_shape_start