Skip to content

Commit

Permalink
add make_faster_guard
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Feb 3, 2025
1 parent 7ca88d1 commit 7ae36b8
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...psdb import NO_FALLBACK_CODES
from ...utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_ENABLE_GUARD_TREE,
BreakGraphError,
CompileCountInfo,
FallbackError,
Expand Down Expand Up @@ -135,7 +136,8 @@ def lookup(
f"[Cache]: Cache hit, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n",
)
return custom_code
else:
elif not ENV_SOT_ENABLE_GUARD_TREE.get():
# TODO(zrr1999): remove condition after faster guard tree support error analysis
log_do(
4,
self.analyse_guard_global_object(guard_fn),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from ...symbolic.symbolic_context import SymbolicTraceContext
from ...utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_ENABLE_FASTER_GUARD,
ENV_SOT_ENABLE_GUARD_TREE,
NameGenerator,
SotUndefinedVar,
inner_error_default_handler,
Expand All @@ -54,7 +56,7 @@
)
from ...utils.exceptions import BreakGraphError, SotExtraInfo
from ..instruction_utils import get_instructions
from .guard import Guard, StringifiedExpression, make_guard
from .guard import Guard, StringifiedExpression, make_faster_guard, make_guard
from .mutable_data import MutationDel, MutationNew, MutationSet
from .pycode_generator import PyCodeGen
from .side_effects import (
Expand Down Expand Up @@ -316,6 +318,19 @@ def collect(inp):
@property
@event_register("guard_fn")
def guard_fn(self) -> Guard:
if (
ENV_SOT_ENABLE_FASTER_GUARD.get()
and ENV_SOT_ENABLE_GUARD_TREE.get()
):
guard_nodes: list[paddle.framework.core.GuardNode] = []
with EventGuard("guard_fn: find vars and make faster guard"):
for variable in find_traceable_vars(
self.input_variables + list(self._global_guarded_variables)
):
guard_nodes.extend(variable.make_faster_guard())

return make_faster_guard(guard_nodes)

with switch_symbol_registry():
guards: list[StringifiedExpression] = []
with EventGuard("guard_fn: find vars and make stringified guard"):
Expand Down
35 changes: 35 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ def make_guard(stringified_guards: list[StringifiedExpression]) -> Guard:
return guard


def make_faster_guard(
guard_nodes: list[paddle.framework.core.GuardNode],
) -> Guard:
with EventGuard("make_guard"):
num_guards = len(guard_nodes)
if not num_guards:
guard = lambda frame: True
return guard
guard_tree = paddle.framework.core.GuardTree([guard_nodes])
guard = lambda frame: guard_tree.check(frame) is not None
return guard


def support_weak_ref(obj):
if isinstance(obj, types.FunctionType):
return True
Expand All @@ -184,6 +197,28 @@ def guard_log():
return wrapper


def check_faster_guard(
fn: Callable[[CheckGuardInputT], list[paddle.framework.core.GuardNode]],
) -> Callable[[CheckGuardInputT], list[paddle.framework.core.GuardNode]]:
def wrapper(
self: CheckGuardInputT,
) -> list[paddle.framework.core.GuardNode]:
assert (
self.tracker.is_traceable()
), "Cannot make guard from a non-tracable guard variable."

def guard_log():
frame_value_tracer = self.tracker.trace_value_from_frame()
print(
f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.registered_expr}"
)

log_do(4, guard_log)
return fn(self)

return wrapper


@check_guard
def object_equal_stringified_guard(self) -> list[StringifiedExpression]:
frame_value_tracer = self.tracker.trace_value_from_frame()
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..guard import (
FasterStringifiedExpression,
StringifiedExpression,
check_faster_guard,
check_guard,
union_free_vars,
)
Expand Down Expand Up @@ -362,6 +363,16 @@ def debug_name(self, name):
def __hash__(self):
return hash(self.id)

@check_faster_guard
def make_faster_guard(self) -> list[paddle.framework.core.GuardNode]:
frame_value_tracer = self.tracker.guard_tree_expr_node()
return [
paddle.framework.core.GuardNode(
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
frame_value_tracer,
)
]

@check_guard
def make_stringified_guard(self) -> list[StringifiedExpression]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from ..guard import (
FasterStringifiedExpression,
StringifiedExpression,
check_faster_guard,
check_guard,
object_equal_stringified_guard,
stringify_pyobject,
Expand Down Expand Up @@ -451,6 +452,31 @@ def out_var_name(self):
def _reconstruct(self, codegen: PyCodeGen):
codegen.gen_load_fast(self.out_var_name)

@check_faster_guard
def make_faster_guard(self) -> list[paddle.framework.core.GuardNode]:
assert paddle.framework.use_pir_api(), "Only support PIR"
expr_node = self.tracker.guard_tree_expr_node()
meta = self.origin_meta
return [
# Check shape
paddle.framework.core.GuardNode(
paddle.framework.core.ShapeMatchGuard(meta.shape),
expr_node,
),
# Check dtype
paddle.framework.core.GuardNode(
paddle.framework.core.DtypeMatchGuard(meta.dtype),
expr_node,
),
# Check stop_gradient
paddle.framework.core.GuardNode(
paddle.framework.core.ValueMatchGuard(meta.stop_gradient),
paddle.framework.core.AttributeExprNode(
expr_node, "stop_gradient"
),
),
]

@check_guard
def make_stringified_guard(self) -> list[StringifiedExpression]:
frame_value_tracer = self.tracker.trace_value_from_frame()
Expand Down

0 comments on commit 7ae36b8

Please sign in to comment.