Skip to content

Commit

Permalink
Merge pull request #176 from nextstrain/download-resolve-method
Browse files Browse the repository at this point in the history
vdb/download: Add `--prioritized_seqs_file` option
  • Loading branch information
joverlee521 authored Jan 27, 2025
2 parents b4b672b + 03f3d1b commit 6cad83e
Showing 1 changed file with 74 additions and 5 deletions.
79 changes: 74 additions & 5 deletions vdb/download.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os, json, datetime, sys, re
from rethinkdb import r
from Bio import SeqIO
from csv import DictReader
from typing import Dict, List, Optional
import numpy as np

# Enable import from modules in parent directory.
Expand Down Expand Up @@ -54,6 +56,12 @@ def duplicate_resolver(resolve_method):
return method

parser.add_argument('--resolve_method', type=duplicate_resolver, default="choose_longest", help="Set method of resolving duplicates for the same locus, options are \'keep_duplicates\', \'choose_longest\', \'choose_genbank\' and \'split_passage\'")
parser.add_argument('--prioritized_seqs_file', type=str,
help="A TSV file with strain and sequence accessions that are prioritized " +
"when using the `split_passage` resolve method. " +
"The file is expected to have the columns `strain` and `accession`, " +
"and the `strain` column is expected to have unique strain names. " +
"Remember to add the `-egg` suffix to the strain name for egg passaged sequences.")
return parser


Expand Down Expand Up @@ -219,7 +227,54 @@ def check_date_format(self, older_date, newer_date):
raise Exception("Date interval must be in YYYY-MM-DD format with all values defined", older_date, newer_date)
return(older_date.upper(), newer_date.upper())

def resolve_duplicates(self, sequence_docs, resolve_method=None, **kwargs):

def parse_prioritized_seqs(self, prioritized_seqs_file: str) -> Dict[str, str]:
'''
Parse provided TSV file *prioritized_seqs_file* to return a dict with the
`strain` as the key and `accession` as the value.
Raises an AssertionError if there are duplicate strains in the file.
'''
prioritized_seqs = {}

with open(prioritized_seqs_file, 'r', newline='') as handle:
for record in DictReader(handle, delimiter="\t"):
assert record['strain'] not in prioritized_seqs, \
f"Found duplicate strain {record['strain']} in prioritized-seqs TSV."

prioritized_seqs[record['strain']] = record['accession']

return prioritized_seqs


def resolve_prioritized_seqs(self, prioritized_seqs: Dict[str,str], strain_sdocs: List[Dict]) -> Optional[Dict]:
'''
Returns the prioritized sequence based on sequence accession
for a strain when available.
'''
strain = strain_sdocs[0]["strain"]

# All `strain_sdocs` should be for the same strain!
assert all(d["strain"] == strain for d in strain_sdocs), \
f"Not all strain_sdocs should have the same strain name {strain!r}"

seq_accession = prioritized_seqs.get(strain, None)
if seq_accession is None:
return None

prioritized_seq_sdocs = [sdoc for sdoc in strain_sdocs if sdoc['accession'] == seq_accession]
if len(prioritized_seq_sdocs) == 0:
print(f"WARNING: cannot find prioritized accession {seq_accession!r} for strain {strain!r}")
return None

# There should always only be one matching accession since the sequences table is indexed on accession!
assert len(prioritized_seq_sdocs) == 1, \
f"More than one sequence matched the prioritized accession {seq_accession!r} for strain {strain!r}"

return prioritized_seq_sdocs[0]


def resolve_duplicates(self, sequence_docs, resolve_method=None, prioritized_seqs_file=None, **kwargs):
'''
Takes a list of sequence documents (each one a Dict of key value pairs)
And subsets this list to have only 1 sequence document for each 'strain'
Expand Down Expand Up @@ -255,6 +310,12 @@ def resolve_duplicates(self, sequence_docs, resolve_method=None, **kwargs):
print("Resolving duplicate strains by keeping one cell/direct and one egg sequence")
print("Appends -egg to egg-passaged sequence")
print("Within cell/egg partitions prioritize longest sequence")

prioritized_seqs = {}
if prioritized_seqs_file is not None:
print(f"Sequence accessions specified in {prioritized_seqs_file!r} will be prioritized regardless of sequence length.")
prioritized_seqs = self.parse_prioritized_seqs(prioritized_seqs_file)

for strain in strains:
strain_sdocs = strain_to_sdocs[strain]
cell_strain_sdocs = []
Expand All @@ -266,11 +327,19 @@ def resolve_duplicates(self, sequence_docs, resolve_method=None, **kwargs):
strain_sdoc['strain'] = strain_sdoc['strain'] + "-egg"
egg_strain_sdocs.append(strain_sdoc)
if len(cell_strain_sdocs) > 0:
sorted_cell_strain_sdocs = sorted(cell_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True)
resolved_sequence_docs.append(sorted_cell_strain_sdocs[0])
prioritized_seq_sdoc = self.resolve_prioritized_seqs(prioritized_seqs, cell_strain_sdocs)
if prioritized_seq_sdoc is not None:
resolved_sequence_docs.append(prioritized_seq_sdoc)
else:
sorted_cell_strain_sdocs = sorted(cell_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True)
resolved_sequence_docs.append(sorted_cell_strain_sdocs[0])
if len(egg_strain_sdocs) > 0:
sorted_egg_strain_sdocs = sorted(egg_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True)
resolved_sequence_docs.append(sorted_egg_strain_sdocs[0])
prioritized_seq_sdoc = self.resolve_prioritized_seqs(prioritized_seqs, egg_strain_sdocs)
if prioritized_seq_sdoc is not None:
resolved_sequence_docs.append(prioritized_seq_sdoc)
else:
sorted_egg_strain_sdocs = sorted(egg_strain_sdocs, key=lambda k: len(k['sequence'].replace('n', '')), reverse=True)
resolved_sequence_docs.append(sorted_egg_strain_sdocs[0])
elif resolve_method == "keep_duplicates":
print("Keeping duplicate strains")
resolved_sequence_docs = sequence_docs
Expand Down

0 comments on commit 6cad83e

Please sign in to comment.