diff --git a/run_llama_train.sh b/run_llama_train.sh index 33aaf79b..9b046d95 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -29,6 +29,7 @@ if [ $# -ne 0 ]; then overrides="$*" fi +CUDA_LAUNCH_BLOCKING=1 \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index e29338d9..9e51f9d7 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -5,9 +5,11 @@ # LICENSE file in the root directory of this source tree. import math +from typing import Tuple import torch import torch.nn as nn +from torch import Tensor import triton import triton.language as tl @@ -213,47 +215,95 @@ def _rms_norm_bwd_kernel_sm( tl.store(DW + row_block_id * N + cols, dw, mask=mask) -class TritonFusedRMSNorm(torch.autograd.Function): - @staticmethod - def forward(ctx, x, weight, eps): - x_shape_start = x.shape +def fused_rmsnorm_forward(x, weight, eps): + x_shape_start = x.shape - # Flatten input - x = x.view(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if weight.stride(-1) != 1: - weight = weight.contiguous() + # Flatten input + x = x.view(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if weight.stride(-1) != 1: + weight = weight.contiguous() - M, N = x.shape - y = torch.empty_like(x) - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - max_size = 65536 // x.element_size() - block_N = min(max_size, triton.next_power_of_2(N)) + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) - if N > block_N: - raise ValueError(f"N {N} must be <= {block_N=}") + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") - grid = lambda meta: (M,) - _rms_norm_fwd_kernel[grid]( - x, - x.stride(0), - y, - y.stride(0), - weight, - rstd, - eps, - M, - N, - block_N, - ) + grid = lambda meta: (M,) + _rms_norm_fwd_kernel[grid]( + x, + x.stride(0), + y, + y.stride(0), + weight, + rstd, + eps, + M, + N, + block_N, + ) + + y = y.reshape(x_shape_start) + return y, rstd + + +def fused_rmsnorm_backward(dy, x, weight, eps, rstd, x_shape_start): + # Flatten input and output gradients + dy = dy.view(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = dy.shape + dx = torch.empty_like(x) + dw = torch.empty_like(weight) + + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + rows_per_sm = math.ceil(M / sm_count) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (sm_count,) + _rms_norm_bwd_kernel_sm[grid]( + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, + ) + dw = _dw.sum(0).to(weight.dtype) + dx = dx.view(x_shape_start) + return dx, dw, None + + +class TritonFusedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, eps): + y, rstd = fused_rmsnorm_forward(x, weight, eps) ctx.eps = eps + ctx.x_shape_start = x.shape ctx.save_for_backward(x, weight, rstd) - ctx.x_shape_start = x_shape_start - y = y.reshape(x_shape_start) return y @staticmethod @@ -262,55 +312,144 @@ def backward(ctx, dy): eps = ctx.eps x_shape_start = ctx.x_shape_start - # Flatten input and output gradients - dy = dy.view(-1, dy.shape[-1]) - if dy.stride(-1) != 1: - dy = dy.contiguous() - - M, N = dy.shape - dx = torch.empty_like(x) - dw = torch.empty_like(weight) - - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - - max_size = 65536 // x.element_size() - block_N = min(max_size, triton.next_power_of_2(N)) - rows_per_sm = math.ceil(M / sm_count) - - if N > block_N: - raise ValueError(f"N {N} must be <= {block_N=}") - - grid = lambda meta: (sm_count,) - _rms_norm_bwd_kernel_sm[grid]( - x, - x.stride(0), - weight, - dy, - dy.stride(0), - dx, - dx.stride(0), - rstd, - _dw, - eps, - M, - N, - rows_per_sm, - block_N, - ) - dw = _dw.sum(0).to(weight.dtype) - dx = dx.view(x_shape_start) - return dx, dw, None + return fused_rmsnorm_backward(dy, x, weight, eps, rstd, x_shape_start) # expose fusedRMSNorm as a function def fused_rms_norm_fn( x, weight, - eps=1e-6, + eps, ): + # option 1: register forward and backward separately + return TritonFusedRMSNorm.apply( x, weight, eps, ) + + # option 2: register forward only, and register backward using torch.library.register_autograd + + # args = (x, weight, eps,) + # torch.library.opcheck(fused_rmsnorm_forward, args) + # return fused_rmsnorm_forward(x, weight, eps)[0] + + +# @torch.library.custom_op("torchtitan::fused_rmsnorm", mutates_args=()) +# def fused_rmsnorm_forward(x: Tensor, weight: Tensor, eps: float) -> Tuple[Tensor, Tensor]: +# x_shape_start = x.shape + +# # Flatten input +# x = x.view(-1, x.shape[-1]) +# if x.stride(-1) != 1: +# x = x.contiguous() +# if weight.stride(-1) != 1: +# weight = weight.contiguous() + +# M, N = x.shape +# y = torch.empty_like(x) +# rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + +# max_size = 65536 // x.element_size() +# block_N = min(max_size, triton.next_power_of_2(N)) + +# if N > block_N: +# raise ValueError(f"N {N} must be <= {block_N=}") + +# grid = lambda meta: (M,) +# _rms_norm_fwd_kernel[grid]( +# x, +# x.stride(0), +# y, +# y.stride(0), +# weight, +# rstd, +# eps, +# M, +# N, +# block_N, +# ) + +# y = y.reshape(x_shape_start) +# # return y +# return y, rstd + + +# def setup_context(ctx, inputs, output) -> Tensor: +# x, weight, eps = inputs +# y, rstd = output + +# x_shape_start = x.shape + +# ctx.eps = eps +# ctx.save_for_backward(x, weight, rstd) +# ctx.x_shape_start = x_shape_start + + +# def fused_rmsnorm_backward(ctx, dy, drstd): +# x, weight, rstd = ctx.saved_tensors +# eps = ctx.eps +# x_shape_start = ctx.x_shape_start + +# # Flatten input and output gradients +# dy = dy.view(-1, dy.shape[-1]) +# if dy.stride(-1) != 1: +# dy = dy.contiguous() + +# M, N = dy.shape +# dx = torch.empty_like(x) +# dw = torch.empty_like(weight) + +# sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count +# _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + +# max_size = 65536 // x.element_size() +# block_N = min(max_size, triton.next_power_of_2(N)) +# rows_per_sm = math.ceil(M / sm_count) + +# if N > block_N: +# raise ValueError(f"N {N} must be <= {block_N=}") + +# grid = lambda meta: (sm_count,) +# _rms_norm_bwd_kernel_sm[grid]( +# x, +# x.stride(0), +# weight, +# dy, +# dy.stride(0), +# dx, +# dx.stride(0), +# rstd, +# _dw, +# eps, +# M, +# N, +# rows_per_sm, +# block_N, +# ) +# dw = _dw.sum(0).to(weight.dtype) +# dx = dx.view(x_shape_start) +# return dx, dw, None + + +# torch.library.register_autograd("torchtitan::fused_rmsnorm", fused_rmsnorm_backward, setup_context=setup_context) + + +# @torch.library.register_fake("torchtitan::fused_rmsnorm") +# def _(x, weight, eps): +# x_shape_start = x.shape + +# # Flatten input +# x = x.view(-1, x.shape[-1]) +# if x.stride(-1) != 1: +# x = x.contiguous() +# # if weight.stride(-1) != 1: +# # weight = weight.contiguous() + +# M, N = x.shape +# y = torch.empty_like(x) +# rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + +# y = y.reshape(x_shape_start) +# return y, rstd