Skip to content

Commit

Permalink
Deduplicate and standardize deserialization logic for streams (#13412)
Browse files Browse the repository at this point in the history
* Add BatchDeserializer

* Fix formatting

* Remove unused enum value

* Update datafusion/core/src/datasource/file_format/mod.rs

---------

Co-authored-by: Mehmet Ozan Kabak <[email protected]>
  • Loading branch information
alihan-synnada and ozankabak authored Nov 16, 2024
1 parent a09814a commit 06db9ed
Show file tree
Hide file tree
Showing 5 changed files with 547 additions and 72 deletions.
235 changes: 229 additions & 6 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use std::fmt::{self, Debug};
use std::sync::Arc;

use super::write::orchestration::stateless_multipart_put;
use super::{FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD};
use super::{
Decoder, DecoderDeserializer, FileFormat, FileFormatFactory,
DEFAULT_SCHEMA_INFER_MAX_RECORD,
};
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::file_format::write::BatchSerializer;
use crate::datasource::physical_plan::{
Expand All @@ -38,8 +41,8 @@ use crate::physical_plan::{

use arrow::array::RecordBatch;
use arrow::csv::WriterBuilder;
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, Field, Fields, Schema};
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use arrow_schema::ArrowError;
use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::{
Expand Down Expand Up @@ -293,6 +296,45 @@ impl CsvFormat {
}
}

#[derive(Debug)]
pub(crate) struct CsvDecoder {
inner: arrow::csv::reader::Decoder,
}

impl CsvDecoder {
pub(crate) fn new(decoder: arrow::csv::reader::Decoder) -> Self {
Self { inner: decoder }
}
}

impl Decoder for CsvDecoder {
fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
self.inner.decode(buf)
}

fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
self.inner.flush()
}

fn can_flush_early(&self) -> bool {
self.inner.capacity() == 0
}
}

impl Debug for CsvSerializer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CsvSerializer")
.field("header", &self.header)
.finish()
}
}

impl From<arrow::csv::reader::Decoder> for DecoderDeserializer<CsvDecoder> {
fn from(decoder: arrow::csv::reader::Decoder) -> Self {
DecoderDeserializer::new(CsvDecoder::new(decoder))
}
}

#[async_trait]
impl FileFormat for CsvFormat {
fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -692,23 +734,28 @@ impl DataSink for CsvSink {
mod tests {
use super::super::test_util::scan_format;
use super::*;
use crate::arrow::util::pretty;
use crate::assert_batches_eq;
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::file_format::test_util::VariableStream;
use crate::datasource::file_format::{
BatchDeserializer, DecoderDeserializer, DeserializerOutput,
};
use crate::datasource::listing::ListingOptions;
use crate::execution::session_state::SessionStateBuilder;
use crate::physical_plan::collect;
use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
use crate::test_util::arrow_test_data;

use arrow::compute::concat_batches;
use arrow::csv::ReaderBuilder;
use arrow::util::pretty::pretty_format_batches;
use arrow_array::{BooleanArray, Float64Array, Int32Array, StringArray};
use datafusion_common::cast::as_string_array;
use datafusion_common::internal_err;
use datafusion_common::stats::Precision;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::{col, lit};

use crate::execution::session_state::SessionStateBuilder;
use chrono::DateTime;
use object_store::local::LocalFileSystem;
use object_store::path::Path;
Expand Down Expand Up @@ -1097,7 +1144,7 @@ mod tests {
) -> Result<usize> {
let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
let result = df.collect().await?;
let plan = format!("{}", &pretty::pretty_format_batches(&result)?);
let plan = format!("{}", &pretty_format_batches(&result)?);

let re = Regex::new(r"CsvExec: file_groups=\{(\d+) group").unwrap();

Expand Down Expand Up @@ -1464,4 +1511,180 @@ mod tests {

Ok(())
}

#[rstest]
fn test_csv_deserializer_with_finish(
#[values(1, 5, 17)] batch_size: usize,
#[values(0, 5, 93)] line_count: usize,
) -> Result<()> {
let schema = csv_schema();
let generator = CsvBatchGenerator::new(batch_size, line_count);
let mut deserializer = csv_deserializer(batch_size, &schema);

for data in generator {
deserializer.digest(data);
}
deserializer.finish();

let batch_count = line_count.div_ceil(batch_size);

let mut all_batches = RecordBatch::new_empty(schema.clone());
for _ in 0..batch_count {
let output = deserializer.next()?;
let DeserializerOutput::RecordBatch(batch) = output else {
panic!("Expected RecordBatch, got {:?}", output);
};
all_batches = concat_batches(&schema, &[all_batches, batch])?;
}
assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted);

let expected = csv_expected_batch(schema, line_count)?;

assert_eq!(
expected.clone(),
all_batches.clone(),
"Expected:\n{}\nActual:\n{}",
pretty_format_batches(&[expected])?,
pretty_format_batches(&[all_batches])?,
);

Ok(())
}

#[rstest]
fn test_csv_deserializer_without_finish(
#[values(1, 5, 17)] batch_size: usize,
#[values(0, 5, 93)] line_count: usize,
) -> Result<()> {
let schema = csv_schema();
let generator = CsvBatchGenerator::new(batch_size, line_count);
let mut deserializer = csv_deserializer(batch_size, &schema);

for data in generator {
deserializer.digest(data);
}

let batch_count = line_count / batch_size;

let mut all_batches = RecordBatch::new_empty(schema.clone());
for _ in 0..batch_count {
let output = deserializer.next()?;
let DeserializerOutput::RecordBatch(batch) = output else {
panic!("Expected RecordBatch, got {:?}", output);
};
all_batches = concat_batches(&schema, &[all_batches, batch])?;
}
assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData);

let expected = csv_expected_batch(schema, batch_count * batch_size)?;

assert_eq!(
expected.clone(),
all_batches.clone(),
"Expected:\n{}\nActual:\n{}",
pretty_format_batches(&[expected])?,
pretty_format_batches(&[all_batches])?,
);

Ok(())
}

struct CsvBatchGenerator {
batch_size: usize,
line_count: usize,
offset: usize,
}

impl CsvBatchGenerator {
fn new(batch_size: usize, line_count: usize) -> Self {
Self {
batch_size,
line_count,
offset: 0,
}
}
}

impl Iterator for CsvBatchGenerator {
type Item = Bytes;

fn next(&mut self) -> Option<Self::Item> {
// Return `batch_size` rows per batch:
let mut buffer = Vec::new();
for _ in 0..self.batch_size {
if self.offset >= self.line_count {
break;
}
buffer.extend_from_slice(&csv_line(self.offset));
self.offset += 1;
}

(!buffer.is_empty()).then(|| buffer.into())
}
}

fn csv_expected_batch(
schema: SchemaRef,
line_count: usize,
) -> Result<RecordBatch, DataFusionError> {
let mut c1 = Vec::with_capacity(line_count);
let mut c2 = Vec::with_capacity(line_count);
let mut c3 = Vec::with_capacity(line_count);
let mut c4 = Vec::with_capacity(line_count);

for i in 0..line_count {
let (int_value, float_value, bool_value, char_value) = csv_values(i);
c1.push(int_value);
c2.push(float_value);
c3.push(bool_value);
c4.push(char_value);
}

let expected = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(c1)),
Arc::new(Float64Array::from(c2)),
Arc::new(BooleanArray::from(c3)),
Arc::new(StringArray::from(c4)),
],
)?;
Ok(expected)
}

fn csv_line(line_number: usize) -> Bytes {
let (int_value, float_value, bool_value, char_value) = csv_values(line_number);
format!(
"{},{},{},{}\n",
int_value, float_value, bool_value, char_value
)
.into()
}

fn csv_values(line_number: usize) -> (i32, f64, bool, String) {
let int_value = line_number as i32;
let float_value = line_number as f64;
let bool_value = line_number % 2 == 0;
let char_value = format!("{}-string", line_number);
(int_value, float_value, bool_value, char_value)
}

fn csv_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Float64, true),
Field::new("c3", DataType::Boolean, true),
Field::new("c4", DataType::Utf8, true),
]))
}

fn csv_deserializer(
batch_size: usize,
schema: &Arc<Schema>,
) -> impl BatchDeserializer<Bytes> {
let decoder = ReaderBuilder::new(schema.clone())
.with_batch_size(batch_size)
.build_decoder();
DecoderDeserializer::new(CsvDecoder::new(decoder))
}
}
Loading

0 comments on commit 06db9ed

Please sign in to comment.