-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement array-valued signatures #56
Comments
It's definitely a shortcoming, because the corresponding I would say that it doesn't render the library useless, though. Anyway, I took a quick shot at using |
do you have an example, perchance? |
modified numba_scipy/special/overloads.py
@@ -10,7 +10,12 @@ def choose_kernel(name, all_signatures):
for signature in all_signatures:
if args == signature:
f = signatures.name_and_types_to_pointer[(name, *signature)]
- return lambda *args: f(*args)
+
+ @numba.vectorize
+ def _f(*args):
+ return f(*args)
+
+ return _f
return choice_function results in the following error: E numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E No implementation of function Function(<ufunc 'agm'>) found for signature:
E
E >>> agm(float64, float64)
E
E There are 2 candidate implementations:
E - Of which 2 did not match due to:
E Overload in function 'choose_kernel.<locals>.choice_function': File: ../code/python/numba-scipy/numba_scipy/special/overloads.py: Line 9.
E With argument(s): '(float64, float64)':
E Rejected as the implementation raised a specific error:
E AssertionError: Implementator function returned by `@overload` has an unexpected type. Got <numba._DUFunc '_f'>
E raised from ~/envs/numba-scipy-env/lib/python3.7/site-packages/numba/core/typing/templates.py:742
E
E During: resolving callee type: Function(<ufunc 'agm'>)
E During: typing of call at ~/code/python/numba-scipy/numba_scipy/tests/test_special.py (76)
E
E
E File "numba_scipy/tests/test_special.py", line 76:
E def numba_func(*args):
E return scipy_func(*args)
E ^ Is |
The varargs could also be a problem. |
I have a hack to get this working in my If anyone knows how to get past this varargs issue without creating functions in this fashion—or any other fundamentally AST-based approach—please tell me, it would really help with the work we're doing in Aesara, as well. |
There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around. from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special
x = np.linspace(-10, 10, 1000)
# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
return x + 12.34
@vectorize
def vectorize_j0(x):
return pretend_j0_from_cython(x)
# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
if isinstance(x, (types.Array, types.Number)):
def impl(x):
return vectorize_j0(x)
return impl
@njit
def jitted_j0(x):
res1 = special.j0(x[0])
res2 = special.j0(x)
return res1, res2
print(jitted_j0(x)) |
The issue I ran into above is the signature for the |
Ah, I see, I misinterpreted this as not being able to register an overload with Opened numba/numba#6954 to track. |
Thanks for that; it's a problem that shows up in at least a couple places where we're trying to use Numba as a backend (e.g. here). |
Hello, I have been able of using the workaround by @stuartarchibald . Is there any plan add this so there is no need to write the vectorized version of every function? |
@PabloRdrRbl I think a PR has already been opened: #58 |
Is it possible to extend it to a function like |
As of #54 the simplest scalar calls to jitted special functions should work.
However there's no support yet for array-valued inputs:
This is not obviously a shortcoming, since looping in jitted functions should be alright. So this is just a mild suggestion to consider adding support for array-valued signatures. (This should probably be preceded with some benchmarks to see whether this would help anything performance-wise.)
The text was updated successfully, but these errors were encountered: