diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 4d63de75a5..4115fa9ab4 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -272,18 +272,19 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) - val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf( - s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") - .doc( - "The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. " + - "Compression can be disabled by setting spark.shuffle.compress=false.") - .stringConf - .checkValues(Set("zstd")) - .createWithDefault("zstd") + val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") + .doc( + "The codec of Comet native shuffle used to compress shuffle data. lz4, zstd, and " + + "snappy are supported. Compression can be disabled by setting " + + "spark.shuffle.compress=false.") + .stringConf + .checkValues(Set("zstd", "lz4", "snappy")) + .createWithDefault("lz4") - val COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL: ConfigEntry[Int] = - conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.level") - .doc("The compression level to use when compression shuffle files.") + val COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.zstd.level") + .doc("The compression level to use when compressing shuffle files with zstd.") .intConf .createWithDefault(1) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 20923b93ae..d78e6111df 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -50,8 +50,8 @@ Comet provides the following configuration settings. | spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared | | spark.comet.exec.project.enabled | Whether to enable project by default. | true | | spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false | -| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. Compression can be disabled by setting spark.shuffle.compress=false. | zstd | -| spark.comet.exec.shuffle.compression.level | The compression level to use when compression shuffle files. | 1 | +| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. lz4, zstd, and snappy are supported. Compression can be disabled by setting spark.shuffle.compress=false. | lz4 | +| spark.comet.exec.shuffle.compression.zstd.level | The compression level to use when compressing shuffle files with zstd. | 1 | | spark.comet.exec.shuffle.enabled | Whether to enable Comet native shuffle. Note that this requires setting 'spark.shuffle.manager' to 'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'. 'spark.shuffle.manager' must be set before starting the Spark application and cannot be changed during the application. | true | | spark.comet.exec.sort.enabled | Whether to enable sort by default. | true | | spark.comet.exec.sortMergeJoin.enabled | Whether to enable sortMergeJoin by default. | true | diff --git a/native/Cargo.lock b/native/Cargo.lock index bbc0ff97a9..1c44e3cc52 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 4 +version = 3 [[package]] name = "addr2line" @@ -346,7 +346,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -903,6 +903,7 @@ dependencies = [ "lazy_static", "log", "log4rs", + "lz4_flex", "mimalloc", "num", "once_cell", @@ -914,6 +915,7 @@ dependencies = [ "regex", "serde", "simd-adler32", + "snap", "tempfile", "thiserror", "tokio", @@ -1168,7 +1170,7 @@ version = "44.0.0" source = "git+https://github.com/apache/datafusion.git?rev=44.0.0-rc2#3cc3fca31e6edc2d953e663bfd7f856bcb70d8c4" dependencies = [ "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -1333,7 +1335,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -1473,7 +1475,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -1746,7 +1748,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -2556,7 +2558,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -2778,7 +2780,7 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -2868,7 +2870,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -2895,7 +2897,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -2932,7 +2934,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -2977,9 +2979,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.92" +version = "2.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126" +checksum = "9c786062daee0d6db1132800e623df74274a0a87322d8e183338e01b3d98d058" dependencies = [ "proc-macro2", "quote", @@ -2994,7 +2996,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -3027,7 +3029,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -3100,7 +3102,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -3122,7 +3124,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -3276,7 +3278,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", "wasm-bindgen-shared", ] @@ -3298,7 +3300,7 @@ checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3561,7 +3563,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", "synstructure", ] @@ -3583,7 +3585,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] @@ -3603,7 +3605,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", "synstructure", ] @@ -3626,7 +3628,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.92", + "syn 2.0.93", ] [[package]] diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 5089e67a03..8937236dda 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -52,6 +52,9 @@ serde = { version = "1", features = ["derive"] } lazy_static = "1.4.0" prost = "0.12.1" jni = "0.21" +snap = "1.1" +# we disable default features in lz4_flex to force the use of the faster unsafe encoding and decoding implementation +lz4_flex = { version = "0.11.3", default-features = false } zstd = "0.11" rand = { workspace = true} num = { workspace = true } diff --git a/native/core/benches/row_columnar.rs b/native/core/benches/row_columnar.rs index 60b41330e3..a62574111b 100644 --- a/native/core/benches/row_columnar.rs +++ b/native/core/benches/row_columnar.rs @@ -19,6 +19,7 @@ use arrow::datatypes::DataType as ArrowDataType; use comet::execution::shuffle::row::{ process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow, }; +use comet::execution::shuffle::CompressionCodec; use criterion::{criterion_group, criterion_main, Criterion}; use tempfile::Builder; @@ -77,6 +78,7 @@ fn benchmark(c: &mut Criterion) { false, 0, None, + &CompressionCodec::Zstd(1), ) .unwrap(); }); diff --git a/native/core/benches/shuffle_writer.rs b/native/core/benches/shuffle_writer.rs index 865ca73b4a..0d22c62cc2 100644 --- a/native/core/benches/shuffle_writer.rs +++ b/native/core/benches/shuffle_writer.rs @@ -35,23 +35,52 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("shuffle_writer: encode (no compression))", |b| { let batch = create_batch(8192, true); let mut buffer = vec![]; - let mut cursor = Cursor::new(&mut buffer); let ipc_time = Time::default(); - b.iter(|| write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::None, &ipc_time)); + b.iter(|| { + buffer.clear(); + let mut cursor = Cursor::new(&mut buffer); + write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::None, &ipc_time) + }); + }); + group.bench_function("shuffle_writer: encode and compress (snappy)", |b| { + let batch = create_batch(8192, true); + let mut buffer = vec![]; + let ipc_time = Time::default(); + b.iter(|| { + buffer.clear(); + let mut cursor = Cursor::new(&mut buffer); + write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Snappy, &ipc_time) + }); + }); + group.bench_function("shuffle_writer: encode and compress (lz4)", |b| { + let batch = create_batch(8192, true); + let mut buffer = vec![]; + let ipc_time = Time::default(); + b.iter(|| { + buffer.clear(); + let mut cursor = Cursor::new(&mut buffer); + write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Lz4Frame, &ipc_time) + }); }); group.bench_function("shuffle_writer: encode and compress (zstd level 1)", |b| { let batch = create_batch(8192, true); let mut buffer = vec![]; - let mut cursor = Cursor::new(&mut buffer); let ipc_time = Time::default(); - b.iter(|| write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Zstd(1), &ipc_time)); + b.iter(|| { + buffer.clear(); + let mut cursor = Cursor::new(&mut buffer); + write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Zstd(1), &ipc_time) + }); }); group.bench_function("shuffle_writer: encode and compress (zstd level 6)", |b| { let batch = create_batch(8192, true); let mut buffer = vec![]; - let mut cursor = Cursor::new(&mut buffer); let ipc_time = Time::default(); - b.iter(|| write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Zstd(6), &ipc_time)); + b.iter(|| { + buffer.clear(); + let mut cursor = Cursor::new(&mut buffer); + write_ipc_compressed(&batch, &mut cursor, &CompressionCodec::Zstd(6), &ipc_time) + }); }); group.bench_function("shuffle_writer: end to end", |b| { let ctx = SessionContext::new(); diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 7d8d577fe5..aaac7ec8ca 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -17,6 +17,7 @@ //! Define JNI APIs which can be called from Java/Scala. +use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; use arrow::datatypes::DataType as ArrowDataType; use arrow_array::RecordBatch; use datafusion::{ @@ -40,8 +41,6 @@ use jni::{ use std::time::{Duration, Instant}; use std::{collections::HashMap, sync::Arc, task::Poll}; -use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; - use crate::{ errors::{try_unwrap_or_throw, CometError, CometResult}, execution::{ @@ -54,6 +53,7 @@ use datafusion_comet_proto::spark_operator::Operator; use datafusion_common::ScalarValue; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use futures::stream::StreamExt; +use jni::objects::JByteBuffer; use jni::sys::JNI_FALSE; use jni::{ objects::GlobalRef, @@ -64,6 +64,7 @@ use std::sync::Mutex; use tokio::runtime::Runtime; use crate::execution::operators::ScanExec; +use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; use log::info; use once_cell::sync::{Lazy, OnceCell}; @@ -147,7 +148,7 @@ impl PerTaskMemoryPool { /// Accept serialized query plan and return the address of the native query plan. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( e: JNIEnv, @@ -444,7 +445,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr /// Accept serialized query plan and the addresses of Arrow Arrays from Spark, /// then execute the query. Return addresses of arrow vector. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( e: JNIEnv, @@ -618,7 +619,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { /// Used by Comet shuffle external sorter to write sorted records to disk. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( e: JNIEnv, @@ -632,6 +633,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative checksum_enabled: jboolean, checksum_algo: jint, current_checksum: jlong, + compression_codec: jstring, + compression_level: jint, ) -> jlongArray { try_unwrap_or_throw(&e, |mut env| unsafe { let data_types = convert_datatype_arrays(&mut env, serialized_datatypes)?; @@ -659,6 +662,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative Some(current_checksum as u32) }; + let compression_codec: String = env + .get_string(&JString::from_raw(compression_codec)) + .unwrap() + .into(); + + let compression_codec = match compression_codec.as_str() { + "zstd" => CompressionCodec::Zstd(compression_level), + "lz4" => CompressionCodec::Lz4Frame, + "snappy" => CompressionCodec::Snappy, + _ => CompressionCodec::Lz4Frame, + }; + let (written_bytes, checksum) = process_sorted_row_partition( row_num, batch_size as usize, @@ -670,6 +685,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative checksum_enabled, checksum_algo, current_checksum, + &compression_codec, )?; let checksum = if let Some(checksum) = checksum { @@ -703,3 +719,24 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( Ok(()) }) } + +#[no_mangle] +/// Used by Comet native shuffle reader +/// # Safety +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. +pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( + e: JNIEnv, + _class: JClass, + byte_buffer: JByteBuffer, + length: jint, + array_addrs: jlongArray, + schema_addrs: jlongArray, +) -> jlong { + try_unwrap_or_throw(&e, |mut env| { + let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; + let length = length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + let batch = read_ipc_compressed(slice)?; + prepare_output(&mut env, array_addrs, schema_addrs, batch, false) + }) +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index da452c2f15..294922f2f1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1067,9 +1067,11 @@ impl PhysicalPlanner { let codec = match writer.codec.try_into() { Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None), + Ok(SparkCompressionCodec::Snappy) => Ok(CompressionCodec::Snappy), Ok(SparkCompressionCodec::Zstd) => { Ok(CompressionCodec::Zstd(writer.compression_level)) } + Ok(SparkCompressionCodec::Lz4) => Ok(CompressionCodec::Lz4Frame), _ => Err(ExecutionError::GeneralError(format!( "Unsupported shuffle compression codec: {:?}", writer.codec diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index 8111f5eede..178aff1fad 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -19,4 +19,6 @@ mod list; mod map; pub mod row; mod shuffle_writer; -pub use shuffle_writer::{write_ipc_compressed, CompressionCodec, ShuffleWriterExec}; +pub use shuffle_writer::{ + read_ipc_compressed, write_ipc_compressed, CompressionCodec, ShuffleWriterExec, +}; diff --git a/native/core/src/execution/shuffle/row.rs b/native/core/src/execution/shuffle/row.rs index 405f642163..9037bd7943 100644 --- a/native/core/src/execution/shuffle/row.rs +++ b/native/core/src/execution/shuffle/row.rs @@ -3297,6 +3297,7 @@ pub fn process_sorted_row_partition( // this is the initial checksum for this method, as it also gets updated iteratively // inside the loop within the method across batches. initial_checksum: Option, + codec: &CompressionCodec, ) -> Result<(i64, Option), CometError> { // TODO: We can tune this parameter automatically based on row size and cache size. let row_step = 10; @@ -3359,9 +3360,7 @@ pub fn process_sorted_row_partition( // we do not collect metrics in Native_writeSortedFileNative let ipc_time = Time::default(); - // compression codec is not configurable for CometBypassMergeSortShuffleWriter - let codec = CompressionCodec::Zstd(1); - written += write_ipc_compressed(&batch, &mut cursor, &codec, &ipc_time)?; + written += write_ipc_compressed(&batch, &mut cursor, codec, &ipc_time)?; if let Some(checksum) = &mut current_checksum { checksum.update(&mut cursor)?; diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index f3fa685b88..e6679d13d4 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -21,6 +21,7 @@ use crate::{ common::bit::ceil, errors::{CometError, CometResult}, }; +use arrow::ipc::reader::StreamReader; use arrow::{datatypes::*, ipc::writer::StreamWriter}; use async_trait::async_trait; use bytes::Buf; @@ -312,7 +313,7 @@ impl PartitionBuffer { repart_timer.stop(); if self.num_active_rows >= self.batch_size { - let flush = self.flush(&metrics.ipc_time); + let flush = self.flush(metrics); if let Err(e) = flush { return AppendRowStatus::MemDiff(Err(e)); } @@ -330,7 +331,7 @@ impl PartitionBuffer { } /// flush active data into frozen bytes - fn flush(&mut self, ipc_time: &Time) -> Result { + fn flush(&mut self, metrics: &ShuffleRepartitionerMetrics) -> Result { if self.num_active_rows == 0 { return Ok(0); } @@ -340,14 +341,24 @@ impl PartitionBuffer { let active = std::mem::take(&mut self.active); let num_rows = self.num_active_rows; self.num_active_rows = 0; + + let mut mempool_timer = metrics.mempool_time.timer(); self.reservation.try_shrink(self.active_slots_mem_size)?; + mempool_timer.stop(); + let mut repart_timer = metrics.repart_time.timer(); let frozen_batch = make_batch(Arc::clone(&self.schema), active, num_rows)?; + repart_timer.stop(); let frozen_capacity_old = self.frozen.capacity(); let mut cursor = Cursor::new(&mut self.frozen); cursor.seek(SeekFrom::End(0))?; - write_ipc_compressed(&frozen_batch, &mut cursor, &self.codec, ipc_time)?; + write_ipc_compressed( + &frozen_batch, + &mut cursor, + &self.codec, + &metrics.encode_time, + )?; mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize; Ok(mem_diff) @@ -652,7 +663,7 @@ struct ShuffleRepartitionerMetrics { mempool_time: Time, /// Time encoding batches to IPC format - ipc_time: Time, + encode_time: Time, /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL Metrics. write_time: Time, @@ -676,7 +687,7 @@ impl ShuffleRepartitionerMetrics { baseline: BaselineMetrics::new(metrics, partition), repart_time: MetricBuilder::new(metrics).subset_time("repart_time", partition), mempool_time: MetricBuilder::new(metrics).subset_time("mempool_time", partition), - ipc_time: MetricBuilder::new(metrics).subset_time("ipc_time", partition), + encode_time: MetricBuilder::new(metrics).subset_time("encode_time", partition), write_time: MetricBuilder::new(metrics).subset_time("write_time", partition), input_batches: MetricBuilder::new(metrics).counter("input_batches", partition), spill_count: MetricBuilder::new(metrics).spill_count(partition), @@ -790,6 +801,8 @@ impl ShuffleRepartitioner { Partitioning::Hash(exprs, _) => { let (partition_starts, shuffled_partition_ids): (Vec, Vec) = { let mut timer = self.metrics.repart_time.timer(); + + // evaluate partition expressions let arrays = exprs .iter() .map(|expr| expr.evaluate(&input)?.into_array(input.num_rows())) @@ -923,7 +936,7 @@ impl ShuffleRepartitioner { let mut output_batches: Vec> = vec![vec![]; num_output_partitions]; let mut offsets = vec![0; num_output_partitions + 1]; for i in 0..num_output_partitions { - buffered_partitions[i].flush(&self.metrics.ipc_time)?; + buffered_partitions[i].flush(&self.metrics)?; output_batches[i] = std::mem::take(&mut buffered_partitions[i].frozen); } @@ -1023,20 +1036,19 @@ impl ShuffleRepartitioner { } let mut timer = self.metrics.write_time.timer(); - let spillfile = self .runtime .disk_manager .create_tmp_file("shuffle writer spill")?; + timer.stop(); + let offsets = spill_into( &mut self.buffered_partitions, spillfile.path(), self.num_output_partitions, - &self.metrics.ipc_time, + &self.metrics, )?; - timer.stop(); - let mut spills = self.spills.lock().await; let used = self.reservation.size(); self.metrics.spill_count.add(1); @@ -1107,16 +1119,18 @@ fn spill_into( buffered_partitions: &mut [PartitionBuffer], path: &Path, num_output_partitions: usize, - ipc_time: &Time, + metrics: &ShuffleRepartitionerMetrics, ) -> Result> { let mut output_batches: Vec> = vec![vec![]; num_output_partitions]; for i in 0..num_output_partitions { - buffered_partitions[i].flush(ipc_time)?; + buffered_partitions[i].flush(metrics)?; output_batches[i] = std::mem::take(&mut buffered_partitions[i].frozen); } let path = path.to_owned(); + let mut write_timer = metrics.write_time.timer(); + let mut offsets = vec![0; num_output_partitions + 1]; let mut spill_data = OpenOptions::new() .write(true) @@ -1130,6 +1144,8 @@ fn spill_into( spill_data.write_all(&output_batches[i])?; output_batches[i].clear(); } + write_timer.stop(); + // add one extra offset at last to ease partition length computation offsets[num_output_partitions] = spill_data.stream_position()?; Ok(offsets) @@ -1549,7 +1565,9 @@ impl Checksum { #[derive(Debug, Clone)] pub enum CompressionCodec { None, + Lz4Frame, Zstd(i32), + Snappy, } /// Writes given record batch as Arrow IPC bytes into given writer. @@ -1567,17 +1585,41 @@ pub fn write_ipc_compressed( let mut timer = ipc_time.timer(); let start_pos = output.stream_position()?; - // write ipc_length placeholder - output.write_all(&[0u8; 8])?; + // seek past ipc_length placeholder + output.seek_relative(8)?; + + // write number of columns because JVM side needs to know how many addresses to allocate + let field_count = batch.schema().fields().len(); + output.write_all(&field_count.to_le_bytes())?; let output = match codec { CompressionCodec::None => { + output.write_all(b"NONE")?; let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; arrow_writer.write(batch)?; arrow_writer.finish()?; arrow_writer.into_inner()? } + CompressionCodec::Snappy => { + output.write_all(b"SNAP")?; + let mut wtr = snap::write::FrameEncoder::new(output); + let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + wtr.into_inner() + .map_err(|e| DataFusionError::Execution(format!("lz4 compression error: {}", e)))? + } + CompressionCodec::Lz4Frame => { + output.write_all(b"LZ4_")?; + let mut wtr = lz4_flex::frame::FrameEncoder::new(output); + let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + wtr.finish() + .map_err(|e| DataFusionError::Execution(format!("lz4 compression error: {}", e)))? + } CompressionCodec::Zstd(level) => { + output.write_all(b"ZSTD")?; let encoder = zstd::Encoder::new(output, *level)?; let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; arrow_writer.write(batch)?; @@ -1590,6 +1632,13 @@ pub fn write_ipc_compressed( // fill ipc length let end_pos = output.stream_position()?; let ipc_length = end_pos - start_pos - 8; + let max_size = i32::MAX as u64; + if ipc_length > max_size { + return Err(DataFusionError::Execution(format!( + "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ + Try reducing batch size or increasing compression level" + ))); + } // fill ipc length output.seek(SeekFrom::Start(start_pos))?; @@ -1601,6 +1650,33 @@ pub fn write_ipc_compressed( Ok((end_pos - start_pos) as usize) } +pub fn read_ipc_compressed(bytes: &[u8]) -> Result { + match &bytes[0..4] { + b"SNAP" => { + let decoder = snap::read::FrameDecoder::new(&bytes[4..]); + let mut reader = StreamReader::try_new(decoder, None)?; + reader.next().unwrap().map_err(|e| e.into()) + } + b"LZ4_" => { + let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); + let mut reader = StreamReader::try_new(decoder, None)?; + reader.next().unwrap().map_err(|e| e.into()) + } + b"ZSTD" => { + let decoder = zstd::Decoder::new(&bytes[4..])?; + let mut reader = StreamReader::try_new(decoder, None)?; + reader.next().unwrap().map_err(|e| e.into()) + } + b"NONE" => { + let mut reader = StreamReader::try_new(&bytes[4..], None)?; + reader.next().unwrap().map_err(|e| e.into()) + } + _ => Err(DataFusionError::Execution( + "Failed to decode batch: invalid compression codec".to_string(), + )), + } +} + /// A stream that yields no record batches which represent end of output. pub struct EmptyStream { /// Schema representing the data @@ -1650,18 +1726,24 @@ mod test { #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn write_ipc_zstd() { + fn roundtrip_ipc() { let batch = create_batch(8192); - let mut output = vec![]; - let mut cursor = Cursor::new(&mut output); - write_ipc_compressed( - &batch, - &mut cursor, - &CompressionCodec::Zstd(1), - &Time::default(), - ) - .unwrap(); - assert_eq!(40218, output.len()); + for codec in &[ + CompressionCodec::None, + CompressionCodec::Zstd(1), + CompressionCodec::Snappy, + CompressionCodec::Lz4Frame, + ] { + let mut output = vec![]; + let mut cursor = Cursor::new(&mut output); + let length = + write_ipc_compressed(&batch, &mut cursor, codec, &Time::default()).unwrap(); + assert_eq!(length, output.len()); + + let ipc_without_length_prefix = &output[16..]; + let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + assert_eq!(batch, batch2); + } } #[test] diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 5cb2802da8..a3480086c7 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -85,6 +85,8 @@ message Limit { enum CompressionCodec { None = 0; Zstd = 1; + Lz4 = 2; + Snappy = 3; } message ShuffleWriter { diff --git a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java index cc44955705..1e3762a6c4 100644 --- a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java +++ b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java @@ -107,6 +107,8 @@ public final class CometShuffleExternalSorter implements CometShuffleChecksumSup private final long[] partitionChecksums; private final String checksumAlgorithm; + private final String compressionCodec; + private final int compressionLevel; // The memory allocator for this sorter. It is used to allocate/free memory pages for this sorter. // Because we need to allocate off-heap memory regardless of configured Spark memory mode @@ -153,6 +155,9 @@ public CometShuffleExternalSorter( this.peakMemoryUsedBytes = getMemoryUsage(); this.partitionChecksums = createPartitionChecksums(numPartitions, conf); this.checksumAlgorithm = getChecksumAlgorithm(conf); + this.compressionCodec = CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC().get(); + this.compressionLevel = + (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL().get(); this.initialSize = initialSize; @@ -556,7 +561,9 @@ public void writeSortedFileNative(boolean isLastFile) throws IOException { spillInfo.file, rowPartition, writeMetricsToUse, - preferDictionaryRatio); + preferDictionaryRatio, + compressionCodec, + compressionLevel); spillInfo.partitionLengths[currentPartition] = written; // Store the checksum for the current partition. @@ -578,7 +585,13 @@ public void writeSortedFileNative(boolean isLastFile) throws IOException { if (currentPartition != -1) { long written = doSpilling( - dataTypes, spillInfo.file, rowPartition, writeMetricsToUse, preferDictionaryRatio); + dataTypes, + spillInfo.file, + rowPartition, + writeMetricsToUse, + preferDictionaryRatio, + compressionCodec, + compressionLevel); spillInfo.partitionLengths[currentPartition] = written; synchronized (spills) { diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java index dcb9d99d37..006e8ce971 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java @@ -103,6 +103,8 @@ public final class CometDiskBlockWriter { private long totalWritten = 0L; private boolean initialized = false; private final int columnarBatchSize; + private final String compressionCodec; + private final int compressionLevel; private final boolean isAsync; private final int asyncThreadNum; private final ExecutorService threadPool; @@ -153,6 +155,9 @@ public final class CometDiskBlockWriter { this.threadPool = threadPool; this.columnarBatchSize = (int) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_BATCH_SIZE().get(); + this.compressionCodec = CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC().get(); + this.compressionLevel = + (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL().get(); this.numElementsForSpillThreshold = (int) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_SPILL_THRESHOLD().get(); @@ -397,7 +402,14 @@ long doSpilling(boolean isLast) throws IOException { synchronized (file) { outputRecords += rowPartition.getNumRows(); written = - doSpilling(dataTypes, file, rowPartition, writeMetricsToUse, preferDictionaryRatio); + doSpilling( + dataTypes, + file, + rowPartition, + writeMetricsToUse, + preferDictionaryRatio, + compressionCodec, + compressionLevel); } // Update metrics diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index 3dc86b05bb..a4f09b4158 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -171,7 +171,9 @@ protected long doSpilling( File file, RowPartition rowPartition, ShuffleWriteMetricsReporter writeMetricsToUse, - double preferDictionaryRatio) { + double preferDictionaryRatio, + String compressionCodec, + int compressionLevel) { long[] addresses = rowPartition.getRowAddresses(); int[] sizes = rowPartition.getRowSizes(); @@ -190,7 +192,9 @@ protected long doSpilling( batchSize, checksumEnabled, checksumAlgo, - currentChecksum); + currentChecksum, + compressionCodec, + compressionLevel); long written = results[0]; checksum = results[1]; diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index e5728009e4..dbcab15b4f 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -19,6 +19,8 @@ package org.apache.comet +import java.nio.ByteBuffer + import org.apache.spark.CometTaskMemoryManager import org.apache.spark.sql.comet.CometMetricNode @@ -118,9 +120,14 @@ class Native extends NativeBase { * @param currentChecksum * the current checksum of the file. As the checksum is computed incrementally, this is used * to resume the computation of checksum for previous written data. + * @param compressionCodec + * the compression codec + * @param compressionLevel + * the compression level * @return * [the number of bytes written to disk, the checksum] */ + // scalastyle:off @native def writeSortedFileNative( addresses: Array[Long], rowSizes: Array[Int], @@ -130,7 +137,10 @@ class Native extends NativeBase { batchSize: Int, checksumEnabled: Boolean, checksumAlgo: Int, - currentChecksum: Long): Array[Long] + currentChecksum: Long, + compressionCodec: String, + compressionLevel: Int): Array[Long] + // scalastyle:on /** * Sorts partition ids of Spark unsafe rows in place. Used by Comet shuffle external sorter. @@ -141,4 +151,22 @@ class Native extends NativeBase { * the size of the array. */ @native def sortRowPartitionsNative(addr: Long, size: Long): Unit + + /** + * Decompress and decode a native shuffle block. + * @param shuffleBlock + * the encoded anc compressed shuffle block. + * @param length + * the limit of the byte buffer. + * @param addr + * the address of the array of compressed and encoded bytes. + * @param size + * the size of the array. + */ + @native def decodeShuffleBlock( + shuffleBlock: ByteBuffer, + length: Int, + arrayAddrs: Array[Long], + schemaAddrs: Array[Long]): Long + } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index a26fa28c8b..53370a03b7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -132,10 +132,11 @@ object CometMetricNode { def shuffleMetrics(sc: SparkContext): Map[String, SQLMetric] = { Map( - "elapsed_compute" -> SQLMetrics.createNanoTimingMetric(sc, "native shuffle time"), + "elapsed_compute" -> SQLMetrics.createNanoTimingMetric(sc, "native shuffle writer time"), "mempool_time" -> SQLMetrics.createNanoTimingMetric(sc, "memory pool time"), "repart_time" -> SQLMetrics.createNanoTimingMetric(sc, "repartition time"), - "ipc_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and compression time"), + "encode_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and compression time"), + "decode_time" -> SQLMetrics.createNanoTimingMetric(sc, "decoding and decompression time"), "spill_count" -> SQLMetrics.createMetric(sc, "number of spills"), "spilled_bytes" -> SQLMetrics.createMetric(sc, "spilled bytes"), "input_batches" -> SQLMetrics.createMetric(sc, "number of input batches")) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index 74c6559504..1283a745a6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -25,8 +25,14 @@ import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, Task import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader, ShuffleReadMetricsReporter} -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.shuffle.ShuffleReader +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.BlockId +import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.ShuffleBlockFetcherIterator import org.apache.spark.util.CompletionIterator /** @@ -79,7 +85,7 @@ class CometBlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - var currentReadIterator: ArrowReaderIterator = null + var currentReadIterator: NativeBatchDecoderIterator = null // Closes last read iterator after the task is finished. // We need to close read iterator during iterating input streams, @@ -91,18 +97,16 @@ class CometBlockStoreShuffleReader[K, C]( } } - val recordIter = fetchIterator - .flatMap { case (_, inputStream) => - IpcInputStreamIterator(inputStream, decompressingNeeded = true, context) - .flatMap { channel => - if (currentReadIterator != null) { - // Closes previous read iterator. - currentReadIterator.close() - } - currentReadIterator = new ArrowReaderIterator(channel, this.getClass.getSimpleName) - currentReadIterator.map((0, _)) // use 0 as key since it's not used - } - } + val recordIter: Iterator[(Int, ColumnarBatch)] = fetchIterator + .flatMap(blockIdAndStream => { + if (currentReadIterator != null) { + currentReadIterator.close() + } + currentReadIterator = + NativeBatchDecoderIterator(blockIdAndStream._2, context, dep.decodeTime) + currentReadIterator + }) + .map(b => (0, b)) // Update the context task metrics for each record read. val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 7b1d1f1271..8c8aed28ee 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -25,6 +25,7 @@ import org.apache.spark.{Aggregator, Partitioner, ShuffleDependency, SparkEnv} import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleWriteProcessor +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType /** @@ -39,7 +40,8 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( override val mapSideCombine: Boolean = false, override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, val shuffleType: ShuffleType = CometNativeShuffle, - val schema: Option[StructType] = None) + val schema: Option[StructType] = None, + val decodeTime: SQLMetric) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 3a11b8b28c..041411b3f0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -238,7 +238,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { partitioner = new Partitioner { override def numPartitions: Int = outputPartitioning.numPartitions override def getPartition(key: Any): Int = key.asInstanceOf[Int] - }) + }, + decodeTime = metrics("decode_time")) dependency } @@ -435,7 +436,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, - schema = Some(fromAttributes(outputAttributes))) + schema = Some(fromAttributes(outputAttributes)), + decodeTime = writeMetrics("decode_time")) dependency } @@ -481,7 +483,7 @@ class CometShuffleWriteProcessor( val detailedMetrics = Seq( "elapsed_compute", - "ipc_time", + "encode_time", "repart_time", "mempool_time", "input_batches", @@ -557,13 +559,16 @@ class CometShuffleWriteProcessor( if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 + case "snappy" => CompressionCodec.Snappy case other => throw new UnsupportedOperationException(s"invalid codec: $other") } shuffleWriterBuilder.setCodec(codec) } else { shuffleWriterBuilder.setCodec(CompressionCodec.None) } - shuffleWriterBuilder.setCompressionLevel(CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL.get) + shuffleWriterBuilder.setCompressionLevel( + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) outputPartitioning match { case _: HashPartitioning => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala deleted file mode 100644 index aa40550488..0000000000 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet.execution.shuffle - -import java.io.{EOFException, InputStream} -import java.nio.{ByteBuffer, ByteOrder} -import java.nio.channels.{Channels, ReadableByteChannel} - -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.LimitedInputStream - -case class IpcInputStreamIterator( - var in: InputStream, - decompressingNeeded: Boolean, - taskContext: TaskContext) - extends Iterator[ReadableByteChannel] - with Logging { - - private[execution] val channel: ReadableByteChannel = if (in != null) { - Channels.newChannel(in) - } else { - null - } - - private val ipcLengthsBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) - - // NOTE: - // since all ipcs are sharing the same input stream and channel, the second - // hasNext() must be called after the first ipc has been completely processed. - - private[execution] var consumed = true - private var finished = false - private var currentIpcLength = 0L - private var currentLimitedInputStream: LimitedInputStream = _ - - taskContext.addTaskCompletionListener[Unit](_ => { - closeInputStream() - }) - - override def hasNext: Boolean = { - if (in == null || finished) { - return false - } - - // If we've read the length of the next IPC, we don't need to read it again. - if (!consumed) { - return true - } - - if (currentLimitedInputStream != null) { - currentLimitedInputStream.skip(Int.MaxValue) - currentLimitedInputStream = null - } - - // Reads the length of IPC bytes - ipcLengthsBuf.clear() - while (ipcLengthsBuf.hasRemaining && channel.read(ipcLengthsBuf) >= 0) {} - - // If we reach the end of the stream, we are done, or if we read partial length - // then the stream is corrupted. - if (ipcLengthsBuf.hasRemaining) { - if (ipcLengthsBuf.position() == 0) { - finished = true - closeInputStream() - return false - } - throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths") - } - - ipcLengthsBuf.flip() - currentIpcLength = ipcLengthsBuf.getLong - - // Skips empty IPC - if (currentIpcLength == 0) { - return hasNext - } - consumed = false - return true - } - - override def next(): ReadableByteChannel = { - if (!hasNext) { - throw new NoSuchElementException - } - assert(!consumed) - consumed = true - - val is = new LimitedInputStream(Channels.newInputStream(channel), currentIpcLength, false) - currentLimitedInputStream = is - - if (decompressingNeeded) { - ShuffleUtils.compressionCodecForShuffling match { - case Some(codec) => Channels.newChannel(codec.compressedInputStream(is)) - case _ => Channels.newChannel(is) - } - } else { - Channels.newChannel(is) - } - } - - private def closeInputStream(): Unit = - synchronized { - if (in != null) { - in.close() - in = null - } - } -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala new file mode 100644 index 0000000000..2839c9bd8c --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import java.io.{EOFException, InputStream} +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.channels.{Channels, ReadableByteChannel} + +import org.apache.spark.TaskContext +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.Native +import org.apache.comet.vector.NativeUtil + +/** + * This iterator wraps a Spark input stream that is reading shuffle blocks generated by the Comet + * native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks + * and use Arrow FFI to return the Arrow record batch. + */ +case class NativeBatchDecoderIterator( + var in: InputStream, + taskContext: TaskContext, + decodeTime: SQLMetric) + extends Iterator[ColumnarBatch] { + + private var isClosed = false + private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + private val native = new Native() + private val nativeUtil = new NativeUtil() + private var currentBatch: ColumnarBatch = null + private var batch = fetchNext() + + import NativeBatchDecoderIterator.threadLocalDataBuf + + if (taskContext != null) { + taskContext.addTaskCompletionListener[Unit](_ => { + close() + }) + } + + private val channel: ReadableByteChannel = if (in != null) { + Channels.newChannel(in) + } else { + null + } + + def hasNext(): Boolean = { + if (channel == null || isClosed) { + return false + } + if (batch.isDefined) { + return true + } + + // Release the previous batch. + if (currentBatch != null) { + currentBatch.close() + currentBatch = null + } + + batch = fetchNext() + if (batch.isEmpty) { + close() + return false + } + true + } + + def next(): ColumnarBatch = { + if (!hasNext) { + throw new NoSuchElementException + } + + val nextBatch = batch.get + + currentBatch = nextBatch + batch = None + currentBatch + } + + private def fetchNext(): Option[ColumnarBatch] = { + if (channel == null || isClosed) { + return None + } + + // read compressed batch size from header + try { + longBuf.clear() + while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} + } catch { + case _: EOFException => + close() + return None + } + + // If we reach the end of the stream, we are done, or if we read partial length + // then the stream is corrupted. + if (longBuf.hasRemaining) { + if (longBuf.position() == 0) { + close() + return None + } + throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths") + } + + // get compressed length (including headers) + longBuf.flip() + val compressedLength = longBuf.getLong + + // read field count from header + longBuf.clear() + while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} + if (longBuf.hasRemaining) { + throw new EOFException("Data corrupt: unexpected EOF while reading field count") + } + longBuf.flip() + val fieldCount = longBuf.getLong.toInt + + // read body + val bytesToRead = compressedLength - 8 + if (bytesToRead > Integer.MAX_VALUE) { + // very unlikely that shuffle block will reach 2GB + throw new IllegalStateException( + s"Native shuffle block size of $bytesToRead exceeds " + + s"maximum of ${Integer.MAX_VALUE}. Try reducing shuffle batch size.") + } + var dataBuf = threadLocalDataBuf.get() + if (dataBuf.capacity() < bytesToRead) { + val newCapacity = (bytesToRead * 2L).min(Integer.MAX_VALUE).toInt + dataBuf = ByteBuffer.allocateDirect(newCapacity) + threadLocalDataBuf.set(dataBuf) + } + dataBuf.clear() + dataBuf.limit(bytesToRead.toInt) + while (dataBuf.hasRemaining && channel.read(dataBuf) >= 0) {} + if (dataBuf.hasRemaining) { + throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch") + } + + // make native call to decode batch + val startTime = System.nanoTime() + val batch = nativeUtil.getNextBatch( + fieldCount, + (arrayAddrs, schemaAddrs) => { + native.decodeShuffleBlock(dataBuf, bytesToRead.toInt, arrayAddrs, schemaAddrs) + }) + decodeTime.add(System.nanoTime() - startTime) + + batch + } + + def close(): Unit = { + synchronized { + if (!isClosed) { + if (currentBatch != null) { + currentBatch.close() + currentBatch = null + } + in.close() + isClosed = true + } + } + } +} + +object NativeBatchDecoderIterator { + private val threadLocalDataBuf: ThreadLocal[ByteBuffer] = ThreadLocal.withInitial(() => { + ByteBuffer.allocateDirect(128 * 1024) + }) +} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 6130e4cd58..13344c0ed5 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -1132,7 +1132,8 @@ class CometShuffleManagerSuite extends CometTestBase { partitioner = new Partitioner { override def numPartitions: Int = 50 override def getPartition(key: Any): Int = key.asInstanceOf[Int] - }) + }, + decodeTime = null) assert(CometShuffleManager.shouldBypassMergeSort(conf, dependency))