Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Embedded support for skip value connectivities #1441

Merged
merged 24 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 77 additions & 43 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import embedded as itir_embedded


try:
Expand Down Expand Up @@ -159,6 +160,8 @@ def from_array(
def remap(
self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset
) -> NdArrayField:
# TODO skip values: if the skip_value is -1 we don't need special treatment, we'll just select a random value (the wrapped around one)
havogt marked this conversation as resolved.
Show resolved Hide resolved

# For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField
if not common.is_connectivity_field(connectivity):
assert isinstance(connectivity, fbuiltins.FieldOffset)
Expand Down Expand Up @@ -387,47 +390,22 @@ def inverse_image(
assert isinstance(image_range, common.UnitRange)

assert common.UnitRange.is_finite(image_range)

restricted_mask = (self._ndarray >= image_range.start) & (
self._ndarray < image_range.stop
)
# indices of non-zero elements in each dimension
nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask)

new_dims = []
non_contiguous_dims = []

for i, dim_nnz_indices in enumerate(nnz):
# Check if the indices are contiguous
first_data_index = dim_nnz_indices[0]
assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES)
last_data_index = dim_nnz_indices[-1]
assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES)
indices, counts = xp.unique(dim_nnz_indices, return_counts=True)
dim_range = self._domain[i]

if len(xp.unique(counts)) == 1 and (
len(indices) == last_data_index - first_data_index + 1
):
idx_offset = dim_range[1].start
start = idx_offset + first_data_index
assert common.is_int_index(start)
stop = idx_offset + last_data_index + 1
assert common.is_int_index(stop)
new_dims.append(
common.named_range(
(
dim_range[0],
(start, stop),
)
)
)
else:
non_contiguous_dims.append(dim_range[0])

if non_contiguous_dims:
raise ValueError(
f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'."
)
relative_ranges = _hypercube(
restricted_mask, xp, ignore_mask=self._ndarray == common.SKIP_VALUE
havogt marked this conversation as resolved.
Show resolved Hide resolved
)

if relative_ranges is None:
raise ValueError("Restriction generates non-contiguous dimensions.")

new_dims = [
common.named_range((d, rr + ar.start))
for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges)
]

self._cache[cache_key] = new_dims

Expand All @@ -449,6 +427,30 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ
__getitem__ = restrict


def _hypercube(
havogt marked this conversation as resolved.
Show resolved Hide resolved
select: core_defs.NDArrayObject,
havogt marked this conversation as resolved.
Show resolved Hide resolved
xp: ModuleType,
ignore_mask: Optional[core_defs.NDArrayObject] = None,
) -> Optional[list[common.UnitRange]]:
"""
Return the hypercube that contains all True values and no False values or `None` if no such hypercube exists.

If `ignore_mask` is given, the selected values are ignored.
"""
nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select)

slices = tuple(
slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz
)
hcube = select[tuple(slices)]
if ignore_mask is not None:
hcube |= ignore_mask[tuple(slices)]
if not xp.all(hcube):
return None

return [common.UnitRange(s.start, s.stop) for s in slices]


# -- Specialized implementations for builtin operations on array fields --

NdArrayField.register_builtin_func(
Expand Down Expand Up @@ -480,7 +482,9 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[
def _make_reduction(
builtin_name: str, array_builtin_name: str, initial_value_op: Callable
) -> Callable[
...,
NdArrayField[common.DimsT, core_defs.ScalarT],
]:
Expand All @@ -491,20 +495,50 @@ def _builtin_op(
raise ValueError("Can only reduce local dimensions.")
if axis not in field.domain.dims:
raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.")
if len([d for d in field.domain.dims if d.kind is common.DimensionKind.LOCAL]) > 1:
raise NotImplementedError(
"Reducing a field with more than one local dimension is not supported."
)
reduce_dim_index = field.domain.dims.index(axis)
current_offset_provider = embedded_context.offset_provider.get(None)
assert current_offset_provider is not None
offset_definition = current_offset_provider[
axis.value
] # assumes offset and local dimension have same name
egparedes marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider)
new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis])

broadcast_slice = tuple(
slice(None) if d in [axis, offset_definition.origin_axis] else None
havogt marked this conversation as resolved.
Show resolved Hide resolved
for d in field.domain.dims
)
masked_array = field.array_ns.where(
field.array_ns.asarray(offset_definition.table[broadcast_slice]) != common.SKIP_VALUE,
field.ndarray,
initial_value_op(field),
)

return field.__class__.from_array(
getattr(field.array_ns, array_builtin_name)(field.ndarray, axis=reduce_dim_index),
getattr(field.array_ns, array_builtin_name)(
masked_array,
axis=reduce_dim_index,
),
domain=new_domain,
)

_builtin_op.__name__ = builtin_name
return _builtin_op


NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum"))
NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max"))
NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min"))
NdArrayField.register_builtin_func(
fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum", lambda x: x.dtype.scalar_type(0))
)
NdArrayField.register_builtin_func(
fbuiltins.max_over, _make_reduction("max_over", "max", lambda x: x.array_ns.min(x._ndarray))
)
NdArrayField.register_builtin_func(
fbuiltins.min_over, _make_reduction("min_over", "min", lambda x: x.array_ns.max(x._ndarray))
)


# -- Concrete array implementations --
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def __post_init__(self):
f"The following closure variables are undefined: {', '.join(undefined_symbols)}."
)

@property
egparedes marked this conversation as resolved.
Show resolved Hide resolved
def __name__(self) -> str:
return self.definition.__name__

@functools.cached_property
def __gt_allocator__(
self,
Expand Down Expand Up @@ -601,6 +605,10 @@ def from_function(
operator_attributes=operator_attributes,
)

@property
egparedes marked this conversation as resolved.
Show resolved Hide resolved
def __name__(self) -> str:
return self.definition.__name__

def __gt_type__(self) -> ts.CallableType:
type_ = self.foast_node.type
assert isinstance(type_, ts.CallableType)
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def as_connectivity_field(self):
if common.is_connectivity_field(offset_definition):
connectivity = offset_definition
elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider):
assert not offset_definition.has_skip_values
connectivity = gtx.as_connectivity(
domain=self.target,
codomain=self.source,
Expand Down
11 changes: 10 additions & 1 deletion src/gt4py/next/iterator/atlas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __getitem__(self, indices):
if neigh_index < self.atlas_connectivity.cols(primary_index):
return self.atlas_connectivity[primary_index, neigh_index]
else:
return None
return -1
havogt marked this conversation as resolved.
Show resolved Hide resolved
else:
if neigh_index < 2:
return self.atlas_connectivity[primary_index, neigh_index]
Expand All @@ -53,3 +53,12 @@ def max(self): # noqa: A003
if v is not None:
maximum = max(maximum, v)
return maximum

def asnumpy(self):
import numpy as np

res = np.empty(self.shape, dtype=self.dtype)
for i in range(self.shape[0]):
for j in range(self.shape[1]):
res[i, j] = self[i, j]
return res
1 change: 0 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
XFAIL,
UNSUPPORTED_MESSAGE,
), # we can't extract the field type from scan args
(USES_MESH_WITH_SKIP_VALUES, XFAIL, UNSUPPORTED_MESSAGE),
]
GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
# floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Joff,
KDim,
V2EDim,
Vertex,
cartesian_case,
unstructured_case,
)
Expand Down Expand Up @@ -85,26 +86,60 @@ def minover(edge_f: cases.EField) -> cases.VField:
)


@pytest.mark.uses_unstructured_shift
def test_reduction_execution(unstructured_case):
@gtx.field_operator
def reduction(edge_f: cases.EField) -> cases.VField:
return neighbor_sum(edge_f(V2E), axis=V2EDim)
@gtx.field_operator
def reduction_e_field(edge_f: cases.EField) -> cases.VField:
return neighbor_sum(edge_f(V2E), axis=V2EDim)


@gtx.field_operator
def reduction_ek_field(
edge_f: common.Field[[Edge, KDim], np.int32]
) -> common.Field[[Vertex, KDim], np.int32]:
return neighbor_sum(edge_f(V2E), axis=V2EDim)

@gtx.program
def fencil(edge_f: cases.EField, out: cases.VField):
reduction(edge_f, out=out)

@gtx.field_operator
def reduction_ke_field(
edge_f: common.Field[[KDim, Edge], np.int32]
) -> common.Field[[KDim, Vertex], np.int32]:
return neighbor_sum(edge_f(V2E), axis=V2EDim)


@pytest.mark.uses_unstructured_shift
@pytest.mark.parametrize(
"fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__
)
def test_neighbor_sum(unstructured_case, fop):
v2e_table = unstructured_case.offset_provider["V2E"].table
cases.verify_with_default_data(

edge_f = cases.allocate(unstructured_case, fop, "edge_f")()

local_dim_idx = edge_f.domain.dims.index(Edge)
adv_indexing = tuple(
slice(None) if dim is not Edge else v2e_table for dim in edge_f.domain.dims
)

broadcast_slice = []
for dim in edge_f.domain.dims:
if dim is Edge:
broadcast_slice.append(slice(None))
broadcast_slice.append(slice(None))
else:
broadcast_slice.append(None)

broadcasted_table = v2e_table[tuple(broadcast_slice)]
ref = np.sum(
edge_f.asnumpy()[adv_indexing],
axis=local_dim_idx + 1,
initial=0,
where=broadcasted_table != common.SKIP_VALUE,
havogt marked this conversation as resolved.
Show resolved Hide resolved
)
cases.verify(
unstructured_case,
fencil,
ref=lambda edge_f: np.sum(
edge_f[v2e_table],
axis=1,
initial=0,
where=v2e_table != common.SKIP_VALUE,
),
fop,
edge_f,
out=cases.allocate(unstructured_case, fop, cases.RETURN)(),
ref=ref,
)


Expand Down
Loading
Loading