Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Output argument with non-zero domain start #1780

Merged
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
aed4d1e
Support for calling a program with field arguments whose domain does …
tehrengruber Dec 10, 2024
f722c14
Add test for input arg with different domain
tehrengruber Dec 11, 2024
c5a61e9
Fix format
tehrengruber Dec 11, 2024
9e09c86
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Dec 11, 2024
9deb814
update dace backend
edopao Dec 11, 2024
61feb99
Fix failing tests
tehrengruber Jan 10, 2025
30a4911
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
052c54b
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 10, 2025
0d903cc
Address review comments
tehrengruber Jan 10, 2025
7b77c9f
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
a6cf988
Merge origin/main
tehrengruber Jan 10, 2025
858a573
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 14, 2025
da65ca1
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 15, 2025
e6e640c
dace support for domain range and field origin
edopao Jan 15, 2025
a9f67f9
minor edit
edopao Jan 15, 2025
b97232a
Revert "minor edit"
edopao Jan 16, 2025
56ec88d
Revert "dace support for domain range and field origin"
edopao Jan 16, 2025
ad68fac
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 16, 2025
a28fbf3
skip dace orchestration tests
edopao Jan 16, 2025
9637866
skip dace test_halo_exchange_helper_attrs
edopao Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,41 +83,47 @@ def _process_args(
# TODO(tehrengruber): Previously this function was called with the actual arguments
# not their type. The check using the shape here is not functional anymore and
# should instead be placed in a proper location.
egparedes marked this conversation as resolved.
Show resolved Hide resolved
shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)]
ranges_and_dims = [*_field_constituents_range_and_dims(args[param_idx], param.type)]
# check that all non-scalar like constituents have the same shape and dimension, e.g.
# for `(scalar, (field1, field2))` the two fields need to have the same shape and
# dimension
egparedes marked this conversation as resolved.
Show resolved Hide resolved
if shapes_and_dims:
shape, dims = shapes_and_dims[0]
if ranges_and_dims:
range_, dims = ranges_and_dims[0]
if not all(
el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims
el_range == range_ and el_dims == dims
for (el_range, el_dims) in ranges_and_dims
):
raise ValueError(
"Constituents of composite arguments (e.g. the elements of a"
" tuple) need to have the same shape and dimensions."
)
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
size_args.extend(
shape if shape else [ts.ScalarType(kind=ts.ScalarKind.INT32)] * len(dims) # type: ignore[arg-type] # shape is always empty
range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty
)
return tuple(rewritten_args), tuple(size_args), kwargs


def _field_constituents_shape_and_dims(
def _field_constituents_range_and_dims(
arg: Any, # TODO(havogt): improve typing
arg_type: ts.DataType,
) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]:
) -> Iterator[tuple[tuple[tuple[int, int], ...], list[common.Dimension]]]:
match arg_type:
case ts.TupleType():
for el, el_type in zip(arg, arg_type.types):
assert isinstance(el_type, ts.DataType)
yield from _field_constituents_shape_and_dims(el, el_type)
yield from _field_constituents_range_and_dims(el, el_type)
case ts.FieldType():
dims = type_info.extract_dims(arg_type)
if isinstance(arg, ts.TypeSpec): # TODO
yield (tuple(), dims)
elif dims:
assert hasattr(arg, "shape") and len(arg.shape) == len(dims)
yield (arg.shape, dims)
assert (
hasattr(arg, "domain")
and isinstance(arg.domain, common.Domain)
and len(arg.domain.dims) == len(dims)
egparedes marked this conversation as resolved.
Show resolved Hide resolved
)
yield (tuple((r.start, r.stop) for r in arg.domain.ranges), dims)
else:
yield from [] # ignore 0-dim fields
egparedes marked this conversation as resolved.
Show resolved Hide resolved
case ts.ScalarType():
egparedes marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
37 changes: 21 additions & 16 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]
return iter(scanops_per_axis.keys()).__next__()


def _size_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_size_{dim}"
def _range_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_{dim}_range"


def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]:
Expand Down Expand Up @@ -217,13 +217,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
)
if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType`
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
index_type = ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
)
for dim_idx in range(len(fields_dims[0])):
size_params.append(
itir.Sym(
id=_size_arg_from_field(param.id, dim_idx),
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
id=_range_arg_from_field(param.id, dim_idx),
type=ts.TupleType(types=[index_type, index_type]),
)
)

Expand Down Expand Up @@ -286,7 +287,8 @@ def _visit_slice_bound(
self,
slice_bound: Optional[past.Constant],
default_value: itir.Expr,
dim_size: itir.Expr,
start_idx: itir.Expr,
stop_idx: itir.Expr,
**kwargs: Any,
) -> itir.Expr:
if slice_bound is None:
Expand All @@ -296,11 +298,9 @@ def _visit_slice_bound(
slice_bound.type
)
if slice_bound.value < 0:
lowered_bound = itir.FunCall(
fun=itir.SymRef(id="plus"), args=[dim_size, self.visit(slice_bound, **kwargs)]
)
lowered_bound = im.plus(stop_idx, self.visit(slice_bound, **kwargs))
else:
lowered_bound = self.visit(slice_bound, **kwargs)
lowered_bound = im.plus(start_idx, self.visit(slice_bound, **kwargs))
else:
raise AssertionError("Expected 'None' or 'past.Constant'.")
if slice_bound:
Expand Down Expand Up @@ -348,8 +348,9 @@ def _construct_itir_domain_arg(
domain_args = []
domain_args_kind = []
for dim_i, dim in enumerate(out_dims):
# an expression for the size of a dimension
dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i))
# an expression for the range of a dimension
dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i))
dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range)
# bounds
lower: itir.Expr
upper: itir.Expr
Expand All @@ -359,11 +360,15 @@ def _construct_itir_domain_arg(
else:
lower = self._visit_slice_bound(
slices[dim_i].lower if slices else None,
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
dim_size,
dim_start,
dim_start,
dim_stop,
)
upper = self._visit_slice_bound(
slices[dim_i].upper if slices else None, dim_size, dim_size
slices[dim_i].upper if slices else None,
dim_stop,
dim_start,
dim_stop,
)

if dim.kind == common.DimensionKind.LOCAL:
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/otf/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]:
return None


def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]:
"""
Yield the size of each field argument in each dimension.

Expand All @@ -136,7 +136,9 @@ def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
if first_field:
yield from iter_size_args((first_field,))
case common.Field():
yield from arg.ndarray.shape
for range_ in arg.domain.ranges:
assert isinstance(range_, common.UnitRange)
yield (range_.start, range_.stop)
case _:
pass

Expand All @@ -156,6 +158,7 @@ def iter_size_compile_args(
)
if field_constituents:
# we only need the first field, because all fields in a tuple must have the same dims and sizes
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
yield from [
ts.ScalarType(kind=ts.ScalarKind.INT32) for dim in field_constituents[0].dims
ts.TupleType(types=[index_type, index_type]) for dim in field_constituents[0].dims
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


# regex to match the symbols for field shape and strides
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+")
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"^__.+_((\d+_range_[01])|((size|stride)_\d+))$")


def as_dace_type(type_: ts.ScalarType) -> dace.typeclass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils


class CompiledDaceProgram(stages.CompiledProgram):
class CompiledDaceProgram(stages.ExtendedCompiledProgram):
sdfg_program: dace.CompiledSDFG

# Sorted list of SDFG arguments as they appear in program ABI and corresponding data type;
# scalar arguments that are not used in the SDFG will not be present.
sdfg_arglist: list[tuple[str, dace.dtypes.Data]]

def __init__(self, program: dace.CompiledSDFG):
def __init__(self, program: dace.CompiledSDFG, implicit_domain: bool):
self.sdfg_program = program
self.implicit_domain = implicit_domain
# `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument
# name to its data type, in the same order as arguments appear in the program ABI.
# This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`.
Expand Down Expand Up @@ -88,7 +89,7 @@ def __call__(
dace.config.Config.set("compiler", "cpu", "args", value=compiler_args)
sdfg_program = sdfg.compile(validate=False)

return CompiledDaceProgram(sdfg_program)
return CompiledDaceProgram(sdfg_program, inp.program_source.implicit_domain)


class DaCeCompilationStepFactory(factory.Factory):
Expand All @@ -113,9 +114,11 @@ def decorated_program(
if out is not None:
args = (*args, out)
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
if len(sdfg.arg_names) > len(flat_args):
# The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments.
flat_args = (*flat_args, *arguments.iter_size_args(args))
if inp.implicit_domain:
# generate implicit domain size arguments only if necessary
size_args = arguments.iter_size_args(args)
flat_size_args: Sequence[int] = gtx_utils.flatten_nested_tuple(tuple(size_args))
flat_args = (*flat_args, *flat_size_args)

if sdfg_program._lastargs:
kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def builtin_if(*args: Any) -> str:
return f"{true_val} if {cond} else {false_val}"


def builtin_tuple_get(*args: Any) -> str:
edopao marked this conversation as resolved.
Show resolved Hide resolved
index, tuple_name = args
return f"{tuple_name}_{index}"


def make_const_list(arg: str) -> str:
"""
Takes a single scalar argument and broadcasts this value on the local dimension
Expand All @@ -97,6 +102,7 @@ def make_const_list(arg: str) -> str:
"cast_": builtin_cast,
"if_": builtin_if,
"make_const_list": make_const_list,
"tuple_get": builtin_tuple_get,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811
if not cartesian_case.backend or "dace" not in cartesian_case.backend.name:
pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs")

# TODO(edopao): add support for range symbols in field domain and re-enable this test
pytest.skip("Requires support for field domain range.")

backend = cartesian_case.backend

in_field = cases.allocate(cartesian_case, laplap_program, "in_field")()
Expand Down Expand Up @@ -87,6 +90,9 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811
if not unstructured_case.backend or "dace" not in unstructured_case.backend.name:
pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs")

# TODO(edopao): add support for range symbols in field domain and re-enable this test
pytest.skip("Requires support for field domain range.")

allocator, backend = unstructured_case.allocator, unstructured_case.backend

if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def program_domain(a: cases.IField, out: cases.IField):
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain
ref = out.asnumpy().copy() # ensure we are not writing to out outside the domain
ref[1:9] = a.asnumpy()[1:9] * 2

cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

import gt4py.next as gtx
from gt4py.next import errors
from gt4py.next import errors, constructors, common

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
Expand Down Expand Up @@ -251,3 +251,42 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField):
ValueError, match=(r"Dimensions in out field and field domain are not equivalent")
):
cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={})


@pytest.mark.uses_origin
def test_out_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend)

size = cartesian_case.default_sizes[IDim]

inp = cases.allocate(cartesian_case, copy_program, "in_field").unique()()
out = constructors.empty(
common.domain({IDim: (1, size - 2)}),
allocator=cartesian_case.allocator,
)
ref = inp.ndarray[1:-2]

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)


@pytest.mark.uses_origin
def test_in_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already worked before, I've just added a test.

@gtx.field_operator
def identity(a: cases.IField) -> cases.IField:
return a

@gtx.program
def copy_program(a: cases.IField, out: cases.IField):
identity(a, out=out, domain={IDim: (1, 9)})

inp = constructors.empty(
common.domain({IDim: (1, 9)}),
dtype=np.int32,
allocator=cartesian_case.allocator,
)
inp.ndarray[...] = 42
out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})()
ref = out.asnumpy().copy() # ensure we are not writing to `out` outside the domain
ref[1:9] = inp.asnumpy()

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)
Loading
Loading