Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on relation to univariate Taylor series propagation for higher-order derivatives #1

Open
thisiscam opened this issue Dec 19, 2024 · 6 comments

Comments

@thisiscam
Copy link

Hi,

Thank you for your interesting work! I’ve been reading your paper and found the idea of stochastically evaluating higher-order derivatives quite intriguing.

I was wondering if you could comment on how your technique relates to the known method of evaluating higher-order derivative tensors by propagating a set of univariate Taylor series. This idea is discussed in this paper and the Evaluating Derivatives book (Chapter 13).

Specifically, could your method be interpreted (or extended) as sampling a set of univariate Taylor series to propogate, to the end of approximating the derivative tensor? I’d love to hear your thoughts on this connection or any fundamental differences.

Thanks in advance for your response!

@thisiscam thisiscam changed the title Question on relation to univariate Taylor series for higher-order derivatives Question on relation to univariate Taylor series propagation for higher-order derivatives Dec 19, 2024
@zekun-shi
Copy link
Collaborator

Thanks for bringing up this highly relevant work! Upon reading the paper, I think the idea of evaluating arbitrary derivative tensor elements via forward propagation univariate Taylor series is similar. I wasn't aware of this work, and based on my experience talking to people at NeurIPS, this work is not well known within the ML community even for people who work on AD in the ML context. The JAX team who wrote the Taylor mode AD I used in the paper wasn't aware of this technique and was quite surprised that one could do this.

So yes your interpretation is right, STDE approximates derivative tensor with randomized univariate Taylor series propagation.

@thisiscam
Copy link
Author

thisiscam commented Jan 10, 2025

Thank you for your detailed response! I wanted to follow up with some additional thoughts and questions:

  1. I agree that the method of propagating univariate Taylor series to compute higher-order derivative tensors is not widely known in the ML community, which is indeed unfortunate given its mention in the classic Evaluating Derivatives book. It's great to see that your work brings attention to this approach, even if from a different perspective!

  2. That said, I think there's a key distinction between your method and GUW98. Specifically, GUW98 considers only first-order perturbations, whereas your work (as seen in Eq. (10) and Appendix F.1) includes higher-order perturbations.
    Specifically, (it appears to me that) GUW98 only considers series inputs of the form
    $$[v, 0 \ldots 0]$$ where $$v$$ is an integer linear combination of the basis vectors $e_i$,
    while your work (please correct me if I'm wrong) considers
    $$[\ldots 0 e_1 , 0, \ldots 0, e_2, 0, \ldots, 0, e_k]$$ that scatters the basis vector in the higher-order perturbations.
    In Eq. (10) of your paper, the coefficient $$C_{d_1 \dots d_k}$$ corresponds to the coefficient associated with each directional derivative. However, I couldn't find a clear algorithm or explicit description in your paper for computing $$C$$. Could you clarify if you have a formula or algorithm for computing $$C$$ that generalizes Eq. (13) in GUW98? Screenshot from 2025-01-10 12-47-16 It would be interesting to explore whether including higher-order perturbations, as in your work, might lead to better algorithms or approximations in some contexts.

  3. I found your comment in the paper on fully mixed partial derivatives $$\frac{\partial^k f}{\partial x_1 \dots \partial x_k}$$ (i.e., maximally non-diagonal operators) particularly interesting. Your paper mentions that the minimum $$l$$ required for these cases is $$(1 + k)k/2$$. However, when generalizing your example, I noticed that calculating the fully mixed partial derivative using a single higher-order univariate derivative (without cancellation steps) seems to require an exponential $$l = O(k \cdot 2^k)$$-th order derivative. For instance:

    • When $$k = 2$$, positions $$[0, e_1, e_2]$$ work.
    • When $$k = 3$$, positions $$[0, 0, 0, e_1, 0, e_2, e_3]$$ work.
    • In general, the perturbation $$e_i$$ appears at index $$2^{k-1} + \ldots + 2^{k-i}$$ , leading to an overall length $$l = k 2^{k-1} - 1$$. Could you clarify if the quadratic $$l = (1 + k)k/2$$ formula in your paper offers a more efficient method or insight into improving this computation?

Looking forward to your thoughts!

@zekun-shi
Copy link
Collaborator

Thanks again for your insightful questions!

  1. Upon further reading of GUW98, I realized that it uses a very different approach for evaluating partial derivatives where only the first-order perturbations are used whereas in my approach high-order perturbations are used, as you have pointed out. First I would like to clarify that, the $C$ tensor in Eq. (10) in my paper corresponds to the coefficients of a linear derivative operator and is known. I guess what you want to ask is that, given a multi-index $i=(i_{1}, \dots, i_{N})$, how to compute the partial derivative $\frac{\partial^{|i|}f}{\partial x_{1}^{i_{1}} \dots \partial x_{N}^{i_{N}}}$ by pushforwards of some univariate Taylor series with high-order perturbation? You are right that there isn't any description of the algorithm for computing the set of series input required given a $i$. I did not include this in the paper as I thought it was out of scope, but I intend to publish the description and the code for that somewhere. Stay tuned!

  2. To achieve the minimum order $(1+k)k / 2$, one has to push forward multiple different jets and do cancellations to remove the extra terms. I didn't include a formal proof in the paper on this, but it follows directly from Faa Di Bruno's formula (Eq. 38 in my paper). Here I give a short proof. Suppose we want to compute a fully mixed partial of order $M$: $\frac{\partial^{M}f}{\partial x_{1} \dots \partial x_{M}}$. We want to extract the term in the summation of Eq. (38) where the term $\prod_{j} \left(\frac{1}{j!} g^{(j)}(x) \right)^{p_{j}}$ contains the product of all the different perturbations $g^{(1)} \dots g^{(M)}$. The lowest $k$ where this term appears is $M(M+1) / 2 = 1 + 2 + \dots + M$ where $p_{i}=1$ for all $i\in [1,M]$. Of course, there are extra terms, but one can always perform cancellation to remove all the unnecessary terms. For example, take your example of $M=2$, the correct cancellation steps are:
    $$\partial^{3}u(a,e_{1},e_{2},0)-\partial^{3}u(a,e_{1},0,0)=\frac{\partial^{2} u}{\partial x_{1} \partial x_{2}} \Big\vert_{x=a}$$
    The case of $M=3$ is slightly more complicated:
    $$\frac{1}{60}[\partial^{6}u(a,e_{1},e_{2},e_{3}, 0, 0, 0)-\partial^{6}u(a,e_{1},e_{2},0,0,0,0) - \partial^{6}u(a,e_{1},0,e_{3},0,0,0) + \partial^{6}u(a,e_{1},0,0,0,0,0)]$$
    These are a special case of the algo I mentioned in above, whose purpose corresponds to Eq. (13) in GUW98.

I'm unable to understand why the series you gave would work though (e.g. $[0,e_1,e_2]$). Could you explain your solution a bit more?

Also, feel free to continue the discussion if you have further questions or ideas!

@thisiscam
Copy link
Author

thisiscam commented Feb 7, 2025

2. First I would like to clarify that, the
C
tensor in Eq. (10) in my paper corresponds to the coefficients of a linear derivative operator and is known.

What is C here and could you clarify what you mean by "is known"? I probably have a misunderstanding, but I thought in your Equation 21, 1/330 and 1/200200 are entries of C? Is it different from interpreting C as the coefficient of GUW98 equation 13 above, the coefficient term $$choose(i, k) * (-1)^{|i - k|}$$ ?

but one can always perform cancellation to remove all the unnecessary terms.

Thanks for clarifying! If cancellation allowed, you can do it with just first order perturbations using Equation 13 of GUW98, right? So the minimum order could be $k$ (i.e., $$|i|$$ in Eq 13 of GUW 98)? What am I missing here? It seems to me that there is a tradeoff between the order of the jet and the number of jet evaluations (i.e., required cancellation steps): the lower the order of jet, the more cancellation steps one have to do, and vice versa.

I'm unable to understand why the series you gave would work though

It's been a while and I forgot how I come up with that construction! I remember the idea was to construct it so that lower order perturbations will contribute once and only once at the max order.
But here is the code that illustrates the idea: https://gist.github.com/thisiscam/5f6e16c1629408d2fca5542ccec11c91.
As you can see, with my construction, the higher order jet is a constant multiple of the desired mixed partial derivative (the constant multiple doesn't change with varying x0). I think one should be able to derive a formula for this constant multiple, but I have not tried.

@zekun-shi
Copy link
Collaborator

What is C here and could you clarify what you mean by "is known"?

This C corresponds to the linear derivative operators. For example, for the Laplacian operator, C is simply the identity matrix. The 1/330 and 1/200200 in Eq. 21 in my paper are the "constant_ratio" in your gist. The formula for this ratio is given by Faa di Bruno's formula.

It seems to me that there is a tradeoff between the order of the jet and the number of jet evaluations (i.e., required cancellation steps): the lower the order of jet, the more cancellation steps one have to do, and vice versa.

Yes, this is my observation as well. The $l=k(k+1) / 2$ bound I reported is for the scheme where only standard basis $e_i$ are used for perturbations. With the approach of GUW98, $l=k$ but more cancellation steps, i.e. more jet forwards, are needed. For example, when $M=3$, my approach (see my previous reply) requires 4 jet forwards, while the GUW98 approach would require 7 jet forwards. I will do some analysis later to see exactly how different they are.

One further comment: I think going for higher jets seems to be slightly better. Here's the sketch:

  • both GUW98 and my cancellation algorithm require an exponential number of jet forwards
  • most jet rules use recursion to compute the higher-order jet output from the lower order where the complexity for each step is of roughly the same order
  • therefore compute scaling is roughly $O(2^k)k$ vs $O(2^k)$. The case for memory scaling is similar

It's been a while and I forgot how I come up with that construction! I remember the idea was to construct it so that lower order perturbations will contribute once and only once at the max order.

Thanks for posting the gist, this clears up a lot! I thought you were using $[0,e_1,e_2]$ for $M=2$ but in fact the perturbations you used were $[0,e_1,e_2,0,0]$ which totally make sense. As mentioned, the constant multiple in your gist is given by Faa di Bruno's formula. See my gist for an example implementation.

@thisiscam
Copy link
Author

I'd like to add one further comment regarding the use of higher jets, which seems to be slightly better based on the initial analysis:

The complexity analysis should also include the order of jet forward. The complexity in implementations like jax.experimental.jet scales quadratically with the jet dimension ($$O(n^2)$$); the textbook reference I shared earlier describes a method to achieve a $$O(n \log(n))$$ complexity using FFT (which I happen to have preliminarily implemented and plan to make public). Nonetheless, either complexity could be factored in when trading off jet order vs. number of jet forwards.

Additionally, it's also desirable to consider the simultaneous computation of multiple mixed partial derivatives. GUW98 offers an optimized approach for sharing jet forward computations (see Eq 17) that computes all entries of the high-order derivatives tensor with a better complexity.
It would be interesting to understand the complexity-optimal way for computing any subset of higher-order derivative entries through multiple jet forwards in the most general way. By "the most general way," I refer to the comparison where GUW98 focuses on first-order perturbations using a non-standard basis, whereas your approach seems to incorporate a standard basis at higher jet orders. Ideally, the general way would integrate both approaches.

the constant multiple in your gist is given by Faa di Bruno's formula.

I see! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants