Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define ufunc JO and JTO simultaneously #312

Open
wants to merge 39 commits into
base: master
Choose a base branch
from

Conversation

j-towns
Copy link
Collaborator

@j-towns j-towns commented Oct 16, 2017

To possibly do:

Summary of the changes in this pr

  1. New helper function def_ufunc_jps for defining the jvp and vjp of a ufunc in one shot. The lines
    defjvp(anp.sin, lambda g, ans, x : g * anp.cos(x))
    defvjp(anp.sin, lambda ans, x : lambda g: g * anp.cos(x))
    can be replaced with a single line
    #                      ('derivative'             , linear operator to apply)
    def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul'                   ))
    I've added a docstring to def_ufunc_jps explaining how to use it.
  2. All numpy and scipy ufuncs and ufunc-like functions to the new format enabling forward mode for many scipy primitives. Enable forward mode tests for newly supported primitives.
  3. Make broadcast_to into a primitive, define its adjoint in numpy_wrapper.py and setup derivatives. This is roughly the same as make internal broadcast and unbroadcast both primitives #292.
  4. New helper function def_ufunc_jps_inv_pair for defining the jps of an inverse pair of ufuncs in one shot. So for example the four defs
    defvjp(anp.square,  lambda ans, x : lambda g: g * 2 * x)		
    defvjp(anp.sqrt,    lambda ans, x : lambda g: g * 0.5 * x**-0.5)
    
    defjvp(anp.square,      lambda g, ans, x : g * 2 * x)		
    defjvp(anp.sqrt,        lambda g, ans, x : g * 0.5 * x**-0.5)
    become
    def_ufunc_jps_inv_pair(anp.square, anp.sqrt, lambda ans, x: 2 * x)
    Implement this for the 10 or so inverse pairs I spotted. This could also make implementing the grads of inverse cdf functions (which exist for most scipy.stats distributions) very straightforward.
  5. Move match_complex, unbroadcast and unbroadcast_f into newly created autograd.numpy.util alongside the new helper functions (I think this rearrangement makes sense).
  6. (Bonus) I've added the derivatives and tests for scipy.special.rel_entr.

Notes

We could reduce the number of lines of code for other primitive defs in this way. In particular I've got my eye on the reductions (sum, mean, var, std) to potentially do next. I also think, at least in the case of ufuncs, that this style is clearer. I guess we'd better check that there's little or no harm to performance. I think any overhead introduced could potentially be optimized away by carefully handling different special cases in def_ufunc_jps.

During higher order derivatives, the computation of what I've called 'derivative' in the snippet above could be cached and reused. This computation is currently being re-done, because the same primitive's jvp/vjp is being called at each layer of unboxing, with numerically the same ans and *args inputs (although g will be different for each layer). I might write a little blog post about this as I suspect this is an optimization that may not yet be implemented in e.g. Theano or Tensorflow. Implementing this in Autograd might require something similar to what was discussed in #188.

@j-towns j-towns changed the title [Experiment][WIP] Define JO and JTO simultaneously [Experiment][WIP] Define ufunc JO and JTO simultaneously Oct 17, 2017
@j-towns j-towns changed the title [Experiment][WIP] Define ufunc JO and JTO simultaneously Define ufunc JO and JTO simultaneously Oct 18, 2017
This should minimize memory overhead.
unbroadcast_f(args[argnum], lambda g: -g)),
'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args),
lambda argnum, deriv: lambda ans, *args:
unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)),
Copy link
Collaborator Author

@j-towns j-towns Oct 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the vjps I've used this slightly weird d=deriv(ans, *args) default argument syntax to ensure that deriv is evaluated during the forward pass, allowing *args and ans to potentially be garbage collected.

Any objections? I could also have done this using a kind of helper closure to evaluate deriv, which would have been a bit more explicit.

@j-towns j-towns changed the title Define ufunc JO and JTO simultaneously [WIP] Define ufunc JO and JTO simultaneously Oct 26, 2017
@j-towns j-towns changed the base branch from dev-1.2 to master October 30, 2017 17:29
@j-towns j-towns changed the title [WIP] Define ufunc JO and JTO simultaneously Define ufunc JO and JTO simultaneously Nov 2, 2017
@j-towns
Copy link
Collaborator Author

j-towns commented Nov 2, 2017

@mattjj / @dougalm can you review this?

@j-towns
Copy link
Collaborator Author

j-towns commented Nov 6, 2017

Have added a couple of benchmarks and run a bench compare. Almost everything is the same but there are some differences from compute which I've shifted from the backward pass to the forward pass:

    before     after       ratio
  [1b96990a] [98bedda2]
+  541.78μs   725.65μs      1.34  bench_core.time_long_forward_pass
+   34.93μs    44.31μs      1.27  bench_core.time_short_forward_pass
+  697.56ms   879.13ms      1.26  bench_core.time_fan_out_fan_in_grad
-   30.90ms    21.06ms      0.68  bench_numpy_vjps.time_tanh_0
-   17.07μs    10.21μs      0.60  bench_core.time_short_backward_pass
-  305.48μs   120.86μs      0.40  bench_core.time_long_backward_pass

Also something's going wrong in the needless nodes case, will check that out. Edit: fixed the garbage collection of needless nodes and the vjp of add taking longer.

@j-towns
Copy link
Collaborator Author

j-towns commented Nov 13, 2017

Ping @mattjj + @dougalm! Can you take a look at this pr?

@j-towns
Copy link
Collaborator Author

j-towns commented Nov 20, 2017

I'm wondering whether, for consistency, all of the extra numpy-ish primitives that we define (things like the dot and tensordot adjoints) should be in numpy_wrapper, alongside things like make_diagonal and broadcast_to_adjoint.

They can be viewed as extra primitives that we want to add to numpy (primitives which happen to be useful for calculating derivatives), so perhaps it makes more sense for them to be there.

@j-towns j-towns mentioned this pull request Nov 29, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant