From 29758a5f77d2fc7090f6b9d2f05186eb7b04813d Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 4 Feb 2025 11:41:43 -0800 Subject: [PATCH] Add `RMSNormToL2Norm` graph surgeon (#1594) ## Describe your changes New graph surgeon to simplify RMSNorm subgraph using L2Norm subgraph. `rmsnorm = x/sqrt(mean(x**2)) = x/sqrt(1/N*sum(x**2)) = sqrt(N)*x/sqrt(sum(x**2)) = sqrt(N)*L2Norm` This is useful when quantizing the model for NPU since there are fewer activations to quantize. ## Checklist before requesting a review - [x] Add unit tests for this change. - [x] Make sure all tests can pass. - [x] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- .../configure-workflows/onnx-graph-surgeon.md | 41 +++++ olive/passes/onnx/graph_surgeries.py | 144 +++++++++++++++++- olive/passes/onnx/onnx_dag.py | 12 ++ .../passes/onnx/test_graph_surgeries.py | 77 ++++++++++ 4 files changed, 273 insertions(+), 1 deletion(-) diff --git a/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md b/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md index 2cac249c4..33246ca64 100644 --- a/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md +++ b/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md @@ -656,3 +656,44 @@ graph { output: "zero_point" } ``` + + +### RMSNormToL2Norm + +#### Description + +Replace RMSNorm subgraph with L2Norm subgraph. + +#### Example + +Initial model graph: + +``` +RMSNorm pattern: + +-----------------------------------------------+ + | | + | v +[Root] --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul + (y=2) (axis=-1) (B=E-6) +``` + +After applying: + +```json +{ + "type": "GraphSurgeries", + "surgeries": [ + { + "surgeon": "RMSNormToL2Norm" + } + ] +} +``` + + +Transformed model graph: + +``` +[Root] --> LpNormalization --> Mul + (p=2, axis=-1) +``` diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 3df770667..7ada8c07f 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -8,7 +8,7 @@ import inspect import logging from pathlib import Path -from typing import Any, ClassVar, Dict, List, Type +from typing import Any, ClassVar, Dict, List, Optional, Type import numpy as np import onnx @@ -20,6 +20,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model +from olive.passes.onnx.onnx_dag import OnnxDAG from olive.passes.pass_config import PassConfigParam logger = logging.getLogger(__name__) @@ -455,6 +456,147 @@ def __call__(self, model: ModelProto): return self._add_zero_point(model, zero_point_value, zero_point_onnx_dtype, zero_point_np_dtype) +class RMSNormToL2Norm(Surgeon): + """Replace RMSNorm subgraph with L2Norm subgraph. + + RMSNorm pattern: + +-----------------------------------------------+ + | | + | v + [Root] --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul + (y=2) (axis=-1) (B=E-6) + + Also handles the case where [Root] is multiplied with rsqrt leading to Div -> Mul + instead of dividing by sqrt. + + L2Norm pattern: + [Root] --> LpNormalization --> Mul + (p=2, axis=-1) + + Where the weight of the Mul node is multiplied by sqrt(N) where N is equal to the reduced axis size. + If the weight is all 1s, it is replaced with a 1D array of sqrt(N). + """ + + def __init__(self): + pass + + def __call__(self, model: ModelProto): + dag = OnnxDAG(model) + + modified = 0 + removed_nodes = set() + replaced_initializers = set() + for node_name in dag.get_node_names(): + if node_name in removed_nodes or dag.get_node_op_type(node_name) != "Pow": + continue + + rmsnorm_nodes = self.get_rmsnorm_nodes(node_name, dag) + if not rmsnorm_nodes: + continue + + graph_idx = dag.get_graph_idx(node_name) + + # name for the new L2Norm node + l2norm_node_name = node_name.replace("Pow", "L2Norm") if "Pow" in node_name else f"{node_name}_L2Norm" + node_output_name = dag.get_node_outputs(node_name)[0] + l2norm_node_output_name = ( + node_output_name.replace("Pow", "L2Norm") if "Pow" in node_output_name else f"{node_output_name}_L2Norm" + ) + + # create L2Norm node + l2norm_node = onnx.helper.make_node( + "LpNormalization", + inputs=[dag.get_node_inputs(node_name)[0]], + outputs=[l2norm_node_output_name], + name=l2norm_node_name, + axis=-1, + p=2, + ) + + # Muliply weight by sqrt(N) + final_node_name = rmsnorm_nodes[-1] + final_node_children = dag.get_consumers(final_node_name) + # can be Cast or Mul + if len(final_node_children) != 1 or dag.get_node_op_type(final_node_children[0]) not in ["Cast", "Mul"]: + logger.debug("RMSNorm Pattern does not end with Cast or Mul. Found %s", final_node_children) + continue + final_node_output_name = dag.get_node_outputs(final_node_name)[0] + + # Get the weight Mul node + if dag.get_node_op_type(final_node_children[0]) == "Mul": + rmsnorm_mul_node_name = final_node_children[0] + else: + cast_node_children = dag.get_consumers(final_node_children[0]) + if len(cast_node_children) != 1 or dag.get_node_op_type(cast_node_children[0]) != "Mul": + logger.debug("RMSNorm Pattern does not end with Cast -> Mul. Found %s", cast_node_children) + continue + rmsnorm_mul_node_name = cast_node_children[0] + + # Get the weight Mul node inputs + rmsnorm_weight_name = None + for input_name in dag.get_node_inputs(rmsnorm_mul_node_name): + if dag.is_initializer(input_name): + rmsnorm_weight_name = input_name + break + if rmsnorm_weight_name is None: + logger.debug("RMSNorm Mul node does not have initializer input") + continue + + # update weight and replace initializer + if rmsnorm_weight_name not in replaced_initializers: + # rotated models have all 1s and might share initializer + # don't want to multiply by sqrt(N) multiple times even though it is fine in all 1s case + rmsnorm_weight = dag.get_initializer_np_array(rmsnorm_weight_name) + sqrt_n = np.sqrt(rmsnorm_weight.shape[-1]) + if np.all(rmsnorm_weight == 1): + # this is possible in a quarot/spinquant rotated model + # Multiplying by 1D is probably faster + rmsnorm_weight = np.array([1], dtype=rmsnorm_weight.dtype) + rmsnorm_weight = sqrt_n * rmsnorm_weight + + dag.replace_initializer( + onnx.numpy_helper.from_array(rmsnorm_weight, name=rmsnorm_weight_name), graph_idx + ) + replaced_initializers.add(rmsnorm_weight_name) + + # add and replace nodes + dag.add_node(l2norm_node, graph_idx) + dag.replace_node_input(final_node_children[0], final_node_output_name, l2norm_node_output_name) + for rms_node_name in rmsnorm_nodes[::-1]: + dag.remove_node(rms_node_name) + removed_nodes.add(rms_node_name) + + modified += 1 + + if modified > 0: + logger.debug("Replaced %d RMSNorm nodes with L2Norm nodes", modified) + + dag.update() + return dag.model + + @staticmethod + def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> Optional[List[str]]: + # Two possible patterns: + # x / sqrt(mean(x^2) + epsilon): Pow -> ReduceMean -> Add -> Sqrt -> Div + # x * 1 / sqrt(mean(x^2) + epsilon): Pow -> ReduceMean -> Add -> Sqrt -> Div -> Mul + pattern = ["Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul"] + pow_node_input = dag.get_node_inputs(pow_node)[0] + current_node = pow_node + rmsnorm_nodes = [current_node] + for op_type in pattern[1:]: + child_nodes = dag.get_consumers(current_node) + if len(child_nodes) != 1 or dag.get_node_op_type(child_nodes[0]) != op_type: + return [] + current_node = child_nodes[0] + rmsnorm_nodes.append(current_node) + if pow_node_input in dag.get_node_inputs(current_node): + # this can happen either at Div or Mul + # early stopping if it is Div + break + + return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else [] + + class GraphSurgeries(Pass): """ONNX graph surgeries collections. diff --git a/olive/passes/onnx/onnx_dag.py b/olive/passes/onnx/onnx_dag.py index 94f08d517..c0f73a4e9 100644 --- a/olive/passes/onnx/onnx_dag.py +++ b/olive/passes/onnx/onnx_dag.py @@ -243,6 +243,18 @@ def add_initializer(self, initializer: TensorProto, graph_idx: int, keep_input: """ self._add_special_input(initializer, graph_idx, SpecialInput.INITIALIZER, keep_input) + def replace_initializer(self, initializer: TensorProto, graph_idx: int): + """Replace an initializer in the graph. + + :param initializer: TensorProto of the initializer. + :param graph_idx: index of the graph in the model. + """ + name = initializer.name + if not self.is_initializer(name): + raise ValueError(f"{name} is not an initializer.") + proto_list = self.ios[name].proto[:-1] + [initializer] + self.ios[name].proto = proto_list + def add_value_info(self, value_info: ValueInfoProto, graph_idx: int, overwrite: bool = False): """Add a value info to the graph. diff --git a/test/unit_test/passes/onnx/test_graph_surgeries.py b/test/unit_test/passes/onnx/test_graph_surgeries.py index d2db03fbf..daf1611bc 100644 --- a/test/unit_test/passes/onnx/test_graph_surgeries.py +++ b/test/unit_test/passes/onnx/test_graph_surgeries.py @@ -6,11 +6,15 @@ import numpy as np import onnx +import pytest +import torch from onnx import TensorProto, helper, numpy_helper +from onnxruntime import InferenceSession from olive.model.handler.onnx import ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict from olive.passes.onnx.graph_surgeries import GraphSurgeries +from olive.passes.onnx.onnx_dag import OnnxDAG def get_onnx_model(model_path): @@ -351,3 +355,76 @@ def test_expose_quantized_output(tmp_path): assert np.allclose( numpy_helper.to_array(zero_point_initializer), np.array([original_zero_point_value], dtype=zero_point_dtype) ), "Zero point value mismatch." + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, use_rsqrt=True, use_cast=True, all_ones=False): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size) if all_ones else torch.randn(hidden_size)) + self.variance_epsilon = eps + self.use_rsqrt = use_rsqrt + self.use_cast = use_cast + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + if self.use_cast: + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + if self.use_rsqrt: + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + else: + hidden_states = hidden_states / torch.sqrt(variance + self.variance_epsilon) + return self.weight * (hidden_states.to(input_dtype) if self.use_cast else hidden_states) + + +@pytest.mark.parametrize("use_rsqrt", [True, False]) +@pytest.mark.parametrize("use_cast", [True, False]) +@pytest.mark.parametrize("all_ones", [True, False]) +def test_rmsnorm_to_l2norm(tmp_path, use_rsqrt, use_cast, all_ones): + # setup + hidden_size = 3 + module = RMSNorm(hidden_size, use_rsqrt=use_rsqrt, use_cast=use_cast, all_ones=all_ones) + input_model_path = tmp_path / "input_model.onnx" + torch.onnx.export( + module, torch.randn(1, hidden_size), input_model_path, input_names=["x"], output_names=["y"], opset_version=20 + ) + input_model = ONNXModelHandler(input_model_path) + + output_folder = str(tmp_path / "output") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RMSNormToL2Norm"}]}, + disable_search=True, + ) + + # execute + onnx_model = p.run(input_model, output_folder) + + # assert + # check output values match + input_session = InferenceSession(input_model_path) + output_session = InferenceSession(onnx_model.model_path) + input_feed = {"x": np.random.randn(1, hidden_size).astype(np.float32)} + input_result = input_session.run(None, input_feed) + output_result = output_session.run(None, input_feed) + np.testing.assert_allclose(input_result[0], output_result[0], rtol=1e-5, atol=1e-5) + # count nodes + dag = OnnxDAG.from_model_path(onnx_model.model_path) + expected_num_nodes = 2 + 2 * int(use_cast) + assert len(dag.nodes) == expected_num_nodes + # check all ones case + if all_ones: + mul_name = None + for node in dag.get_node_names(): + if dag.get_node_op_type(node) == "Mul": + mul_name = node + break + mul_weight_name = None + for input_name in dag.get_node_inputs(mul_name): + if dag.is_initializer(input_name): + mul_weight_name = input_name + break + mul_weight = dag.get_initializer_np_array(mul_weight_name) + assert mul_weight.shape == (1,) + assert np.allclose(mul_weight, np.sqrt(hidden_size))