-
Notifications
You must be signed in to change notification settings - Fork 912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Define ufunc JO and JTO simultaneously #312
base: master
Are you sure you want to change the base?
Conversation
I think this obviates the changes in HIPS#292.
This should minimize memory overhead.
autograd/numpy/numpy_jvps.py
Outdated
unbroadcast_f(args[argnum], lambda g: -g)), | ||
'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), | ||
lambda argnum, deriv: lambda ans, *args: | ||
unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the vjps I've used this slightly weird d=deriv(ans, *args)
default argument syntax to ensure that deriv
is evaluated during the forward pass, allowing *args
and ans
to potentially be garbage collected.
Any objections? I could also have done this using a kind of helper closure to evaluate deriv, which would have been a bit more explicit.
Also move broadcast_to_adjoint into numpy_wrapper.
Have added a couple of benchmarks and run a bench compare. Almost everything is the same but there are some differences from compute which I've shifted from the backward pass to the forward pass:
|
I'm wondering whether, for consistency, all of the extra numpy-ish primitives that we define (things like the dot and tensordot adjoints) should be in numpy_wrapper, alongside things like make_diagonal and broadcast_to_adjoint. They can be viewed as extra primitives that we want to add to numpy (primitives which happen to be useful for calculating derivatives), so perhaps it makes more sense for them to be there. |
86820fd
to
2f6cc22
Compare
To possibly do:
Summary of the changes in this pr
def_ufunc_jps
for defining the jvp and vjp of a ufunc in one shot. The linesdef_ufunc_jps
explaining how to use it.broadcast_to
into a primitive, define its adjoint in numpy_wrapper.py and setup derivatives. This is roughly the same as make internal broadcast and unbroadcast both primitives #292.def_ufunc_jps_inv_pair
for defining the jps of an inverse pair of ufuncs in one shot. So for example the four defsmatch_complex
,unbroadcast
andunbroadcast_f
into newly createdautograd.numpy.util
alongside the new helper functions (I think this rearrangement makes sense).scipy.special.rel_entr
.Notes
We could reduce the number of lines of code for other primitive defs in this way. In particular I've got my eye on the reductions (sum, mean, var, std) to potentially do next. I also think, at least in the case of ufuncs, that this style is clearer. I guess we'd better check that there's little or no harm to performance. I think any overhead introduced could potentially be optimized away by carefully handling different special cases in def_ufunc_jps.
During higher order derivatives, the computation of what I've called
'derivative'
in the snippet above could be cached and reused. This computation is currently being re-done, because the same primitive's jvp/vjp is being called at each layer of unboxing, with numerically the same ans and *args inputs (although g will be different for each layer). I might write a little blog post about this as I suspect this is an optimization that may not yet be implemented in e.g. Theano or Tensorflow. Implementing this in Autograd might require something similar to what was discussed in #188.