Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into update-api-tests-with…
Browse files Browse the repository at this point in the history
…out-complex
  • Loading branch information
cbourjau committed Sep 3, 2024
2 parents 9f5054d + 28f5773 commit a8be313
Show file tree
Hide file tree
Showing 21 changed files with 693 additions and 161 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ jobs:
name: artifact
path: dist
- name: Publish package on TestPyPi
uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0
uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6
with:
repository-url: https://test.pypi.org/legacy/
- name: Publish package on PyPi
uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0
uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6
16 changes: 15 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,23 @@
Changelog
=========

0.9.0 (unreleased)

0.9.0 (2024-08-30)
------------------

**New features**

- User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function.
- User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function.
- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx.

**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
- :func:`ndonnx.cumulative_sum` now correctly applies the ``include_initial`` parameter and works around missing onnxruntime kernels for unsigned integral types.
- :func:`ndonnx.additional.make_nullable` applies broadcasting to the provided null array (instead of reshape like it did previously). This allows writing ``make_nullable(x, False)`` to turn an array into nullable.
- User-defined data types that implement :class:`ndonnx._core.UniformShapeOperations` may now implement :func:`ndonnx.where` without requiring both data types be promotable.

**Breaking change**

- Iterating over dynamic dimensions of :class:`~ndonnx.Array` is no longer allowed since it commonly lead to infinite loops when used without an explicit break condition.
Expand Down
2 changes: 2 additions & 0 deletions ndonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Floating,
Integral,
Nullable,
NullableCore,
NullableFloating,
NullableIntegral,
NullableNumerical,
Expand Down Expand Up @@ -323,6 +324,7 @@
"Floating",
"NullableIntegral",
"Nullable",
"NullableCore",
"Integral",
"CoreType",
"CastError",
Expand Down
21 changes: 8 additions & 13 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx.additional import shape
from ndonnx.additional._additional import _getitem as getitem
from ndonnx.additional._additional import _static_shape as static_shape

from ._corearray import _CoreArray
Expand Down Expand Up @@ -47,7 +48,11 @@ def array(
out : Array
The new array. This represents an ONNX model input.
"""
return Array._construct(shape=shape, dtype=dtype)
if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented:
return out
raise ndx.UnsupportedOperationError(
f"No implementation of `make_array` for {dtype}"
)


def from_spox_var(
Expand Down Expand Up @@ -154,17 +159,7 @@ def astype(self, to: CoreType | StructType) -> Array:
return ndx.astype(self, to)

def __getitem__(self, index: IndexType) -> Array:
if isinstance(index, Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
raise TypeError(
f"Index must be an integral or boolean 'Array', not `{index.dtype}`"
)

if isinstance(index, Array):
index = index._core()

return self._transmute(lambda corearray: corearray[index])
return getitem(self, index)

def __setitem__(
self, index: IndexType | Self, updates: int | bool | float | Array
Expand Down Expand Up @@ -517,7 +512,7 @@ def size(self) -> ndx.Array:
out: Array
Scalar ``Array`` instance whose value is the number of elements in the original array.
"""
return ndx.prod(self.shape)
return ndx.prod(shape(self))

@property
def T(self) -> ndx.Array: # noqa: N802
Expand Down
24 changes: 8 additions & 16 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, unary_op, validate_core

if TYPE_CHECKING:
from ndonnx import Array


class BooleanOperationsImpl(UniformShapeOperations):
class _BooleanOperationsImpl(OperationsBlock):
@validate_core
def equal(self, x, y) -> Array:
return binary_op(x, y, opx.equal)
Expand Down Expand Up @@ -99,7 +100,7 @@ def can_cast(self, from_, to) -> bool:

@validate_core
def all(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, True, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(True, dtype=ndx.bool)
Expand All @@ -109,7 +110,7 @@ def all(self, x, *, axis=None, keepdims: bool = False):

@validate_core
def any(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, False, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(False, dtype=ndx.bool)
Expand Down Expand Up @@ -162,17 +163,8 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
)

class BooleanOperationsImpl(CoreOperationsImpl, _BooleanOperationsImpl): ...

class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented

class NullableBooleanOperationsImpl(NullableOperationsImpl, _BooleanOperationsImpl): ...
58 changes: 58 additions & 0 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from spox import Tensor, argument

import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx.additional as nda
from ndonnx._corearray import _CoreArray

from ._shapeimpl import UniformShapeOperations
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import Dtype


class CoreOperationsImpl(UniformShapeOperations):
def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> Array:
if not isinstance(dtype, dtypes.CoreType):
return NotImplemented
return ndx.Array._from_fields(
dtype,
data=_CoreArray(
dtype._parse_input(eager_value)["data"]
if eager_value is not None
else argument(Tensor(dtype.to_numpy_dtype(), shape))
),
)

@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
if null.dtype != ndx.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)
39 changes: 29 additions & 10 deletions ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@

from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np

import ndonnx as ndx
import ndonnx._data_types as dtypes

if TYPE_CHECKING:
from ndonnx._array import IndexType
from ndonnx._data_types import Dtype


class OperationsBlock:
"""Interface for data types to implement top-level functions exported by ndonnx."""
Expand Down Expand Up @@ -251,7 +257,7 @@ def cumulative_sum(
x,
*,
axis: int | None = None,
dtype: ndx.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
include_initial: bool = False,
):
return NotImplemented
Expand All @@ -270,7 +276,7 @@ def prod(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -293,7 +299,7 @@ def sum(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -305,7 +311,7 @@ def var(
axis=None,
keepdims: bool = False,
correction=0.0,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
) -> ndx.Array:
return NotImplemented

Expand Down Expand Up @@ -352,7 +358,7 @@ def full_like(self, x, fill_value, dtype=None, device=None) -> ndx.Array:
def ones(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented
Expand All @@ -365,14 +371,12 @@ def ones_like(
def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
def zeros_like(self, x, dtype: Dtype | None = None, device=None):
return NotImplemented

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
Expand Down Expand Up @@ -413,3 +417,18 @@ def can_cast(self, from_, to) -> bool:

def static_shape(self, x) -> tuple[int | None, ...]:
return NotImplemented

def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> ndx.Array:
return NotImplemented

def getitem(
self,
x: ndx.Array,
index: IndexType,
) -> ndx.Array:
return NotImplemented
27 changes: 24 additions & 3 deletions ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING, Union

import ndonnx as ndx

from ._interface import OperationsBlock
from ._shapeimpl import UniformShapeOperations
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import CoreType, StructType

Dtype = Union[CoreType, StructType]

class NullableOperationsImpl(OperationsBlock):

class NullableOperationsImpl(UniformShapeOperations):
@validate_core
def fill_null(self, x, value):
def fill_null(self, x: Array, value) -> Array:
value = ndx.asarray(value)
if value.dtype != x.values.dtype:
value = value.astype(x.values.dtype)
return ndx.where(x.null, value, x.values)

@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
return NotImplemented

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)
Loading

0 comments on commit a8be313

Please sign in to comment.