Skip to content

Commit

Permalink
feat[next]: Embedded support for skip value connectivities (#1441)
Browse files Browse the repository at this point in the history
- Introduces `common._DEFAULT_SKIP_VALUE = -1`, currently we don't support execution with skip_values which are different than that default.
- inverse_image works by ignoring the skip_value for the inverse image calculation. current result is the smallest hypercube that contains all non skip values that are requested.

Additional changes:
- Adds a ffront test for the fvm nabla example
- Fixes a lowering bug from past to itir, where domain didn't respect order of horizontal and vertical dimension
  • Loading branch information
havogt authored Feb 23, 2024
1 parent d39c36f commit 117de0a
Show file tree
Hide file tree
Showing 18 changed files with 449 additions and 116 deletions.
13 changes: 11 additions & 2 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ class ConnectivityKind(enum.Flag):


@extended_runtime_checkable
# type: ignore[misc] # DimT should be covariant, but break in another place
# type: ignore[misc] # DimT should be covariant, but breaks in another place
class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]):
@property
@abc.abstractmethod
Expand All @@ -749,6 +749,10 @@ def kind(self) -> ConnectivityKind:
@abc.abstractmethod
def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ...

@property
@abc.abstractmethod
def skip_value(self) -> Optional[core_defs.IntegralScalar]: ...

# Operators
def __abs__(self) -> Never:
raise TypeError("'ConnectivityField' does not support this operation.")
Expand Down Expand Up @@ -840,6 +844,7 @@ def _connectivity(
*,
domain: Optional[DomainLike] = None,
dtype: Optional[core_defs.DType] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> ConnectivityField:
raise NotImplementedError

Expand Down Expand Up @@ -918,6 +923,10 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]:
def codomain(self) -> DimT:
return self.dimension

@property
def skip_value(self) -> None:
return None

@functools.cached_property
def kind(self) -> ConnectivityKind:
return ConnectivityKind(0)
Expand Down Expand Up @@ -1083,4 +1092,4 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call
#: Numeric value used to represent missing values in connectivities.
#: Equivalent to the `_FillValue` attribute in the UGRID Conventions
#: (see: http://ugrid-conventions.github.io/ugrid-conventions/).
SKIP_VALUE: Final[int] = -1
_DEFAULT_SKIP_VALUE: Final[int] = -1
6 changes: 5 additions & 1 deletion src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def as_connectivity(
*,
allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None,
device: Optional[core_defs.Device] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
# copy=False, TODO
) -> common.ConnectivityField:
"""
Expand All @@ -330,6 +331,9 @@ def as_connectivity(
Raises:
ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape.
"""
assert (
skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE
) # TODO(havogt): not yet configurable
if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain):
domain = cast(Sequence[common.Dimension], domain)
if len(domain) != data.ndim:
Expand Down Expand Up @@ -359,7 +363,7 @@ def as_connectivity(
# TODO(havogt): consider adding MutableNDArrayObject
buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index]
connectivity_field = common._connectivity(
buffer.ndarray, codomain=codomain, domain=actual_domain
buffer.ndarray, codomain=codomain, domain=actual_domain, skip_value=skip_value
)
assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField)

Expand Down
151 changes: 102 additions & 49 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import embedded as itir_embedded


try:
Expand Down Expand Up @@ -170,9 +171,12 @@ def remap(
if not common.is_connectivity_field(connectivity):
assert isinstance(connectivity, fbuiltins.FieldOffset)
connectivity = connectivity.as_connectivity_field()

assert common.is_connectivity_field(connectivity)

# Current implementation relies on skip_value == -1:
# if we assume the indexed array has at least one element, we wrap around without out of bounds
assert connectivity.skip_value is None or connectivity.skip_value == -1

# Compute the new domain
dim = connectivity.codomain
dim_idx = self.domain.dim_index(dim)
Expand Down Expand Up @@ -315,6 +319,7 @@ class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__
NdArrayField[common.DimsT, core_defs.IntegralScalar],
):
_codomain: common.DimT
_skip_value: Optional[core_defs.IntegralScalar]

@functools.cached_property
def _cache(self) -> dict:
Expand All @@ -329,6 +334,10 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig
def codomain(self) -> common.DimT:
return self._codomain

@property
def skip_value(self) -> Optional[core_defs.IntegralScalar]:
return self._skip_value

@functools.cached_property
def kind(self) -> common.ConnectivityKind:
kind = common.ConnectivityKind.MODIFY_STRUCTURE
Expand All @@ -349,6 +358,7 @@ def from_array( # type: ignore[override]
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> NdArrayConnectivityField:
domain = common.domain(domain)
xp = cls.array_ns
Expand All @@ -367,7 +377,12 @@ def from_array( # type: ignore[override]

assert isinstance(codomain, common.Dimension)

return cls(domain, array, codomain)
return cls(
domain,
array,
codomain,
_skip_value=skip_value,
)

def inverse_image(
self, image_range: common.UnitRange | common.NamedRange
Expand All @@ -390,47 +405,16 @@ def inverse_image(
assert isinstance(image_range, common.UnitRange)

assert common.UnitRange.is_finite(image_range)
restricted_mask = (self._ndarray >= image_range.start) & (
self._ndarray < image_range.stop
)
# indices of non-zero elements in each dimension
nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask)

new_dims = []
non_contiguous_dims = []

for i, dim_nnz_indices in enumerate(nnz):
# Check if the indices are contiguous
first_data_index = dim_nnz_indices[0]
assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES)
last_data_index = dim_nnz_indices[-1]
assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES)
indices, counts = xp.unique(dim_nnz_indices, return_counts=True)
dim_range = self._domain[i]

if len(xp.unique(counts)) == 1 and (
len(indices) == last_data_index - first_data_index + 1
):
idx_offset = dim_range[1].start
start = idx_offset + first_data_index
assert common.is_int_index(start)
stop = idx_offset + last_data_index + 1
assert common.is_int_index(stop)
new_dims.append(
common.named_range(
(
dim_range[0],
(start, stop),
)
)
)
else:
non_contiguous_dims.append(dim_range[0])

if non_contiguous_dims:
raise ValueError(
f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'."
)
relative_ranges = _hypercube(self._ndarray, image_range, xp, self.skip_value)

if relative_ranges is None:
raise ValueError("Restriction generates non-contiguous dimensions.")

new_dims = [
common.named_range((d, rr + ar.start))
for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges)
]

self._cache[cache_key] = new_dims

Expand All @@ -444,14 +428,49 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field:
xp = cls.array_ns
new_domain, buffer_slice = self._slice(index)
new_buffer = xp.asarray(self.ndarray[buffer_slice])
restricted_connectivity = cls(new_domain, new_buffer, self.codomain)
restricted_connectivity = cls(new_domain, new_buffer, self.codomain, self.skip_value)
self._cache[cache_key] = restricted_connectivity

return restricted_connectivity

__getitem__ = restrict


def _hypercube(
index_array: core_defs.NDArrayObject,
image_range: common.UnitRange,
xp: ModuleType,
skip_value: Optional[core_defs.IntegralScalar] = None,
) -> Optional[list[common.UnitRange]]:
"""
Return the hypercube that contains all indices in `index_array` that are within `image_range`, or `None` if no such hypercube exists.
If `skip_value` is given, the selected values are ignored. It returns the smallest hypercube.
A bigger hypercube could be constructed by adding lines that contain only `skip_value`s.
Example:
index_array = 0 1 -1
3 4 -1
-1 -1 -1
skip_value = -1
would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3].
"""
select_mask = (index_array >= image_range.start) & (index_array < image_range.stop)

nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select_mask)

slices = tuple(
slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz
)
hcube = select_mask[tuple(slices)]
if skip_value is not None:
ignore_mask = index_array == skip_value
hcube |= ignore_mask[tuple(slices)]
if not xp.all(hcube):
return None

return [common.UnitRange(s.start, s.stop) for s in slices]


# -- Specialized implementations for builtin operations on array fields --

NdArrayField.register_builtin_func(
Expand Down Expand Up @@ -483,31 +502,65 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field:
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[
def _make_reduction(
builtin_name: str, array_builtin_name: str, initial_value_op: Callable
) -> Callable[
...,
NdArrayField[common.DimsT, core_defs.ScalarT],
]:
def _builtin_op(
field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension
) -> NdArrayField[common.DimsT, core_defs.ScalarT]:
xp = field.array_ns

if not axis.kind == common.DimensionKind.LOCAL:
raise ValueError("Can only reduce local dimensions.")
if axis not in field.domain.dims:
raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.")
if len([d for d in field.domain.dims if d.kind is common.DimensionKind.LOCAL]) > 1:
raise NotImplementedError(
"Reducing a field with more than one local dimension is not supported."
)
reduce_dim_index = field.domain.dims.index(axis)
current_offset_provider = embedded_context.offset_provider.get(None)
assert current_offset_provider is not None
offset_definition = current_offset_provider[
axis.value
] # assumes offset and local dimension have same name
assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider)
new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis])

broadcast_slice = tuple(
slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis
for d in field.domain.dims
)
masked_array = xp.where(
xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE,
field.ndarray,
initial_value_op(field),
)

return field.__class__.from_array(
getattr(field.array_ns, array_builtin_name)(field.ndarray, axis=reduce_dim_index),
getattr(xp, array_builtin_name)(
masked_array,
axis=reduce_dim_index,
),
domain=new_domain,
)

_builtin_op.__name__ = builtin_name
return _builtin_op


NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum"))
NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max"))
NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min"))
NdArrayField.register_builtin_func(
fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum", lambda x: x.dtype.scalar_type(0))
)
NdArrayField.register_builtin_func(
fbuiltins.max_over, _make_reduction("max_over", "max", lambda x: x.array_ns.min(x._ndarray))
)
NdArrayField.register_builtin_func(
fbuiltins.min_over, _make_reduction("min_over", "min", lambda x: x.array_ns.max(x._ndarray))
)


# -- Concrete array implementations --
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def __post_init__(self):
f"The following closure variables are undefined: {', '.join(undefined_symbols)}."
)

@property
def __name__(self) -> str:
return self.definition.__name__

@functools.cached_property
def __gt_allocator__(
self,
Expand Down Expand Up @@ -603,6 +607,10 @@ def from_function(
operator_attributes=operator_attributes,
)

@property
def __name__(self) -> str:
return self.definition.__name__

def __gt_type__(self) -> ts.CallableType:
type_ = self.foast_node.type
assert isinstance(type_, ts.CallableType)
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,14 @@ def as_connectivity_field(self):
if common.is_connectivity_field(offset_definition):
connectivity = offset_definition
elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider):
assert not offset_definition.has_skip_values
connectivity = gtx.as_connectivity(
domain=self.target,
codomain=self.source,
data=offset_definition.table,
dtype=offset_definition.index_type,
skip_value=(
common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None
),
)
else:
raise NotImplementedError()
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def _construct_itir_domain_arg(
node_domain: Optional[past.Expr],
slices: Optional[list[past.Slice]] = None,
) -> itir.FunCall:
domain_args = []

assert isinstance(out_field.type, ts.TypeSpec)
out_field_types = type_info.primitive_constituents(out_field.type).to_list()
Expand All @@ -246,6 +245,8 @@ def _construct_itir_domain_arg(
" caught in type deduction already."
)

domain_args = []
domain_args_kind = []
for dim_i, dim in enumerate(out_dims):
# an expression for the size of a dimension
dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i))
Expand All @@ -271,11 +272,17 @@ def _construct_itir_domain_arg(
args=[itir.AxisLiteral(value=dim.value), lower, upper],
)
)
domain_args_kind.append(dim.kind)

if self.grid_type == GridType.CARTESIAN:
domain_builtin = "cartesian_domain"
elif self.grid_type == GridType.UNSTRUCTURED:
domain_builtin = "unstructured_domain"
# for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical)
if domain_args_kind[0] == DimensionKind.VERTICAL:
assert len(domain_args) == 2
assert domain_args_kind[1] == DimensionKind.HORIZONTAL
domain_args[0], domain_args[1] = domain_args[1], domain_args[0]
else:
raise AssertionError()

Expand Down
Loading

0 comments on commit 117de0a

Please sign in to comment.