Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define ufunc JO and JTO simultaneously #312

Open
wants to merge 39 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7b18fe4
Start implementing def_ufunc_derivs
j-towns Oct 16, 2017
91e510e
Refactor def_ufunc_derivs
j-towns Oct 16, 2017
89cf08c
More ufuncs to new format
j-towns Oct 16, 2017
94747f5
Unary ufuncs new style
j-towns Oct 17, 2017
0e18a7e
Use broadcast_to not broadcast, also refactor unbroadcast.
j-towns Oct 18, 2017
b753db7
Ensure that deriv is evaluated during fwd pass
j-towns Oct 18, 2017
9015f0f
Merge branch 'dev-1.2' into fwd-rev-in-one-def
j-towns Oct 18, 2017
5104a90
Re-add accidentally deleted imports
j-towns Oct 18, 2017
b1e5e54
Re-add subval import
j-towns Oct 18, 2017
cd5662b
Rm unnecessary brackets
j-towns Oct 25, 2017
760618e
Merge branch 'dev-1.2' into fwd-rev-in-one-def
j-towns Oct 25, 2017
b41a84b
Merge branch 'fwd-rev-onedef-reducs' into fwd-rev-in-one-def
j-towns Oct 26, 2017
d305a45
Refactor - new numpy.util module for def_ufunc_jps
j-towns Oct 26, 2017
a86d0b6
Rename binary->nary ufunc jps
j-towns Oct 26, 2017
21ac8c8
stats.norm jps to new format
j-towns Oct 26, 2017
62c3e54
Add possibility for None jp to nary ufuncs
j-towns Oct 26, 2017
585f3cb
stats.t jps to new format (and psi, polygamma)
j-towns Oct 26, 2017
a29c770
Scipy special to new format
j-towns Oct 26, 2017
e4c719d
simplify logcdf grads
j-towns Oct 26, 2017
fad42ef
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 2, 2017
57df4d8
Re-add missing scipy tests
j-towns Nov 2, 2017
8205a3d
hypot grad to new format
j-towns Nov 2, 2017
d2c737a
New scipy ufuncs to new format
j-towns Nov 2, 2017
1ba7015
Simplify def_ufunc_jps api
j-towns Nov 2, 2017
394e28b
Refactor def_ufunc_jps
j-towns Nov 2, 2017
9a00147
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 2, 2017
486ecea
Add docstring to def_ufunc_jps
j-towns Nov 2, 2017
27e882b
New stats grads to new format
j-towns Nov 2, 2017
204baea
Add inverse pair helper
j-towns Nov 2, 2017
98bedda
Beta fns to new ufunc jp format
j-towns Nov 3, 2017
b1852af
Add tanh and 'add' benchmarks
j-towns Nov 6, 2017
cd5d24f
rm unnecessary match_complex from ufunc vjps
j-towns Nov 6, 2017
f6dfd73
fix numpy vjp benchmarks
j-towns Nov 6, 2017
70e0bd0
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 13, 2017
087880b
Merge branch 'master' into fwd-rev-in-one-def
j-towns Nov 20, 2017
b5fa236
Fix arctan2 jps def
j-towns Nov 20, 2017
57eb1b4
Rm unused imports
j-towns Nov 20, 2017
93e5601
fix indentation numpy_wrapper.py
j-towns Nov 20, 2017
a9c0e45
Define derivs for scipy.special.rel_entr
j-towns Dec 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autograd/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from builtins import zip
import numpy.fft as ffto
from .numpy_wrapper import wrap_namespace
from .numpy_vjps import match_complex
from .util import match_complex
from . import numpy_wrapper as anp
from autograd.extend import primitive, defvjp, vspace

Expand Down
153 changes: 71 additions & 82 deletions autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
tensordot_adjoint_1, nograd_functions)
from .numpy_vjps import (untake, balanced_eq, replace_zero, dot_adjoint_0, dot_adjoint_1,
tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions)
from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode,
register_notrace)
from .util import def_ufunc_jps, def_ufunc_jps_inv_pair

