Skip to content

Commit

Permalink
perf(call): minor optimizations with lookups, SNV distance
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Sep 19, 2024
1 parent b694f19 commit 7ca6dee
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 27 deletions.
63 changes: 36 additions & 27 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .snvs import (
SNV_GAP_CHAR,
SNV_OUT_OF_RANGE_CHAR,
SNV_NA_CHARS,
call_and_filter_useful_snvs,
process_read_snvs_for_locus_and_calculate_useful_snvs,
)
Expand Down Expand Up @@ -185,45 +186,46 @@ def calculate_read_distance(
# Initialize a distance matrix for all reads
distance_matrix = np.zeros((n_reads, n_reads), dtype=np.float_)

@functools.cache
def _skip_set(idx: int) -> set:
r_snv_u = read_dict_items[idx][1]["snvu"]
return set(
filter(
lambda y: r_snv_u[y][0] == SNV_OUT_OF_RANGE_CHAR or
(r_snv_u[y][0] != SNV_GAP_CHAR and r_snv_u[y][1] < snv_quality_threshold),
useful_snvs_range
)
)

# Loop through and compare all vs. all reads. We can skip a few indices since the distance will be symmetrical.
for i in range(n_reads - 1):
r1 = read_dict_items[i][1]
r1_snv_u = r1["snvu"]

r1_skip: set[int] = set(
filter(
lambda y: r1_snv_u[y][0] == SNV_OUT_OF_RANGE_CHAR or
(r1_snv_u[y][0] != SNV_GAP_CHAR and r1_snv_u[y][1] < snv_quality_threshold),
useful_snvs_range
)
)
r1_skip: set[int] = _skip_set(i)

for j in range(i + 1, n_reads):
r2 = read_dict_items[j][1]
r2_snv_u = r2["snvu"]

n_not_equal: int = 0
d: float = 0.0
n_comparable: int = 0

r2_skip = _skip_set(j)

for z in useful_snvs_range:
if z in r1_skip:
continue

r2_b, r2_bq = r2_snv_u[z]
if r2_b == SNV_OUT_OF_RANGE_CHAR:
continue

if r2_b != SNV_GAP_CHAR and r2_bq < snv_quality_threshold:
# too low quality to incorporate in distance metric
if z in r2_skip:
continue

r1_b, r1_bq = r1_snv_u[z]
if r1_b != r2_b:
n_not_equal += 1
if r1_b != r2_snv_u[z][0]:
d += 1.0 # increase distance by 1 for each mismatched SNV

n_comparable += 1

d: float = float(n_not_equal)
if not pure_snv_peak_assignment: # Add in copy number distance
d += abs(r1["cn"] - r2["cn"]) * (
relative_cn_distance_weight_scaling_many if n_comparable >= many_snvs_quantity
Expand Down Expand Up @@ -540,7 +542,7 @@ def call_alleles_with_incorporated_snvs(
1 for _ in filter(
# If we were calling SNVs from scratch, we used to include the gap character. However, it seems to cause
# more issues than not - let's stick to real SNVs...
lambda s: s[0] not in (SNV_OUT_OF_RANGE_CHAR, SNV_GAP_CHAR) and s[1] >= snv_quality_threshold,
lambda s: s[0] not in SNV_NA_CHARS and s[1] >= snv_quality_threshold,
read_useful_snv_bases
)
)
Expand Down Expand Up @@ -776,6 +778,7 @@ def debug_log_flanking_seq(logger_: logging.Logger, locus_log_str: str, rn: str,
def _ndarray_serialize(x: Iterable) -> list[Union[int, np.int_]]:
return list(map(round, x))


def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, np.int_]]]:
return list(map(_ndarray_serialize, x))

