From 9be2d2d22df9cb45c8cf1b75f78d72dbeda87cf5 Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 28 Jan 2025 12:50:33 +0100 Subject: [PATCH] refactor[next]: new ir.makers for common builtins (#1827) Using `im.call("...")` inside the transformations is cumbersome. This PR adds new helpers to the `ir_makers` for all commonly used builtins used inside of the transformations, namely: `reduce`, `scan`, `list_get`, `maximum`, `minimum`, `cast_`, and `can_deref`. The helpers for `floordiv_` and `mod` were removed since they weren't used and are rather uncommon anyway. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/ffront/foast_to_gtir.py | 14 +- .../ir_utils/common_pattern_matcher.py | 2 +- .../next/iterator/ir_utils/domain_utils.py | 4 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 56 ++++++-- .../iterator/transforms/collapse_list_get.py | 4 +- .../iterator/transforms/inline_fundefs.py | 4 +- .../next/iterator/transforms/inline_lifts.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 52 ++------ .../ffront_tests/test_foast_to_gtir.py | 44 +++--- .../iterator_tests/test_type_inference.py | 126 ++++++++---------- .../transforms_tests/test_constant_folding.py | 6 +- .../transforms_tests/test_cse.py | 8 +- .../transforms_tests/test_domain_inference.py | 4 +- .../transforms_tests/test_inline_lifts.py | 10 +- .../transforms_tests/test_prune_casts.py | 4 +- .../transforms_tests/test_trace_shifts.py | 2 +- .../transforms_tests/test_unroll_reduce.py | 39 ++---- .../gtfn_tests/test_gtfn_module.py | 2 +- .../gtfn_tests/test_itir_to_gtfn_ir.py | 2 +- .../dace_tests/test_gtir_to_sdfg.py | 78 ++++------- 20 files changed, 197 insertions(+), 268 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 007e195f3e..f884ec555d 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -138,7 +138,7 @@ def visit_ScanOperator( definition = itir.Lambda(params=func_definition.params, expr=new_body) - body = im.as_fieldop(im.call("scan")(definition, forward, init))(*stencil_args) + body = im.as_fieldop(im.scan(definition, forward, init))(*stencil_args) return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) @@ -360,7 +360,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: - return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t) + return _map(im.lambda_("val")(im.cast_("val", str(new_type))), (expr,), t) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return create_cast(obj, (node.args[0].type,)) @@ -409,7 +409,7 @@ def _make_reduction_expr( # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) it = self.visit(node.args[0], **kwargs) assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, init_expr)) + val = im.reduce(op, init_expr) return im.op_as_fieldop(val)(it) def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -462,14 +462,14 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: + def _lower_and_map(self, op: itir.Lambda | str, *args: Any, **kwargs: Any) -> itir.FunCall: return _map( op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) ) def _map( - op: itir.Expr | str, + op: itir.Lambda | str, lowered_args: tuple, original_arg_types: tuple[ts.TypeSpec, ...], ) -> itir.FunCall: @@ -487,9 +487,9 @@ def _map( promote_to_list(arg_type)(larg) for arg_type, larg in zip(original_arg_types, lowered_args) ) - op = im.call("map_")(op) + op = im.map_(op) - return im.op_as_fieldop(im.call(op))(*lowered_args) + return im.op_as_fieldop(op)(*lowered_args) class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 19d0802f4b..c16b9f2b48 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -72,7 +72,7 @@ def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: attribute which can be anything. >>> from gt4py.next.iterator.ir_utils import ir_makers as im - >>> node = im.call("plus")(1, 2) + >>> node = im.plus(1, 2) >>> is_call_to(node, "plus") True >>> is_call_to(node, "minus") diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index c84e2c0228..27900b6db6 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -162,11 +162,11 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) for dim in domains[0].ranges.keys(): start = functools.reduce( - lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + lambda current_expr, el_expr: im.minimum(current_expr, el_expr), [domain.ranges[dim].start for domain in domains], ) stop = functools.reduce( - lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + lambda current_expr, el_expr: im.maximum(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) # constant fold expression to keep the tree small diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index c5cf2efa5a..24842ad3be 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -169,18 +169,6 @@ def divides_(left, right): return call("divides")(left, right) -def floordiv_(left, right): - """Create a floor division FunCall, shorthand for ``call("floordiv")(left, right)``.""" - # TODO(tehrengruber): Use int(floor(left/right)) as soon as we support integer casting - # and remove the `floordiv` builtin again. - return call("floordiv")(left, right) - - -def mod(left, right): - """Create a modulo FunCall, shorthand for ``call("mod")(left, right)``.""" - return call("mod")(left, right) - - def and_(left, right): """Create an and_ FunCall, shorthand for ``call("and_")(left, right)``.""" return call("and_")(left, right) @@ -302,7 +290,10 @@ def shift(offset, value=None): offset = ensure_offset(offset) args = [offset] if value is not None: - value = ensure_offset(value) + if isinstance(value, int): + value = ensure_offset(value) + elif isinstance(value, str): + value = ref(value) args.append(value) return call(call("shift")(*args)) @@ -469,7 +460,7 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal def op_as_fieldop( - op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None + op: str | itir.SymRef | itir.Lambda | Callable, domain: Optional[itir.FunCall] = None ) -> Callable[..., itir.FunCall]: """ Promotes a function `op` to a field_operator. @@ -536,3 +527,40 @@ def index(dim: common.Dimension) -> itir.FunCall: def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) + + +def reduce(op, expr): + """Create a `reduce` call.""" + return call(call("reduce")(op, expr)) + + +def scan(expr, forward, init): + """Create a `scan` call.""" + return call("scan")(expr, forward, init) + + +def list_get(list_idx, list_): + """Create a `list_get` call.""" + return call("list_get")(list_idx, list_) + + +def maximum(expr1, expr2): + """Create a `maximum` call.""" + return call("maximum")(expr1, expr2) + + +def minimum(expr1, expr2): + """Create a `minimum` call.""" + return call("minimum")(expr1, expr2) + + +def cast_(expr, dtype: ts.ScalarType | str): + """Create a `cast_` call.""" + if isinstance(dtype, ts.ScalarType): + dtype = dtype.kind.name.lower() + return call("cast_")(expr, dtype) + + +def can_deref(expr): + """Create a `can_deref` call.""" + return call("can_deref")(expr) diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 4a354879ca..b0a0c1e1dc 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -27,8 +27,8 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: cond, true_val, false_val = node.args[1].args return im.if_( cond, - self.visit(im.call("list_get")(list_idx, true_val)), - self.visit(im.call("list_get")(list_idx, false_val)), + self.visit(im.list_get(list_idx, true_val)), + self.visit(im.list_get(list_idx, false_val)), ) if cpm.is_call_to(node.args[1], "neighbors"): offset_tag = node.args[1].args[0] diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index e4cae978da..03b20d14fe 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -36,12 +36,12 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: >>> fun1 = itir.FunctionDefinition( ... id="fun1", ... params=[im.sym("a")], - ... expr=im.call("deref")("a"), + ... expr=im.deref("a"), ... ) >>> fun2 = itir.FunctionDefinition( ... id="fun2", ... params=[im.sym("a")], - ... expr=im.call("deref")("a"), + ... expr=im.deref("a"), ... ) >>> program = itir.Program( ... id="testee", diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index f27dbbb74c..7724aa86f6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -197,11 +197,11 @@ def visit_FunCall( if len(args) == 0: return im.literal_from_value(True) - res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]]) + res = im.can_deref(args[0]) for arg in args[1:]: res = ir.FunCall( fun=ir.SymRef(id="and_"), - args=[res, ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[arg])], + args=[res, im.can_deref(arg)], ) return res elif ( diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 042a86cd8e..6e993a2ed7 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -14,7 +14,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -85,34 +85,6 @@ def _get_connectivity( return connectivities[0] -def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), - args=[iterator], - location=iterator.location, - ) - - -def _make_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location) - - -def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location - ) - - -def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], location=cond.location - ) - - -def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location) - - @dataclasses.dataclass(frozen=True) class UnrollReduce(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are @@ -130,27 +102,25 @@ def _visit_reduce( max_neighbors = connectivity_type.max_neighbors has_skip_values = connectivity_type.has_skip_values - acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc")) - offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i")) - step = itir.SymRef(id=self.uids.sequential_id(prefix="_step")) + acc: str = self.uids.sequential_id(prefix="_acc") + offset: str = self.uids.sequential_id(prefix="_i") + step: str = self.uids.sequential_id(prefix="_step") assert isinstance(node.fun, itir.FunCall) fun, init = node.fun.args - elems = [_make_list_get(offset, arg) for arg in node.args] - step_fun: itir.Expr = itir.FunCall(fun=fun, args=[acc, *elems]) + elems = [im.list_get(offset, arg) for arg in node.args] + step_fun: itir.Expr = im.call(fun)(acc, *elems) if has_skip_values: check_arg = next(_get_neighbors_args(node.args)) offset_tag, it = check_arg.args - can_deref = _make_can_deref(_make_shift([offset_tag, offset], it)) - step_fun = _make_if(can_deref, step_fun, acc) - step_fun = itir.Lambda(params=[itir.Sym(id=acc.id), itir.Sym(id=offset.id)], expr=step_fun) + can_deref = im.can_deref(im.shift(offset_tag, offset)(it)) + step_fun = im.if_(can_deref, step_fun, acc) + step_fun = im.lambda_(acc, offset)(step_fun) expr = init for i in range(max_neighbors): - expr = itir.FunCall(fun=step, args=[expr, itir.OffsetLiteral(value=i)]) - expr = itir.FunCall( - fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun] - ) + expr = im.call(step)(expr, itir.OffsetLiteral(value=i)) + expr = im.let(step, step_fun)(expr) return expr diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index d2d5404cb5..c0d762efc8 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -91,7 +91,7 @@ def foo(bar: int64, alpha: int64) -> int64: parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.call("multiplies")("alpha", "bar") + reference = im.multiplies_("alpha", "bar") assert lowered.expr == reference @@ -297,7 +297,7 @@ def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a") + reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.cast_("val", "int32"))))("a") assert lowered.expr == reference @@ -310,7 +310,7 @@ def foo(a: float64): lowered = FieldOperatorLowering.apply(parsed) lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) - reference = im.call("cast_")("a", "int32") + reference = im.cast_("a", "int32") assert lowered_inlined.expr == reference @@ -341,7 +341,7 @@ def foo(a: tuple[gtx.Field[[TDim], float64], float64]): reference = im.make_tuple( im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), - im.call("cast_")(im.tuple_get(1, "a"), "int32"), + im.cast_(im.tuple_get(1, "a"), "int32"), ) assert lowered_inlined.expr == reference @@ -551,7 +551,7 @@ def foo(a: gtx.Field[[TDim], "int32"]) -> gtx.Field[[TDim], "int32"]: reference = im.let( ssa.unique_name("tmp", 0), - im.call("plus")( + im.plus( im.literal("1", "int32"), im.literal("1", "int32"), ), @@ -656,7 +656,7 @@ def foo() -> bool: parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.call("greater")( + reference = im.greater( im.literal("3", "int32"), im.literal("4", "int32"), ) @@ -761,11 +761,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "plus", - im.literal(value="0", typename="float64"), - ) + im.reduce( + "plus", + im.literal(value="0", typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -780,11 +778,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "maximum", - im.literal(value=str(np.finfo(np.float64).min), typename="float64"), - ) + im.reduce( + "maximum", + im.literal(value=str(np.finfo(np.float64).min), typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -799,11 +795,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "minimum", - im.literal(value=str(np.finfo(np.float64).max), typename="float64"), - ) + im.reduce( + "minimum", + im.literal(value=str(np.finfo(np.float64).max), typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -828,11 +822,9 @@ def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64] im.as_fieldop_neighbors("V2E", "e1"), )( im.op_as_fieldop( - im.call( - im.call("reduce")( - "plus", - im.literal(value="0", typename="float64"), - ) + im.reduce( + "plus", + im.literal(value="0", typename="float64"), ) )(mapped) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d4d7c60d69..a39fe3c6d8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -77,10 +77,10 @@ def expression_test_cases(): (im.plus(1, 2), int_type), (im.eq(1, 2), bool_type), (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), - (im.call("can_deref")(im.ref("it", it_on_e_of_e_type)), bool_type), + (im.can_deref(im.ref("it", it_on_e_of_e_type)), bool_type), (im.if_(True, 1, 2), int_type), (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), - (im.call("list_get")(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), + (im.list_get(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( im.call("named_range")( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 @@ -119,7 +119,7 @@ def expression_test_cases(): ts.ListType(element_type=it_on_e_of_e_type.element_type), ), # cast - (im.call("cast_")(1, "int32"), int_type), + (im.cast_(1, int_type), int_type), # TODO: lift # TODO: scan # map @@ -128,18 +128,16 @@ def expression_test_cases(): int_list_type, ), # reduce - (im.call(im.call("reduce")("plus", 0))(im.ref("l", int_list_type)), int_type), + (im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type), ( - im.call( - im.call("reduce")( - im.lambda_("acc", "a", "b")( - im.make_tuple( - im.plus(im.tuple_get(0, "acc"), "a"), - im.plus(im.tuple_get(1, "acc"), "b"), - ) - ), - im.make_tuple(0, 0.0), - ) + im.reduce( + im.lambda_("acc", "a", "b")( + im.make_tuple( + im.plus(im.tuple_get(0, "acc"), "a"), + im.plus(im.tuple_get(1, "acc"), "b"), + ) + ), + im.make_tuple(0, 0.0), )(im.ref("la", int_list_type), im.ref("lb", float64_list_type)), ts.TupleType(types=[int_type, float64_type]), ), @@ -148,42 +146,36 @@ def expression_test_cases(): (im.shift("Ioff", 1)(im.ref("it", it_ijk_type)), it_ijk_type), # as_fieldop ( - im.call( - im.call("as_fieldop")( - "deref", - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + "deref", + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field)), float_i_field, ), ( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), - im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), - 0, - 1, - ), - im.call("named_range")( - itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 - ), + im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), + im.call("unstructured_domain")( + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), + 0, + 1, ), - ) + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), + ), )(im.ref("inp", float_edge_k_field)), float_vertex_k_field, ), ( - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), ts.TupleType(types=[float_i_field, float_i_field]), ), @@ -197,21 +189,17 @@ def expression_test_cases(): ( im.if_( False, - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field), 1.0), - im.call( - im.call("as_fieldop")( - "deref", - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + "deref", + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field)), ), float_i_field, @@ -276,9 +264,7 @@ def test_cast_first_arg_inference(): # since cast_ is a grammar builtin whose return type is given by its second argument it is # easy to forget inferring the types of the first argument and its children. Simply check # if the first argument has a type inferred correctly here. - testee = im.call("cast_")( - im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64" - ) + testee = im.cast_(im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64") result = itir_type_inference.infer( testee, offset_provider_type={}, allow_undeclared_symbols=True ) @@ -299,9 +285,7 @@ def test_cartesian_fencil_definition(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_fieldop")(im.ref("deref"), cartesian_domain))( - im.ref("inp") - ), + expr=im.as_fieldop(im.ref("deref"), cartesian_domain)(im.ref("inp")), domain=cartesian_domain, target=im.ref("out"), ), @@ -336,10 +320,8 @@ def test_unstructured_fencil_definition(): declarations=[], body=[ itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain - ) + expr=im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain )(im.ref("inp")), domain=unstructured_domain, target=im.ref("out"), @@ -375,7 +357,7 @@ def test_function_definition(): body=[ itir.SetAt( domain=cartesian_domain, - expr=im.call(im.call("as_fieldop")(im.ref("bar"), cartesian_domain))(im.ref("inp")), + expr=im.as_fieldop(im.ref("bar"), cartesian_domain)(im.ref("inp")), target=im.ref("out"), ), ], @@ -408,11 +390,9 @@ def test_fencil_with_nb_field_input(): body=[ itir.SetAt( domain=unstructured_domain, - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), - unstructured_domain, - ) + expr=im.as_fieldop( + im.lambda_("it")(im.reduce("plus", 0.0)(im.deref("it"))), + unstructured_domain, )(im.ref("inp")), target=im.ref("out"), ), @@ -438,9 +418,7 @@ def test_program_tuple_setat_short_target(): declarations=[], body=[ itir.SetAt( - expr=im.call( - im.call("as_fieldop")(im.lambda_()(im.make_tuple(1.0, 2.0)), cartesian_domain) - )(), + expr=im.as_fieldop(im.lambda_()(im.make_tuple(1.0, 2.0)), cartesian_domain)(), domain=cartesian_domain, target=im.make_tuple("out"), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0bf8dcb65d..cf325c2daa 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -31,7 +31,7 @@ def test_constant_folding_math_op(): def test_constant_folding_if(): - expected = im.call("plus")("a", 2) + expected = im.plus("a", 2) testee = im.if_( im.literal_from_value(True), im.plus(im.ref("a"), im.literal_from_value(2)), @@ -42,7 +42,7 @@ def test_constant_folding_if(): def test_constant_folding_minimum(): - testee = im.call("minimum")("a", "a") + testee = im.minimum("a", "a") expected = im.ref("a") actual = ConstantFolding.apply(testee) assert actual == expected @@ -56,7 +56,7 @@ def test_constant_folding_literal(): def test_constant_folding_literal_maximum(): - testee = im.call("maximum")(im.literal_from_value(1), im.literal_from_value(2)) + testee = im.maximum(im.literal_from_value(1), im.literal_from_value(2)) expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 14860d9bdd..3909c6f26a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -135,7 +135,7 @@ def test_if_can_deref_no_extraction(offset_provider_type): # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) else 1 testee = im.if_( - im.call("can_deref")(im.shift("I", 1)("it")), + im.can_deref(im.shift("I", 1)("it")), im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), # use something more involved where a subexpression can still be eliminated im.literal("1", "int32"), @@ -143,7 +143,7 @@ def test_if_can_deref_no_extraction(offset_provider_type): # (λ(_cs_1) → if can_deref(_cs_1) then (λ(_cs_2) → _cs_2 + _cs_2)(·_cs_1) else 1)(⟪Iₒ, 1ₒ⟫(it)) expected = im.let("_cs_1", im.shift("I", 1)("it"))( im.if_( - im.call("can_deref")("_cs_1"), + im.can_deref("_cs_1"), im.let("_cs_2", im.deref("_cs_1"))(im.plus("_cs_2", "_cs_2")), im.literal("1", "int32"), ) @@ -159,14 +159,14 @@ def test_if_can_deref_eligible_extraction(offset_provider_type): # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) testee = im.if_( - im.call("can_deref")(im.shift("I", 1)("it")), + im.can_deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it")), im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), ) # (λ(_cs_3) → (λ(_cs_1) → if can_deref(_cs_3) then _cs_1 else _cs_1 + _cs_1)(·_cs_3))(⟪Iₒ, 1ₒ⟫(it)) expected = im.let("_cs_3", im.shift("I", 1)("it"))( im.let("_cs_1", im.deref("_cs_3"))( - im.if_(im.call("can_deref")("_cs_3"), "_cs_1", im.plus("_cs_1", "_cs_1")) + im.if_(im.can_deref("_cs_3"), "_cs_1", im.plus("_cs_1", "_cs_1")) ) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 779ab738cb..4a2a441510 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1028,10 +1028,10 @@ def test_arithmetic_builtin(offset_provider): def test_scan(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) testee = im.as_fieldop( - im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0) + im.scan(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0) )("a") expected = im.as_fieldop( - im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0), + im.scan(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0), domain, )("a") diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py index f81ca5a666..957e7ffe63 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py @@ -27,15 +27,15 @@ def inline_lift_test_data(): ), ( # can_deref(lift(f)(args...)) -> and(can_deref(arg[0]), and(can_deref(arg[1]), ...)) - im.call("can_deref")(im.lift("f")("arg1", "arg2")), - im.and_(im.call("can_deref")("arg1"), im.call("can_deref")("arg2")), + im.can_deref(im.lift("f")("arg1", "arg2")), + im.and_(im.can_deref("arg1"), im.can_deref("arg2")), ), ( # can_deref(shift(...)(lift(f)(args...)) -> and(can_deref(shift(...)(arg[0])), and(can_deref(shift(...)(arg[1])), ...)) - im.call("can_deref")(im.shift("I", 1)(im.lift("f")("arg1", "arg2"))), + im.can_deref(im.shift("I", 1)(im.lift("f")("arg1", "arg2"))), im.and_( - im.call("can_deref")(im.shift("I", 1)("arg1")), - im.call("can_deref")(im.shift("I", 1)("arg2")), + im.can_deref(im.shift("I", 1)("arg1")), + im.can_deref(im.shift("I", 1)("arg2")), ), ), ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 77d3323fb4..b1a18ddab8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -16,10 +16,10 @@ def test_prune_casts_simple(): x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) + testee = im.plus(im.cast_(x_ref, "float64"), im.cast_(y_ref, "float64")) testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) - expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) + expected = im.plus(im.cast_(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 1cf662e221..dd7a8f4d43 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -44,7 +44,7 @@ def test_neighbors(): def test_reduce(): # λ(inp) → reduce(plus, 0.)(·inp) - testee = im.lambda_("inp")(im.call(im.call("reduce")("plus", 0.0))(im.deref("inp"))) + testee = im.lambda_("inp")(im.reduce("plus", 0.0)(im.deref("inp"))) expected = [{()}] actual = TraceShifts.trace_stencil(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 0760247996..2415a42267 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -35,27 +35,25 @@ def has_skip_values(request): @pytest.fixture def basic_reduction(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))(im.neighbors("Dim", "x")) + return im.reduce("foo", 0.0)(im.neighbors("Dim", "x")) @pytest.fixture def reduction_with_shift_on_second_arg(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))("x", im.neighbors("Dim", "y")) + return im.reduce("foo", 0.0)("x", im.neighbors("Dim", "y")) @pytest.fixture def reduction_with_incompatible_shifts(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))( - im.neighbors("Dim", "x"), im.neighbors("Dim2", "y") - ) + return im.reduce("foo", 0.0)(im.neighbors("Dim", "x"), im.neighbors("Dim2", "y")) @pytest.fixture def reduction_with_irrelevant_full_shift(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))( + return im.reduce("foo", 0.0)( im.neighbors("Dim", im.shift("IrrelevantDim", 0)("x")), im.neighbors("Dim", "y") ) @@ -63,7 +61,7 @@ def reduction_with_irrelevant_full_shift(): @pytest.fixture def reduction_if(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))(im.if_(True, im.neighbors("Dim", "x"), "y")) + return im.reduce("foo", 0.0)(im.if_(True, im.neighbors("Dim", "x"), "y")) @pytest.mark.parametrize( @@ -83,35 +81,26 @@ def test_get_partial_offsets(reduction, request): def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): - acc = ir.SymRef(id="_acc_1") - offset = ir.SymRef(id="_i_2") - step = ir.SymRef(id="_step_3") + acc, offset, step = "_acc_1", "_i_2", "_step_3" red_fun, red_init = red.fun.args - elements = [ir.FunCall(fun=ir.SymRef(id="list_get"), args=[offset, arg]) for arg in red.args] + elements = [im.list_get(offset, arg) for arg in red.args] - step_expr = ir.FunCall(fun=red_fun, args=[acc] + elements) + step_expr = im.call(red_fun)(acc, *elements) if has_skip_values: neighbors_offset = red.args[shifted_arg].args[0] neighbors_it = red.args[shifted_arg].args[1] - can_deref = ir.FunCall( - fun=ir.SymRef(id="can_deref"), - args=[ - ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="shift"), args=[neighbors_offset, offset]), - args=[neighbors_it], - ) - ], - ) - step_expr = ir.FunCall(fun=ir.SymRef(id="if_"), args=[can_deref, step_expr, acc]) - step_fun = ir.Lambda(params=[ir.Sym(id=acc.id), ir.Sym(id=offset.id)], expr=step_expr) + can_deref = im.can_deref(im.shift(neighbors_offset, offset)(neighbors_it)) + + step_expr = im.if_(can_deref, step_expr, acc) + step_fun = im.lambda_(acc, offset)(step_expr) step_app = red_init for i in range(max_neighbors): - step_app = ir.FunCall(fun=step, args=[step_app, ir.OffsetLiteral(value=i)]) + step_app = im.call(step)(step_app, ir.OffsetLiteral(value=i)) - return ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id=step.id)], expr=step_app), args=[step_fun]) + return im.let(step, step_fun)(step_app) def test_basic(basic_reduction, has_skip_values): diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 53e463c6c7..e7053d3317 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -60,7 +60,7 @@ def program_example(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_fieldop")(itir.SymRef(id="stencil"), domain))( + expr=im.as_fieldop(itir.SymRef(id="stencil"), domain)( itir.SymRef(id="buf"), itir.SymRef(id="sc") ), domain=domain, diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 97591122e5..50e8fa43f0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -47,7 +47,7 @@ def test_get_domains(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_fieldop")("deref"))(), + expr=im.as_fieldop("deref")(), domain=domain, target=itir.SymRef(id="bar"), ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index faf611878d..bfde179e33 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1183,9 +1183,7 @@ def test_gtir_neighbors_as_input(): gtir.SetAt( expr=im.as_fieldop( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ), vertex_domain, )( @@ -1283,25 +1281,15 @@ def test_gtir_neighbors_as_output(): def test_gtir_reduce(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.neighbors("V2E", "it") - ) - ), - vertex_domain, - ) + stencil_inlined = im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.neighbors("V2E", "it")) + ), + vertex_domain, )("edges") - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + stencil_fieldview = im.as_fieldop( + im.lambda_("it")(im.reduce("plus", im.literal_from_value(init_value))(im.deref("it"))), + vertex_domain, )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] @@ -1349,25 +1337,15 @@ def test_gtir_reduce(): def test_gtir_reduce_with_skip_values(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.neighbors("V2E", "it") - ) - ), - vertex_domain, - ) + stencil_inlined = im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.neighbors("V2E", "it")) + ), + vertex_domain, )("edges") - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + stencil_fieldview = im.as_fieldop( + im.lambda_("it")(im.reduce("plus", im.literal_from_value(init_value))(im.deref("it"))), + vertex_domain, )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] @@ -1450,15 +1428,11 @@ def test_gtir_reduce_dot_product(): declarations=[], body=[ gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + expr=im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) + ), + vertex_domain, )( im.op_as_fieldop(im.map_("plus"), vertex_domain)( im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( @@ -1508,9 +1482,7 @@ def test_gtir_reduce_with_cond_neighbors(): gtir.SetAt( expr=im.as_fieldop( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ), vertex_domain, )( @@ -1958,8 +1930,8 @@ def test_gtir_if_scalars(): "f", im.if_( "pred", - im.call("cast_")("y_0", "float64"), - im.call("cast_")("y_1", "float64"), + im.cast_("y_0", "float64"), + im.cast_("y_1", "float64"), ), ) )