Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define ufunc JO and JTO simultaneously #312

Open
wants to merge 39 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7b18fe4
Start implementing def_ufunc_derivs
j-towns Oct 16, 2017
91e510e
Refactor def_ufunc_derivs
j-towns Oct 16, 2017
89cf08c
More ufuncs to new format
j-towns Oct 16, 2017
94747f5
Unary ufuncs new style
j-towns Oct 17, 2017
0e18a7e
Use broadcast_to not broadcast, also refactor unbroadcast.
j-towns Oct 18, 2017
b753db7
Ensure that deriv is evaluated during fwd pass
j-towns Oct 18, 2017
9015f0f
Merge branch 'dev-1.2' into fwd-rev-in-one-def
j-towns Oct 18, 2017
5104a90
Re-add accidentally deleted imports
j-towns Oct 18, 2017
b1e5e54
Re-add subval import
j-towns Oct 18, 2017
cd5662b
Rm unnecessary brackets
j-towns Oct 25, 2017
760618e
Merge branch 'dev-1.2' into fwd-rev-in-one-def
j-towns Oct 25, 2017
b41a84b
Merge branch 'fwd-rev-onedef-reducs' into fwd-rev-in-one-def
j-towns Oct 26, 2017
d305a45
Refactor - new numpy.util module for def_ufunc_jps
j-towns Oct 26, 2017
a86d0b6
Rename binary->nary ufunc jps
j-towns Oct 26, 2017
21ac8c8
stats.norm jps to new format
j-towns Oct 26, 2017
62c3e54
Add possibility for None jp to nary ufuncs
j-towns Oct 26, 2017
585f3cb
stats.t jps to new format (and psi, polygamma)
j-towns Oct 26, 2017
a29c770
Scipy special to new format
j-towns Oct 26, 2017
e4c719d
simplify logcdf grads
j-towns Oct 26, 2017
fad42ef
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 2, 2017
57df4d8
Re-add missing scipy tests
j-towns Nov 2, 2017
8205a3d
hypot grad to new format
j-towns Nov 2, 2017
d2c737a
New scipy ufuncs to new format
j-towns Nov 2, 2017
1ba7015
Simplify def_ufunc_jps api
j-towns Nov 2, 2017
394e28b
Refactor def_ufunc_jps
j-towns Nov 2, 2017
9a00147
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 2, 2017
486ecea
Add docstring to def_ufunc_jps
j-towns Nov 2, 2017
27e882b
New stats grads to new format
j-towns Nov 2, 2017
204baea
Add inverse pair helper
j-towns Nov 2, 2017
98bedda
Beta fns to new ufunc jp format
j-towns Nov 3, 2017
b1852af
Add tanh and 'add' benchmarks
j-towns Nov 6, 2017
cd5d24f
rm unnecessary match_complex from ufunc vjps
j-towns Nov 6, 2017
f6dfd73
fix numpy vjp benchmarks
j-towns Nov 6, 2017
70e0bd0
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 13, 2017
087880b
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 20, 2017
b5fa236
Fix arctan2 jps def
j-towns Nov 20, 2017
57eb1b4
Rm unused imports
j-towns Nov 20, 2017
93e5601
fix indentation numpy_wrapper.py
j-towns Nov 20, 2017
a9c0e45
Define derivs for scipy.special.rel_entr
j-towns Dec 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 109 additions & 68 deletions autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
from itertools import repeat
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
tensordot_adjoint_1)
from autograd.extend import defjvp, defjvp_argnum, def_linear, vspace
from ..util import func
tensordot_adjoint_1, unbroadcast_f)
from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace,
defvjp)
from ..util import func, subval
from .numpy_boxes import ArrayBox


def def_ufunc_jps(ufunc, *derivs_ops):
derivs_ops = list(derivs_ops)

unary_ufunc_jps = {
'same': (lambda deriv: lambda g, ans, x: ufunc(g),
lambda deriv: lambda ans, x: ufunc),
'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x),
lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g * d),
'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x),
lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g / d),
'cmul': (lambda deriv: lambda g, ans, x: match_complex(ans, g * deriv(ans, x)),
lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): match_complex(x, g * d)),
'cid': (lambda deriv: lambda g, ans, x: match_complex(ans, g),
lambda deriv: lambda ans, x: lambda g: match_complex(x , g))
}

if len(derivs_ops) == 1:
deriv, op = derivs_ops[0]
defjvp(ufunc, unary_ufunc_jps[op][0](deriv))
defvjp(ufunc, unary_ufunc_jps[op][1](deriv))

binary_ufunc_jps = {
'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)),
lambda argnum, deriv: lambda ans, *args:
unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))),
'id': (lambda argnum, deriv: lambda g, ans, *args: broadcast(g, ans),
lambda argnum, deriv: lambda ans, *args:
unbroadcast_f(args[argnum], lambda g: g)),
'neg': (lambda argnum, deriv: lambda g, ans, *args: broadcast(-g, ans),
lambda argnum, deriv: lambda ans, *args:
unbroadcast_f(args[argnum], lambda g: -g)),
'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args),
lambda argnum, deriv: lambda ans, *args:
unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)),
Copy link
Collaborator Author

@j-towns j-towns Oct 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the vjps I've used this slightly weird d=deriv(ans, *args) default argument syntax to ensure that deriv is evaluated during the forward pass, allowing *args and ans to potentially be garbage collected.

Any objections? I could also have done this using a kind of helper closure to evaluate deriv, which would have been a bit more explicit.

'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args),
lambda argnum, deriv: lambda ans, *args:
unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d))
}
if len(derivs_ops) == 2:
defjvp(ufunc, *[binary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)])
defvjp(ufunc, *[binary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)])


defjvp(func(ArrayBox.__getitem__), 'same')
defjvp(untake, 'same')

Expand All @@ -16,43 +62,70 @@
# ----- Functions that are constant w.r.t. continuous inputs -----
defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs (linear) -----
def_linear(anp.multiply)
# ----- Unary ufuncs ------
def_ufunc_jps(anp.negative, (None, 'same'))
def_ufunc_jps(anp.rad2deg, (None, 'same'))
def_ufunc_jps(anp.degrees, (None, 'same'))
def_ufunc_jps(anp.deg2rad, (None, 'same'))
def_ufunc_jps(anp.radians, (None, 'same'))
def_ufunc_jps(anp.abs,
(lambda ans, x: replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.), 'cmul'))
def_ufunc_jps(anp.fabs, (lambda ans, x: anp.sign(x), 'mul')) # fabs doesn't take complex numbers.
def_ufunc_jps(anp.absolute, (lambda ans, x: anp.conj(x) / ans, 'cmul'))
def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' ))
def_ufunc_jps(anp.exp, (lambda ans, x: ans, 'mul' ))
def_ufunc_jps(anp.exp2, (lambda ans, x: ans * anp.log(2), 'mul' ))
def_ufunc_jps(anp.expm1, (lambda ans, x: (ans + 1), 'mul' ))
def_ufunc_jps(anp.log, (lambda ans, x: x, 'div' ))
def_ufunc_jps(anp.log2, (lambda ans, x: x * anp.log(2), 'div' ))
def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' ))
def_ufunc_jps(anp.log1p, (lambda ans, x: x + 1, 'div' ))
def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul' ))
def_ufunc_jps(anp.cos, (lambda ans, x: -anp.sin(x), 'mul' ))
def_ufunc_jps(anp.tan, (lambda ans, x: 1 + ans**2, 'mul' ))
def_ufunc_jps(anp.arcsin, (lambda ans, x: anp.sqrt(1 - x**2), 'div' ))
def_ufunc_jps(anp.arccos, (lambda ans, x:-anp.sqrt(1 - x**2), 'div' ))
def_ufunc_jps(anp.arctan, (lambda ans, x: 1 + x**2, 'div' ))
def_ufunc_jps(anp.sinh, (lambda ans, x: anp.cosh(x), 'mul' ))
def_ufunc_jps(anp.cosh, (lambda ans, x: anp.sinh(x), 'mul' ))
def_ufunc_jps(anp.tanh, (lambda ans, x: 1 - ans**2, 'mul' ))
def_ufunc_jps(anp.arcsinh, (lambda ans, x: anp.sqrt(x**2 + 1), 'div' ))
def_ufunc_jps(anp.arccosh, (lambda ans, x: anp.sqrt(x**2 - 1), 'div' ))
def_ufunc_jps(anp.arctanh, (lambda ans, x: 1 - x**2, 'div' ))
def_ufunc_jps(anp.square, (lambda ans, x: 2 * x, 'mul' ))
def_ufunc_jps(anp.sqrt, (lambda ans, x: 2 * ans, 'div' ))
def_ufunc_jps(anp.sinc, (lambda ans, x: (anp.cos(anp.pi*x)-ans)/x, 'mul' ))
def_ufunc_jps(anp.real_if_close, (None, 'cid'))
def_ufunc_jps(anp.real, (None, 'cid'))
def_ufunc_jps(anp.imag, (lambda ans, x: -1j, 'cmul'))
def_ufunc_jps(anp.conj, (None, 'same'))
def_ufunc_jps(anp.conjugate, (None, 'same'))
def_ufunc_jps(anp.angle, (lambda ans, x: anp.conj(x * 1j)/anp.abs(x)**2, 'cmul'))

# ----- Binary ufuncs -----
defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : broadcast(-g, ans))
defjvp(anp.divide, 'same',
lambda g, ans, x, y : - g * x / y**2)
defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.minimum, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.fmax, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.fmin, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.logaddexp, lambda g, ans, x, y : g * anp.exp(x-ans),
lambda g, ans, x, y : g * anp.exp(y-ans))
defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(x-ans),
lambda g, ans, x, y : g * 2**(y-ans))
defjvp(anp.true_divide,'same',
lambda g, ans, x, y : - g * x / y**2)
defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : -g * anp.floor(x/y))
defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : -g * anp.floor(x/y))
defjvp(anp.power, lambda g, ans, x, y : g * y * x ** anp.where(y, y - 1, 1.),
lambda g, ans, x, y : g * anp.log(replace_zero(x, 1.)) * x ** y)
def_ufunc_jps(anp.add, *repeat((None, 'id'), 2))
def_ufunc_jps(anp.subtract, (None, 'id'), (None, 'neg'))
def_ufunc_jps(anp.multiply, *repeat((None, 'same'), 2))
def_ufunc_jps(anp.divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul'))
def_ufunc_jps(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.minimum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.fmax, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.fmin, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.logaddexp, (lambda ans, x, y: anp.exp(x-ans), 'mul'),
(lambda ans, x, y: anp.exp(y-ans), 'mul'))
def_ufunc_jps(anp.logaddexp2, (lambda ans, x, y: 2**(x-ans), 'mul'),
(lambda ans, x, y: 2**(y-ans), 'mul'))
def_ufunc_jps(anp.true_divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul'))
def_ufunc_jps(anp.mod, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul'))
def_ufunc_jps(anp.remainder, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul'))
def_ufunc_jps(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'),
(lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul'))

# ----- Simple grads (linear) -----
defjvp(anp.negative, 'same')
defjvp(anp.rad2deg, 'same')
defjvp(anp.degrees, 'same')
defjvp(anp.deg2rad, 'same')
defjvp(anp.radians, 'same')
defjvp(anp.reshape, 'same')
defjvp(anp.roll, 'same')
defjvp(anp.array_split, 'same')
Expand All @@ -79,39 +152,7 @@
def_linear(anp.cross)

# ----- Simple grads -----
defjvp(anp.abs,
lambda g, ans, x : anp.real(g * replace_zero(anp.conj(x), 0.)) / replace_zero(ans, 1.))
defjvp(anp.fabs, lambda g, ans, x : anp.sign(x) * g) # fabs doesn't take complex numbers.
defjvp(anp.absolute, lambda g, ans, x : anp.real(g * anp.conj(x)) / ans)
defjvp(anp.reciprocal, lambda g, ans, x : - g / x**2)
defjvp(anp.exp, lambda g, ans, x : ans * g)
defjvp(anp.exp2, lambda g, ans, x : ans * anp.log(2) * g)
defjvp(anp.expm1, lambda g, ans, x : (ans + 1) * g)
defjvp(anp.log, lambda g, ans, x : g / x)
defjvp(anp.log2, lambda g, ans, x : g / x / anp.log(2))
defjvp(anp.log10, lambda g, ans, x : g / x / anp.log(10))
defjvp(anp.log1p, lambda g, ans, x : g / (x + 1))
defjvp(anp.sin, lambda g, ans, x : g * anp.cos(x))
defjvp(anp.cos, lambda g, ans, x : - g * anp.sin(x))
defjvp(anp.tan, lambda g, ans, x : g / anp.cos(x) **2)
defjvp(anp.arcsin, lambda g, ans, x : g / anp.sqrt(1 - x**2))
defjvp(anp.arccos, lambda g, ans, x :-g / anp.sqrt(1 - x**2))
defjvp(anp.arctan, lambda g, ans, x : g / (1 + x**2))
defjvp(anp.sinh, lambda g, ans, x : g * anp.cosh(x))
defjvp(anp.cosh, lambda g, ans, x : g * anp.sinh(x))
defjvp(anp.tanh, lambda g, ans, x : g / anp.cosh(x) **2)
defjvp(anp.arcsinh, lambda g, ans, x : g / anp.sqrt(x**2 + 1))
defjvp(anp.arccosh, lambda g, ans, x : g / anp.sqrt(x**2 - 1))
defjvp(anp.arctanh, lambda g, ans, x : g / (1 - x**2))
defjvp(anp.square, lambda g, ans, x : g * 2 * x)
defjvp(anp.sqrt, lambda g, ans, x : g * 0.5 * x**-0.5)
defjvp(anp.sinc, lambda g, ans, x : g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2))
defjvp(anp.clip, lambda g, ans, x, a_min, a_max : g * anp.logical_and(ans != a_min, ans != a_max))
defjvp(anp.real_if_close, lambda g, ans, x : match_complex(ans, g))
defjvp(anp.real, lambda g, ans, x : anp.real(g))
defjvp(anp.imag, lambda g, ans, x : match_complex(ans, -1j * g))
defjvp(anp.conj, lambda g, ans, x : anp.conj(g))
defjvp(anp.angle, lambda g, ans, x : match_complex(ans, g * anp.conj(x * 1j) / anp.abs(x)**2))
defjvp(anp.where, None,
lambda g, ans, c, x=None, y=None : anp.where(c, g, anp.zeros(anp.shape(g))),
lambda g, ans, c, x=None, y=None : anp.where(c, anp.zeros(g.shape), g))
Expand Down
70 changes: 1 addition & 69 deletions autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,76 +7,14 @@
from .numpy_boxes import ArrayBox
from autograd.extend import primitive, vspace, defvjp, defvjp_argnum, SparseObject


# ----- Functions that are constant w.r.t. continuous inputs -----

defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs -----

defvjp(anp.add, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: g))
defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g),
lambda ans, x, y : unbroadcast_f(y, lambda g: x * g))
defvjp(anp.subtract, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g))
defvjp(anp.divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y),
lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
defvjp(anp.maximum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.minimum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.fmax, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.fmin, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.logaddexp, lambda ans, x, y : unbroadcast_f(x, lambda g: g * anp.exp(x-ans)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.exp(y-ans)))
defvjp(anp.logaddexp2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * 2**(x-ans)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * 2**(y-ans)))
defvjp(anp.true_divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y),
lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
defvjp(anp.mod, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y)))
defvjp(anp.remainder, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y)))
defvjp(anp.power,
lambda ans, x, y : unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.log(replace_zero(x, 1.)) * x ** y))

