From 03f3d1b8561163c1b4f89cb162aa01b658996cd8 Mon Sep 17 00:00:00 2001 From: Jover Lee Date: Fri, 10 Jan 2025 14:58:01 -0800 Subject: [PATCH] vdb/download: Add `--prioritized_seqs_file` option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow users to define prioritized sequence accessions for a strain when the duplicate resolve method is `split_passage`. This flag could be extended for other resolve methods, but we are not actively using them in seasonal flu workflows so punting for now. The file is expected to be a TSV file with a `strain` and `accession` column where the `strain` column only contains unique strain names. With the seasonal flu's pattern of downloading each segment separately,¹ the file should contain sequence accessions for a single segment. ¹ --- vdb/download.py | 79 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/vdb/download.py b/vdb/download.py index 7b339602..e78c385f 100644 --- a/vdb/download.py +++ b/vdb/download.py @@ -1,6 +1,8 @@ import os, json, datetime, sys, re from rethinkdb import r from Bio import SeqIO +from csv import DictReader +from typing import Dict, List, Optional import numpy as np # Enable import from modules in parent directory. @@ -54,6 +56,12 @@ def duplicate_resolver(resolve_method): return method parser.add_argument('--resolve_method', type=duplicate_resolver, default="choose_longest", help="Set method of resolving duplicates for the same locus, options are \'keep_duplicates\', \'choose_longest\', \'choose_genbank\' and \'split_passage\'") + parser.add_argument('--prioritized_seqs_file', type=str, + help="A TSV file with strain and sequence accessions that are prioritized " + + "when using the `split_passage` resolve method. " + + "The file is expected to have the columns `strain` and `accession`, " + + "and the `strain` column is expected to have unique strain names. " + + "Remember to add the `-egg` suffix to the strain name for egg passaged sequences.") return parser @@ -219,7 +227,54 @@ def check_date_format(self, older_date, newer_date): raise Exception("Date interval must be in YYYY-MM-DD format with all values defined", older_date, newer_date) return(older_date.upper(), newer_date.upper()) - def resolve_duplicates(self, sequence_docs, resolve_method=None, **kwargs): + + def parse_prioritized_seqs(self, prioritized_seqs_file: str) -> Dict[str, str]: + ''' + Parse provided TSV file *prioritized_seqs_file* to return a dict with the + `strain` as the key and `accession` as the value. + + Raises an AssertionError if there are duplicate strains in the file. + ''' + prioritized_seqs = {} + + with open(prioritized_seqs_file, 'r', newline='') as handle: + for record in DictReader(handle, delimiter="\t"): + assert record['strain'] not in prioritized_seqs, \ + f"Found duplicate strain {record['strain']} in prioritized-seqs TSV." + + prioritized_seqs[record['strain']] = record['accession'] + + return prioritized_seqs + + + def resolve_prioritized_seqs(self, prioritized_seqs: Dict[str,str], strain_sdocs: List[Dict]) -> Optional[Dict]: + ''' + Returns the prioritized sequence based on sequence accession + for a strain when available. + ''' + strain = strain_sdocs[0]["strain"] + + # All `strain_sdocs` should be for the same strain! + assert all(d["strain"] == strain for d in strain_sdocs), \ + f"Not all strain_sdocs should have the same strain name {strain!r}" + + seq_accession = prioritized_seqs.get(strain, None) + if seq_accession is None: + return None + + prioritized_seq_sdocs = [sdoc for sdoc in strain_sdocs if sdoc['accession'] == seq_accession] + if len(prioritized_seq_sdocs) == 0: + print(f"WARNING: cannot find prioritized accession {seq_accession!r} for strain {strain!r}") + return None + + # There should always only be one matching accession since the sequences table is indexed on accession! + assert len(prioritized_seq_sdocs) == 1, \ + f"More than one sequence matched the prioritized accession {seq_accession!r} for strain {strain!r}" + + return prioritized_seq_sdocs[0] + + + def resolve_duplicates(self, sequence_docs, resolve_method=None, prioritized_seqs_file=None, **kwargs): ''' Takes a list of sequence documents (each one a Dict of key value pairs) And subsets this list to have only 1 sequence document for each 'strain' @@ -255,6 +310,12 @@ def resolve_duplicates(self, sequence_docs, resolve_method=None, **kwargs): print("Resolving duplicate strains by keeping one cell/direct and one egg sequence") print("Appends -egg to egg-passaged sequence") print("Within cell/egg partitions prioritize longest sequence") + + prioritized_seqs = {} + if prioritized_seqs_file is not None: + print(f"Sequence accessions specified in {prioritized_seqs_file!r} will be prioritized regardless of sequence length.") + prioritized_seqs = self.parse_prioritized_seqs(prioritized_seqs_file) + for strain in strains: strain_sdocs = strain_to_sdocs[strain] cell_strain_sdocs = [] @@ -266,11 +327,19 @@ def resolve_duplicates(self, sequence_docs, resolve_method=None, **kwargs): strain_sdoc['strain'] = strain_sdoc['strain'] + "-egg" egg_strain_sdocs.append(strain_sdoc) if len(cell_strain_sdocs) > 0: - sorted_cell_strain_sdocs = sorted(cell_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True) - resolved_sequence_docs.append(sorted_cell_strain_sdocs[0]) + prioritized_seq_sdoc = self.resolve_prioritized_seqs(prioritized_seqs, cell_strain_sdocs) + if prioritized_seq_sdoc is not None: + resolved_sequence_docs.append(prioritized_seq_sdoc) + else: + sorted_cell_strain_sdocs = sorted(cell_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True) + resolved_sequence_docs.append(sorted_cell_strain_sdocs[0]) if len(egg_strain_sdocs) > 0: - sorted_egg_strain_sdocs = sorted(egg_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True) - resolved_sequence_docs.append(sorted_egg_strain_sdocs[0]) + prioritized_seq_sdoc = self.resolve_prioritized_seqs(prioritized_seqs, egg_strain_sdocs) + if prioritized_seq_sdoc is not None: + resolved_sequence_docs.append(prioritized_seq_sdoc) + else: + sorted_egg_strain_sdocs = sorted(egg_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True) + resolved_sequence_docs.append(sorted_egg_strain_sdocs[0]) elif resolve_method == "keep_duplicates": print("Keeping duplicate strains") resolved_sequence_docs = sequence_docs