diff --git a/src/lib.rs b/src/lib.rs index ab89657..d1d0cf1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,6 @@ -use pyo3::prelude::*; use external_sort::{ExternalSorter, ExternallySortable}; -use itertools::Itertools; -use std::io::{BufRead, BufReader, Read, Write}; +use pyo3::prelude::*; +use std::io::{BufRead, BufReader, BufWriter, Read, SeekFrom, Write}; // Define a string structure that can be sorted externally #[derive(Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -38,6 +37,34 @@ fn sort_lines(lines: Vec) -> Vec { lines } +fn streaming_sort_until<'a, IN, OUT, STR>( + input: IN, + mut output: OUT, + end: &str, +) -> std::io::Result<()> +where + // can take &str or &String + STR: AsRef, + IN: Iterator, + OUT: Write, +{ + let input = input + .take_while(|l| l.as_ref() != end) + .map(|l| TsvLine::new(l.as_ref())); + // Do the external sort + let sorted = ExternalSorter::new(1000000, None) + .sort_by(input, |a, b| { + tsv_cmp(a.the_line.as_str(), b.the_line.as_str()) + }) + .unwrap(); + // Write the sorted lines to the output file + for line in sorted { + writeln!(&mut output, "{}", line.unwrap().the_line)?; + } + writeln!(&mut output, "{end}")?; + Ok(()) +} + /// Merge sort a range of lines from an input file and write the result to another file. /// /// The function `sort_file_lines` seeks to the given start position in the input file, reads @@ -76,31 +103,22 @@ fn sort_lines(lines: Vec) -> Vec { /// #[pyfunction] fn sort_file_lines(input: &str, output: &str, start: u64, end: &str) -> PyResult { - // Open the input file and seek to the start position - let mut input_file = std::fs::File::open(input)?; - input_file.seek(std::io::SeekFrom::Start(start))?; - // Wrap the input file in a buffered reader - let mut input = BufReader::new(&mut input_file); - // Create an iterator which reads lines until the end marker and doesn't consume the end marker - let mut binding = input.by_ref().lines().peekable(); - let lines = binding - .peeking_take_while(|line| line.as_ref().map(|l| l != end).unwrap_or(false)) - .map(|line| TsvLine::new(&line.unwrap())); - // Do the external sort - let iter = ExternalSorter::new(1000000, None).sort_by( - lines, - |a, b| tsv_cmp(a.the_line.as_str(), b.the_line.as_str()), - ).unwrap(); - // Write the sorted lines to the output file - let output_file = std::fs::File::create(output)?; - let mut output = std::io::BufWriter::new(output_file); - for line in iter { - writeln!(output, "{}", line.unwrap().the_line)?; - } - // Write the end marker (which was not consumed by peeking_take_while) - writeln!(output, "{}", binding.next().unwrap().unwrap())?; + // Open the input file with a buffer + let mut input = BufReader::new(File::open(input)?); + + // Seek to the start position + input.seek(SeekFrom::Start(start))?; + + // Open the output file + let mut output = BufWriter::new(File::create(output)?); + + // Create the lines iterator + let lines = input.by_ref().lines().map(|l| l.unwrap()); + streaming_sort_until(lines, output, end)?; + // sort_until(input.as_ref().lines(), &mut output)?; + // return the stream position from the counting reader object - Ok(input.stream_position().unwrap()) + Ok(input.stream_position().unwrap() - (end.bytes().len() as u64)) } /// A Python module implemented in Rust. @@ -413,4 +431,10 @@ mod tests { expected as i8, ); } + #[test] + fn streaming_sort_smoke() { + let mut res = Vec::new(); + streaming_sort_until("1\n3\n2\nEND".lines(), &mut res, "END").unwrap(); + assert_eq!(std::str::from_utf8(&res).unwrap(), "1\n2\n3\nEND\n"); + } }