Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable TritonFusedRMSNorm with local_map annotation #364

Merged
merged 16 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
expecttest
pytest
pytest-cov
pre-commit
72 changes: 72 additions & 0 deletions test/test_fused_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)

from torchtitan.models.norms import fused_rms_norm_fn


class TestFusedRMSNorm(DTensorTestBase):
@property
def world_size(self):
return 4

@skip_if_lt_x_gpu(4)
@with_comms
def test_fused_rms_norm(self):
XilunWu marked this conversation as resolved.
Show resolved Hide resolved
mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
)
x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1)
w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate

dist_x = distribute_tensor(x, mesh, [Shard(1)])
dist_w = distribute_tensor(w, mesh, [Replicate()])

x = x.clone().detach()
w = w.clone().detach().requires_grad_()

self.assertEqual(dist_x.full_tensor(), x)
self.assertEqual(dist_w.full_tensor(), w)

# fused rmsnorm on DTensor
comm_mode = CommDebugMode()
# fused rmsnorm
with comm_mode:
dist_out = fused_rms_norm_fn(dist_x, dist_w)

self.assertEqual(comm_mode.get_total_counts(), 0)

with comm_mode:
dist_grad_out = torch.ones_like(dist_out)
dist_out.backward(dist_grad_out)

self.assertEqual(comm_mode.get_total_counts(), 0)

# fused rmsnorm on Tensor
out = fused_rms_norm_fn(x, w)
grad_out = torch.ones_like(out)
out.backward(grad_out)

self.assertEqual(dist_out.full_tensor(), out)
self.assertEqual(dist_grad_out.full_tensor(), grad_out)


if __name__ == "__main__":
run_tests()
13 changes: 11 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,17 @@ def build_test_list():
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"Eager mode 2DParallel",
"eager_2d",
"Eager mode 2DParallel with rmsnorm",
"eager_2d_rmsnorm",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
],
],
"Eager mode 2DParallel with fused_rmsnorm",
"eager_2d_fused_rmsnorm",
),
OverrideDefinitions(
[
Expand Down
15 changes: 15 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

import math

from functools import partial

import torch
import torch.nn as nn

import triton
import triton.language as tl

from torch.distributed._tensor import Partial, Replicate, Shard
from torch.distributed._tensor.experimental import local_map


def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Expand Down Expand Up @@ -214,6 +219,11 @@ def _rms_norm_bwd_kernel_sm(


class TritonFusedRMSNorm(torch.autograd.Function):
@partial(
XilunWu marked this conversation as resolved.
Show resolved Hide resolved
local_map,
out_placements=[Shard(1)],
in_placements=(None, [Shard(1)], [Replicate()], None),
)
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
Expand Down Expand Up @@ -256,6 +266,11 @@ def forward(ctx, x, weight, eps):
y = y.reshape(x_shape_start)
return y

@partial(
local_map,
out_placements=([Shard(1)], [Partial()], None),
in_placements=(None, [Shard(1)]),
)
@staticmethod
def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm."
)

tp_mesh = world_mesh["tp"]
(
row_parallel_strategy,
Expand Down
Loading