Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refine record batch project #1014

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
999 changes: 805 additions & 194 deletions crates/iceberg/src/arrow/record_batch_projector.rs

Large diffs are not rendered by default.

319 changes: 21 additions & 298 deletions crates/iceberg/src/arrow/record_batch_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::sync::Arc;

use arrow_array::{
Array as ArrowArray, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
Int32Array, Int64Array, NullArray, RecordBatch, RecordBatchOptions, StringArray,
};
use arrow_cast::cast;
use arrow_schema::{
DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, SchemaRef,
};
use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
use arrow_array::RecordBatch;
use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};

use super::get_field_id;
use super::record_batch_projector::{DefaultValueGenerator, RecordBatchProjector};
use crate::arrow::schema_to_arrow_schema;
use crate::spec::{Literal, PrimitiveLiteral, Schema as IcebergSchema};
use crate::{Error, ErrorKind, Result};

/// Indicates how a particular column in a processed RecordBatch should
/// be sourced.
#[derive(Debug)]
pub(crate) enum ColumnSource {
// signifies that a column should be passed through unmodified
// from the file's RecordBatch
PassThrough {
source_index: usize,
},

// signifies that a column from the file's RecordBatch has undergone
// type promotion so the source column with the given index needs
// to be promoted to the specified type
Promote {
target_type: DataType,
source_index: usize,
},

// Signifies that a new column has been inserted before the column
// with index `index`. (we choose "before" rather than "after" so
// that we can use usize; if we insert after, then we need to
// be able to store -1 here to signify that a new
// column is to be added at the front of the column list).
// If multiple columns need to be inserted at a given
// location, they should all be given the same index, as the index
// here refers to the original RecordBatch, not the interim state after
// a preceding operation.
Add {
target_type: DataType,
value: Option<PrimitiveLiteral>,
},
// The iceberg spec refers to other permissible schema evolution actions
// (see https://iceberg.apache.org/spec/#schema-evolution):
// renaming fields, deleting fields and reordering fields.
// Renames only affect the schema of the RecordBatch rather than the
// columns themselves, so a single updated cached schema can
// be re-used and no per-column actions are required.
// Deletion and Reorder can be achieved without needing this
// post-processing step by using the projection mask.
}
use crate::spec::Schema as IcebergSchema;
use crate::Result;

#[derive(Debug)]
enum BatchTransform {
Expand All @@ -81,14 +34,7 @@ enum BatchTransform {
PassThrough,

Modify {
// Every transformed RecordBatch will have the same schema. We create the
// target just once and cache it here. Helpfully, Arc<Schema> is needed in
// the constructor for RecordBatch, so we don't need an expensive copy
// each time we build a new RecordBatch
target_schema: Arc<ArrowSchema>,

// Indicates how each column in the target schema is derived.
operations: Vec<ColumnSource>,
record_bacth_projector: RecordBatchProjector,
},

// Sometimes only the schema will need modifying, for example when
Expand Down Expand Up @@ -137,21 +83,11 @@ impl RecordBatchTransformer {
&mut self,
record_batch: RecordBatch,
) -> Result<RecordBatch> {
Ok(match &self.batch_transform {
Ok(match &mut self.batch_transform {
Some(BatchTransform::PassThrough) => record_batch,
Some(BatchTransform::Modify {
target_schema,
operations,
}) => {
let options = RecordBatchOptions::default()
.with_match_field_names(false)
.with_row_count(Some(record_batch.num_rows()));
RecordBatch::try_new_with_options(
target_schema.clone(),
self.transform_columns(record_batch.columns(), operations)?,
&options,
)?
}
record_bacth_projector,
}) => record_bacth_projector.project_batch(record_batch)?,
Some(BatchTransform::ModifySchema { target_schema }) => {
record_batch.with_schema(target_schema.clone())?
}
Expand Down Expand Up @@ -179,36 +115,22 @@ impl RecordBatchTransformer {
snapshot_schema: &IcebergSchema,
projected_iceberg_field_ids: &[i32],
) -> Result<BatchTransform> {
let mapped_unprojected_arrow_schema = Arc::new(schema_to_arrow_schema(snapshot_schema)?);
let field_id_to_mapped_schema_map =
Self::build_field_id_to_arrow_schema_map(&mapped_unprojected_arrow_schema)?;

// Create a new arrow schema by selecting fields from mapped_unprojected,
// in the order of the field ids in projected_iceberg_field_ids
let fields: Result<Vec<_>> = projected_iceberg_field_ids
.iter()
.map(|field_id| {
Ok(field_id_to_mapped_schema_map
.get(field_id)
.ok_or(Error::new(ErrorKind::Unexpected, "field not found"))?
.0
.clone())
})
.collect();

let target_schema = Arc::new(ArrowSchema::new(fields?));
let projected_iceberg_schema = snapshot_schema.project(projected_iceberg_field_ids)?;
let target_schema = Arc::new(schema_to_arrow_schema(&projected_iceberg_schema)?);

match Self::compare_schemas(source_schema, &target_schema) {
SchemaComparison::Equivalent => Ok(BatchTransform::PassThrough),
SchemaComparison::NameChangesOnly => Ok(BatchTransform::ModifySchema { target_schema }),
SchemaComparison::Different => Ok(BatchTransform::Modify {
operations: Self::generate_transform_operations(
source_schema,
snapshot_schema,
projected_iceberg_field_ids,
field_id_to_mapped_schema_map,
)?,
target_schema,
record_bacth_projector: {
let projected_schema = snapshot_schema.project(projected_iceberg_field_ids)?;
RecordBatchProjector::new(
&projected_schema,
source_schema,
get_field_id,
Some(DefaultValueGenerator),
)?
},
}),
}
}
Expand Down Expand Up @@ -257,187 +179,6 @@ impl RecordBatchTransformer {
SchemaComparison::Equivalent
}
}

fn generate_transform_operations(
source_schema: &ArrowSchemaRef,
snapshot_schema: &IcebergSchema,
projected_iceberg_field_ids: &[i32],
field_id_to_mapped_schema_map: HashMap<i32, (FieldRef, usize)>,
) -> Result<Vec<ColumnSource>> {
let field_id_to_source_schema_map =
Self::build_field_id_to_arrow_schema_map(source_schema)?;

projected_iceberg_field_ids.iter().map(|field_id|{
let (target_field, _) = field_id_to_mapped_schema_map.get(field_id).ok_or(
Error::new(ErrorKind::Unexpected, "could not find field in schema")
)?;
let target_type = target_field.data_type();

Ok(if let Some((source_field, source_index)) = field_id_to_source_schema_map.get(field_id) {
// column present in source

if source_field.data_type().equals_datatype(target_type) {
// no promotion required
ColumnSource::PassThrough {
source_index: *source_index
}
} else {
// promotion required
ColumnSource::Promote {
target_type: target_type.clone(),
source_index: *source_index,
}
}
} else {
// column must be added
let iceberg_field = snapshot_schema.field_by_id(*field_id).ok_or(
Error::new(ErrorKind::Unexpected, "Field not found in snapshot schema")
)?;

let default_value = if let Some(iceberg_default_value) =
&iceberg_field.initial_default
{
let Literal::Primitive(primitive_literal) = iceberg_default_value else {
return Err(Error::new(
ErrorKind::Unexpected,
format!("Default value for column must be primitive type, but encountered {:?}", iceberg_default_value)
));
};
Some(primitive_literal.clone())
} else {
None
};

ColumnSource::Add {
value: default_value,
target_type: target_type.clone(),
}
})
}).collect()
}

fn build_field_id_to_arrow_schema_map(
source_schema: &SchemaRef,
) -> Result<HashMap<i32, (FieldRef, usize)>> {
let mut field_id_to_source_schema = HashMap::new();
for (source_field_idx, source_field) in source_schema.fields.iter().enumerate() {
let this_field_id = source_field
.metadata()
.get(PARQUET_FIELD_ID_META_KEY)
.ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
"field ID not present in parquet metadata",
)
})?
.parse()
.map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
format!("field id not parseable as an i32: {}", e),
)
})?;

field_id_to_source_schema
.insert(this_field_id, (source_field.clone(), source_field_idx));
}

Ok(field_id_to_source_schema)
}

fn transform_columns(
&self,
columns: &[Arc<dyn ArrowArray>],
operations: &[ColumnSource],
) -> Result<Vec<Arc<dyn ArrowArray>>> {
if columns.is_empty() {
return Ok(columns.to_vec());
}
let num_rows = columns[0].len();

operations
.iter()
.map(|op| {
Ok(match op {
ColumnSource::PassThrough { source_index } => columns[*source_index].clone(),

ColumnSource::Promote {
target_type,
source_index,
} => cast(&*columns[*source_index], target_type)?,

ColumnSource::Add { target_type, value } => {
Self::create_column(target_type, value, num_rows)?
}
})
})
.collect()
}

fn create_column(
target_type: &DataType,
prim_lit: &Option<PrimitiveLiteral>,
num_rows: usize,
) -> Result<ArrayRef> {
Ok(match (target_type, prim_lit) {
(DataType::Boolean, Some(PrimitiveLiteral::Boolean(value))) => {
Arc::new(BooleanArray::from(vec![*value; num_rows]))
}
(DataType::Boolean, None) => {
let vals: Vec<Option<bool>> = vec![None; num_rows];
Arc::new(BooleanArray::from(vals))
}
(DataType::Int32, Some(PrimitiveLiteral::Int(value))) => {
Arc::new(Int32Array::from(vec![*value; num_rows]))
}
(DataType::Int32, None) => {
let vals: Vec<Option<i32>> = vec![None; num_rows];
Arc::new(Int32Array::from(vals))
}
(DataType::Int64, Some(PrimitiveLiteral::Long(value))) => {
Arc::new(Int64Array::from(vec![*value; num_rows]))
}
(DataType::Int64, None) => {
let vals: Vec<Option<i64>> = vec![None; num_rows];
Arc::new(Int64Array::from(vals))
}
(DataType::Float32, Some(PrimitiveLiteral::Float(value))) => {
Arc::new(Float32Array::from(vec![value.0; num_rows]))
}
(DataType::Float32, None) => {
let vals: Vec<Option<f32>> = vec![None; num_rows];
Arc::new(Float32Array::from(vals))
}
(DataType::Float64, Some(PrimitiveLiteral::Double(value))) => {
Arc::new(Float64Array::from(vec![value.0; num_rows]))
}
(DataType::Float64, None) => {
let vals: Vec<Option<f64>> = vec![None; num_rows];
Arc::new(Float64Array::from(vals))
}
(DataType::Utf8, Some(PrimitiveLiteral::String(value))) => {
Arc::new(StringArray::from(vec![value.clone(); num_rows]))
}
(DataType::Utf8, None) => {
let vals: Vec<Option<String>> = vec![None; num_rows];
Arc::new(StringArray::from(vals))
}
(DataType::Binary, Some(PrimitiveLiteral::Binary(value))) => {
Arc::new(BinaryArray::from_vec(vec![value; num_rows]))
}
(DataType::Binary, None) => {
let vals: Vec<Option<&[u8]>> = vec![None; num_rows];
Arc::new(BinaryArray::from_opt_vec(vals))
}
(DataType::Null, _) => Arc::new(NullArray::new(num_rows)),
(dt, _) => {
return Err(Error::new(
ErrorKind::Unexpected,
format!("unexpected target column type {}", dt),
))
}
})
}
}

#[cfg(test)]
Expand All @@ -454,24 +195,6 @@ mod test {
use crate::arrow::record_batch_transformer::RecordBatchTransformer;
use crate::spec::{Literal, NestedField, PrimitiveType, Schema, Type};

#[test]
fn build_field_id_to_source_schema_map_works() {
let arrow_schema = arrow_schema_already_same_as_target();

let result =
RecordBatchTransformer::build_field_id_to_arrow_schema_map(&arrow_schema).unwrap();

let expected = HashMap::from_iter([
(10, (arrow_schema.fields()[0].clone(), 0)),
(11, (arrow_schema.fields()[1].clone(), 1)),
(12, (arrow_schema.fields()[2].clone(), 2)),
(14, (arrow_schema.fields()[3].clone(), 3)),
(15, (arrow_schema.fields()[4].clone(), 4)),
]);

assert!(result.eq(&expected));
}

#[test]
fn processor_returns_properly_shaped_record_batch_when_no_schema_migration_required() {
let snapshot_schema = Arc::new(iceberg_table_schema());
Expand Down
2 changes: 1 addition & 1 deletion crates/iceberg/src/arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ pub fn arrow_type_to_type(ty: &DataType) -> Result<Type> {

const ARROW_FIELD_DOC_KEY: &str = "doc";

pub(super) fn get_field_id(field: &Field) -> Result<i32> {
pub(crate) fn get_field_id(field: &Field) -> Result<i32> {
if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) {
return value.parse::<i32>().map_err(|e| {
Error::new(
Expand Down
Loading
Loading