From 7b18fe48f6e7444eaf86933ae3f887a3c248af23 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 16 Oct 2017 16:07:43 +0200 Subject: [PATCH 01/32] Start implementing def_ufunc_derivs --- autograd/numpy/numpy_jvps.py | 48 +++++++++++++++++++++++++++--------- autograd/numpy/numpy_vjps.py | 11 ++------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index cb38a6ec7..c3b3ed747 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,11 +1,38 @@ +from itertools import repeat, count +from functools import partial from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, - tensordot_adjoint_1) -from autograd.extend import defjvp, defjvp_argnum, def_linear, vspace + tensordot_adjoint_1, unbroadcast_f) +from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, + defvjp) from ..util import func from .numpy_boxes import ArrayBox + +def def_ufunc_derivs(ufunc, derivs, ops=list(repeat('mul', 2))): + derivs = list(derivs) + ops = list(ops) + + op_map = { + 'mul': lambda g, d: g * d, + 'div': lambda g, d: g / d, + 'id' : lambda g, d: g, + 'neg': lambda g, d: -g + } + + def ufunc_jvp(deriv, op): + return lambda g, ans, *args: broadcast(op(g, deriv(ans, *args)), ans) + + def ufunc_vjp(argnum, deriv, op): + return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: op(g, deriv(ans, *args))) + + defjvp(ufunc, *[ufunc_jvp(deriv, op_map[op]) for deriv, op in zip(derivs, ops)]) + defvjp(ufunc, *[ufunc_vjp(argnum, deriv, op_map[op]) for argnum, (deriv, op) in enumerate(zip(derivs, ops))]) + + +def_ufunc_derivs(anp.exp, (lambda ans, x: ans,)) + defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') @@ -20,14 +47,13 @@ def_linear(anp.multiply) # ----- Binary ufuncs ----- -defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : broadcast(g, ans)) -defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : broadcast(-g, ans)) -defjvp(anp.divide, 'same', - lambda g, ans, x, y : - g * x / y**2) -defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) +def_ufunc_derivs(anp.add, repeat(lambda *args: None, 2), repeat('id', 2)) +def_ufunc_derivs(anp.subtract, repeat(lambda *args: None, 2), ('id', 'neg')) +def_ufunc_derivs(anp.divide, (lambda ans, x, y: y, + lambda ans, x, y: -x/y**2), ('div', 'mul')) +def_ufunc_derivs(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), + lambda ans, x, y: balanced_eq(y, ans, x))) + defjvp(anp.minimum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), lambda g, ans, x, y : g * balanced_eq(y, ans, x)) defjvp(anp.fmax, lambda g, ans, x, y : g * balanced_eq(x, ans, y), @@ -84,7 +110,7 @@ defjvp(anp.fabs, lambda g, ans, x : anp.sign(x) * g) # fabs doesn't take complex numbers. defjvp(anp.absolute, lambda g, ans, x : anp.real(g * anp.conj(x)) / ans) defjvp(anp.reciprocal, lambda g, ans, x : - g / x**2) -defjvp(anp.exp, lambda g, ans, x : ans * g) +# defjvp(anp.exp, lambda g, ans, x : ans * g) defjvp(anp.exp2, lambda g, ans, x : ans * anp.log(2) * g) defjvp(anp.expm1, lambda g, ans, x : (ans + 1) * g) defjvp(anp.log, lambda g, ans, x : g / x) diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index 2b0ea6288..d7acf5f3e 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -7,22 +7,15 @@ from .numpy_boxes import ArrayBox from autograd.extend import primitive, vspace, defvjp, defvjp_argnum, SparseObject + # ----- Functions that are constant w.r.t. continuous inputs ----- defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.)) # ----- Binary ufuncs ----- -defvjp(anp.add, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: g)) defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g), lambda ans, x, y : unbroadcast_f(y, lambda g: x * g)) -defvjp(anp.subtract, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: -g)) -defvjp(anp.divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y), - lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2)) -defvjp(anp.maximum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) defvjp(anp.minimum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) defvjp(anp.fmax, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), @@ -51,7 +44,7 @@ defvjp(anp.fabs, lambda ans, x : lambda g: anp.sign(x) * g) # fabs doesn't take complex numbers. defvjp(anp.absolute, lambda ans, x : lambda g: g * anp.conj(x) / ans) defvjp(anp.reciprocal, lambda ans, x : lambda g: - g / x**2) -defvjp(anp.exp, lambda ans, x : lambda g: ans * g) +# defvjp(anp.exp, lambda ans, x : lambda g: ans * g) defvjp(anp.exp2, lambda ans, x : lambda g: ans * anp.log(2) * g) defvjp(anp.expm1, lambda ans, x : lambda g: (ans + 1) * g) defvjp(anp.log, lambda ans, x : lambda g: g / x) From 91e510e41f853b9ef1579084a0821f38015dbb2b Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 16 Oct 2017 18:23:29 +0200 Subject: [PATCH 02/32] Refactor def_ufunc_derivs --- autograd/numpy/numpy_jvps.py | 55 +++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index c3b3ed747..267d812a9 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -6,32 +6,42 @@ tensordot_adjoint_1, unbroadcast_f) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, defvjp) -from ..util import func +from ..util import func, subval from .numpy_boxes import ArrayBox -def def_ufunc_derivs(ufunc, derivs, ops=list(repeat('mul', 2))): - derivs = list(derivs) - ops = list(ops) +def def_ufunc_derivs(ufunc, *derivs_ops): + derivs_ops = list(derivs_ops) - op_map = { - 'mul': lambda g, d: g * d, - 'div': lambda g, d: g / d, - 'id' : lambda g, d: g, - 'neg': lambda g, d: -g - } - - def ufunc_jvp(deriv, op): - return lambda g, ans, *args: broadcast(op(g, deriv(ans, *args)), ans) + def ufunc_jvp(argnum, deriv, op): + if op == 'same': + return lambda g, ans, *args: ufunc(*subval(args, argnum, g)) + elif op == 'id': + return lambda g, ans, *args: broadcast(g, ans) + elif op == 'neg': + return lambda g, ans, *args: broadcast(-g, ans) + elif op == 'mul': + return lambda g, ans, *args: broadcast(g * deriv(ans, *args), ans) + elif op == 'div': + return lambda g, ans, *args: broadcast(g / deriv(ans, *args), ans) def ufunc_vjp(argnum, deriv, op): - return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: op(g, deriv(ans, *args))) + if op == 'same': + return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g))) + elif op == 'id': + return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g) + elif op == 'neg': + return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: -g) + elif op == 'mul': + return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g * deriv(ans, *args)) + elif op == 'div': + return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g / deriv(ans, *args)) - defjvp(ufunc, *[ufunc_jvp(deriv, op_map[op]) for deriv, op in zip(derivs, ops)]) - defvjp(ufunc, *[ufunc_vjp(argnum, deriv, op_map[op]) for argnum, (deriv, op) in enumerate(zip(derivs, ops))]) + defjvp(ufunc, *[ufunc_jvp(argnum, deriv, op) for argnum, (deriv, op) in enumerate(derivs_ops)]) + defvjp(ufunc, *[ufunc_vjp(argnum, deriv, op) for argnum, (deriv, op) in enumerate(derivs_ops)]) -def_ufunc_derivs(anp.exp, (lambda ans, x: ans,)) +def_ufunc_derivs(anp.exp, (lambda ans, x: ans, 'mul')) defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') @@ -47,12 +57,11 @@ def ufunc_vjp(argnum, deriv, op): def_linear(anp.multiply) # ----- Binary ufuncs ----- -def_ufunc_derivs(anp.add, repeat(lambda *args: None, 2), repeat('id', 2)) -def_ufunc_derivs(anp.subtract, repeat(lambda *args: None, 2), ('id', 'neg')) -def_ufunc_derivs(anp.divide, (lambda ans, x, y: y, - lambda ans, x, y: -x/y**2), ('div', 'mul')) -def_ufunc_derivs(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), - lambda ans, x, y: balanced_eq(y, ans, x))) +def_ufunc_derivs(anp.add, *repeat((None, 'id'), 2)) +def_ufunc_derivs(anp.subtract, (None, 'id'), (None, 'neg')) +def_ufunc_derivs(anp.divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_derivs(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) defjvp(anp.minimum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), lambda g, ans, x, y : g * balanced_eq(y, ans, x)) From 89cf08cf6790aeb890bcf41ecc0543709a695bad Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 16 Oct 2017 18:40:45 +0200 Subject: [PATCH 03/32] More ufuncs to new format --- autograd/numpy/numpy_jvps.py | 47 ++++++++++++++++-------------------- autograd/numpy/numpy_vjps.py | 19 --------------- 2 files changed, 21 insertions(+), 45 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 267d812a9..329adeaff 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,5 +1,4 @@ -from itertools import repeat, count -from functools import partial +from itertools import repeat from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, @@ -57,30 +56,26 @@ def ufunc_vjp(argnum, deriv, op): def_linear(anp.multiply) # ----- Binary ufuncs ----- -def_ufunc_derivs(anp.add, *repeat((None, 'id'), 2)) -def_ufunc_derivs(anp.subtract, (None, 'id'), (None, 'neg')) -def_ufunc_derivs(anp.divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) -def_ufunc_derivs(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), - (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) - -defjvp(anp.minimum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) -defjvp(anp.fmax, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) -defjvp(anp.fmin, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) -defjvp(anp.logaddexp, lambda g, ans, x, y : g * anp.exp(x-ans), - lambda g, ans, x, y : g * anp.exp(y-ans)) -defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(x-ans), - lambda g, ans, x, y : g * 2**(y-ans)) -defjvp(anp.true_divide,'same', - lambda g, ans, x, y : - g * x / y**2) -defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : -g * anp.floor(x/y)) -defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : -g * anp.floor(x/y)) -defjvp(anp.power, lambda g, ans, x, y : g * y * x ** anp.where(y, y - 1, 1.), - lambda g, ans, x, y : g * anp.log(replace_zero(x, 1.)) * x ** y) +def_ufunc_derivs(anp.add, *repeat((None, 'id'), 2)) +def_ufunc_derivs(anp.subtract, (None, 'id'), (None, 'neg')) +def_ufunc_derivs(anp.divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_derivs(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_derivs(anp.minimum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_derivs(anp.fmax, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_derivs(anp.fmin, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_derivs(anp.logaddexp, (lambda ans, x, y: anp.exp(x-ans), 'mul'), + (lambda ans, x, y: anp.exp(y-ans), 'mul')) +def_ufunc_derivs(anp.logaddexp2, (lambda ans, x, y: 2**(x-ans), 'mul'), + (lambda ans, x, y: 2**(y-ans), 'mul')) +def_ufunc_derivs(anp.true_divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_derivs(anp.mod, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_derivs(anp.remainder, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_derivs(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'), + (lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul')) # ----- Simple grads (linear) ----- defjvp(anp.negative, 'same') diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index d7acf5f3e..c5fe2fd2b 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -16,25 +16,6 @@ defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g), lambda ans, x, y : unbroadcast_f(y, lambda g: x * g)) -defvjp(anp.minimum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) -defvjp(anp.fmax, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) -defvjp(anp.fmin, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) -defvjp(anp.logaddexp, lambda ans, x, y : unbroadcast_f(x, lambda g: g * anp.exp(x-ans)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.exp(y-ans))) -defvjp(anp.logaddexp2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * 2**(x-ans)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * 2**(y-ans))) -defvjp(anp.true_divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y), - lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2)) -defvjp(anp.mod, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y))) -defvjp(anp.remainder, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y))) -defvjp(anp.power, - lambda ans, x, y : unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.log(replace_zero(x, 1.)) * x ** y)) # ----- Simple grads ----- From 94747f53b32f16bb55e997c96d138a7e257a52ba Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Tue, 17 Oct 2017 12:05:01 +0200 Subject: [PATCH 04/32] Unary ufuncs new style --- autograd/numpy/numpy_jvps.py | 183 +++++++++++++++++++---------------- autograd/numpy/numpy_vjps.py | 42 -------- 2 files changed, 97 insertions(+), 128 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 329adeaff..34bc1ba49 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -9,39 +9,49 @@ from .numpy_boxes import ArrayBox -def def_ufunc_derivs(ufunc, *derivs_ops): +def def_ufunc_jps(ufunc, *derivs_ops): derivs_ops = list(derivs_ops) - def ufunc_jvp(argnum, deriv, op): - if op == 'same': - return lambda g, ans, *args: ufunc(*subval(args, argnum, g)) - elif op == 'id': - return lambda g, ans, *args: broadcast(g, ans) - elif op == 'neg': - return lambda g, ans, *args: broadcast(-g, ans) - elif op == 'mul': - return lambda g, ans, *args: broadcast(g * deriv(ans, *args), ans) - elif op == 'div': - return lambda g, ans, *args: broadcast(g / deriv(ans, *args), ans) + unary_ufunc_jps = { + 'same': (lambda deriv: lambda g, ans, x: ufunc(g), + lambda deriv: lambda ans, x: ufunc), + 'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x), + lambda deriv: lambda ans, x: lambda g: g * deriv(ans, x)), + 'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x), + lambda deriv: lambda ans, x: lambda g: g / deriv(ans, x)), + 'cmul': (lambda deriv: lambda g, ans, x: match_complex(ans, g * deriv(ans, x)), + lambda deriv: lambda ans, x: lambda g: match_complex(x , g * deriv(ans, x))), + 'cid': (lambda deriv: lambda g, ans, x: match_complex(ans, g), + lambda deriv: lambda ans, x: lambda g: match_complex(x , g)) + } - def ufunc_vjp(argnum, deriv, op): - if op == 'same': - return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g))) - elif op == 'id': - return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g) - elif op == 'neg': - return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: -g) - elif op == 'mul': - return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g * deriv(ans, *args)) - elif op == 'div': - return lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g / deriv(ans, *args)) + if len(derivs_ops) == 1: + deriv, op = derivs_ops[0] + defjvp(ufunc, unary_ufunc_jps[op][0](deriv)) + defvjp(ufunc, unary_ufunc_jps[op][1](deriv)) - defjvp(ufunc, *[ufunc_jvp(argnum, deriv, op) for argnum, (deriv, op) in enumerate(derivs_ops)]) - defvjp(ufunc, *[ufunc_vjp(argnum, deriv, op) for argnum, (deriv, op) in enumerate(derivs_ops)]) + binary_ufunc_jps = { + 'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), + 'id': (lambda argnum, deriv: lambda g, ans, *args: broadcast(g, ans), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: g)), + 'neg': (lambda argnum, deriv: lambda g, ans, *args: broadcast(-g, ans), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: -g)), + 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: g * deriv(ans, *args))), + 'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: g / deriv(ans, *args))) + } + if len(derivs_ops) == 2: + defjvp(ufunc, *[binary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) + defvjp(ufunc, *[binary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) -def_ufunc_derivs(anp.exp, (lambda ans, x: ans, 'mul')) - defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') @@ -52,37 +62,70 @@ def ufunc_vjp(argnum, deriv, op): # ----- Functions that are constant w.r.t. continuous inputs ----- defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) -# ----- Binary ufuncs (linear) ----- -def_linear(anp.multiply) +# ----- Unary ufuncs ------ +def_ufunc_jps(anp.negative, (None, 'same')) +def_ufunc_jps(anp.rad2deg, (None, 'same')) +def_ufunc_jps(anp.degrees, (None, 'same')) +def_ufunc_jps(anp.deg2rad, (None, 'same')) +def_ufunc_jps(anp.radians, (None, 'same')) +def_ufunc_jps(anp.abs, + (lambda ans, x: replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.), 'cmul')) +def_ufunc_jps(anp.fabs, (lambda ans, x: anp.sign(x), 'mul')) # fabs doesn't take complex numbers. +def_ufunc_jps(anp.absolute, (lambda ans, x: anp.conj(x) / ans, 'cmul')) +def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' )) +def_ufunc_jps(anp.exp, (lambda ans, x: ans, 'mul' )) +def_ufunc_jps(anp.exp2, (lambda ans, x: ans * anp.log(2), 'mul' )) +def_ufunc_jps(anp.expm1, (lambda ans, x: (ans + 1), 'mul' )) +def_ufunc_jps(anp.log, (lambda ans, x: x, 'div' )) +def_ufunc_jps(anp.log2, (lambda ans, x: x * anp.log(2), 'div' )) +def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' )) +def_ufunc_jps(anp.log1p, (lambda ans, x: x + 1, 'div' )) +def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul' )) +def_ufunc_jps(anp.cos, (lambda ans, x: -anp.sin(x), 'mul' )) +def_ufunc_jps(anp.tan, (lambda ans, x: 1 + ans**2, 'mul' )) +def_ufunc_jps(anp.arcsin, (lambda ans, x: anp.sqrt(1 - x**2), 'div' )) +def_ufunc_jps(anp.arccos, (lambda ans, x:-anp.sqrt(1 - x**2), 'div' )) +def_ufunc_jps(anp.arctan, (lambda ans, x: 1 + x**2, 'div' )) +def_ufunc_jps(anp.sinh, (lambda ans, x: anp.cosh(x), 'mul' )) +def_ufunc_jps(anp.cosh, (lambda ans, x: anp.sinh(x), 'mul' )) +def_ufunc_jps(anp.tanh, (lambda ans, x: 1 - ans**2, 'mul' )) +def_ufunc_jps(anp.arcsinh, (lambda ans, x: anp.sqrt(x**2 + 1), 'div' )) +def_ufunc_jps(anp.arccosh, (lambda ans, x: anp.sqrt(x**2 - 1), 'div' )) +def_ufunc_jps(anp.arctanh, (lambda ans, x: 1 - x**2, 'div' )) +def_ufunc_jps(anp.square, (lambda ans, x: 2 * x, 'mul' )) +def_ufunc_jps(anp.sqrt, (lambda ans, x: 2 * ans, 'div' )) +def_ufunc_jps(anp.sinc, (lambda ans, x: (anp.cos(anp.pi*x)-ans)/x, 'mul' )) +def_ufunc_jps(anp.real_if_close, (None, 'cid')) +def_ufunc_jps(anp.real, (None, 'cid')) +def_ufunc_jps(anp.imag, (lambda ans, x: -1j, 'cmul')) +def_ufunc_jps(anp.conj, (None, 'same')) +def_ufunc_jps(anp.conjugate, (None, 'same')) +def_ufunc_jps(anp.angle, (lambda ans, x: anp.conj(x * 1j)/anp.abs(x)**2, 'cmul')) # ----- Binary ufuncs ----- -def_ufunc_derivs(anp.add, *repeat((None, 'id'), 2)) -def_ufunc_derivs(anp.subtract, (None, 'id'), (None, 'neg')) -def_ufunc_derivs(anp.divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) -def_ufunc_derivs(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), - (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) -def_ufunc_derivs(anp.minimum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), - (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) -def_ufunc_derivs(anp.fmax, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), - (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) -def_ufunc_derivs(anp.fmin, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), - (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) -def_ufunc_derivs(anp.logaddexp, (lambda ans, x, y: anp.exp(x-ans), 'mul'), - (lambda ans, x, y: anp.exp(y-ans), 'mul')) -def_ufunc_derivs(anp.logaddexp2, (lambda ans, x, y: 2**(x-ans), 'mul'), - (lambda ans, x, y: 2**(y-ans), 'mul')) -def_ufunc_derivs(anp.true_divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) -def_ufunc_derivs(anp.mod, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) -def_ufunc_derivs(anp.remainder, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) -def_ufunc_derivs(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'), - (lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul')) +def_ufunc_jps(anp.add, *repeat((None, 'id'), 2)) +def_ufunc_jps(anp.subtract, (None, 'id'), (None, 'neg')) +def_ufunc_jps(anp.multiply, *repeat((None, 'same'), 2)) +def_ufunc_jps(anp.divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_jps(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.minimum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.fmax, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.fmin, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.logaddexp, (lambda ans, x, y: anp.exp(x-ans), 'mul'), + (lambda ans, x, y: anp.exp(y-ans), 'mul')) +def_ufunc_jps(anp.logaddexp2, (lambda ans, x, y: 2**(x-ans), 'mul'), + (lambda ans, x, y: 2**(y-ans), 'mul')) +def_ufunc_jps(anp.true_divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_jps(anp.mod, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_jps(anp.remainder, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_jps(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'), + (lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul')) # ----- Simple grads (linear) ----- -defjvp(anp.negative, 'same') -defjvp(anp.rad2deg, 'same') -defjvp(anp.degrees, 'same') -defjvp(anp.deg2rad, 'same') -defjvp(anp.radians, 'same') defjvp(anp.reshape, 'same') defjvp(anp.roll, 'same') defjvp(anp.array_split, 'same') @@ -109,39 +152,7 @@ def ufunc_vjp(argnum, deriv, op): def_linear(anp.cross) # ----- Simple grads ----- -defjvp(anp.abs, - lambda g, ans, x : anp.real(g * replace_zero(anp.conj(x), 0.)) / replace_zero(ans, 1.)) -defjvp(anp.fabs, lambda g, ans, x : anp.sign(x) * g) # fabs doesn't take complex numbers. -defjvp(anp.absolute, lambda g, ans, x : anp.real(g * anp.conj(x)) / ans) -defjvp(anp.reciprocal, lambda g, ans, x : - g / x**2) -# defjvp(anp.exp, lambda g, ans, x : ans * g) -defjvp(anp.exp2, lambda g, ans, x : ans * anp.log(2) * g) -defjvp(anp.expm1, lambda g, ans, x : (ans + 1) * g) -defjvp(anp.log, lambda g, ans, x : g / x) -defjvp(anp.log2, lambda g, ans, x : g / x / anp.log(2)) -defjvp(anp.log10, lambda g, ans, x : g / x / anp.log(10)) -defjvp(anp.log1p, lambda g, ans, x : g / (x + 1)) -defjvp(anp.sin, lambda g, ans, x : g * anp.cos(x)) -defjvp(anp.cos, lambda g, ans, x : - g * anp.sin(x)) -defjvp(anp.tan, lambda g, ans, x : g / anp.cos(x) **2) -defjvp(anp.arcsin, lambda g, ans, x : g / anp.sqrt(1 - x**2)) -defjvp(anp.arccos, lambda g, ans, x :-g / anp.sqrt(1 - x**2)) -defjvp(anp.arctan, lambda g, ans, x : g / (1 + x**2)) -defjvp(anp.sinh, lambda g, ans, x : g * anp.cosh(x)) -defjvp(anp.cosh, lambda g, ans, x : g * anp.sinh(x)) -defjvp(anp.tanh, lambda g, ans, x : g / anp.cosh(x) **2) -defjvp(anp.arcsinh, lambda g, ans, x : g / anp.sqrt(x**2 + 1)) -defjvp(anp.arccosh, lambda g, ans, x : g / anp.sqrt(x**2 - 1)) -defjvp(anp.arctanh, lambda g, ans, x : g / (1 - x**2)) -defjvp(anp.square, lambda g, ans, x : g * 2 * x) -defjvp(anp.sqrt, lambda g, ans, x : g * 0.5 * x**-0.5) -defjvp(anp.sinc, lambda g, ans, x : g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2)) defjvp(anp.clip, lambda g, ans, x, a_min, a_max : g * anp.logical_and(ans != a_min, ans != a_max)) -defjvp(anp.real_if_close, lambda g, ans, x : match_complex(ans, g)) -defjvp(anp.real, lambda g, ans, x : anp.real(g)) -defjvp(anp.imag, lambda g, ans, x : match_complex(ans, -1j * g)) -defjvp(anp.conj, lambda g, ans, x : anp.conj(g)) -defjvp(anp.angle, lambda g, ans, x : match_complex(ans, g * anp.conj(x * 1j) / anp.abs(x)**2)) defjvp(anp.where, None, lambda g, ans, c, x=None, y=None : anp.where(c, g, anp.zeros(anp.shape(g))), lambda g, ans, c, x=None, y=None : anp.where(c, anp.zeros(g.shape), g)) diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index c5fe2fd2b..f7e6d6e01 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -12,45 +12,9 @@ defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.)) -# ----- Binary ufuncs ----- - -defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g), - lambda ans, x, y : unbroadcast_f(y, lambda g: x * g)) # ----- Simple grads ----- -defvjp(anp.negative, lambda ans, x: lambda g: -g) -defvjp(anp.abs, - lambda ans, x : lambda g: g * replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.)) -defvjp(anp.fabs, lambda ans, x : lambda g: anp.sign(x) * g) # fabs doesn't take complex numbers. -defvjp(anp.absolute, lambda ans, x : lambda g: g * anp.conj(x) / ans) -defvjp(anp.reciprocal, lambda ans, x : lambda g: - g / x**2) -# defvjp(anp.exp, lambda ans, x : lambda g: ans * g) -defvjp(anp.exp2, lambda ans, x : lambda g: ans * anp.log(2) * g) -defvjp(anp.expm1, lambda ans, x : lambda g: (ans + 1) * g) -defvjp(anp.log, lambda ans, x : lambda g: g / x) -defvjp(anp.log2, lambda ans, x : lambda g: g / x / anp.log(2)) -defvjp(anp.log10, lambda ans, x : lambda g: g / x / anp.log(10)) -defvjp(anp.log1p, lambda ans, x : lambda g: g / (x + 1)) -defvjp(anp.sin, lambda ans, x : lambda g: g * anp.cos(x)) -defvjp(anp.cos, lambda ans, x : lambda g: - g * anp.sin(x)) -defvjp(anp.tan, lambda ans, x : lambda g: g / anp.cos(x) **2) -defvjp(anp.arcsin, lambda ans, x : lambda g: g / anp.sqrt(1 - x**2)) -defvjp(anp.arccos, lambda ans, x : lambda g:-g / anp.sqrt(1 - x**2)) -defvjp(anp.arctan, lambda ans, x : lambda g: g / (1 + x**2)) -defvjp(anp.sinh, lambda ans, x : lambda g: g * anp.cosh(x)) -defvjp(anp.cosh, lambda ans, x : lambda g: g * anp.sinh(x)) -defvjp(anp.tanh, lambda ans, x : lambda g: g / anp.cosh(x) **2) -defvjp(anp.arcsinh, lambda ans, x : lambda g: g / anp.sqrt(x**2 + 1)) -defvjp(anp.arccosh, lambda ans, x : lambda g: g / anp.sqrt(x**2 - 1)) -defvjp(anp.arctanh, lambda ans, x : lambda g: g / (1 - x**2)) -defvjp(anp.rad2deg, lambda ans, x : lambda g: g / anp.pi * 180.0) -defvjp(anp.degrees, lambda ans, x : lambda g: g / anp.pi * 180.0) -defvjp(anp.deg2rad, lambda ans, x : lambda g: g * anp.pi / 180.0) -defvjp(anp.radians, lambda ans, x : lambda g: g * anp.pi / 180.0) -defvjp(anp.square, lambda ans, x : lambda g: g * 2 * x) -defvjp(anp.sqrt, lambda ans, x : lambda g: g * 0.5 * x**-0.5) -defvjp(anp.sinc, lambda ans, x : lambda g: g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2)) defvjp(anp.reshape, lambda ans, x, shape, order=None : lambda g: anp.reshape(g, anp.shape(x), order=order)) defvjp(anp.roll, lambda ans, x, shift, axis=None : lambda g: anp.roll(g, -shift, axis=axis)) defvjp(anp.array_split, lambda ans, ary, idxs, axis=0 : lambda g: anp.concatenate(g, axis=axis)) @@ -76,12 +40,6 @@ anp.moveaxis(g, destination, source)) defvjp(anp.rollaxis, lambda ans, a, axis, start=0: lambda g: anp.rollaxis(g, start - 1, axis) if start > axis else anp.rollaxis(g, start, axis + 1)) -defvjp(anp.real_if_close, lambda ans, x : lambda g: match_complex(x, g)) -defvjp(anp.real, lambda ans, x : lambda g: match_complex(x, g)) -defvjp(anp.imag, lambda ans, x : lambda g: match_complex(x, -1j * g)) -defvjp(anp.conj, lambda ans, x : lambda g: anp.conj(g)) -defvjp(anp.conjugate, lambda ans, x: lambda g: anp.conj(g)) -defvjp(anp.angle, lambda ans, x : lambda g: match_complex(x, g * anp.conj(x * 1j) / anp.abs(x)**2)) defvjp(anp.where, None, lambda ans, c, x=None, y=None : lambda g: anp.where(c, g, anp.zeros(g.shape)), lambda ans, c, x=None, y=None : lambda g: anp.where(c, anp.zeros(g.shape), g)) From 0e18a7eef552e4d81d14a03536474d2f1422e53e Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 18 Oct 2017 11:39:52 +0200 Subject: [PATCH 05/32] Use broadcast_to not broadcast, also refactor unbroadcast. I think this obviates the changes in #292. --- autograd/numpy/numpy_jvps.py | 27 +++++++++------------------ autograd/numpy/numpy_vjps.py | 28 +++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 34bc1ba49..94a1bb019 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -2,7 +2,7 @@ from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, - tensordot_adjoint_1, unbroadcast_f) + tensordot_adjoint_1, unbroadcast_f, _broadcast_to_adjoint) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, defvjp) from ..util import func, subval @@ -34,12 +34,12 @@ def def_ufunc_jps(ufunc, *derivs_ops): 'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), lambda argnum, deriv: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), - 'id': (lambda argnum, deriv: lambda g, ans, *args: broadcast(g, ans), + 'id': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(g, ans.shape)), lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: g)), - 'neg': (lambda argnum, deriv: lambda g, ans, *args: broadcast(-g, ans), + unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], g))), + 'neg': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(-g, ans.shape)), lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: -g)), + unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], -g))), 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), lambda argnum, deriv: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g * deriv(ans, *args))), @@ -52,6 +52,9 @@ def def_ufunc_jps(ufunc, *derivs_ops): defvjp(ufunc, *[binary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) +defjvp(anp.broadcast_to, 'same') +defjvp(_broadcast_to_adjoint, 'same') + defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') @@ -75,7 +78,7 @@ def def_ufunc_jps(ufunc, *derivs_ops): def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' )) def_ufunc_jps(anp.exp, (lambda ans, x: ans, 'mul' )) def_ufunc_jps(anp.exp2, (lambda ans, x: ans * anp.log(2), 'mul' )) -def_ufunc_jps(anp.expm1, (lambda ans, x: (ans + 1), 'mul' )) +def_ufunc_jps(anp.expm1, (lambda ans, x: ans + 1, 'mul' )) def_ufunc_jps(anp.log, (lambda ans, x: x, 'div' )) def_ufunc_jps(anp.log2, (lambda ans, x: x * anp.log(2), 'div' )) def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' )) @@ -261,15 +264,3 @@ def jvp(g, ans, *arys): defjvp(anp.atleast_3d, atleast_jvpmaker(anp.atleast_3d)) def_linear(anp.einsum) - -# TODO(mattjj): can we call np.broadcast_to or a related function instead? -def broadcast(x, target): - target_shape, target_ndim, target_dtype, target_iscomplex = anp.metadata(target) - while anp.ndim(x) < target_ndim: - x = anp.expand_dims(x, 0) - for axis, size in enumerate(anp.shape(x)): - if size == 1: - x = anp.repeat(x, target_shape[axis], axis=axis) - if target_iscomplex and not anp.iscomplexobj(x): - x = x + 0j # TODO(mattjj): this might promote the dtype - return x diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index f7e6d6e01..110c385d7 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -460,11 +460,7 @@ def match_complex(target, x): def unbroadcast(x, target_meta, broadcast_idx=0): target_shape, target_ndim, dtype, target_iscomplex = target_meta - while anp.ndim(x) > target_ndim: - x = anp.sum(x, axis=broadcast_idx) - for axis, size in enumerate(target_shape): - if size == 1: - x = anp.sum(x, axis=axis, keepdims=True) + x = _broadcast_to_adjoint(x, target_shape) if anp.iscomplexobj(x) and not target_iscomplex: x = anp.real(x) return x @@ -483,6 +479,28 @@ def unbroadcast_einsum(x, target_meta, subscript): else: return unbroadcast(x, target_meta, subscript.index(Ellipsis)) +@primitive +def _broadcast_to_adjoint(x, shape): + while anp.ndim(x) > len(shape): + x = onp.sum(x, axis=0) + for axis, size in enumerate(shape): + if size == 1: + x = onp.sum(x, axis=axis, keepdims=True) + return x + +def _broadcast_to_vjpmaker(x_shape): + # Ensure that x can be garbage collected by only passing + # its shape to this closure. + return lambda g: _broadcast_to_adjoint(g, x_shape) + +def _broadcast_to_adjoint_vjpmaker(g_shape): + # Ensure that g can be garbage collected by only passing + # its shape to this closure. + return lambda x: anp.broadcast_to(x, g_shape) + +defvjp(anp.broadcast_to, lambda ans, x, ans_shp: _broadcast_to_vjpmaker(x.shape)) +defvjp(_broadcast_to_adjoint, lambda ans, g, ans_shp: _broadcast_to_adjoint_vjpmaker(g.shape)) + def balanced_eq(x, z, y): return (x == z) / (1.0 + (x == y)) From b753db79944448d0737e12b61fa59cf4655c44c3 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 18 Oct 2017 12:17:49 +0200 Subject: [PATCH 06/32] Ensure that deriv is evaluated during fwd pass This should minimize memory overhead. --- autograd/numpy/numpy_jvps.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 34bc1ba49..ec063e9c3 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -16,11 +16,11 @@ def def_ufunc_jps(ufunc, *derivs_ops): 'same': (lambda deriv: lambda g, ans, x: ufunc(g), lambda deriv: lambda ans, x: ufunc), 'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x), - lambda deriv: lambda ans, x: lambda g: g * deriv(ans, x)), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g * d), 'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x), - lambda deriv: lambda ans, x: lambda g: g / deriv(ans, x)), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g / d), 'cmul': (lambda deriv: lambda g, ans, x: match_complex(ans, g * deriv(ans, x)), - lambda deriv: lambda ans, x: lambda g: match_complex(x , g * deriv(ans, x))), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): match_complex(x, g * d)), 'cid': (lambda deriv: lambda g, ans, x: match_complex(ans, g), lambda deriv: lambda ans, x: lambda g: match_complex(x , g)) } @@ -42,10 +42,10 @@ def def_ufunc_jps(ufunc, *derivs_ops): unbroadcast_f(args[argnum], lambda g: -g)), 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: g * deriv(ans, *args))), + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), 'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args), lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: g / deriv(ans, *args))) + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) } if len(derivs_ops) == 2: defjvp(ufunc, *[binary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) From 5104a90722631dad4ff7103358369d09ad1a087f Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 18 Oct 2017 12:31:40 +0200 Subject: [PATCH 07/32] Re-add accidentally deleted imports --- autograd/numpy/numpy_jvps.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index f7799cbec..4387236ea 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -6,6 +6,9 @@ from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace, defvjp) +from ..util import func +from .numpy_boxes import ArrayBox + for fun in nograd_functions: register_notrace(JVPNode, fun) From b1e5e5454a56f182c21e32cf7d603bb2a7412d60 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 18 Oct 2017 17:12:01 +0200 Subject: [PATCH 08/32] Re-add subval import --- autograd/numpy/numpy_jvps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 4387236ea..b35848f46 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -6,7 +6,7 @@ from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace, defvjp) -from ..util import func +from ..util import func, subval from .numpy_boxes import ArrayBox for fun in nograd_functions: From cd5662b7cf51ac4121578c8580f3a7673c1cba86 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Wed, 25 Oct 2017 11:34:14 +0100 Subject: [PATCH 09/32] Rm unnecessary brackets --- autograd/numpy/numpy_jvps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index b35848f46..b06ed2886 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -77,7 +77,7 @@ def def_ufunc_jps(ufunc, *derivs_ops): def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' )) def_ufunc_jps(anp.exp, (lambda ans, x: ans, 'mul' )) def_ufunc_jps(anp.exp2, (lambda ans, x: ans * anp.log(2), 'mul' )) -def_ufunc_jps(anp.expm1, (lambda ans, x: (ans + 1), 'mul' )) +def_ufunc_jps(anp.expm1, (lambda ans, x: ans + 1, 'mul' )) def_ufunc_jps(anp.log, (lambda ans, x: x, 'div' )) def_ufunc_jps(anp.log2, (lambda ans, x: x * anp.log(2), 'div' )) def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' )) From d305a45e497f98e1604621338e487dae6b228cf0 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 26 Oct 2017 13:04:00 +0100 Subject: [PATCH 10/32] Refactor - new numpy.util module for def_ufunc_jps Also move broadcast_to_adjoint into numpy_wrapper. --- autograd/numpy/fft.py | 2 +- autograd/numpy/numpy_jvps.py | 53 ++-------------- autograd/numpy/numpy_vjps.py | 39 ++---------- autograd/numpy/numpy_wrapper.py | 9 +++ autograd/numpy/util.py | 67 +++++++++++++++++++++ autograd/scipy/stats/multivariate_normal.py | 2 +- autograd/scipy/stats/norm.py | 2 +- autograd/scipy/stats/t.py | 2 +- 8 files changed, 90 insertions(+), 86 deletions(-) create mode 100644 autograd/numpy/util.py diff --git a/autograd/numpy/fft.py b/autograd/numpy/fft.py index 0d9d27050..280ccfb28 100644 --- a/autograd/numpy/fft.py +++ b/autograd/numpy/fft.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import numpy.fft as ffto from .numpy_wrapper import wrap_namespace -from .numpy_vjps import match_complex +from .util import match_complex from . import numpy_wrapper as anp from autograd.extend import primitive, defvjp, vspace from builtins import zip diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 44688fdc1..7ee7c5f5b 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,62 +1,19 @@ from itertools import repeat from . import numpy_wrapper as anp -from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, - dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, - tensordot_adjoint_1, nograd_functions, unbroadcast_f, - _broadcast_to_adjoint) +from .numpy_vjps import (untake, balanced_eq, replace_zero, dot_adjoint_0, dot_adjoint_1, + tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace, defvjp) +from .util import def_ufunc_jps -from ..util import func, subval +from ..util import func from .numpy_boxes import ArrayBox for fun in nograd_functions: register_notrace(JVPNode, fun) -def def_ufunc_jps(ufunc, *derivs_ops): - derivs_ops = list(derivs_ops) - - unary_ufunc_jps = { - 'same': (lambda deriv: lambda g, ans, x: ufunc(g), - lambda deriv: lambda ans, x: ufunc), - 'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x), - lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g * d), - 'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x), - lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g / d), - 'cmul': (lambda deriv: lambda g, ans, x: match_complex(ans, g * deriv(ans, x)), - lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): match_complex(x, g * d)), - 'cid': (lambda deriv: lambda g, ans, x: match_complex(ans, g), - lambda deriv: lambda ans, x: lambda g: match_complex(x , g)) - } - - if len(derivs_ops) == 1: - deriv, op = derivs_ops[0] - defjvp(ufunc, unary_ufunc_jps[op][0](deriv)) - defvjp(ufunc, unary_ufunc_jps[op][1](deriv)) - - binary_ufunc_jps = { - 'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), - lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), - 'id': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(g, ans.shape)), - lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], g))), - 'neg': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(-g, ans.shape)), - lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], -g))), - 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), - lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), - 'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args), - lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) - } - if len(derivs_ops) == 2: - defjvp(ufunc, *[binary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) - defvjp(ufunc, *[binary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) - defjvp(anp.broadcast_to, 'same') -defjvp(_broadcast_to_adjoint, 'same') +defjvp(anp._broadcast_to_adjoint, 'same') defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index a6fbd2617..886482214 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -4,9 +4,10 @@ import numpy as onp from ..util import func from . import numpy_wrapper as anp +from autograd.numpy.util import unbroadcast from .numpy_boxes import ArrayBox -from autograd.extend import (primitive, vspace, defvjp, defvjp_argnum, - SparseObject, VJPNode, register_notrace) +from autograd.extend import (primitive, vspace, defvjp, defvjp_argnum, SparseObject, VJPNode, + register_notrace) # ----- Non-differentiable functions ----- @@ -464,27 +465,6 @@ def vjp(g): lambda ans, D, offset=0, axis1=0, axis2=1 : lambda g: anp.diagonal(g, offset, axis1, axis2)) -def match_complex(target, x): - target_iscomplex = anp.iscomplexobj(target) - x_iscomplex = anp.iscomplexobj(x) - if x_iscomplex and not target_iscomplex: - return anp.real(x) - elif not x_iscomplex and target_iscomplex: - return x + 0j - else: - return x - -def unbroadcast(x, target_meta, broadcast_idx=0): - target_shape, target_ndim, dtype, target_iscomplex = target_meta - x = _broadcast_to_adjoint(x, target_shape) - if anp.iscomplexobj(x) and not target_iscomplex: - x = anp.real(x) - return x - -def unbroadcast_f(target, f): - target_meta = anp.metadata(target) - return lambda g: unbroadcast(f(g), target_meta) - def unbroadcast_einsum(x, target_meta, subscript): if Ellipsis not in subscript: return x @@ -495,19 +475,10 @@ def unbroadcast_einsum(x, target_meta, subscript): else: return unbroadcast(x, target_meta, subscript.index(Ellipsis)) -@primitive -def _broadcast_to_adjoint(x, shape): - while anp.ndim(x) > len(shape): - x = onp.sum(x, axis=0) - for axis, size in enumerate(shape): - if size == 1: - x = onp.sum(x, axis=axis, keepdims=True) - return x - def _broadcast_to_vjpmaker(x_shape): # Ensure that x can be garbage collected by only passing # its shape to this closure. - return lambda g: _broadcast_to_adjoint(g, x_shape) + return lambda g: anp._broadcast_to_adjoint(g, x_shape) def _broadcast_to_adjoint_vjpmaker(g_shape): # Ensure that g can be garbage collected by only passing @@ -515,7 +486,7 @@ def _broadcast_to_adjoint_vjpmaker(g_shape): return lambda x: anp.broadcast_to(x, g_shape) defvjp(anp.broadcast_to, lambda ans, x, ans_shp: _broadcast_to_vjpmaker(x.shape)) -defvjp(_broadcast_to_adjoint, lambda ans, g, ans_shp: _broadcast_to_adjoint_vjpmaker(g.shape)) +defvjp(anp._broadcast_to_adjoint, lambda ans, g, ans_shp: _broadcast_to_adjoint_vjpmaker(g.shape)) def balanced_eq(x, z, y): return (x == z) / (1.0 + (x == y)) diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index 98f0abb4f..e681b6883 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -133,3 +133,12 @@ def metadata(A): @notrace_primitive def parse_einsum_input(*args): return _parse_einsum_input(args) + +@primitive +def _broadcast_to_adjoint(x, shape): + while _np.ndim(x) > len(shape): + x = _np.sum(x, axis=0) + for axis, size in enumerate(shape): + if size == 1: + x = _np.sum(x, axis=axis, keepdims=True) + return x diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py new file mode 100644 index 000000000..fe000f27e --- /dev/null +++ b/autograd/numpy/util.py @@ -0,0 +1,67 @@ +from . import numpy_wrapper as anp +from autograd.core import defjvp, defvjp +from autograd.util import subval + +def match_complex(target, x): + target_iscomplex = anp.iscomplexobj(target) + x_iscomplex = anp.iscomplexobj(x) + if x_iscomplex and not target_iscomplex: + return anp.real(x) + elif not x_iscomplex and target_iscomplex: + return x + 0j + else: + return x + +def unbroadcast(x, target_meta, broadcast_idx=0): + target_shape, target_ndim, dtype, target_iscomplex = target_meta + x = anp._broadcast_to_adjoint(x, target_shape) + if anp.iscomplexobj(x) and not target_iscomplex: + x = anp.real(x) + return x + +def unbroadcast_f(target, f): + target_meta = anp.metadata(target) + return lambda g: unbroadcast(f(g), target_meta) + +def def_ufunc_jps(ufunc, *derivs_ops): + derivs_ops = list(derivs_ops) + + unary_ufunc_jps = { + 'same': (lambda deriv: lambda g, ans, x: ufunc(g), + lambda deriv: lambda ans, x: ufunc), + 'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g * d), + 'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g / d), + 'cmul': (lambda deriv: lambda g, ans, x: match_complex(ans, g * deriv(ans, x)), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): match_complex(x, g * d)), + 'cid': (lambda deriv: lambda g, ans, x: match_complex(ans, g), + lambda deriv: lambda ans, x: lambda g: match_complex(x , g)) + } + + if len(derivs_ops) == 1: + deriv, op = derivs_ops[0] + defjvp(ufunc, unary_ufunc_jps[op][0](deriv)) + defvjp(ufunc, unary_ufunc_jps[op][1](deriv)) + + binary_ufunc_jps = { + 'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), + 'id': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(g, ans.shape)), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], g))), + 'neg': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(-g, ans.shape)), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], -g))), + 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), + 'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) + } + if len(derivs_ops) == 2: + defjvp(ufunc, *[binary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) + defvjp(ufunc, *[binary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) + diff --git a/autograd/scipy/stats/multivariate_normal.py b/autograd/scipy/stats/multivariate_normal.py index fc4924971..b3a32ed04 100644 --- a/autograd/scipy/stats/multivariate_normal.py +++ b/autograd/scipy/stats/multivariate_normal.py @@ -2,7 +2,7 @@ import scipy.stats import autograd.numpy as np -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.numpy.util import unbroadcast_f from autograd.extend import primitive, defvjp diff --git a/autograd/scipy/stats/norm.py b/autograd/scipy/stats/norm.py index 5fa88a29f..41b2fe717 100644 --- a/autograd/scipy/stats/norm.py +++ b/autograd/scipy/stats/norm.py @@ -3,7 +3,7 @@ import scipy.stats import autograd.numpy as anp from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.numpy.util import unbroadcast_f pdf = primitive(scipy.stats.norm.pdf) cdf = primitive(scipy.stats.norm.cdf) diff --git a/autograd/scipy/stats/t.py b/autograd/scipy/stats/t.py index 66453aa5c..dc5f31c59 100644 --- a/autograd/scipy/stats/t.py +++ b/autograd/scipy/stats/t.py @@ -3,7 +3,7 @@ import scipy.stats import autograd.numpy as np from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.numpy.util import unbroadcast_f from autograd.scipy.special import psi pdf = primitive(scipy.stats.t.pdf) From a86d0b6d6cde541d38d4af0ebbb4886bf4ea27f9 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 26 Oct 2017 14:33:53 +0100 Subject: [PATCH 11/32] Rename binary->nary ufunc jps --- autograd/numpy/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index fe000f27e..564b91e45 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -44,7 +44,7 @@ def def_ufunc_jps(ufunc, *derivs_ops): defjvp(ufunc, unary_ufunc_jps[op][0](deriv)) defvjp(ufunc, unary_ufunc_jps[op][1](deriv)) - binary_ufunc_jps = { + nary_ufunc_jps = { 'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), lambda argnum, deriv: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), @@ -62,6 +62,6 @@ def def_ufunc_jps(ufunc, *derivs_ops): unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) } if len(derivs_ops) == 2: - defjvp(ufunc, *[binary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) - defvjp(ufunc, *[binary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) + defjvp(ufunc, *[nary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) + defvjp(ufunc, *[nary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) From 21ac8c892817bdfede5a3deed773f7ccf5652f48 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 26 Oct 2017 15:05:04 +0100 Subject: [PATCH 12/32] stats.norm jps to new format --- autograd/numpy/util.py | 2 +- autograd/scipy/stats/norm.py | 46 +++++++++++++----------------------- tests/test_scipy.py | 18 +++++++------- 3 files changed, 27 insertions(+), 39 deletions(-) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index 564b91e45..ce476f4d7 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -61,7 +61,7 @@ def def_ufunc_jps(ufunc, *derivs_ops): lambda argnum, deriv: lambda ans, *args: unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) } - if len(derivs_ops) == 2: + if len(derivs_ops) >= 2: defjvp(ufunc, *[nary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) defvjp(ufunc, *[nary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) diff --git a/autograd/scipy/stats/norm.py b/autograd/scipy/stats/norm.py index 41b2fe717..ca58de98b 100644 --- a/autograd/scipy/stats/norm.py +++ b/autograd/scipy/stats/norm.py @@ -3,41 +3,29 @@ import scipy.stats import autograd.numpy as anp from autograd.extend import primitive, defvjp -from autograd.numpy.util import unbroadcast_f +from autograd.numpy.util import def_ufunc_jps pdf = primitive(scipy.stats.norm.pdf) cdf = primitive(scipy.stats.norm.cdf) logpdf = primitive(scipy.stats.norm.logpdf) logcdf = primitive(scipy.stats.norm.logcdf) -defvjp(pdf, - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * ans * (((x - loc)/scale)**2 - 1.0)/scale)) +def_ufunc_jps(pdf, + (lambda ans, x, loc=0.0, scale=1.0: -ans * (x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: ans * (x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: ans * (((x - loc)/scale)**2 - 1.0)/scale, 'mul')) -defvjp(cdf, - lambda ans, x, loc=-1.0, scale=1.0: - unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)) , - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)), - lambda ans, x, loc=-1.0, scale=1.0: - unbroadcast_f(scale, lambda g: -g * pdf(x, loc, scale)*(x-loc)/scale)) +def_ufunc_jps(logpdf, + (lambda ans, x, loc=0.0, scale=1.0: -(x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: (x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: -1.0/scale + (x - loc)**2/scale**3, 'mul')) -defvjp(logpdf, - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2), - lambda ans, x, loc=-1.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * (-1.0/scale + (x - loc)**2/scale**3))) +def_ufunc_jps(cdf, + (lambda ans, x, loc=0.0, scale=1.0: pdf(x, loc, scale), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: -pdf(x, loc, scale), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: -pdf(x, loc, scale)*(x-loc)/scale, 'mul')) -defvjp(logcdf, - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale)) +def_ufunc_jps(logcdf, + (lambda ans, x, loc=0.0, scale=1.0: anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale, 'mul')) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index d483e1840..3f01bcd3a 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -26,15 +26,15 @@ check_grads = partial(check_grads, modes=['rev']) ### Stats ### -def test_norm_pdf(): combo_check(stats.norm.pdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) -def test_norm_cdf(): combo_check(stats.norm.cdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) -def test_norm_logpdf(): combo_check(stats.norm.logpdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) -def test_norm_logcdf(): combo_check(stats.norm.logcdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) - -def test_norm_pdf_broadcast(): combo_check(stats.norm.pdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) -def test_norm_cdf_broadcast(): combo_check(stats.norm.cdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) -def test_norm_logpdf_broadcast(): combo_check(stats.norm.logpdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) -def test_norm_logcdf_broadcast(): combo_check(stats.norm.logcdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) +def test_norm_pdf(): combo_check(stats.norm.pdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) +def test_norm_cdf(): combo_check(stats.norm.cdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) +def test_norm_logpdf(): combo_check(stats.norm.logpdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) +def test_norm_logcdf(): combo_check(stats.norm.logcdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) + +def test_norm_pdf_broadcast(): combo_check(stats.norm.pdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) +def test_norm_cdf_broadcast(): combo_check(stats.norm.cdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) +def test_norm_logpdf_broadcast(): combo_check(stats.norm.logpdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) +def test_norm_logcdf_broadcast(): combo_check(stats.norm.logcdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) def test_t_pdf(): combo_check(stats.t.pdf, [0,1,2,3])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) def test_t_cdf(): combo_check(stats.t.cdf, [0,2])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) From 62c3e5415e8cacc2c2377d35644c5ca7877bf0d0 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 26 Oct 2017 15:20:16 +0100 Subject: [PATCH 13/32] Add possibility for None jp to nary ufuncs --- autograd/numpy/util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index ce476f4d7..ca58a1db8 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -62,6 +62,8 @@ def def_ufunc_jps(ufunc, *derivs_ops): unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) } if len(derivs_ops) >= 2: - defjvp(ufunc, *[nary_ufunc_jps[op][0](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) - defvjp(ufunc, *[nary_ufunc_jps[op][1](argnum, deriv) for argnum, (deriv, op) in enumerate(derivs_ops)]) + defjvp(ufunc, *[nary_ufunc_jps[deriv_op[1]][0](argnum, deriv_op[0]) + if deriv_op is not None else None for argnum, deriv_op in enumerate(derivs_ops)]) + defvjp(ufunc, *[nary_ufunc_jps[deriv_op[1]][1](argnum, deriv_op[0]) + if deriv_op is not None else None for argnum, deriv_op in enumerate(derivs_ops)]) From 585f3cb365943d6bff590f8f92704b96e440c93d Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 26 Oct 2017 15:31:06 +0100 Subject: [PATCH 14/32] stats.t jps to new format (and psi, polygamma) --- autograd/scipy/special.py | 5 ++-- autograd/scipy/stats/t.py | 48 ++++++++++++++++----------------------- tests/test_scipy.py | 20 ++++++++-------- 3 files changed, 32 insertions(+), 41 deletions(-) diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index 9768b0ea8..76194a33a 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import scipy.special import autograd.numpy as np +from autograd.numpy.util import def_ufunc_jps from autograd.extend import primitive, defvjp ### Gamma functions ### @@ -14,8 +15,8 @@ multigammaln = primitive(scipy.special.multigammaln) defvjp(gammasgn, None) -defvjp(polygamma, None, lambda ans, n, x: lambda g: g * polygamma(n + 1, x)) -defvjp(psi, lambda ans, x: lambda g: g * polygamma(1, x)) +def_ufunc_jps(polygamma, None, (lambda ans, n, x: polygamma(n + 1, x), 'mul')) +def_ufunc_jps(psi, (lambda ans, x: polygamma(1, x), 'mul')) defvjp(digamma, lambda ans, x: lambda g: g * polygamma(1, x)) defvjp(gamma, lambda ans, x: lambda g: g * ans * psi(x)) defvjp(gammaln, lambda ans, x: lambda g: g * psi(x)) diff --git a/autograd/scipy/stats/t.py b/autograd/scipy/stats/t.py index dc5f31c59..7bb6b3cf2 100644 --- a/autograd/scipy/stats/t.py +++ b/autograd/scipy/stats/t.py @@ -3,7 +3,7 @@ import scipy.stats import autograd.numpy as np from autograd.extend import primitive, defvjp -from autograd.numpy.util import unbroadcast_f +from autograd.numpy.util import unbroadcast_f, def_ufunc_jps from autograd.scipy.special import psi pdf = primitive(scipy.stats.t.pdf) @@ -24,34 +24,24 @@ def grad_tlogpdf_df(x, df, loc, scale): y = (x - loc)/scale return 0.5 * ((y**2 * (df+1))/(df * (y**2 + df)) - np.log(y**2 / df + 1) - 1.0/df -psi(df/2.0) + psi((df + 1)/2.0)) -defvjp(pdf, lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * ans * grad_tlogpdf_x( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(df, lambda g: g * ans * grad_tlogpdf_df( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * ans * grad_tlogpdf_loc( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * ans * grad_tlogpdf_scale(x, df, loc, scale))) +def_ufunc_jps(pdf, + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_x( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_df( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_loc( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_scale(x, df, loc, scale), 'mul')) -defvjp(cdf, - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * pdf(x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: -g * pdf(x, df, loc, scale)), argnums=(0,2)) +def_ufunc_jps(logpdf, + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_x( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_df( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_loc( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_scale(x, df, loc, scale), 'mul')) -defvjp(logpdf, - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * grad_tlogpdf_x( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(df, lambda g: g * grad_tlogpdf_df( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * grad_tlogpdf_loc( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * grad_tlogpdf_scale(x, df, loc, scale))) +def_ufunc_jps(cdf, + (lambda ans, x, df, loc=0.0, scale=1.0: pdf(x, df, loc, scale), 'mul'), + None, + (lambda ans, x, df, loc=0.0, scale=1.0: -pdf(x, df, loc, scale), 'mul')) -defvjp(logcdf, - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: -g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))), - argnums=(0,2)) +def_ufunc_jps(logcdf, + (lambda ans, x, df, loc=0.0, scale=1.0: np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale)), 'mul'), + None, + (lambda ans, x, df, loc=0.0, scale=1.0: -np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale)), 'mul')) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 3f01bcd3a..1a25204e1 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -36,15 +36,15 @@ def test_norm_cdf_broadcast(): combo_check(stats.norm.cdf, [0,1,2], modes= def test_norm_logpdf_broadcast(): combo_check(stats.norm.logpdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) def test_norm_logcdf_broadcast(): combo_check(stats.norm.logcdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) -def test_t_pdf(): combo_check(stats.t.pdf, [0,1,2,3])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) -def test_t_cdf(): combo_check(stats.t.cdf, [0,2])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) -def test_t_logpdf(): combo_check(stats.t.logpdf, [0,1,2,3])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) -def test_t_logcdf(): combo_check(stats.t.logcdf, [0,2])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) +def test_t_pdf(): combo_check(stats.t.pdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) +def test_t_cdf(): combo_check(stats.t.cdf, [0,2], modes=['fwd', 'rev'])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) +def test_t_logpdf(): combo_check(stats.t.logpdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) +def test_t_logcdf(): combo_check(stats.t.logcdf, [0,2], modes=['fwd', 'rev'])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) -def test_t_pdf_broadcast(): combo_check(stats.t.pdf, [0,1,2,3])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) -def test_t_cdf_broadcast(): combo_check(stats.t.cdf, [0,2])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) -def test_t_logpdf_broadcast(): combo_check(stats.t.logpdf, [0,1,2,3])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) -def test_t_logcdf_broadcast(): combo_check(stats.t.logcdf, [0,2])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) +def test_t_pdf_broadcast(): combo_check(stats.t.pdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) +def test_t_cdf_broadcast(): combo_check(stats.t.cdf, [0,2], modes=['fwd', 'rev'])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) +def test_t_logpdf_broadcast(): combo_check(stats.t.logpdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) +def test_t_logcdf_broadcast(): combo_check(stats.t.logcdf, [0,2], modes=['fwd', 'rev'])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) def make_psd(mat): return np.dot(mat.T, mat) + np.eye(mat.shape[0]) def test_mvn_pdf(): combo_check(mvn.pdf, [0, 1, 2])([R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False]) @@ -129,11 +129,11 @@ def test_convolve_ignore_dot(): axes=[([1],[1])], dot_axes=[([0],[2]), ([0],[0])], mode=['full', 'valid']) ### Special ### -def test_polygamma(): combo_check(special.polygamma, [1])([0], R(4)**2 + 1.3) +def test_polygamma(): combo_check(special.polygamma, [1], modes=['fwd', 'rev'])([0], R(4)**2 + 1.3) def test_jn(): combo_check(special.jn, [1])([2], R(4)**2 + 1.3) def test_yn(): combo_check(special.yn, [1])([2], R(4)**2 + 1.3) -def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False) +def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False) def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False) def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False) From a29c7707ec8f7e17a768da74914ac2c4281f2664 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 26 Oct 2017 16:03:20 +0100 Subject: [PATCH 15/32] Scipy special to new format --- autograd/numpy/util.py | 10 +++++++--- autograd/scipy/special.py | 41 +++++++++++++++++++-------------------- tests/test_scipy.py | 36 +++++++++++++++++----------------- 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index ca58a1db8..6e7b2c207 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -40,9 +40,13 @@ def def_ufunc_jps(ufunc, *derivs_ops): } if len(derivs_ops) == 1: - deriv, op = derivs_ops[0] - defjvp(ufunc, unary_ufunc_jps[op][0](deriv)) - defvjp(ufunc, unary_ufunc_jps[op][1](deriv)) + if derivs_ops[0] is None: + defjvp(ufunc, None) + defvjp(ufunc, None) + else: + deriv, op = derivs_ops[0] + defjvp(ufunc, unary_ufunc_jps[op][0](deriv)) + defvjp(ufunc, unary_ufunc_jps[op][1](deriv)) nary_ufunc_jps = { 'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index 76194a33a..20c4232c2 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -14,15 +14,15 @@ rgamma = primitive(scipy.special.rgamma) multigammaln = primitive(scipy.special.multigammaln) -defvjp(gammasgn, None) +def_ufunc_jps(gammasgn, None) def_ufunc_jps(polygamma, None, (lambda ans, n, x: polygamma(n + 1, x), 'mul')) -def_ufunc_jps(psi, (lambda ans, x: polygamma(1, x), 'mul')) -defvjp(digamma, lambda ans, x: lambda g: g * polygamma(1, x)) -defvjp(gamma, lambda ans, x: lambda g: g * ans * psi(x)) -defvjp(gammaln, lambda ans, x: lambda g: g * psi(x)) -defvjp(rgamma, lambda ans, x: lambda g: g * psi(x) / -gamma(x)) -defvjp(multigammaln,lambda ans, a, d: lambda g: - g * np.sum(digamma(np.expand_dims(a, -1) - np.arange(d)/2.), -1), +def_ufunc_jps(psi, (lambda ans, x: polygamma(1, x), 'mul')) +def_ufunc_jps(digamma, (lambda ans, x: polygamma(1, x), 'mul')) +def_ufunc_jps(gamma, (lambda ans, x: ans * psi(x), 'mul')) +def_ufunc_jps(gammaln, (lambda ans, x: psi(x), 'mul')) +def_ufunc_jps(rgamma, (lambda ans, x: psi(x) / -gamma(x), 'mul')) +def_ufunc_jps(multigammaln, (lambda ans, a, d: + np.sum(digamma(np.expand_dims(a, -1) - np.arange(d)/2.), -1), 'mul'), None) ### Bessel functions ### @@ -33,33 +33,32 @@ jn = primitive(scipy.special.jn) yn = primitive(scipy.special.yn) -defvjp(j0,lambda ans, x: lambda g: -g * j1(x)) -defvjp(y0,lambda ans, x: lambda g: -g * y1(x)) -defvjp(j1,lambda ans, x: lambda g: g * (j0(x) - jn(2, x)) / 2.0) -defvjp(y1,lambda ans, x: lambda g: g * (y0(x) - yn(2, x)) / 2.0) -defvjp(jn, None, lambda ans, n, x: lambda g: g * (jn(n - 1, x) - jn(n + 1, x)) / 2.0) -defvjp(yn, None, lambda ans, n, x: lambda g: g * (yn(n - 1, x) - yn(n + 1, x)) / 2.0) +def_ufunc_jps(j0, (lambda ans, x: -j1(x), 'mul')) +def_ufunc_jps(y0, (lambda ans, x: -y1(x), 'mul')) +def_ufunc_jps(j1, (lambda ans, x: (j0(x) - jn(2, x)) / 2.0, 'mul')) +def_ufunc_jps(y1, (lambda ans, x: (y0(x) - yn(2, x)) / 2.0, 'mul')) +def_ufunc_jps(jn, None, (lambda ans, n, x: (jn(n - 1, x) - jn(n + 1, x)) / 2.0, 'mul')) +def_ufunc_jps(yn, None, (lambda ans, n, x: (yn(n - 1, x) - yn(n + 1, x)) / 2.0, 'mul')) ### Error Function ### inv_root_pi = 0.56418958354775627928 erf = primitive(scipy.special.erf) erfc = primitive(scipy.special.erfc) -defvjp(erf, lambda ans, x: lambda g: 2.*g*inv_root_pi*np.exp(-x**2)) -defvjp(erfc,lambda ans, x: lambda g: -2.*g*inv_root_pi*np.exp(-x**2)) - +def_ufunc_jps(erf, (lambda ans, x: 2.*inv_root_pi*np.exp(-x**2), 'mul')) +def_ufunc_jps(erfc, (lambda ans, x: -2.*inv_root_pi*np.exp(-x**2), 'mul')) ### Inverse error function ### root_pi = 1.7724538509055159 erfinv = primitive(scipy.special.erfinv) erfcinv = primitive(scipy.special.erfcinv) -defvjp(erfinv,lambda ans, x: lambda g: g * root_pi / 2 * np.exp(erfinv(x)**2)) -defvjp(erfcinv,lambda ans, x: lambda g: -g * root_pi / 2 * np.exp(erfcinv(x)**2)) +def_ufunc_jps(erfinv, (lambda ans, x: root_pi / 2 * np.exp(erfinv(x)**2 ), 'mul')) +def_ufunc_jps(erfcinv, (lambda ans, x: -root_pi / 2 * np.exp(erfcinv(x)**2), 'mul')) ### Logit and Expit ### logit = primitive(scipy.special.logit) expit = primitive(scipy.special.expit) -defvjp(logit,lambda ans, x: lambda g: g / ( x * (1 - x))) -defvjp(expit,lambda ans, x: lambda g: g * ans * (1 - ans)) +def_ufunc_jps(logit, (lambda ans, x: x * (1 - x ), 'div')) +def_ufunc_jps(expit, (lambda ans, x: ans * (1 - ans), 'mul')) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 1a25204e1..ee9aa80bd 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -130,28 +130,28 @@ def test_convolve_ignore_dot(): ### Special ### def test_polygamma(): combo_check(special.polygamma, [1], modes=['fwd', 'rev'])([0], R(4)**2 + 1.3) -def test_jn(): combo_check(special.jn, [1])([2], R(4)**2 + 1.3) -def test_yn(): combo_check(special.yn, [1])([2], R(4)**2 + 1.3) +def test_jn(): combo_check(special.jn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) +def test_yn(): combo_check(special.yn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) -def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False) -def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False) -def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False) -def test_gammasgn(): unary_ufunc_check(special.gammasgn,lims=[0.3, 2.0], test_complex=False) -def test_rgamma() : unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False) -def test_multigammaln(): combo_check(special.multigammaln, [0])([U(4., 5.), U(4., 5., (2,3))], +def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) +def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) +def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) +def test_gammasgn(): unary_ufunc_check(special.gammasgn,lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) +def test_rgamma() : unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) +def test_multigammaln(): combo_check(special.multigammaln, [0], modes=['fwd', 'rev'])([U(4., 5.), U(4., 5., (2,3))], [1, 2, 3]) -def test_j0(): unary_ufunc_check(special.j0, lims=[0.2, 20.0], test_complex=False) -def test_j1(): unary_ufunc_check(special.j1, lims=[0.2, 20.0], test_complex=False) -def test_y0(): unary_ufunc_check(special.y0, lims=[0.2, 20.0], test_complex=False) -def test_y1(): unary_ufunc_check(special.y1, lims=[0.2, 20.0], test_complex=False) +def test_j0(): unary_ufunc_check(special.j0, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) +def test_j1(): unary_ufunc_check(special.j1, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) +def test_y0(): unary_ufunc_check(special.y0, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) +def test_y1(): unary_ufunc_check(special.y1, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) -def test_erf(): unary_ufunc_check(special.erf, lims=[-3., 3.], test_complex=True) -def test_erfc(): unary_ufunc_check(special.erfc, lims=[-3., 3.], test_complex=True) +def test_erf(): unary_ufunc_check(special.erf, lims=[-3., 3.], test_complex=True, modes=['fwd', 'rev']) +def test_erfc(): unary_ufunc_check(special.erfc, lims=[-3., 3.], test_complex=True, modes=['fwd', 'rev']) -def test_erfinv(): unary_ufunc_check(special.erfinv, lims=[-0.95, 0.95], test_complex=False) -def test_erfcinv(): unary_ufunc_check(special.erfcinv, lims=[0.05, 1.95], test_complex=False) +def test_erfinv(): unary_ufunc_check(special.erfinv, lims=[-0.95, 0.95], test_complex=False, modes=['fwd', 'rev']) +def test_erfcinv(): unary_ufunc_check(special.erfcinv, lims=[0.05, 1.95], test_complex=False, modes=['fwd', 'rev']) -def test_logit(): unary_ufunc_check(special.logit, lims=[0.05, 0.95], test_complex=False) -def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False) +def test_logit(): unary_ufunc_check(special.logit, lims=[0.05, 0.95], test_complex=False, modes=['fwd', 'rev']) +def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False, modes=['fwd', 'rev']) From e4c719d57213f80c92c474ba1cc7b9d3e869591f Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 26 Oct 2017 16:59:01 +0100 Subject: [PATCH 16/32] simplify logcdf grads --- autograd/scipy/stats/norm.py | 6 +++--- autograd/scipy/stats/t.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/autograd/scipy/stats/norm.py b/autograd/scipy/stats/norm.py index ca58de98b..ea086ea87 100644 --- a/autograd/scipy/stats/norm.py +++ b/autograd/scipy/stats/norm.py @@ -26,6 +26,6 @@ (lambda ans, x, loc=0.0, scale=1.0: -pdf(x, loc, scale)*(x-loc)/scale, 'mul')) def_ufunc_jps(logcdf, - (lambda ans, x, loc=0.0, scale=1.0: anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)), 'mul'), - (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)), 'mul'), - (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale, 'mul')) + (lambda ans, x, loc=0.0, scale=1.0: anp.exp(logpdf(x, loc, scale) - ans), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - ans), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - ans)*(x-loc)/scale, 'mul')) diff --git a/autograd/scipy/stats/t.py b/autograd/scipy/stats/t.py index 7bb6b3cf2..c68dcf75b 100644 --- a/autograd/scipy/stats/t.py +++ b/autograd/scipy/stats/t.py @@ -42,6 +42,6 @@ def grad_tlogpdf_df(x, df, loc, scale): (lambda ans, x, df, loc=0.0, scale=1.0: -pdf(x, df, loc, scale), 'mul')) def_ufunc_jps(logcdf, - (lambda ans, x, df, loc=0.0, scale=1.0: np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale)), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: np.exp(logpdf(x, df, loc, scale) - ans), 'mul'), None, - (lambda ans, x, df, loc=0.0, scale=1.0: -np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale)), 'mul')) + (lambda ans, x, df, loc=0.0, scale=1.0: -np.exp(logpdf(x, df, loc, scale) - ans), 'mul')) From 57df4d8d81f8efecc12d399e39710015393b68c4 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 15:34:51 +0000 Subject: [PATCH 17/32] Re-add missing scipy tests --- tests/test_scipy.py | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index de188adba..4611a1c0a 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -16,7 +16,9 @@ import autograd.scipy.stats as stats import autograd.scipy.stats.multivariate_normal as mvn import autograd.scipy.special as special + import autograd.scipy.linalg as spla from autograd import grad + from scipy.signal import convolve as sp_convolve from autograd.test_util import combo_check, check_grads from numpy_utils import unary_ufunc_check @@ -30,6 +32,15 @@ unary_ufunc_check = partial(unary_ufunc_check, modes=['rev']) check_grads = partial(check_grads, modes=['rev']) + def symmetrize_matrix_arg(fun, argnum): + def T(X): return np.swapaxes(X, -1, -2) if np.ndim(X) > 1 else X + def symmetrize(X): return 0.5 * (X + T(X)) + def symmetrized_fun(*args, **kwargs): + args = list(args) + args[argnum] = symmetrize(args[argnum]) + return fun(*args, **kwargs) + return symmetrized_fun + ### Stats ### def test_norm_pdf(): combo_check(stats.norm.pdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) def test_norm_cdf(): combo_check(stats.norm.cdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) @@ -41,6 +52,14 @@ def test_norm_cdf_broadcast(): combo_check(stats.norm.cdf, [0,1,2], modes= def test_norm_logpdf_broadcast(): combo_check(stats.norm.logpdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) def test_norm_logcdf_broadcast(): combo_check(stats.norm.logcdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) + def test_poisson_cdf(): combo_check(stats.poisson.cdf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + def test_poisson_logpmf(): combo_check(stats.poisson.logpmf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + def test_poisson_pmf(): combo_check(stats.poisson.pmf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + + def test_poisson_cdf_broadcast(): combo_check(stats.poisson.cdf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_poisson_logpmf_broadcast(): combo_check(stats.poisson.logpmf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_poisson_pmf_broadcast(): combo_check(stats.poisson.pmf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_t_pdf(): combo_check(stats.t.pdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) def test_t_cdf(): combo_check(stats.t.cdf, [0,2], modes=['fwd', 'rev'])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) def test_t_logpdf(): combo_check(stats.t.logpdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) @@ -52,9 +71,9 @@ def test_t_logpdf_broadcast(): combo_check(stats.t.logpdf, [0,1,2,3], modes=['fw def test_t_logcdf_broadcast(): combo_check(stats.t.logcdf, [0,2], modes=['fwd', 'rev'])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) def make_psd(mat): return np.dot(mat.T, mat) + np.eye(mat.shape[0]) - def test_mvn_pdf(): combo_check(mvn.pdf, [0, 1, 2])([R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False]) - def test_mvn_logpdf(): combo_check(mvn.logpdf, [0, 1, 2])([R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False]) - def test_mvn_entropy():combo_check(mvn.entropy,[0, 1])( [R(4)], [make_psd(R(4, 4))]) + def test_mvn_pdf(): combo_check(symmetrize_matrix_arg(mvn.pdf, 2), [0, 1, 2], [R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False]) + def test_mvn_logpdf(): combo_check(symmetrize_matrix_arg(mvn.logpdf, 2), [0, 1, 2], [R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False]) + def test_mvn_entropy():combo_check(mvn.entropy,[0, 1], [R(4)], [make_psd(R(4, 4))]) C = np.zeros((4, 4)) C[0, 0] = C[1, 1] = 1 @@ -62,8 +81,8 @@ def test_mvn_entropy():combo_check(mvn.entropy,[0, 1])( [R(4)], [make_ def test_mvn_pdf_sing_cov(): combo_check(mvn.pdf, [0, 1])([np.concatenate((R(2), np.zeros(2)))], [np.concatenate((R(2), np.zeros(2)))], [C], [True]) def test_mvn_logpdf_sing_cov(): combo_check(mvn.logpdf, [0, 1])([np.concatenate((R(2), np.zeros(2)))], [np.concatenate((R(2), np.zeros(2)))], [C], [True]) - def test_mvn_pdf_broadcast(): combo_check(mvn.pdf, [0, 1, 2])([R(5, 4)], [R(4)], [make_psd(R(4, 4))]) - def test_mvn_logpdf_broadcast(): combo_check(mvn.logpdf, [0, 1, 2])([R(5, 4)], [R(4)], [make_psd(R(4, 4))]) + def test_mvn_pdf_broadcast(): combo_check(symmetrize_matrix_arg(mvn.pdf, 2), [0, 1, 2], [R(5, 4)], [R(4)], [make_psd(R(4, 4))]) + def test_mvn_logpdf_broadcast(): combo_check(symmetrize_matrix_arg(mvn.logpdf, 2), [0, 1, 2], [R(5, 4)], [R(4)], [make_psd(R(4, 4))]) alpha = npr.random(4)**2 + 1.2 x = stats.dirichlet.rvs(alpha, size=1)[0,:] @@ -111,7 +130,7 @@ def test_convolve_generalization(): assert npo.allclose(ag_convolve(A_2543, A_24232, axes=([1, 2],[2, 4]), dot_axes=([0, 3], [0, 3]), mode=mode)[2], sum([sum([sp_convolve(A_2543[i, :, :, j], - A_24232[i, 2, :, j, :], mode) + A_24232[i, 2, :, j, :], mode) for i in range(2)]) for j in range(3)])) def test_convolve(): @@ -134,6 +153,8 @@ def test_convolve_ignore_dot(): axes=[([1],[1])], dot_axes=[([0],[2]), ([0],[0])], mode=['full', 'valid']) ### Special ### + def test_gammainc(): combo_check(special.gammainc, [1])([1], R(4)**2 + 1.3) + def test_gammaincc(): combo_check(special.gammaincc, [1])([1], R(4)**2 + 1.3) def test_polygamma(): combo_check(special.polygamma, [1], modes=['fwd', 'rev'])([0], R(4)**2 + 1.3) def test_jn(): combo_check(special.jn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) def test_yn(): combo_check(special.yn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) @@ -158,5 +179,5 @@ def test_erfc(): unary_ufunc_check(special.erfc, lims=[-3., 3.], test_complex=Tr def test_erfinv(): unary_ufunc_check(special.erfinv, lims=[-0.95, 0.95], test_complex=False, modes=['fwd', 'rev']) def test_erfcinv(): unary_ufunc_check(special.erfcinv, lims=[0.05, 1.95], test_complex=False, modes=['fwd', 'rev']) - def test_logit(): unary_ufunc_check(special.logit, lims=[0.05, 0.95], test_complex=False, modes=['fwd', 'rev']) - def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False, modes=['fwd', 'rev']) \ No newline at end of file + def test_logit(): unary_ufunc_check(special.logit, lims=[0.05, 0.95], test_complex=False, modes=['fwd', 'rev']) + def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False, modes=['fwd', 'rev']) From 8205a3d7572480c02c68f9b38b2cfd2a28eee030 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 15:47:46 +0000 Subject: [PATCH 18/32] hypot grad to new format --- autograd/numpy/numpy_jvps.py | 2 ++ autograd/scipy/special.py | 2 +- autograd/scipy/stats/poisson.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 7ee7c5f5b..35dec93b6 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -87,6 +87,8 @@ def_ufunc_jps(anp.remainder, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) def_ufunc_jps(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'), (lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul')) +def_ufunc_jps(anp.hypot, (lambda ans, x, y: x / ans, 'mul'), + (lambda ans, x, y: y / ans, 'mul')) # ----- Simple grads (linear) ----- defjvp(anp.reshape, 'same') diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index 36871cd03..0c81610f2 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -3,7 +3,7 @@ import autograd.numpy as np from autograd.numpy.util import def_ufunc_jps from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.numpy.util import unbroadcast_f ### Gamma functions ### polygamma = primitive(scipy.special.polygamma) diff --git a/autograd/scipy/stats/poisson.py b/autograd/scipy/stats/poisson.py index 381c32c10..b6c57246f 100644 --- a/autograd/scipy/stats/poisson.py +++ b/autograd/scipy/stats/poisson.py @@ -3,7 +3,7 @@ import autograd.numpy as np import scipy.stats from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.numpy.util import unbroadcast_f cdf = primitive(scipy.stats.poisson.cdf) logpmf = primitive(scipy.stats.poisson.logpmf) From d2c737ad8b30b747a1a82c3c0cd3755ed330bf43 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 16:00:14 +0000 Subject: [PATCH 19/32] New scipy ufuncs to new format --- autograd/scipy/special.py | 7 +++---- autograd/scipy/stats/poisson.py | 8 ++++---- tests/test_scipy.py | 16 ++++++++-------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index 0c81610f2..a963938a2 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -30,11 +30,10 @@ def make_gammainc_vjp_arg1(sign): def gammainc_vjp_arg1(ans, a, x): - coeffs = sign * np.exp(-x) * np.power(x, a - 1) / gamma(a) - return unbroadcast_f(x, lambda g: g * coeffs) + return sign * np.exp(-x) * np.power(x, a - 1) / gamma(a) return gammainc_vjp_arg1 -defvjp(gammainc, make_gammainc_vjp_arg1(1), argnums=[1]) -defvjp(gammaincc, make_gammainc_vjp_arg1(-1), argnums=[1]) +def_ufunc_jps(gammainc, None, (make_gammainc_vjp_arg1(1), 'mul')) +def_ufunc_jps(gammaincc, None, (make_gammainc_vjp_arg1(-1), 'mul')) ### Bessel functions ### j0 = primitive(scipy.special.j0) diff --git a/autograd/scipy/stats/poisson.py b/autograd/scipy/stats/poisson.py index b6c57246f..8546602a1 100644 --- a/autograd/scipy/stats/poisson.py +++ b/autograd/scipy/stats/poisson.py @@ -3,7 +3,7 @@ import autograd.numpy as np import scipy.stats from autograd.extend import primitive, defvjp -from autograd.numpy.util import unbroadcast_f +from autograd.numpy.util import def_ufunc_jps cdf = primitive(scipy.stats.poisson.cdf) logpmf = primitive(scipy.stats.poisson.logpmf) @@ -12,6 +12,6 @@ def grad_poisson_logpmf(k, mu): return np.where(k % 1 == 0, k / mu - 1, 0) -defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1]) -defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1]) -defvjp(pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1]) +def_ufunc_jps(cdf, None, (lambda ans, k, mu: -pmf(np.floor(k), mu), 'mul')) +def_ufunc_jps(logpmf, None, (lambda ans, k, mu: grad_poisson_logpmf(k, mu), 'mul')) +def_ufunc_jps(pmf, None, (lambda ans, k, mu: ans * grad_poisson_logpmf(k, mu), 'mul')) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 4611a1c0a..c3eac4a85 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -52,13 +52,13 @@ def test_norm_cdf_broadcast(): combo_check(stats.norm.cdf, [0,1,2], modes= def test_norm_logpdf_broadcast(): combo_check(stats.norm.logpdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) def test_norm_logcdf_broadcast(): combo_check(stats.norm.logcdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) - def test_poisson_cdf(): combo_check(stats.poisson.cdf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) - def test_poisson_logpmf(): combo_check(stats.poisson.logpmf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) - def test_poisson_pmf(): combo_check(stats.poisson.pmf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + def test_poisson_cdf(): combo_check(stats.poisson.cdf, [1], modes=['fwd', 'rev'])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + def test_poisson_logpmf(): combo_check(stats.poisson.logpmf, [1], modes=['fwd', 'rev'])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + def test_poisson_pmf(): combo_check(stats.poisson.pmf, [1], modes=['fwd', 'rev'])([np.round(R(4)**2)], [R(4)**2 + 1.1]) - def test_poisson_cdf_broadcast(): combo_check(stats.poisson.cdf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) - def test_poisson_logpmf_broadcast(): combo_check(stats.poisson.logpmf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) - def test_poisson_pmf_broadcast(): combo_check(stats.poisson.pmf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_poisson_cdf_broadcast(): combo_check(stats.poisson.cdf, [1], modes=['fwd', 'rev'])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_poisson_logpmf_broadcast(): combo_check(stats.poisson.logpmf, [1], modes=['fwd', 'rev'])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_poisson_pmf_broadcast(): combo_check(stats.poisson.pmf, [1], modes=['fwd', 'rev'])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) def test_t_pdf(): combo_check(stats.t.pdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) def test_t_cdf(): combo_check(stats.t.cdf, [0,2], modes=['fwd', 'rev'])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) @@ -153,8 +153,8 @@ def test_convolve_ignore_dot(): axes=[([1],[1])], dot_axes=[([0],[2]), ([0],[0])], mode=['full', 'valid']) ### Special ### - def test_gammainc(): combo_check(special.gammainc, [1])([1], R(4)**2 + 1.3) - def test_gammaincc(): combo_check(special.gammaincc, [1])([1], R(4)**2 + 1.3) + def test_gammainc(): combo_check(special.gammainc, [1], modes=['fwd', 'rev'])([1], R(4)**2 + 1.3) + def test_gammaincc(): combo_check(special.gammaincc, [1], modes=['fwd', 'rev'])([1], R(4)**2 + 1.3) def test_polygamma(): combo_check(special.polygamma, [1], modes=['fwd', 'rev'])([0], R(4)**2 + 1.3) def test_jn(): combo_check(special.jn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) def test_yn(): combo_check(special.yn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) From 1ba7015b3692fb890fc3299f18ae0cd082bf88c0 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 16:28:33 +0000 Subject: [PATCH 20/32] Simplify def_ufunc_jps api --- autograd/numpy/numpy_jvps.py | 86 ++++++++++++++++++------------------ autograd/numpy/util.py | 57 +++++++++++++++--------- 2 files changed, 80 insertions(+), 63 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 35dec93b6..1dcfd2f94 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -26,50 +26,50 @@ defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) # ----- Unary ufuncs ------ -def_ufunc_jps(anp.negative, (None, 'same')) -def_ufunc_jps(anp.rad2deg, (None, 'same')) -def_ufunc_jps(anp.degrees, (None, 'same')) -def_ufunc_jps(anp.deg2rad, (None, 'same')) -def_ufunc_jps(anp.radians, (None, 'same')) +def_ufunc_jps(anp.negative, 'same') +def_ufunc_jps(anp.rad2deg, 'same') +def_ufunc_jps(anp.degrees, 'same') +def_ufunc_jps(anp.deg2rad, 'same') +def_ufunc_jps(anp.radians, 'same') def_ufunc_jps(anp.abs, (lambda ans, x: replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.), 'cmul')) -def_ufunc_jps(anp.fabs, (lambda ans, x: anp.sign(x), 'mul')) # fabs doesn't take complex numbers. -def_ufunc_jps(anp.absolute, (lambda ans, x: anp.conj(x) / ans, 'cmul')) -def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' )) -def_ufunc_jps(anp.exp, (lambda ans, x: ans, 'mul' )) -def_ufunc_jps(anp.exp2, (lambda ans, x: ans * anp.log(2), 'mul' )) -def_ufunc_jps(anp.expm1, (lambda ans, x: ans + 1, 'mul' )) -def_ufunc_jps(anp.log, (lambda ans, x: x, 'div' )) -def_ufunc_jps(anp.log2, (lambda ans, x: x * anp.log(2), 'div' )) -def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' )) -def_ufunc_jps(anp.log1p, (lambda ans, x: x + 1, 'div' )) -def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul' )) -def_ufunc_jps(anp.cos, (lambda ans, x: -anp.sin(x), 'mul' )) -def_ufunc_jps(anp.tan, (lambda ans, x: 1 + ans**2, 'mul' )) -def_ufunc_jps(anp.arcsin, (lambda ans, x: anp.sqrt(1 - x**2), 'div' )) -def_ufunc_jps(anp.arccos, (lambda ans, x:-anp.sqrt(1 - x**2), 'div' )) -def_ufunc_jps(anp.arctan, (lambda ans, x: 1 + x**2, 'div' )) -def_ufunc_jps(anp.sinh, (lambda ans, x: anp.cosh(x), 'mul' )) -def_ufunc_jps(anp.cosh, (lambda ans, x: anp.sinh(x), 'mul' )) -def_ufunc_jps(anp.tanh, (lambda ans, x: 1 - ans**2, 'mul' )) -def_ufunc_jps(anp.arcsinh, (lambda ans, x: anp.sqrt(x**2 + 1), 'div' )) -def_ufunc_jps(anp.arccosh, (lambda ans, x: anp.sqrt(x**2 - 1), 'div' )) -def_ufunc_jps(anp.arctanh, (lambda ans, x: 1 - x**2, 'div' )) -def_ufunc_jps(anp.square, (lambda ans, x: 2 * x, 'mul' )) -def_ufunc_jps(anp.sqrt, (lambda ans, x: 2 * ans, 'div' )) -def_ufunc_jps(anp.sinc, (lambda ans, x: (anp.cos(anp.pi*x)-ans)/x, 'mul' )) -def_ufunc_jps(anp.real_if_close, (None, 'cid')) -def_ufunc_jps(anp.real, (None, 'cid')) -def_ufunc_jps(anp.imag, (lambda ans, x: -1j, 'cmul')) -def_ufunc_jps(anp.conj, (None, 'same')) -def_ufunc_jps(anp.conjugate, (None, 'same')) -def_ufunc_jps(anp.angle, (lambda ans, x: anp.conj(x * 1j)/anp.abs(x)**2, 'cmul')) +def_ufunc_jps(anp.fabs, (lambda ans, x: anp.sign(x), 'mul')) # fabs doesn't take complex numbers. +def_ufunc_jps(anp.absolute, (lambda ans, x: anp.conj(x) / ans, 'cmul')) +def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' )) +def_ufunc_jps(anp.exp, (lambda ans, x: ans, 'mul' )) +def_ufunc_jps(anp.exp2, (lambda ans, x: ans * anp.log(2), 'mul' )) +def_ufunc_jps(anp.expm1, (lambda ans, x: ans + 1, 'mul' )) +def_ufunc_jps(anp.log, (lambda ans, x: x, 'div' )) +def_ufunc_jps(anp.log2, (lambda ans, x: x * anp.log(2), 'div' )) +def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' )) +def_ufunc_jps(anp.log1p, (lambda ans, x: x + 1, 'div' )) +def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul' )) +def_ufunc_jps(anp.cos, (lambda ans, x: -anp.sin(x), 'mul' )) +def_ufunc_jps(anp.tan, (lambda ans, x: 1 + ans**2, 'mul' )) +def_ufunc_jps(anp.arcsin, (lambda ans, x: anp.sqrt(1 - x**2), 'div' )) +def_ufunc_jps(anp.arccos, (lambda ans, x:-anp.sqrt(1 - x**2), 'div' )) +def_ufunc_jps(anp.arctan, (lambda ans, x: 1 + x**2, 'div' )) +def_ufunc_jps(anp.sinh, (lambda ans, x: anp.cosh(x), 'mul' )) +def_ufunc_jps(anp.cosh, (lambda ans, x: anp.sinh(x), 'mul' )) +def_ufunc_jps(anp.tanh, (lambda ans, x: 1 - ans**2, 'mul' )) +def_ufunc_jps(anp.arcsinh, (lambda ans, x: anp.sqrt(x**2 + 1), 'div' )) +def_ufunc_jps(anp.arccosh, (lambda ans, x: anp.sqrt(x**2 - 1), 'div' )) +def_ufunc_jps(anp.arctanh, (lambda ans, x: 1 - x**2, 'div' )) +def_ufunc_jps(anp.square, (lambda ans, x: 2 * x, 'mul' )) +def_ufunc_jps(anp.sqrt, (lambda ans, x: 2 * ans, 'div' )) +def_ufunc_jps(anp.sinc, (lambda ans, x: (anp.cos(anp.pi*x)-ans)/x, 'mul' )) +def_ufunc_jps(anp.real_if_close, 'cid') +def_ufunc_jps(anp.real, 'cid') +def_ufunc_jps(anp.imag, (lambda ans, x: -1j, 'cmul')) +def_ufunc_jps(anp.conj, 'same') +def_ufunc_jps(anp.conjugate, 'same') +def_ufunc_jps(anp.angle, (lambda ans, x: anp.conj(x * 1j)/anp.abs(x)**2, 'cmul')) # ----- Binary ufuncs ----- -def_ufunc_jps(anp.add, *repeat((None, 'id'), 2)) -def_ufunc_jps(anp.subtract, (None, 'id'), (None, 'neg')) -def_ufunc_jps(anp.multiply, *repeat((None, 'same'), 2)) -def_ufunc_jps(anp.divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_jps(anp.add, 'id', 'id') +def_ufunc_jps(anp.subtract, 'id', 'neg') +def_ufunc_jps(anp.multiply, 'same', 'same') +def_ufunc_jps(anp.divide, 'same', (lambda ans, x, y: -ans/y, 'mul')) def_ufunc_jps(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) def_ufunc_jps(anp.minimum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), @@ -82,9 +82,9 @@ (lambda ans, x, y: anp.exp(y-ans), 'mul')) def_ufunc_jps(anp.logaddexp2, (lambda ans, x, y: 2**(x-ans), 'mul'), (lambda ans, x, y: 2**(y-ans), 'mul')) -def_ufunc_jps(anp.true_divide, (None, 'same'), (lambda ans, x, y: -ans/y, 'mul')) -def_ufunc_jps(anp.mod, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) -def_ufunc_jps(anp.remainder, (None, 'id'), (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_jps(anp.true_divide, 'same', (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_jps(anp.mod, 'id', (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_jps(anp.remainder, 'id', (lambda ans, x, y: -anp.floor(x/y), 'mul')) def_ufunc_jps(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'), (lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul')) def_ufunc_jps(anp.hypot, (lambda ans, x, y: x / ans, 'mul'), diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index 6e7b2c207..c1d852883 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -27,37 +27,47 @@ def def_ufunc_jps(ufunc, *derivs_ops): derivs_ops = list(derivs_ops) unary_ufunc_jps = { - 'same': (lambda deriv: lambda g, ans, x: ufunc(g), - lambda deriv: lambda ans, x: ufunc), + 'same': (lambda g, ans, x: ufunc(g), + lambda ans, x: ufunc), + 'cid': (lambda g, ans, x: match_complex(ans, g), + lambda ans, x: lambda g: match_complex(x , g)) + } + + unary_ufunc_linops = { 'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x), lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g * d), 'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x), lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g / d), 'cmul': (lambda deriv: lambda g, ans, x: match_complex(ans, g * deriv(ans, x)), lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): match_complex(x, g * d)), - 'cid': (lambda deriv: lambda g, ans, x: match_complex(ans, g), - lambda deriv: lambda ans, x: lambda g: match_complex(x , g)) } if len(derivs_ops) == 1: - if derivs_ops[0] is None: + deriv_op = derivs_ops[0] + if type(deriv_op) is tuple: + deriv, op = deriv_op + defjvp(ufunc, unary_ufunc_linops[op][0](deriv)) + defvjp(ufunc, unary_ufunc_linops[op][1](deriv)) + elif deriv_op is None: defjvp(ufunc, None) defvjp(ufunc, None) else: - deriv, op = derivs_ops[0] - defjvp(ufunc, unary_ufunc_jps[op][0](deriv)) - defvjp(ufunc, unary_ufunc_jps[op][1](deriv)) + defjvp(ufunc, unary_ufunc_jps[deriv_op][0]) + defvjp(ufunc, unary_ufunc_jps[deriv_op][1]) nary_ufunc_jps = { - 'same': (lambda argnum, deriv: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), - lambda argnum, deriv: lambda ans, *args: + 'same': (lambda argnum: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), + lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), - 'id': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(g, ans.shape)), - lambda argnum, deriv: lambda ans, *args: + 'id': (lambda argnum: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(g, ans.shape)), + lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], g))), - 'neg': (lambda argnum, deriv: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(-g, ans.shape)), - lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], -g))), + 'neg': (lambda argnum: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(-g, ans.shape)), + lambda argnum: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], -g))) + } + + nary_ufunc_linops = { 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), lambda argnum, deriv: lambda ans, *args: unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), @@ -65,9 +75,16 @@ def def_ufunc_jps(ufunc, *derivs_ops): lambda argnum, deriv: lambda ans, *args: unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) } - if len(derivs_ops) >= 2: - defjvp(ufunc, *[nary_ufunc_jps[deriv_op[1]][0](argnum, deriv_op[0]) - if deriv_op is not None else None for argnum, deriv_op in enumerate(derivs_ops)]) - defvjp(ufunc, *[nary_ufunc_jps[deriv_op[1]][1](argnum, deriv_op[0]) - if deriv_op is not None else None for argnum, deriv_op in enumerate(derivs_ops)]) + def deriv_op_to_ufunc_jp(idx, argnum, deriv_op): + if type(deriv_op) is tuple: + deriv, op = deriv_op + return nary_ufunc_linops[op][idx](argnum, deriv) + elif deriv_op is None: + return None + else: + return nary_ufunc_jps[deriv_op][idx](argnum) + + if len(derivs_ops) >= 2: + defjvp(ufunc, *[deriv_op_to_ufunc_jp(0, argnum, deriv_op) for argnum, deriv_op in enumerate(derivs_ops)]) + defvjp(ufunc, *[deriv_op_to_ufunc_jp(1, argnum, deriv_op) for argnum, deriv_op in enumerate(derivs_ops)]) From 394e28b699da6b4090f638ab6f5c678f0defa129 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 16:46:44 +0000 Subject: [PATCH 21/32] Refactor def_ufunc_jps --- autograd/numpy/util.py | 57 +++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index c1d852883..b9ff85cee 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -23,17 +23,15 @@ def unbroadcast_f(target, f): target_meta = anp.metadata(target) return lambda g: unbroadcast(f(g), target_meta) -def def_ufunc_jps(ufunc, *derivs_ops): - derivs_ops = list(derivs_ops) - - unary_ufunc_jps = { +def def_unary_ufunc_jps(ufunc, deriv_op): + jps = { 'same': (lambda g, ans, x: ufunc(g), lambda ans, x: ufunc), 'cid': (lambda g, ans, x: match_complex(ans, g), lambda ans, x: lambda g: match_complex(x , g)) } - unary_ufunc_linops = { + linops = { 'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x), lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g * d), 'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x), @@ -42,20 +40,19 @@ def def_ufunc_jps(ufunc, *derivs_ops): lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): match_complex(x, g * d)), } - if len(derivs_ops) == 1: - deriv_op = derivs_ops[0] - if type(deriv_op) is tuple: - deriv, op = deriv_op - defjvp(ufunc, unary_ufunc_linops[op][0](deriv)) - defvjp(ufunc, unary_ufunc_linops[op][1](deriv)) - elif deriv_op is None: - defjvp(ufunc, None) - defvjp(ufunc, None) - else: - defjvp(ufunc, unary_ufunc_jps[deriv_op][0]) - defvjp(ufunc, unary_ufunc_jps[deriv_op][1]) + if type(deriv_op) is tuple: + deriv, op = deriv_op + defjvp(ufunc, linops[op][0](deriv)) + defvjp(ufunc, linops[op][1](deriv)) + elif deriv_op is None: + defjvp(ufunc, None) + defvjp(ufunc, None) + else: + defjvp(ufunc, jps[deriv_op][0]) + defvjp(ufunc, jps[deriv_op][1]) - nary_ufunc_jps = { +def def_nary_ufunc_jps(ufunc, derivs_ops): + jps = { 'same': (lambda argnum: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), @@ -67,24 +64,32 @@ def def_ufunc_jps(ufunc, *derivs_ops): unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], -g))) } - nary_ufunc_linops = { + linops = { 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), lambda argnum, deriv: lambda ans, *args: unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), 'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args), lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) } - def deriv_op_to_ufunc_jp(idx, argnum, deriv_op): + def deriv_op_to_jp(idx, argnum, deriv_op): if type(deriv_op) is tuple: deriv, op = deriv_op - return nary_ufunc_linops[op][idx](argnum, deriv) + return linops[op][idx](argnum, deriv) elif deriv_op is None: return None else: - return nary_ufunc_jps[deriv_op][idx](argnum) + return jps[deriv_op][idx](argnum) + + defjvp(ufunc, *[deriv_op_to_jp(0, argnum, deriv_op) + for argnum, deriv_op in enumerate(derivs_ops)]) + defvjp(ufunc, *[deriv_op_to_jp(1, argnum, deriv_op) + for argnum, deriv_op in enumerate(derivs_ops)]) - if len(derivs_ops) >= 2: - defjvp(ufunc, *[deriv_op_to_ufunc_jp(0, argnum, deriv_op) for argnum, deriv_op in enumerate(derivs_ops)]) - defvjp(ufunc, *[deriv_op_to_ufunc_jp(1, argnum, deriv_op) for argnum, deriv_op in enumerate(derivs_ops)]) +def def_ufunc_jps(ufunc, *derivs_ops): + derivs_ops = list(derivs_ops) + if len(derivs_ops) == 1: + def_unary_ufunc_jps(ufunc, derivs_ops[0]) + elif len(derivs_ops) > 1: + def_nary_ufunc_jps(ufunc, derivs_ops) From 486eceab319b96808642f432cc6e39cc745e0a14 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 22:49:43 +0000 Subject: [PATCH 22/32] Add docstring to def_ufunc_jps --- autograd/numpy/util.py | 72 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index b9ff85cee..f45d63b94 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -88,6 +88,78 @@ def deriv_op_to_jp(idx, argnum, deriv_op): for argnum, deriv_op in enumerate(derivs_ops)]) def def_ufunc_jps(ufunc, *derivs_ops): + """ + Specify the derivatives of ufunc. Once this has been done the ufunc will + support both reverse and forward mode differentiation. + + The derivatives can be specified as follows. + + Unary ufuncs + ------------ + If the ufunc is unary (that is, if it takes one array valued argument), + then a single optional argument is required to specify the ufunc's + derivative. + + In the general case, this is done via a pair (deriv, op), where deriv is a + function taking in the output of the ufunc (ans), and its array argument + (x), and returning the derivative of the ufunc. + + Here 'derivative' means the elementwise derivative of the ufunc w.r.t. it's + input. + + For example, for the ufunc np.sin, this is as simple as + >>> def deriv(ans, x): + ... return np.cos(x) + ... + + Sometimes the output of the ufunc is useful, for example the derivative of + np.exp is np.exp, which is identical to ans, so the derivative of np.exp + can be efficiently implemented as + >>> def deriv(ans, x): + ... return ans + ... + + The other element of the pair is `op`, which should usually be set to + 'mul'. However, if the derivative of the ufunc is of the form + 1 / f(ans, x), then you can save some computation by using the pair + (f, 'div') to specify the derivative. The 'div' flags that the gradients + being propagated through this primitive should be divided by the result of + f, not multiplied. + + Some full examples: + >>> def_ufunc_jps(np.sin, (lambda ans, x: np.cos(x), 'mul')) + >>> def_ufunc_jps(np.exp, (lambda ans, x: ans, 'mul')) + >>> def_ufunc_jps(np.log, (lambda ans, x: x, 'div')) + + Special cases + ------------- + If the derivative of the ufunc is a constant, then you don't need to + specify its derivative and you can use just the string 'same' in place of + the pair (deriv, op). This says that its ok to propagate the gradient + through this primitive by applying the ufunc itself to the gradient, and + neither x nor ans are relevant to this computation. + + For example, the derivative of np.negative (which simply negates its + inputs), is -1, so + >>> def_ufunc_jps(np.negative, 'same') + + will correctly set its derivative. + + N-ary ufuncs + ------------ + For ufuncs which take more than one array argument, the derivatives can be + specified by passing one (deriv, op) pair for each argument (you can use + None as a placeholder for args whose derivative you don't wish to define). + + You can use 'same' in exactly the same way as for unary ufuncs, and + additionally you can use 'id' when the derivative w.r.t. an arg is always + equal to 1, and 'neg' when it's always equal to -1. + + Some examples: + >>> def_ufunc_jps(anp.divide, 'same', (lambda ans, x, y: -ans/y, 'mul')) + >>> def_ufunc_jps(anp.add, 'id', 'id') + >>> def_ufunc_jps(anp.subtract, 'id', 'neg') + """ derivs_ops = list(derivs_ops) if len(derivs_ops) == 1: def_unary_ufunc_jps(ufunc, derivs_ops[0]) From 27e882b7da57ffb6276ec58e99cd2c946e4d101e Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 23:24:59 +0000 Subject: [PATCH 23/32] New stats grads to new format --- autograd/scipy/stats/beta.py | 25 ++++++++++++++----------- autograd/scipy/stats/chi2.py | 11 ++++++----- autograd/scipy/stats/gamma.py | 21 ++++++++++++--------- tests/test_scipy.py | 20 ++++++++++---------- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/autograd/scipy/stats/beta.py b/autograd/scipy/stats/beta.py index a703ae6a9..64a822d46 100644 --- a/autograd/scipy/stats/beta.py +++ b/autograd/scipy/stats/beta.py @@ -2,8 +2,8 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import beta, psi cdf = primitive(scipy.stats.beta.cdf) @@ -19,12 +19,15 @@ def grad_beta_logpdf_arg1(x, a, b): def grad_beta_logpdf_arg2(x, a, b): return np.log1p(-x) - psi(b) + psi(a + b) -defvjp(cdf, lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * np.power(x, a-1) * np.power(1-x, b-1) / beta(a, b)), argnums=[0]) -defvjp(logpdf, - lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * grad_beta_logpdf_arg0(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * grad_beta_logpdf_arg1(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * grad_beta_logpdf_arg2(x, a, b))) -defvjp(pdf, - lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * ans * grad_beta_logpdf_arg0(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * ans * grad_beta_logpdf_arg1(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * ans * grad_beta_logpdf_arg2(x, a, b))) +def_ufunc_jps(cdf, + (lambda ans, x, a, b: np.power(x, a-1) * np.power(1-x, b-1) / beta(a, b), 'mul'), + None, + None) +def_ufunc_jps(logpdf, + (lambda ans, x, a, b: grad_beta_logpdf_arg0(x, a, b), 'mul'), + (lambda ans, x, a, b: grad_beta_logpdf_arg1(x, a, b), 'mul'), + (lambda ans, x, a, b: grad_beta_logpdf_arg2(x, a, b), 'mul')) +def_ufunc_jps(pdf, + (lambda ans, x, a, b: ans * grad_beta_logpdf_arg0(x, a, b), 'mul'), + (lambda ans, x, a, b: ans * grad_beta_logpdf_arg1(x, a, b), 'mul'), + (lambda ans, x, a, b: ans * grad_beta_logpdf_arg2(x, a, b), 'mul')) diff --git a/autograd/scipy/stats/chi2.py b/autograd/scipy/stats/chi2.py index 8555739a9..f7d14b53d 100644 --- a/autograd/scipy/stats/chi2.py +++ b/autograd/scipy/stats/chi2.py @@ -2,8 +2,8 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import gamma cdf = primitive(scipy.stats.chi2.cdf) @@ -13,6 +13,7 @@ def grad_chi2_logpdf(x, df): return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0) -defvjp(cdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * np.power(2., -df/2) * np.exp(-x/2) * np.power(x, df/2 - 1) / gamma(df/2)), argnums=[0]) -defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0]) -defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0]) +def_ufunc_jps(cdf, (lambda ans, x, df: (np.power(2., -df/2) * np.exp(-x/2) * + np.power(x, df/2 - 1) / gamma(df/2)), 'mul'), None) +def_ufunc_jps(logpdf, (lambda ans, x, df: (grad_chi2_logpdf(x, df)), 'mul'), None) +def_ufunc_jps(pdf, (lambda ans, x, df: (ans * grad_chi2_logpdf(x, df)), 'mul'), None) diff --git a/autograd/scipy/stats/gamma.py b/autograd/scipy/stats/gamma.py index 5b595099c..56fd85471 100644 --- a/autograd/scipy/stats/gamma.py +++ b/autograd/scipy/stats/gamma.py @@ -2,8 +2,8 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import gamma, psi cdf = primitive(scipy.stats.gamma.cdf) @@ -16,10 +16,13 @@ def grad_gamma_logpdf_arg0(x, a): def grad_gamma_logpdf_arg1(x, a): return np.log(x) - psi(a) -defvjp(cdf, lambda ans, x, a: unbroadcast_f(x, lambda g: g * np.exp(-x) * np.power(x, a-1) / gamma(a)), argnums=[0]) -defvjp(logpdf, - lambda ans, x, a: unbroadcast_f(x, lambda g: g * grad_gamma_logpdf_arg0(x, a)), - lambda ans, x, a: unbroadcast_f(a, lambda g: g * grad_gamma_logpdf_arg1(x, a))) -defvjp(pdf, - lambda ans, x, a: unbroadcast_f(x, lambda g: g * ans * grad_gamma_logpdf_arg0(x, a)), - lambda ans, x, a: unbroadcast_f(a, lambda g: g * ans * grad_gamma_logpdf_arg1(x, a))) +def_ufunc_jps(cdf, + (lambda ans, x, a: np.exp(-x) * np.power(x, a-1) / gamma(a), 'mul'), + None, + None) +def_ufunc_jps(logpdf, + (lambda ans, x, a: grad_gamma_logpdf_arg0(x, a), 'mul'), + (lambda ans, x, a: grad_gamma_logpdf_arg1(x, a), 'mul')) +def_ufunc_jps(pdf, + (lambda ans, x, a: ans * grad_gamma_logpdf_arg0(x, a), 'mul'), + (lambda ans, x, a: ans * grad_gamma_logpdf_arg1(x, a), 'mul')) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 6d2b834ba..e0152d8c1 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -42,17 +42,17 @@ def symmetrized_fun(*args, **kwargs): return symmetrized_fun ### Stats ### - def test_chi2_pdf(): combo_check(stats.chi2.pdf, [0])([R(4)**2 + 1.1], [1, 2, 3]) - def test_chi2_cdf(): combo_check(stats.chi2.cdf, [0])([R(4)**2 + 1.1], [1, 2, 3]) - def test_chi2_logpdf(): combo_check(stats.chi2.logpdf, [0])([R(4)**2 + 1.1], [1, 2, 3]) + def test_chi2_pdf(): combo_check(stats.chi2.pdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [1, 2, 3]) + def test_chi2_cdf(): combo_check(stats.chi2.cdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [1, 2, 3]) + def test_chi2_logpdf(): combo_check(stats.chi2.logpdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [1, 2, 3]) - def test_beta_cdf(): combo_check(stats.beta.cdf, [0]) ([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_beta_pdf(): combo_check(stats.beta.pdf, [0,1,2])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_beta_logpdf(): combo_check(stats.beta.logpdf, [0,1,2])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_beta_cdf(): combo_check(stats.beta.cdf, [0], modes=['fwd', 'rev'])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_beta_pdf(): combo_check(stats.beta.pdf, [0,1,2], modes=['fwd', 'rev'])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_beta_logpdf(): combo_check(stats.beta.logpdf, [0,1,2], modes=['fwd', 'rev'])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_gamma_cdf(): combo_check(stats.gamma.cdf, [0]) ([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_gamma_pdf(): combo_check(stats.gamma.pdf, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_gamma_logpdf(): combo_check(stats.gamma.logpdf, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_gamma_cdf(): combo_check(stats.gamma.cdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_gamma_pdf(): combo_check(stats.gamma.pdf, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_gamma_logpdf(): combo_check(stats.gamma.logpdf, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) def test_norm_pdf(): combo_check(stats.norm.pdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) def test_norm_cdf(): combo_check(stats.norm.cdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) @@ -196,4 +196,4 @@ def test_erfinv(): unary_ufunc_check(special.erfinv, lims=[-0.95, 0.95], test_c def test_erfcinv(): unary_ufunc_check(special.erfcinv, lims=[0.05, 1.95], test_complex=False, modes=['fwd', 'rev']) def test_logit(): unary_ufunc_check(special.logit, lims=[0.05, 0.95], test_complex=False, modes=['fwd', 'rev']) - def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False, modes=['fwd', 'rev']) \ No newline at end of file + def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False, modes=['fwd', 'rev']) From 204baea9a62b36cf0bbdf33f8a7233bd0056d479 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 2 Nov 2017 23:45:06 +0000 Subject: [PATCH 24/32] Add inverse pair helper --- autograd/numpy/numpy_jvps.py | 24 +++++++++--------------- autograd/numpy/util.py | 8 ++++++++ autograd/scipy/special.py | 12 ++++-------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 1dcfd2f94..602a2252f 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -4,7 +4,7 @@ tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace, defvjp) -from .util import def_ufunc_jps +from .util import def_ufunc_jps, def_ufunc_jps_inv_pair from ..util import func from .numpy_boxes import ArrayBox @@ -36,27 +36,13 @@ def_ufunc_jps(anp.fabs, (lambda ans, x: anp.sign(x), 'mul')) # fabs doesn't take complex numbers. def_ufunc_jps(anp.absolute, (lambda ans, x: anp.conj(x) / ans, 'cmul')) def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' )) -def_ufunc_jps(anp.exp, (lambda ans, x: ans, 'mul' )) -def_ufunc_jps(anp.exp2, (lambda ans, x: ans * anp.log(2), 'mul' )) -def_ufunc_jps(anp.expm1, (lambda ans, x: ans + 1, 'mul' )) -def_ufunc_jps(anp.log, (lambda ans, x: x, 'div' )) -def_ufunc_jps(anp.log2, (lambda ans, x: x * anp.log(2), 'div' )) def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' )) -def_ufunc_jps(anp.log1p, (lambda ans, x: x + 1, 'div' )) def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul' )) def_ufunc_jps(anp.cos, (lambda ans, x: -anp.sin(x), 'mul' )) -def_ufunc_jps(anp.tan, (lambda ans, x: 1 + ans**2, 'mul' )) def_ufunc_jps(anp.arcsin, (lambda ans, x: anp.sqrt(1 - x**2), 'div' )) def_ufunc_jps(anp.arccos, (lambda ans, x:-anp.sqrt(1 - x**2), 'div' )) -def_ufunc_jps(anp.arctan, (lambda ans, x: 1 + x**2, 'div' )) -def_ufunc_jps(anp.sinh, (lambda ans, x: anp.cosh(x), 'mul' )) def_ufunc_jps(anp.cosh, (lambda ans, x: anp.sinh(x), 'mul' )) -def_ufunc_jps(anp.tanh, (lambda ans, x: 1 - ans**2, 'mul' )) -def_ufunc_jps(anp.arcsinh, (lambda ans, x: anp.sqrt(x**2 + 1), 'div' )) def_ufunc_jps(anp.arccosh, (lambda ans, x: anp.sqrt(x**2 - 1), 'div' )) -def_ufunc_jps(anp.arctanh, (lambda ans, x: 1 - x**2, 'div' )) -def_ufunc_jps(anp.square, (lambda ans, x: 2 * x, 'mul' )) -def_ufunc_jps(anp.sqrt, (lambda ans, x: 2 * ans, 'div' )) def_ufunc_jps(anp.sinc, (lambda ans, x: (anp.cos(anp.pi*x)-ans)/x, 'mul' )) def_ufunc_jps(anp.real_if_close, 'cid') def_ufunc_jps(anp.real, 'cid') @@ -65,6 +51,14 @@ def_ufunc_jps(anp.conjugate, 'same') def_ufunc_jps(anp.angle, (lambda ans, x: anp.conj(x * 1j)/anp.abs(x)**2, 'cmul')) +def_ufunc_jps_inv_pair(anp.exp, anp.log, lambda ans, x: ans) +def_ufunc_jps_inv_pair(anp.exp2, anp.log2, lambda ans, x: ans * anp.log(2)) +def_ufunc_jps_inv_pair(anp.expm1, anp.log1p, lambda ans, x: ans + 1) +def_ufunc_jps_inv_pair(anp.tan, anp.arctan, lambda ans, x: 1 + ans**2) +def_ufunc_jps_inv_pair(anp.tanh, anp.arctanh, lambda ans, x: 1 - ans**2) +def_ufunc_jps_inv_pair(anp.sinh, anp.arcsinh, lambda ans, x: anp.sqrt(ans**2 + 1)) +def_ufunc_jps_inv_pair(anp.square, anp.sqrt, lambda ans, x: 2 * x) + # ----- Binary ufuncs ----- def_ufunc_jps(anp.add, 'id', 'id') def_ufunc_jps(anp.subtract, 'id', 'neg') diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index f45d63b94..d2a951eff 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -165,3 +165,11 @@ def def_ufunc_jps(ufunc, *derivs_ops): def_unary_ufunc_jps(ufunc, derivs_ops[0]) elif len(derivs_ops) > 1: def_nary_ufunc_jps(ufunc, derivs_ops) + +def def_ufunc_jps_inv_pair(ufunc, ufunc_inv, deriv): + """ + Define the derivatives for an inverse pair of unary ufuncs. deriv must be + the derivative of the first ufunc. + """ + def_ufunc_jps(ufunc, (deriv, 'mul')) + def_ufunc_jps(ufunc_inv, (lambda ans, x: deriv(x, ans), 'div')) diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index ff7154374..0d21c75b8 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import scipy.special import autograd.numpy as np -from autograd.numpy.util import def_ufunc_jps +from autograd.numpy.util import def_ufunc_jps, def_ufunc_jps_inv_pair from autograd.extend import primitive, defvjp from autograd.numpy.util import unbroadcast_f @@ -70,20 +70,16 @@ def gammainc_vjp_arg1(ans, a, x): erf = primitive(scipy.special.erf) erfc = primitive(scipy.special.erfc) -def_ufunc_jps(erf, (lambda ans, x: 2.*inv_root_pi*np.exp(-x**2), 'mul')) -def_ufunc_jps(erfc, (lambda ans, x: -2.*inv_root_pi*np.exp(-x**2), 'mul')) - ### Inverse error function ### root_pi = 1.7724538509055159 erfinv = primitive(scipy.special.erfinv) erfcinv = primitive(scipy.special.erfcinv) -def_ufunc_jps(erfinv, (lambda ans, x: root_pi / 2 * np.exp(erfinv(x)**2 ), 'mul')) -def_ufunc_jps(erfcinv, (lambda ans, x: -root_pi / 2 * np.exp(erfcinv(x)**2), 'mul')) +def_ufunc_jps_inv_pair(erf, erfinv, lambda ans, x: 2.*inv_root_pi*np.exp(-x**2)) +def_ufunc_jps_inv_pair(erfc, erfcinv, lambda ans, x: -2.*inv_root_pi*np.exp(-x**2)) ### Logit and Expit ### logit = primitive(scipy.special.logit) expit = primitive(scipy.special.expit) -def_ufunc_jps(logit, (lambda ans, x: x * (1 - x ), 'div')) -def_ufunc_jps(expit, (lambda ans, x: ans * (1 - ans), 'mul')) +def_ufunc_jps_inv_pair(expit, logit, lambda ans, x: ans * (1 - ans)) From 98bedda2f1992ea8c3627ea093652ff5f1b197db Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Fri, 3 Nov 2017 10:41:24 +0000 Subject: [PATCH 25/32] Beta fns to new ufunc jp format --- autograd/scipy/special.py | 19 ++++++++++--------- tests/test_scipy.py | 6 +++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index 0d21c75b8..a75d2ff12 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -10,15 +10,16 @@ betainc = primitive(scipy.special.betainc) betaln = primitive(scipy.special.betaln) -defvjp(beta, - lambda ans, a, b: unbroadcast_f(a, lambda g: g * ans * (psi(a) - psi(a + b))), - lambda ans, a, b: unbroadcast_f(b, lambda g: g * ans * (psi(b) - psi(a + b)))) -defvjp(betainc, - lambda ans, a, b, x: unbroadcast_f(x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)), - argnums=[2]) -defvjp(betaln, - lambda ans, a, b: unbroadcast_f(a, lambda g: g * (psi(a) - psi(a + b))), - lambda ans, a, b: unbroadcast_f(b, lambda g: g * (psi(b) - psi(a + b)))) +def_ufunc_jps(beta, + (lambda ans, a, b: ans * (psi(a) - psi(a + b)), 'mul'), + (lambda ans, a, b: ans * (psi(b) - psi(a + b)), 'mul')) +def_ufunc_jps(betainc, + None, + None, + (lambda ans, a, b, x: np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b), 'mul')) +def_ufunc_jps(betaln, + (lambda ans, a, b: psi(a) - psi(a + b), 'mul'), + (lambda ans, a, b: psi(b) - psi(a + b), 'mul')) ### Gamma functions ### polygamma = primitive(scipy.special.polygamma) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index e0152d8c1..3ade59d19 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -165,9 +165,9 @@ def test_convolve_ignore_dot(): axes=[([1],[1])], dot_axes=[([0],[2]), ([0],[0])], mode=['full', 'valid']) ### Special ### - def test_beta(): combo_check(special.beta, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_betainc(): combo_check(special.betainc, [2]) ([R(4)**2 + 1.1], [R(4)**2 + 1.1], [U(0., 1., 4)]) - def test_betaln(): combo_check(special.betaln, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_beta(): combo_check(special.beta, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_betainc(): combo_check(special.betainc, [2] , modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1], [U(0., 1., 4)]) + def test_betaln(): combo_check(special.betaln, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) def test_gammainc(): combo_check(special.gammainc, [1], modes=['fwd', 'rev'])([1], R(4)**2 + 1.3) def test_gammaincc(): combo_check(special.gammaincc, [1], modes=['fwd', 'rev'])([1], R(4)**2 + 1.3) From b1852af316301e76312d2b2f053ba991b1937cad Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 6 Nov 2017 14:47:28 +0000 Subject: [PATCH 26/32] Add tanh and 'add' benchmarks --- benchmarks/bench_numpy_vjps.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/benchmarks/bench_numpy_vjps.py b/benchmarks/bench_numpy_vjps.py index 896450932..704ec58c6 100644 --- a/benchmarks/bench_numpy_vjps.py +++ b/benchmarks/bench_numpy_vjps.py @@ -81,3 +81,13 @@ def time_tensordot_1_1(): def time_tensordot_1_2(): tensordot_1_2(A, B, G) +A = np.random.randn(200, 200, 5, 4) +C = np.random.randn(1, 1, 5, 4) +add_0 = lambda A, B, G: make_vjp(np.add, argnum=0)(A, B)[0](G) +tanh_0 = lambda A, G: make_vjp(np.tanh, argnum=0)(A)[0](G) + +def time_add_0(): + add_0(A, C, A) + +def time_tanh_0(): + tanh_0(A, A) From cd5d24f4f284677305b2c84555c8560eb57b715a Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 6 Nov 2017 14:58:18 +0000 Subject: [PATCH 27/32] rm unnecessary match_complex from ufunc vjps --- autograd/numpy/util.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py index d2a951eff..665f2580e 100644 --- a/autograd/numpy/util.py +++ b/autograd/numpy/util.py @@ -57,11 +57,9 @@ def def_nary_ufunc_jps(ufunc, derivs_ops): lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), 'id': (lambda argnum: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(g, ans.shape)), - lambda argnum: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], g))), + lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g)), 'neg': (lambda argnum: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(-g, ans.shape)), - lambda argnum: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g: match_complex(args[argnum], -g))) + lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: -g)) } linops = { @@ -70,7 +68,7 @@ def def_nary_ufunc_jps(ufunc, derivs_ops): unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), 'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args), lambda argnum, deriv: lambda ans, *args: - unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) } def deriv_op_to_jp(idx, argnum, deriv_op): From f6dfd737cdfca49e2576fec39f68c55e1a6c1a76 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 6 Nov 2017 15:30:18 +0000 Subject: [PATCH 28/32] fix numpy vjp benchmarks --- benchmarks/bench_numpy_vjps.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_numpy_vjps.py b/benchmarks/bench_numpy_vjps.py index 704ec58c6..0d1b4f806 100644 --- a/benchmarks/bench_numpy_vjps.py +++ b/benchmarks/bench_numpy_vjps.py @@ -81,13 +81,13 @@ def time_tensordot_1_1(): def time_tensordot_1_2(): tensordot_1_2(A, B, G) -A = np.random.randn(200, 200, 5, 4) -C = np.random.randn(1, 1, 5, 4) -add_0 = lambda A, B, G: make_vjp(np.add, argnum=0)(A, B)[0](G) -tanh_0 = lambda A, G: make_vjp(np.tanh, argnum=0)(A)[0](G) +C = np.random.randn(200, 200, 5, 4) +D = np.random.randn(1, 1, 5, 4) +add_0 = lambda C, D, G: make_vjp(np.add, argnum=0)(C, D)[0](G) +tanh_0 = lambda C, G: make_vjp(np.tanh, argnum=0)(C)[0](G) def time_add_0(): - add_0(A, C, A) + add_0(C, D, C) def time_tanh_0(): - tanh_0(A, A) + tanh_0(C, C) From b5fa23693b5d46d49981eaebddb4ff65b899dbab Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 20 Nov 2017 10:13:24 +0000 Subject: [PATCH 29/32] Fix arctan2 jps def --- autograd/numpy/numpy_jvps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 9fd35a01f..2ed67c50d 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -83,7 +83,7 @@ (lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul')) def_ufunc_jps(anp.hypot, (lambda ans, x, y: x / ans, 'mul'), (lambda ans, x, y: y / ans, 'mul')) -defjvp(anp.arctan2, (lambda ans, x, y: y / (x**2 + y**2), 'mul'), +def_ufunc_jps(anp.arctan2, (lambda ans, x, y: y / (x**2 + y**2), 'mul'), (lambda ans, x, y:-x / (x**2 + y**2), 'mul')) # ----- Simple grads (linear) ----- From 57eb1b494365fe11957581a771fe05f38698e909 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 20 Nov 2017 10:29:50 +0000 Subject: [PATCH 30/32] Rm unused imports --- autograd/numpy/numpy_jvps.py | 9 +++++++-- autograd/numpy/numpy_vjps.py | 2 -- autograd/scipy/special.py | 3 +-- autograd/scipy/stats/norm.py | 2 +- autograd/scipy/stats/poisson.py | 2 +- autograd/scipy/stats/t.py | 4 ++-- 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index 2ed67c50d..29d26502b 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,9 +1,8 @@ -from itertools import repeat from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, - register_notrace, defvjp) + register_notrace) from .util import def_ufunc_jps, def_ufunc_jps_inv_pair from ..util import func @@ -23,9 +22,11 @@ lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g)) # ----- Functions that are constant w.r.t. continuous inputs ----- + defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) # ----- Unary ufuncs ------ + def_ufunc_jps(anp.negative, 'same') def_ufunc_jps(anp.rad2deg, 'same') def_ufunc_jps(anp.degrees, 'same') @@ -60,6 +61,7 @@ def_ufunc_jps_inv_pair(anp.square, anp.sqrt, lambda ans, x: 2 * x) # ----- Binary ufuncs ----- + def_ufunc_jps(anp.add, 'id', 'id') def_ufunc_jps(anp.subtract, 'id', 'neg') def_ufunc_jps(anp.multiply, 'same', 'same') @@ -87,6 +89,7 @@ (lambda ans, x, y:-x / (x**2 + y**2), 'mul')) # ----- Simple grads (linear) ----- + defjvp(anp.reshape, 'same') defjvp(anp.roll, 'same') defjvp(anp.array_split, 'same') @@ -113,12 +116,14 @@ def_linear(anp.cross) # ----- Simple grads ----- + defjvp(anp.clip, lambda g, ans, x, a_min, a_max : g * anp.logical_and(ans != a_min, ans != a_max)) defjvp(anp.where, None, lambda g, ans, c, x=None, y=None : anp.where(c, g, anp.zeros(anp.shape(g))), lambda g, ans, c, x=None, y=None : anp.where(c, anp.zeros(g.shape), g)) # ----- Trickier grads ----- + defjvp(anp.kron, 'same', 'same') defjvp(anp.diff, 'same') defjvp(anp.repeat, 'same') diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index 0a1412072..e80a60935 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -24,12 +24,10 @@ for fun in nograd_functions: register_notrace(VJPNode, fun) - # ----- Functions that are constant w.r.t. continuous inputs ----- defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.)) - # ----- Simple grads ----- defvjp(anp.reshape, lambda ans, x, shape, order=None : lambda g: anp.reshape(g, anp.shape(x), order=order)) diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index a75d2ff12..b49fc4a91 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -2,8 +2,7 @@ import scipy.special import autograd.numpy as np from autograd.numpy.util import def_ufunc_jps, def_ufunc_jps_inv_pair -from autograd.extend import primitive, defvjp -from autograd.numpy.util import unbroadcast_f +from autograd.extend import primitive ### Beta function ### beta = primitive(scipy.special.beta) diff --git a/autograd/scipy/stats/norm.py b/autograd/scipy/stats/norm.py index ea086ea87..d999be72b 100644 --- a/autograd/scipy/stats/norm.py +++ b/autograd/scipy/stats/norm.py @@ -2,7 +2,7 @@ from __future__ import absolute_import import scipy.stats import autograd.numpy as anp -from autograd.extend import primitive, defvjp +from autograd.extend import primitive from autograd.numpy.util import def_ufunc_jps pdf = primitive(scipy.stats.norm.pdf) diff --git a/autograd/scipy/stats/poisson.py b/autograd/scipy/stats/poisson.py index 8546602a1..c33d40a4b 100644 --- a/autograd/scipy/stats/poisson.py +++ b/autograd/scipy/stats/poisson.py @@ -2,7 +2,7 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp +from autograd.extend import primitive from autograd.numpy.util import def_ufunc_jps cdf = primitive(scipy.stats.poisson.cdf) diff --git a/autograd/scipy/stats/t.py b/autograd/scipy/stats/t.py index c68dcf75b..763d9ac8d 100644 --- a/autograd/scipy/stats/t.py +++ b/autograd/scipy/stats/t.py @@ -2,8 +2,8 @@ from __future__ import absolute_import import scipy.stats import autograd.numpy as np -from autograd.extend import primitive, defvjp -from autograd.numpy.util import unbroadcast_f, def_ufunc_jps +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import psi pdf = primitive(scipy.stats.t.pdf) From 93e560127917dc6ea5a676651ca9acd9dc6796d2 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Mon, 20 Nov 2017 10:32:12 +0000 Subject: [PATCH 31/32] fix indentation numpy_wrapper.py --- autograd/numpy/numpy_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index fda4693a0..b8feb631e 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -164,4 +164,4 @@ def _broadcast_to_adjoint(x, shape): @primitive def _astype(A, dtype, order='K', casting='unsafe', subok=True, copy=True): - return A.astype(dtype, order, casting, subok, copy) \ No newline at end of file + return A.astype(dtype, order, casting, subok, copy) From a9c0e454329d93b315d272f68745c01eaf70eec7 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Fri, 8 Dec 2017 15:59:43 +0000 Subject: [PATCH 32/32] Define derivs for scipy.special.rel_entr --- autograd/scipy/special.py | 5 +++++ tests/test_scipy.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index b49fc4a91..6abef774f 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -83,3 +83,8 @@ def gammainc_vjp_arg1(ans, a, x): expit = primitive(scipy.special.expit) def_ufunc_jps_inv_pair(expit, logit, lambda ans, x: ans * (1 - ans)) + +### Relative entropy ### +rel_entr = primitive(scipy.special.rel_entr) + +def_ufunc_jps(rel_entr, (lambda ans, x, y: np.log(x / y) + 1, 'mul'), (lambda ans, x, y: - x / y, 'mul')) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 3ade59d19..16f0a0692 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -21,7 +21,7 @@ from scipy.signal import convolve as sp_convolve from autograd.test_util import combo_check, check_grads - from numpy_utils import unary_ufunc_check + from numpy_utils import unary_ufunc_check, binary_ufunc_check npr.seed(1) R = npr.randn @@ -197,3 +197,5 @@ def test_erfcinv(): unary_ufunc_check(special.erfcinv, lims=[0.05, 1.95], test_c def test_logit(): unary_ufunc_check(special.logit, lims=[0.05, 0.95], test_complex=False, modes=['fwd', 'rev']) def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False, modes=['fwd', 'rev']) + + def test_rel_entr(): binary_ufunc_check(special.rel_entr, lims_A=[0.05, 1], lims_B=[0.05, 1], test_complex=False, modes=['fwd', 'rev'])