From 868d5a85b9f448edf49f9d60952139272a6c6c62 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 13 Nov 2024 22:04:40 +1100 Subject: [PATCH] c --- .../src/plans/optimizer/slice_pushdown_lp.rs | 24 + .../src/async_primitives/wait_group.rs | 28 + crates/polars-stream/src/nodes/csv_source.rs | 675 ++++++++++++++++++ crates/polars-stream/src/nodes/mod.rs | 1 + .../src/nodes/parquet_source/init.rs | 40 +- .../src/nodes/parquet_source/mod.rs | 6 +- .../src/physical_plan/lower_ir.rs | 109 ++- .../src/physical_plan/to_graph.rs | 18 + py-polars/polars/_utils/various.py | 6 +- py-polars/polars/io/csv/functions.py | 60 +- py-polars/tests/unit/io/test_csv.py | 15 +- py-polars/tests/unit/io/test_scan.py | 7 + 12 files changed, 897 insertions(+), 92 deletions(-) create mode 100644 crates/polars-stream/src/nodes/csv_source.rs diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index a5ff806abae9..6ecae038f49f 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -166,6 +166,30 @@ impl SlicePushDown { Ok(lp) } #[cfg(feature = "csv")] + (Scan { + sources, + file_info, + hive_parts, + output_schema, + mut file_options, + predicate, + scan_type: FileScan::Csv { options, cloud_options }, + }, Some(state)) if predicate.is_none() && self.new_streaming => { + file_options.slice = Some((state.offset, state.len as usize)); + + let lp = Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type: FileScan::Csv { options, cloud_options }, + file_options, + predicate, + }; + + Ok(lp) + }, + #[cfg(feature = "csv")] (Scan { sources, file_info, diff --git a/crates/polars-stream/src/async_primitives/wait_group.rs b/crates/polars-stream/src/async_primitives/wait_group.rs index e08f556d3b95..8cfa150747a2 100644 --- a/crates/polars-stream/src/async_primitives/wait_group.rs +++ b/crates/polars-stream/src/async_primitives/wait_group.rs @@ -38,6 +38,34 @@ impl WaitGroup { } } +// Wait group with an associated index. +pub struct IndexedWaitGroup { + index: usize, + wait_group: WaitGroup, +} + +impl IndexedWaitGroup { + pub fn new(index: usize) -> Self { + Self { + index, + wait_group: Default::default(), + } + } + + pub fn index(&self) -> usize { + self.index + } + + pub fn token(&self) -> WaitToken { + self.wait_group.token() + } + + pub async fn wait(self) -> Self { + self.wait_group.wait().await; + self + } +} + struct WaitGroupFuture<'a> { inner: &'a Arc, } diff --git a/crates/polars-stream/src/nodes/csv_source.rs b/crates/polars-stream/src/nodes/csv_source.rs new file mode 100644 index 000000000000..d6d8388d44e9 --- /dev/null +++ b/crates/polars-stream/src/nodes/csv_source.rs @@ -0,0 +1,675 @@ +use std::future::Future; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use futures::stream::FuturesUnordered; +use futures::StreamExt; +use polars_core::prelude::{AnyValue, DataType, Field}; +use polars_core::scalar::Scalar; +use polars_core::schema::{SchemaExt, SchemaRef}; +use polars_core::{config, StringCacheHolder}; +use polars_error::{polars_bail, PolarsResult}; +use polars_io::prelude::_csv_read_internal::{ + cast_columns, find_starting_point, prepare_csv_schema, read_chunk, CountLines, + NullValuesCompiled, +}; +use polars_io::prelude::buffer::validate_utf8; +use polars_io::prelude::{CommentPrefix, CsvEncoding, CsvReadOptions}; +use polars_io::utils::compression::maybe_decompress_bytes; +use polars_io::utils::slice::SplitSlicePosition; +use polars_io::RowIndex; +use polars_plan::plans::{FileInfo, ScanSources}; +use polars_plan::prelude::FileScanOptions; +use polars_utils::mmap::MemSlice; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::IdxSize; + +use super::compute_node_prelude::*; +use super::{MorselSeq, TaskPriority}; +use crate::async_executor; +use crate::async_primitives::connector::connector; +use crate::async_primitives::wait_group::{IndexedWaitGroup, WaitToken}; +use crate::morsel::SourceToken; + +struct LineBatch { + bytes: MemSlice, + n_lines: usize, + slice: (usize, usize), + row_offset: usize, + morsel_seq: MorselSeq, + wait_token: WaitToken, + path_name: Option, +} + +type AsyncTaskData = ( + Vec>, + Arc, + async_executor::AbortOnDropHandle>, +); + +pub struct CsvSourceNode { + scan_sources: ScanSources, + file_info: FileInfo, + file_options: FileScanOptions, + options: CsvReadOptions, + schema: Option, + num_pipelines: usize, + async_task_data: Arc>>, + is_finished: Arc, + verbose: bool, +} + +impl CsvSourceNode { + pub fn new( + scan_sources: ScanSources, + file_info: FileInfo, + file_options: FileScanOptions, + options: CsvReadOptions, + ) -> Self { + let verbose = config::verbose(); + + Self { + scan_sources, + file_info, + file_options, + options, + schema: None, + num_pipelines: 0, + async_task_data: Arc::new(tokio::sync::Mutex::new(None)), + is_finished: Arc::new(AtomicBool::new(false)), + verbose, + } + } +} + +impl ComputeNode for CsvSourceNode { + fn name(&self) -> &str { + "csv_source" + } + + fn initialize(&mut self, num_pipelines: usize) { + self.num_pipelines = num_pipelines; + + if self.verbose { + eprintln!("[CsvSource]: initialize"); + } + + self.schema = Some(self.file_info.reader_schema.take().unwrap().unwrap_right()); + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + use std::sync::atomic::Ordering; + + assert!(recv.is_empty()); + assert_eq!(send.len(), 1); + + if self.is_finished.load(Ordering::Relaxed) { + send[0] = PortState::Done; + assert!( + self.async_task_data.try_lock().unwrap().is_none(), + "should have already been shut down" + ); + } else if send[0] == PortState::Done { + { + // Early shutdown - our port state was set to `Done` by the downstream nodes. + self.shutdown_in_background(); + }; + self.is_finished.store(true, Ordering::Relaxed); + } else { + send[0] = PortState::Ready + } + + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv_ports: &mut [Option>], + send_ports: &mut [Option>], + _state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + use std::sync::atomic::Ordering; + + assert!(recv_ports.is_empty()); + assert_eq!(send_ports.len(), 1); + assert!(!self.is_finished.load(Ordering::Relaxed)); + + let morsel_senders = send_ports[0].take().unwrap().parallel(); + + let mut async_task_data_guard = { + let guard = self.async_task_data.try_lock().unwrap(); + + if guard.is_some() { + guard + } else { + drop(guard); + let v = self.init_line_batch_source(); + let mut guard = self.async_task_data.try_lock().unwrap(); + guard.replace(v); + guard + } + }; + + let (line_batch_receivers, chunk_reader, _) = async_task_data_guard.as_mut().unwrap(); + + assert_eq!(line_batch_receivers.len(), morsel_senders.len()); + + let is_finished = self.is_finished.clone(); + let source_token = SourceToken::new(); + + let task_handles = line_batch_receivers + .drain(..) + .zip(morsel_senders) + .map(|(mut line_batch_rx, mut morsel_tx)| { + let is_finished = is_finished.clone(); + let chunk_reader = chunk_reader.clone(); + let source_token = source_token.clone(); + + scope.spawn_task(TaskPriority::Low, async move { + loop { + let Ok(LineBatch { + bytes, + n_lines, + slice: (offset, len), + row_offset, + morsel_seq, + wait_token, + mut path_name, + }) = line_batch_rx.recv().await + else { + is_finished.store(true, Ordering::Relaxed); + break; + }; + + let mut df = + chunk_reader.read_chunk(&bytes, n_lines, (offset, len), row_offset)?; + + if let Some(path_name) = path_name.take() { + unsafe { + df.with_column_unchecked( + Scalar::new(DataType::String, AnyValue::StringOwned(path_name)) + .into_column( + chunk_reader.include_file_paths.clone().unwrap(), + ) + .new_from_index(0, df.height()), + ) + }; + } + + let mut morsel = Morsel::new(df, morsel_seq, source_token.clone()); + morsel.set_consume_token(wait_token); + + if morsel_tx.send(morsel).await.is_err() { + break; + } + + if source_token.stop_requested() { + break; + } + } + + PolarsResult::Ok(line_batch_rx) + }) + }) + .collect::>(); + + drop(async_task_data_guard); + + let async_task_data = self.async_task_data.clone(); + + join_handles.push(scope.spawn_task(TaskPriority::Low, async move { + { + let mut async_task_data_guard = async_task_data.try_lock().unwrap(); + let (line_batch_receivers, ..) = async_task_data_guard.as_mut().unwrap(); + + for handle in task_handles { + line_batch_receivers.push(handle.await?); + } + } + + if self.is_finished.load(Ordering::Relaxed) { + self.shutdown().await?; + } + + Ok(()) + })) + } +} + +impl CsvSourceNode { + fn init_line_batch_source(&mut self) -> AsyncTaskData { + let verbose = self.verbose; + + let (mut line_batch_senders, line_batch_receivers): (Vec<_>, Vec<_>) = + (0..self.num_pipelines).map(|_| connector()).unzip(); + + let scan_sources = self.scan_sources.clone(); + let run_async = scan_sources.is_cloud_url() || config::force_async(); + let num_pipelines = self.num_pipelines; + + let schema_len = self.schema.as_ref().unwrap().len(); + + let options = &self.options; + let parse_options = self.options.parse_options.as_ref(); + + let quote_char = parse_options.quote_char; + let eol_char = parse_options.eol_char; + + let skip_rows_before_header = options.skip_rows; + let skip_rows_after_header = options.skip_rows_after_header; + let comment_prefix = parse_options.comment_prefix.clone(); + let has_header = options.has_header; + let global_slice = self.file_options.slice; + let include_file_paths = self.file_options.include_file_paths.is_some(); + + if verbose { + eprintln!( + "[CsvSource]: slice: {:?}, row_index: {:?}", + global_slice, &self.file_options.row_index + ) + } + + let line_batch_source_task_handle = async_executor::AbortOnDropHandle::new( + async_executor::spawn(TaskPriority::Low, async move { + let global_slice = if let Some((offset, len)) = global_slice { + if offset < 0 { + polars_bail!( + ComputeError: + "unsupported negative slice offset {} for CSV source", + offset + ); + } + Some(offset as usize..offset as usize + len) + } else { + None + }; + + let mut wait_groups = (0..num_pipelines) + .map(|index| IndexedWaitGroup::new(index).wait()) + .collect::>(); + let morsel_seq_ref = &mut MorselSeq::default(); + let current_row_offset_ref = &mut 0usize; + + let n_parts_hint = num_pipelines * 16; + + let line_counter = CountLines::new(quote_char, eol_char); + + let comment_prefix = comment_prefix.as_ref(); + + 'main: for (i, v) in scan_sources + .iter() + .map(|x| { + let bytes = x.to_memslice_async_assume_latest(run_async)?; + PolarsResult::Ok(( + bytes, + include_file_paths.then(|| x.to_include_path_name().into()), + )) + }) + .enumerate() + { + if verbose { + eprintln!( + "[CsvSource]: Start line splitting for file {} / {}", + 1 + i, + scan_sources.len() + ); + } + let (mem_slice, path_name) = v?; + let mem_slice = { + let mut out = vec![]; + maybe_decompress_bytes(&mem_slice, &mut out)?; + + if out.is_empty() { + mem_slice + } else { + MemSlice::from_vec(out) + } + }; + + let bytes = mem_slice.as_ref(); + + let i = find_starting_point( + bytes, + quote_char, + eol_char, + schema_len, + skip_rows_before_header, + skip_rows_after_header, + comment_prefix, + has_header, + )?; + + let mut bytes = &bytes[i..]; + + let mut chunk_size = { + let max_chunk_size = 16 * 1024 * 1024; + let chunk_size = if global_slice.is_some() { + max_chunk_size + } else { + std::cmp::min(bytes.len() / n_parts_hint, max_chunk_size) + }; + + // Use a small min chunk size to catch failures in tests. + #[cfg(debug_assertions)] + let min_chunk_size = 64; + #[cfg(not(debug_assertions))] + let min_chunk_size = 1024 * 4; + std::cmp::max(chunk_size, min_chunk_size) + }; + + loop { + if bytes.is_empty() { + break; + } + + let (count, position) = line_counter.find_next(bytes, &mut chunk_size); + let (count, position) = if count == 0 { + (1, bytes.len()) + } else { + let pos = (position + 1).min(bytes.len()); // +1 for '\n' + (count, pos) + }; + + let slice_start = bytes.as_ptr() as usize - mem_slice.as_ptr() as usize; + + bytes = &bytes[position..]; + + let current_row_offset = *current_row_offset_ref; + *current_row_offset_ref += count; + + let slice = if let Some(global_slice) = &global_slice { + match SplitSlicePosition::split_slice_at_file( + current_row_offset, + count, + global_slice.clone(), + ) { + // Note that we don't check that the skipped line batches actually contain this many + // lines. + SplitSlicePosition::Before => continue, + SplitSlicePosition::Overlapping(offset, len) => (offset, len), + SplitSlicePosition::After => break 'main, + } + } else { + // (0, 0) is interpreted as no slicing + (0, 0) + }; + + let mut mem_slice_this_chunk = + mem_slice.slice(slice_start..slice_start + position); + + let morsel_seq = *morsel_seq_ref; + *morsel_seq_ref = morsel_seq.successor(); + + let Some(mut indexed_wait_group) = wait_groups.next().await else { + break; + }; + + let mut path_name = path_name.clone(); + + loop { + use crate::async_primitives::connector::SendError; + + let channel_index = indexed_wait_group.index(); + let wait_token = indexed_wait_group.token(); + + match line_batch_senders[channel_index].try_send(LineBatch { + bytes: mem_slice_this_chunk, + n_lines: count, + slice, + row_offset: current_row_offset, + morsel_seq, + wait_token, + path_name, + }) { + Ok(_) => { + wait_groups.push(indexed_wait_group.wait()); + break; + }, + Err(SendError::Closed(v)) => { + mem_slice_this_chunk = v.bytes; + path_name = v.path_name; + }, + Err(SendError::Full(_)) => unreachable!(), + } + + let Some(v) = wait_groups.next().await else { + break 'main; // All channels closed + }; + + indexed_wait_group = v; + } + } + } + + Ok(()) + }), + ); + + ( + line_batch_receivers, + // TODO: Refactor so that we don't unwrap, it's currently hard because + // `ComputeNode::{initialize, spawn}` doesn't return a `PolarsResult` + Arc::new(self.try_init_chunk_reader().unwrap()), + line_batch_source_task_handle, + ) + } + + fn try_init_chunk_reader(&mut self) -> PolarsResult { + let with_columns = self + .file_options + .with_columns + .clone() + // Interpret selecting no columns as selecting all columns. + .filter(|columns| !columns.is_empty()); + + ChunkReader::try_new( + &mut self.options, + self.schema.as_ref().unwrap(), + with_columns.as_deref(), + self.file_options.row_index.clone(), + self.file_options.include_file_paths.clone(), + ) + } + + /// # Panics + /// Panics if called more than once. + async fn shutdown_impl( + async_task_data: Arc>>, + verbose: bool, + ) -> PolarsResult<()> { + if verbose { + eprintln!("[CsvSource]: Shutting down"); + } + + let (line_batch_receivers, _chunk_reader, task_handle) = + async_task_data.try_lock().unwrap().take().unwrap(); + + drop(line_batch_receivers); + // Join on the producer handle to catch errors/panics. + // Safety + // * We dropped the receivers on the line above + // * This function is only called once. + task_handle.await + } + + fn shutdown(&self) -> impl Future> { + if self.verbose { + eprintln!("[CsvSource]: Shutdown via `shutdown()`"); + } + Self::shutdown_impl(self.async_task_data.clone(), self.verbose) + } + + fn shutdown_in_background(&self) { + if self.verbose { + eprintln!("[CsvSource]: Shutdown via `shutdown_in_background()`"); + } + let async_task_data = self.async_task_data.clone(); + polars_io::pl_async::get_runtime() + .spawn(Self::shutdown_impl(async_task_data, self.verbose)); + } +} + +struct ChunkReader { + reader_schema: SchemaRef, + fields_to_cast: Vec, + _cat_lock: Option, + separator: u8, + ignore_errors: bool, + projection: Vec, + quote_char: Option, + eol_char: u8, + comment_prefix: Option, + encoding: CsvEncoding, + null_values: Option, + missing_is_null: bool, + truncate_ragged_lines: bool, + decimal_comma: bool, + validate_utf8: bool, + row_index: Option, + include_file_paths: Option, +} + +impl ChunkReader { + fn try_new( + options: &mut CsvReadOptions, + reader_schema: &SchemaRef, + with_columns: Option<&[PlSmallStr]>, + row_index: Option, + include_file_paths: Option, + ) -> PolarsResult { + let mut reader_schema = reader_schema.clone(); + // Logic from `CsvReader::finish()` + let mut fields_to_cast = std::mem::take(&mut options.fields_to_cast); + + if let Some(dtypes) = options.dtype_overwrite.as_deref() { + let mut s = Arc::unwrap_or_clone(reader_schema); + for (i, dtype) in dtypes.iter().enumerate() { + s.set_dtype_at_index(i, dtype.clone()); + } + reader_schema = s.into(); + } + + let has_categorical = prepare_csv_schema(&mut reader_schema, &mut fields_to_cast)?; + + let _cat_lock = has_categorical.then(polars_core::StringCacheHolder::hold); + + let parse_options = &*options.parse_options; + + // Logic from `CoreReader::new()` + let separator = parse_options.separator; + + let null_values = parse_options + .null_values + .clone() + .map(|nv| nv.compile(&reader_schema)) + .transpose()?; + + let projection = if let Some(cols) = with_columns { + let mut v = Vec::with_capacity(cols.len()); + for col in cols { + v.push(reader_schema.try_index_of(col)?); + } + v.sort_unstable(); + v + } else if let Some(v) = options.projection.clone() { + let mut v = Arc::unwrap_or_clone(v); + v.sort_unstable(); + v + } else { + (0..reader_schema.len()).collect::>() + }; + + let validate_utf8 = matches!(parse_options.encoding, CsvEncoding::Utf8) + && reader_schema.iter_fields().any(|f| f.dtype().is_string()); + + Ok(Self { + reader_schema, + fields_to_cast, + _cat_lock, + separator, + ignore_errors: options.ignore_errors, + projection, + quote_char: parse_options.quote_char, + eol_char: parse_options.eol_char, + comment_prefix: parse_options.comment_prefix.clone(), + encoding: parse_options.encoding, + null_values, + missing_is_null: parse_options.missing_is_null, + truncate_ragged_lines: parse_options.truncate_ragged_lines, + decimal_comma: parse_options.decimal_comma, + validate_utf8, + row_index, + include_file_paths, + }) + } + + fn read_chunk( + &self, + chunk: &[u8], + n_lines: usize, + slice: (usize, usize), + chunk_row_offset: usize, + ) -> PolarsResult { + if self.validate_utf8 && !validate_utf8(chunk) { + polars_bail!(ComputeError: "invalid utf-8 sequence") + } + + read_chunk( + chunk, + self.separator, + &self.reader_schema, + self.ignore_errors, + &self.projection, + 0, // bytes_offset_thread + self.quote_char, + self.eol_char, + self.comment_prefix.as_ref(), + n_lines, // capacity + self.encoding, + self.null_values.as_ref(), + self.missing_is_null, + self.truncate_ragged_lines, + usize::MAX, // chunk_size + chunk.len(), // stop_at_nbytes + Some(0), // starting_point_offset + self.decimal_comma, + ) + .and_then(|mut df| { + let n_lines_is_correct = df.height() == n_lines; + + if slice != (0, 0) { + assert!(n_lines_is_correct); + + df = df.slice(slice.0 as i64, slice.1); + } + + cast_columns(&mut df, &self.fields_to_cast, false, self.ignore_errors)?; + + if let Some(ri) = &self.row_index { + assert!(n_lines_is_correct); + + let offset = ri.offset; + + let Some(offset) = (|| { + let offset = offset.checked_add((chunk_row_offset + slice.0) as IdxSize)?; + offset.checked_add(df.height() as IdxSize)?; + + Some(offset) + })() else { + let msg = format!( + "adding a row index column with offset {} overflows at {} rows", + offset, + chunk_row_offset + slice.0 + ); + polars_bail!(ComputeError: msg) + }; + + df.with_row_index_mut(ri.name.clone(), Some(offset as IdxSize)); + } + + Ok(df) + }) + } +} diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index effebe67c34b..16ac5dab7a98 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -1,3 +1,4 @@ +pub mod csv_source; pub mod filter; pub mod group_by; pub mod in_memory_map; diff --git a/crates/polars-stream/src/nodes/parquet_source/init.rs b/crates/polars-stream/src/nodes/parquet_source/init.rs index 7aad10ddde6c..f1cb20e004d6 100644 --- a/crates/polars-stream/src/nodes/parquet_source/init.rs +++ b/crates/polars-stream/src/nodes/parquet_source/init.rs @@ -14,7 +14,7 @@ use super::row_group_decode::RowGroupDecoder; use super::{AsyncTaskData, ParquetSourceNode}; use crate::async_executor; use crate::async_primitives::connector::connector; -use crate::async_primitives::wait_group::{WaitGroup, WaitToken}; +use crate::async_primitives::wait_group::IndexedWaitGroup; use crate::morsel::get_ideal_morsel_size; use crate::nodes::{MorselSeq, TaskPriority}; use crate::utils::task_handles_ext; @@ -23,7 +23,7 @@ impl ParquetSourceNode { /// # Panics /// Panics if called more than once. async fn shutdown_impl( - async_task_data: Arc>, + async_task_data: Arc>>, verbose: bool, ) -> PolarsResult<()> { if verbose { @@ -65,12 +65,7 @@ impl ParquetSourceNode { /// Constructs the task that distributes morsels across the engine pipelines. #[allow(clippy::type_complexity)] - pub(super) fn init_raw_morsel_distributor( - &mut self, - ) -> ( - Vec>, - task_handles_ext::AbortOnDropHandle>, - ) { + pub(super) fn init_raw_morsel_distributor(&mut self) -> AsyncTaskData { let verbose = self.verbose; let io_runtime = polars_io::pl_async::get_runtime(); @@ -140,33 +135,10 @@ impl ParquetSourceNode { row_group_data_fetcher.slice_range = slice_range; - // Pins a wait group to a channel index. - struct IndexedWaitGroup { - index: usize, - wait_group: WaitGroup, - } - - impl IndexedWaitGroup { - async fn wait(self) -> Self { - self.wait_group.wait().await; - self - } - } - // Ensure proper backpressure by only polling the buffered iterator when a wait group // is free. let mut wait_groups = (0..num_pipelines) - .map(|index| { - let wait_group = WaitGroup::default(); - { - let _prime_this_wait_group = wait_group.token(); - } - IndexedWaitGroup { - index, - wait_group: WaitGroup::default(), - } - .wait() - }) + .map(|index| IndexedWaitGroup::new(index).wait()) .collect::>(); let mut df_stream = row_group_data_fetcher @@ -229,8 +201,8 @@ impl ParquetSourceNode { loop { use crate::async_primitives::connector::SendError; - let channel_index = indexed_wait_group.index; - let wait_token = indexed_wait_group.wait_group.token(); + let channel_index = indexed_wait_group.index(); + let wait_token = indexed_wait_group.token(); match raw_morsel_senders[channel_index].try_send((df, morsel_seq, wait_token)) { Ok(_) => { diff --git a/crates/polars-stream/src/nodes/parquet_source/mod.rs b/crates/polars-stream/src/nodes/parquet_source/mod.rs index 6427d08e2696..eb4bf3cd330a 100644 --- a/crates/polars-stream/src/nodes/parquet_source/mod.rs +++ b/crates/polars-stream/src/nodes/parquet_source/mod.rs @@ -30,10 +30,10 @@ mod metadata_utils; mod row_group_data_fetch; mod row_group_decode; -type AsyncTaskData = Option<( +type AsyncTaskData = ( Vec>, task_handles_ext::AbortOnDropHandle>, -)>; +); #[allow(clippy::type_complexity)] pub struct ParquetSourceNode { @@ -61,7 +61,7 @@ pub struct ParquetSourceNode { // This permit blocks execution until the first morsel is requested. morsel_stream_starter: Option>, // This is behind a Mutex so that we can call `shutdown()` asynchronously. - async_task_data: Arc>, + async_task_data: Arc>>, is_finished: Arc, } diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 063c94081dbc..bf548bc4da5c 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -318,7 +318,7 @@ pub fn lower_ir( output_schema: scan_output_schema, scan_type, mut predicate, - file_options, + mut file_options, } = v.clone() else { unreachable!(); @@ -342,39 +342,86 @@ pub fn lower_ir( } } - // If the node itself would just filter on the whole output then there is no real - // reason to do it in the source node itself. - let do_filter_in_separate_node = - predicate.is_some() && matches!(scan_type, FileScan::Ipc { .. }); - - if do_filter_in_separate_node { - assert!(file_options.slice.is_none()); // Invariant of the scan - let predicate = predicate.take().unwrap(); - - let input = phys_sm.insert(PhysNode::new( - output_schema.clone(), - PhysNodeKind::FileScan { - scan_sources, - file_info, - hive_parts, - output_schema: scan_output_schema, - scan_type, - predicate: None, - file_options, - }, - )); + // Operation ordering: + // * with_row_index() -> slice() -> filter() + + // Some scans have built-in support for applying these operations in an optimized manner. + let opt_rewrite_to_nodes = match &scan_type { + FileScan::Parquet { .. } => (None, None, None), + FileScan::Csv { options, .. } => { + if file_options.slice.map_or(true, |(offset, _)| { + // Only really makes sense to push this if we are skipping rows + offset > 0 + }) && options.parse_options.comment_prefix.is_none() + && std::env::var("POLARS_DISABLE_EXPERIMENTAL_CSV_SLICE").as_deref() + != Ok("1") + { + // Note: This relies on `CountLines` being exact. + (None, None, predicate.take()) + } else { + // There can be comments in the middle of the file, then `CountLines` won't + // return an accurate line count :'(. + ( + file_options.row_index.take(), + file_options.slice.take(), + predicate.take(), + ) + } + }, + FileScan::Ipc { .. } => (None, None, predicate.take()), + _ => todo!(), + }; - PhysNodeKind::Filter { input, predicate } + let phys_node = PhysNodeKind::FileScan { + scan_sources, + file_info, + hive_parts, + output_schema: scan_output_schema, + scan_type, + predicate, + file_options, + }; + + let (row_index, slice, predicate) = opt_rewrite_to_nodes; + + let phys_node = if let Some(ri) = row_index { + let mut schema = Arc::unwrap_or_clone(output_schema.clone()); + + let v = schema.shift_remove_index(0).unwrap().0; + assert_eq!(v, ri.name); + let input = phys_sm.insert(PhysNode::new(Arc::new(schema), phys_node)); + + PhysNodeKind::WithRowIndex { + input, + name: ri.name, + offset: Some(ri.offset), + } } else { - PhysNodeKind::FileScan { - scan_sources, - file_info, - hive_parts, - output_schema: scan_output_schema, - scan_type, - predicate, - file_options, + phys_node + }; + + let phys_node = if let Some((offset, length)) = slice { + let input = phys_sm.insert(PhysNode::new(output_schema.clone(), phys_node)); + + if offset < 0 { + todo!() + } + + PhysNodeKind::StreamingSlice { + input, + offset: offset as usize, + length, } + } else { + phys_node + }; + + if let Some(predicate) = predicate { + let input = phys_sm.insert(PhysNode::new(output_schema.clone(), phys_node)); + + PhysNodeKind::Filter { input, predicate } + } else { + phys_node } } }, diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index b701696972a9..d0c723a66a79 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -384,6 +384,24 @@ fn to_graph_rec<'a>( )?, [], ), + FileScan::Csv { options, .. } => { + if options.parse_options.comment_prefix.is_some() { + // Should have been re-written to separate streaming nodes + assert!(file_options.row_index.is_none()); + assert!(file_options.slice.is_none()); + assert!(predicate.is_none()); + } + + ctx.graph.add_node( + nodes::csv_source::CsvSourceNode::new( + scan_sources, + file_info, + file_options, + options, + ), + [], + ) + }, _ => todo!(), } } diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index 343ebc587708..3f66c4a26a65 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -103,7 +103,11 @@ def is_path_or_str_sequence( return np.issubdtype(val.dtype, np.str_) elif include_series and isinstance(val, pl.Series): return val.dtype == pl.String - return isinstance(val, Sequence) and _is_iterable_of(val, (Path, str)) + return ( + not isinstance(val, bytes) + and isinstance(val, Sequence) + and _is_iterable_of(val, (Path, str)) + ) def is_bool_sequence( diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 0483d058180b..a8bad651c47b 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -271,6 +271,12 @@ def read_csv( ) raise ValueError(msg) + if schema_overrides is not None and not isinstance( + schema_overrides, (dict, Sequence) + ): + msg = "`schema_overrides` should be of type list or dict" + raise TypeError(msg) + if ( use_pyarrow and schema_overrides is None @@ -435,32 +441,42 @@ def read_csv( schema_overrides_is_list = isinstance(schema_overrides, Sequence) encoding_supported_in_lazy = encoding in {"utf8", "utf8-lossy"} - if ( + force_new_streaming = os.getenv("POLARS_FORCE_NEW_STREAMING") == "1" + + if force_new_streaming or ( # Check that it is not a BytesIO object isinstance(v := source, (str, Path)) - ) and ( - # HuggingFace only for now ⊂( ◜◒◝ )⊃ - str(v).startswith("hf://") - # Also dispatch on FORCE_ASYNC, so that this codepath gets run - # through by our test suite during CI. - or ( - os.getenv("POLARS_FORCE_ASYNC") == "1" - and not schema_overrides_is_list - and encoding_supported_in_lazy + and ( + # HuggingFace only for now ⊂( ◜◒◝ )⊃ + str(v).startswith("hf://") + # Also dispatch on FORCE_ASYNC, so that this codepath gets run + # through by our test suite during CI. + or ( + os.getenv("POLARS_FORCE_ASYNC") == "1" + and not schema_overrides_is_list + and encoding_supported_in_lazy + ) + # TODO: We can't dispatch this for all paths due to a few reasons: + # * `scan_csv` does not support compressed files + # * The `storage_options` configuration keys are different between + # fsspec and object_store (would require a breaking change) ) - # TODO: We can't dispatch this for all paths due to a few reasons: - # * `scan_csv` does not support compressed files - # * The `storage_options` configuration keys are different between - # fsspec and object_store (would require a breaking change) ): - source = normalize_filepath(v, check_not_directory=False) - - if schema_overrides_is_list: - msg = "passing a list to `schema_overrides` is unsupported for hf:// paths" - raise ValueError(msg) - if not encoding_supported_in_lazy: - msg = f"unsupported encoding {encoding} for hf:// paths" - raise ValueError(msg) + if isinstance(source, (str, Path)): + source = normalize_filepath(source, check_not_directory=False) + elif is_path_or_str_sequence(source, allow_str=False): + source = [ # type: ignore[assignment] + normalize_filepath(source, check_not_directory=False) + for source in source + ] + + if not force_new_streaming: + if schema_overrides_is_list: + msg = "passing a list to `schema_overrides` is unsupported for hf:// paths" + raise ValueError(msg) + if not encoding_supported_in_lazy: + msg = f"unsupported encoding {encoding} for hf:// paths" + raise ValueError(msg) lf = _scan_csv_impl( source, diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 226a4e31e2b9..c1e3285ea3a6 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -125,6 +125,7 @@ def test_infer_schema_false() -> None: assert df.dtypes == [pl.String, pl.String, pl.String] +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_csv_null_values() -> None: csv = textwrap.dedent( """\ @@ -362,6 +363,7 @@ def test_datetime_parsing_default_formats() -> None: assert df.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_partial_dtype_overwrite() -> None: csv = textwrap.dedent( """\ @@ -375,6 +377,7 @@ def test_partial_dtype_overwrite() -> None: assert df.dtypes == [pl.String, pl.Int64, pl.Int64] +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_dtype_overwrite_with_column_name_selection() -> None: csv = textwrap.dedent( """\ @@ -388,6 +391,7 @@ def test_dtype_overwrite_with_column_name_selection() -> None: assert df.dtypes == [pl.String, pl.Int32, pl.Int64] +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_dtype_overwrite_with_column_idx_selection() -> None: csv = textwrap.dedent( """\ @@ -440,6 +444,7 @@ def test_read_csv_columns_argument( assert df.columns == col_out +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_read_csv_buffer_ownership() -> None: bts = b"\xf0\x9f\x98\x80,5.55,333\n\xf0\x9f\x98\x86,-5.0,666" buf = io.BytesIO(bts) @@ -455,6 +460,7 @@ def test_read_csv_buffer_ownership() -> None: assert buf.read() == bts +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch @pytest.mark.write_disk def test_read_csv_encoding(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -485,6 +491,7 @@ def test_read_csv_encoding(tmp_path: Path) -> None: ) +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_column_rename_and_dtype_overwrite() -> None: csv = textwrap.dedent( """\ @@ -861,6 +868,7 @@ def test_csv_date_dtype_ignore_errors() -> None: assert_frame_equal(out, expected) +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_csv_globbing(io_files_path: Path) -> None: path = io_files_path / "foods*.csv" df = pl.read_csv(path) @@ -1481,6 +1489,7 @@ def test_csv_categorical_lifetime() -> None: assert (df["a"] == df["b"]).to_list() == [False, False, None] +@pytest.mark.may_fail_auto_streaming def test_csv_categorical_categorical_merge() -> None: N = 50 f = io.BytesIO() @@ -2174,6 +2183,7 @@ def test_csv_float_decimal() -> None: pl.read_csv(floats, decimal_comma=True) +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_fsspec_not_available() -> None: with pytest.MonkeyPatch.context() as mp: mp.setenv("POLARS_FORCE_ASYNC", "0") @@ -2188,6 +2198,7 @@ def test_fsspec_not_available() -> None: ) +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch def test_read_csv_dtypes_deprecated() -> None: csv = textwrap.dedent( """\ @@ -2245,6 +2256,7 @@ def test_write_csv_raise_on_non_utf8_17328( df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk")) +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch @pytest.mark.write_disk def test_write_csv_appending_17543(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -2333,7 +2345,8 @@ def test_csv_read_time_dtype(tmp_path: Path) -> None: ) -def test_csv_read_time_dtype_overwrite(tmp_path: Path) -> None: +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_csv_read_time_dtype_overwrite() -> None: df = pl.Series("time", [0]).cast(pl.Time()).to_frame() assert_frame_equal( diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index cf78190f2380..e85b63ceb303 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -1,6 +1,7 @@ from __future__ import annotations import io +import os from dataclasses import dataclass from datetime import datetime from functools import partial @@ -30,6 +31,12 @@ def _enable_force_async(monkeypatch: pytest.MonkeyPatch) -> None: def _assert_force_async(capfd: Any, data_file_extension: str) -> None: + if ( + os.getenv("POLARS_AUTO_NEW_STREAMING", os.getenv("POLARS_FORCE_NEW_STREAMING")) + == "1" + ): + return + """Calls `capfd.readouterr`, consuming the captured output so far.""" if data_file_extension == ".ndjson": return