Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX backend of TreeMHN #30

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open

JAX backend of TreeMHN #30

wants to merge 25 commits into from

Conversation

pawel-czyz
Copy link
Member

@pawel-czyz pawel-czyz commented Nov 24, 2023

Premise

The $Q$ rate matrix is built out of rates $Q_{ij} = \lambda_{\pi_{ij}}(\theta)$, where $\pi$ is a lineage, and $\theta$ is the (log-) mutual hazard network.

We can construct the necessary lineages $\pi_{ij}$ in the preprocessing step, so that for a given $\theta$ we can construct the $Q$ matrix using only JAX operations. This allows us to use fast numerical algebra operations as well as to obtain gradient of the loglikelihood automatically.
Additionally, by implementing custom forward substitution algorithm, we can work with the $\log$-values, which helps with numerical stability.

Limitations

  1. As the shapes of the structures built for different trees are different, JIT has to recompile the function for each of the trees. The compilation for 100 trees took about ~45 second on my computer.
  2. We use $O(N)$ memory, where $O(N)$ is the number of subtrees and $O(N^2)$ time.
  3. We use padding, which can induce some additional cost, but makes it JAX-compatible.
  4. To make the exit rates dependent on present mutations one has to construct the paths for exit rates. Currently we use placeholder values, which assume empty paths. (I.e., the rates do not depend on the mutations in this form, although the fix is easy once one decides how the paths should be constructed.)

@laurenzkeller
Copy link
Collaborator

laurenzkeller commented Nov 24, 2023

I'm not sure if this version will be faster. Creating the paths matrix is indeed a good idea in terms of speed (we only need to create it once for each tree), however the preprocessing step will probably consume even more memory this way (I tried something similar). Regarding forward substitution: If you want to have as many calculations as you have non-zero entries, then you would need to iterate through the columns of the V matrix, not through the rows (because we solve the system V_transposed * x = b). However, one of the benefits of iterating through the rows first would be lost: When we determine a diagonal entry we can simultaneously find the off-diagonal entries in that same row (we don't need to recalculate the lambdas on the off-diagonal). If we iterate through the columns on the other hand, then we cannot use the diagonal entry to calculate the off-diagonal entries in the same column. Maybe it is still possible to calculate each distinct lambda exactly once, but you would be jumping around in the paths matrix all the time (so when you calculate an off-diagonal entry you would add the value to the corresponding diagonal entry).

@pawel-czyz pawel-czyz marked this pull request as ready for review December 7, 2023 00:16
@pawel-czyz pawel-czyz added 🚂 enhancement New feature or request 👕 effort M labels Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
👕 effort M 🚂 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants