Skip to content

Commit

Permalink
feat[next]: Slicing field to 0d to return field not scalar (#1427)
Browse files Browse the repository at this point in the history
* return array for 0d field from slicing instead of scalar
  • Loading branch information
nfarabullini authored Feb 9, 2024
1 parent 374f043 commit 1d305e1
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 38 deletions.
2 changes: 2 additions & 0 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ def shape(self) -> tuple[int, ...]: ...
@property
def dtype(self) -> Any: ...

def item(self) -> Any: ...

def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ...

def __getitem__(self, item: Any) -> NDArrayObject: ...
Expand Down
14 changes: 9 additions & 5 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,14 +623,17 @@ def asnumpy(self) -> np.ndarray: ...
def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...

@abc.abstractmethod
def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ...
def restrict(self, item: AnyIndexSpec) -> Field: ...

@abc.abstractmethod
def as_scalar(self) -> core_defs.ScalarT: ...

# Operators
@abc.abstractmethod
def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...

@abc.abstractmethod
def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ...
def __getitem__(self, item: AnyIndexSpec) -> Field: ...

@abc.abstractmethod
def __abs__(self) -> Field: ...
Expand Down Expand Up @@ -896,6 +899,9 @@ def ndarray(self) -> Never:
def asnumpy(self) -> Never:
raise NotImplementedError()

def as_scalar(self) -> Never:
raise NotImplementedError()

@functools.cached_property
def domain(self) -> Domain:
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))
Expand Down Expand Up @@ -947,9 +953,7 @@ def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Conne

__call__ = remap

def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar:
if is_int_index(index):
return index + self.offset
def restrict(self, index: AnyIndexSpec) -> Never:
raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case

__getitem__ = restrict
Expand Down
19 changes: 11 additions & 8 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def asnumpy(self) -> np.ndarray:
else:
return np.asarray(self._ndarray)

def as_scalar(self) -> core_defs.ScalarT:
if self.domain.ndim != 0:
raise ValueError(
"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'."
)
return self.ndarray.item()

@property
def codomain(self) -> type[core_defs.ScalarT]:
return self.dtype.scalar_type
Expand Down Expand Up @@ -204,15 +211,11 @@ def remap(

__call__ = remap # type: ignore[assignment]

def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
def restrict(self, index: common.AnyIndexSpec) -> common.Field:
new_domain, buffer_slice = self._slice(index)

new_buffer = self.ndarray[buffer_slice]
if len(new_domain) == 0:
# TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer
return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here
else:
return self.__class__.from_array(new_buffer, domain=new_domain)
new_buffer = self.__class__.array_ns.asarray(new_buffer)
return self.__class__.from_array(new_buffer, domain=new_domain)

__getitem__ = restrict

Expand Down Expand Up @@ -433,7 +436,7 @@ def inverse_image(

return new_dims

def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar:
def restrict(self, index: common.AnyIndexSpec) -> common.Field:
cache_key = (id(self.ndarray), self.domain, index)

if (restricted_connectivity := self._cache.get(cache_key, None)) is None:
Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def _tuple_at(
) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]:
@utils.tree_map
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
res = field[pos] if common.is_field(field) else field
res = res.item() if hasattr(res, "item") else res # extract scalar value from array
res = field[pos].as_scalar() if common.is_field(field) else field
assert core_defs.is_scalar_type(res)
return res

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript:
index = self._match_index(node.slice)
except ValueError:
raise errors.DSLError(
self.get_location(node.slice), "eXpected an integral index."
self.get_location(node.slice), "Expected an integral index."
) from None

return foast.Subscript(
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def visit_Call(self, node: past.Call, **kwargs):
f"'{new_kwargs['out'].type}'."
)
elif new_func.id in ["minimum", "maximum"]:
if new_args[0].type != new_args[1].type:
if arg_types[0] != arg_types[1]:
raise ValueError(
f"First and second argument in '{new_func.id}' must be of the same type."
f"Got '{new_args[0].type}' and '{new_args[1].type}'."
f"Got '{arg_types[0]}' and '{arg_types[1]}'."
)
return_type = new_args[0].type
return_type = arg_types[0]
else:
raise AssertionError(
"Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed."
Expand Down
9 changes: 4 additions & 5 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _visit_stencil_call_out_arg(
) -> tuple[itir.Expr, itir.FunCall]:
if isinstance(out_arg, past.Subscript):
# as the ITIR does not support slicing a field we have to do a deeper
# inspection of the PAST to emulate the behaviour
# inspection of the PAST to emulate the behaviour
out_field_name: past.Name = out_arg.value
return (
self._construct_itir_out_arg(out_field_name),
Expand Down Expand Up @@ -382,12 +382,11 @@ def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall:
)

def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall:
if node.func.id in ["maximum", "minimum"] and len(node.args) == 2:
if node.func.id in ["maximum", "minimum"]:
assert len(node.args) == 2
return itir.FunCall(
fun=itir.SymRef(id=node.func.id),
args=[self.visit(node.args[0]), self.visit(node.args[1])],
)
else:
raise AssertionError(
"Only 'minimum' and 'maximum' builtins supported supported currently."
)
raise NotImplementedError("Only 'minimum', and 'maximum' builtins supported currently.")
28 changes: 22 additions & 6 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def _translate_named_indices(
return tuple(domain_slice)

def field_getitem(self, named_indices: NamedFieldIndices) -> Any:
return self._ndarrayfield[self._translate_named_indices(named_indices)]
return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar()

def field_setitem(self, named_indices: NamedFieldIndices, value: Any):
if common.is_mutable_field(self._ndarrayfield):
Expand Down Expand Up @@ -1040,6 +1040,7 @@ class IndexField(common.Field):
"""

_dimension: common.Dimension
_cur_index: Optional[core_defs.IntegralScalar] = None

@property
def __gt_domain__(self) -> common.Domain:
Expand All @@ -1055,7 +1056,10 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override

@property
def domain(self) -> common.Domain:
return common.Domain((self._dimension, common.UnitRange.infinite()))
if self._cur_index is None:
return common.Domain((self._dimension, common.UnitRange.infinite()))
else:
return common.Domain()

@property
def codomain(self) -> type[core_defs.int32]:
Expand All @@ -1072,16 +1076,24 @@ def ndarray(self) -> core_defs.NDArrayObject:
def asnumpy(self) -> np.ndarray:
raise NotImplementedError()

def as_scalar(self) -> core_defs.IntegralScalar:
if self.domain.ndim != 0:
raise ValueError(
"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'."
)
assert self._cur_index is not None
return self._cur_index

def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32:
def restrict(self, item: common.AnyIndexSpec) -> common.Field:
if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off
d, r = item[0]
assert d == self._dimension
assert isinstance(r, int)
return self.dtype.scalar_type(r)
assert isinstance(r, core_defs.INTEGRAL_TYPES)
return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work
# TODO set a domain...
raise NotImplementedError()

Expand Down Expand Up @@ -1195,8 +1207,12 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
def restrict(self, item: common.AnyIndexSpec) -> common.Field:
# TODO set a domain...
return self

def as_scalar(self) -> core_defs.ScalarT:
assert self.domain.ndim == 0
return self._value

__call__ = remap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,21 @@ def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]:
cases.verify(cartesian_case, testee, inp, out=out, ref=expected)


def test_single_value_field(cartesian_case):
@gtx.field_operator
def testee_fo(a: cases.IKField) -> cases.IKField:
return a

@gtx.program
def testee_prog(a: cases.IKField):
testee_fo(a, out=a[1:2, 3:4])

a = cases.allocate(cartesian_case, testee_prog, "a")()
ref = a[1, 3]

cases.verify(cartesian_case, testee_prog, a, inout=a[1, 3], ref=ref)


def test_astype_int(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_simple_indirection(program_processor):

ref = np.zeros(shape, dtype=inp.dtype)
for i in range(shape[0]):
ref[i] = inp.ndarray[i + 1 - 1] if cond[i] < 0.0 else inp.ndarray[i + 1 + 1]
ref[i] = inp.asnumpy()[i + 1 - 1] if cond.asnumpy()[i] < 0.0 else inp.asnumpy()[i + 1 + 1]

run_processor(
conditional_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))],
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_direct_offset_for_indirection(program_processor):

ref = np.zeros(shape)
for i in range(shape[0]):
ref[i] = inp[i + cond[i]]
ref[i] = inp.asnumpy()[i + cond.asnumpy()[i]]

run_processor(
direct_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ def fencil(x, y, z, out, inp):
def naive_lap(inp):
shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]]
out = np.zeros(shape)
inp_data = inp.asnumpy()
for i in range(1, shape[0] + 1):
for j in range(1, shape[1] + 1):
for k in range(0, shape[2]):
out[i - 1, j - 1, k] = -4 * inp[i, j, k] + (
inp[i + 1, j, k] + inp[i - 1, j, k] + inp[i, j + 1, k] + inp[i, j - 1, k]
out[i - 1, j - 1, k] = -4 * inp_data[i, j, k] + (
inp_data[i + 1, j, k]
+ inp_data[i - 1, j, k]
+ inp_data[i, j + 1, k]
+ inp_data[i, j - 1, k]
)
return out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,11 @@ def test_absolute_indexing_value_return():
field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain)

named_index = ((IDim, 12), (JDim, 6))
assert common.is_field(field)
value = field[named_index]

assert isinstance(value, np.int32)
assert value == 21
assert common.is_field(value)
assert value.as_scalar() == 21


@pytest.mark.parametrize(
Expand Down Expand Up @@ -568,14 +569,17 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain):

@pytest.mark.parametrize(
"index, expected_value",
[((1, 0), 10), ((0, 1), 1)],
[
((1, 0), 10),
((0, 1), 1),
],
)
def test_relative_indexing_value_return(index, expected_value):
domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12)))
field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain)
indexed_field = field[index]

assert indexed_field == expected_value
assert indexed_field.as_scalar() == expected_value


@pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]])
Expand Down

0 comments on commit 1d305e1

Please sign in to comment.