diff --git a/my_profiles/example_multiple_inputs/builds.yaml b/my_profiles/example_multiple_inputs/builds.yaml index 3c309f847..09d501f4a 100644 --- a/my_profiles/example_multiple_inputs/builds.yaml +++ b/my_profiles/example_multiple_inputs/builds.yaml @@ -8,11 +8,11 @@ inputs: - name: "aus" - metadata: "data/example_metadata_aus.tsv" - sequences: "data/example_sequences_aus.fasta" + metadata: "data/example_metadata_aus.tsv.xz" + sequences: "data/example_sequences_aus.fasta.xz" - name: "worldwide" - metadata: "data/example_metadata_worldwide.tsv" - sequences: "data/example_sequences_worldwide.fasta" + metadata: "data/example_metadata_worldwide.tsv.xz" + sequences: "data/example_sequences_worldwide.fasta.xz" builds: multiple-inputs: diff --git a/scripts/get_distance_to_focal_set.py b/scripts/get_distance_to_focal_set.py index 075f631a2..ca4db55b7 100644 --- a/scripts/get_distance_to_focal_set.py +++ b/scripts/get_distance_to_focal_set.py @@ -134,41 +134,51 @@ def calculate_distance_matrix(sparse_matrix_A, sparse_matrix_B, consensus): return d -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="generate priorities files based on genetic proximity to focal sample", - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument("--alignment", type=str, required=True, help="FASTA file of alignment") - parser.add_argument("--reference", type = str, required=True, help="reference sequence (FASTA)") - parser.add_argument("--ignore-seqs", type = str, nargs='+', help="sequences to ignore in distance calculation") - parser.add_argument("--focal-alignment", type = str, required=True, help="focal sample of sequences") - parser.add_argument("--chunk-size", type=int, default=10000, help="number of samples in the global alignment to process at once. Reduce this number to reduce memory usage at the cost of increased run-time.") - parser.add_argument("--output", type=str, required=True, help="FASTA file of output alignment") - args = parser.parse_args() + +def get_distance_to_focal_set(alignment, reference, focal_alignment, output, ignore_seqs=[], chunk_size=10000): + """ + Calculate minimal distances between sequences in an alignment and a set of focal sequences + Parameters + ---------- + alignment : string + Path to FASTA file of alignment + reference : string + path to reference sequence (FASTA) + focal_alignment : string + Path to FASTA of focal sample of sequences + output : string + FASTA file of output alignment + ignore_seqs : list[string], optional + sequences to ignore in distance calculation + chunk_size : int, default: 10000 + number of samples in the global alignment to process at once. Reduce this number to + reduce memory usage at the cost of increased run-time. + Returns + ------- + None + """ # load entire alignment and the alignment of focal sequences (upper case -- probably not necessary) - ref = sequence_to_int_array(SeqIO.read(args.reference, 'fasta').seq) + ref = sequence_to_int_array(SeqIO.read(reference, 'fasta').seq) alignment_length = len(ref) - focal_seqs = read_sequences(args.focal_alignment) - focal_seqs_dict = calculate_snp_matrix(focal_seqs, consensus = ref, ignore_seqs=args.ignore_seqs) + focal_seqs = read_sequences(focal_alignment) + focal_seqs_dict = calculate_snp_matrix(focal_seqs, consensus = ref, ignore_seqs=ignore_seqs) if focal_seqs_dict is None: print( - f"ERROR: There are no valid sequences in the focal alignment, '{args.focal_alignment}', to compare against the full alignment.", + f"ERROR: There are no valid sequences in the focal alignment, '{focal_alignment}', to compare against the full alignment.", "Check your subsampling settings for the focal alignment or consider disabling proximity-based subsampling.", file=sys.stderr ) sys.exit(1) - seqs = read_sequences(args.alignment) + seqs = read_sequences(alignment) # export priorities - fh_out = open(args.output, 'w') + fh_out = open(output, 'w') fh_out.write('strain\tclosest strain\tdistance\n') - chunk_size=args.chunk_size chunk_count = 0 while True: context_seqs_dict = calculate_snp_matrix(seqs, consensus=ref, chunk_size=chunk_size) @@ -196,3 +206,24 @@ def calculate_distance_matrix(sparse_matrix_A, sparse_matrix_B, consensus): chunk_count += 1 fh_out.close() + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="generate priorities files based on genetic proximity to focal sample", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--alignment", type=str, required=True, help="FASTA file of alignment") + parser.add_argument("--reference", type = str, required=True, help="reference sequence (FASTA)") + parser.add_argument("--ignore-seqs", type = str, nargs='+', help="sequences to ignore in distance calculation") + parser.add_argument("--focal-alignment", type = str, required=True, help="focal sample of sequences") + parser.add_argument("--chunk-size", type=int, default=10000, help="number of samples in the global alignment to process at once. Reduce this number to reduce memory usage at the cost of increased run-time.") + parser.add_argument("--output", type=str, required=True, help="FASTA file of output alignment") + args = parser.parse_args() + get_distance_to_focal_set( + args.alignment, + args.reference, + args.focal_alignment, + args.output, + args.ignore_seqs, + args.chunk_size + ) diff --git a/scripts/priorities.py b/scripts/priorities.py index 8470d5b9d..5e6130cee 100644 --- a/scripts/priorities.py +++ b/scripts/priorities.py @@ -7,20 +7,28 @@ import numpy as np import pandas as pd -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="generate priorities files based on genetic proximity to focal sample", - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument("--sequence-index", type=str, required=True, help="sequence index file") - parser.add_argument("--proximities", type = str, required=True, help="tsv file with proximities") - parser.add_argument("--Nweight", type = float, default=0.003, required=False, help="parameterizes de-prioritization of incomplete sequences") - parser.add_argument("--crowding-penalty", type = float, default=0.01, required=False, help="parameterizes how priorities decrease when there is many very similar sequences") - parser.add_argument("--output", type=str, required=True, help="tsv file with the priorities") - args = parser.parse_args() +def create_priorities(sequence_index_path, proximities_path, output_path, Nweight=0.003, crowding_penalty=0.01): + """ + calculate priorties from index and proximities + Parameters + ---------- + sequence_index_path : string + Path to sequence index file + proximities_path : string + path to tsv file with proximities + output_path : string + path to TSV file with the priorities + Nweight : float, default: 0.003 + parameterizes de-prioritization of incomplete sequences + crowding_penalty : float, default: 0.01 + parameterizes how priorities decrease when there is many very similar sequences + Returns + ------- + None + """ - proximities = pd.read_csv(args.proximities, sep='\t', index_col=0) - index = pd.read_csv(args.sequence_index, sep='\t', index_col=0) + proximities = pd.read_csv(proximities_path, sep='\t', index_col=0) + index = pd.read_csv(sequence_index_path, sep='\t', index_col=0) combined = pd.concat([proximities, index], axis=1) closest_matches = combined.groupby('closest strain') @@ -28,16 +36,36 @@ for focal_seq, seqs in closest_matches.groups.items(): tmp = combined.loc[seqs, ["distance", "N"]] # penalize larger distances and more undetermined sites. 1/args.Nweight are 'as bad' as one extra mutation - tmp["priority"] = -tmp.distance - tmp.N*args.Nweight + tmp["priority"] = -tmp.distance - tmp.N*Nweight name_prior = [(name, d.priority) for name, d in tmp.iterrows()] shuffle(name_prior) candidates[focal_seq] = sorted(name_prior, key=lambda x:x[1]) # export priorities - crowding = args.crowding_penalty - with open(args.output, 'w') as fh: + crowding = crowding_penalty + with open(output_path, 'w') as fh: # loop over lists of sequences that are closest to particular focal sequences for cs in candidates.values(): # these sets have been shuffled -- reduce priorities in this shuffled random order for i, (name, pr) in enumerate(cs): fh.write(f"{name}\t{pr-i*crowding:1.2f}\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="generate priorities files based on genetic proximity to focal sample", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--sequence-index", type=str, required=True, help="sequence index file") + parser.add_argument("--proximities", type = str, required=True, help="tsv file with proximities") + parser.add_argument("--Nweight", type = float, default=0.003, required=False, help="parameterizes de-prioritization of incomplete sequences") + parser.add_argument("--crowding-penalty", type = float, default=0.01, required=False, help="parameterizes how priorities decrease when there is many very similar sequences") + parser.add_argument("--output", type=str, required=True, help="tsv file with the priorities") + args = parser.parse_args() + create_priorities( + args.sequence_index, + args.proximities, + args.output, + args.Nweight, + args.crowding_penalty + ) diff --git a/scripts/subsample.py b/scripts/subsample.py new file mode 100644 index 000000000..8b13568b5 --- /dev/null +++ b/scripts/subsample.py @@ -0,0 +1,351 @@ +from augur.utils import AugurException +from augur.filter import run as augur_filter, register_arguments as register_filter_arguments +from augur.index import index_sequences +from augur.io import write_sequences, open_file, read_sequences, read_metadata +from get_distance_to_focal_set import get_distance_to_focal_set # eventually from augur.priorities (or similar) +from priorities import create_priorities # eventually from augur.priorities (or similar) +import yaml +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from os import path +import pandas as pd +from tempfile import NamedTemporaryFile +import jsonschema +# from pkg_resources import resource_string + +DESCRIPTION = "Subsample sequences based on user-defined YAML configuration" + +def register_arguments(parser): + parser.add_argument('--scheme', required=True, metavar="YAML", help="subsampling scheme") + parser.add_argument('--output-dir', required=True, metavar="PATH", help="directory to save intermediate results") + parser.add_argument('--metadata', required=True, metavar="TSV", help="metadata") + parser.add_argument('--alignment', required=True, metavar="FASTA", help="alignment to subsample") + parser.add_argument('--alignment-index', required=False, metavar="INDEX", help="sequence index of alignment") + parser.add_argument('--reference', required=True, metavar="FASTA", help="reference (which was used for alignment)") + parser.add_argument('--include-strains-file', required=False, nargs="+", default=None, metavar="TXT", help="strains to force include") + parser.add_argument('--exclude-strains-file', required=False, nargs="+", default=None, metavar="TXT", help="strains to force exclude") + parser.add_argument('--output-fasta', required=True, metavar="FASTA", help="output subsampled sequences") + parser.add_argument('--output-metadata', required=True, metavar="TSV", help="output subsampled metadata") + parser.add_argument('--output-log', required=False, metavar="TSV", help="log file explaining why strains were excluded / included") + parser.add_argument('--use-existing-outputs', required=False, action="store_true", help="use intermediate files, if they exist") + +def run(args): + + config = parse_scheme(args.scheme) + + generate_sequence_index(args) + + samples = [Sample(name, data, args) for name, data in config.items()] + + graph = make_graph(samples) + + traverse_graph( + graph, + lambda s: s.filter() + ) + + combine_samples(args, samples) + +def parse_scheme(filename): + with open(filename) as fh: + try: + data = yaml.safe_load(fh) + except yaml.YAMLError as exc: + print(exc) + raise AugurException(f"Error parsing subsampling scheme {filename}") + validate_scheme(data) + return data + + +def validate_scheme(scheme): + try: + # When we move this to `augur subsample`, load the schema via: + # schema = yaml.safe_load(resource_string(__package__, path.join("data", "schema-subsampling.yaml"))) + with open(path.join(path.dirname(path.realpath(__file__)), "subsample_schema.yaml")) as fh: + schema = yaml.safe_load(fh) + except yaml.YAMLError as err: + raise AugurException("Subsampling schema definition is not valid YAML. Error: {}".format(err)) + # check loaded schema is itself valid -- see http://python-jsonschema.readthedocs.io/en/latest/errors/ + try: + jsonschema.Draft6Validator.check_schema(schema) + except jsonschema.exceptions.SchemaError as err: + raise AugurException("Subsampling schema definition is not valid. Error: {}".format(path, err)) + + try: + jsonschema.Draft6Validator(schema).validate(scheme) + except jsonschema.exceptions.ValidationError as err: + print(err) + raise AugurException("Subsampling scheme failed validation") + +class Sample(): + """ + A class to hold information about a sample. A subsampling scheme will consist of multiple + samples. Each sample may depend on the priorities based off another sample. + """ + def __init__(self, name, config, cmd_args): + self.name = name + self.tmp_dir = cmd_args.output_dir + self.alignment = cmd_args.alignment + self.alignment_index = cmd_args.alignment_index + self.reference = cmd_args.reference + self.metadata = cmd_args.metadata + self.initialise_filter_args(config, cmd_args) + self.priorities = config.get("priorities", None) + self.use_existing_outputs = args.use_existing_outputs + + def initialise_filter_args(self, config, subsample_args): + """ + Currently this method is needed as we need to call `augur filter`'s `run()` with an + argparse instance. An improvement here would be to expose appropriate filtering + functions and call them as needed, with the output being returned rather than + written to disk. + """ + # create the appropriate command-line arguments for the augur filter run we want + arg_list = [ + "--metadata", self.metadata, + "--sequences", self.alignment, + "--sequence-index", self.alignment_index, + "--output", path.join(self.tmp_dir, f"sample.{self.name}.fasta"), # filtered sequences in FASTA forma + "--output-metadata", path.join(self.tmp_dir, f"sample.{self.name}.tsv"), # metadata for strains that passed filters + "--output-strains", path.join(self.tmp_dir, f"sample.{self.name}.txt"), # list of strains that passed filters (no header) + "--output-log", path.join(self.tmp_dir, f"sample.{self.name}.log.tsv") + ] + # convert the YAML config into the command-line arguments for augur filter + for name, value in config.items(): + if isinstance(value, dict): + pass # we explicitly ignore dictionary config entries + elif isinstance(value, list): + arg_list.append(f"--{name}") + arg_list.extend([str(v) for v in value]) + elif isinstance(value, bool): + if value: + arg_list.append(f"--{name}") + else: + arg_list.append(f"--{name}") + arg_list.append(str(value)) + # mock an ArgumentParser so that we can use augur filters interface, avoiding the need to duplicate logic + parser = ArgumentParser(prog="Mock_Augur_Filter") + register_filter_arguments(parser) + self.filter_args, unused_args = parser.parse_known_args(arg_list) + if unused_args: + print(f"Warning - the following config parameters are not part of augur filter and may be ignored:") + print(' '.join(unused_args)) + + def calculate_required_priorities(self): + """ + If computation of this sample requires priority information of another sample + (the "focus"), then this function will compute those priorities by calling + a method on the focal sample object. + """ + if not self.priorities: + return + focal_sample = self.priorities.get('sample', None) + if not focal_sample: + raise AugurException(f"Cannot calculate priorities needed for {self.name} as the {self.get_priority_focus_name()} sample wasn't linked") + print(f"Calculating priorities of {focal_sample.name}, as required by {self.name}") + priorities_file = focal_sample.calculate_priorities() + print(f"\tSetting {self.name} filter priority file to {priorities_file}") + self.filter_args.priority = priorities_file + + def calculate_priorities(self): + """ + Calculate the priorities TSV file for samples in the alignment vs this sample + + Returns the filename of the priorities file (TSV) + """ + + proximity_output_file = path.join(self.tmp_dir, f"proximity_{self.name}.tsv") + if self.use_existing_outputs and check_outputs_exist(proximity_output_file): + print(f"Using existing proximity scores for {self.name}") + else: + print(f"Calculating proximity of {self.name}") + get_distance_to_focal_set( + self.alignment, + self.reference, + self.filter_args.output, + proximity_output_file, + ignore_seqs=["Wuhan/Hu-1/2019"] # TODO - use the config to define this? + ) + + priorities_path = path.join(self.tmp_dir, f"priorities_{self.name}.tsv") + if self.use_existing_outputs and check_outputs_exist(priorities_path): + print(f"Using existing priorities for {self.name}") + else: + print(f"Calculating priorities of {self.name}") + create_priorities( + self.alignment_index, + proximity_output_file, + priorities_path + ) + return priorities_path + + def get_priority_focus_name(self): + if not self.priorities: + return None + return self.priorities['focus'] + + def set_priority_sample(self, sample): + if not self.priorities: + raise AugurException(f"No priorities set for {self.name}") + self.priorities['sample'] = sample + + def filter(self): + print("\n---------------------------------\nCONSTRUCTING SAMPLE FOR", self.name, "\n---------------------------------") + self.calculate_required_priorities() + if self.use_existing_outputs and check_outputs_exist(self.filter_args.output_metadata, self.filter_args.output_strains, self.filter_args.output_log): + print(f"Using existing filtering results for {self.name}") + else: + print("Calling augur filter") + print("Filter arguments:") + for k,v in self.filter_args.__dict__.items(): + if v is not None: + print(f"\t{k: <30}{v}") + augur_filter(self.filter_args) + + # In the future, instead of `augur_filter` saving data to disk, it would return + # data to the calling process. In lieu of that, we read the data just written. + try: + self.sampled_strains = set(pd.read_csv(self.filter_args.output_strains, header=None)[0]) + except pd.errors.EmptyDataError: + self.sampled_strains = set() + self.filter_log = pd.read_csv( + self.filter_args.output_log, + header=0, + sep="\t", + index_col="strain" + ) + + +def make_graph(samples): + """" + Given a config file, construct a graph of samples to perform in an iterative fashion, such that + priorities + This is a DAG, however an extremely simple one which we can construct outselves rather than relying on + extra libraries. + Constraints: + * Each sample can only use priorities of one other sample + * Acyclic + Structure: + tuple: (sample name, list of descendent samples) where a "descendant" sample requires the linked sample to be + created prior to it's creation. Each entry in the list has this tuple structure. + """ + + included = set() # set of samples added to graph + graph = (None, []) + + # add all the samples which don't require priorities to the graph + for sample in samples: + if not sample.get_priority_focus_name(): + graph[1].append((sample, [])) + included.add(sample.name) + + def add_descendants(level): + parent_sample = level[0] + descendants = level[1] + for sample in samples: + if sample.name in included: + continue + if sample.get_priority_focus_name() == parent_sample.name: + sample.set_priority_sample(parent_sample) + descendants.append((sample, [])) + included.add(sample.name) + for inner_level in descendants: + add_descendants(inner_level) + + for level in graph[1]: + add_descendants(level) + + # from pprint import pprint + # print("\ngraph"); pprint(graph);print("\n") + + if len(samples)!=len(included): + AugurException("Incomplete graph construction") + + return graph + +def traverse_graph(level, callback): + this_sample, descendents = level + if this_sample: + callback(this_sample) + for child in descendents: + traverse_graph(child, callback) + +def generate_sequence_index(args): + if args.alignment_index: + print("Skipping sequence index creation as an index was provided") + return + print("Creating ephemeral sequence index file") + with NamedTemporaryFile(delete=False) as sequence_index_file: + sequence_index_path = sequence_index_file.name + index_sequences(args.alignment, sequence_index_path) + args.alignment_index = sequence_index_path + + +def combine_samples(args, samples): + """Collect the union of strains which are included in each sample and write them to disk. + Parameters + ---------- + args : argparse.Namespace + Parsed arguments from argparse + samples : list[Sample] + list of samples + """ + print("\n\n") + ### Form a union of each sample set, which is the subsampled strains list + sampled_strains = set() + for sample in samples: + print(f"Sample \"{sample.name}\" included {len(sample.sampled_strains)} strains") + sampled_strains.update(sample.sampled_strains) + print(f"In total, {len(sampled_strains)} strains are included in the resulting subsampled dataset") + + ## Iterate through the input sequences, streaming a subsampled version to disk. + sequences = read_sequences(args.alignment) + sequences_written_to_disk = 0 + with open_file(args.output_fasta, "wt") as output_handle: + for sequence in sequences: + if sequence.id in sampled_strains: + sequences_written_to_disk += 1 + write_sequences(sequence, output_handle, 'fasta') + print(f"{sequences_written_to_disk} sequences written to {args.output_fasta}") + + ## Iterate through the metadata in chunks, writing out those entries which are in the subsample + metadata_reader = read_metadata( + args.metadata, + id_columns=["strain", "name"], # TODO - this should be an argument + chunk_size=10000 # TODO - argument + ) + metadata_header = True + metadata_mode = "w" + metadata_written_to_disk = 0 + for metadata in metadata_reader: + df = metadata.loc[metadata.index.intersection(sampled_strains)] + df.to_csv( + args.output_metadata, + sep="\t", + header=metadata_header, + mode=metadata_mode, + ) + metadata_written_to_disk += df.shape[0] + metadata_header = False + metadata_mode = "a" + print(f"{metadata_written_to_disk} metadata entries written to {args.output_metadata}") + + ## Combine the log files (from augur filter) for each sample into a larger log file + ## Format TBD + ## TODO + +def check_outputs_exist(*paths): + for p in paths: + if not (path.exists(p) and path.isfile(p)): + return False + return True + +if __name__ == "__main__": + # the format of this block is designed specifically for future transfer of this script + # into augur in the form of `augur subsample` + parser = ArgumentParser( + usage=DESCRIPTION, + formatter_class=ArgumentDefaultsHelpFormatter, + ) + register_arguments(parser) + args = parser.parse_args() + run(args) \ No newline at end of file diff --git a/scripts/subsample_schema.yaml b/scripts/subsample_schema.yaml new file mode 100644 index 000000000..2c4b6fbb0 --- /dev/null +++ b/scripts/subsample_schema.yaml @@ -0,0 +1,45 @@ + +type: object +title: YAML Schema for subsampling configuration to be consumed by a subsample script / command +patternProperties: + "^[a-zA-Z0-9*_-]+$": + type: object + title: description of a sample + additionalProperties: false + properties: + group-by: + type: array + minItems: 1 + items: + type: string + sequences-per-group: + type: integer + subsample-max-sequences: + type: integer + exclude-ambiguous-dates-by: + type: string + enum: ["any", "day", "month", "year"] + min-date: + type: ["number", "string"] + pattern: ^\d{4}-\d{2}-\d{2}$ + max-date: + type: ["number", "string"] + pattern: ^\d{4}-\d{2}-\d{2}$ + exclude-where: + type: array + minItems: 1 + items: + type: string + include-where: + type: array + minItems: 1 + items: + type: string + query: + type: string + probabilistic-sampling: + type: boolean + no-probabilistic-sampling: + type: boolean + priorities: + type: object diff --git a/workflow/snakemake_rules/main_workflow.smk b/workflow/snakemake_rules/main_workflow.smk index d2f808e2b..71f8375b2 100644 --- a/workflow/snakemake_rules/main_workflow.smk +++ b/workflow/snakemake_rules/main_workflow.smk @@ -250,74 +250,6 @@ def _get_subsampling_settings(wildcards): return subsampling_settings -def get_priorities(wildcards): - subsampling_settings = _get_subsampling_settings(wildcards) - - if "priorities" in subsampling_settings and subsampling_settings["priorities"]["type"] == "proximity": - return f"results/{wildcards.build_name}/priorities_{subsampling_settings['priorities']['focus']}.tsv" - else: - # TODO: find a way to make the list of input files depend on config - return config["files"]["include"] - - -def get_priority_argument(wildcards): - subsampling_settings = _get_subsampling_settings(wildcards) - if "priorities" not in subsampling_settings: - return "" - - if subsampling_settings["priorities"]["type"] == "proximity": - return "--priority " + get_priorities(wildcards) - elif subsampling_settings["priorities"]["type"] == "file" and "file" in subsampling_settings["priorities"]: - return "--priority " + subsampling_settings["priorities"]["file"] - else: - return "" - - -def _get_specific_subsampling_setting(setting, optional=False): - # Note -- this function contains a lot of conditional logic because - # we have the situation where some config options must define the - # augur argument in their value, and some must not. For instance: - # subsamplingScheme -> sampleName -> group_by: year (`--group-by` is _not_ part of this value) - # -> exclude: "--exclude-where 'country=USA'" (`--exclude-where` IS part of this value) - # Since there are a lot of subsampling schemes out there, backwards compatability - # is important! james hadfield, feb 2021 - def _get_setting(wildcards): - if optional: - value = _get_subsampling_settings(wildcards).get(setting, "") - else: - value = _get_subsampling_settings(wildcards)[setting] - - if isinstance(value, str): - # Load build attributes including geographic details about the - # build's region, country, division, etc. as needed for subsampling. - build = config["builds"][wildcards.build_name] - value = value.format(**build) - if value !="": - if setting == 'exclude_ambiguous_dates_by': - value = f"--exclude-ambiguous-dates-by {value}" - elif setting == 'group_by': - value = f"--group-by {value}" - elif value is not None: - # If is 'seq_per_group' or 'max_sequences' build subsampling setting, - # need to return the 'argument' for augur - if setting == 'seq_per_group': - value = f"--sequences-per-group {value}" - elif setting == 'max_sequences': - value = f"--subsample-max-sequences {value}" - - return value - else: - value = "" - - # Check format strings that haven't been resolved. - if re.search(r'\{.+\}', value): - raise Exception(f"The parameters for the subsampling scheme '{wildcards.subsample}' of build '{wildcards.build_name}' reference build attributes that are not defined in the configuration file: '{value}'. Add these build attributes to the appropriate configuration file and try again.") - - return value - - return _get_setting - - rule combine_sequences_for_subsampling: # Similar to rule combine_input_metadata, this rule should only be run if multiple inputs are being used (i.e. multiple origins) message: @@ -365,162 +297,124 @@ rule index_sequences: --output {output.sequence_index} 2>&1 | tee {log} """ -rule subsample: +rule extract_subsampling_scheme: message: """ - Subsample all sequences by '{wildcards.subsample}' scheme for build '{wildcards.build_name}' with the following parameters: - - - group by: {params.group_by} - - sequences per group: {params.sequences_per_group} - - subsample max sequences: {params.subsample_max_sequences} - - min-date: {params.min_date} - - max-date: {params.max_date} - - {params.exclude_ambiguous_dates_argument} - - exclude: {params.exclude_argument} - - include: {params.include_argument} - - query: {params.query_argument} - - priority: {params.priority_argument} + Extracting subsampling scheme for build "{wildcards.build_name}" into its own YAML file """ - input: - sequences = _get_unified_alignment, - metadata = _get_unified_metadata, - sequence_index = rules.index_sequences.output.sequence_index, - include = config["files"]["include"], - priorities = get_priorities, - exclude = config["files"]["exclude"] output: - sequences = "results/{build_name}/sample-{subsample}.fasta", - strains="results/{build_name}/sample-{subsample}.txt", - log: - "logs/subsample_{build_name}_{subsample}.txt" - benchmark: - "benchmarks/subsample_{build_name}_{subsample}.txt" - params: - group_by = _get_specific_subsampling_setting("group_by", optional=True), - sequences_per_group = _get_specific_subsampling_setting("seq_per_group", optional=True), - subsample_max_sequences = _get_specific_subsampling_setting("max_sequences", optional=True), - sampling_scheme = _get_specific_subsampling_setting("sampling_scheme", optional=True), - exclude_argument = _get_specific_subsampling_setting("exclude", optional=True), - include_argument = _get_specific_subsampling_setting("include", optional=True), - query_argument = _get_specific_subsampling_setting("query", optional=True), - exclude_ambiguous_dates_argument = _get_specific_subsampling_setting("exclude_ambiguous_dates_by", optional=True), - min_date = _get_specific_subsampling_setting("min_date", optional=True), - max_date = _get_specific_subsampling_setting("max_date", optional=True), - priority_argument = get_priority_argument - resources: - # Memory use scales primarily with the size of the metadata file. - mem_mb=12000 - conda: config["conda_environment"] - shell: - """ - augur filter \ - --sequences {input.sequences} \ - --metadata {input.metadata} \ - --sequence-index {input.sequence_index} \ - --include {input.include} \ - --exclude {input.exclude} \ - {params.min_date} \ - {params.max_date} \ - {params.exclude_argument} \ - {params.include_argument} \ - {params.query_argument} \ - {params.exclude_ambiguous_dates_argument} \ - {params.priority_argument} \ - {params.group_by} \ - {params.sequences_per_group} \ - {params.subsample_max_sequences} \ - {params.sampling_scheme} \ - --output {output.sequences} \ - --output-strains {output.strains} 2>&1 | tee {log} - """ - -rule proximity_score: - message: - """ - determine priority for inclusion in as phylogenetic context by - genetic similiarity to sequences in focal set for build '{wildcards.build_name}'. - """ - input: - alignment = _get_unified_alignment, - reference = config["files"]["alignment_reference"], - focal_alignment = "results/{build_name}/sample-{focus}.fasta" - output: - proximities = "results/{build_name}/proximity_{focus}.tsv" - log: - "logs/subsampling_proximity_{build_name}_{focus}.txt" - benchmark: - "benchmarks/proximity_score_{build_name}_{focus}.txt" - params: - chunk_size=10000, - ignore_seqs = config['refine']['root'] - resources: - # Memory scales at ~0.15 MB * chunk_size (e.g., 0.15 MB * 10000 = 1.5GB). - mem_mb=4000 - conda: config["conda_environment"] - shell: - """ - python3 scripts/get_distance_to_focal_set.py \ - --reference {input.reference} \ - --alignment {input.alignment} \ - --focal-alignment {input.focal_alignment} \ - --ignore-seqs {params.ignore_seqs} \ - --chunk-size {params.chunk_size} \ - --output {output.proximities} 2>&1 | tee {log} - """ - -rule priority_score: - input: - proximity = rules.proximity_score.output.proximities, - sequence_index = rules.index_sequences.output.sequence_index, - output: - priorities = "results/{build_name}/priorities_{focus}.tsv" - benchmark: - "benchmarks/priority_score_{build_name}_{focus}.txt" - conda: config["conda_environment"] - shell: - """ - python3 scripts/priorities.py \ - --sequence-index {input.sequence_index} \ - --proximities {input.proximity} \ - --output {output.priorities} 2>&1 | tee {log} - """ - - -def _get_subsampled_files(wildcards): - subsampling_settings = _get_subsampling_settings(wildcards) - - return [ - f"results/{wildcards.build_name}/sample-{subsample}.txt" - for subsample in subsampling_settings - ] + scheme="results/{build_name}/subsampling_scheme.yaml", + run: + # Note that the syntax is slightly different between the YAML which `./scripts/subsample.py` + # expects (this script is a precursor for `augur subsample`) and the way in which we write + # subsampling definitions in our `builds.yaml` file. + # (1) wildcards must be filled in + # (2) we use `augur filter`-like syntax, e.g. "sequences-per-group" not "seq_per_group" + import yaml + import json + import shlex + scheme = _get_subsampling_settings(wildcards) + + # fill in templates, e.g `{country}`, with build-specific data + # (prior art existed in the function _get_specific_subsampling_setting) + build = config["builds"][wildcards.build_name] + for sample_description in scheme.values(): + for key,value in sample_description.items(): + if isinstance(value, str): + sample_description[key]=value.format(**build) + + # we allow "no_subsampling" via a boolean in _each_ sample definition + # and pass each sample through the subsampling process, as we used to + # (this should be optimised via snakemake logic) + for sample_name, sample_dict in scheme.items(): + if sample_dict.get("no_subsampling", False) == True: + scheme[sample_name] = {} + + # map ncov-specific syntax into general subsampling syntax + # (prior art existed in the function _get_specific_subsampling_setting) + for sample in scheme.values(): + if "group_by" in sample: + sample["group-by"] = shlex.split(sample['group_by']) + del sample["group_by"] + if "seq_per_group" in sample: + sample["sequences-per-group"] = sample['seq_per_group'] + del sample["seq_per_group"] + if "max_sequences" in sample: + sample["subsample-max-sequences"] = sample['max_sequences'] + del sample["max_sequences"] + if "exclude_ambiguous_dates_by" in sample: + sample["exclude-ambiguous-dates-by"] = sample["exclude_ambiguous_dates_by"] + del sample["exclude_ambiguous_dates_by"] + + # Certain settings required that the argument was defined in the value, + # e.g. max_date: "--max-date 2020-01-02", which we remove here + if "query" in sample: # ncov syntax included the `--query` argument + sample["query"] = sample['query'].lstrip("--query ").strip('"').strip("'") + if "min_date" in sample: # ncov syntax included `--min-date` in the value + sample["min-date"] = sample["min_date"].lstrip("--min-date ") + del sample["min_date"] + if "max_date" in sample: # ncov syntax included `--max-date` in the value + sample["max-date"] = sample["max_date"].lstrip("--max-date ") + del sample["max_date"] + + # the include/exclude ncov subsampling settings meant {include,exclude}-where + # and also included the argument in their value, which we remove here + if "exclude" in sample: + sample["exclude-where"] = shlex.split(sample['exclude'])[1:] + del sample["exclude"] + if "include" in sample: + sample["include-where"] = shlex.split(sample['include'])[1:] + del sample["include"] + + if "sampling_scheme" in sample: # ncov syntax for expressing additional filter arguments + if sample["sampling_scheme"]=="--probabilistic-sampling": + sample["probabilistic-sampling"] = True + elif sample["sampling_scheme"]=="--no-probabilistic-sampling": + sample["no-probabilistic-sampling"] = True + del sample["sampling_scheme"] + + with open(output.scheme, 'w', encoding='utf-8') as fh: + # to avoid a YAML with type annotations, we roundtrip through a JSON + yaml.dump(json.loads(json.dumps(scheme)), fh, default_flow_style=False) -rule combine_samples: +rule subsample: message: """ - Combine and deduplicate FASTAs + Subsampling using our subsampling script. """ input: - sequences=_get_unified_alignment, - sequence_index=rules.index_sequences.output.sequence_index, + scheme="results/{build_name}/subsampling_scheme.yaml", metadata=_get_unified_metadata, - include=_get_subsampled_files, + alignment=_get_unified_alignment, + alignment_index=rules.index_sequences.output.sequence_index, + reference = config["files"]["alignment_reference"], + include = config["files"]["include"], + exclude = config["files"]["exclude"] output: + intermediates=directory("results/{build_name}/subsamples/"), sequences = "results/{build_name}/{build_name}_subsampled_sequences.fasta.xz", - metadata = "results/{build_name}/{build_name}_subsampled_metadata.tsv.xz" + metadata = "results/{build_name}/{build_name}_subsampled_metadata.tsv.xz", + # output_log = "results/{build_name}/subsamples/subsampled.tsv", log: - "logs/subsample_regions_{build_name}.txt" + "logs/subsample_{build_name}.txt" benchmark: - "benchmarks/subsample_regions_{build_name}.txt" + "benchmarks/subsample_{build_name}.txt" conda: config["conda_environment"] + threads: 4 + resources: + mem_mb=12000 shell: """ - augur filter \ - --sequences {input.sequences} \ - --sequence-index {input.sequence_index} \ + python3 scripts/subsample.py \ + --scheme {input.scheme} \ --metadata {input.metadata} \ - --exclude-all \ - --include {input.include} \ - --output-sequences {output.sequences} \ + --alignment {input.alignment} \ + --alignment-index {input.alignment_index} \ + --reference {input.reference} \ + --include-strains-file {input.include} \ + --exclude-strains-file {input.exclude} \ + --output-dir {output.intermediates} \ + --output-fasta {output.sequences} \ --output-metadata {output.metadata} 2>&1 | tee {log} """ @@ -531,7 +425,7 @@ rule build_align: - gaps relative to reference are considered real """ input: - sequences = rules.combine_samples.output.sequences, + sequences = rules.subsample.output.sequences, genemap = config["files"]["annotation"], reference = config["files"]["alignment_reference"] output: