Skip to content

Commit

Permalink
Ensure that Python scripts close files that they open for writing
Browse files Browse the repository at this point in the history
Python scripts under scripts/ in the source tree do not consistently
close files that they open for writing. While some of the scripts use
context managers, most of them do not (or do so inconsistently).  In
previous releases of Python, this apparently was not much of a concern.
However, Python 3.12 seems to be much less forgiving when files are not
properly closed. When running the test suite, many of the files that are
not explicitly closed appear truncated. This leads to various tests
failing or hanging.

Furthermore, khmer defines the get_file_writer() function, but it cannot
be consistently used as a context manager because it sometimes closes
the underlying file descriptor ; and sometimes does not depending on the
arguments.

Fixed by defining a new FileWriter context manager and ensuring that
each call to open() / get_file_writer() frees up resources properly.

Signed-off-by: Olivier Gayot <[email protected]>
  • Loading branch information
ogayot committed Nov 26, 2023
1 parent d71d576 commit 982832b
Show file tree
Hide file tree
Showing 17 changed files with 396 additions and 317 deletions.
35 changes: 35 additions & 0 deletions khmer/kfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""File handling/checking utilities for command-line scripts."""


import contextlib
import os
import sys
import errno
Expand Down Expand Up @@ -246,3 +247,37 @@ def get_file_writer(file_handle, do_gzip, do_bzip):
ofile = file_handle

return ofile


@contextlib.contextmanager
def FileWriter(file_handle, do_gzip, do_bzip, *, steal_ownership=False):
"""Alternative to get_file_writer that requires the use of a with block.
The intent is to address an inherent problem with get_file_writer() that
makes it difficult to use as a context manager. When get_file_writer() is
called with both gzip=False and bzip=False, the underlying file handle is
returned. As a consequence, doing:
> with get_file_writer(sys.stdout, bzip=False, gzip=False) as fh:
> pass
ends up closing sys.stdout when the with block is exited. Using the
function without a context manager avoids the issue, but then it results in
leaked open files when either bzip=True or gzip=True.
FileWriter must be used as a context manager, but it ensures that resources
are closed upon exiting the with block. Furthermore, it can be explicitly
requested to close the underlying file_handle."""
ofile = None

if do_gzip and do_bzip:
raise ValueError("Cannot specify both bzip and gzip compression!")

if do_gzip:
ofile = gzip.GzipFile(fileobj=file_handle, mode='w')
elif do_bzip:
ofile = bz2.open(file_handle, mode='w')
else:
ofile = contextlib.nullcontext(enter_result=file_handle)

with ofile as x:
yield x

if steal_ownership:
file_handle.close()
4 changes: 4 additions & 0 deletions scripts/abundance-dist-single.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def __do_abundance_dist__(read_parser):

log_info('wrote to: {output}', output=args.output_histogram_filename)

# Ensure that the output files are properly written. Python 3.12 seems to
# be less forgiving here ..
hist_fp.close()


if __name__ == '__main__':
main()
Expand Down
33 changes: 18 additions & 15 deletions scripts/abundance-dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Use '-h' for parameter help.
"""

import contextlib
import sys
import csv
import khmer
Expand Down Expand Up @@ -143,26 +144,28 @@ def main():
sys.exit(1)

if args.output_histogram_filename in ('-', '/dev/stdout'):
countgraph_fp = sys.stdout
countgraph_ctx = contextlib.nullcontext(enter_result=sys.stdout)
else:
countgraph_fp = open(args.output_histogram_filename, 'w')
countgraph_fp_csv = csv.writer(countgraph_fp)
# write headers:
countgraph_fp_csv.writerow(['abundance', 'count', 'cumulative',
'cumulative_fraction'])
countgraph_ctx = open(args.output_histogram_filename, 'w')

sofar = 0
for _, i in enumerate(abundances):
if i == 0 and not args.output_zero:
continue
with countgraph_ctx as countgraph_fp:
countgraph_fp_csv = csv.writer(countgraph_fp)
# write headers:
countgraph_fp_csv.writerow(['abundance', 'count', 'cumulative',
'cumulative_fraction'])

sofar += i
frac = sofar / float(total)
sofar = 0
for _, i in enumerate(abundances):
if i == 0 and not args.output_zero:
continue

countgraph_fp_csv.writerow([_, i, sofar, round(frac, 3)])
sofar += i
frac = sofar / float(total)

if sofar == total:
break
countgraph_fp_csv.writerow([_, i, sofar, round(frac, 3)])

if sofar == total:
break


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions scripts/do-partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def main(): # pylint: disable=too-many-locals,too-many-statements
worker_q.put((nodegraph, _, start, end))

print('enqueued %d subset tasks' % n_subsets, file=sys.stderr)
open('%s.info' % args.graphbase, 'w').write('%d subsets total\n'
% (n_subsets))
with open('%s.info' % args.graphbase, 'w') as info_fp:
info_fp.write('%d subsets total\n' % (n_subsets))

if n_subsets < args.threads:
args.threads = n_subsets
Expand Down
14 changes: 7 additions & 7 deletions scripts/extract-long-sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import sys
from khmer import __version__
from khmer.utils import write_record
from khmer.kfile import add_output_compression_type, get_file_writer
from khmer.kfile import add_output_compression_type, FileWriter
from khmer.khmer_args import sanitize_help, KhmerArgumentParser


Expand Down Expand Up @@ -81,12 +81,12 @@ def get_parser():

def main():
args = sanitize_help(get_parser()).parse_args()
outfp = get_file_writer(args.output, args.gzip, args.bzip)
for filename in args.input_filenames:
for record in screed.open(filename):
if len(record['sequence']) >= args.length:
write_record(record, outfp)
print('wrote to: ' + args.output.name, file=sys.stderr)
with FileWriter(args.output, args.gzip, args.bzip) as outfp:
for filename in args.input_filenames:
for record in screed.open(filename):
if len(record['sequence']) >= args.length:
write_record(record, outfp)
print('wrote to: ' + args.output.name, file=sys.stderr)


