Skip to content

Commit

Permalink
Merge pull request #96 from ngoldbaum/pyint-promoter
Browse files Browse the repository at this point in the history
add a promoter for multiplying with a python int
  • Loading branch information
ngoldbaum authored Nov 28, 2023
2 parents 01b2245 + 35948ca commit 1f2c42e
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 106 deletions.
184 changes: 84 additions & 100 deletions stringdtype/stringdtype/src/umath.c
Original file line number Diff line number Diff line change
Expand Up @@ -1081,44 +1081,12 @@ string_isnan_resolve_descriptors(
* Copied from NumPy, because NumPy doesn't always use it :)
*/
static int
ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[],
PyArray_DTypeMeta *final_dtype)
string_inputs_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[],
PyArray_DTypeMeta *final_dtype)
{
/* If nin < 2 promotion is a no-op, so it should not be registered */
assert(ufunc->nin > 1);
if (op_dtypes[0] == NULL) {
assert(ufunc->nin == 2 && ufunc->nout == 1); /* must be reduction */
Py_INCREF(op_dtypes[1]);
new_op_dtypes[0] = op_dtypes[1];
Py_INCREF(op_dtypes[1]);
new_op_dtypes[1] = op_dtypes[1];
Py_INCREF(op_dtypes[1]);
new_op_dtypes[2] = op_dtypes[1];
return 0;
}
PyArray_DTypeMeta *common = NULL;
/*
* If a signature is used and homogeneous in its outputs use that
* (Could/should likely be rather applied to inputs also, although outs
* only could have some advantage and input dtypes are rarely enforced.)
*/
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
if (signature[i] != NULL) {
if (common == NULL) {
Py_INCREF(signature[i]);
common = signature[i];
}
else if (common != signature[i]) {
Py_CLEAR(common); /* Not homogeneous, unset common */
break;
}
}
}
Py_XDECREF(common);

/* Otherwise, set all input operands to final_dtype */
/* set all input operands to final_dtype */
for (int i = 0; i < ufunc->nargs; i++) {
PyArray_DTypeMeta *tmp = final_dtype;
if (signature[i]) {
Expand All @@ -1127,6 +1095,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
Py_INCREF(tmp);
new_op_dtypes[i] = tmp;
}
/* don't touch output dtypes */
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
Py_XINCREF(op_dtypes[i]);
new_op_dtypes[i] = op_dtypes[i];
Expand All @@ -1140,19 +1109,50 @@ string_object_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[])
{
return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
signature, new_op_dtypes,
(PyArray_DTypeMeta *)&PyArray_ObjectDType);
return string_inputs_promoter((PyUFuncObject *)ufunc, op_dtypes, signature,
new_op_dtypes,
(PyArray_DTypeMeta *)&PyArray_ObjectDType);
}

static int
string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[])
{
return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
signature, new_op_dtypes,
(PyArray_DTypeMeta *)&StringDType);
return string_inputs_promoter((PyUFuncObject *)ufunc, op_dtypes, signature,
new_op_dtypes,
(PyArray_DTypeMeta *)&StringDType);
}

static int
string_multiply_promoter(PyObject *ufunc_obj, PyArray_DTypeMeta *op_dtypes[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[])
{
PyUFuncObject *ufunc = (PyUFuncObject *)ufunc_obj;
for (int i = 0; i < ufunc->nargs; i++) {
PyArray_DTypeMeta *tmp = NULL;
if (signature[i]) {
tmp = signature[i];
}
else if (op_dtypes[i] == &PyArray_PyIntAbstractDType) {
tmp = &PyArray_Int64DType;
}
else if (op_dtypes[i]) {
tmp = op_dtypes[i];
}
else {
tmp = (PyArray_DTypeMeta *)&StringDType;
}
Py_INCREF(tmp);
new_op_dtypes[i] = tmp;
}
/* don't touch output dtypes */
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
Py_XINCREF(op_dtypes[i]);
new_op_dtypes[i] = op_dtypes[i];
}
return 0;
}

