Skip to content

Commit

Permalink
[exir] Enable dict for sym shape eval pass
Browse files Browse the repository at this point in the history
Differential Revision: D61728068

Pull Request resolved: #4872
  • Loading branch information
larryliu0820 authored Aug 23, 2024
1 parent 26e921e commit 48f4eee
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 19 deletions.
35 changes: 26 additions & 9 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@
replace_sdpa_with_custom_op,
)
from executorch.examples.models.llava.model import LlavaModel
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)

from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.extension.llm.export.builder import DType, LLMEdgeManager
from executorch.extension.llm.tokenizer.tokenizer import Tokenizer
Expand Down Expand Up @@ -199,7 +207,23 @@ def export_all(llava_model: LlavaModel):
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

executorch_program = lowered_and_edge.to_executorch()
executorch_program = lowered_and_edge.to_executorch(
ExecutorchBackendConfig(
extract_constant_segment=True,
extract_delegate_segments=True,
passes=[
QuantFusionPass(),
],
memory_planning_pass=MemoryPlanningPass("greedy", alloc_graph_input=False),
sym_shape_eval_pass={
"image_encoder": ConstraintBasedSymShapeEvalPass(),
},
)
)
for execution_plan in executorch_program._emitter_output.program.execution_plan:
logging.info(
f"Required memory for activation in bytes: {execution_plan.non_const_buffer_sizes}"
)
return executorch_program


Expand Down Expand Up @@ -253,13 +277,6 @@ def main():

with open(args.pte_name, "wb") as f:
executorch_program.write_to_file(f)
logging.info(
"Required memory for activation in bytes: {}".format(
executorch_program._emitter_output.program.execution_plan[
0
].non_const_buffer_sizes
),
)
logging.info(f"Exported ExecuTorch program to {args.pte_name}")

# artifacts
Expand Down
11 changes: 11 additions & 0 deletions examples/models/llava/test/test_pte.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import sys

import torch
Expand All @@ -17,6 +18,10 @@
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.DEBUG, format=FORMAT)


def main():
args = sys.argv[1:]
llava_module = _load_for_executorch(args[0])
Expand All @@ -41,26 +46,32 @@ def main():
start_pos += pte_prefill_before_img.shape[1]

# pte prefill image
logging.warning("Image encoder started")
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
logging.warning("Image encoder finished")
logging.warning("Image token prefill started")
pte_prefill_img = llava_module.run_method(
"text_model",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
),
)[0]
logging.warning("Image token prefill finished")
print(pte_prefill_img)

start_pos += pte_prefill_img.shape[1]

# pte prefill prompt after img
logging.warning("Text token prefill started")
pte_embeds_after_img = llava_module.run_method(
"token_embedding", (prompt_after_image,)
)[0]
pte_prefill_after_img = llava_module.run_method(
"text_model",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]
logging.warning("Text token prefill finished")
print(pte_prefill_after_img)

# being tested, using llama_transformer
Expand Down
7 changes: 6 additions & 1 deletion exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ class ExecutorchBackendConfig:
# If provided, the minimum alignment of delegate data in the program. Must
# be a power of 2. If not provided, uses the value in the schema file.
delegate_alignment: Optional[int] = None
sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass()

# A single sym shape eval pass can be defined for all the programs in the
# EdgeProgramManager or can be defined per program.
sym_shape_eval_pass: Union[PassType, Dict[str, PassType]] = (
HintBasedSymShapeEvalPass()
)

# If set to true, view_copy operations will be converted to lightweight
# view operations in the ET runtime
Expand Down
41 changes: 32 additions & 9 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import torch
import torch._export

from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize._cord import Cord
from executorch.exir.backend.backend_api import to_backend
Expand All @@ -23,6 +22,7 @@
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
from executorch.exir.error import ExportError
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.pass_base import PassBase
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import (
base_post_op_replace_passes,
Expand Down Expand Up @@ -641,25 +641,48 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
return new_ep


def pre_memory_planning_passes(config: ExecutorchBackendConfig) -> List[PassType]:
def pre_memory_planning_passes(
config: ExecutorchBackendConfig, name: Optional[str] = None
) -> List[PassType]:
"""
Returns a list of passes to run before memory planning.
Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass.
"""
# Handle symbolic shape eval pass
if isinstance(config.sym_shape_eval_pass, dict):
default_pass = ExecutorchBackendConfig().sym_shape_eval_pass
if not name:
sym_shape_eval_pass = default_pass
# pyre-ignore: Undefined attribute [16]
sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass)
elif isinstance(config.sym_shape_eval_pass, PassBase):
sym_shape_eval_pass = config.sym_shape_eval_pass
else:
raise RuntimeError(
f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}"
)
if config.remove_view_copy:
# pyre-ignore
return [
NormalizeViewCopyBasePass(),
dead_code_elimination_pass,
ReplaceViewCopyWithViewPass(),
config.sym_shape_eval_pass,
sym_shape_eval_pass,
config.to_out_var_pass,
]
else:
# pyre-ignore
return [
config.sym_shape_eval_pass,
sym_shape_eval_pass,
config.to_out_var_pass,
]


def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]:
def edge_to_executorch_passes(
config: ExecutorchBackendConfig, name: Optional[str] = None
) -> List[PassType]:
"""
Returns a list of passes to lower from edge to executorch.
Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
"""
passes: List[PassType] = [
*config.passes,
SpecPropPass(),
Expand All @@ -668,7 +691,7 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
# there exists an unbacked symint operation.
EdgeToBackendOpsPass(),
RemoveGraphAssertsPass(),
] + pre_memory_planning_passes(config)
] + pre_memory_planning_passes(config, name)

return passes

Expand Down Expand Up @@ -1234,7 +1257,7 @@ def to_executorch(
program = unsafe_remove_auto_functionalized_pass(program)
gm, new_signature = insert_write_back_for_buffers_pass(program)
new_gm = program.graph_module
for p in edge_to_executorch_passes(config):
for p in edge_to_executorch_passes(config, name):
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
Expand Down

0 comments on commit 48f4eee

Please sign in to comment.