Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relatedness matrix-vector product #2980

Merged
merged 1 commit into from
Sep 25, 2024

Conversation

petrelharp
Copy link
Contributor

@petrelharp petrelharp commented Aug 29, 2024

Here's an implementation of the "relatedness matrix-vector" product operation. (WIP; currently just in python). It is a simplification of the code in #2710.

This could be generalized to the following: given a set of sample weights $w$ and a summary function $f( )$, for a node $n$ in tree $T$ let $w_T(n)$ be the sum of the weights of all samples below $n$, and then for each sample $u$ compute
$$S(u) = \sum_T \ell_T \sum_{n \ge_T u} (t_{p(n)} - t_n) f(w_T(n)) ,$$
i.e., the sum over all nodes above the sample in all trees of the summary function for the node multiplied by the area of the edge above the node.

Besides the left/right child/sib arrays, it just needs to keep track of three (num nodes)-length vectors as it iterates over trees: the sum of the weights below the node (w); the last position the node's contribution was flushed at (x); and the total contribution to the output for all samples below the node (stack). Every time we add or remove an edge we need to mvoe all the contributions on the path above the edge to the children.

Copy link

codecov bot commented Aug 29, 2024

Codecov Report

Attention: Patch coverage is 96.41577% with 10 lines in your changes missing coverage. Please review.

Project coverage is 89.78%. Comparing base (394b84b) to head (75ba0c3).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
c/tskit/trees.c 96.26% 5 Missing and 3 partials ⚠️
python/_tskitmodule.c 96.29% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2980      +/-   ##
==========================================
+ Coverage   89.73%   89.78%   +0.05%     
==========================================
  Files          29       29              
  Lines       31600    31879     +279     
  Branches     6122     6170      +48     
==========================================
+ Hits        28355    28624     +269     
- Misses       1853     1859       +6     
- Partials     1392     1396       +4     
Flag Coverage Δ
c-tests 86.68% <96.26%> (+0.13%) ⬆️
lwt-tests 80.78% <ø> (ø)
python-c-tests 89.03% <96.29%> (+0.05%) ⬆️
python-tests 99.02% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
python/tskit/trees.py 98.80% <100.00%> (+<0.01%) ⬆️
python/_tskitmodule.c 89.03% <96.29%> (+0.05%) ⬆️
c/tskit/trees.c 90.66% <96.26%> (+0.21%) ⬆️

@jeromekelleher
Copy link
Member

Is it worth doing a numba version of this to see how it performs before we get into C implementation?

@hanbin973
Copy link
Contributor

hanbin973 commented Sep 1, 2024

I figured out why the following expression

$$ S(u) = \sum_T \ell_T \sum_{n \ge_T u} (t_{p(n)} - t_n) f(w_T(n)) $$

is GRM * w when $f(x)=x$.

For the uncentered GRM, $v =GRM * w$ is

$$ \sum_T \ell_T \sum_{s:samples} t_{T, us} w_s $$

where $t_{T, us}$ is the time from ' $u$ and $s$'s most common ancestor node' to the 'root node' at tree $T$.
All we have to show is that

$$ \sum_{s:samples} t_{T,us} w_s = \sum_{n \ge_T u} (t_{p(n)} - t_n) w_T(n) $$

To see this, substitute the below to the above

$$ t_{T,us} = \sum_{n \ge_T u,s} (t_{p(n)}-t_n) $$

and change of order of summation

$$ \sum_s t_{T,us} w_s = \sum_s \left[\sum_{n \ge_T u,s} (t_{p(n)}-t_n) \right] w_s = \sum_{n \ge_T u} (t_{p(n)} -t_n) \sum_{s:n \ge_T s} w_s = \sum_{n \ge_T u} (t_{p(n)} - t_n) w_T(n) $$

because of the definition of $w_T(n) = \sum_{s:n \ge_T s} w_s$.

Was this kind of derivation what you had first in mind? I was only able to do this after knowing the results.

@petrelharp
Copy link
Contributor Author

Is it worth doing a numba version of this to see how it performs before we get into C implementation?

My feeling is "no" because we've already basically got the C implementation in #2710. However, FYI:

  1. I've got to finalize some other relatedness-related things before finishing off this (eg the 'centering' and things)
  2. I have a possibly more efficient and simpler method in my head

@petrelharp
Copy link
Contributor Author

Was this kind of derivation what you had first in mind? I was only able to do this after knowing the results.

Exactly! (Well, you've written it out in more detail than I had, which is helpful!)

@hanbin973
Copy link
Contributor

Is it worth doing a numba version of this to see how it performs before we get into C implementation?

# MIT License
#
# Copyright (c) 2024 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Test cases for matrix-vector product stats
"""
import msprime
import numpy as np
import numba
from numba import i4, f8
from numba.experimental import jitclass

import tskit

# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when
# we can remove this.


# Implementation note: the class structure here, where we pass in all the
# needed arrays through the constructor was determined by an older version
# in which we used numba acceleration. We could just pass in a reference to
# the tree sequence now, but it is useful to keep track of exactly what we
# require, so leaving it as it is for now.

spec = [
        ('parent', i4[:]),
        ('left_sib', i4[:]),
        ('right_sib', i4[:]),
        ('left_child', i4[:]),
        ('right_child', i4[:]),
        ('num_samples', i4[:]),
        ('edges_left', f8[:]),
        ('edges_right', f8[:]),
        ('edges_parent', i4[:]),
        ('edges_child', i4[:]),
        ('edge_insertion_order', i4[:]),
        ('edge_removal_order', i4[:]),
        ('sequence_length', f8),
        ('nodes_time', f8[:]),
        ('samples', i4[:]),
        ('position', f8),
        ('virtual_root', i4),
        ('x', f8[:]),
        ('stack', f8[:]),
        ('NULL', i4)
       ] 

@jitclass(spec)
class RelatednessVector:
    def __init__(
        self,
        num_nodes,
        samples,
        nodes_time,
        edges_left,
        edges_right,
        edges_parent,
        edges_child,
        edge_insertion_order,
        edge_removal_order,
        sequence_length
    ):
        # virtual root is at num_nodes; virtual samples are beyond that
        N = num_nodes + 1 + len(samples)
        # Quintuply linked tree
        self.parent = np.full(N, -1, dtype=np.int32)
        self.left_sib = np.full(N, -1, dtype=np.int32)
        self.right_sib = np.full(N, -1, dtype=np.int32)
        self.left_child = np.full(N, -1, dtype=np.int32)
        self.right_child = np.full(N, -1, dtype=np.int32)
        # Sample lists refer to sample *index*
        self.num_samples = np.full(N, 0, dtype=np.int32)
        # Edges and indexes
        self.edges_left = edges_left
        self.edges_right = edges_right
        self.edges_parent = edges_parent
        self.edges_child = edges_child
        self.edge_insertion_order = edge_insertion_order
        self.edge_removal_order = edge_removal_order
        self.sequence_length = sequence_length
        self.nodes_time = nodes_time
        self.samples = samples
        self.position = 0
        self.virtual_root = num_nodes
        self.x = np.zeros(N, dtype=np.float64)
        self.stack = np.zeros(N, dtype=np.float64)
        self.NULL = -1

        for j, u in enumerate(samples):
            self.num_samples[u] = 1
            v = num_nodes + 1 + j
            self.insert_branch(u, v)
            self.num_samples[v] = 1

    def remove_branch(self, p, c):
        lsib = self.left_sib[c]
        rsib = self.right_sib[c]
        if lsib == -1:
            self.left_child[p] = rsib
        else:
            self.right_sib[lsib] = rsib
        if rsib == -1:
            self.right_child[p] = lsib
        else:
            self.left_sib[rsib] = lsib
        self.parent[c] = -1
        self.left_sib[c] = -1
        self.right_sib[c] = -1

    def insert_branch(self, p, c):
        self.parent[c] = p
        u = self.right_child[p]
        if u == -1:
            self.left_child[p] = c
            self.left_sib[c] = -1
            self.right_sib[c] = -1
        else:
            self.right_sib[u] = c
            self.left_sib[c] = u
            self.right_sib[c] = -1
        self.right_child[p] = c

    def remove_edge(self, p, c):
        assert p != -1
        self.stack[c] += self.get_z(c)
        self.remove_branch(p, c)

    def insert_edge(self, p, c):
        assert p != -1
        assert self.parent[c] == -1, "contradictory edges"
        self.insert_branch(p, c)

    def get_z(self, u):
        p = self.parent[u]
        if p == self.NULL or u >= self.virtual_root:
            return 0.0
        time = self.nodes_time[p] - self.nodes_time[u]
        span = self.position - self.x[u]
        return np.sqrt(time * span) * np.random.normal()

    def get_root_path(self, u):
        """
        Returns the list of nodes back to the virtual root.
        """
        root_path = []
        p = u
        while p != self.NULL:
            root_path.append(p)
            p = self.parent[p]
        return root_path

    def push_down(self, u):
        """
        Add the edge above u to its stack, and then move u's stack to its
        children.
        """
        self.stack[u] += self.get_z(u)
        self.x[u] = self.position

        c = self.left_child[u]
        while c != self.NULL:
            self.stack[c] += self.stack[u]
            c = self.right_sib[c]

        self.stack[u] = 0
        
    def flush_root_path(self, root_path):
        """
        Clears all nodes on the path from the virtual root down to u
        by pushing the contributions of all their branches to the stack
        and pushing their stacks to their children.
        """
        
        j = len(root_path) - 1
        while j >= 0:
            p = root_path[j]
            self.push_down(p)
            j -= 1

    def run(self):
        sequence_length = self.sequence_length
        M = self.edges_left.shape[0]
        in_order = self.edge_insertion_order
        out_order = self.edge_removal_order
        edges_left = self.edges_left
        edges_right = self.edges_right
        edges_parent = self.edges_parent
        edges_child = self.edges_child

        j = 0
        k = 0
        # TODO: self.position is redundant with left
        left = 0
        self.position = left

        while k < M and left <= self.sequence_length:
            while k < M and edges_right[out_order[k]] == left:
                p = edges_parent[out_order[k]]
                c = edges_child[out_order[k]]
                root_path = self.get_root_path(p)
                self.flush_root_path(root_path)
                self.remove_edge(p, c)
                k += 1
            while j < M and edges_left[in_order[j]] == left:
                p = edges_parent[in_order[j]]
                c = edges_child[in_order[j]]
                if self.position > 0:
                    root_path = self.get_root_path(p)
                    self.flush_root_path(root_path)
                assert self.parent[p] == self.NULL or self.x[p] == self.position
                self.insert_edge(p, c)
                self.x[c] = self.position
                j += 1
            right = sequence_length
            if j < M:
                right = min(right, edges_left[in_order[j]])
            if k < M:
                right = min(right, edges_right[out_order[k]])
            left = right
            self.position = left

        # clear remaining things down to virtual samples
        for j, u in enumerate(self.samples):
            self.push_down(u)
            v = self.virtual_root + 1 + j
            self.remove_edge(u, v)

        out = np.zeros(len(self.samples))
        for out_i in range(len(self.samples)):
            i = out_i + self.virtual_root + 1
            out[out_i] = self.stack[i]
        return out


def relatedness_vector(ts, **kwargs):
    rv = RelatednessVector(
        ts.num_nodes,
        samples=ts.samples(),
        nodes_time=ts.nodes_time,
        edges_left=ts.edges_left,
        edges_right=ts.edges_right,
        edges_parent=ts.edges_parent,
        edges_child=ts.edges_child,
        edge_insertion_order=ts.indexes_edge_insertion_order,
        edge_removal_order=ts.indexes_edge_removal_order,
        sequence_length=ts.sequence_length,
        **kwargs,
    )
    return rv.run()

This is a quick numba version for trait simulation. I removed some debugging features and weight updates that are not needed for trait simulation.

For a 1 Megabase region simulated on 10k people, it takes 0.4s.

@hanbin973
Copy link
Contributor

Was this kind of derivation what you had first in mind? I was only able to do this after knowing the results.

Exactly! (Well, you've written it out in more detail than I had, which is helpful!)

I added a slightly more detailed version on overleaf. Check it out if it sounds useful.

@petrelharp
Copy link
Contributor Author

Hah, love it. Let's see - what do we compare the numba version to? Maybe the best thing to do is to make a scaling plot - say, for a 1e8-long sequence, runtime against number of samples, to verify it's scaling not-much-worse-than linear and extrapolate to what sample sizes are do-able?

@petrelharp
Copy link
Contributor Author

I realized there is a simpler algorithm - see tests/test_relatedness_vector2.py. It relies on a lot of cancellation, so has the potential for floating point error in practice. @nspope thinks it's probably fine, based on similarity to other algorithms?

@petrelharp
Copy link
Contributor Author

Here's the overview: in either algorithm. Recall that what we want to compute is: given a vector of weights, for each sample s, the sum over all trees and the sum over all edges above s of the area of the edge multiplied by the weight of the edge.

  • Each node has x, the position at which the subtree below it changed; stack, contributions to all samples below it; and w, the total weight of all samples below it.
  • At any point in the algorithm when at position pos, summing stack[u] and (time[parent[u]] - time[u]) * (pos - x[u]) * w[u] over all parents of a sample gets the output value so far for that sample - ie, what would be output if the tree sequence ended at that point.

And, they differ in that:

  • In test_relatedness_vector.py, we maintain stack[u] to be the (nonnegative) contribution to everything below u not accounted for elsewhere; so, it can increase but also decrease, since when the subtree of u changes we zero out stack[u] and add its value to all its children.
  • In test_relatedness_vector2.py, instead of zeroing out the stack and adding to the children, we subtract its value from any new childen. This creates the possibility for floating point error, because if we insert node c below node p then we update stack[c] -= stack[p], and then we rely on cancelation.

@nspope
Copy link
Contributor

nspope commented Sep 11, 2024

@nspope thinks it's probably fine, based on similarity to other algorithms?

The "cancellation" strategy is similar to what is done for pair coalescence counts and for some of the tsdate internals, which seem to work well enough with large tree sequences; but we should probably do some systematic testing.

@hanbin973
Copy link
Contributor

hanbin973 commented Sep 13, 2024

image-10

Here's the timing result of the new algorithms. It's jitted.

@jeromekelleher
Copy link
Member

This looks great! I'm all for the new cancellation-based algorithm. We do this sort of thing in loads of places, and it's fine, right?

I've not grokked the full details of how all this is working, but the shape of the algorithm looks just right. I think part of what's getting in my way is the "stack". In my mind a "stack" is the list of nodes that you traverse on the way back to root, or the list of functions on the way back to "main". Can we use a different name for this, if this is not a good analogy? It would really help me and maybe other people too.

@petrelharp
Copy link
Contributor Author

I'll be working on this next week. Thanks for the timing plots! They look great, but the new one should be even better. Do we know how many mat-vecs we need in practice (order of magnitude)? Extrapolating from the plots above, maybe a 1e8 chromosome will take <= 5s with that many samples?

@hanbin973
Copy link
Contributor

I'll be working on this next week. Thanks for the timing plots! They look great, but the new one should be even better. Do we know how many mat-vecs we need in practice (order of magnitude)? Extrapolating from the plots above, maybe a 1e8 chromosome will take <= 5s with that many samples?

It takes 4.4s in my machine.
image

@petrelharp
Copy link
Contributor Author

petrelharp commented Sep 16, 2024

I'm thinking this will get the basics in and then we can follow up with PRs to do:

  • allow windows
  • allow centreing
  • allow mode="mutation": AFAICT, like for divergence_matrix there is not an algorithm of the same sort to calculate the site-mode value; rather, the natural corresponding thing is mutation-mode ("mutation" mode for statistics #2982). And, this should probably be the default mode, instead of "branch", to mirror other stats (at least, divergence_matrix does)?

@petrelharp
Copy link
Contributor Author

I've just rebased this on top of #1623, so I can remove the centring from the testing.

@jeromekelleher
Copy link
Member

Can you rebase this please @petrelharp? I'll try and do my bit tomorrow or Friday.

@petrelharp
Copy link
Contributor Author

Done!

@jeromekelleher
Copy link
Member

I got this more-working than it was @petrelharp. There's a few issues:

  • I hacked in the "no windows please" bit into the Python C interface. You were hitting a tsk_bug_assert before
  • There's some issue with the shape of the output, I'm not sure which is correct? I hacked in something quick to see if the tests would pass otherwise.
  • I think the "centre=True" tests are all failing. Again, I don't know which one is correct here, so I'll have to pass it back to you. There you go!

@benjeffery benjeffery added this to the Python 0.5.9 milestone Sep 23, 2024
@benjeffery
Copy link
Member

I've added this work to the next release milestone. Hoping to get a release out in a week or two, if that is too ambitious for this let me know.

@petrelharp
Copy link
Contributor Author

I was a bit worried about how to efficiently deal with clearing out state between windows, but there's a nice way. Naively, we could start at each sample and traverse up the tree to the root to compute the state for that sample. That'd be O(N log N). However, if we clear the state of each node we pass through on the way, that lets us know "we've already been here and everywhere above this"; so we just have to traverse up until we find a node with x == position. This means we hit each edge once; so it's O(N).

But I think I'll do that in another PR?

@jeromekelleher
Copy link
Member

I'll fix up the windows issues here and push an update

Copy link
Member

@jeromekelleher jeromekelleher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! I've made a few perf-oriented comments. Some are easy and should be done, others less so (making the _z function inline) so feel free to ignore that. There's a few iffy restrict pointers declared, which should be removed as well.

Happy to merge and log issues for required follow ups after that?

c/tskit/trees.c Show resolved Hide resolved
c/tskit/trees.c Outdated
}

static void
tsk_matvec_calculator_add_z(const tsk_matvec_calculator_t *self, tsk_id_t u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you restructure this so that any pointers needed are passed in as const restrict parameters? You can then mark it as static inline void function. I think it could be a bottleneck otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had a go at this, but I am more or less guessing what you mean - have a look?

c/tskit/trees.c Outdated
static void
tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c)
{
tsk_id_t *restrict parent = self->parent;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - I'm iffy about the use of restrict here, as parent is accessed in functions that are called within the scope of the restrict declaration. I'd remove it to be on the safe side.

c/tskit/trees.c Outdated
static void
tsk_matvec_calculator_insert_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c)
{
tsk_id_t *restrict parent = self->parent;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar not sure about that restrict

@petrelharp petrelharp marked this pull request as ready for review September 24, 2024 16:29
@petrelharp
Copy link
Contributor Author

I've made those changes - I think? If they look good, go ahead and merge!

@petrelharp
Copy link
Contributor Author

I see two remaining lines not hit by tests; I'll get those in the follow-up PR with windows.

@petrelharp
Copy link
Contributor Author

NEVER MIND, I've just put the code for windows in this PR - it changes the logic a bit (we don't need the virtual samples!) - but in a separate commit so you can have a separate look if you like.

@petrelharp
Copy link
Contributor Author

Okay - this is ready!

Copy link
Member

@jeromekelleher jeromekelleher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is beautiful @petrelharp, such an elegant algorithm!

I spotted a few minor things, just needs a tiny bit more testing on the CPython side. It's ready for a squash and merge then. (Multiple commits are fine, I think it's probably good to squash out the intermediate "let's get this working" stuff)

return 0;
}

static inline void
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - that's exactly what I had in mind.

// do this: self->v[c] -= sign * self->v[p];
p_row = GET_2D_ROW(v, num_weights, p);
c_row = GET_2D_ROW(v, num_weights, c);
for (j = 0; j < num_weights; j++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an observation here that if num_weights is O(100) these loops could be usefully vectorised using AVX instructions and so on.

}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|sii", kwlist, &weights, &windows,
&mode, &span_normalise, &centre)) {
goto out;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we're not doing bad parameter testing in test_lowlevel.py. Doesn't have to be comprehensive, but we do need to make sure the error path is covered (it's a common and hard to track down source of segfaults)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah - I got confused about that.

python/_tskitmodule.c Show resolved Hide resolved
weights=np.ones(bad_weight_shape), mode=mode, **params
)

@pytest.mark.skip(reason="Haven't implemented windows.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is implemented now, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah I knew I wrote tests for those!

python/tests/test_lowlevel.py Outdated Show resolved Hide resolved
python/tests/test_lowlevel.py Outdated Show resolved Hide resolved
@petrelharp
Copy link
Contributor Author

All done!

@jeromekelleher jeromekelleher added the AUTOMERGE-REQUESTED Ask Mergify to merge this PR label Sep 25, 2024
@jeromekelleher
Copy link
Member

Looks like your error checking tests are overly specific

@jeromekelleher jeromekelleher removed the AUTOMERGE-REQUESTED Ask Mergify to merge this PR label Sep 25, 2024
@petrelharp
Copy link
Contributor Author

Whoops not quite

@petrelharp
Copy link
Contributor Author

Looks like error strings changed across python versions.

@petrelharp petrelharp merged commit e724f33 into tskit-dev:main Sep 25, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants