Skip to content

Commit

Permalink
removes fused_rmsnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
allenwang28 committed Feb 6, 2025
1 parent 690f299 commit aa02c15
Show file tree
Hide file tree
Showing 10 changed files with 7 additions and 369 deletions.
11 changes: 0 additions & 11 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,6 @@ def estimate_memory(job_config: JobConfig):
# Get the world size
world_size = int(os.environ["WORLD_SIZE"])

# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and not job_config.memory_estimation.disable_fake_mode
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"

if job_config.model.norm_type == "compiled_rmsnorm":
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"
Expand Down
10 changes: 0 additions & 10 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,6 @@ def build_test_list():
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2",
"--model.norm_type=fused_rmsnorm",
],
],
"2D eager with fused_rmsnorm",
"2d_eager_fused_rmsnorm",
),
OverrideDefinitions(
[
[
Expand Down
72 changes: 0 additions & 72 deletions tests/unit_tests/test_fused_rms_norm_dtensor.py

This file was deleted.

3 changes: 2 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def __init__(self):
"--model.norm_type",
type=str,
default="rmsnorm",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
choices=["layernorm", "np_layernorm", "rmsnorm"],
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down
Loading

0 comments on commit aa02c15

Please sign in to comment.