if __name__ == '__main__':
Expand Down
58 changes: 30 additions & 28 deletions scripts/extract-paired-reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Reads FASTQ and FASTA input, retains format for output.
"""
from contextlib import nullcontext
import sys
import os.path
import textwrap
Expand All @@ -53,7 +54,7 @@
from khmer.khmer_args import sanitize_help, KhmerArgumentParser
from khmer.khmer_args import FileType as khFileType
from khmer.kfile import add_output_compression_type
from khmer.kfile import get_file_writer
from khmer.kfile import FileWriter

from khmer.utils import broken_paired_reader, write_record, write_record_pair

Expand Down Expand Up @@ -132,39 +133,40 @@ def main():

# OVERRIDE default output file locations with -p, -s
if args.output_paired:
paired_fp = get_file_writer(args.output_paired, args.gzip, args.bzip)
out2 = paired_fp.name
paired_ctx = FileWriter(args.output_paired, args.gzip, args.bzip)
out2 = args.output_paired.name
else:
# Don't override, just open the default filename from above
paired_fp = get_file_writer(open(out2, 'wb'), args.gzip, args.bzip)
paired_ctx = FileWriter(open(out2, 'wb'), args.gzip, args.bzip,
steal_ownership=True)
if args.output_single:
single_fp = get_file_writer(args.output_single, args.gzip, args.bzip)
single_ctx = FileWriter(args.output_single, args.gzip, args.bzip)
out1 = args.output_single.name
else:
# Don't override, just open the default filename from above
single_fp = get_file_writer(open(out1, 'wb'), args.gzip, args.bzip)

print('reading file "%s"' % infile, file=sys.stderr)
print('outputting interleaved pairs to "%s"' % out2, file=sys.stderr)
print('outputting orphans to "%s"' % out1, file=sys.stderr)

n_pe = 0
n_se = 0

reads = ReadParser(infile)
for index, is_pair, read1, read2 in broken_paired_reader(reads):
if index % 100000 == 0 and index > 0:
print('...', index, file=sys.stderr)

if is_pair:
write_record_pair(read1, read2, paired_fp)
n_pe += 1
else:
write_record(read1, single_fp)
n_se += 1

single_fp.close()
paired_fp.close()
single_ctx = FileWriter(open(out1, 'wb'), args.gzip, args.bzip,
steal_ownership=True)

with paired_ctx as paired_fp, single_ctx as single_fp:
print('reading file "%s"' % infile, file=sys.stderr)
print('outputting interleaved pairs to "%s"' % out2,
file=sys.stderr)
print('outputting orphans to "%s"' % out1, file=sys.stderr)

n_pe = 0
n_se = 0

reads = ReadParser(infile)
for index, is_pair, read1, read2 in broken_paired_reader(reads):
if index % 100000 == 0 and index > 0:
print('...', index, file=sys.stderr)

if is_pair:
write_record_pair(read1, read2, paired_fp)
n_pe += 1
else:
write_record(read1, single_fp)
n_se += 1

if n_pe == 0:
raise Exception("no paired reads!? check file formats...")
Expand Down
21 changes: 16 additions & 5 deletions scripts/extract-partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,10 @@ def main():
args.max_size)

if args.output_unassigned:
ofile = open('%s.unassigned.%s' % (args.prefix, suffix), 'wb')
unassigned_fp = get_file_writer(ofile, args.gzip, args.bzip)
extractor.process_unassigned(unassigned_fp)
unassigned_fp.close()
with open('%s.unassigned.%s' % (args.prefix, suffix), 'wb') as ofile:
unassigned_fp = get_file_writer(ofile, args.gzip, args.bzip)
extractor.process_unassigned(unassigned_fp)
unassigned_fp.close()
else:
extractor.process_unassigned()

Expand All @@ -320,13 +320,21 @@ def main():
print('nothing to output; exiting!', file=sys.stderr)
return

to_close = []
# open a bunch of output files for the different groups
group_fps = {}
for index in range(extractor.group_n):
fname = '%s.group%04d.%s' % (args.prefix, index, suffix)
group_fp = get_file_writer(open(fname, 'wb'), args.gzip,
back_fp = open(fname, 'wb')
group_fp = get_file_writer(back_fp, args.gzip,
args.bzip)
group_fps[index] = group_fp
# It feels more natural to close the writer before closing the
# underlying file. fp.close() is theoretically idempotent, so it should
# be fine even though sometimes get_file_writer "steals" ownership of
# the underlying stream.
to_close.append(group_fp)
to_close.append(back_fp)

# write 'em all out!
# refresh the generator
Expand All @@ -351,6 +359,9 @@ def main():
args.prefix,
suffix), file=sys.stderr)

for fp in to_close:
fp.close()


if __name__ == '__main__':
main()
26 changes: 13 additions & 13 deletions scripts/fastq-to-fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import sys
import screed
from khmer import __version__
from khmer.kfile import (add_output_compression_type, get_file_writer,
from khmer.kfile import (add_output_compression_type, FileWriter,
describe_file_handle)
from khmer.utils import write_record
from khmer.khmer_args import sanitize_help, KhmerArgumentParser
Expand Down Expand Up @@ -74,21 +74,21 @@ def main():
args = sanitize_help(get_parser()).parse_args()

print('fastq from ', args.input_sequence, file=sys.stderr)
outfp = get_file_writer(args.output, args.gzip, args.bzip)
n_count = 0
for n, record in enumerate(screed.open(args.input_sequence)):
if n % 10000 == 0:
print('...', n, file=sys.stderr)
with FileWriter(args.output, args.gzip, args.bzip) as outfp:
n_count = 0
for n, record in enumerate(screed.open(args.input_sequence)):
if n % 10000 == 0:
print('...', n, file=sys.stderr)

sequence = record['sequence']
sequence = record['sequence']

if 'N' in sequence:
if not args.n_keep:
n_count += 1
continue
if 'N' in sequence:
if not args.n_keep:
n_count += 1
continue

del record['quality']
write_record(record, outfp)
del record['quality']
write_record(record, outfp)

print('\n' + 'lines from ' + args.input_sequence, file=sys.stderr)

Expand Down
35 changes: 18 additions & 17 deletions scripts/filter-abund-single.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from khmer.kfile import (check_input_files, check_space,
check_space_for_graph,
add_output_compression_type,
get_file_writer)
FileWriter)
from khmer.khmer_logger import (configure_logging, log_info, log_error,
log_warn)
from khmer.trimming import (trim_record)
Expand Down Expand Up @@ -160,22 +160,23 @@ def main():
outfile = os.path.basename(args.datafile) + '.abundfilt'
else:
outfile = args.outfile
outfp = open(outfile, 'wb')
outfp = get_file_writer(outfp, args.gzip, args.bzip)

paired_iter = broken_paired_reader(ReadParser(args.datafile),
min_length=graph.ksize(),
force_single=True)

for n, is_pair, read1, read2 in paired_iter:
assert not is_pair
assert read2 is None

trimmed_record, _ = trim_record(graph, read1, args.cutoff,
args.variable_coverage,
args.normalize_to)
if trimmed_record:
write_record(trimmed_record, outfp)

with FileWriter(open(outfile, 'wb'), args.gzip, args.bzip,
steal_ownership=True) as outfp:

paired_iter = broken_paired_reader(ReadParser(args.datafile),
min_length=graph.ksize(),
force_single=True)

for n, is_pair, read1, read2 in paired_iter:
assert not is_pair
assert read2 is None

trimmed_record, _ = trim_record(graph, read1, args.cutoff,
args.variable_coverage,
args.normalize_to)
if trimmed_record:
write_record(trimmed_record, outfp)

log_info('output in {outfile}', outfile=outfile)

Expand Down
Loading

0 comments on commit 982832b

Please sign in to comment.