Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sketching out what TJPs might look like #280

Open
wants to merge 8 commits into
base: dougal-dev
Choose a base branch
from
Open

sketching out what TJPs might look like #280

wants to merge 8 commits into from

Conversation

mattjj
Copy link
Contributor

@mattjj mattjj commented Aug 26, 2017

This is very preliminary and probably not ready to merge, even into a dev branch, but I think it might be ready for some high-level feedback.

In particular, the amount of code redundancy is interesting. In tjp.py, which is like core.py, the functions/classes make_tjp, tjp_backward_pass, and TJPNode are all very similar to the VJP versions. I did have to change assert_vspace_match to assert_vspace_compatible. The definition of jacobian, which I put at the end of tjp.py, is also a little interesting.

In numpy_tjps.py, there are two main cases for the simple TJPs I've implemented so far:

  1. for the basic binary ufuncs, I had to change the implementation of unbroadcast but otherwise the code is the same, and
  2. for basic unary ufuncs, the VJPs already handle the broadcasting needed to be TJPs.

The final interesting thing to look at is the additions to ArrayVSpace, which is related to #262 (and indeed I had that code open in a tab when I wrote these vspace functions!).

@mattjj mattjj requested a review from dougalm August 26, 2017 21:46
@mattjj mattjj force-pushed the tjps branch 3 times, most recently from 23b09bd to 8149360 Compare August 27, 2017 00:44
@mattjj
Copy link
Contributor Author

mattjj commented Aug 27, 2017

@j-towns do you have any thoughts?

Right now this TJP business is totally separated from the VJPs, but since it looks like we can adapt many of the VJPs to broadcast along leading dimensions without too much fuss, we might want to merge this TJP functionality with the VJP functionality, rather than keep things separate.

It would be interesting to take a look at whether we can easily get JTPs from the JVP definitions, though I think they would need to broadcast along trailing dimensions, which is a bit weird compared to numpy's convention of broadcasting along leading dimensions (because of the row-major default).

@j-towns
Copy link
Collaborator

j-towns commented Aug 27, 2017

Can't the JTP primitives broadcast along leading dimensions, just like the TJPs? I can't see why we'd have them the other way round (we don't require anything to be transposed for VJPs/JVPs).

I can't see any major reason why this generalisation would be bad (and I can see use cases where you want to quickly calculate a Jacobian or Hessian say). My only very slight concern on first looking at this is assert_vspace_compatible – it needs to do more than it does in the above implementation, checking that dtypes match etc, and this might affect performance.

@j-towns
Copy link
Collaborator

j-towns commented Aug 27, 2017

That is supposing that we implemented this generalisation as the default behaviour, which maybe is not what you were suggesting?

@mattjj
Copy link
Contributor Author

mattjj commented Aug 27, 2017

Yes I think we should consider implementing this behavior as a default. Maybe the interface could involve two jacobian functions, one which does this TJP thing (but might raise an error if it encounters a VJP implementation which can't handle broadcasting along leading dimensions, or if it encounters an intermediate value for which we don't know how to prepend leading dimensions) and one which does the current thing (calling the VJP repeatedly on standard basis vectors of the same shape as the output). That is, the main (EDIT) code changes would be

  1. we would try to make most VJP implementations broadcast along leading dimensions,
  2. there would only be one backward_pass function, and it would handle leading dimensions like the version in this PR,
  3. our vspaces for array types would gain the _contract, _product, and _kronecker_tensor methods.

(EDIT) The main interface changes would be

  1. the VJP functions returned by make_vjp would (attempt to) broadcast along leading dimensions,
  2. we would add a second jacobian function, basically the one implemented in this PR in which it tries to call the VJP function on a Kronecker tensor, possibly with a different name.

Re: performance, assert_vspace_match does the same dtype checks that we'd want to add to assert_vspace_compatible when we do the equality check, right? Or are they more complicated for TJPs somehow?

The reason I suggested broadcasting along trailing dimensions for JTPs was that I was thinking in terms of tensordot: that is, just like np.tensordot(J, V, dims) would contract the trailing dimensions of J with the leading dimensions of V, so too could our JTPs. More precisely, we could have

make_jtp(f)(x)[0](V) === np.tensordot(jacobian(f)(x), V, np.ndim(x))

hence the trailing dimensions of V would come along for the ride. I think this matches how make_tjp works, as in

make_tjp(f)(x)[0](V) === np.tensordot(V, jacobian(f)(x), np.ndim(f(x)))

However, you're right that we could choose a different convention, and just say the contraction works along trailing dimensions of V for both TJPs and JTPs. It seems slightly messier to specify in terms of tensordot, though.

@mattjj
Copy link
Contributor Author

mattjj commented Aug 27, 2017

Separate but related, since we're talking about VJPs, JVPs, TJPs, and JTPs: what do you think about changing the naming convention to "Jacobian operator" or "JO" for forward-mode and "Jacobian adjoint operator" or "JAO" for reverse-mode? I still constantly mess up VJP vs JVP when typing things.

@dougalm
Copy link
Contributor

dougalm commented Aug 27, 2017

I agree that we should change vjp/jvp to a more distinct pair. Now that we use both, it's easy to confuse them.

@j-towns
Copy link
Collaborator

j-towns commented Aug 27, 2017

Cool, I didn't realise you had tensordot in mind. I guess the JTP I'm suggesting in terms of tensordot would be

make_jtp(f)(x)[0](V) === np.tensordot(V, np.transpose(jacobian(f)(x)), np.ndim(x))

Which isn't pretty, but I don't think we should let tensordot dictate our interface.

Re implementing primitive tjps, I reckon that'll be straightforward, since all the vjps are linear in v (or g as it's still referred to in the code) and in the worst case we'll always be able to do the broadcasting with einsum I think.

I agree the jvp/vjp names are quite annoying to type and don't generalise to tensors. Maybe JO/JAO is a little better, what do others think @dougalm @duvenaud?

@j-towns
Copy link
Collaborator

j-towns commented Aug 27, 2017

Also surely it wouldn't be too difficult to have just one jacobian interface which attempts the new tjp approach and if it encounters a primitive where that doesn't work it falls back locally to the old/current method.

@mattjj
Copy link
Contributor Author

mattjj commented Aug 27, 2017

Tensordot was written that way for a reason though, namely that it generalizes matrix multiplication. I don't think your transpose works quite right because it would reverse the dimensions being contracted over, so if the jacobian is of a function with input shape (3,4) and output with shape (3,4), I would expect to be able to contract the Jacobian against a tensor of shape (3,4,5) on the input side (i.e. a JTP), but that transpose would give it shape (4,3,4,3).

@duvenaud
Copy link
Contributor

I agree that jvp/vjp is confusing, although if we change those it should probably be to something more distinct than jo/jao. Here are a couple of possibilities:

fwd/rev (not very classy or descriptive, but at least memorable)
drd/adj (directional derivative and adjoint)

By the way, I'm still not clear on the motivation for JTPs. Is it just for things like jacobian? In principle, if every jvp supported broadcasting along its first dimension, would we need jtps?

@mattjj
Copy link
Contributor Author

mattjj commented Aug 27, 2017

That's right, TJPs (and JTPs) mean broadcasting along leading (or trailing) dimensions. One application is to compute Jacobian matrices/tensors without having an outer Python for loop, which is useful for the same reason we don't write matrix-matrix multiplication as a Python for loop over matrix-vector multiplication. The hope here is that we can add this broadcasting support with no cost to the VJP/JVP use case (though possibly at the cost of a couple more lines of code in core.py).

I like the descriptiveness of VJP/JVP and JAO/JO. Actually drd is pretty good, though 'adjoint' seems unclear.

@mattjj
Copy link
Contributor Author

mattjj commented Aug 27, 2017

@j-towns re: having just one jacobian function, I prefer not doing magic and instead making things explicit, though it's an option we could consider.

@dougalm
Copy link
Contributor

dougalm commented Aug 27, 2017

I'm going to throw "JTO" out there, as in "Jacobian transpose operator". It's less awkward to pronounce that JAO with its diaeresis. And "Jacobian transpose" might even come closer to capturing what we mean: our VJPs of holomorphic primitives actually multiply complex vectors by the Jacobian's transpose, not by its Hermitian adjoint.

@mattjj
Copy link
Contributor Author

mattjj commented Aug 28, 2017

Oh nice, good insight about the complex case. Adjoint is still nice because it suggests a matrix-free linear map (to me at least), but overall I like JTO better.

@j-towns
Copy link
Collaborator

j-towns commented Aug 28, 2017

@mattjj I think you're right about not doing implicit magic stuff. On the subject of which way the jtp should broadcast, I would do it as follows, explained without referring to tensordot (since I think you're right that my version would actually look even more messy than what I wrote above):


Both jvp(f)(x) and vjp(f)(x) are linear operators which broadcast in a similar way to some of numpy's matrix linear algebra routines (see here for explanation). They treat their input as a 'stack' of arrays to which the linear operator should be applied.

In particular vjp(f)(x) maps from shape (..., *f(x).shape) tensors to shape (..., *x.shape), while jvp(f)(x) maps from shape (..., *x.shape) to shape (..., *f(x).shape). Or to use a more concise notation:

jvp(f)(x): (..., *x.shape   ) -> (..., *f(x).shape)
vjp(f)(x): (..., *f(x).shape) -> (..., *x.shape   )

That will be the simplest to implement and I think it's pretty straightforward to explain/understand what's going on too.

@mattjj
Copy link
Contributor Author

mattjj commented Aug 28, 2017

@j-towns I like your definition and explanation! After reading it I realized that's more like how np.matmul works, and np.matmul is probably the best thing to model behavior on.

I think my next step on this PR is to see how well we can generalize the core's VJP logic to be the TJP logic. If that looks good (i.e. if that doesn't have much complexity cost), we could merge it into dougal-dev and then be free to add broadcasting support to VJPs/JVPs as we get to them.

EDIT: by the way, anyone looking at a redesign of the test code might want to keep in mind the potential for broadcasting in some VJPs/JVPs.

@j-towns j-towns mentioned this pull request Oct 14, 2017
@j-towns
Copy link
Collaborator

j-towns commented Oct 17, 2017

Another point re: renaming jvp/vjp. It would be handy if there was a collective noun and abreviation to refer to both at the same time. For example, if we called them Left Jacobian Operator (LJO) and Right Jacobian Operator (RJO) then we could refer to both as the Jacobian Operators (JOs) or something, I'm not saying that's a good one to go with though.

I don't think we really get this with JO/JTO. Collectively we can't refer to them as JOs because that could imply a collection of just JOs (with no JTOs)... It would be handy to have a word which covers both for this pr. I've been using derivs but it feels clumsy.

@mattjj
Copy link
Contributor Author

mattjj commented Oct 17, 2017

I like that point. There's a nice symmetry with LJO and RJO, though I'm not sure if they're as descriptive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants