Skip to content

Commit

Permalink
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
Browse files Browse the repository at this point in the history
…nstant_folding_main
  • Loading branch information
SF-N committed Jan 28, 2025
2 parents 9676334 + 9e14480 commit 1543688
Show file tree
Hide file tree
Showing 20 changed files with 196 additions and 267 deletions.
14 changes: 7 additions & 7 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def visit_ScanOperator(

definition = itir.Lambda(params=func_definition.params, expr=new_body)

body = im.as_fieldop(im.call("scan")(definition, forward, init))(*stencil_args)
body = im.as_fieldop(im.scan(definition, forward, init))(*stencil_args)

return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body)

Expand Down Expand Up @@ -360,7 +360,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id

def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall:
return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t)
return _map(im.lambda_("val")(im.cast_("val", str(new_type))), (expr,), t)

if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return create_cast(obj, (node.args[0].type,))
Expand Down Expand Up @@ -409,7 +409,7 @@ def _make_reduction_expr(
# TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2)))
it = self.visit(node.args[0], **kwargs)
assert isinstance(node.kwargs["axis"].type, ts.DimensionType)
val = im.call(im.call("reduce")(op, init_expr))
val = im.reduce(op, init_expr)
return im.op_as_fieldop(val)(it)

def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
Expand Down Expand Up @@ -462,14 +462,14 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr:
def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr:
return self._make_literal(node.value, node.type)

def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
def _lower_and_map(self, op: itir.Lambda | str, *args: Any, **kwargs: Any) -> itir.FunCall:
return _map(
op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args)
)


def _map(
op: itir.Expr | str,
op: itir.Lambda | str,
lowered_args: tuple,
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
Expand All @@ -487,9 +487,9 @@ def _map(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.call("map_")(op)
op = im.map_(op)

return im.op_as_fieldop(im.call(op))(*lowered_args)
return im.op_as_fieldop(op)(*lowered_args)


class FieldOperatorLoweringError(Exception): ...
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]:
attribute which can be anything.
>>> from gt4py.next.iterator.ir_utils import ir_makers as im
>>> node = im.call("plus")(1, 2)
>>> node = im.plus(1, 2)
>>> is_call_to(node, "plus")
True
>>> is_call_to(node, "minus")
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain:
assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains)
for dim in domains[0].ranges.keys():
start = functools.reduce(
lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr),
lambda current_expr, el_expr: im.minimum(current_expr, el_expr),
[domain.ranges[dim].start for domain in domains],
)
stop = functools.reduce(
lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr),
lambda current_expr, el_expr: im.maximum(current_expr, el_expr),
[domain.ranges[dim].stop for domain in domains],
)
# constant fold expression to keep the tree small
Expand Down
56 changes: 42 additions & 14 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,6 @@ def divides_(left, right):
return call("divides")(left, right)


def floordiv_(left, right):
"""Create a floor division FunCall, shorthand for ``call("floordiv")(left, right)``."""
# TODO(tehrengruber): Use int(floor(left/right)) as soon as we support integer casting
# and remove the `floordiv` builtin again.
return call("floordiv")(left, right)


def mod(left, right):
"""Create a modulo FunCall, shorthand for ``call("mod")(left, right)``."""
return call("mod")(left, right)


def and_(left, right):
"""Create an and_ FunCall, shorthand for ``call("and_")(left, right)``."""
return call("and_")(left, right)
Expand Down Expand Up @@ -302,7 +290,10 @@ def shift(offset, value=None):
offset = ensure_offset(offset)
args = [offset]
if value is not None:
value = ensure_offset(value)
if isinstance(value, int):
value = ensure_offset(value)
elif isinstance(value, str):
value = ref(value)
args.append(value)
return call(call("shift")(*args))

Expand Down Expand Up @@ -469,7 +460,7 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal


def op_as_fieldop(
op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None
op: str | itir.SymRef | itir.Lambda | Callable, domain: Optional[itir.FunCall] = None
) -> Callable[..., itir.FunCall]:
"""
Promotes a function `op` to a field_operator.
Expand Down Expand Up @@ -536,3 +527,40 @@ def index(dim: common.Dimension) -> itir.FunCall:
def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))


def reduce(op, expr):
"""Create a `reduce` call."""
return call(call("reduce")(op, expr))


def scan(expr, forward, init):
"""Create a `scan` call."""
return call("scan")(expr, forward, init)


def list_get(list_idx, list_):
"""Create a `list_get` call."""
return call("list_get")(list_idx, list_)


def maximum(expr1, expr2):
"""Create a `maximum` call."""
return call("maximum")(expr1, expr2)


def minimum(expr1, expr2):
"""Create a `minimum` call."""
return call("minimum")(expr1, expr2)


def cast_(expr, dtype: ts.ScalarType | str):
"""Create a `cast_` call."""
if isinstance(dtype, ts.ScalarType):
dtype = dtype.kind.name.lower()
return call("cast_")(expr, dtype)


def can_deref(expr):
"""Create a `can_deref` call."""
return call("can_deref")(expr)
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node:
cond, true_val, false_val = node.args[1].args
return im.if_(
cond,
self.visit(im.call("list_get")(list_idx, true_val)),
self.visit(im.call("list_get")(list_idx, false_val)),
self.visit(im.list_get(list_idx, true_val)),
self.visit(im.list_get(list_idx, false_val)),
)
if cpm.is_call_to(node.args[1], "neighbors"):
offset_tag = node.args[1].args[0]
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/inline_fundefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program:
>>> fun1 = itir.FunctionDefinition(
... id="fun1",
... params=[im.sym("a")],
... expr=im.call("deref")("a"),
... expr=im.deref("a"),
... )
>>> fun2 = itir.FunctionDefinition(
... id="fun2",
... params=[im.sym("a")],
... expr=im.call("deref")("a"),
... expr=im.deref("a"),
... )
>>> program = itir.Program(
... id="testee",
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/inline_lifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,11 @@ def visit_FunCall(
if len(args) == 0:
return im.literal_from_value(True)

res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]])
res = im.can_deref(args[0])
for arg in args[1:]:
res = ir.FunCall(
fun=ir.SymRef(id="and_"),
args=[res, ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[arg])],
args=[res, im.can_deref(arg)],
)
return res
elif (
Expand Down
52 changes: 11 additions & 41 deletions src/gt4py/next/iterator/transforms/unroll_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift


Expand Down Expand Up @@ -85,34 +85,6 @@ def _get_connectivity(
return connectivities[0]


def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall:
return itir.FunCall(
fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets),
args=[iterator],
location=iterator.location,
)


def _make_deref(iterator: itir.Expr) -> itir.FunCall:
return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location)


def _make_can_deref(iterator: itir.Expr) -> itir.FunCall:
return itir.FunCall(
fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location
)


def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall:
return itir.FunCall(
fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], location=cond.location
)


def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall:
return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location)


@dataclasses.dataclass(frozen=True)
class UnrollReduce(PreserveLocationVisitor, NodeTranslator):
# we use one UID generator per instance such that the generated ids are
Expand All @@ -130,27 +102,25 @@ def _visit_reduce(
max_neighbors = connectivity_type.max_neighbors
has_skip_values = connectivity_type.has_skip_values

acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc"))
offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i"))
step = itir.SymRef(id=self.uids.sequential_id(prefix="_step"))
acc: str = self.uids.sequential_id(prefix="_acc")
offset: str = self.uids.sequential_id(prefix="_i")
step: str = self.uids.sequential_id(prefix="_step")

assert isinstance(node.fun, itir.FunCall)
fun, init = node.fun.args

elems = [_make_list_get(offset, arg) for arg in node.args]
step_fun: itir.Expr = itir.FunCall(fun=fun, args=[acc, *elems])
elems = [im.list_get(offset, arg) for arg in node.args]
step_fun: itir.Expr = im.call(fun)(acc, *elems)
if has_skip_values:
check_arg = next(_get_neighbors_args(node.args))
offset_tag, it = check_arg.args
can_deref = _make_can_deref(_make_shift([offset_tag, offset], it))
step_fun = _make_if(can_deref, step_fun, acc)
step_fun = itir.Lambda(params=[itir.Sym(id=acc.id), itir.Sym(id=offset.id)], expr=step_fun)
can_deref = im.can_deref(im.shift(offset_tag, offset)(it))
step_fun = im.if_(can_deref, step_fun, acc)
step_fun = im.lambda_(acc, offset)(step_fun)
expr = init
for i in range(max_neighbors):
expr = itir.FunCall(fun=step, args=[expr, itir.OffsetLiteral(value=i)])
expr = itir.FunCall(
fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun]
)
expr = im.call(step)(expr, itir.OffsetLiteral(value=i))
expr = im.let(step, step_fun)(expr)

return expr

Expand Down
44 changes: 18 additions & 26 deletions tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def foo(bar: int64, alpha: int64) -> int64:
parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.call("multiplies")("alpha", "bar")
reference = im.multiplies_("alpha", "bar")

assert lowered.expr == reference

Expand Down Expand Up @@ -297,7 +297,7 @@ def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]):
parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a")
reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.cast_("val", "int32"))))("a")

assert lowered.expr == reference

Expand All @@ -310,7 +310,7 @@ def foo(a: float64):
lowered = FieldOperatorLowering.apply(parsed)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.call("cast_")("a", "int32")
reference = im.cast_("a", "int32")

assert lowered_inlined.expr == reference

Expand Down Expand Up @@ -341,7 +341,7 @@ def foo(a: tuple[gtx.Field[[TDim], float64], float64]):

reference = im.make_tuple(
im.cast_as_fieldop("int32")(im.tuple_get(0, "a")),
im.call("cast_")(im.tuple_get(1, "a"), "int32"),
im.cast_(im.tuple_get(1, "a"), "int32"),
)

assert lowered_inlined.expr == reference
Expand Down Expand Up @@ -551,7 +551,7 @@ def foo(a: gtx.Field[[TDim], "int32"]) -> gtx.Field[[TDim], "int32"]:

reference = im.let(
ssa.unique_name("tmp", 0),
im.call("plus")(
im.plus(
im.literal("1", "int32"),
im.literal("1", "int32"),
),
Expand Down Expand Up @@ -656,7 +656,7 @@ def foo() -> bool:
parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.call("greater")(
reference = im.greater(
im.literal("3", "int32"),
im.literal("4", "int32"),
)
Expand Down Expand Up @@ -761,11 +761,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]):
lowered = FieldOperatorLowering.apply(parsed)

reference = im.op_as_fieldop(
im.call(
im.call("reduce")(
"plus",
im.literal(value="0", typename="float64"),
)
im.reduce(
"plus",
im.literal(value="0", typename="float64"),
)
)(im.as_fieldop_neighbors("V2E", "edge_f"))

Expand All @@ -780,11 +778,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]):
lowered = FieldOperatorLowering.apply(parsed)

reference = im.op_as_fieldop(
im.call(
im.call("reduce")(
"maximum",
im.literal(value=str(np.finfo(np.float64).min), typename="float64"),
)
im.reduce(
"maximum",
im.literal(value=str(np.finfo(np.float64).min), typename="float64"),
)
)(im.as_fieldop_neighbors("V2E", "edge_f"))

Expand All @@ -799,11 +795,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]):
lowered = FieldOperatorLowering.apply(parsed)

reference = im.op_as_fieldop(
im.call(
im.call("reduce")(
"minimum",
im.literal(value=str(np.finfo(np.float64).max), typename="float64"),
)
im.reduce(
"minimum",
im.literal(value=str(np.finfo(np.float64).max), typename="float64"),
)
)(im.as_fieldop_neighbors("V2E", "edge_f"))

Expand All @@ -828,11 +822,9 @@ def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]
im.as_fieldop_neighbors("V2E", "e1"),
)(
im.op_as_fieldop(
im.call(
im.call("reduce")(
"plus",
im.literal(value="0", typename="float64"),
)
im.reduce(
"plus",
im.literal(value="0", typename="float64"),
)
)(mapped)
)
Expand Down
Loading

0 comments on commit 1543688

Please sign in to comment.