Skip to content

Commit

Permalink
Test _construct_log_Q_offdiag
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Dec 6, 2023
1 parent ac6b007 commit feb935f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
3 changes: 1 addition & 2 deletions src/pmhn/_trees/_backend_jax/_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@ def _construct_log_transtion_rate(

def _construct_log_exit_rate(
traj: Int[Array, " n_events"],
extended_omega: Float[Array, " n+1"],
extended_omega: Float[Array, " G+1"],
) -> Float:
return jnp.sum(extended_omega[traj - 1])


def _construct_log_Q_offdiag(
paths: DoublyIndexedPaths, extended_theta: Float[Array, "G+1 G+1"]
) -> Values:
# TODO(Pawel): UNTESTED
return Values(
start=paths.start,
end=paths.end,
Expand Down
42 changes: 42 additions & 0 deletions tests/trees/backend_jax/test_rates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import jax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
import pmhn._trees._backend_jax._private_api as api
import pmhn._trees._backend_jax._rates as rates
import pytest

Expand Down Expand Up @@ -67,6 +69,46 @@ def test_construct_log_exit_rate_3() -> None:
npt.assert_allclose(0.0, rates._construct_log_exit_rate(path, extended_omega))


# *** _construct_log_Q_offdiag ***


@pytest.mark.parametrize("n_genes", [4, 5, 21])
def test_construct_log_Q_offdiag(n_genes: int, seed: int = 42) -> None:
rng = np.random.default_rng(seed)

paths = api.DoublyIndexedPaths(
start=jnp.arange(7),
end=jnp.arange(8, 35, 2)[:7],
path=jnp.asarray(
[
[0, 0, rng.integers(1, n_genes)],
[0, 0, 2],
[0, 2, 1],
[0, 1, 1 + rng.integers(1, n_genes - 1)],
[0, 1, 3],
[1, 3, 2],
[0, 1, 3],
]
),
)
theta = jnp.asarray(rng.normal(size=(n_genes, n_genes)))
extended_theta = rates._extend_theta(theta)

offdiag = rates._construct_log_Q_offdiag(paths=paths, extended_theta=extended_theta)

npt.assert_allclose(offdiag.start, paths.start)
npt.assert_allclose(offdiag.end, paths.end)
expected = jnp.asarray(
[
rates._construct_log_transtion_rate(
traj=traj, extended_theta=extended_theta
)
for traj in paths.path
]
)
npt.assert_allclose(offdiag.value, expected)


# *** _construct_log_U ***


Expand Down

0 comments on commit feb935f

Please sign in to comment.