Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Gtir concat where #1713

Draft
wants to merge 75 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
d6e8732
Add concat_where frontend and domain inference
SF-N Oct 25, 2024
69f6b11
Finish domain inference for (nested) concat_where and transform to as…
SF-N Oct 25, 2024
05e74c2
fix merge conflicts
havogt Jan 20, 2025
c3a18c4
Merge origin/main
tehrengruber Jan 29, 2025
ba8343b
Extend concat_where, now also working for nested concat_wheres and ex…
SF-N Jan 30, 2025
f90329e
Some fixes, tuples still not supported
SF-N Jan 31, 2025
401d9dd
Merge branch 'main' into GTIR_concat_where
SF-N Jan 31, 2025
b49a82d
Some updates for concat where, which were necessary when using it in …
SF-N Feb 5, 2025
2219314
Merge branch 'main' into GTIR_concat_where
SF-N Feb 5, 2025
9eb428a
Merge origin/main
tehrengruber Feb 14, 2025
d16bbd5
ITIR type inference: store param type in Lambda
tehrengruber Feb 15, 2025
aca4824
Merge branch 'main' into store_lambda_param_type
tehrengruber Feb 17, 2025
813f328
Flatten as_fieldop tuple arguments
tehrengruber Feb 18, 2025
3745461
Add support for scan and nested tuples
tehrengruber Feb 19, 2025
1f23e17
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 19, 2025
8bec9ab
Merge branch 'store_lambda_param_type' into GTIR_concat_where
tehrengruber Feb 19, 2025
06806fb
Preserve annex on new nodes
tehrengruber Feb 19, 2025
bab4fe1
Fix unnecessary import
tehrengruber Feb 19, 2025
6257a2b
Merge branch 'eve_annex_preserve_new_node' into GTIR_concat_where
tehrengruber Feb 19, 2025
14b4bf3
Cleanup
tehrengruber Feb 19, 2025
fc20d7c
Fix doctest
tehrengruber Feb 19, 2025
c5fba83
Fix failing tests
tehrengruber Feb 19, 2025
fa17228
Merge branch 'store_lambda_param_type' into collapse_tuple_as_fieldop…
tehrengruber Feb 19, 2025
04ae430
Fix tests
tehrengruber Feb 19, 2025
5136adc
Fix tests
tehrengruber Feb 19, 2025
5939618
Cleanup frontend type deduction
tehrengruber Feb 19, 2025
157b0e2
Cleanup frontend type deduction
tehrengruber Feb 19, 2025
435d057
Cleanup concat where:
tehrengruber Feb 20, 2025
5e5c66e
Merge branch 'eve_annex_preserve_new_node' into GTIR_concat_where
tehrengruber Feb 20, 2025
bd8dbaa
Fix iterator tests
tehrengruber Feb 20, 2025
2c14648
Fix infer domain ops
tehrengruber Feb 20, 2025
a7f3cac
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
1200803
Cleanup
tehrengruber Feb 20, 2025
cf0ffb2
Fix format
tehrengruber Feb 20, 2025
335e932
Fix broken scan (e.g. test_tuple_scalar_scan)
tehrengruber Feb 20, 2025
7518b9c
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
39652de
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Feb 20, 2025
ba03c7e
Merge remote-tracking branch 'origin/main' into collapse_tuple_as_fie…
tehrengruber Feb 20, 2025
71980af
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
c18b7ad
Fix failing tests
tehrengruber Feb 20, 2025
d399c65
Fix format
tehrengruber Feb 20, 2025
5ad7701
Fix failing tests
tehrengruber Feb 20, 2025
d3957bd
Fix format
tehrengruber Feb 20, 2025
e95fdf0
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
b52a07c
Cleanup
tehrengruber Feb 20, 2025
f8703b2
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
c5c3e5f
Fix pyproject.toml test marker
tehrengruber Feb 20, 2025
f59fabf
Remove unnecessary visits
tehrengruber Feb 20, 2025
c8e06bd
Cleanup trace shifts
tehrengruber Feb 20, 2025
f748da7
Fix type inference
tehrengruber Feb 20, 2025
45f8b09
Add concat_where transforms to field view transforms
tehrengruber Feb 20, 2025
b3647bf
Fix typo
tehrengruber Feb 20, 2025
6ea11e5
Add support for tuples
tehrengruber Feb 20, 2025
60d0d9a
Fixes
tehrengruber Feb 20, 2025
93a6d33
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Feb 21, 2025
132e576
Improve docs
tehrengruber Feb 21, 2025
e469075
Improve docs
tehrengruber Feb 21, 2025
24e2f57
Fix typo
tehrengruber Feb 21, 2025
d14fb21
Cleanup & improve test coverage
tehrengruber Feb 24, 2025
1e3ced5
Cleanup
tehrengruber Feb 24, 2025
595b675
Cleanup
tehrengruber Feb 24, 2025
59a1226
Improve type inference for concat_where tuple case
tehrengruber Feb 28, 2025
f832a19
Fix typo
tehrengruber Feb 28, 2025
75cc4f2
Fix bug in infer domain ops
tehrengruber Mar 2, 2025
6e85bd0
Address review comments
tehrengruber Mar 2, 2025
a8b9736
Merge remote-tracking branch 'origin_tehrengruber/store_lambda_param_…
tehrengruber Mar 2, 2025
9978a43
Address review comments
tehrengruber Mar 2, 2025
232d4b8
Address review comments
tehrengruber Mar 2, 2025
57abfaf
Merge remote-tracking branch 'origin/main' into store_lambda_param_type
tehrengruber Mar 2, 2025
f488b1a
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Mar 3, 2025
55dc611
Merge branch 'store_lambda_param_type' into GTIR_concat_where
tehrengruber Mar 3, 2025
2674f11
Merge origin/main
tehrengruber Mar 3, 2025
d0f93be
Fix deferred type in concat_where
tehrengruber Mar 3, 2025
cf50a37
Fix tuple concat_where (not fully done yet)
tehrengruber Mar 3, 2025
5fc42ce
Fix tuple concat_where (not fully done yet)
tehrengruber Mar 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ markers = [
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_max_over: tests that use the max_over builtin',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
'checks_specific_error: tests that rely on the backend to produce a specific error message',
'uses_frontend_concat_where: tests that use the frontend concat_where builtin',
'uses_gtir_concat_where: tests that use the GTIR concat_where builtin'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
testpaths = 'tests'
Expand Down
31 changes: 24 additions & 7 deletions src/gt4py/eve/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
return None


def _preserve_annex(
node: concepts.Node, new_node: concepts.Node, preserved_annex_attrs: tuple[str, ...]
) -> None:
if preserved_annex_attrs and (old_annex := getattr(node, "__node_annex__", None)):
# access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter
new_annex_dict = new_node.annex.__dict__
for key in preserved_annex_attrs:
if (value := getattr(old_annex, key, NOTHING)) is not NOTHING:
# Note: The annex value of the new node might not be equal
# (in the sense that an equality comparison returns false),
# but in the context of the pass, they are equivalent.
# Therefore, we don't assert equality here.
new_annex_dict[key] = value


class NodeTranslator(NodeVisitor):
"""Special `NodeVisitor` to translate nodes and trees.

Expand Down Expand Up @@ -158,13 +173,7 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
if (new_child := self.visit(child, **kwargs)) is not NOTHING
}
)
if self.PRESERVED_ANNEX_ATTRS and (old_annex := getattr(node, "__node_annex__", None)):
# access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter
new_annex_dict = new_node.annex.__dict__
for key in self.PRESERVED_ANNEX_ATTRS:
if (value := getattr(old_annex, key, NOTHING)) is not NOTHING:
assert key not in new_annex_dict
new_annex_dict[key] = value
_preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS)

return new_node

Expand All @@ -189,3 +198,11 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
)

return copy.deepcopy(node, memo=memo)

def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
new_node = super().visit(node, **kwargs)

if isinstance(node, concepts.Node) and isinstance(new_node, concepts.Node):
_preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS)

return new_node
18 changes: 14 additions & 4 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,31 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Tuple
from typing import Tuple, TypeVar

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction
from gt4py.next.ffront.fbuiltins import (
BuiltInFunction,
FieldOffset,
FieldT,
WhereLikeBuiltinFunction,
)


@BuiltInFunction
def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivity:
raise NotImplementedError()


@WhereBuiltinFunction
_R = TypeVar("_R")
DomainT = TypeVar("DomainT", bound=common.Field)
ConcatWhereBuiltinFunction = WhereLikeBuiltinFunction[_R, DomainT, FieldT]


@ConcatWhereBuiltinFunction
def concat_where(
mask: common.Field,
mask: common.Domain,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
/,
Expand Down
15 changes: 11 additions & 4 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp
return ts.OffsetType
elif t is core_defs.ScalarT:
return ts.ScalarType
elif t is common.Domain:
return ts.DomainType
elif t is type:
return (
ts.FunctionType
Expand Down Expand Up @@ -135,14 +137,15 @@ def __gt_type__(self) -> ts.FunctionType:
)


MaskT = TypeVar("MaskT", bound=common.Field)
MaskLikeT = TypeVar("MaskLikeT", bound=common.Field)
FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple])


class WhereBuiltinFunction(
BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT]
class WhereLikeBuiltinFunction(
BuiltInFunction[_R, [MaskLikeT, FieldT, FieldT]],
Generic[_R, MaskLikeT, FieldT],
):
def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R:
def __call__(self, mask: MaskLikeT, true_field: FieldT, false_field: FieldT) -> _R:
if isinstance(true_field, tuple) or isinstance(false_field, tuple):
if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)):
raise ValueError(
Expand All @@ -157,6 +160,10 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R:
return super().__call__(mask, true_field, false_field)


MaskT = TypeVar("MaskT", bound=common.Field)
WhereBuiltinFunction = WhereLikeBuiltinFunction[_R, MaskT, FieldT]


@BuiltInFunction
def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field:
raise NotImplementedError()
Expand Down
155 changes: 118 additions & 37 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, NodeVisitor, traits
from gt4py.next import errors
from gt4py.next.common import DimensionKind
from gt4py.next import errors, utils
from gt4py.next.common import DimensionKind, promote_dims
from gt4py.next.ffront import ( # noqa
dialect_ast_enums,
experimental,
Expand All @@ -20,6 +20,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.foast_passes.utils import compute_assign_indices
from gt4py.next.iterator import builtins
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation


Expand Down Expand Up @@ -566,16 +567,10 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> foast.Compare:
op=node.op, left=new_left, right=new_right, location=node.location, type=new_type
)

def _deduce_compare_type(
def _deduce_arithmetic_compare_type(
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# check both types compatible
for arg in (left, right):
if not type_info.is_arithmetic(arg.type):
raise errors.DSLError(
arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'."
)

# e.g. `1 < 2`
self._check_operand_dtypes_match(node, left=left, right=right)

try:
Expand All @@ -592,6 +587,48 @@ def _deduce_compare_type(
f" in call to '{node.op}'.",
) from ex

def _deduce_dimension_compare_type(
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# e.g. `IDim > 1`
index_type = ts.ScalarType(
kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
)

if isinstance(left.type, ts.DimensionType):
if not right.type == index_type:
raise errors.DSLError(
right.location,
f"Expected an {index_type}, but got '{right.type}' instead.",
)
return ts.DomainType(dims=[left.type.dim])
elif isinstance(right.type, ts.DimensionType):
if not left.type == index_type:
raise errors.DSLError(
left.location,
f"Expected an {index_type}, but got '{right.type}' instead.",
)
return ts.DomainType(dims=[right.type.dim])
else:
raise AssertionError()

def _deduce_compare_type(
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# e.g. `1 < 1`
if all(type_info.is_arithmetic(arg) for arg in (left.type, right.type)):
return self._deduce_arithmetic_compare_type(node, left=left, right=right)
# e.g. `IDim > 1`
if any(isinstance(arg, ts.DimensionType) for arg in (left.type, right.type)):
return self._deduce_dimension_compare_type(node, left=left, right=right)

raise errors.DSLError(
left.location,
"Comparison operators can only be used between arithmetic types "
"(scalars, fields) or between a dimension and an index type "
"({builtins.INTEGER_INDEX_BUILTIN}).",
)

def _deduce_binop_type(
self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
Expand All @@ -612,37 +649,48 @@ def _deduce_binop_type(
dialect_ast_enums.BinaryOperator.BIT_OR,
dialect_ast_enums.BinaryOperator.BIT_XOR,
}
is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic

# check both types compatible
for arg in (left, right):
if not is_compatible(arg.type):
raise errors.DSLError(
arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'."
)

left_type = cast(ts.FieldType | ts.ScalarType, left.type)
right_type = cast(ts.FieldType | ts.ScalarType, right.type)

if node.op == dialect_ast_enums.BinaryOperator.POW:
return left_type
err_msg = f"Unsupported operand type(s) for {node.op}: '{left.type}' and '{right.type}'."

if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
right_type
if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance(
right.type, (ts.ScalarType, ts.FieldType)
):
raise errors.DSLError(
arg.location,
f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.",
is_compatible = (
type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic
)
for arg in (left, right):
if not is_compatible(arg.type):
raise errors.DSLError(arg.location, err_msg)

try:
return type_info.promote(left_type, right_type)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left_type}' and '{right_type}' to common type"
f" in call to '{node.op}'.",
) from ex
if node.op == dialect_ast_enums.BinaryOperator.POW:
return left.type

if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
right.type
):
raise errors.DSLError(
arg.location,
f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.",
)

try:
return type_info.promote(left.type, right.type)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left.type}' and '{right.type}' to common type"
f" in call to '{node.op}'.",
) from ex
elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType):
if node.op not in logical_ops:
raise errors.DSLError(
node.location,
f"{err_msg} Operator "
f"must be one of {', '.join((str(op) for op in logical_ops))}.",
)
return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims))
else:
raise errors.DSLError(node.location, err_msg)

def _check_operand_dtypes_match(
self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr
Expand Down Expand Up @@ -908,6 +956,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
)

try:
# TODO(tehrengruber): the construct_tuple_type function doesn't look correct
if isinstance(true_branch_type, ts.TupleType) and isinstance(
false_branch_type, ts.TupleType
):
Expand Down Expand Up @@ -943,7 +992,39 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
location=node.location,
)

_visit_concat_where = _visit_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args)

assert isinstance(mask_type, ts.DomainType)
assert all(
isinstance(el, (ts.FieldType, ts.ScalarType))
for arg in (true_branch_type, false_branch_type)
for el in type_info.primitive_constituents(arg)
)

@utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda el: ts.TupleType(types=list(el)),
)
def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType):
if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)):
raise errors.DSLError(
node.location,
f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.",
)
return_dims = promote_dims(mask_type.dims, type_info.promote(tb, fb).dims)
return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype))
return return_type

return_type = deduce_return_type(true_branch_type, false_branch_type)

return foast.Call(
func=node.func,
args=node.args,
kwargs=node.kwargs,
type=return_type,
location=node.location,
)

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call:
arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type)
Expand Down
15 changes: 12 additions & 3 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,16 @@ def visit_Assign(
def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym:
return im.sym(node.id)

def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef:
def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.AxisLiteral:
if isinstance(node.type, ts.DimensionType):
return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind)
return im.ref(node.id)

def visit_Attribute(self, node: foast.Attribute, **kwargs):
if isinstance(node.type, ts.DimensionType):
return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind)
raise AssertionError()

def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr:
return im.tuple_get(node.index, self.visit(node.value, **kwargs))

Expand Down Expand Up @@ -394,7 +401,9 @@ def create_if(

return im.let(cond_symref_name, cond_)(result)

_visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
domain, true_branch, false_branch = self.visit(node.args)
return im.concat_where(domain, true_branch, false_branch)

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
expr = self.visit(node.args[0], **kwargs)
Expand Down Expand Up @@ -477,7 +486,7 @@ def _map(
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
if all(
isinstance(t, ts.ScalarType)
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
Expand Down
Loading
Loading