Expand Down Expand Up @@ -818,8 +821,10 @@ def call_locus(
flank_size = params.flank_size
realign = params.realign
respect_ref = params.respect_ref
targeted = params.targeted
min_read_align_score = params.min_read_align_score
snv_min_base_qual = params.snv_min_base_qual
use_hp = params.use_hp
# ----------------------------------

rng = np.random.default_rng(seed=seed)
Expand Down Expand Up @@ -983,6 +988,10 @@ def call_locus(

sorted_read_lengths = np.sort(read_lengths)

@functools.cache
def get_read_length_partition_mean(p_idx: int) -> float:
return np.mean(sorted_read_lengths[p_idx:]).item()

# Find candidate SNVs, if we're using SNV data

candidate_snvs: Optional[CandidateSNVs] = None # Lookup dictionary for candidate SNVs by position
Expand Down Expand Up @@ -1257,28 +1266,28 @@ def call_locus(
)
exit(1)

mean_containing_size = read_len if params.targeted else np.mean(sorted_read_lengths[partition_idx:]).item()
mean_containing_size = read_len if targeted else get_read_length_partition_mean(partition_idx)
# TODO: re-examine weighting to possibly incorporate chance of drawing read large enough
read_weight = (mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1)

crs_cir = chimeric_read_status[rn] == 3 # Chimera within the TR region, indicating a potential large expansion
read_dict[rn] = {
read_dict[rn] = read_dict_entry = {
"s": "-" if segment.is_reverse else "+",
"cn": read_cn,
"w": read_weight,
**({"realn": realigned} if realign and realigned else {}),
**({"chimeric_in_region": crs_cir} if crs_cir else {}),
**({"kmers": dict(read_kmers)} if count_kmers != "none" else {}),
}
read_dict_extra[rn] = {
read_dict_extra[rn] = read_extra_entry = {
"_ref_start": segment_start,
"_ref_end": segment_end,
**({"_tr_seq": tr_read_seq} if consensus else {}),
}

# Reads can show up more than once - TODO - cache this information across loci

if params.use_hp:
if use_hp:
if (hp := segment.hp) is not None and (ps := segment.ps) is not None:
orig_ps = int(ps)

Expand All @@ -1295,17 +1304,17 @@ def call_locus(

phase_set_lock.release()

read_dict[rn]["hp"] = hp # not none inside this if-statement
read_dict[rn]["ps"] = ps_remapped
read_dict_entry["hp"] = hp # not none inside this if-statement
read_dict_entry["ps"] = ps_remapped
haplotags.add(hp)
haplotagged_reads_count += 1
phase_sets[ps_remapped] += 1

if should_incorporate_snvs:
# Store the segment sequence and qualities in the read dict for the next go-around if we've enabled SNV
# incorporation, in order to pass them to the get_read_snvs function with the cached ref string.
read_dict_extra[rn]["_qs"] = qs
read_dict_extra[rn]["_fqqs"] = fqqs
read_extra_entry["_qs"] = qs
read_extra_entry["_fqqs"] = fqqs

# Observed significant increase in annoying, probably false SNVs near the edges of significantly
# clipped reads in CCS data. Figure out if we have large clipping for later use here in the SNV finder.
Expand Down Expand Up @@ -1364,7 +1373,7 @@ def call_locus(

allele_start_time = time.perf_counter()

if params.use_hp:
if use_hp:
top_ps = phase_sets.most_common(1)
if (haplotagged_reads_count >= min_hp_read_coverage and len(haplotags) == n_alleles and top_ps and
top_ps[0][1] >= min_hp_read_coverage):
Expand Down
2 changes: 2 additions & 0 deletions strkit/call/snvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
__all__ = [
"SNV_OUT_OF_RANGE_CHAR",
"SNV_GAP_CHAR",
"SNV_NA_CHARS",
"get_read_snvs",
"call_and_filter_useful_snvs",
"process_read_snvs_for_locus_and_calculate_useful_snvs",
]

SNV_OUT_OF_RANGE_CHAR = "-"
SNV_GAP_CHAR = "_"
SNV_NA_CHARS = (SNV_OUT_OF_RANGE_CHAR, SNV_GAP_CHAR)


def call_and_filter_useful_snvs(
Expand Down

0 comments on commit 7ca6dee

Please sign in to comment.