// Register a ufunc.
Expand All @@ -1161,14 +1161,18 @@ string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
int
init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
resolve_descriptors_function *resolve_func,
PyArrayMethod_StridedLoop *loop_func, const char *loop_name,
int nin, int nout, NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags)
PyArrayMethod_StridedLoop *loop_func, int nin, int nout,
NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags)
{
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
if (ufunc == NULL) {
return -1;
}

char loop_name[256] = {0};

snprintf(loop_name, sizeof(loop_name), "string_%s", ufunc_name);

PyArrayMethod_Spec spec = {
.name = loop_name,
.nin = nin,
Expand Down Expand Up @@ -1208,7 +1212,7 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
PyArray_DTypeMeta *ldtype, PyArray_DTypeMeta *rdtype,
PyArray_DTypeMeta *edtype, promoter_function *promoter_impl)
{
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
PyObject *ufunc = PyObject_GetAttrString((PyObject *)numpy, ufunc_name);

if (ufunc == NULL) {
return -1;
Expand Down Expand Up @@ -1251,8 +1255,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
\
if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \
&multiply_resolve_descriptors, \
&multiply_right_##shortname##_strided_loop, \
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
&multiply_right_##shortname##_strided_loop, 2, 1, \
NPY_NO_CASTING, 0) < 0) { \
goto error; \
} \
\
Expand All @@ -1262,8 +1266,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
\
if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \
&multiply_resolve_descriptors, \
&multiply_left_##shortname##_strided_loop, \
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
&multiply_left_##shortname##_strided_loop, 2, 1, \
NPY_NO_CASTING, 0) < 0) { \
goto error; \
}

Expand All @@ -1279,53 +1283,23 @@ init_ufuncs(void)
"greater", "greater_equal",
"less", "less_equal"};

static PyArrayMethod_StridedLoop *strided_loops[6] = {
&string_equal_strided_loop, &string_not_equal_strided_loop,
&string_greater_strided_loop, &string_greater_equal_strided_loop,
&string_less_strided_loop, &string_less_equal_strided_loop,
};

PyArray_DTypeMeta *comparison_dtypes[] = {
(PyArray_DTypeMeta *)&StringDType,
(PyArray_DTypeMeta *)&StringDType, &PyArray_BoolDType};

if (init_ufunc(numpy, "equal", comparison_dtypes,
&string_comparison_resolve_descriptors,
&string_equal_strided_loop, "string_equal", 2, 1,
NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (init_ufunc(numpy, "not_equal", comparison_dtypes,
&string_comparison_resolve_descriptors,
&string_not_equal_strided_loop, "string_not_equal", 2, 1,
NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (init_ufunc(numpy, "greater", comparison_dtypes,
&string_comparison_resolve_descriptors,
&string_greater_strided_loop, "string_greater", 2, 1,
NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (init_ufunc(numpy, "greater_equal", comparison_dtypes,
&string_comparison_resolve_descriptors,
&string_greater_equal_strided_loop, "string_greater_equal",
2, 1, NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (init_ufunc(numpy, "less", comparison_dtypes,
&string_comparison_resolve_descriptors,
&string_less_strided_loop, "string_less", 2, 1,
NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (init_ufunc(numpy, "less_equal", comparison_dtypes,
&string_comparison_resolve_descriptors,
&string_less_equal_strided_loop, "string_less_equal", 2, 1,
NPY_NO_CASTING, 0) < 0) {
goto error;
}

for (int i = 0; i < 6; i++) {
if (init_ufunc(numpy, comparison_ufunc_names[i], comparison_dtypes,
&string_comparison_resolve_descriptors,
strided_loops[i], 2, 1, NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (add_promoter(numpy, comparison_ufunc_names[i],
(PyArray_DTypeMeta *)&StringDType,
&PyArray_UnicodeDType, &PyArray_BoolDType,
Expand Down Expand Up @@ -1360,8 +1334,7 @@ init_ufuncs(void)

if (init_ufunc(numpy, "isnan", isnan_dtypes,
&string_isnan_resolve_descriptors,
&string_isnan_strided_loop, "string_isnan", 1, 1,
NPY_NO_CASTING, 0) < 0) {
&string_isnan_strided_loop, 1, 1, NPY_NO_CASTING, 0) < 0) {
goto error;
}

Expand All @@ -1372,20 +1345,17 @@ init_ufuncs(void)
};

if (init_ufunc(numpy, "maximum", binary_dtypes, binary_resolve_descriptors,
&maximum_strided_loop, "string_maximum", 2, 1,
NPY_NO_CASTING, 0) < 0) {
&maximum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (init_ufunc(numpy, "minimum", binary_dtypes, binary_resolve_descriptors,
&minimum_strided_loop, "string_minimum", 2, 1,
NPY_NO_CASTING, 0) < 0) {
&minimum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
goto error;
}

if (init_ufunc(numpy, "add", binary_dtypes, binary_resolve_descriptors,
&add_strided_loop, "string_add", 2, 1, NPY_NO_CASTING,
0) < 0) {
&add_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
goto error;
}

Expand Down Expand Up @@ -1414,6 +1384,20 @@ init_ufuncs(void)
INIT_MULTIPLY(ULongLong, ulonglong);
#endif

if (add_promoter(numpy, "multiply", (PyArray_DTypeMeta *)&StringDType,
&PyArray_PyIntAbstractDType,
(PyArray_DTypeMeta *)&StringDType,
string_multiply_promoter) < 0) {
goto error;
}

if (add_promoter(numpy, "multiply", &PyArray_PyIntAbstractDType,
(PyArray_DTypeMeta *)&StringDType,
(PyArray_DTypeMeta *)&StringDType,
string_multiply_promoter) < 0) {
goto error;
}

Py_DECREF(numpy);
return 0;

Expand Down
18 changes: 12 additions & 6 deletions stringdtype/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def test_ufunc_add(dtype, string_list, other_strings, use_out):
@pytest.mark.parametrize(
"other_dtype",
[
None,
"int8",
"int16",
"int32",
Expand All @@ -666,13 +667,17 @@ def test_ufunc_add(dtype, string_list, other_strings, use_out):
def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out):
"""Test the two-argument ufuncs match python builtin behavior."""
arr = np.array(string_list, dtype=dtype)
other_dtype = np.dtype(other_dtype)
if other_dtype is not None:
other_dtype = np.dtype(other_dtype)
try:
len(other)
result = [s * o for s, o in zip(string_list, other)]
other = np.array(other, dtype=other_dtype)
other = np.array(other)
if other_dtype is not None:
other = other.astype(other_dtype)
except TypeError:
other = other_dtype.type(other)
if other_dtype is not None:
other = other_dtype.type(other)
result = [s * other for s in string_list]

if use_out:
Expand Down Expand Up @@ -702,7 +707,9 @@ def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out):

try:
len(other)
other = np.append(other, 3).astype(other_dtype)
other = np.append(other, 3)
if other_dtype is not None:
other = other.astype(other_dtype)
except TypeError:
pass

Expand All @@ -714,7 +721,7 @@ def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out):
else:
try:
assert res[-1] == dtype.na_object * other[-1]
except IndexError:
except (IndexError, TypeError):
assert res[-1] == dtype.na_object * other
else:
with pytest.raises(TypeError):
Expand Down Expand Up @@ -776,7 +783,6 @@ def test_null_roundtripping(dtype):
assert data[1] == arr[1]


@pytest.mark.xfail(strict=True)
def test_string_too_large_error():
arr = np.array(["a", "b", "c"], dtype=StringDType())
with pytest.raises(MemoryError):
Expand Down

0 comments on commit 1f2c42e

Please sign in to comment.