Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove general_stat cacheing #1937

Closed
wants to merge 1 commit into from

Conversation

brieuclehmann
Copy link
Contributor

@brieuclehmann brieuclehmann commented Nov 24, 2021

Here's a first pass at removing cacheing to reduce memory consumption in the general_stat framework. I've only tweaked the branch C function, and the python tests are passing.

Without caching, however, ts.genetic_relatedness is still slower. For the script below, with caching takes about 10s, without caching: 20s. When I increased n_samples to 1000: with caching - 58s; without caching - 110s.

I may have introduced some unnecessary computation so would appreciate a sense-check!

import itertools
import time
import msprime
import numpy as np
import tskit
from memory_profiler import profile

fp = open("memprof_cache.log", "a+")
@profile(stream=fp)
def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    return x

seed = 42
n_samples = 500
ts = msprime.simulate(
    n_samples,
    length=1e6,
    recombination_rate=1e-8,
    Ne=1e4,
    mutation_rate=1e-7,
    random_seed=seed,
)

n_ind = int(ts.num_samples / 2)
sample_sets = [(2 * i, (2 * i) + 1) for i in range(n_ind)]

n = len(sample_sets)
indexes = [
    (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
]

start_time = time.time()
x = genetic_relatedness_matrix(ts, sample_sets, indexes, 'branch')
end_time = time.time()
print(end_time - start_time)

@codecov
Copy link

codecov bot commented Nov 24, 2021

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 79.57%. Comparing base (315e47a) to head (0b96818).
Report is 562 commits behind head on main.

❗ There is a different number of reports uploaded between BASE (315e47a) and HEAD (0b96818). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (315e47a) HEAD (0b96818)
python-tests 2 0
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1937       +/-   ##
===========================================
- Coverage   93.15%   79.57%   -13.59%     
===========================================
  Files          27       27               
  Lines       25085    24959      -126     
  Branches     1109     1107        -2     
===========================================
- Hits        23369    19862     -3507     
- Misses       1682     5036     +3354     
- Partials       34       61       +27     
Flag Coverage Δ
c-tests 92.20% <100.00%> (+<0.01%) ⬆️
lwt-tests 89.14% <ø> (ø)
python-c-tests 68.29% <ø> (ø)
python-tests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
c/tskit/trees.c 94.19% <100.00%> (+0.01%) ⬆️

... and 12 files with indirect coverage changes

@molpopgen
Copy link
Member

Any idea how much this drops RAM use? For me, the case that motivated opening the issue on RAM use was computing the pairwise distance matrix. It blew up to 7+ GB for 1000s of nodes.

@jeromekelleher
Copy link
Member

jeromekelleher commented Nov 24, 2021

We had a discussion about this earlier and @brieuclehmann is going to run it through memory-profiler. Assuming we're on the right track memory reduction wise, the plan is to add an option to general_stat to disable caching like this.

@brieuclehmann
Copy link
Contributor Author

brieuclehmann commented Nov 25, 2021

Have used memory-profiler to check RAM usage (see updated script above). Oddly, there doesn't appear to be too much of a difference with and without caching, but I'm not 100% sure I'm reading the output correctly (or have set up the profiling properly).

memprof_cache.log
memprof_nocache.log
)

@molpopgen
Copy link
Member

Is memory-profiler measuring actual process memory, or just memory occupied by Python object? The caching is happening on the C side, and so may not be visible? If you are on Linux, then this may be more relevant:

/usr/bin/time -f "%e %M" python script.py

The second number output will be the peak RAM use in KB.

@jeromekelleher
Copy link
Member

Is memory-profiler measuring actual process memory,

Yes, it takes periodic snapshots using OS routines, so it's the "real" process footprint.

@jeromekelleher
Copy link
Member

jeromekelleher commented Nov 26, 2021

Pasting in the profiles for ease:

ilename: test_time.py
  
Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     65.5 MiB     65.5 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, mode):
    11     65.5 MiB      0.0 MiB           1       n = len(sample_sets)
    12     67.6 MiB    -46.8 MiB       31379       indexes = [
    13     67.6 MiB    -44.6 MiB       31376           (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
    14                                             ]
    15     67.6 MiB      0.0 MiB           1       K = np.zeros((n, n))
    16     74.9 MiB      7.3 MiB           2       K[np.triu_indices(n)] = ts.genetic_relatedness(
    17     67.6 MiB      0.0 MiB           1           sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
    18                                             )
    19     74.9 MiB      0.0 MiB           1       K = K + np.triu(K, 1).transpose()
    20     74.9 MiB      0.0 MiB           1       return K

Filename: test_time.py
  
Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     64.6 MiB     64.6 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, mode):
    11     64.6 MiB      0.0 MiB           1       n = len(sample_sets)
    12     66.7 MiB    -69.3 MiB       31379       indexes = [
    13     66.7 MiB    -67.1 MiB       31376           (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
    14                                             ]
    15     66.7 MiB      0.0 MiB           1       K = np.zeros((n, n))
    16     66.7 MiB    -18.0 MiB           2       K[np.triu_indices(n)] = ts.genetic_relatedness(
    17     66.7 MiB      0.0 MiB           1           sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
    18                                             )
    19     48.8 MiB    -17.9 MiB           1       K = K + np.triu(K, 1).transpose()
    20     48.8 MiB      0.0 MiB           1       return K

@jeromekelleher
Copy link
Member

There isn't a huge difference either way here - maybe run this on a later example? Also do just one thing on the line where we call genetic_relatedness, the assignment stuff on the LHS will be complicating things. (so x = ts.genetic_related...)

@brieuclehmann
Copy link
Contributor Author

I simplified the code slightly (see edits above) so that it only profiles the ts.genetic_relatedness call. Here are the results with n_samples = 500 across 3 runs of the same script. Pretty similar overall usage between the two (not sure what's going on in line 11 of each?). Going to run again now with n_samples = 2000
With caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     67.3 MiB     67.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     58.3 MiB     -9.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     58.3 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     67.3 MiB     67.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     74.5 MiB      7.2 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     74.5 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     67.3 MiB     67.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     21.0 MiB    -46.3 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     21.0 MiB      0.0 MiB           1       return x

Without caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     68.6 MiB     68.6 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     56.6 MiB    -12.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     56.7 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     68.1 MiB     68.1 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     52.1 MiB    -16.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     52.1 MiB      0.0 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9     68.3 MiB     68.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     75.4 MiB      7.1 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     75.4 MiB      0.0 MiB           1       return x

@petrelharp
Copy link
Contributor

Hm - notes:

  • update_node_summary and update_running_sum can be combined into one (I think), eliminating some extra copying around of results
  • it's a bit hard to know how well it's possible to do, since for mode="branch" we don't have something convenient to compare it to (although we could use Caoqi's eGRM, I guess?) - looking at mode="site" we could compare to the timings that Gregor reported
  • have you computed how much less memory we expect it to be using? Like, what's the difference in memory usage for the two calloc()'s with num_samples = 1000 and num_nodes equal to whatever it is?

@brieuclehmann
Copy link
Contributor Author

Thanks @petrelharp ! Agreed with your point about combining update_node_summary and update_running_sum - is it worth doing this now or shall I wait until we're sure we want to include a no caching option?

For expected change in memory usage, I guess it should be a num_nodes-fold decrease between the two callocs, so I'm pretty surprised that they are so similar. Is there a way to check where in the C code memory is being used? For context, when n_samples = 2000 in the above examples, we have num_nodes = 6068.

@brieuclehmann
Copy link
Contributor Author

And here the memory-profiler results for n_samples = 2000, TL;DR largely the same again

With caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    111.1 MiB    111.1 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     31.0 MiB    -80.1 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     31.1 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    110.9 MiB    110.9 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     33.9 MiB    -77.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     34.0 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    111.3 MiB    111.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     27.8 MiB    -83.4 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     27.9 MiB      0.1 MiB           1       return x

Without caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    109.9 MiB    109.9 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     56.3 MiB    -53.6 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     56.4 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    110.0 MiB    110.0 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     54.0 MiB    -56.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     54.1 MiB      0.1 MiB           1       return x

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    110.8 MiB    110.8 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     31.6 MiB    -79.2 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     31.7 MiB      0.1 MiB           1       return x

@jeromekelleher
Copy link
Member

For context, when n_samples = 2000 in the above examples, we have num_nodes = 6068.

Ah, I see what's happening here - your tree sequence is quite short, so there's not that many nodes. Try increasing the sequence length or recombination rate so that you have at least 100K trees. Then you should see some difference.

@brieuclehmann
Copy link
Contributor Author

Ah OK, now with n_samples = 500 and recombination_rate = 1e-7 corresponding to ~250K trees and ~150k nodes. We now see a slight improvement without caching.

With caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    215.7 MiB    215.7 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     26.6 MiB   -189.1 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     26.7 MiB      0.1 MiB           1       return x

Without caching:

Line #    Mem usage    Increment  Occurences   Line Contents
============================================================
     9    197.3 MiB    197.3 MiB           1   @profile(stream=fp)
    10                                         def genetic_relatedness_matrix(ts, sample_sets, indexes, mode):
    11     36.3 MiB   -161.0 MiB           1       x = ts.genetic_relatedness(sample_sets, indexes, mode=mode, proportion=False, span_normalise=False)
    12     36.4 MiB      0.1 MiB           1       return x

@jeromekelleher
Copy link
Member

Still surprisingly little...

@petrelharp
Copy link
Contributor

ping @jeromekelleher

@jeromekelleher
Copy link
Member

Ah - I don't think this is computing the right value @brieuclehmann. Tests are failing because we're getting different numbers.

I'll think about how to do this again.

@benjeffery
Copy link
Member

Closing this for inactivity - reopen if needed!

@benjeffery benjeffery closed this Sep 23, 2024
@petrelharp
Copy link
Contributor

Nope - this is superceded by #2980.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants