diff --git a/khmer/kfile.py b/khmer/kfile.py index a3a8170627..7d0ccd72de 100755 --- a/khmer/kfile.py +++ b/khmer/kfile.py @@ -35,6 +35,7 @@ """File handling/checking utilities for command-line scripts.""" +import contextlib import os import sys import errno @@ -246,3 +247,37 @@ def get_file_writer(file_handle, do_gzip, do_bzip): ofile = file_handle return ofile + + +@contextlib.contextmanager +def FileWriter(file_handle, do_gzip, do_bzip, *, steal_ownership=False): + """Alternative to get_file_writer that requires the use of a with block. + The intent is to address an inherent problem with get_file_writer() that + makes it difficult to use as a context manager. When get_file_writer() is + called with both gzip=False and bzip=False, the underlying file handle is + returned. As a consequence, doing: + > with get_file_writer(sys.stdout, bzip=False, gzip=False) as fh: + > pass + ends up closing sys.stdout when the with block is exited. Using the + function without a context manager avoids the issue, but then it results in + leaked open files when either bzip=True or gzip=True. + FileWriter must be used as a context manager, but it ensures that resources + are closed upon exiting the with block. Furthermore, it can be explicitly + requested to close the underlying file_handle.""" + ofile = None + + if do_gzip and do_bzip: + raise ValueError("Cannot specify both bzip and gzip compression!") + + if do_gzip: + ofile = gzip.GzipFile(fileobj=file_handle, mode='w') + elif do_bzip: + ofile = bz2.open(file_handle, mode='w') + else: + ofile = contextlib.nullcontext(enter_result=file_handle) + + with ofile as x: + yield x + + if steal_ownership: + file_handle.close() diff --git a/scripts/abundance-dist-single.py b/scripts/abundance-dist-single.py index 7fcf580276..409c7e7827 100755 --- a/scripts/abundance-dist-single.py +++ b/scripts/abundance-dist-single.py @@ -218,6 +218,10 @@ def __do_abundance_dist__(read_parser): log_info('wrote to: {output}', output=args.output_histogram_filename) + # Ensure that the output files are properly written. Python 3.12 seems to + # be less forgiving here .. + hist_fp.close() + if __name__ == '__main__': main() diff --git a/scripts/abundance-dist.py b/scripts/abundance-dist.py index af8b6c5e8e..7076f6e253 100755 --- a/scripts/abundance-dist.py +++ b/scripts/abundance-dist.py @@ -42,6 +42,7 @@ Use '-h' for parameter help. """ +import contextlib import sys import csv import khmer @@ -143,26 +144,28 @@ def main(): sys.exit(1) if args.output_histogram_filename in ('-', '/dev/stdout'): - countgraph_fp = sys.stdout + countgraph_ctx = contextlib.nullcontext(enter_result=sys.stdout) else: - countgraph_fp = open(args.output_histogram_filename, 'w') - countgraph_fp_csv = csv.writer(countgraph_fp) - # write headers: - countgraph_fp_csv.writerow(['abundance', 'count', 'cumulative', - 'cumulative_fraction']) + countgraph_ctx = open(args.output_histogram_filename, 'w') - sofar = 0 - for _, i in enumerate(abundances): - if i == 0 and not args.output_zero: - continue + with countgraph_ctx as countgraph_fp: + countgraph_fp_csv = csv.writer(countgraph_fp) + # write headers: + countgraph_fp_csv.writerow(['abundance', 'count', 'cumulative', + 'cumulative_fraction']) - sofar += i - frac = sofar / float(total) + sofar = 0 + for _, i in enumerate(abundances): + if i == 0 and not args.output_zero: + continue - countgraph_fp_csv.writerow([_, i, sofar, round(frac, 3)]) + sofar += i + frac = sofar / float(total) - if sofar == total: - break + countgraph_fp_csv.writerow([_, i, sofar, round(frac, 3)]) + + if sofar == total: + break if __name__ == '__main__': diff --git a/scripts/do-partition.py b/scripts/do-partition.py index 0027270b58..53552be442 100755 --- a/scripts/do-partition.py +++ b/scripts/do-partition.py @@ -168,8 +168,8 @@ def main(): # pylint: disable=too-many-locals,too-many-statements worker_q.put((nodegraph, _, start, end)) print('enqueued %d subset tasks' % n_subsets, file=sys.stderr) - open('%s.info' % args.graphbase, 'w').write('%d subsets total\n' - % (n_subsets)) + with open('%s.info' % args.graphbase, 'w') as info_fp: + info_fp.write('%d subsets total\n' % (n_subsets)) if n_subsets < args.threads: args.threads = n_subsets diff --git a/scripts/extract-long-sequences.py b/scripts/extract-long-sequences.py index 7526c4aedf..c84e5d5d68 100755 --- a/scripts/extract-long-sequences.py +++ b/scripts/extract-long-sequences.py @@ -52,7 +52,7 @@ import sys from khmer import __version__ from khmer.utils import write_record -from khmer.kfile import add_output_compression_type, get_file_writer +from khmer.kfile import add_output_compression_type, FileWriter from khmer.khmer_args import sanitize_help, KhmerArgumentParser @@ -81,12 +81,12 @@ def get_parser(): def main(): args = sanitize_help(get_parser()).parse_args() - outfp = get_file_writer(args.output, args.gzip, args.bzip) - for filename in args.input_filenames: - for record in screed.open(filename): - if len(record['sequence']) >= args.length: - write_record(record, outfp) - print('wrote to: ' + args.output.name, file=sys.stderr) + with FileWriter(args.output, args.gzip, args.bzip) as outfp: + for filename in args.input_filenames: + for record in screed.open(filename): + if len(record['sequence']) >= args.length: + write_record(record, outfp) + print('wrote to: ' + args.output.name, file=sys.stderr) if __name__ == '__main__': diff --git a/scripts/extract-paired-reads.py b/scripts/extract-paired-reads.py index 29d7cbe3cb..8335cf893d 100755 --- a/scripts/extract-paired-reads.py +++ b/scripts/extract-paired-reads.py @@ -44,6 +44,7 @@ Reads FASTQ and FASTA input, retains format for output. """ +from contextlib import nullcontext import sys import os.path import textwrap @@ -53,7 +54,7 @@ from khmer.khmer_args import sanitize_help, KhmerArgumentParser from khmer.khmer_args import FileType as khFileType from khmer.kfile import add_output_compression_type -from khmer.kfile import get_file_writer +from khmer.kfile import FileWriter from khmer.utils import broken_paired_reader, write_record, write_record_pair @@ -132,39 +133,40 @@ def main(): # OVERRIDE default output file locations with -p, -s if args.output_paired: - paired_fp = get_file_writer(args.output_paired, args.gzip, args.bzip) - out2 = paired_fp.name + paired_ctx = FileWriter(args.output_paired, args.gzip, args.bzip) + out2 = args.output_paired.name else: # Don't override, just open the default filename from above - paired_fp = get_file_writer(open(out2, 'wb'), args.gzip, args.bzip) + paired_ctx = FileWriter(open(out2, 'wb'), args.gzip, args.bzip, + steal_ownership=True) if args.output_single: - single_fp = get_file_writer(args.output_single, args.gzip, args.bzip) + single_ctx = FileWriter(args.output_single, args.gzip, args.bzip) out1 = args.output_single.name else: # Don't override, just open the default filename from above - single_fp = get_file_writer(open(out1, 'wb'), args.gzip, args.bzip) - - print('reading file "%s"' % infile, file=sys.stderr) - print('outputting interleaved pairs to "%s"' % out2, file=sys.stderr) - print('outputting orphans to "%s"' % out1, file=sys.stderr) - - n_pe = 0 - n_se = 0 - - reads = ReadParser(infile) - for index, is_pair, read1, read2 in broken_paired_reader(reads): - if index % 100000 == 0 and index > 0: - print('...', index, file=sys.stderr) - - if is_pair: - write_record_pair(read1, read2, paired_fp) - n_pe += 1 - else: - write_record(read1, single_fp) - n_se += 1 - - single_fp.close() - paired_fp.close() + single_ctx = FileWriter(open(out1, 'wb'), args.gzip, args.bzip, + steal_ownership=True) + + with paired_ctx as paired_fp, single_ctx as single_fp: + print('reading file "%s"' % infile, file=sys.stderr) + print('outputting interleaved pairs to "%s"' % out2, + file=sys.stderr) + print('outputting orphans to "%s"' % out1, file=sys.stderr) + + n_pe = 0 + n_se = 0 + + reads = ReadParser(infile) + for index, is_pair, read1, read2 in broken_paired_reader(reads): + if index % 100000 == 0 and index > 0: + print('...', index, file=sys.stderr) + + if is_pair: + write_record_pair(read1, read2, paired_fp) + n_pe += 1 + else: + write_record(read1, single_fp) + n_se += 1 if n_pe == 0: raise Exception("no paired reads!? check file formats...") diff --git a/scripts/extract-partitions.py b/scripts/extract-partitions.py index a1d25dcdd6..b196c100c4 100755 --- a/scripts/extract-partitions.py +++ b/scripts/extract-partitions.py @@ -301,10 +301,10 @@ def main(): args.max_size) if args.output_unassigned: - ofile = open('%s.unassigned.%s' % (args.prefix, suffix), 'wb') - unassigned_fp = get_file_writer(ofile, args.gzip, args.bzip) - extractor.process_unassigned(unassigned_fp) - unassigned_fp.close() + with open('%s.unassigned.%s' % (args.prefix, suffix), 'wb') as ofile: + unassigned_fp = get_file_writer(ofile, args.gzip, args.bzip) + extractor.process_unassigned(unassigned_fp) + unassigned_fp.close() else: extractor.process_unassigned() @@ -320,13 +320,21 @@ def main(): print('nothing to output; exiting!', file=sys.stderr) return + to_close = [] # open a bunch of output files for the different groups group_fps = {} for index in range(extractor.group_n): fname = '%s.group%04d.%s' % (args.prefix, index, suffix) - group_fp = get_file_writer(open(fname, 'wb'), args.gzip, + back_fp = open(fname, 'wb') + group_fp = get_file_writer(back_fp, args.gzip, args.bzip) group_fps[index] = group_fp + # It feels more natural to close the writer before closing the + # underlying file. fp.close() is theoretically idempotent, so it should + # be fine even though sometimes get_file_writer "steals" ownership of + # the underlying stream. + to_close.append(group_fp) + to_close.append(back_fp) # write 'em all out! # refresh the generator @@ -351,6 +359,9 @@ def main(): args.prefix, suffix), file=sys.stderr) + for fp in to_close: + fp.close() + if __name__ == '__main__': main() diff --git a/scripts/fastq-to-fasta.py b/scripts/fastq-to-fasta.py index ef597d9c39..c41d1e7ab6 100755 --- a/scripts/fastq-to-fasta.py +++ b/scripts/fastq-to-fasta.py @@ -45,7 +45,7 @@ import sys import screed from khmer import __version__ -from khmer.kfile import (add_output_compression_type, get_file_writer, +from khmer.kfile import (add_output_compression_type, FileWriter, describe_file_handle) from khmer.utils import write_record from khmer.khmer_args import sanitize_help, KhmerArgumentParser @@ -74,21 +74,21 @@ def main(): args = sanitize_help(get_parser()).parse_args() print('fastq from ', args.input_sequence, file=sys.stderr) - outfp = get_file_writer(args.output, args.gzip, args.bzip) - n_count = 0 - for n, record in enumerate(screed.open(args.input_sequence)): - if n % 10000 == 0: - print('...', n, file=sys.stderr) + with FileWriter(args.output, args.gzip, args.bzip) as outfp: + n_count = 0 + for n, record in enumerate(screed.open(args.input_sequence)): + if n % 10000 == 0: + print('...', n, file=sys.stderr) - sequence = record['sequence'] + sequence = record['sequence'] - if 'N' in sequence: - if not args.n_keep: - n_count += 1 - continue + if 'N' in sequence: + if not args.n_keep: + n_count += 1 + continue - del record['quality'] - write_record(record, outfp) + del record['quality'] + write_record(record, outfp) print('\n' + 'lines from ' + args.input_sequence, file=sys.stderr) diff --git a/scripts/filter-abund-single.py b/scripts/filter-abund-single.py index 3edcef86ec..14e2944add 100755 --- a/scripts/filter-abund-single.py +++ b/scripts/filter-abund-single.py @@ -60,7 +60,7 @@ from khmer.kfile import (check_input_files, check_space, check_space_for_graph, add_output_compression_type, - get_file_writer) + FileWriter) from khmer.khmer_logger import (configure_logging, log_info, log_error, log_warn) from khmer.trimming import (trim_record) @@ -160,22 +160,23 @@ def main(): outfile = os.path.basename(args.datafile) + '.abundfilt' else: outfile = args.outfile - outfp = open(outfile, 'wb') - outfp = get_file_writer(outfp, args.gzip, args.bzip) - - paired_iter = broken_paired_reader(ReadParser(args.datafile), - min_length=graph.ksize(), - force_single=True) - - for n, is_pair, read1, read2 in paired_iter: - assert not is_pair - assert read2 is None - - trimmed_record, _ = trim_record(graph, read1, args.cutoff, - args.variable_coverage, - args.normalize_to) - if trimmed_record: - write_record(trimmed_record, outfp) + + with FileWriter(open(outfile, 'wb'), args.gzip, args.bzip, + steal_ownership=True) as outfp: + + paired_iter = broken_paired_reader(ReadParser(args.datafile), + min_length=graph.ksize(), + force_single=True) + + for n, is_pair, read1, read2 in paired_iter: + assert not is_pair + assert read2 is None + + trimmed_record, _ = trim_record(graph, read1, args.cutoff, + args.variable_coverage, + args.normalize_to) + if trimmed_record: + write_record(trimmed_record, outfp) log_info('output in {outfile}', outfile=outfile) diff --git a/scripts/filter-abund.py b/scripts/filter-abund.py index cb729c9b77..1d32455edb 100755 --- a/scripts/filter-abund.py +++ b/scripts/filter-abund.py @@ -44,6 +44,7 @@ Use '-h' for parameter help. """ +from contextlib import nullcontext import sys import os import textwrap @@ -56,7 +57,7 @@ sanitize_help, check_argument_range) from khmer.khmer_args import FileType as khFileType from khmer.kfile import (check_input_files, check_space, - add_output_compression_type, get_file_writer) + add_output_compression_type, FileWriter) from khmer.khmer_logger import (configure_logging, log_info, log_error, log_warn) from khmer.trimming import (trim_record) @@ -137,31 +138,38 @@ def main(): if args.single_output_file: outfile = args.single_output_file.name - outfp = get_file_writer(args.single_output_file, args.gzip, args.bzip) - - # the filtering loop - for infile in infiles: - log_info('filtering {infile}', infile=infile) - if not args.single_output_file: - outfile = os.path.basename(infile) + '.abundfilt' - outfp = open(outfile, 'wb') - outfp = get_file_writer(outfp, args.gzip, args.bzip) - - paired_iter = broken_paired_reader(ReadParser(infile), - min_length=ksize, - force_single=True) - - for n, is_pair, read1, read2 in paired_iter: - assert not is_pair - assert read2 is None - - trimmed_record, _ = trim_record(countgraph, read1, args.cutoff, - args.variable_coverage, - args.normalize_to) - if trimmed_record: - write_record(trimmed_record, outfp) - - log_info('output in {outfile}', outfile=outfile) + out_single_ctx = FileWriter(args.single_output_file, args.gzip, + args.bzip) + else: + out_single_ctx = nullcontext() + + with out_single_ctx as out_single_fp: + # the filtering loop + for infile in infiles: + log_info('filtering {infile}', infile=infile) + if not args.single_output_file: + outfile = os.path.basename(infile) + '.abundfilt' + out_ctx = FileWriter(open(outfile, 'wb'), args.gzip, + args.bzip, steal_ownership=True) + else: + out_ctx = nullcontext(enter_result=out_single_fp) + + paired_iter = broken_paired_reader(ReadParser(infile), + min_length=ksize, + force_single=True) + + with out_ctx as outfp: + for n, is_pair, read1, read2 in paired_iter: + assert not is_pair + assert read2 is None + + trimmed_record, _ = trim_record(countgraph, read1, args.cutoff, + args.variable_coverage, + args.normalize_to) + if trimmed_record: + write_record(trimmed_record, outfp) + + log_info('output in {outfile}', outfile=outfile) if __name__ == '__main__': diff --git a/scripts/filter-stoptags.py b/scripts/filter-stoptags.py index a2bc48f170..b97910eafb 100755 --- a/scripts/filter-stoptags.py +++ b/scripts/filter-stoptags.py @@ -108,10 +108,9 @@ def process_fn(record): print('filtering', infile, file=sys.stderr) outfile = os.path.basename(infile) + '.stopfilt' - outfp = open(outfile, 'w') - - tsp = ThreadedSequenceProcessor(process_fn) - tsp.start(verbose_loader(infile), outfp) + with open(outfile, 'w') as outfp: + tsp = ThreadedSequenceProcessor(process_fn) + tsp.start(verbose_loader(infile), outfp) print('output in', outfile, file=sys.stderr) diff --git a/scripts/interleave-reads.py b/scripts/interleave-reads.py index 65c557d5f2..4c02591157 100755 --- a/scripts/interleave-reads.py +++ b/scripts/interleave-reads.py @@ -52,7 +52,7 @@ from khmer.kfile import check_input_files, check_space from khmer.khmer_args import sanitize_help, KhmerArgumentParser from khmer.khmer_args import FileType as khFileType -from khmer.kfile import (add_output_compression_type, get_file_writer, +from khmer.kfile import (add_output_compression_type, FileWriter, describe_file_handle) from khmer.utils import (write_record_pair, check_is_left, check_is_right, check_is_pair) @@ -109,42 +109,41 @@ def main(): print("Interleaving:\n\t%s\n\t%s" % (s1_file, s2_file), file=sys.stderr) - outfp = get_file_writer(args.output, args.gzip, args.bzip) - - counter = 0 - screed_iter_1 = screed.open(s1_file) - screed_iter_2 = screed.open(s2_file) - for read1, read2 in zip_longest(screed_iter_1, screed_iter_2): - if read1 is None or read2 is None: - print(("ERROR: Input files contain different number" - " of records."), file=sys.stderr) - sys.exit(1) + with FileWriter(args.output, args.gzip, args.bzip) as outfp: + counter = 0 + screed_iter_1 = screed.open(s1_file) + screed_iter_2 = screed.open(s2_file) + for read1, read2 in zip_longest(screed_iter_1, screed_iter_2): + if read1 is None or read2 is None: + print(("ERROR: Input files contain different number" + " of records."), file=sys.stderr) + sys.exit(1) - if counter % 100000 == 0: - print('...', counter, 'pairs', file=sys.stderr) - counter += 1 + if counter % 100000 == 0: + print('...', counter, 'pairs', file=sys.stderr) + counter += 1 - name1 = read1.name - name2 = read2.name + name1 = read1.name + name2 = read2.name - if not args.no_reformat: - if not check_is_left(name1): - name1 += '/1' - if not check_is_right(name2): - name2 += '/2' + if not args.no_reformat: + if not check_is_left(name1): + name1 += '/1' + if not check_is_right(name2): + name2 += '/2' - read1.name = name1 - read2.name = name2 + read1.name = name1 + read2.name = name2 - if not check_is_pair(read1, read2): - print("ERROR: This doesn't look like paired data! " - "%s %s" % (read1.name, read2.name), file=sys.stderr) - sys.exit(1) + if not check_is_pair(read1, read2): + print("ERROR: This doesn't look like paired data! " + "%s %s" % (read1.name, read2.name), file=sys.stderr) + sys.exit(1) - write_record_pair(read1, read2, outfp) + write_record_pair(read1, read2, outfp) - print('final: interleaved %d pairs' % counter, file=sys.stderr) - print('output written to', describe_file_handle(outfp), file=sys.stderr) + print('final: interleaved %d pairs' % counter, file=sys.stderr) + print('output written to', describe_file_handle(outfp), file=sys.stderr) if __name__ == '__main__': diff --git a/scripts/normalize-by-median.py b/scripts/normalize-by-median.py index 39e387663e..ac3c84540f 100755 --- a/scripts/normalize-by-median.py +++ b/scripts/normalize-by-median.py @@ -46,6 +46,7 @@ Use '-h' for parameter help. """ +from contextlib import nullcontext import sys import screed import os @@ -60,7 +61,7 @@ import argparse from khmer.kfile import (check_space, check_space_for_graph, check_valid_file_exists, add_output_compression_type, - get_file_writer, describe_file_handle) + FileWriter, describe_file_handle) from khmer.utils import (write_record, broken_paired_reader, ReadBundle, clean_input_reads) from khmer.khmer_logger import (configure_logging, log_info, log_error) @@ -360,39 +361,43 @@ def main(): # pylint: disable=too-many-branches,too-many-statements output_name = None if args.single_output_file: - outfp = get_file_writer(args.single_output_file, args.gzip, args.bzip) + out_single_ctx = FileWriter(args.single_output_file, args.gzip, args.bzip) else: + out_single_ctx = nullcontext() if '-' in filenames or '/dev/stdin' in filenames: print("Accepting input from stdin; output filename must " "be provided with '-o'.", file=sys.stderr) sys.exit(1) - # - # main loop: iterate over all files given, do diginorm. - # - - for filename, require_paired in files: - if not args.single_output_file: - output_name = os.path.basename(filename) + '.keep' - outfp = open(output_name, 'wb') - outfp = get_file_writer(outfp, args.gzip, args.bzip) - - # failsafe context manager in case an input file breaks - with catch_io_errors(filename, outfp, args.single_output_file, - args.force, corrupt_files): - screed_iter = clean_input_reads(screed.open(filename)) - reader = broken_paired_reader(screed_iter, min_length=args.ksize, - force_single=force_single, - require_paired=require_paired) - - # actually do diginorm - for record in with_diagnostics(reader, filename): - if record is not None: - write_record(record, outfp) - - log_info('output in {name}', name=describe_file_handle(outfp)) + with out_single_ctx as out_single_fp: + # + # main loop: iterate over all files given, do diginorm. + # + for filename, require_paired in files: if not args.single_output_file: - outfp.close() + output_name = os.path.basename(filename) + '.keep' + out_ctx = FileWriter(open(output_name, 'wb'), args.gzip, + args.bzip, steal_ownership=True) + else: + out_ctx = nullcontext(enter_result=out_single_fp) + + with out_ctx as outfp: + # failsafe context manager in case an input file breaks + with catch_io_errors(filename, outfp, args.single_output_file, + args.force, corrupt_files): + screed_iter = clean_input_reads(screed.open(filename)) + reader = broken_paired_reader(screed_iter, + min_length=args.ksize, + force_single=force_single, + require_paired=require_paired) + + # actually do diginorm + for record in with_diagnostics(reader, filename): + if record is not None: + write_record(record, outfp) + + log_info('output in {name}', + name=describe_file_handle(outfp)) # finished - print out some diagnostics. diff --git a/scripts/partition-graph.py b/scripts/partition-graph.py index f841fe4848..82e2bf705e 100755 --- a/scripts/partition-graph.py +++ b/scripts/partition-graph.py @@ -143,7 +143,8 @@ def main(): worker_q.put((nodegraph, _, start, end)) print('enqueued %d subset tasks' % n_subsets, file=sys.stderr) - open('%s.info' % basename, 'w').write('%d subsets total\n' % (n_subsets)) + with open('%s.info' % basename, 'w') as info_fp: + info_fp.write('%d subsets total\n' % (n_subsets)) n_threads = args.threads if n_subsets < n_threads: diff --git a/scripts/sample-reads-randomly.py b/scripts/sample-reads-randomly.py index 79b02d764e..c9ded92600 100755 --- a/scripts/sample-reads-randomly.py +++ b/scripts/sample-reads-randomly.py @@ -47,6 +47,7 @@ """ import argparse +from contextlib import nullcontext import os.path import random import textwrap @@ -55,7 +56,7 @@ from khmer import __version__ from khmer import ReadParser from khmer.kfile import (check_input_files, add_output_compression_type, - get_file_writer) + FileWriter) from khmer.khmer_args import sanitize_help, KhmerArgumentParser from khmer.utils import write_record, broken_paired_reader @@ -201,27 +202,27 @@ def main(): print('Writing %d sequences to %s' % (len(reads[0]), output_filename), file=sys.stderr) - output_file = args.output_file - if not output_file: - output_file = open(output_filename, 'wb') + output_back_ctx = nullcontext(args.output_file) + if not args.output_file: + output_back_ctx = open(output_filename, 'wb') - output_file = get_file_writer(output_file, args.gzip, args.bzip) - - for records in reads[0]: - write_record(records[0], output_file) - if records[1] is not None: - write_record(records[1], output_file) + with output_back_ctx as output_back_fp: + with FileWriter(output_back_fp, args.gzip, args.bzip) as output_fp: + for records in reads[0]: + write_record(records[0], output_fp) + if records[1] is not None: + write_record(records[1], output_fp) else: for n in range(num_samples): n_filename = output_filename + '.%d' % n print('Writing %d sequences to %s' % (len(reads[n]), n_filename), file=sys.stderr) - output_file = get_file_writer(open(n_filename, 'wb'), args.gzip, - args.bzip) - for records in reads[n]: - write_record(records[0], output_file) - if records[1] is not None: - write_record(records[1], output_file) + with FileWriter(open(n_filename, 'wb'), args.gzip, args.bzip, + steal_ownership=True) as output_fp: + for records in reads[n]: + write_record(records[0], output_fp) + if records[1] is not None: + write_record(records[1], output_fp) if __name__ == '__main__': diff --git a/scripts/split-paired-reads.py b/scripts/split-paired-reads.py index 5750100312..cd3dd06bc1 100755 --- a/scripts/split-paired-reads.py +++ b/scripts/split-paired-reads.py @@ -44,6 +44,7 @@ Reads FASTQ and FASTA input, retains format for output. """ +from contextlib import nullcontext import sys import os import textwrap @@ -56,7 +57,7 @@ UnpairedReadsError) from khmer.kfile import (check_input_files, check_space, add_output_compression_type, - get_file_writer, describe_file_handle) + FileWriter, describe_file_handle) def get_parser(): @@ -145,22 +146,26 @@ def main(): # OVERRIDE output file locations with -1, -2 if args.output_first: - fp_out1 = get_file_writer(args.output_first, args.gzip, args.bzip) - out1 = fp_out1.name + out1_ctx = FileWriter(args.output_first, args.gzip, args.bzip) + out1 = args.output_first.name else: # Use default filename created above - fp_out1 = get_file_writer(open(out1, 'wb'), args.gzip, args.bzip) + out1_ctx = FileWriter(open(out1, 'wb'), args.gzip, args.bzip, + steal_ownership=True) if args.output_second: - fp_out2 = get_file_writer(args.output_second, args.gzip, args.bzip) - out2 = fp_out2.name + out2_ctx = FileWriter(args.output_second, args.gzip, args.bzip) + out2 = args.output_second.name else: # Use default filename created above - fp_out2 = get_file_writer(open(out2, 'wb'), args.gzip, args.bzip) + out2_ctx = FileWriter(open(out2, 'wb'), args.gzip, args.bzip, + steal_ownership=True) # put orphaned reads here, if -0! if args.output_orphaned: - fp_out0 = get_file_writer(args.output_orphaned, args.gzip, args.bzip) + out0_ctx = FileWriter(args.output_orphaned, args.gzip, args.bzip) out0 = describe_file_handle(args.output_orphaned) + else: + out0_ctx = nullcontext() counter1 = 0 counter2 = 0 @@ -171,23 +176,24 @@ def main(): paired_iter = broken_paired_reader(ReadParser(infile), require_paired=not args.output_orphaned) - try: - for index, is_pair, record1, record2 in paired_iter: - if index % 10000 == 0: - print('...', index, file=sys.stderr) - - if is_pair: - write_record(record1, fp_out1) - counter1 += 1 - write_record(record2, fp_out2) - counter2 += 1 - elif args.output_orphaned: - write_record(record1, fp_out0) - counter3 += 1 - except UnpairedReadsError as e: - print("Unpaired reads found starting at {name}; exiting".format( - name=e.read1.name), file=sys.stderr) - sys.exit(1) + with out0_ctx as fp_out0, out1_ctx as fp_out1, out2_ctx as fp_out2: + try: + for index, is_pair, record1, record2 in paired_iter: + if index % 10000 == 0: + print('...', index, file=sys.stderr) + + if is_pair: + write_record(record1, fp_out1) + counter1 += 1 + write_record(record2, fp_out2) + counter2 += 1 + elif args.output_orphaned: + write_record(record1, fp_out0) + counter3 += 1 + except UnpairedReadsError as e: + print("Unpaired reads found starting at {name}; exiting".format( + name=e.read1.name), file=sys.stderr) + sys.exit(1) print("DONE; split %d sequences (%d left, %d right, %d orphans)" % (counter1 + counter2, counter1, counter2, counter3), file=sys.stderr) diff --git a/scripts/trim-low-abund.py b/scripts/trim-low-abund.py index 8572575111..48b1fa296b 100755 --- a/scripts/trim-low-abund.py +++ b/scripts/trim-low-abund.py @@ -43,6 +43,7 @@ Use -h for parameter help. """ +from contextlib import nullcontext import csv import sys import os @@ -63,7 +64,7 @@ from khmer.utils import write_record, broken_paired_reader, ReadBundle from khmer.kfile import (check_space, check_space_for_graph, check_valid_file_exists, add_output_compression_type, - get_file_writer) + get_file_writer, FileWriter) from khmer.khmer_logger import configure_logging, log_info, log_error from khmer.trimming import trim_record @@ -374,108 +375,111 @@ def main(): # only create the file writer once if outfp is specified; otherwise, # create it for each file. if args.output: - trimfp = get_file_writer(args.output, args.gzip, args.bzip) + trim_ctx = FileWriter(args.output, args.gzip, args.bzip) + else: + trim_ctx = nullcontext() pass2list = [] - for filename in args.input_filenames: - # figure out temporary filename for 2nd pass - pass2filename = filename.replace(os.path.sep, '-') + '.pass2' - pass2filename = os.path.join(tempdir, pass2filename) - pass2fp = open(pass2filename, 'w') - - # construct output filenames - if args.output is None: - # note: this will be saved in trimfp. - outfp = open(os.path.basename(filename) + '.abundtrim', 'wb') - - # get file handle w/gzip, bzip - trimfp = get_file_writer(outfp, args.gzip, args.bzip) - - # record all this info - pass2list.append((filename, pass2filename, trimfp)) - - # input file stuff: get a broken_paired reader. - paired_iter = broken_paired_reader(ReadParser(filename), min_length=K, - force_single=args.ignore_pairs) - - # main loop through the file. - n_start = trimmer.n_reads - save_start = trimmer.n_saved - - watermark = REPORT_EVERY_N_READS - for read in trimmer.pass1(paired_iter, pass2fp): - if (trimmer.n_reads - n_start) > watermark: - log_info("... {filename} {n_saved} {n_reads} {n_bp} " - "{w_reads} {w_bp}", filename=filename, - n_saved=trimmer.n_saved, n_reads=trimmer.n_reads, - n_bp=trimmer.n_bp, w_reads=written_reads, - w_bp=written_bp) - watermark += REPORT_EVERY_N_READS - - # write out the trimmed/etc sequences that AREN'T going to be - # revisited in a 2nd pass. - write_record(read, trimfp) - written_bp += len(read) - written_reads += 1 - pass2fp.close() - - log_info("{filename}: kept aside {kept} of {total} from first pass", - filename=filename, kept=trimmer.n_saved - save_start, - total=trimmer.n_reads - n_start) - - # first pass goes across all the data, so record relevant stats... - n_reads = trimmer.n_reads - n_bp = trimmer.n_bp - n_skipped = trimmer.n_skipped - bp_skipped = trimmer.bp_skipped - save_pass2_total = trimmer.n_saved - - # ### SECOND PASS. ### - - # nothing should have been skipped yet! - assert trimmer.n_skipped == 0 - assert trimmer.bp_skipped == 0 - - if args.single_pass: - pass2list = [] - - # go back through all the files again. - for _, pass2filename, trimfp in pass2list: - log_info('second pass: looking at sequences kept aside in {pass2}', - pass2=pass2filename) - - # note that for this second pass, we don't care about paired - # reads - they will be output in the same order they're read in, - # so pairs will stay together if not orphaned. This is in contrast - # to the first loop. Hence, force_single=True below. - - read_parser = ReadParser(pass2filename) - paired_iter = broken_paired_reader(read_parser, - min_length=K, - force_single=True) - - watermark = REPORT_EVERY_N_READS - for read in trimmer.pass2(paired_iter): - if (trimmer.n_reads - n_start) > watermark: - log_info('... x 2 {a} {b} {c} {d} {e} {f} {g}', - a=trimmer.n_reads - n_start, - b=pass2filename, c=trimmer.n_saved, - d=trimmer.n_reads, e=trimmer.n_bp, - f=written_reads, g=written_bp) - watermark += REPORT_EVERY_N_READS - - write_record(read, trimfp) - written_reads += 1 - written_bp += len(read) - - read_parser.close() - - log_info('removing {pass2}', pass2=pass2filename) - os.unlink(pass2filename) - - # if we created our own trimfps, close 'em. - if not args.output: - trimfp.close() + with trim_ctx as trimfp: + for filename in args.input_filenames: + # figure out temporary filename for 2nd pass + pass2filename = filename.replace(os.path.sep, '-') + '.pass2' + pass2filename = os.path.join(tempdir, pass2filename) + pass2fp = open(pass2filename, 'w') + + # construct output filenames + if args.output is None: + # note: this will be saved in trimfp. + outfp = open(os.path.basename(filename) + '.abundtrim', 'wb') + + # get file handle w/gzip, bzip + trimfp = get_file_writer(outfp, args.gzip, args.bzip) + + # record all this info + pass2list.append((filename, pass2filename, trimfp)) + + # input file stuff: get a broken_paired reader. + paired_iter = broken_paired_reader(ReadParser(filename), min_length=K, + force_single=args.ignore_pairs) + + # main loop through the file. + n_start = trimmer.n_reads + save_start = trimmer.n_saved + + watermark = REPORT_EVERY_N_READS + for read in trimmer.pass1(paired_iter, pass2fp): + if (trimmer.n_reads - n_start) > watermark: + log_info("... {filename} {n_saved} {n_reads} {n_bp} " + "{w_reads} {w_bp}", filename=filename, + n_saved=trimmer.n_saved, n_reads=trimmer.n_reads, + n_bp=trimmer.n_bp, w_reads=written_reads, + w_bp=written_bp) + watermark += REPORT_EVERY_N_READS + + # write out the trimmed/etc sequences that AREN'T going to be + # revisited in a 2nd pass. + write_record(read, trimfp) + written_bp += len(read) + written_reads += 1 + pass2fp.close() + + log_info("{filename}: kept aside {kept} of {total} from first pass", + filename=filename, kept=trimmer.n_saved - save_start, + total=trimmer.n_reads - n_start) + + # first pass goes across all the data, so record relevant stats... + n_reads = trimmer.n_reads + n_bp = trimmer.n_bp + n_skipped = trimmer.n_skipped + bp_skipped = trimmer.bp_skipped + save_pass2_total = trimmer.n_saved + + # ### SECOND PASS. ### + + # nothing should have been skipped yet! + assert trimmer.n_skipped == 0 + assert trimmer.bp_skipped == 0 + + if args.single_pass: + pass2list = [] + + # go back through all the files again. + for _, pass2filename, trimfp in pass2list: + log_info('second pass: looking at sequences kept aside in {pass2}', + pass2=pass2filename) + + # note that for this second pass, we don't care about paired + # reads - they will be output in the same order they're read in, + # so pairs will stay together if not orphaned. This is in contrast + # to the first loop. Hence, force_single=True below. + + read_parser = ReadParser(pass2filename) + paired_iter = broken_paired_reader(read_parser, + min_length=K, + force_single=True) + + watermark = REPORT_EVERY_N_READS + for read in trimmer.pass2(paired_iter): + if (trimmer.n_reads - n_start) > watermark: + log_info('... x 2 {a} {b} {c} {d} {e} {f} {g}', + a=trimmer.n_reads - n_start, + b=pass2filename, c=trimmer.n_saved, + d=trimmer.n_reads, e=trimmer.n_bp, + f=written_reads, g=written_bp) + watermark += REPORT_EVERY_N_READS + + write_record(read, trimfp) + written_reads += 1 + written_bp += len(read) + + read_parser.close() + + log_info('removing {pass2}', pass2=pass2filename) + os.unlink(pass2filename) + + # if we created our own trimfps, close 'em. + if not args.output: + trimfp.close() try: log_info('removing temp directory & contents ({temp})', temp=tempdir)