From 39174d26ed499ba396369754c2b29aecca3e561d Mon Sep 17 00:00:00 2001 From: peter Date: Tue, 24 Sep 2024 17:06:53 -0700 Subject: [PATCH] windows!!! closes #2996 --- c/tests/test_stats.c | 97 +++++++++++++++++----- c/tskit/trees.c | 102 ++++++++++++------------ python/_tskitmodule.c | 7 +- python/tests/test_relatedness_vector.py | 15 ++-- 4 files changed, 136 insertions(+), 85 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index d1c42e5625..78adb4e0d9 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2001,10 +2001,12 @@ test_empty_genetic_relatedness_vector(void) double *weights, *result; tsk_size_t j; tsk_size_t num_weights = 2; + double windows[] = { 0, 0 }; tsk_treeseq_from_text( &ts, 1, single_tree_ex_nodes, "", NULL, NULL, NULL, NULL, NULL, 0); num_samples = tsk_treeseq_get_num_samples(&ts); + windows[1] = tsk_treeseq_get_sequence_length(&ts); weights = tsk_malloc(num_weights * num_samples * sizeof(double)); result = tsk_malloc(num_weights * num_samples * sizeof(double)); for (j = 0; j < num_samples; j++) { @@ -2015,11 +2017,11 @@ test_empty_genetic_relatedness_vector(void) } ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 0, NULL, result, 0); + &ts, num_weights, weights, 1, windows, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 0, NULL, result, TSK_STAT_NONCENTRED); + &ts, num_weights, weights, 1, windows, result, TSK_STAT_NONCENTRED); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_treeseq_free(&ts); @@ -2028,47 +2030,67 @@ test_empty_genetic_relatedness_vector(void) } static void -test_paper_ex_genetic_relatedness_vector(void) +verify_genetic_relatedness_vector( + tsk_treeseq_t *ts, tsk_size_t num_weights, tsk_size_t num_windows) { int ret; - tsk_treeseq_t ts; tsk_size_t num_samples; double *weights, *result; - tsk_size_t j; - tsk_size_t num_weights = 2; - - tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, - paper_ex_mutations, paper_ex_individuals, NULL, 0); - num_samples = tsk_treeseq_get_num_samples(&ts); + tsk_size_t j, k; + double *windows = tsk_malloc((num_windows + 1) * sizeof(*windows)); + double L = tsk_treeseq_get_sequence_length(ts); - weights = tsk_malloc(num_weights * num_samples * sizeof(double)); - result = tsk_malloc(num_weights * num_samples * sizeof(double)); - for (j = 0; j < num_samples; j++) { - weights[j] = 1.0; + windows[0] = 0; + windows[num_windows] = L; + for (j = 1; j < num_windows; j++) { + windows[j] = ((double) j) * L / (double) num_windows; } + num_samples = tsk_treeseq_get_num_samples(ts); + + weights = tsk_malloc(num_weights * num_samples * sizeof(*weights)); + result = tsk_malloc(num_windows * num_weights * num_samples * sizeof(*result)); for (j = 0; j < num_samples; j++) { - weights[j + num_samples] = (float) j; + for (k = 0; k < num_weights; k++) { + weights[j + k * num_samples] = 1.0 + (double) k; + } } ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 0, NULL, result, 0); + ts, num_weights, weights, num_windows, windows, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 0, NULL, result, TSK_STAT_NONCENTRED); + ts, num_weights, weights, num_windows, windows, result, TSK_STAT_NONCENTRED); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_set_debug_stream(_devnull); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 0, NULL, result, TSK_DEBUG); + ts, num_weights, weights, num_windows, windows, result, TSK_DEBUG); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_set_debug_stream(stdout); - tsk_treeseq_free(&ts); + free(windows); free(weights); free(result); } +static void +test_paper_ex_genetic_relatedness_vector(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + tsk_size_t j, k; + for (j = 1; j < 3; j++) { + for (k = 1; k < 3; k++) { + verify_genetic_relatedness_vector(&ts, j, k); + } + } + tsk_treeseq_free(&ts); +} + static void test_paper_ex_genetic_relatedness_vector_errors(void) { @@ -2078,6 +2100,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void) double *weights, *result; tsk_size_t j; tsk_size_t num_weights = 2; + double windows[] = { 0, 0, 0 }; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); @@ -2092,11 +2115,43 @@ test_paper_ex_genetic_relatedness_vector_errors(void) weights[j + num_samples] = (float) j; } + /* Window errors */ + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 0, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 0, NULL, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = -1; + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 10; + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0; + windows[2] = 12; + ret = tsk_treeseq_genetic_relatedness_vector( + &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + /* unsupported mode errors */ + windows[0] = 0.0; + windows[1] = 5.0; + windows[2] = 10.0; ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 0, NULL, result, TSK_STAT_SITE); + &ts, num_weights, weights, 2, windows, result, TSK_STAT_SITE); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 0, NULL, result, TSK_STAT_NODE); + &ts, num_weights, weights, 2, windows, result, TSK_STAT_NODE); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); tsk_treeseq_free(&ts); diff --git a/c/tskit/trees.c b/c/tskit/trees.c index be2ca5c678..bc1db29c7b 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -9905,11 +9905,12 @@ typedef struct { const tsk_treeseq_t *ts; tsk_size_t num_weights; const double *weights; + tsk_size_t num_windows; + const double *windows; tsk_flags_t options; double *result; /* tree */ double tree_left; - tsk_id_t virtual_root; tsk_size_t num_nodes; tsk_id_t *parent; double *x; @@ -9920,7 +9921,7 @@ typedef struct { static void tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out) { - tsk_id_t j, u; + tsk_id_t j; tsk_size_t num_samples = tsk_treeseq_get_num_samples(self->ts); fprintf(out, "Matvec state:\n"); @@ -9932,14 +9933,7 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out fprintf(out, "\n"); for (j = 0; j < (tsk_id_t) self->num_nodes; j++) { - if (j < self->virtual_root) { - fprintf(out, "%lld\t", (long long) j); - } else if (j == self->virtual_root) { - fprintf(out, "VR:%lld\t", (long long) j); - } else { - u = self->ts->samples[j - self->virtual_root - 1]; - fprintf(out, "%lld(%lld)\t", (long long) j, (long long) u); - } + fprintf(out, "%lld\t", (long long) j); fprintf(out, "%lld\t%g\t%g\t%g\n", (long long) self->parent[j], self->x[j], self->v[j], self->w[j]); } @@ -9947,25 +9941,27 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out static int tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *ts, - tsk_size_t num_weights, const double *weights, tsk_flags_t options, double *result) + tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) { int ret = 0; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); - const tsk_size_t num_nodes = ts->tables->nodes.num_rows + num_samples + 1; + const tsk_size_t num_nodes = ts->tables->nodes.num_rows; const double *row; double *new_row; tsk_size_t k; - tsk_id_t u, v, j; + tsk_id_t u, j; double *weight_means = tsk_malloc(num_weights * sizeof(*weight_means)); self->ts = ts; self->tree_left = 0.0; self->num_weights = num_weights; self->weights = weights; + self->num_windows = num_windows; + self->windows = windows; self->options = options; self->result = result; self->num_nodes = num_nodes; - self->virtual_root = (tsk_id_t) ts->tables->nodes.num_rows; self->parent = tsk_malloc(num_nodes * sizeof(*self->parent)); self->x = tsk_calloc(num_nodes, sizeof(*self->x)); @@ -9978,7 +9974,7 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t goto out; } - tsk_memset(result, 0, num_samples * num_weights * sizeof(*result)); + tsk_memset(result, 0, num_windows * num_samples * num_weights * sizeof(*result)); tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); for (k = 0; k < num_weights; k++) { @@ -10003,9 +9999,6 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t for (k = 0; k < num_weights; k++) { new_row[k] = row[k] - weight_means[k]; } - // add branch to the virtual sample - v = self->virtual_root + 1 + j; - self->parent[v] = u; } out: tsk_safe_free(weight_means); @@ -10063,10 +10056,8 @@ tsk_matvec_calculator_adjust_path_up( // sign = -1 for removing edges, +1 for adding while (p != TSK_NULL) { - if (p < self->virtual_root) { - tsk_matvec_calculator_add_z( - p, parent[p], tree_left, x, num_weights, w, v, nodes_time); - } + tsk_matvec_calculator_add_z( + p, parent[p], tree_left, x, num_weights, w, v, nodes_time); // do this: self->v[c] -= sign * self->v[p]; p_row = GET_2D_ROW(v, num_weights, p); c_row = GET_2D_ROW(v, num_weights, c); @@ -10094,10 +10085,8 @@ tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk double *restrict v = self->v; const double *restrict nodes_time = self->ts->tables->nodes.time; - if (c < self->virtual_root) { - tsk_matvec_calculator_add_z( - c, parent[c], tree_left, x, num_weights, w, v, nodes_time); - } + tsk_matvec_calculator_add_z( + c, parent[c], tree_left, x, num_weights, w, v, nodes_time); parent[c] = TSK_NULL; tsk_matvec_calculator_adjust_path_up(self, p, c, -1); } @@ -10113,15 +10102,16 @@ tsk_matvec_calculator_insert_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk } static int -tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self) +tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restrict y) { int ret = 0; - tsk_id_t u, v; + tsk_id_t u; tsk_size_t j, k; tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); - double *restrict y = self->result; - double *v_row, *out_row; + double *u_row, *out_row; double *out_means = tsk_malloc(self->num_weights * sizeof(*out_means)); + const tsk_id_t *restrict parent = self->parent; + const double *restrict nodes_time = self->ts->tables->nodes.time; if (out_means == NULL) { ret = TSK_ERR_NO_MEMORY; @@ -10129,17 +10119,18 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self) } for (j = 0; j < n; j++) { - u = self->ts->samples[j]; - v = self->virtual_root + 1 + (tsk_id_t) j; - tsk_bug_assert(u == self->parent[v]); - tsk_matvec_calculator_remove_edge(self, u, v); - } - for (j = 0; j < n; j++) { - v = self->virtual_root + 1 + (tsk_id_t) j; - v_row = GET_2D_ROW(self->v, self->num_weights, v); out_row = GET_2D_ROW(y, self->num_weights, j); - for (k = 0; k < self->num_weights; k++) { - out_row[k] = v_row[k]; + u = self->ts->samples[j]; + while (u != TSK_NULL) { + if (self->x[u] != self->tree_left) { + tsk_matvec_calculator_add_z(u, parent[u], self->tree_left, self->x, + self->num_weights, self->w, self->v, nodes_time); + } + u_row = GET_2D_ROW(self->v, self->num_weights, u); + for (k = 0; k < self->num_weights; k++) { + out_row[k] += u_row[k]; + } + u = parent[u]; } } @@ -10163,6 +10154,8 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self) } } } + /* zero out v */ + tsk_memset(self->v, 0, self->num_nodes * self->num_weights * sizeof(*self->v)); out: tsk_safe_free(out_means); return ret; @@ -10172,8 +10165,9 @@ static int tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) { int ret = 0; - tsk_size_t j, k; + tsk_size_t j, k, m; tsk_id_t e, p, c; + tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); double tree_right; const double sequence_length = self->ts->tables->sequence_length; const tsk_size_t num_edges = self->ts->tables->edges.num_rows; @@ -10183,12 +10177,15 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) const double *restrict edge_left = self->ts->tables->edges.left; const tsk_id_t *restrict edge_child = self->ts->tables->edges.child; const tsk_id_t *restrict edge_parent = self->ts->tables->edges.parent; + double *restrict out; - tree_right = sequence_length; j = 0; k = 0; + m = 0; + tree_right = sequence_length; - while (k < num_edges || self->tree_left < sequence_length) { + while ( + m < self->num_windows && k < num_edges && self->tree_left <= sequence_length) { while (k < num_edges && edge_right[O[k]] == self->tree_left) { e = O[k]; p = edge_parent[e]; @@ -10204,7 +10201,7 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) self->x[c] = self->tree_left; j++; } - tree_right = sequence_length; + tree_right = self->windows[m + 1]; if (j < num_edges) { tree_right = TSK_MIN(tree_right, edge_left[I[j]]); } @@ -10212,11 +10209,15 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) tree_right = TSK_MIN(tree_right, edge_right[O[k]]); } self->tree_left = tree_right; + if (self->tree_left == self->windows[m + 1]) { + out = GET_2D_ROW(self->result, self->num_weights * n, m); + tsk_matvec_calculator_write_output(self, out); + m += 1; + } if (self->options & TSK_DEBUG) { tsk_matvec_calculator_print_state(self, tsk_get_debug_stream()); } } - ret = tsk_matvec_calculator_write_output(self); /* out: */ return ret; @@ -10232,18 +10233,19 @@ tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num bool stat_node = !!(options & TSK_STAT_NODE); tsk_matvec_calculator_t calc; - // TODO add windows - tsk_bug_assert(num_windows == 0); - tsk_bug_assert(windows == NULL); - memset(&calc, 0, sizeof(calc)); if (stat_node || stat_site) { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; goto out; } + ret = tsk_treeseq_check_windows(self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); + if (ret != 0) { + goto out; + } - ret = tsk_matvec_calculator_init(&calc, self, num_weights, weights, options, result); + ret = tsk_matvec_calculator_init( + &calc, self, num_weights, weights, num_windows, windows, options, result); if (ret != 0) { goto out; } diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 176f5ff6b3..6d275a499a 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9706,12 +9706,7 @@ TreeSequence_weighted_stat_vector_method( goto out; } err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), - /* num_windows, PyArray_DATA(windows_array), */ - /* FIXME! The C code has a bug_assert on windows being NULL, and - * it looks like you're not testing anything with windows so far - * anyway, so there's no point in these parameters? - */ - 0, NULL, PyArray_DATA(result_array), options); + num_windows, PyArray_DATA(windows_array), PyArray_DATA(result_array), options); if (err != 0) { handle_library_error(err); goto out; diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index 6eaecd881a..9bd8fc48f9 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -186,7 +186,7 @@ def mrca(self, a, b): b = self.parent[b] return b - def output_state(self): + def write_output(self): """ Compute and return the current state, zero-ing out all contributions (used for switching between windows). @@ -258,7 +258,7 @@ def run(self): right = min(right, edges_right[out_order[k]]) self.position = right if self.position == self.windows[m + 1]: - out[m] = self.output_state() + out[m] = self.write_output() m = m + 1 if self.verbosity > 1: @@ -334,22 +334,21 @@ def verify_relatedness_vector( R2 = np.zeros((len(windows) - 1, ts.num_samples, wvec.shape[1])) for k in range(len(windows) - 1): R2[k] = Sigma[k].dot(wvec) - # R3 = ts.genetic_relatedness_vector(w, windows=windows, mode="branch", - # centre=centre) + R3 = ts.genetic_relatedness_vector(w, windows=windows, mode="branch", centre=centre) if verbosity > 0: print(ts.draw_text()) print("weights:", w) print("windows:", windows) print("here:", R1) print("with ts:", R2) - # print("with lib:", R3) + print("with lib:", R3) print("Sigma:", Sigma) if windows is None: assert R1.shape == (ts.num_samples, wvec.shape[1]) else: assert R1.shape == (len(windows) - 1, ts.num_samples, wvec.shape[1]) np.testing.assert_allclose(R1, R2, atol=1e-14) - # np.testing.assert_allclose(R1, R3, atol=1e-14) + np.testing.assert_allclose(R1, R3, atol=1e-14) return R1 @@ -397,7 +396,7 @@ def test_small_internal_checks(self, n, seed, centre, num_windows): @pytest.mark.parametrize("n", [2, 3, 5, 15]) @pytest.mark.parametrize("seed", range(1, 5)) @pytest.mark.parametrize("centre", (True, False)) - @pytest.mark.parametrize("num_windows", (0, 1, 2)) + @pytest.mark.parametrize("num_windows", (0, 1, 3)) def test_simple_sims(self, n, seed, centre, num_windows): ts = msprime.sim_ancestry( n, @@ -409,7 +408,7 @@ def test_simple_sims(self, n, seed, centre, num_windows): ) assert ts.num_trees >= 2 check_relatedness_vector( - ts, num_windows=num_windows, centre=centre, verbosity=0 + ts, num_windows=num_windows, centre=centre, verbosity=1 ) @pytest.mark.parametrize("n", [2, 3, 5, 15])