diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 44523fa71..81ea10b8d 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -33,17 +33,6 @@ def estimate_memory(job_config: JobConfig): # Get the world size world_size = int(os.environ["WORLD_SIZE"]) - # fake tensor doesn't work with fused rmsnorm - if ( - job_config.model.norm_type == "fused_rmsnorm" - and not job_config.memory_estimation.disable_fake_mode - ): - logger.info( - "Fused RMSNorm is not supported yet under fake estimation mode. " - "Switching to rmsnorm." - ) - job_config.model.norm_type = "rmsnorm" - if job_config.model.norm_type == "compiled_rmsnorm": logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.") job_config.model.norm_type = "rmsnorm" diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 7048439c9..53de7f833 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -94,16 +94,6 @@ def build_test_list(): "2D compile", "2d_compile", ), - OverrideDefinitions( - [ - [ - "--training.tensor_parallel_degree 2", - "--model.norm_type=fused_rmsnorm", - ], - ], - "2D eager with fused_rmsnorm", - "2d_eager_fused_rmsnorm", - ), OverrideDefinitions( [ [ diff --git a/tests/unit_tests/test_fused_rms_norm_dtensor.py b/tests/unit_tests/test_fused_rms_norm_dtensor.py deleted file mode 100644 index d5c353c2f..000000000 --- a/tests/unit_tests/test_fused_rms_norm_dtensor.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch.distributed._tensor import ( - distribute_tensor, - init_device_mesh, - Replicate, - Shard, -) -from torch.distributed.tensor.debug import CommDebugMode -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - skip_if_lt_x_gpu, - with_comms, -) - -from torchtitan.models.norms import fused_rms_norm_fn - - -class TestFusedRMSNorm(DTensorTestBase): - @property - def world_size(self): - return 4 - - @skip_if_lt_x_gpu(4) - @with_comms - def test_fused_rms_norm(self): - mesh = init_device_mesh( - device_type=self.device_type, mesh_shape=(self.world_size,) - ) - x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1) - w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate - - dist_x = distribute_tensor(x, mesh, [Shard(1)]) - dist_w = distribute_tensor(w, mesh, [Replicate()]) - - x = x.clone().detach() - w = w.clone().detach().requires_grad_() - - self.assertEqual(dist_x.full_tensor(), x) - self.assertEqual(dist_w.full_tensor(), w) - - # fused rmsnorm on DTensor - comm_mode = CommDebugMode() - # fused rmsnorm - with comm_mode: - dist_out = fused_rms_norm_fn(dist_x, dist_w) - - self.assertEqual(comm_mode.get_total_counts(), 0) - - with comm_mode: - dist_grad_out = torch.ones_like(dist_out) - dist_out.backward(dist_grad_out) - - self.assertEqual(comm_mode.get_total_counts(), 0) - - # fused rmsnorm on Tensor - out = fused_rms_norm_fn(x, w) - grad_out = torch.ones_like(out) - out.backward(grad_out) - - self.assertEqual(dist_out.full_tensor(), out) - self.assertEqual(dist_grad_out.full_tensor(), grad_out) - - -if __name__ == "__main__": - run_tests() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 5ca4e46f7..864c5878e 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -174,7 +174,8 @@ def __init__(self): "--model.norm_type", type=str, default="rmsnorm", - help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]", + choices=["layernorm", "np_layernorm", "rmsnorm"], + help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm]", ) self.parser.add_argument( "--model.tokenizer_path", diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 62dbc6abe..17dfb270f 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -4,20 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import math - -from functools import partial - import torch import torch.nn as nn -import triton -import triton.language as tl - -from torch.distributed._tensor import Partial, Replicate, Shard -from torch.distributed._tensor.experimental import local_map -from torchtitan.utils import device_module - def build_norm(norm_type: str, dim: int, eps: float = 1e-6): """ @@ -25,7 +14,7 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6): Args: norm_type (str): The type of normalization layer to build. - Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm + Supported types: layernorm, np_layernorm, rmsnorm dim (int): The dimension of the normalization layer. eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. @@ -43,37 +32,10 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6): return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) elif norm_type == "rmsnorm": return RMSNorm(dim, eps=eps) - elif norm_type == "fused_rmsnorm": - return FusedRMSNorm(dim, eps=eps) else: raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") -class FusedRMSNorm(nn.Module): - """Fused RMS Norm, wraps a fused Triton Kernel""" - - def __init__( - self, - dim: int, - eps: float = 1e-6, - ): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - self.fused_rms_norm_fn = fused_rms_norm_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """leverages Triton Fused RMS Norm kernel""" - return self.fused_rms_norm_fn( - x, - self.weight, - eps=self.eps, - ) - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) # type: ignore - - class RMSNorm(nn.Module): """ Initialize the RMSNorm normalization layer. @@ -102,230 +64,3 @@ def forward(self, x: torch.Tensor): def reset_parameters(self): torch.nn.init.ones_(self.weight) # type: ignore - - -# FusedRMSNorm in Triton - -# Credit -# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py -# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N"], -) -@triton.jit -def _rms_norm_fwd_kernel( - X, - stride_x, - Y, - stride_y, - W, - Rstd, - eps, - M, # num rows - N, # num cols - block_N: tl.constexpr, -): - row = tl.program_id(0) - cols = tl.arange(0, block_N) - - # Load input data and weights - mask = cols < N - x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) - w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) - - # Compute mean and variance - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - - # Store the reciprocal standard deviation - tl.store(Rstd + row, rstd) - - # Normalize and apply linear transformation - x_hat = x * rstd - y = x_hat * w - - # Write output - tl.store(Y + row * stride_y + cols, y, mask=mask) - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N"], -) -@triton.jit -def _rms_norm_bwd_kernel_sm( - X, - stride_x, - W, - DY, - stride_dy, - DX, - stride_dx, - Rstd, - DW, - eps, - M, # num rows - N, # num cols - rows_per_program, - block_N: tl.constexpr, -): - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - cols = tl.arange(0, block_N) - mask = cols < N - - # Load weights - w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) - - # Accumulate gradients for weights - dw = tl.zeros((block_N,), dtype=tl.float32) - - row_end = min(row_start + rows_per_program, M) - for row in range(row_start, row_end): - # Load input, output gradient, and reciprocal standard deviation - x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) - dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) - rstd = tl.load(Rstd + row) - - # Compute normalized input and gradients - x_hat = x * rstd - wdy = w * dy - dw += dy * x_hat - c1 = tl.sum(x_hat * wdy, axis=0) / N - dx = (wdy - x_hat * c1) * rstd - - # Store input gradient - tl.store(DX + row * stride_dx + cols, dx, mask=mask) - - # Store weight gradients - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - - -class TritonFusedRMSNorm(torch.autograd.Function): - @partial( - local_map, - out_placements=[Shard(1)], - in_placements=(None, [Shard(1)], [Replicate()], None), - ) - @staticmethod - def forward(ctx, x, weight, eps): - x_shape_start = x.shape - - # Flatten input - x = x.view(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if weight.stride(-1) != 1: - weight = weight.contiguous() - - M, N = x.shape - y = torch.empty_like(x) - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - - max_size = 65536 // x.element_size() - block_N = min(max_size, triton.next_power_of_2(N)) - - if N > block_N: - raise ValueError(f"N {N} must be <= {block_N=}") - - grid = lambda meta: (M,) - _rms_norm_fwd_kernel[grid]( - x, - x.stride(0), - y, - y.stride(0), - weight, - rstd, - eps, - M, - N, - block_N, - ) - - ctx.eps = eps - ctx.save_for_backward(x, weight, rstd) - ctx.x_shape_start = x_shape_start - - y = y.reshape(x_shape_start) - return y - - @partial( - local_map, - out_placements=([Shard(1)], [Partial()], None), - in_placements=(None, [Shard(1)]), - ) - @staticmethod - def backward(ctx, dy): - x, weight, rstd = ctx.saved_tensors - eps = ctx.eps - x_shape_start = ctx.x_shape_start - - # Flatten input and output gradients - dy = dy.view(-1, dy.shape[-1]) - if dy.stride(-1) != 1: - dy = dy.contiguous() - - M, N = dy.shape - dx = torch.empty_like(x) - - sm_count = device_module.get_device_properties(x.device).multi_processor_count - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - - max_size = 65536 // x.element_size() - block_N = min(max_size, triton.next_power_of_2(N)) - rows_per_sm = math.ceil(M / sm_count) - - if N > block_N: - raise ValueError(f"N {N} must be <= {block_N=}") - - grid = lambda meta: (sm_count,) - _rms_norm_bwd_kernel_sm[grid]( - x, - x.stride(0), - weight, - dy, - dy.stride(0), - dx, - dx.stride(0), - rstd, - _dw, - eps, - M, - N, - rows_per_sm, - block_N, - ) - dw = _dw.sum(0).to(weight.dtype) - dx = dx.view(x_shape_start) - return dx, dw, None - - -# expose fusedRMSNorm as a function -def fused_rms_norm_fn( - x, - weight, - eps=1e-6, -): - return TritonFusedRMSNorm.apply( - x, - weight, - eps, - ) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9728569ab..0d70adeb5 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -69,11 +69,6 @@ def parallelize_llama( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if job_config.training.compile: - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm is not compatible with torch.compile yet. " - "Please use rmsnorm or layernorm." - ) apply_compile(model) if ( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 733bc0ae4..33d8193b3 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -23,7 +23,7 @@ enable_wandb = false [model] name = "llama3" flavor = "debugmodel" -norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./tests/assets/test_tiktoken.model" diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index e52beb366..26405603d 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "405B" -norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 2d55a36d3..e73e4b945 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "70B" -norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3001ec748..c61640362 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer]