Skip to content

Commit

Permalink
Adapting some tests with new time_windows parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
tforest committed Oct 30, 2024
1 parent da5f205 commit 6b3ab4f
Showing 1 changed file with 166 additions and 8 deletions.
174 changes: 166 additions & 8 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,33 +2239,57 @@ def test_basic_example(self):
ts = self.get_example_tree_sequence()
n = ts.get_num_samples()
result = ts.allele_frequency_spectrum(
[n], ts.get_samples(), [0, ts.get_sequence_length()]
[n],
ts.get_samples(),
[0, ts.get_sequence_length()],
mode="branch",
time_windows=[0, np.inf],
)
assert result.shape == (1, n + 1)
assert result.shape == (1, 1, n + 1)
result = ts.allele_frequency_spectrum(
[n], ts.get_samples(), [0, ts.get_sequence_length()], polarised=True
[n],
ts.get_samples(),
[0, ts.get_sequence_length()],
mode="branch",
time_windows=[0, np.inf],
polarised=True,
)
assert result.shape == (1, n + 1)
assert result.shape == (1, 1, n + 1)

def test_output_dims(self):
ts = self.get_example_tree_sequence()
samples = ts.get_samples()
L = ts.get_sequence_length()
n = len(samples)
time_windows = [0, np.inf]

for mode in ["site", "branch"]:
for s in [[n], [n - 2, 2], [n - 4, 2, 2], [1] * n]:
s = np.array(s, dtype=np.uint32)
windows = [0, L]
for windows in [[0, L], [0, L / 2, L], np.linspace(0, L, num=10)]:
jafs = ts.allele_frequency_spectrum(
s, samples, windows, mode=mode, polarised=True
s,
samples,
windows,
mode=mode,
time_windows=time_windows,
polarised=True,
)
assert jafs.shape == tuple(
[len(windows) - 1] + [len(time_windows) - 1] + list(s + 1)
)
assert jafs.shape == tuple([len(windows) - 1] + list(s + 1))
jafs = ts.allele_frequency_spectrum(
s, samples, windows, mode=mode, polarised=False
s,
samples,
windows,
mode=mode,
time_windows=time_windows,
polarised=False,
)
assert jafs.shape == tuple(
[len(windows) - 1] + [len(time_windows) - 1] + list(s + 1)
)
assert jafs.shape == tuple([len(windows) - 1] + list(s + 1))

def test_node_mode_not_supported(self):
ts = self.get_example_tree_sequence()
Expand All @@ -2275,8 +2299,142 @@ def test_node_mode_not_supported(self):
ts.get_samples(),
[0, ts.get_sequence_length()],
mode="node",
time_windows=[0, np.inf],
)

def test_polarised(self):
"""
Temporary duplicate from class OneWaySampleStatsMixin
used to provide the time_windows argument.
"""
# TODO move this to the top level.
ts, method = self.get_method()
samples = ts.get_samples()
n = len(samples)
windows = [0, ts.get_sequence_length()]
method(
[n],
samples,
windows,
time_windows=[0, np.inf],
mode="branch",
polarised=True,
)
method(
[n],
samples,
windows,
time_windows=[0, np.inf],
mode="branch",
polarised=False,
)

def test_polarisation(self):
ts, f, params = self.get_example()
with pytest.raises(TypeError):
f(polarised="sdf", time_windows=[0, np.inf], mode="branch", **params)
x1 = f(polarised=False, time_windows=[0, np.inf], mode="branch", **params)
x2 = f(polarised=True, time_windows=[0, np.inf], mode="branch", **params)
# Basic check just to run both code paths
assert x1.shape == x2.shape

def test_mode_errors(self):
_, f, params = self.get_example()
for bad_mode in ["", "not a mode", "SITE", "x" * 8192]:
with pytest.raises(ValueError):
f(mode=bad_mode, time_windows=[0, np.inf], **params)

for bad_type in [123, {}, None, [[]]]:
with pytest.raises(TypeError):
f(mode=bad_type, time_windows=[0, np.inf], **params)

def test_window_errors(self):
ts, f, params = self.get_example()
del params["windows"]
for bad_array in ["asdf", None, [[[[]], [[]]]], np.zeros((10, 3, 4))]:
with pytest.raises(ValueError):
f(windows=bad_array, time_windows=[0, np.inf], mode="branch", **params)

for bad_windows in [[], [0]]:
with pytest.raises(ValueError):
f(
windows=bad_windows,
time_windows=[0, np.inf],
mode="branch",
**params,
)
L = ts.get_sequence_length()
bad_windows = [
[L, 0],
[0.1, L],
[-1, L],
[0, L + 0.1],
[0, 0.1, 0.1, L],
[0, -1, L],
[0, 0.1, 0.05, 0.2, L],
]
for bad_window in bad_windows:
with pytest.raises(_tskit.LibraryError):
f(windows=bad_window, time_windows=[0, np.inf], mode="branch", **params)

def test_windows_output(self):
ts, f, params = self.get_example()
del params["windows"]
for num_windows in range(1, 10):
windows = np.linspace(0, ts.get_sequence_length(), num=num_windows + 1)
assert windows.shape[0] == num_windows + 1
sigma = f(
windows=windows, time_windows=[0, np.inf], mode="branch", **params
)
assert sigma.shape[0] == num_windows

def test_bad_sample_sets(self):
ts, f, params = self.get_example()
del params["sample_set_sizes"]
del params["sample_sets"]

with pytest.raises(_tskit.LibraryError):
f(
sample_sets=[],
sample_set_sizes=[],
time_windows=[0, np.inf],
mode="branch",
**params,
)

n = ts.get_num_samples()
samples = ts.get_samples()
for bad_set_sizes in [[], [1], [n - 1], [n + 1], [n - 3, 1, 1], [1, n - 2]]:
with pytest.raises(ValueError):
f(
sample_set_sizes=bad_set_sizes,
sample_sets=samples,
time_windows=[0, np.inf],
mode="branch",
**params,
)

N = ts.get_num_nodes()
for bad_node in [-1, N, N + 1, -N]:
with pytest.raises(_tskit.LibraryError):
f(
sample_set_sizes=[2],
sample_sets=[0, bad_node],
time_windows=[0, np.inf],
mode="branch",
**params,
)

for bad_sample in [n, n + 1, N - 1]:
with pytest.raises(_tskit.LibraryError):
f(
sample_set_sizes=[2],
sample_sets=[0, bad_sample],
time_windows=[0, np.inf],
mode="branch",
**params,
)


class TwoWaySampleStatsMixin(SampleSetMixin):
"""
Expand Down

0 comments on commit 6b3ab4f

Please sign in to comment.