# ----- Simple grads -----

defvjp(anp.negative, lambda ans, x: lambda g: -g)
defvjp(anp.abs,
lambda ans, x : lambda g: g * replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.))
defvjp(anp.fabs, lambda ans, x : lambda g: anp.sign(x) * g) # fabs doesn't take complex numbers.
defvjp(anp.absolute, lambda ans, x : lambda g: g * anp.conj(x) / ans)
defvjp(anp.reciprocal, lambda ans, x : lambda g: - g / x**2)
defvjp(anp.exp, lambda ans, x : lambda g: ans * g)
defvjp(anp.exp2, lambda ans, x : lambda g: ans * anp.log(2) * g)
defvjp(anp.expm1, lambda ans, x : lambda g: (ans + 1) * g)
defvjp(anp.log, lambda ans, x : lambda g: g / x)
defvjp(anp.log2, lambda ans, x : lambda g: g / x / anp.log(2))
defvjp(anp.log10, lambda ans, x : lambda g: g / x / anp.log(10))
defvjp(anp.log1p, lambda ans, x : lambda g: g / (x + 1))
defvjp(anp.sin, lambda ans, x : lambda g: g * anp.cos(x))
defvjp(anp.cos, lambda ans, x : lambda g: - g * anp.sin(x))
defvjp(anp.tan, lambda ans, x : lambda g: g / anp.cos(x) **2)
defvjp(anp.arcsin, lambda ans, x : lambda g: g / anp.sqrt(1 - x**2))
defvjp(anp.arccos, lambda ans, x : lambda g:-g / anp.sqrt(1 - x**2))
defvjp(anp.arctan, lambda ans, x : lambda g: g / (1 + x**2))
defvjp(anp.sinh, lambda ans, x : lambda g: g * anp.cosh(x))
defvjp(anp.cosh, lambda ans, x : lambda g: g * anp.sinh(x))
defvjp(anp.tanh, lambda ans, x : lambda g: g / anp.cosh(x) **2)
defvjp(anp.arcsinh, lambda ans, x : lambda g: g / anp.sqrt(x**2 + 1))
defvjp(anp.arccosh, lambda ans, x : lambda g: g / anp.sqrt(x**2 - 1))
defvjp(anp.arctanh, lambda ans, x : lambda g: g / (1 - x**2))
defvjp(anp.rad2deg, lambda ans, x : lambda g: g / anp.pi * 180.0)
defvjp(anp.degrees, lambda ans, x : lambda g: g / anp.pi * 180.0)
defvjp(anp.deg2rad, lambda ans, x : lambda g: g * anp.pi / 180.0)
defvjp(anp.radians, lambda ans, x : lambda g: g * anp.pi / 180.0)
defvjp(anp.square, lambda ans, x : lambda g: g * 2 * x)
defvjp(anp.sqrt, lambda ans, x : lambda g: g * 0.5 * x**-0.5)
defvjp(anp.sinc, lambda ans, x : lambda g: g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2))
defvjp(anp.reshape, lambda ans, x, shape, order=None : lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.roll, lambda ans, x, shift, axis=None : lambda g: anp.roll(g, -shift, axis=axis))
defvjp(anp.array_split, lambda ans, ary, idxs, axis=0 : lambda g: anp.concatenate(g, axis=axis))
Expand All @@ -102,12 +40,6 @@
anp.moveaxis(g, destination, source))
defvjp(anp.rollaxis, lambda ans, a, axis, start=0: lambda g: anp.rollaxis(g, start - 1, axis) if start > axis
else anp.rollaxis(g, start, axis + 1))
defvjp(anp.real_if_close, lambda ans, x : lambda g: match_complex(x, g))
defvjp(anp.real, lambda ans, x : lambda g: match_complex(x, g))
defvjp(anp.imag, lambda ans, x : lambda g: match_complex(x, -1j * g))
defvjp(anp.conj, lambda ans, x : lambda g: anp.conj(g))
defvjp(anp.conjugate, lambda ans, x: lambda g: anp.conj(g))
defvjp(anp.angle, lambda ans, x : lambda g: match_complex(x, g * anp.conj(x * 1j) / anp.abs(x)**2))
defvjp(anp.where, None,
lambda ans, c, x=None, y=None : lambda g: anp.where(c, g, anp.zeros(g.shape)),
lambda ans, c, x=None, y=None : lambda g: anp.where(c, anp.zeros(g.shape), g))
Expand Down