-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added wrapper for simple unary delayed operations. (#3)
This includes the usual log, trig and rounding functions; we do a test run in the constructor to automatically determine the data type of the output array, so that we can correctly report it as the operation's dtype.
- Loading branch information
Showing
3 changed files
with
182 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from typing import Literal, Tuple, Union | ||
|
||
import numpy | ||
|
||
from .interface import extract_dense_array, extract_sparse_array, is_sparse | ||
from .SparseNdarray import SparseNdarray | ||
from .utils import sanitize_indices, sanitize_single_index | ||
|
||
__author__ = "ltla" | ||
__copyright__ = "ltla" | ||
__license__ = "MIT" | ||
|
||
OP = Literal[ | ||
"log1p", "log2", "log10", | ||
"exp", "expm1", | ||
"sqrt", "abs", | ||
"sin", "cos", "tan", | ||
"sinh", "cosh", "tanh", | ||
"arcsin", "arccos", "arctan", | ||
"arcsinh", "arccosh", "arctanh", | ||
"ceil", "floor", "trunc", | ||
"sign" | ||
] | ||
|
||
def _choose_operator(op: OP): | ||
return getattr(numpy, op) | ||
|
||
class UnaryIsometricOpSimple: | ||
"""Unary isometric operation involving an n-dimensional seed array with no additional arguments. | ||
Attributes: | ||
seed: | ||
An array-like object. | ||
op (OP): | ||
String specifying the unary operation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seed, | ||
op: OP | ||
): | ||
f = _choose_operator(op) | ||
dummy = f(numpy.zeros(1, dtype=seed.dtype)) | ||
|
||
self._seed = seed | ||
self._op = op | ||
self._preserves_sparse = (dummy[0] == 0) | ||
self._dtype = dummy.dtype | ||
|
||
@property | ||
def shape(self) -> Tuple[int, ...]: | ||
return self._seed.shape | ||
|
||
@property | ||
def dtype(self) -> numpy.dtype: | ||
return self._dtype | ||
|
||
|
||
@is_sparse.register | ||
def _is_sparse_UnaryIsometricOpSimple(x: UnaryIsometricOpSimple) -> bool: | ||
return x._preserves_sparse and is_sparse(x._seed) | ||
|
||
|
||
@extract_dense_array.register | ||
def _extract_dense_array_UnaryIsometricOpSimple( | ||
x: UnaryIsometricOpSimple, idx | ||
) -> numpy.ndarray: | ||
base = extract_dense_array(x._seed, idx) | ||
opfun = _choose_operator(x._op) | ||
return opfun(base).astype(x._dtype, copy=False) | ||
|
||
|
||
def _recursive_apply_op_with_arg_to_sparse_array(contents, at, ndim, op): | ||
if len(at) == ndim - 2: | ||
for i in range(len(contents)): | ||
if contents[i] is not None: | ||
idx, val = contents[i] | ||
contents[i] = (idx, op(idx, val, (*at, i))) | ||
else: | ||
for i in range(len(contents)): | ||
if contents[i] is not None: | ||
_recursive_apply_op_with_arg_to_sparse_array( | ||
contents[i], (*at, i), ndim, op | ||
) | ||
|
||
|
||
@extract_sparse_array.register | ||
def _extract_sparse_array_UnaryIsometricOpSimple( | ||
x: UnaryIsometricOpSimple, idx | ||
) -> SparseNdarray: | ||
sparse = extract_sparse_array(x._seed, idx) | ||
|
||
opfun = _choose_operator(x._op) | ||
def execute(indices, values, at): | ||
return opfun(values) | ||
|
||
if isinstance(sparse._contents, list): | ||
_recursive_apply_op_with_arg_to_sparse_array( | ||
sparse._contents, (), len(sparse.shape), execute | ||
) | ||
elif sparse._contents is not None: | ||
idx, val = sparse._contents | ||
sparse._contents = (idx, execute(idx, val, ())) | ||
|
||
sparse._dtype = x._dtype | ||
return sparse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import warnings | ||
|
||
import delayedarray | ||
import numpy | ||
import pytest | ||
from utils import * | ||
|
||
|
||
def test_UnaryIsometricOpSimple_dense(): | ||
test_shape = (10, 15, 20) | ||
y = numpy.random.rand(*test_shape) | ||
full_index = (slice(None), slice(None), slice(None)) | ||
|
||
op = delayedarray.UnaryIsometricOpSimple(y, "exp") | ||
assert not delayedarray.is_sparse(op) | ||
assert (delayedarray.extract_dense_array(op, full_index) == numpy.exp(y)).all() | ||
|
||
contents = mock_SparseNdarray_contents(test_shape) | ||
ys = delayedarray.SparseNdarray(test_shape, contents) | ||
ops = delayedarray.UnaryIsometricOpSimple(ys, "exp") | ||
assert not delayedarray.is_sparse(ops) | ||
assert (delayedarray.extract_dense_array(ops, full_index) == numpy.exp(delayedarray.extract_dense_array(ys, full_index))).all() | ||
|
||
# Works with a slice. | ||
sub_index = (slice(1, 9), slice(2, 14), slice(0, 20, 2)) | ||
assert (delayedarray.extract_dense_array(op, sub_index) == numpy.exp(y[(..., *sub_index)])).all() | ||
assert (delayedarray.extract_dense_array(ops, sub_index) == numpy.exp(delayedarray.extract_dense_array(ys, sub_index))).all() | ||
|
||
|
||
def test_UnaryIsometricOpSimple_sparse(): | ||
test_shape = (50, 20) | ||
y = numpy.random.rand(*test_shape) | ||
full_index = (slice(None), slice(None)) | ||
|
||
op = delayedarray.UnaryIsometricOpSimple(y, "expm1") | ||
assert not delayedarray.is_sparse(op) | ||
assert (delayedarray.extract_dense_array(op, full_index) == numpy.expm1(y)).all() | ||
|
||
contents = mock_SparseNdarray_contents(test_shape) | ||
ys = delayedarray.SparseNdarray(test_shape, contents) | ||
ops = delayedarray.UnaryIsometricOpSimple(ys, "abs") | ||
assert delayedarray.is_sparse(ops) | ||
assert (delayedarray.extract_dense_array(ops, full_index) == numpy.abs(delayedarray.extract_dense_array(ys, full_index))).all() | ||
|
||
# Works with a slice. | ||
sub_index = (slice(10, 40, 3), slice(2, 18)) | ||
assert (delayedarray.extract_dense_array(op, sub_index) == numpy.expm1(y[(..., *sub_index)])).all() | ||
assert (delayedarray.extract_dense_array(ops, sub_index) == numpy.abs(delayedarray.extract_dense_array(ys, sub_index))).all() | ||
|
||
|
||
def test_UnaryIsometricOpSimple_int_promotion(): | ||
test_shape = (20, 10) | ||
contents = mock_SparseNdarray_contents(test_shape, density1=0) | ||
for i in range(len(contents)): | ||
if contents[i] is not None: | ||
contents[i] = (contents[i][0], (contents[i][1]*10).astype(numpy.int32)) | ||
|
||
y = delayedarray.SparseNdarray(test_shape, contents) | ||
assert y.dtype == numpy.int32 | ||
full_index = (slice(None), slice(None)) | ||
|
||
op = delayedarray.UnaryIsometricOpSimple(y, "sin") | ||
assert delayedarray.is_sparse(op) | ||
assert op.dtype == numpy.float64 # correctly promoted | ||
|
||
out = delayedarray.extract_dense_array(op, full_index) | ||
assert out.dtype == numpy.float64 | ||
ref = numpy.sin(delayedarray.extract_dense_array(y, full_index)) | ||
assert (out == ref).all() | ||
|
||
spout = delayedarray.extract_sparse_array(op, full_index) | ||
assert spout.dtype == numpy.float64 | ||
assert (delayedarray.extract_dense_array(spout, full_index) == ref).all() |