-
Notifications
You must be signed in to change notification settings - Fork 912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sketching out what TJPs might look like #280
base: dougal-dev
Are you sure you want to change the base?
Conversation
23b09bd
to
8149360
Compare
@j-towns do you have any thoughts? Right now this TJP business is totally separated from the VJPs, but since it looks like we can adapt many of the VJPs to broadcast along leading dimensions without too much fuss, we might want to merge this TJP functionality with the VJP functionality, rather than keep things separate. It would be interesting to take a look at whether we can easily get JTPs from the JVP definitions, though I think they would need to broadcast along trailing dimensions, which is a bit weird compared to numpy's convention of broadcasting along leading dimensions (because of the row-major default). |
Can't the JTP primitives broadcast along leading dimensions, just like the TJPs? I can't see why we'd have them the other way round (we don't require anything to be transposed for VJPs/JVPs). I can't see any major reason why this generalisation would be bad (and I can see use cases where you want to quickly calculate a Jacobian or Hessian say). My only very slight concern on first looking at this is assert_vspace_compatible – it needs to do more than it does in the above implementation, checking that dtypes match etc, and this might affect performance. |
That is supposing that we implemented this generalisation as the default behaviour, which maybe is not what you were suggesting? |
Yes I think we should consider implementing this behavior as a default. Maybe the interface could involve two
(EDIT) The main interface changes would be
Re: performance, The reason I suggested broadcasting along trailing dimensions for JTPs was that I was thinking in terms of tensordot: that is, just like
hence the trailing dimensions of
However, you're right that we could choose a different convention, and just say the contraction works along trailing dimensions of |
Separate but related, since we're talking about VJPs, JVPs, TJPs, and JTPs: what do you think about changing the naming convention to "Jacobian operator" or "JO" for forward-mode and "Jacobian adjoint operator" or "JAO" for reverse-mode? I still constantly mess up VJP vs JVP when typing things. |
I agree that we should change vjp/jvp to a more distinct pair. Now that we use both, it's easy to confuse them. |
Cool, I didn't realise you had tensordot in mind. I guess the JTP I'm suggesting in terms of tensordot would be
Which isn't pretty, but I don't think we should let tensordot dictate our interface. Re implementing primitive tjps, I reckon that'll be straightforward, since all the vjps are linear in v (or I agree the jvp/vjp names are quite annoying to type and don't generalise to tensors. Maybe JO/JAO is a little better, what do others think |
Also surely it wouldn't be too difficult to have just one |
Tensordot was written that way for a reason though, namely that it generalizes matrix multiplication. I don't think your transpose works quite right because it would reverse the dimensions being contracted over, so if the jacobian is of a function with input shape (3,4) and output with shape (3,4), I would expect to be able to contract the Jacobian against a tensor of shape (3,4,5) on the input side (i.e. a JTP), but that transpose would give it shape (4,3,4,3). |
I agree that jvp/vjp is confusing, although if we change those it should probably be to something more distinct than jo/jao. Here are a couple of possibilities: fwd/rev (not very classy or descriptive, but at least memorable) By the way, I'm still not clear on the motivation for JTPs. Is it just for things like jacobian? In principle, if every jvp supported broadcasting along its first dimension, would we need jtps? |
That's right, TJPs (and JTPs) mean broadcasting along leading (or trailing) dimensions. One application is to compute Jacobian matrices/tensors without having an outer Python for loop, which is useful for the same reason we don't write matrix-matrix multiplication as a Python for loop over matrix-vector multiplication. The hope here is that we can add this broadcasting support with no cost to the VJP/JVP use case (though possibly at the cost of a couple more lines of code in core.py). I like the descriptiveness of VJP/JVP and JAO/JO. Actually drd is pretty good, though 'adjoint' seems unclear. |
@j-towns re: having just one jacobian function, I prefer not doing magic and instead making things explicit, though it's an option we could consider. |
I'm going to throw "JTO" out there, as in "Jacobian transpose operator". It's less awkward to pronounce that JAO with its diaeresis. And "Jacobian transpose" might even come closer to capturing what we mean: our VJPs of holomorphic primitives actually multiply complex vectors by the Jacobian's transpose, not by its Hermitian adjoint. |
Oh nice, good insight about the complex case. Adjoint is still nice because it suggests a matrix-free linear map (to me at least), but overall I like JTO better. |
@mattjj I think you're right about not doing implicit magic stuff. On the subject of which way the jtp should broadcast, I would do it as follows, explained without referring to tensordot (since I think you're right that my version would actually look even more messy than what I wrote above): Both In particular
That will be the simplest to implement and I think it's pretty straightforward to explain/understand what's going on too. |
@j-towns I like your definition and explanation! After reading it I realized that's more like how I think my next step on this PR is to see how well we can generalize the core's VJP logic to be the TJP logic. If that looks good (i.e. if that doesn't have much complexity cost), we could merge it into dougal-dev and then be free to add broadcasting support to VJPs/JVPs as we get to them. EDIT: by the way, anyone looking at a redesign of the test code might want to keep in mind the potential for broadcasting in some VJPs/JVPs. |
Another point re: renaming jvp/vjp. It would be handy if there was a collective noun and abreviation to refer to both at the same time. For example, if we called them Left Jacobian Operator (LJO) and Right Jacobian Operator (RJO) then we could refer to both as the Jacobian Operators (JOs) or something, I'm not saying that's a good one to go with though. I don't think we really get this with JO/JTO. Collectively we can't refer to them as JOs because that could imply a collection of just JOs (with no JTOs)... It would be handy to have a word which covers both for this pr. I've been using |
I like that point. There's a nice symmetry with LJO and RJO, though I'm not sure if they're as descriptive. |
This is very preliminary and probably not ready to merge, even into a dev branch, but I think it might be ready for some high-level feedback.
In particular, the amount of code redundancy is interesting. In tjp.py, which is like core.py, the functions/classes
make_tjp
,tjp_backward_pass
, andTJPNode
are all very similar to the VJP versions. I did have to changeassert_vspace_match
toassert_vspace_compatible
. The definition ofjacobian
, which I put at the end of tjp.py, is also a little interesting.In numpy_tjps.py, there are two main cases for the simple TJPs I've implemented so far:
The final interesting thing to look at is the additions to ArrayVSpace, which is related to #262 (and indeed I had that code open in a tab when I wrote these vspace functions!).