from ..util import func
from .numpy_boxes import ArrayBox

for fun in nograd_functions:
register_notrace(JVPNode, fun)

defjvp(anp.broadcast_to, 'same')
defjvp(anp._broadcast_to_adjoint, 'same')

defjvp(func(ArrayBox.__getitem__), 'same')
defjvp(untake, 'same')

Expand All @@ -18,47 +22,74 @@
lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g))

# ----- Functions that are constant w.r.t. continuous inputs -----

defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs (linear) -----
def_linear(anp.multiply)
# ----- Unary ufuncs ------

def_ufunc_jps(anp.negative, 'same')
def_ufunc_jps(anp.rad2deg, 'same')
def_ufunc_jps(anp.degrees, 'same')
def_ufunc_jps(anp.deg2rad, 'same')
def_ufunc_jps(anp.radians, 'same')
def_ufunc_jps(anp.abs,
(lambda ans, x: replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.), 'cmul'))
def_ufunc_jps(anp.fabs, (lambda ans, x: anp.sign(x), 'mul')) # fabs doesn't take complex numbers.
def_ufunc_jps(anp.absolute, (lambda ans, x: anp.conj(x) / ans, 'cmul'))
def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' ))
def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' ))
def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul' ))
def_ufunc_jps(anp.cos, (lambda ans, x: -anp.sin(x), 'mul' ))
def_ufunc_jps(anp.arcsin, (lambda ans, x: anp.sqrt(1 - x**2), 'div' ))
def_ufunc_jps(anp.arccos, (lambda ans, x:-anp.sqrt(1 - x**2), 'div' ))
def_ufunc_jps(anp.cosh, (lambda ans, x: anp.sinh(x), 'mul' ))
def_ufunc_jps(anp.arccosh, (lambda ans, x: anp.sqrt(x**2 - 1), 'div' ))
def_ufunc_jps(anp.sinc, (lambda ans, x: (anp.cos(anp.pi*x)-ans)/x, 'mul' ))
def_ufunc_jps(anp.real_if_close, 'cid')
def_ufunc_jps(anp.real, 'cid')
def_ufunc_jps(anp.imag, (lambda ans, x: -1j, 'cmul'))
def_ufunc_jps(anp.conj, 'same')
def_ufunc_jps(anp.conjugate, 'same')
def_ufunc_jps(anp.angle, (lambda ans, x: anp.conj(x * 1j)/anp.abs(x)**2, 'cmul'))

def_ufunc_jps_inv_pair(anp.exp, anp.log, lambda ans, x: ans)
def_ufunc_jps_inv_pair(anp.exp2, anp.log2, lambda ans, x: ans * anp.log(2))
def_ufunc_jps_inv_pair(anp.expm1, anp.log1p, lambda ans, x: ans + 1)
def_ufunc_jps_inv_pair(anp.tan, anp.arctan, lambda ans, x: 1 + ans**2)
def_ufunc_jps_inv_pair(anp.tanh, anp.arctanh, lambda ans, x: 1 - ans**2)
def_ufunc_jps_inv_pair(anp.sinh, anp.arcsinh, lambda ans, x: anp.sqrt(ans**2 + 1))
def_ufunc_jps_inv_pair(anp.square, anp.sqrt, lambda ans, x: 2 * x)

# ----- Binary ufuncs -----
defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : broadcast(-g, ans))
defjvp(anp.divide, 'same',
lambda g, ans, x, y : - g * x / y**2)
defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.minimum, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.fmax, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.fmin, lambda g, ans, x, y : g * balanced_eq(x, ans, y),
lambda g, ans, x, y : g * balanced_eq(y, ans, x))
defjvp(anp.logaddexp, lambda g, ans, x, y : g * anp.exp(x-ans),
lambda g, ans, x, y : g * anp.exp(y-ans))
defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(x-ans),
lambda g, ans, x, y : g * 2**(y-ans))
defjvp(anp.true_divide,'same',
lambda g, ans, x, y : - g * x / y**2)
defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : -g * anp.floor(x/y))
defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, ans),
lambda g, ans, x, y : -g * anp.floor(x/y))
defjvp(anp.power, lambda g, ans, x, y : g * y * x ** anp.where(y, y - 1, 1.),
lambda g, ans, x, y : g * anp.log(replace_zero(x, 1.)) * x ** y)
defjvp(anp.arctan2, lambda g, ans, x, y : g * y / (x**2 + y**2),
lambda g, ans, x, y : g * -x / (x**2 + y**2))

def_ufunc_jps(anp.add, 'id', 'id')
def_ufunc_jps(anp.subtract, 'id', 'neg')
def_ufunc_jps(anp.multiply, 'same', 'same')
def_ufunc_jps(anp.divide, 'same', (lambda ans, x, y: -ans/y, 'mul'))
def_ufunc_jps(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.minimum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.fmax, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.fmin, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'),
(lambda ans, x, y: balanced_eq(y, ans, x), 'mul'))
def_ufunc_jps(anp.logaddexp, (lambda ans, x, y: anp.exp(x-ans), 'mul'),
(lambda ans, x, y: anp.exp(y-ans), 'mul'))
def_ufunc_jps(anp.logaddexp2, (lambda ans, x, y: 2**(x-ans), 'mul'),
(lambda ans, x, y: 2**(y-ans), 'mul'))
def_ufunc_jps(anp.true_divide, 'same', (lambda ans, x, y: -ans/y, 'mul'))
def_ufunc_jps(anp.mod, 'id', (lambda ans, x, y: -anp.floor(x/y), 'mul'))
def_ufunc_jps(anp.remainder, 'id', (lambda ans, x, y: -anp.floor(x/y), 'mul'))
def_ufunc_jps(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'),
(lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul'))
def_ufunc_jps(anp.hypot, (lambda ans, x, y: x / ans, 'mul'),
(lambda ans, x, y: y / ans, 'mul'))
def_ufunc_jps(anp.arctan2, (lambda ans, x, y: y / (x**2 + y**2), 'mul'),
(lambda ans, x, y:-x / (x**2 + y**2), 'mul'))

# ----- Simple grads (linear) -----
defjvp(anp.negative, 'same')
defjvp(anp.rad2deg, 'same')
defjvp(anp.degrees, 'same')
defjvp(anp.deg2rad, 'same')
defjvp(anp.radians, 'same')

defjvp(anp.reshape, 'same')
defjvp(anp.roll, 'same')
defjvp(anp.array_split, 'same')
Expand All @@ -85,44 +116,14 @@
def_linear(anp.cross)

# ----- Simple grads -----
defjvp(anp.abs,
lambda g, ans, x : anp.real(g * replace_zero(anp.conj(x), 0.)) / replace_zero(ans, 1.))
defjvp(anp.fabs, lambda g, ans, x : anp.sign(x) * g) # fabs doesn't take complex numbers.
defjvp(anp.absolute, lambda g, ans, x : anp.real(g * anp.conj(x)) / ans)
defjvp(anp.reciprocal, lambda g, ans, x : - g / x**2)
defjvp(anp.exp, lambda g, ans, x : ans * g)
defjvp(anp.exp2, lambda g, ans, x : ans * anp.log(2) * g)
defjvp(anp.expm1, lambda g, ans, x : (ans + 1) * g)
defjvp(anp.log, lambda g, ans, x : g / x)
defjvp(anp.log2, lambda g, ans, x : g / x / anp.log(2))
defjvp(anp.log10, lambda g, ans, x : g / x / anp.log(10))
defjvp(anp.log1p, lambda g, ans, x : g / (x + 1))
defjvp(anp.sin, lambda g, ans, x : g * anp.cos(x))
defjvp(anp.cos, lambda g, ans, x : - g * anp.sin(x))
defjvp(anp.tan, lambda g, ans, x : g / anp.cos(x) **2)
defjvp(anp.arcsin, lambda g, ans, x : g / anp.sqrt(1 - x**2))
defjvp(anp.arccos, lambda g, ans, x :-g / anp.sqrt(1 - x**2))
defjvp(anp.arctan, lambda g, ans, x : g / (1 + x**2))
defjvp(anp.sinh, lambda g, ans, x : g * anp.cosh(x))
defjvp(anp.cosh, lambda g, ans, x : g * anp.sinh(x))
defjvp(anp.tanh, lambda g, ans, x : g / anp.cosh(x) **2)
defjvp(anp.arcsinh, lambda g, ans, x : g / anp.sqrt(x**2 + 1))
defjvp(anp.arccosh, lambda g, ans, x : g / anp.sqrt(x**2 - 1))
defjvp(anp.arctanh, lambda g, ans, x : g / (1 - x**2))
defjvp(anp.square, lambda g, ans, x : g * 2 * x)
defjvp(anp.sqrt, lambda g, ans, x : g * 0.5 * x**-0.5)
defjvp(anp.sinc, lambda g, ans, x : g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2))

defjvp(anp.clip, lambda g, ans, x, a_min, a_max : g * anp.logical_and(ans != a_min, ans != a_max))
defjvp(anp.real_if_close, lambda g, ans, x : match_complex(ans, g))
defjvp(anp.real, lambda g, ans, x : anp.real(g))
defjvp(anp.imag, lambda g, ans, x : match_complex(ans, -1j * g))
defjvp(anp.conj, lambda g, ans, x : anp.conj(g))
defjvp(anp.angle, lambda g, ans, x : match_complex(ans, g * anp.conj(x * 1j) / anp.abs(x)**2))
defjvp(anp.where, None,
lambda g, ans, c, x=None, y=None : anp.where(c, g, anp.zeros(anp.shape(g))),
lambda g, ans, c, x=None, y=None : anp.where(c, anp.zeros(g.shape), g))

# ----- Trickier grads -----

defjvp(anp.kron, 'same', 'same')
defjvp(anp.diff, 'same')
defjvp(anp.repeat, 'same')
Expand Down Expand Up @@ -226,15 +227,3 @@ def jvp(g, ans, *arys):
defjvp(anp.atleast_3d, atleast_jvpmaker(anp.atleast_3d))

def_linear(anp.einsum)

# TODO(mattjj): can we call np.broadcast_to or a related function instead?
def broadcast(x, target):
target_shape, target_ndim, target_dtype, target_iscomplex = anp.metadata(target)
while anp.ndim(x) < target_ndim:
x = anp.expand_dims(x, 0)
for axis, size in enumerate(anp.shape(x)):
if size == 1:
x = anp.repeat(x, target_shape[axis], axis=axis)
if target_iscomplex and not anp.iscomplexobj(x):
x = x + 0j # TODO(mattjj): this might promote the dtype
return x
118 changes: 16 additions & 102 deletions autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as onp
from ..util import func
from . import numpy_wrapper as anp
from autograd.numpy.util import unbroadcast
from .numpy_boxes import ArrayBox
from autograd.extend import (primitive, vspace, defvjp, defvjp_argnum,
SparseObject, VJPNode, register_notrace)
from autograd.extend import (primitive, vspace, defvjp, defvjp_argnum, SparseObject, VJPNode,
register_notrace)

# ----- Non-differentiable functions -----

Expand All @@ -27,77 +28,8 @@

defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs -----

defvjp(anp.add, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: g))
defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g),
lambda ans, x, y : unbroadcast_f(y, lambda g: x * g))
defvjp(anp.subtract, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g))
defvjp(anp.divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y),
lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
defvjp(anp.maximum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.minimum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.fmax, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.fmin, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)))
defvjp(anp.logaddexp, lambda ans, x, y : unbroadcast_f(x, lambda g: g * anp.exp(x-ans)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.exp(y-ans)))
defvjp(anp.logaddexp2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * 2**(x-ans)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * 2**(y-ans)))
defvjp(anp.true_divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y),
lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
defvjp(anp.mod, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y)))
defvjp(anp.remainder, lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y)))
defvjp(anp.power,
lambda ans, x, y : unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.log(replace_zero(x, 1.)) * x ** y))
defvjp(anp.arctan2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * y / (x**2 + y**2)),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * -x / (x**2 + y**2)))
defvjp(anp.hypot,
lambda ans, x, y : unbroadcast_f(x, lambda g: g * x / ans),
lambda ans, x, y : unbroadcast_f(y, lambda g: g * y / ans))

# ----- Simple grads -----

defvjp(anp.negative, lambda ans, x: lambda g: -g)
defvjp(anp.abs,
lambda ans, x : lambda g: g * replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.))
defvjp(anp.fabs, lambda ans, x : lambda g: anp.sign(x) * g) # fabs doesn't take complex numbers.
defvjp(anp.absolute, lambda ans, x : lambda g: g * anp.conj(x) / ans)
defvjp(anp.reciprocal, lambda ans, x : lambda g: - g / x**2)
defvjp(anp.exp, lambda ans, x : lambda g: ans * g)
defvjp(anp.exp2, lambda ans, x : lambda g: ans * anp.log(2) * g)
defvjp(anp.expm1, lambda ans, x : lambda g: (ans + 1) * g)
defvjp(anp.log, lambda ans, x : lambda g: g / x)
defvjp(anp.log2, lambda ans, x : lambda g: g / x / anp.log(2))
defvjp(anp.log10, lambda ans, x : lambda g: g / x / anp.log(10))
defvjp(anp.log1p, lambda ans, x : lambda g: g / (x + 1))
defvjp(anp.sin, lambda ans, x : lambda g: g * anp.cos(x))
defvjp(anp.cos, lambda ans, x : lambda g: - g * anp.sin(x))
defvjp(anp.tan, lambda ans, x : lambda g: g / anp.cos(x) **2)
defvjp(anp.arcsin, lambda ans, x : lambda g: g / anp.sqrt(1 - x**2))
defvjp(anp.arccos, lambda ans, x : lambda g:-g / anp.sqrt(1 - x**2))
defvjp(anp.arctan, lambda ans, x : lambda g: g / (1 + x**2))
defvjp(anp.sinh, lambda ans, x : lambda g: g * anp.cosh(x))
defvjp(anp.cosh, lambda ans, x : lambda g: g * anp.sinh(x))
defvjp(anp.tanh, lambda ans, x : lambda g: g / anp.cosh(x) **2)
defvjp(anp.arcsinh, lambda ans, x : lambda g: g / anp.sqrt(x**2 + 1))
defvjp(anp.arccosh, lambda ans, x : lambda g: g / anp.sqrt(x**2 - 1))
defvjp(anp.arctanh, lambda ans, x : lambda g: g / (1 - x**2))
defvjp(anp.rad2deg, lambda ans, x : lambda g: g / anp.pi * 180.0)
defvjp(anp.degrees, lambda ans, x : lambda g: g / anp.pi * 180.0)
defvjp(anp.deg2rad, lambda ans, x : lambda g: g * anp.pi / 180.0)
defvjp(anp.radians, lambda ans, x : lambda g: g * anp.pi / 180.0)
defvjp(anp.square, lambda ans, x : lambda g: g * 2 * x)
defvjp(anp.sqrt, lambda ans, x : lambda g: g * 0.5 * x**-0.5)
defvjp(anp.sinc, lambda ans, x : lambda g: g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2))
defvjp(anp.reshape, lambda ans, x, shape, order=None : lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.roll, lambda ans, x, shift, axis=None : lambda g: anp.roll(g, -shift, axis=axis))
defvjp(anp.array_split, lambda ans, ary, idxs, axis=0 : lambda g: anp.concatenate(g, axis=axis))
Expand All @@ -123,12 +55,6 @@
anp.moveaxis(g, destination, source))
defvjp(anp.rollaxis, lambda ans, a, axis, start=0: lambda g: anp.rollaxis(g, start - 1, axis) if start > axis
else anp.rollaxis(g, start, axis + 1))
defvjp(anp.real_if_close, lambda ans, x : lambda g: match_complex(x, g))
defvjp(anp.real, lambda ans, x : lambda g: match_complex(x, g))
defvjp(anp.imag, lambda ans, x : lambda g: match_complex(x, -1j * g))
defvjp(anp.conj, lambda ans, x : lambda g: anp.conj(g))
defvjp(anp.conjugate, lambda ans, x: lambda g: anp.conj(g))
defvjp(anp.angle, lambda ans, x : lambda g: match_complex(x, g * anp.conj(x * 1j) / anp.abs(x)**2))
defvjp(anp.where, None,
lambda ans, c, x=None, y=None : lambda g: anp.where(c, g, anp.zeros(g.shape)),
lambda ans, c, x=None, y=None : lambda g: anp.where(c, anp.zeros(g.shape), g))
Expand Down Expand Up @@ -541,31 +467,6 @@ def vjp(g):
lambda ans, D, offset=0, axis1=0, axis2=1 :
lambda g: anp.diagonal(g, offset, axis1, axis2))

def match_complex(target, x):
target_iscomplex = anp.iscomplexobj(target)
x_iscomplex = anp.iscomplexobj(x)
if x_iscomplex and not target_iscomplex:
return anp.real(x)
elif not x_iscomplex and target_iscomplex:
return x + 0j
else:
return x

def unbroadcast(x, target_meta, broadcast_idx=0):
target_shape, target_ndim, dtype, target_iscomplex = target_meta
while anp.ndim(x) > target_ndim:
x = anp.sum(x, axis=broadcast_idx)
for axis, size in enumerate(target_shape):
if size == 1:
x = anp.sum(x, axis=axis, keepdims=True)
if anp.iscomplexobj(x) and not target_iscomplex:
x = anp.real(x)
return x

def unbroadcast_f(target, f):
target_meta = anp.metadata(target)
return lambda g: unbroadcast(f(g), target_meta)

def unbroadcast_einsum(x, target_meta, subscript):
if Ellipsis not in subscript:
return x
Expand All @@ -576,6 +477,19 @@ def unbroadcast_einsum(x, target_meta, subscript):
else:
return unbroadcast(x, target_meta, subscript.index(Ellipsis))

def _broadcast_to_vjpmaker(x_shape):
# Ensure that x can be garbage collected by only passing
# its shape to this closure.
return lambda g: anp._broadcast_to_adjoint(g, x_shape)

def _broadcast_to_adjoint_vjpmaker(g_shape):
# Ensure that g can be garbage collected by only passing
# its shape to this closure.
return lambda x: anp.broadcast_to(x, g_shape)

defvjp(anp.broadcast_to, lambda ans, x, ans_shp: _broadcast_to_vjpmaker(x.shape))
defvjp(anp._broadcast_to_adjoint, lambda ans, g, ans_shp: _broadcast_to_adjoint_vjpmaker(g.shape))

def balanced_eq(x, z, y):
return (x == z) / (1.0 + (x == y))

Expand Down
11 changes: 10 additions & 1 deletion autograd/numpy/numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def metadata(A):
def parse_einsum_input(*args):
return _parse_einsum_input(args)

@primitive
def _broadcast_to_adjoint(x, shape):
while _np.ndim(x) > len(shape):
x = _np.sum(x, axis=0)
for axis, size in enumerate(shape):
if size == 1:
x = _np.sum(x, axis=axis, keepdims=True)
return x

@primitive
def _astype(A, dtype, order='K', casting='unsafe', subok=True, copy=True):
return A.astype(dtype, order, casting, subok, copy)
return A.astype(dtype, order, casting, subok, copy)
Loading