Skip to content

Commit

Permalink
added value error and test for permute_seq_by_k
Browse files Browse the repository at this point in the history
  • Loading branch information
PSmaruj committed Jul 30, 2024
1 parent 90320ad commit 3b2ecca
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
4 changes: 3 additions & 1 deletion akita_utils/dna_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def permute_seq_k(seq_1hot, k=2):
seq_length = len(seq_1hot)
if seq_length % k != 0:
raise ValueError("Sequence length must be divisible by k")

if seq_length < k:
raise ValueError("Sequence length must be greater than k")

seq_1hot_perm = np.zeros_like(seq_1hot)

num_permutations = seq_length // k
Expand Down
46 changes: 45 additions & 1 deletion tests/test_dna_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from akita_utils.dna_utils import dna_1hot, dna_1hot_to_seq, dna_seq_rc
import numpy as np
from akita_utils.dna_utils import dna_1hot, dna_1hot_to_seq, dna_seq_rc, test_permute_seq_k


def test_dna_1hot_to_seq():
Expand All @@ -9,3 +10,46 @@ def test_dna_seq_rc():
seq = "ACTG"
rc_seq = "CAGT"
assert dna_seq_rc(seq) == rc_seq


def test_permute_seq_k():
# Test 1: Basic functionality
seq_1hot = np.array([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 1, 0, 0]])
k = 2
result = permute_seq_k(seq_1hot, k)
assert result.shape == seq_1hot.shape, "Shape mismatch"
assert not np.array_equal(result, seq_1hot), "Permutation did not change sequence"

# Check that the sequence is a permutation of k-mers
permuted_k_mers = [result[i:i + k] for i in range(0, len(result), k)]
original_k_mers = [seq_1hot[i:i + k] for i in range(0, len(seq_1hot), k)]
assert sorted(map(tuple, original_k_mers)) == sorted(map(tuple, permuted_k_mers)), "Permuted k-mers are not as expected"

# Test 2: Edge case where sequence length is not divisible by k
try:
permute_seq_k(seq_1hot, k=3)
except ValueError as e:
assert str(e) == "Sequence length must be divisible by k", "Incorrect error message for length not divisible by k"

# Test 3: Edge case where k > sequence length
short_seq_1hot = np.array([[1, 0, 0, 0],
[0, 1, 0, 0]])
k_large = 3
try:
permute_seq_k(short_seq_1hot, k=k_large)
except ValueError as e:
assert str(e) == "Sequence length must be divisible by k", "Incorrect error message for k > sequence length"

# Test 4: Randomness check (manual validation needed)
np.random.seed(0) # For reproducibility
result1 = permute_seq_k(seq_1hot, k)
np.random.seed(0) # Reset seed to verify same permutation
result2 = permute_seq_k(seq_1hot, k)
assert np.array_equal(result1, result2), "Results should be consistent with the same seed"

print("All tests passed!")

0 comments on commit 3b2ecca

Please sign in to comment.