diff --git a/strkit/call/call_locus.py b/strkit/call/call_locus.py index f04b3b9..93b9dde 100644 --- a/strkit/call/call_locus.py +++ b/strkit/call/call_locus.py @@ -37,6 +37,7 @@ from .snvs import ( SNV_GAP_CHAR, SNV_OUT_OF_RANGE_CHAR, + SNV_NA_CHARS, call_and_filter_useful_snvs, process_read_snvs_for_locus_and_calculate_useful_snvs, ) @@ -185,45 +186,46 @@ def calculate_read_distance( # Initialize a distance matrix for all reads distance_matrix = np.zeros((n_reads, n_reads), dtype=np.float_) + @functools.cache + def _skip_set(idx: int) -> set: + r_snv_u = read_dict_items[idx][1]["snvu"] + return set( + filter( + lambda y: r_snv_u[y][0] == SNV_OUT_OF_RANGE_CHAR or + (r_snv_u[y][0] != SNV_GAP_CHAR and r_snv_u[y][1] < snv_quality_threshold), + useful_snvs_range + ) + ) + # Loop through and compare all vs. all reads. We can skip a few indices since the distance will be symmetrical. for i in range(n_reads - 1): r1 = read_dict_items[i][1] r1_snv_u = r1["snvu"] - r1_skip: set[int] = set( - filter( - lambda y: r1_snv_u[y][0] == SNV_OUT_OF_RANGE_CHAR or - (r1_snv_u[y][0] != SNV_GAP_CHAR and r1_snv_u[y][1] < snv_quality_threshold), - useful_snvs_range - ) - ) + r1_skip: set[int] = _skip_set(i) for j in range(i + 1, n_reads): r2 = read_dict_items[j][1] r2_snv_u = r2["snvu"] - n_not_equal: int = 0 + d: float = 0.0 n_comparable: int = 0 + r2_skip = _skip_set(j) + for z in useful_snvs_range: if z in r1_skip: continue - r2_b, r2_bq = r2_snv_u[z] - if r2_b == SNV_OUT_OF_RANGE_CHAR: - continue - - if r2_b != SNV_GAP_CHAR and r2_bq < snv_quality_threshold: - # too low quality to incorporate in distance metric + if z in r2_skip: continue r1_b, r1_bq = r1_snv_u[z] - if r1_b != r2_b: - n_not_equal += 1 + if r1_b != r2_snv_u[z][0]: + d += 1.0 # increase distance by 1 for each mismatched SNV n_comparable += 1 - d: float = float(n_not_equal) if not pure_snv_peak_assignment: # Add in copy number distance d += abs(r1["cn"] - r2["cn"]) * ( relative_cn_distance_weight_scaling_many if n_comparable >= many_snvs_quantity @@ -540,7 +542,7 @@ def call_alleles_with_incorporated_snvs( 1 for _ in filter( # If we were calling SNVs from scratch, we used to include the gap character. However, it seems to cause # more issues than not - let's stick to real SNVs... - lambda s: s[0] not in (SNV_OUT_OF_RANGE_CHAR, SNV_GAP_CHAR) and s[1] >= snv_quality_threshold, + lambda s: s[0] not in SNV_NA_CHARS and s[1] >= snv_quality_threshold, read_useful_snv_bases ) ) @@ -776,6 +778,7 @@ def debug_log_flanking_seq(logger_: logging.Logger, locus_log_str: str, rn: str, def _ndarray_serialize(x: Iterable) -> list[Union[int, np.int_]]: return list(map(round, x)) + def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, np.int_]]]: return list(map(_ndarray_serialize, x)) @@ -818,8 +821,10 @@ def call_locus( flank_size = params.flank_size realign = params.realign respect_ref = params.respect_ref + targeted = params.targeted min_read_align_score = params.min_read_align_score snv_min_base_qual = params.snv_min_base_qual + use_hp = params.use_hp # ---------------------------------- rng = np.random.default_rng(seed=seed) @@ -983,6 +988,10 @@ def call_locus( sorted_read_lengths = np.sort(read_lengths) + @functools.cache + def get_read_length_partition_mean(p_idx: int) -> float: + return np.mean(sorted_read_lengths[p_idx:]).item() + # Find candidate SNVs, if we're using SNV data candidate_snvs: Optional[CandidateSNVs] = None # Lookup dictionary for candidate SNVs by position @@ -1257,12 +1266,12 @@ def call_locus( ) exit(1) - mean_containing_size = read_len if params.targeted else np.mean(sorted_read_lengths[partition_idx:]).item() + mean_containing_size = read_len if targeted else get_read_length_partition_mean(partition_idx) # TODO: re-examine weighting to possibly incorporate chance of drawing read large enough read_weight = (mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1) crs_cir = chimeric_read_status[rn] == 3 # Chimera within the TR region, indicating a potential large expansion - read_dict[rn] = { + read_dict[rn] = read_dict_entry = { "s": "-" if segment.is_reverse else "+", "cn": read_cn, "w": read_weight, @@ -1270,7 +1279,7 @@ def call_locus( **({"chimeric_in_region": crs_cir} if crs_cir else {}), **({"kmers": dict(read_kmers)} if count_kmers != "none" else {}), } - read_dict_extra[rn] = { + read_dict_extra[rn] = read_extra_entry = { "_ref_start": segment_start, "_ref_end": segment_end, **({"_tr_seq": tr_read_seq} if consensus else {}), @@ -1278,7 +1287,7 @@ def call_locus( # Reads can show up more than once - TODO - cache this information across loci - if params.use_hp: + if use_hp: if (hp := segment.hp) is not None and (ps := segment.ps) is not None: orig_ps = int(ps) @@ -1295,8 +1304,8 @@ def call_locus( phase_set_lock.release() - read_dict[rn]["hp"] = hp # not none inside this if-statement - read_dict[rn]["ps"] = ps_remapped + read_dict_entry["hp"] = hp # not none inside this if-statement + read_dict_entry["ps"] = ps_remapped haplotags.add(hp) haplotagged_reads_count += 1 phase_sets[ps_remapped] += 1 @@ -1304,8 +1313,8 @@ def call_locus( if should_incorporate_snvs: # Store the segment sequence and qualities in the read dict for the next go-around if we've enabled SNV # incorporation, in order to pass them to the get_read_snvs function with the cached ref string. - read_dict_extra[rn]["_qs"] = qs - read_dict_extra[rn]["_fqqs"] = fqqs + read_extra_entry["_qs"] = qs + read_extra_entry["_fqqs"] = fqqs # Observed significant increase in annoying, probably false SNVs near the edges of significantly # clipped reads in CCS data. Figure out if we have large clipping for later use here in the SNV finder. @@ -1364,7 +1373,7 @@ def call_locus( allele_start_time = time.perf_counter() - if params.use_hp: + if use_hp: top_ps = phase_sets.most_common(1) if (haplotagged_reads_count >= min_hp_read_coverage and len(haplotags) == n_alleles and top_ps and top_ps[0][1] >= min_hp_read_coverage): diff --git a/strkit/call/snvs.py b/strkit/call/snvs.py index 2f21b0c..e9d3f42 100644 --- a/strkit/call/snvs.py +++ b/strkit/call/snvs.py @@ -13,6 +13,7 @@ __all__ = [ "SNV_OUT_OF_RANGE_CHAR", "SNV_GAP_CHAR", + "SNV_NA_CHARS", "get_read_snvs", "call_and_filter_useful_snvs", "process_read_snvs_for_locus_and_calculate_useful_snvs", @@ -20,6 +21,7 @@ SNV_OUT_OF_RANGE_CHAR = "-" SNV_GAP_CHAR = "_" +SNV_NA_CHARS = (SNV_OUT_OF_RANGE_CHAR, SNV_GAP_CHAR) def call_and_filter_useful_snvs(