Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Output argument with non-zero domain start #1780

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
aed4d1e
Support for calling a program with field arguments whose domain does …
tehrengruber Dec 10, 2024
f722c14
Add test for input arg with different domain
tehrengruber Dec 11, 2024
c5a61e9
Fix format
tehrengruber Dec 11, 2024
9e09c86
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Dec 11, 2024
9deb814
update dace backend
edopao Dec 11, 2024
61feb99
Fix failing tests
tehrengruber Jan 10, 2025
30a4911
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
052c54b
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 10, 2025
0d903cc
Address review comments
tehrengruber Jan 10, 2025
7b77c9f
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
a6cf988
Merge origin/main
tehrengruber Jan 10, 2025
858a573
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 14, 2025
da65ca1
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 15, 2025
e6e640c
dace support for domain range and field origin
edopao Jan 15, 2025
a9f67f9
minor edit
edopao Jan 15, 2025
b97232a
Revert "minor edit"
edopao Jan 16, 2025
56ec88d
Revert "dace support for domain range and field origin"
edopao Jan 16, 2025
ad68fac
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 16, 2025
a28fbf3
skip dace orchestration tests
edopao Jan 16, 2025
9637866
skip dace test_halo_exchange_helper_attrs
edopao Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,40 +83,46 @@ def _process_args(
# TODO(tehrengruber): Previously this function was called with the actual arguments
# not their type. The check using the shape here is not functional anymore and
# should instead be placed in a proper location.
egparedes marked this conversation as resolved.
Show resolved Hide resolved
shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)]
ranges_and_dims = [*_field_constituents_range_and_dims(args[param_idx], param.type)]
# check that all non-scalar like constituents have the same shape and dimension, e.g.
# for `(scalar, (field1, field2))` the two fields need to have the same shape and
# dimension
egparedes marked this conversation as resolved.
Show resolved Hide resolved
if shapes_and_dims:
shape, dims = shapes_and_dims[0]
if ranges_and_dims:
range_, dims = ranges_and_dims[0]
if not all(
el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims
el_range == range_ and el_dims == dims
for (el_range, el_dims) in ranges_and_dims
):
raise ValueError(
"Constituents of composite arguments (e.g. the elements of a"
" tuple) need to have the same shape and dimensions."
)
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
size_args.extend(
shape if shape else [ts.ScalarType(kind=ts.ScalarKind.INT32)] * len(dims) # type: ignore[arg-type] # shape is always empty
range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty
)
return tuple(rewritten_args), tuple(size_args), kwargs


def _field_constituents_shape_and_dims(
def _field_constituents_range_and_dims(
arg: Any, # TODO(havogt): improve typing
arg_type: ts.DataType,
) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]:
) -> Iterator[tuple[tuple[tuple[int, int], ...], list[common.Dimension]]]:
match arg_type:
case ts.TupleType():
for el, el_type in zip(arg, arg_type.types):
yield from _field_constituents_shape_and_dims(el, el_type)
yield from _field_constituents_range_and_dims(el, el_type)
case ts.FieldType():
dims = type_info.extract_dims(arg_type)
if isinstance(arg, ts.TypeSpec): # TODO
yield (tuple(), dims)
elif dims:
assert hasattr(arg, "shape") and len(arg.shape) == len(dims)
yield (arg.shape, dims)
assert (
hasattr(arg, "domain")
and isinstance(arg.domain, common.Domain)
and len(arg.domain.dims) == len(dims)
egparedes marked this conversation as resolved.
Show resolved Hide resolved
)
yield (tuple((r.start, r.stop) for r in arg.domain.ranges), dims)
else:
yield from [] # ignore 0-dim fields
egparedes marked this conversation as resolved.
Show resolved Hide resolved
case ts.ScalarType():
egparedes marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
34 changes: 18 additions & 16 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]
return iter(scanops_per_axis.keys()).__next__()


def _size_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_size_{dim}"
def _range_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_{dim}_range"


def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]:
Expand Down Expand Up @@ -217,13 +217,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
)
if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType`
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
index_type = ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
)
for dim_idx in range(len(fields_dims[0])):
size_params.append(
itir.Sym(
id=_size_arg_from_field(param.id, dim_idx),
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
id=_range_arg_from_field(param.id, dim_idx),
type=ts.TupleType(types=[index_type, index_type]),
)
)

Expand Down Expand Up @@ -286,7 +287,8 @@ def _visit_slice_bound(
self,
slice_bound: Optional[past.Constant],
default_value: itir.Expr,
dim_size: itir.Expr,
start_idx: itir.Expr,
stop_idx: itir.Expr,
**kwargs: Any,
) -> itir.Expr:
if slice_bound is None:
Expand All @@ -296,11 +298,9 @@ def _visit_slice_bound(
slice_bound.type
)
if slice_bound.value < 0:
lowered_bound = itir.FunCall(
fun=itir.SymRef(id="plus"), args=[dim_size, self.visit(slice_bound, **kwargs)]
)
lowered_bound = im.plus(stop_idx, self.visit(slice_bound, **kwargs))
else:
lowered_bound = self.visit(slice_bound, **kwargs)
lowered_bound = im.plus(start_idx, self.visit(slice_bound, **kwargs))
else:
raise AssertionError("Expected 'None' or 'past.Constant'.")
if slice_bound:
Expand Down Expand Up @@ -348,8 +348,9 @@ def _construct_itir_domain_arg(
domain_args = []
domain_args_kind = []
for dim_i, dim in enumerate(out_dims):
# an expression for the size of a dimension
dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i))
# an expression for the range of a dimension
dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i))
dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range)
# bounds
lower: itir.Expr
upper: itir.Expr
Expand All @@ -359,11 +360,12 @@ def _construct_itir_domain_arg(
else:
lower = self._visit_slice_bound(
slices[dim_i].lower if slices else None,
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
dim_size,
dim_start,
dim_start,
dim_stop,
)
upper = self._visit_slice_bound(
slices[dim_i].upper if slices else None, dim_size, dim_size
slices[dim_i].upper if slices else None, dim_stop, dim_start, dim_stop
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
)

if dim.kind == common.DimensionKind.LOCAL:
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/otf/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]:
return None


def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]:
"""
Yield the size of each field argument in each dimension.

Expand All @@ -136,7 +136,9 @@ def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
if first_field:
yield from iter_size_args((first_field,))
case common.Field():
yield from arg.ndarray.shape
for range_ in arg.domain.ranges:
assert isinstance(range_, common.UnitRange)
yield (range_.start, range_.stop)
case _:
pass

Expand All @@ -156,6 +158,7 @@ def iter_size_compile_args(
)
if field_constituents:
# we only need the first field, because all fields in a tuple must have the same dims and sizes
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
yield from [
ts.ScalarType(kind=ts.ScalarKind.INT32) for dim in field_constituents[0].dims
ts.TupleType(types=[index_type, index_type]) for dim in field_constituents[0].dims
]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


# regex to match the symbols for field shape and strides
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+")
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"^__.+_(range_[01]|((size|stride)_\d+))$")
egparedes marked this conversation as resolved.
Show resolved Hide resolved


def as_dace_type(type_: ts.ScalarType) -> dace.typeclass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def builtin_if(*args: Any) -> str:
return f"{true_val} if {cond} else {false_val}"


def builtin_tuple_get(*args: Any) -> str:
edopao marked this conversation as resolved.
Show resolved Hide resolved
index, tuple_name = args
return f"{tuple_name}_{index}"


def make_const_list(arg: str) -> str:
"""
Takes a single scalar argument and broadcasts this value on the local dimension
Expand All @@ -97,6 +102,7 @@ def make_const_list(arg: str) -> str:
"cast_": builtin_cast,
"if_": builtin_if,
"make_const_list": make_const_list,
"tuple_get": builtin_tuple_get,
}


Expand Down
1 change: 1 addition & 0 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
]
DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def program_domain(a: cases.IField, out: cases.IField):
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain
ref = out.asnumpy().copy() # ensure we are not overwriting out outside the domain
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
ref[1:9] = a.asnumpy()[1:9] * 2

cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

import gt4py.next as gtx
from gt4py.next import errors
from gt4py.next import errors, constructors, common

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
Expand Down Expand Up @@ -251,3 +251,42 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField):
ValueError, match=(r"Dimensions in out field and field domain are not equivalent")
):
cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={})


@pytest.mark.uses_origin
def test_out_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend)

size = cartesian_case.default_sizes[IDim]

inp = cases.allocate(cartesian_case, copy_program, "in_field").unique()()
out = constructors.empty(
common.domain({IDim: (1, size - 2)}),
allocator=cartesian_case.allocator,
)
ref = inp.ndarray[1:-2]

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)


@pytest.mark.uses_origin
def test_in_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already worked before, I've just added a test.

@gtx.field_operator
def identity(a: cases.IField) -> cases.IField:
return a

@gtx.program
def copy_program(a: cases.IField, out: cases.IField):
identity(a, out=out, domain={IDim: (1, 9)})

inp = constructors.empty(
common.domain({IDim: (1, 9)}),
dtype=np.int32,
allocator=cartesian_case.allocator,
)
inp.ndarray[...] = 42
out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})()
ref = out.asnumpy().copy() # ensure we are not overwriting `out` outside the domain
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
ref[1:9] = inp.asnumpy()

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)
94 changes: 78 additions & 16 deletions tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,30 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef):
fun=P(itir.SymRef, id=eve.SymbolRef("named_range")),
args=[
P(itir.AxisLiteral, value="IDim"),
P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)),
P(itir.SymRef, id=eve.SymbolRef("__out_size_0")),
P(
itir.FunCall,
fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")),
args=[
P(
itir.Literal,
value="0",
type=ts.ScalarType(kind=ts.ScalarKind.INT32),
),
P(itir.SymRef, id=eve.SymbolRef("__out_0_range")),
],
),
P(
itir.FunCall,
fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")),
args=[
P(
itir.Literal,
value="1",
type=ts.ScalarType(kind=ts.ScalarKind.INT32),
),
P(itir.SymRef, id=eve.SymbolRef("__out_0_range")),
],
),
],
)
],
Expand All @@ -77,8 +99,8 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef):
params=[
P(itir.Sym, id=eve.SymbolName("in_field")),
P(itir.Sym, id=eve.SymbolName("out")),
P(itir.Sym, id=eve.SymbolName("__in_field_size_0")),
P(itir.Sym, id=eve.SymbolName("__out_size_0")),
P(itir.Sym, id=eve.SymbolName("__in_field_0_range")),
P(itir.Sym, id=eve.SymbolName("__out_0_range")),
],
body=[set_at_pattern],
)
Expand All @@ -105,18 +127,58 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef)
args=[
P(itir.AxisLiteral, value="IDim"),
P(
itir.Literal,
value="1",
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
itir.FunCall,
fun=P(itir.SymRef, id=eve.SymbolRef("plus")),
args=[
P(
itir.FunCall,
fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")),
args=[
P(
itir.Literal,
value="0",
type=ts.ScalarType(kind=ts.ScalarKind.INT32),
),
P(itir.SymRef, id=eve.SymbolRef("__out_0_range")),
],
),
P(
itir.Literal,
value="1",
type=ts.ScalarType(
kind=getattr(
ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()
)
),
),
],
),
P(
itir.Literal,
value="2",
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
itir.FunCall,
fun=P(itir.SymRef, id=eve.SymbolRef("plus")),
args=[
P(
itir.FunCall,
fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")),
args=[
P(
itir.Literal,
value="0",
type=ts.ScalarType(kind=ts.ScalarKind.INT32),
),
P(itir.SymRef, id=eve.SymbolRef("__out_0_range")),
],
),
P(
itir.Literal,
value="2",
type=ts.ScalarType(
kind=getattr(
ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()
)
),
),
],
),
],
)
Expand All @@ -129,8 +191,8 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef)
params=[
P(itir.Sym, id=eve.SymbolName("in_field")),
P(itir.Sym, id=eve.SymbolName("out")),
P(itir.Sym, id=eve.SymbolName("__in_field_size_0")),
P(itir.Sym, id=eve.SymbolName("__out_size_0")),
P(itir.Sym, id=eve.SymbolName("__in_field_0_range")),
P(itir.Sym, id=eve.SymbolName("__out_0_range")),
],
body=[set_at_pattern],
)
Expand Down
Loading