Skip to content

Commit

Permalink
feat(call): correctly alter/remove alt anchor base in VCF output
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Sep 24, 2024
1 parent 7a3543b commit 9b04b3c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 16 deletions.
32 changes: 24 additions & 8 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def call_alleles_with_incorporated_snvs(

# Cluster reads together using the distance matrix, which incorporates SNV and possibly copy number information.
cluster_labels, cluster_indices = _agg_clust_alleles_by_dm(n_alleles, dm)
del dm

cluster_reads: list[tuple[ReadDict, ...]] = []
cns: list[NDArray[np.int32]] = []
Expand Down Expand Up @@ -1276,6 +1277,17 @@ def get_read_length_partition_mean(p_idx: int) -> float:
# TODO: re-examine weighting to possibly incorporate chance of drawing read large enough
read_weight = (mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1)

# ---

read_start_anchor: str = ""
if consensus:
anchor_pair_idx, anchor_pair_found = find_pair_by_ref_pos(r_coords, left_coord_adj - 1, 0)
if anchor_pair_found:
read_start_anchor = qs[q_coords[anchor_pair_idx]:left_flank_end]
# otherwise, leave as blank - anchor base deleted

# ---

crs_cir = chimeric_read_status[rn] == 3 # Chimera within the TR region, indicating a potential large expansion
read_dict[rn] = read_dict_entry = {
"s": "-" if segment.is_reverse else "+",
Expand All @@ -1288,7 +1300,7 @@ def get_read_length_partition_mean(p_idx: int) -> float:
read_dict_extra[rn] = read_extra_entry = {
"_ref_start": segment_start,
"_ref_end": segment_end,
**({"_tr_seq": tr_read_seq} if consensus else {}),
**({"_start_anchor": read_start_anchor, "_tr_seq": tr_read_seq} if consensus else {}),
}

# Reads can show up more than once - TODO - cache this information across loci
Expand Down Expand Up @@ -1342,8 +1354,7 @@ def get_read_length_partition_mean(p_idx: int) -> float:
n_reads_in_dict: int = len(read_dict)

locus_result.update({
# TODO: alt anchors:
**({"ref_start_anchor": ref_left_flank_seq[-1], "ref_seq": ref_seq} if consensus else {}),
**({"ref_start_anchor": ref_left_flank_seq[-1].upper(), "ref_seq": ref_seq} if consensus else {}),
"reads": read_dict,
})

Expand Down Expand Up @@ -1506,7 +1517,10 @@ def get_read_length_partition_mean(p_idx: int) -> float:
# don't know how re-sampling has occurred.
call_peak_n_reads: list[int] = []
peak_kmers: list[Counter] = [Counter() for _ in range(call_modal_n or 0)]

call_seqs: list[tuple[str, ConsensusMethod]] = []
call_anchor_seqs: list[tuple[str, ConsensusMethod]] = []

if read_peaks_called := call_modal_n and call_modal_n <= 2:
peaks: NDArray[np.float_] = call_peaks[:call_modal_n]
stdevs: NDArray[np.float_] = call_stdevs[:call_modal_n]
Expand Down Expand Up @@ -1576,16 +1590,18 @@ def get_read_length_partition_mean(p_idx: int) -> float:
call_99_cis = None

if call_data and consensus:
call_seqs.extend(
map(
def _consensi_for_key(k: str):
return map(
lambda a: consensus_seq(
list(map(lambda rr: read_dict_extra[rr]["_tr_seq"], a)),
list(map(lambda rr: read_dict_extra[rr][k], a)),
logger_,
max_mdn_poa_length,
),
allele_reads,
)
)

call_seqs.extend(_consensi_for_key("_tr_seq"))
call_anchor_seqs.extend(_consensi_for_key("_start_anchor"))

peak_data = {
"means": call_peaks,
Expand All @@ -1594,7 +1610,7 @@ def get_read_length_partition_mean(p_idx: int) -> float:
"modal_n": call_modal_n,
"n_reads": call_peak_n_reads,
**({"kmers": list(map(dict, peak_kmers))} if count_kmers in ("peak", "both") else {}),
**({"seqs": call_seqs} if consensus else {}),
**({"seqs": call_seqs, "start_anchor_seqs": call_anchor_seqs} if consensus else {}),
} if call_data else None

assign_time = time.perf_counter() - assign_start_time
Expand Down
39 changes: 32 additions & 7 deletions strkit/call/output/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pathlib
import pysam

from os.path import commonprefix
# from os.path import commonprefix
from typing import Optional

from strkit.utils import cat_strs, is_none
Expand Down Expand Up @@ -120,26 +120,47 @@ def output_contig_vcf_lines(

res_reads = result["reads"]
res_peaks = result["peaks"] or {}

peak_seqs: list[str] = list(map(idx_0_getter, res_peaks.get("seqs", [])))
peak_start_anchor_seqs: list[str] = list(map(idx_0_getter, res_peaks.get("start_anchor_seqs", [])))

if any(map(is_none, peak_seqs)): # Occurs when no consensus for one of the peaks
logger.error(f"Encountered None in results[{result_idx}].peaks.seqs: {peak_seqs}")
continue

if any(map(is_none, peak_start_anchor_seqs)): # Occurs when no consensus for one of the peaks
logger.error(f"Encountered None in results[{result_idx}].peaks.start_anchor_seqs: {peak_start_anchor_seqs}")
continue

seqs = tuple(map(str.upper, peak_seqs))
seqs_with_anchors = tuple(zip(seqs, tuple(map(str.upper, peak_start_anchor_seqs))))

if 0 < len(seqs) < n_alleles:
seqs = tuple([seqs[0]] * n_alleles)
seqs_with_anchors = tuple([seqs_with_anchors[0]] * n_alleles)

seq_alts = sorted(set(filter(lambda c: c != ref_seq, seqs)))
common_suffix_idx = -1 * len(commonprefix(tuple(map(_reversed_str, (ref_seq, *seqs)))))
seq_alts = sorted(
set(filter(lambda c: not (c[0] == ref_seq and c[1] == ref_start_anchor), seqs_with_anchors)),
key=lambda x: x[0]
)

# common_suffix_idx = -1 * len(commonprefix(tuple(map(_reversed_str, (ref_seq, *seqs)))))

call = result["call"]
call_95_cis = result["call_95_cis"]

seq_alleles_raw: tuple[Optional[str], ...] = (ref_seq, *(seq_alts or (None,))) if call is not None else (".",)
seq_alleles: list[str] = [ref_start_anchor + (ref_seq[:common_suffix_idx] if common_suffix_idx else ref_seq)]
seq_alleles_raw: tuple[Optional[str], ...] = (
((ref_seq, ref_start_anchor), *(seq_alts or (None,)))
if call is not None
else ()
)

# seq_alleles: list[str] = [ref_start_anchor + (ref_seq[:common_suffix_idx] if common_suffix_idx else ref_seq)]
seq_alleles: list[str] = [ref_start_anchor + ref_seq]
if call is not None and seq_alts:
seq_alleles.extend(ref_start_anchor + (a[:common_suffix_idx] if common_suffix_idx else a) for a in seq_alts)
# seq_alleles.extend(a[1] + (a[0][:common_suffix_idx] if common_suffix_idx else a[0]) for a in seq_alts)
# If we have a complete deletion, including the anchor, use a symbolic allele meaning "upstream deletion"
seq_alleles.extend((a[1] + a[0] if a[1] or a[0] else "*") for a in seq_alts)
else:
seq_alleles.append(".")

Expand All @@ -155,8 +176,12 @@ def output_contig_vcf_lines(
vr.info[VCF_INFO_MOTIF] = result["motif"]
vr.info[VCF_INFO_REFMC] = result["ref_cn"]

vr.samples[sample_id]["GT"] = tuple(map(seq_alleles_raw.index, seqs)) if call is not None and seqs \
vr.samples[sample_id]["GT"] = (
tuple(map(seq_alleles_raw.index, seqs_with_anchors))
if call is not None and seqs
else _blank_entry(n_alleles)
)
del seq_alleles_raw

if am := result.get("assign_method"):
vr.samples[sample_id]["PM"] = am
Expand Down
5 changes: 4 additions & 1 deletion strkit/call/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ class ReadDictExtra(TypedDict, total=False):
_ref_start: int # Read start in ref coordinates
_ref_end: int # Read end in ref coordinates

_tr_seq: str # Tandem repeat sequence... only added if consensus is being calculated
# BEGIN: only added if consensus is being calculated
_start_anchor: str # Left anchor for calculated allele sequence (usually 1 base)
_tr_seq: str # Tandem repeat sequence
# END: only added if consensus is being calculated

# Below are only added if SNVs are being incorporated:

Expand Down

0 comments on commit 9b04b3c

Please sign in to comment.