Skip to content

Commit

Permalink
Add RMSNormToL2Norm graph surgeon (#1594)
Browse files Browse the repository at this point in the history
## Describe your changes
New graph surgeon to simplify RMSNorm subgraph using L2Norm subgraph.
`rmsnorm = x/sqrt(mean(x**2)) = x/sqrt(1/N*sum(x**2)) =
sqrt(N)*x/sqrt(sum(x**2)) = sqrt(N)*L2Norm`

This is useful when quantizing the model for NPU since there are fewer
activations to quantize.

## Checklist before requesting a review
- [x] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [x] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
jambayk authored Feb 4, 2025
1 parent ea64800 commit 29758a5
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 1 deletion.
41 changes: 41 additions & 0 deletions docs/source/how-to/configure-workflows/onnx-graph-surgeon.md
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,44 @@ graph {
output: "zero_point"
}
```


### RMSNormToL2Norm

#### Description

Replace RMSNorm subgraph with L2Norm subgraph.

#### Example

Initial model graph:

```
RMSNorm pattern:
+-----------------------------------------------+
| |
| v
[Root] --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
(y=2) (axis=-1) (B=E-6)
```

After applying:

```json
{
"type": "GraphSurgeries",
"surgeries": [
{
"surgeon": "RMSNormToL2Norm"
}
]
}
```


Transformed model graph:

```
[Root] --> LpNormalization --> Mul
(p=2, axis=-1)
```
144 changes: 143 additions & 1 deletion olive/passes/onnx/graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import inspect
import logging
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Type
from typing import Any, ClassVar, Dict, List, Optional, Type

import numpy as np
import onnx
Expand All @@ -20,6 +20,7 @@
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
from olive.passes.onnx.onnx_dag import OnnxDAG
from olive.passes.pass_config import PassConfigParam

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -455,6 +456,147 @@ def __call__(self, model: ModelProto):
return self._add_zero_point(model, zero_point_value, zero_point_onnx_dtype, zero_point_np_dtype)


class RMSNormToL2Norm(Surgeon):
"""Replace RMSNorm subgraph with L2Norm subgraph.
RMSNorm pattern:
+-----------------------------------------------+
| |
| v
[Root] --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
(y=2) (axis=-1) (B=E-6)
Also handles the case where [Root] is multiplied with rsqrt leading to Div -> Mul
instead of dividing by sqrt.
L2Norm pattern:
[Root] --> LpNormalization --> Mul
(p=2, axis=-1)
Where the weight of the Mul node is multiplied by sqrt(N) where N is equal to the reduced axis size.
If the weight is all 1s, it is replaced with a 1D array of sqrt(N).
"""

def __init__(self):
pass

def __call__(self, model: ModelProto):
dag = OnnxDAG(model)

modified = 0
removed_nodes = set()
replaced_initializers = set()
for node_name in dag.get_node_names():
if node_name in removed_nodes or dag.get_node_op_type(node_name) != "Pow":
continue

rmsnorm_nodes = self.get_rmsnorm_nodes(node_name, dag)
if not rmsnorm_nodes:
continue

graph_idx = dag.get_graph_idx(node_name)

# name for the new L2Norm node
l2norm_node_name = node_name.replace("Pow", "L2Norm") if "Pow" in node_name else f"{node_name}_L2Norm"
node_output_name = dag.get_node_outputs(node_name)[0]
l2norm_node_output_name = (
node_output_name.replace("Pow", "L2Norm") if "Pow" in node_output_name else f"{node_output_name}_L2Norm"
)

# create L2Norm node
l2norm_node = onnx.helper.make_node(
"LpNormalization",
inputs=[dag.get_node_inputs(node_name)[0]],
outputs=[l2norm_node_output_name],
name=l2norm_node_name,
axis=-1,
p=2,
)

# Muliply weight by sqrt(N)
final_node_name = rmsnorm_nodes[-1]
final_node_children = dag.get_consumers(final_node_name)
# can be Cast or Mul
if len(final_node_children) != 1 or dag.get_node_op_type(final_node_children[0]) not in ["Cast", "Mul"]:
logger.debug("RMSNorm Pattern does not end with Cast or Mul. Found %s", final_node_children)
continue
final_node_output_name = dag.get_node_outputs(final_node_name)[0]

# Get the weight Mul node
if dag.get_node_op_type(final_node_children[0]) == "Mul":
rmsnorm_mul_node_name = final_node_children[0]
else:
cast_node_children = dag.get_consumers(final_node_children[0])
if len(cast_node_children) != 1 or dag.get_node_op_type(cast_node_children[0]) != "Mul":
logger.debug("RMSNorm Pattern does not end with Cast -> Mul. Found %s", cast_node_children)
continue
rmsnorm_mul_node_name = cast_node_children[0]

# Get the weight Mul node inputs
rmsnorm_weight_name = None
for input_name in dag.get_node_inputs(rmsnorm_mul_node_name):
if dag.is_initializer(input_name):
rmsnorm_weight_name = input_name
break
if rmsnorm_weight_name is None:
logger.debug("RMSNorm Mul node does not have initializer input")
continue

# update weight and replace initializer
if rmsnorm_weight_name not in replaced_initializers:
# rotated models have all 1s and might share initializer
# don't want to multiply by sqrt(N) multiple times even though it is fine in all 1s case
rmsnorm_weight = dag.get_initializer_np_array(rmsnorm_weight_name)
sqrt_n = np.sqrt(rmsnorm_weight.shape[-1])
if np.all(rmsnorm_weight == 1):
# this is possible in a quarot/spinquant rotated model
# Multiplying by 1D is probably faster
rmsnorm_weight = np.array([1], dtype=rmsnorm_weight.dtype)
rmsnorm_weight = sqrt_n * rmsnorm_weight

dag.replace_initializer(
onnx.numpy_helper.from_array(rmsnorm_weight, name=rmsnorm_weight_name), graph_idx
)
replaced_initializers.add(rmsnorm_weight_name)

# add and replace nodes
dag.add_node(l2norm_node, graph_idx)
dag.replace_node_input(final_node_children[0], final_node_output_name, l2norm_node_output_name)
for rms_node_name in rmsnorm_nodes[::-1]:
dag.remove_node(rms_node_name)
removed_nodes.add(rms_node_name)

modified += 1

if modified > 0:
logger.debug("Replaced %d RMSNorm nodes with L2Norm nodes", modified)

dag.update()
return dag.model

@staticmethod
def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> Optional[List[str]]:
# Two possible patterns:
# x / sqrt(mean(x^2) + epsilon): Pow -> ReduceMean -> Add -> Sqrt -> Div
# x * 1 / sqrt(mean(x^2) + epsilon): Pow -> ReduceMean -> Add -> Sqrt -> Div -> Mul
pattern = ["Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul"]
pow_node_input = dag.get_node_inputs(pow_node)[0]
current_node = pow_node
rmsnorm_nodes = [current_node]
for op_type in pattern[1:]:
child_nodes = dag.get_consumers(current_node)
if len(child_nodes) != 1 or dag.get_node_op_type(child_nodes[0]) != op_type:
return []
current_node = child_nodes[0]
rmsnorm_nodes.append(current_node)
if pow_node_input in dag.get_node_inputs(current_node):
# this can happen either at Div or Mul
# early stopping if it is Div
break

return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else []


class GraphSurgeries(Pass):
"""ONNX graph surgeries collections.
Expand Down
12 changes: 12 additions & 0 deletions olive/passes/onnx/onnx_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ def add_initializer(self, initializer: TensorProto, graph_idx: int, keep_input:
"""
self._add_special_input(initializer, graph_idx, SpecialInput.INITIALIZER, keep_input)

def replace_initializer(self, initializer: TensorProto, graph_idx: int):
"""Replace an initializer in the graph.
:param initializer: TensorProto of the initializer.
:param graph_idx: index of the graph in the model.
"""
name = initializer.name
if not self.is_initializer(name):
raise ValueError(f"{name} is not an initializer.")
proto_list = self.ios[name].proto[:-1] + [initializer]
self.ios[name].proto = proto_list

def add_value_info(self, value_info: ValueInfoProto, graph_idx: int, overwrite: bool = False):
"""Add a value info to the graph.
Expand Down
77 changes: 77 additions & 0 deletions test/unit_test/passes/onnx/test_graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

import numpy as np
import onnx
import pytest
import torch
from onnx import TensorProto, helper, numpy_helper
from onnxruntime import InferenceSession

from olive.model.handler.onnx import ONNXModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.graph_surgeries import GraphSurgeries
from olive.passes.onnx.onnx_dag import OnnxDAG


def get_onnx_model(model_path):
Expand Down Expand Up @@ -351,3 +355,76 @@ def test_expose_quantized_output(tmp_path):
assert np.allclose(
numpy_helper.to_array(zero_point_initializer), np.array([original_zero_point_value], dtype=zero_point_dtype)
), "Zero point value mismatch."


class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, use_rsqrt=True, use_cast=True, all_ones=False):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size) if all_ones else torch.randn(hidden_size))
self.variance_epsilon = eps
self.use_rsqrt = use_rsqrt
self.use_cast = use_cast

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
if self.use_cast:
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
if self.use_rsqrt:
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
else:
hidden_states = hidden_states / torch.sqrt(variance + self.variance_epsilon)
return self.weight * (hidden_states.to(input_dtype) if self.use_cast else hidden_states)


@pytest.mark.parametrize("use_rsqrt", [True, False])
@pytest.mark.parametrize("use_cast", [True, False])
@pytest.mark.parametrize("all_ones", [True, False])
def test_rmsnorm_to_l2norm(tmp_path, use_rsqrt, use_cast, all_ones):
# setup
hidden_size = 3
module = RMSNorm(hidden_size, use_rsqrt=use_rsqrt, use_cast=use_cast, all_ones=all_ones)
input_model_path = tmp_path / "input_model.onnx"
torch.onnx.export(
module, torch.randn(1, hidden_size), input_model_path, input_names=["x"], output_names=["y"], opset_version=20
)
input_model = ONNXModelHandler(input_model_path)

output_folder = str(tmp_path / "output")

p = create_pass_from_dict(
GraphSurgeries,
{"surgeries": [{"surgeon": "RMSNormToL2Norm"}]},
disable_search=True,
)

# execute
onnx_model = p.run(input_model, output_folder)

# assert
# check output values match
input_session = InferenceSession(input_model_path)
output_session = InferenceSession(onnx_model.model_path)
input_feed = {"x": np.random.randn(1, hidden_size).astype(np.float32)}
input_result = input_session.run(None, input_feed)
output_result = output_session.run(None, input_feed)
np.testing.assert_allclose(input_result[0], output_result[0], rtol=1e-5, atol=1e-5)
# count nodes
dag = OnnxDAG.from_model_path(onnx_model.model_path)
expected_num_nodes = 2 + 2 * int(use_cast)
assert len(dag.nodes) == expected_num_nodes
# check all ones case
if all_ones:
mul_name = None
for node in dag.get_node_names():
if dag.get_node_op_type(node) == "Mul":
mul_name = node
break
mul_weight_name = None
for input_name in dag.get_node_inputs(mul_name):
if dag.is_initializer(input_name):
mul_weight_name = input_name
break
mul_weight = dag.get_initializer_np_array(mul_weight_name)
assert mul_weight.shape == (1,)
assert np.allclose(mul_weight, np.sqrt(hidden_size))

0 comments on commit 29758a5

Please sign in to comment.