From e8caac3429834082f4390d29adce6df9d972800a Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:47:51 +0100 Subject: [PATCH 1/6] feat: generated columns Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/delta_datafusion/mod.rs | 15 +- crates/core/src/kernel/error.rs | 9 + crates/core/src/kernel/models/actions.rs | 101 +++++++++-- crates/core/src/kernel/models/schema.rs | 72 ++++++++ crates/core/src/operations/add_column.rs | 26 +-- crates/core/src/operations/add_feature.rs | 4 +- crates/core/src/operations/cdc.rs | 2 +- crates/core/src/operations/create.rs | 45 ++--- .../src/operations/transaction/protocol.rs | 44 +---- crates/core/src/operations/write.rs | 167 +++++++++++------- crates/core/src/table/mod.rs | 36 ++++ 11 files changed, 358 insertions(+), 163 deletions(-) diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 6aad9ab3f7..08c38ddff3 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -81,7 +81,7 @@ use crate::kernel::{Add, DataCheck, EagerSnapshot, Invariant, Snapshot, StructTy use crate::logstore::LogStoreRef; use crate::table::builder::ensure_table_uri; use crate::table::state::DeltaTableState; -use crate::table::Constraint; +use crate::table::{Constraint, GeneratedColumn}; use crate::{open_table, open_table_with_storage_options, DeltaTable}; pub(crate) const PATH_COLUMN: &str = "__delta_rs_path"; @@ -1159,6 +1159,7 @@ pub(crate) async fn execute_plan_to_batch( pub struct DeltaDataChecker { constraints: Vec, invariants: Vec, + generated_columns: Vec, non_nullable_columns: Vec, ctx: SessionContext, } @@ -1169,6 +1170,7 @@ impl DeltaDataChecker { Self { invariants: vec![], constraints: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1179,6 +1181,7 @@ impl DeltaDataChecker { Self { invariants, constraints: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1189,6 +1192,7 @@ impl DeltaDataChecker { Self { constraints, invariants: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1209,6 +1213,10 @@ impl DeltaDataChecker { /// Create a new DeltaDataChecker pub fn new(snapshot: &DeltaTableState) -> Self { let invariants = snapshot.schema().get_invariants().unwrap_or_default(); + let generated_columns = snapshot + .schema() + .get_generated_columns() + .unwrap_or_default(); let constraints = snapshot.table_config().get_constraints(); let non_nullable_columns = snapshot .schema() @@ -1224,6 +1232,7 @@ impl DeltaDataChecker { Self { invariants, constraints, + generated_columns, non_nullable_columns, ctx: DeltaSessionContext::default().into(), } @@ -1236,7 +1245,9 @@ impl DeltaDataChecker { pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { self.check_nullability(record_batch)?; self.enforce_checks(record_batch, &self.invariants).await?; - self.enforce_checks(record_batch, &self.constraints).await + self.enforce_checks(record_batch, &self.constraints).await?; + self.enforce_checks(record_batch, &self.generated_columns) + .await } /// Return true if all the nullability checks are valid diff --git a/crates/core/src/kernel/error.rs b/crates/core/src/kernel/error.rs index cefe81bf9d..fe34b1d7e4 100644 --- a/crates/core/src/kernel/error.rs +++ b/crates/core/src/kernel/error.rs @@ -65,6 +65,15 @@ pub enum Error { line: String, }, + /// Error returned when the log contains invalid stats JSON. + #[error("Invalid JSON in generation expression, line=`{line}`, err=`{json_err}`")] + InvalidGenerationExpressionJson { + /// JSON error details returned when parsing the generation expression JSON. + json_err: serde_json::error::Error, + /// Generation expression. + line: String, + }, + #[error("Table metadata is invalid: {0}")] MetadataError(String), diff --git a/crates/core/src/kernel/models/actions.rs b/crates/core/src/kernel/models/actions.rs index 3812dc4838..d825d5bec4 100644 --- a/crates/core/src/kernel/models/actions.rs +++ b/crates/core/src/kernel/models/actions.rs @@ -2,12 +2,14 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Display}; use std::str::FromStr; +use delta_kernel::schema::{DataType, StructField}; use maplit::hashset; use serde::{Deserialize, Serialize}; use tracing::warn; use url::Url; use super::schema::StructType; +use super::StructTypeExt; use crate::kernel::{error::Error, DeltaResult}; use crate::TableProperty; use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; @@ -115,6 +117,19 @@ impl Metadata { } } +/// checks if table contains timestamp_ntz in any field including nested fields. +pub fn contains_timestampntz<'a>(mut fields: impl Iterator) -> bool { + fn _check_type(dtype: &DataType) -> bool { + match dtype { + &DataType::TIMESTAMP_NTZ => true, + DataType::Array(inner) => _check_type(inner.element_type()), + DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())), + _ => false, + } + } + fields.any(|f| _check_type(f.data_type())) +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] #[serde(rename_all = "camelCase")] /// Defines a protocol action @@ -146,8 +161,8 @@ impl Protocol { } } - /// set the reader features in the protocol action, automatically bumps min_reader_version - pub fn with_reader_features( + /// Append the reader features in the protocol action, automatically bumps min_reader_version + pub fn append_reader_features( mut self, reader_features: impl IntoIterator>, ) -> Self { @@ -156,14 +171,20 @@ impl Protocol { .map(Into::into) .collect::>(); if !all_reader_features.is_empty() { - self.min_reader_version = 3 + self.min_reader_version = 3; + match self.reader_features { + Some(mut features) => { + features.extend(all_reader_features); + self.reader_features = Some(features); + } + None => self.reader_features = Some(all_reader_features), + }; } - self.reader_features = Some(all_reader_features); self } - /// set the writer features in the protocol action, automatically bumps min_writer_version - pub fn with_writer_features( + /// Append the writer features in the protocol action, automatically bumps min_writer_version + pub fn append_writer_features( mut self, writer_features: impl IntoIterator>, ) -> Self { @@ -172,9 +193,16 @@ impl Protocol { .map(|c| c.into()) .collect::>(); if !all_writer_feautures.is_empty() { - self.min_writer_version = 7 + self.min_writer_version = 7; + + match self.writer_features { + Some(mut features) => { + features.extend(all_writer_feautures); + self.writer_features = Some(features); + } + None => self.writer_features = Some(all_writer_feautures), + }; } - self.writer_features = Some(all_writer_feautures); self } @@ -255,6 +283,32 @@ impl Protocol { } self } + + /// Will apply the column metadata to the protocol by either bumping the version or setting + /// features + pub fn apply_column_metadata_to_protocol( + mut self, + schema: &StructType, + ) -> DeltaResult { + let generated_cols = schema.get_generated_columns()?; + let invariants = schema.get_invariants()?; + let contains_timestamp_ntz = self.contains_timestampntz(schema.fields()); + + if contains_timestamp_ntz { + self = self.enable_timestamp_ntz() + } + + if !generated_cols.is_empty() { + self = self.enable_generated_columns() + } + + if !invariants.is_empty() { + self = self.enable_invariants() + } + + Ok(self) + } + /// Will apply the properties to the protocol by either bumping the version or setting /// features pub fn apply_properties_to_protocol( @@ -391,10 +445,35 @@ impl Protocol { } Ok(self) } + + /// checks if table contains timestamp_ntz in any field including nested fields. + fn contains_timestampntz<'a>(&self, fields: impl Iterator) -> bool { + contains_timestampntz(fields) + } + /// Enable timestamp_ntz in the protocol - pub fn enable_timestamp_ntz(mut self) -> Protocol { - self = self.with_reader_features(vec![ReaderFeatures::TimestampWithoutTimezone]); - self = self.with_writer_features(vec![WriterFeatures::TimestampWithoutTimezone]); + fn enable_timestamp_ntz(mut self) -> Self { + self = self.append_reader_features([ReaderFeatures::TimestampWithoutTimezone]); + self = self.append_writer_features([WriterFeatures::TimestampWithoutTimezone]); + self + } + + /// Enabled generated columns + fn enable_generated_columns(mut self) -> Self { + if self.min_writer_version < 4 { + self.min_writer_version = 4; + } + if self.min_writer_version >= 7 { + self = self.append_writer_features([WriterFeatures::GeneratedColumns]); + } + self + } + + /// Enabled generated columns + fn enable_invariants(mut self) -> Self { + if self.min_writer_version >= 7 { + self = self.append_writer_features([WriterFeatures::Invariants]); + } self } } diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 3a88564f1d..976fe467ef 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -10,6 +10,7 @@ use serde_json::Value; use crate::kernel::error::Error; use crate::kernel::DataCheck; +use crate::table::GeneratedColumn; /// Type alias for a top level schema pub type Schema = StructType; @@ -49,9 +50,80 @@ impl DataCheck for Invariant { pub trait StructTypeExt { /// Get all invariants in the schemas fn get_invariants(&self) -> Result, Error>; + + /// Get all generated column expressions + fn get_generated_columns(&self) -> Result, Error>; } impl StructTypeExt for StructType { + /// Get all invariants in the schemas + fn get_generated_columns(&self) -> Result, Error> { + let mut remaining_fields: Vec<(String, StructField)> = self + .fields() + .map(|field| (field.name.clone(), field.clone())) + .collect(); + let mut generated_cols: Vec = Vec::new(); + + let add_segment = |prefix: &str, segment: &str| -> String { + if prefix.is_empty() { + segment.to_owned() + } else { + format!("{prefix}.{segment}") + } + }; + + while let Some((field_path, field)) = remaining_fields.pop() { + match field.data_type() { + DataType::Struct(inner) => { + remaining_fields.extend( + inner + .fields() + .map(|field| { + let new_prefix = add_segment(&field_path, &field.name); + (new_prefix, field.clone()) + }) + .collect::>(), + ); + } + DataType::Array(inner) => { + let element_field_name = add_segment(&field_path, "element"); + remaining_fields.push(( + element_field_name, + StructField::new("".to_string(), inner.element_type.clone(), false), + )); + } + DataType::Map(inner) => { + let key_field_name = add_segment(&field_path, "key"); + remaining_fields.push(( + key_field_name, + StructField::new("".to_string(), inner.key_type.clone(), false), + )); + let value_field_name = add_segment(&field_path, "value"); + remaining_fields.push(( + value_field_name, + StructField::new("".to_string(), inner.value_type.clone(), false), + )); + } + _ => {} + } + if let Some(MetadataValue::String(generated_col_string)) = field + .metadata + .get(ColumnMetadataKey::GenerationExpression.as_ref()) + { + let json: Value = serde_json::from_str(generated_col_string).map_err(|e| { + Error::InvalidGenerationExpressionJson { + json_err: e, + line: generated_col_string.to_string(), + } + })?; + if let Value::String(sql) = json { + generated_cols.push(GeneratedColumn::new(&field_path, &sql)); + } + } + } + Ok(generated_cols) + } + /// Get all invariants in the schemas fn get_invariants(&self) -> Result, Error> { let mut remaining_fields: Vec<(String, StructField)> = self diff --git a/crates/core/src/operations/add_column.rs b/crates/core/src/operations/add_column.rs index 2b6d9de7df..a3477405af 100644 --- a/crates/core/src/operations/add_column.rs +++ b/crates/core/src/operations/add_column.rs @@ -88,24 +88,12 @@ impl std::future::IntoFuture for AddColumnBuilder { let table_schema = this.snapshot.schema(); let new_table_schema = merge_delta_struct(table_schema, fields_right)?; - // TODO(ion): Think of a way how we can simply this checking through the API or centralize some checks. - let contains_timestampntz = PROTOCOL.contains_timestampntz(fields.iter()); - let protocol = this.snapshot.protocol(); - - let maybe_new_protocol = if contains_timestampntz { - let updated_protocol = protocol.clone().enable_timestamp_ntz(); - if !(protocol.min_reader_version == 3 && protocol.min_writer_version == 7) { - // Convert existing properties to features since we advanced the protocol to v3,7 - Some( - updated_protocol - .move_table_properties_into_features(&metadata.configuration), - ) - } else { - Some(updated_protocol) - } - } else { - None - }; + let current_protocol = this.snapshot.protocol(); + + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&new_table_schema)? + .move_table_properties_into_features(&metadata.configuration); let operation = DeltaOperation::AddColumn { fields: fields.into_iter().collect_vec(), @@ -115,7 +103,7 @@ impl std::future::IntoFuture for AddColumnBuilder { let mut actions = vec![metadata.into()]; - if let Some(new_protocol) = maybe_new_protocol { + if current_protocol != &new_protocol { actions.push(new_protocol.into()) } diff --git a/crates/core/src/operations/add_feature.rs b/crates/core/src/operations/add_feature.rs index 0e7f88ee7f..31dbb928bf 100644 --- a/crates/core/src/operations/add_feature.rs +++ b/crates/core/src/operations/add_feature.rs @@ -123,8 +123,8 @@ impl std::future::IntoFuture for AddTableFeatureBuilder { } } - protocol = protocol.with_reader_features(reader_features); - protocol = protocol.with_writer_features(writer_features); + protocol = protocol.append_reader_features(reader_features); + protocol = protocol.append_writer_features(writer_features); let operation = DeltaOperation::AddFeature { name: name.to_vec(), diff --git a/crates/core/src/operations/cdc.rs b/crates/core/src/operations/cdc.rs index c9d0ca0665..5e950402b8 100644 --- a/crates/core/src/operations/cdc.rs +++ b/crates/core/src/operations/cdc.rs @@ -175,7 +175,7 @@ mod tests { #[tokio::test] async fn test_should_write_cdc_v7_table_with_writer_feature() { let protocol = - Protocol::new(1, 7).with_writer_features(vec![WriterFeatures::ChangeDataFeed]); + Protocol::new(1, 7).append_writer_features(vec![WriterFeatures::ChangeDataFeed]); let actions = vec![Action::Protocol(protocol)]; let mut table: DeltaTable = DeltaOps::new_in_memory() .create() diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index 5f6ef47bc0..bcf79650cf 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -289,24 +289,12 @@ impl CreateBuilder { self.pre_execute(operation_id).await?; let configuration = self.configuration; - let contains_timestampntz = PROTOCOL.contains_timestampntz(self.columns.iter()); - // TODO configure more permissive versions based on configuration. Also how should this ideally be handled? - // We set the lowest protocol we can, and if subsequent writes use newer features we update metadata? - - let current_protocol = if contains_timestampntz { - Protocol { - min_reader_version: 3, - min_writer_version: 7, - writer_features: Some(hashset! {WriterFeatures::TimestampWithoutTimezone}), - reader_features: Some(hashset! {ReaderFeatures::TimestampWithoutTimezone}), - } - } else { - Protocol { - min_reader_version: PROTOCOL.default_reader_version(), - min_writer_version: PROTOCOL.default_writer_version(), - reader_features: None, - writer_features: None, - } + + let current_protocol = Protocol { + min_reader_version: PROTOCOL.default_reader_version(), + min_writer_version: PROTOCOL.default_writer_version(), + reader_features: None, + writer_features: None, }; let protocol = self @@ -319,18 +307,21 @@ impl CreateBuilder { }) .unwrap_or_else(|| current_protocol); - let protocol = protocol.apply_properties_to_protocol( - &configuration - .iter() - .map(|(k, v)| (k.clone(), v.clone().unwrap())) - .collect::>(), - self.raise_if_key_not_exists, - )?; + let schema = StructType::new(self.columns); - let protocol = protocol.move_table_properties_into_features(&configuration); + let protocol = protocol + .apply_properties_to_protocol( + &configuration + .iter() + .map(|(k, v)| (k.clone(), v.clone().unwrap())) + .collect::>(), + self.raise_if_key_not_exists, + )? + .apply_column_metadata_to_protocol(&schema)? + .move_table_properties_into_features(&configuration); let mut metadata = Metadata::try_new( - StructType::new(self.columns), + schema, self.partition_columns.unwrap_or_default(), configuration, )? diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index ef88fbf8e6..bb49e0fae9 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -5,7 +5,7 @@ use once_cell::sync::Lazy; use tracing::log::*; use super::{TableReference, TransactionError}; -use crate::kernel::{Action, DataType, EagerSnapshot, Schema, StructField}; +use crate::kernel::{contains_timestampntz, Action, DataType, EagerSnapshot, Schema, StructField}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; @@ -79,29 +79,13 @@ impl ProtocolChecker { Ok(()) } - /// checks if table contains timestamp_ntz in any field including nested fields. - pub fn contains_timestampntz<'a>( - &self, - mut fields: impl Iterator, - ) -> bool { - fn _check_type(dtype: &DataType) -> bool { - match dtype { - &DataType::TIMESTAMP_NTZ => true, - DataType::Array(inner) => _check_type(inner.element_type()), - DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())), - _ => false, - } - } - fields.any(|f| _check_type(f.data_type())) - } - /// Check can write_timestamp_ntz pub fn check_can_write_timestamp_ntz( &self, snapshot: &DeltaTableState, schema: &Schema, ) -> Result<(), TransactionError> { - let contains_timestampntz = self.contains_timestampntz(schema.fields()); + let contains_timestampntz = contains_timestampntz(schema.fields()); let required_features: Option<&HashSet> = match snapshot.protocol().min_writer_version { 0..=6 => None, @@ -159,22 +143,6 @@ impl ProtocolChecker { _ => snapshot.protocol().writer_features.as_ref(), }; - if (4..7).contains(&min_writer_version) { - debug!("min_writer_version is less 4-6, checking for unsupported table features"); - if let Ok(schema) = snapshot.metadata().schema() { - for field in schema.fields() { - if field.metadata.contains_key( - crate::kernel::ColumnMetadataKey::GenerationExpression.as_ref(), - ) { - error!("The table contains `delta.generationExpression` settings on columns which mean this table cannot be currently written to by delta-rs"); - return Err(TransactionError::UnsupportedWriterFeatures(vec![ - WriterFeatures::GeneratedColumns, - ])); - } - } - } - } - if let Some(features) = required_features { let mut diff = features.difference(&self.writer_features).peekable(); if diff.peek().is_some() { @@ -246,15 +214,13 @@ pub static INSTANCE: Lazy = Lazy::new(|| { #[cfg(feature = "cdf")] { writer_features.insert(WriterFeatures::ChangeDataFeed); - writer_features.insert(WriterFeatures::GeneratedColumns); } #[cfg(feature = "datafusion")] { writer_features.insert(WriterFeatures::Invariants); writer_features.insert(WriterFeatures::CheckConstraints); + writer_features.insert(WriterFeatures::GeneratedColumns); } - // writer_features.insert(WriterFeatures::ChangeDataFeed); - // writer_features.insert(WriterFeatures::GeneratedColumns); // writer_features.insert(WriterFeatures::ColumnMapping); // writer_features.insert(WriterFeatures::IdentityColumns); @@ -584,7 +550,7 @@ mod tests { let checker_5 = ProtocolChecker::new(READER_V2.clone(), WRITER_V4.clone()); let actions = vec![ Action::Protocol( - Protocol::new(2, 4).with_writer_features(vec![WriterFeatures::ChangeDataFeed]), + Protocol::new(2, 4).append_writer_features(vec![WriterFeatures::ChangeDataFeed]), ), metadata_action(None).into(), ]; @@ -601,7 +567,7 @@ mod tests { let checker_5 = ProtocolChecker::new(READER_V2.clone(), WRITER_V4.clone()); let actions = vec![ Action::Protocol( - Protocol::new(2, 4).with_writer_features(vec![WriterFeatures::GeneratedColumns]), + Protocol::new(2, 4).append_writer_features([WriterFeatures::GeneratedColumns]), ), metadata_action(None).into(), ]; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 0be11f05cf..1a1cc9d11f 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -25,6 +25,7 @@ //! ```` use std::collections::HashMap; +use std::hash::Hash; use std::str::FromStr; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; @@ -35,7 +36,7 @@ use arrow_cast::can_cast_types; use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion_common::DFSchema; -use datafusion_expr::{lit, Expr}; +use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::expressions::{self}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::filter::FilterExec; @@ -63,14 +64,15 @@ use crate::delta_datafusion::{ use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{ - Action, ActionType, Add, AddCDCFile, Metadata, PartitionsExt, Remove, StructType, + Action, ActionType, Add, AddCDCFile, DataCheck, Metadata, PartitionsExt, Remove, StructType, + StructTypeExt, }; use crate::logstore::LogStoreRef; use crate::operations::cast::{cast_record_batch, merge_schema::merge_arrow_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; -use crate::table::Constraint as DeltaConstraint; +use crate::table::{Constraint as DeltaConstraint, GeneratedColumn}; use crate::writer::record_batch::divide_by_partition_values; use crate::DeltaTable; @@ -852,8 +854,14 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; + + let generated_col_expressions = this + .snapshot + .as_ref() + .map(|v| v.schema().get_generated_columns().unwrap_or_default()) + .unwrap_or_default(); let mut schema_drift = false; - let plan = if let Some(plan) = this.input { + let mut plan = if let Some(plan) = this.input { if this.schema_mode == Some(SchemaMode::Merge) { return Err(DeltaTableError::Generic( "Schema merge not supported yet for Datafusion".to_string(), @@ -864,11 +872,22 @@ impl std::future::IntoFuture for WriteBuilder { if batches.is_empty() { Err(WriteError::MissingData) } else { - let schema = batches[0].schema(); + let mut schema = batches[0].schema(); + // Schema merging code should be aware of columns that can be generated during write + // so they might be empty in the batch, but the will exist in the input_schema() + // in this case we have to insert the generated column and it's type in the schema of the batch let mut new_schema = None; if let Some(snapshot) = &this.snapshot { let table_schema = snapshot.input_schema()?; + + // Merge schema's initial round when there are generated columns expressions + // This is to have the batch schema be the same as the input schema without adding new fields + // from the incoming batch + if !generated_col_expressions.is_empty() { + schema = merge_arrow_schema(table_schema.clone(), schema, true)?; + } + if let Err(schema_err) = try_cast_batch(schema.fields(), table_schema.fields()) { @@ -876,7 +895,11 @@ impl std::future::IntoFuture for WriteBuilder { if this.mode == SaveMode::Overwrite && this.schema_mode == Some(SchemaMode::Overwrite) { - new_schema = None // we overwrite anyway, so no need to cast + if generated_col_expressions.is_empty() { + new_schema = None // we overwrite anyway, so no need to cast + } else { + new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null + } } else if this.schema_mode == Some(SchemaMode::Merge) { new_schema = Some(merge_arrow_schema( table_schema.clone(), @@ -889,7 +912,11 @@ impl std::future::IntoFuture for WriteBuilder { } else if this.mode == SaveMode::Overwrite && this.schema_mode == Some(SchemaMode::Overwrite) { - new_schema = None // we overwrite anyway, so no need to cast + if generated_col_expressions.is_empty() { + new_schema = None // we overwrite anyway, so no need to cast + } else { + new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null + } } else { // Schema needs to be merged so that utf8/binary/list types are preserved from the batch side if both table // and batch contains such type. Other types are preserved from the table side. @@ -912,7 +939,7 @@ impl std::future::IntoFuture for WriteBuilder { &batch, new_schema, this.safe_cast, - schema_drift, // Schema drifted so we have to add the missing columns/structfields. + schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols.. )?, None => batch, }; @@ -949,7 +976,7 @@ impl std::future::IntoFuture for WriteBuilder { &batch, new_schema.clone(), this.safe_cast, - schema_drift, // Schema drifted so we have to add the missing columns/structfields. + schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols. )?); num_added_rows += batch.num_rows(); } @@ -972,40 +999,25 @@ impl std::future::IntoFuture for WriteBuilder { } else { Err(WriteError::MissingData) }?; + let schema = plan.schema(); if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { if let Some(snapshot) = &this.snapshot { let schema_struct: StructType = schema.clone().try_into()?; let current_protocol = snapshot.protocol(); let configuration = snapshot.metadata().configuration.clone(); - let maybe_new_protocol = if PROTOCOL - .contains_timestampntz(schema_struct.fields()) - && !current_protocol - .reader_features - .clone() - .unwrap_or_default() - .contains(&delta_kernel::table_features::ReaderFeatures::TimestampWithoutTimezone) - // We can check only reader features, as reader and writer timestampNtz - // should be always enabled together - { - let new_protocol = current_protocol.clone().enable_timestamp_ntz(); - if !(current_protocol.min_reader_version == 3 - && current_protocol.min_writer_version == 7) - { - Some(new_protocol.move_table_properties_into_features(&configuration)) - } else { - Some(new_protocol) - } - } else { - None - }; + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&schema_struct)? + .move_table_properties_into_features(&configuration); + let schema_action = Action::Metadata(Metadata::try_new( schema_struct, partition_columns.clone(), configuration, )?); actions.push(schema_action); - if let Some(new_protocol) = maybe_new_protocol { + if current_protocol != &new_protocol { actions.push(new_protocol.into()) } } @@ -1019,6 +1031,55 @@ impl std::future::IntoFuture for WriteBuilder { } }; + // Add when.then expr for generated columns + if !generated_col_expressions.is_empty() { + fn create_field( + idx: usize, + field: &arrow_schema::Field, + generated_cols_map: &HashMap, + state: &datafusion::execution::session_state::SessionState, + dfschema: &DFSchema, + ) -> DeltaResult<(Arc, String)> { + match generated_cols_map.get(field.name()) { + Some(generated_col) => { + let generation_expr = state.create_physical_expr( + when( + col(generated_col.get_name()).is_null(), + state.create_logical_expr( + generated_col.get_generation_expression(), + dfschema, + )?, + ) + .otherwise(col(generated_col.get_name()))?, + dfschema, + )?; + Ok((generation_expr, field.name().to_owned())) + } + None => Ok(( + Arc::new(expressions::Column::new(field.name(), idx)), + field.name().to_owned(), + )), + } + } + + let dfschema: DFSchema = schema.as_ref().clone().try_into()?; + let generated_cols_map = generated_col_expressions + .into_iter() + .map(|v| (v.name.clone(), v)) + .collect::>(); + let current_fields: DeltaResult, String)>> = plan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, field)| { + create_field(idx, field, &generated_cols_map, &state, &dfschema) + }) + .collect(); + + plan = Arc::new(ProjectionExec::try_new(current_fields?, plan.clone())?); + }; + let (predicate_str, predicate) = match this.predicate { Some(predicate) => { let pred = match predicate { @@ -1074,43 +1135,25 @@ impl std::future::IntoFuture for WriteBuilder { // Update metadata with new schema let table_schema = snapshot.input_schema()?; - let configuration = snapshot.metadata().configuration.clone(); - let current_protocol = snapshot.protocol(); - let maybe_new_protocol = if PROTOCOL.contains_timestampntz( - TryInto::::try_into(schema.clone())?.fields(), - ) && !current_protocol - .reader_features - .clone() - .unwrap_or_default() - .contains( - &delta_kernel::table_features::ReaderFeatures::TimestampWithoutTimezone, - ) - // We can check only reader features, as reader and writer timestampNtz - // should be always enabled together - { - let new_protocol = current_protocol.clone().enable_timestamp_ntz(); - if !(current_protocol.min_reader_version == 3 - && current_protocol.min_writer_version == 7) - { - Some(new_protocol.move_table_properties_into_features(&configuration)) - } else { - Some(new_protocol) - } - } else { - None - }; - - if let Some(protocol) = maybe_new_protocol { - actions.push(protocol.into()) - } - + let delta_schema: StructType = schema.as_ref().try_into()?; if schema != table_schema { let mut metadata = snapshot.metadata().clone(); - let delta_schema: StructType = schema.as_ref().try_into()?; + metadata.schema_string = serde_json::to_string(&delta_schema)?; actions.push(Action::Metadata(metadata)); } + let configuration = snapshot.metadata().configuration.clone(); + let current_protocol = snapshot.protocol(); + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&delta_schema)? + .move_table_properties_into_features(&configuration); + + if current_protocol != &new_protocol { + actions.push(new_protocol.into()) + } + let deletion_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 1409c498c2..56bfd664cb 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -157,6 +157,42 @@ impl DataCheck for Constraint { } } +/// A generated column +#[derive(Eq, PartialEq, Debug, Default, Clone)] +pub struct GeneratedColumn { + /// The full path to the field. + pub name: String, + /// The SQL string that generate the column value. + pub generation_expr: String, + /// The SQL string that must always evaluate to true. + pub validation_expr: String, +} + +impl GeneratedColumn { + /// Create a new invariant + pub fn new(field_name: &str, sql_generation: &str) -> Self { + Self { + name: field_name.to_string(), + generation_expr: sql_generation.to_string(), + validation_expr: format!("{} <=> {}", field_name, sql_generation), + } + } + + pub fn get_generation_expression(&self) -> &str { + &self.generation_expr + } +} + +impl DataCheck for GeneratedColumn { + fn get_name(&self) -> &str { + &self.name + } + + fn get_expression(&self) -> &str { + &self.validation_expr + } +} + /// Return partition fields along with their data type from the current schema. pub(crate) fn get_partition_col_data_types<'a>( schema: &'a StructType, From eeed42b568b1319f681c1565073d47ad400cd8ad Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 12 Jan 2025 19:49:40 +0100 Subject: [PATCH 2/6] feat: enable generated columns merge Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/operations/merge/mod.rs | 125 +++++++++++++++++++++--- crates/core/src/table/mod.rs | 3 +- 2 files changed, 112 insertions(+), 16 deletions(-) diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 8a4640b9a3..b9b0124189 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -74,7 +74,7 @@ use crate::delta_datafusion::{ register_store, DataFusionMixins, DeltaColumn, DeltaScan, DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider, }; -use crate::kernel::Action; +use crate::kernel::{Action, DataCheck, StructTypeExt}; use crate::logstore::LogStoreRef; use crate::operations::cdc::*; use crate::operations::merge::barrier::find_node; @@ -82,6 +82,7 @@ use crate::operations::transaction::CommitBuilder; use crate::operations::write::{write_execution_plan, write_execution_plan_cdc, WriterStatsConfig}; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; +use crate::table::GeneratedColumn; use crate::{DeltaResult, DeltaTable, DeltaTableError}; mod barrier; @@ -750,6 +751,73 @@ async fn execute( None => TableReference::bare(UNNAMED_TABLE), }; + /// Add generated column expressions to a dataframe + fn add_missing_generated_columns( + mut df: DataFrame, + generated_cols: &Vec, + ) -> DeltaResult<(DataFrame, Vec)> { + let mut missing_cols = vec![]; + for generated_col in generated_cols { + let col_name = generated_col.get_name(); + + if !df + .clone() + .schema() + .field_names() + .contains(&col_name.to_string()) + { + debug!( + "Adding missing generated column {} in source as placeholder", + col_name + ); + // If column doesn't exist, we add a null column, later we will generate the values after + // all the merge is projected. + // Other generated columns that were provided upon the start we only validate during write + missing_cols.push(col_name.to_string()); + df = df + .clone() + .with_column(col_name, Expr::Literal(ScalarValue::Null))?; + } + } + Ok((df, missing_cols)) + } + + /// Add generated column expressions to a dataframe + fn add_generated_columns( + mut df: DataFrame, + generated_cols: &Vec, + generated_cols_missing_in_source: &Vec, + state: &SessionState, + ) -> DeltaResult { + debug!("Generating columns in dataframe"); + for generated_col in generated_cols { + // We only validate columns that were missing from the start. We don't update + // update generated columns that were provided during runtime + if !generated_cols_missing_in_source.contains(&generated_col.name) { + continue; + } + + let generation_expr = state.create_logical_expr( + generated_col.get_generation_expression(), + df.clone().schema(), + )?; + let col_name = generated_col.get_name(); + + df = df.clone().with_column( + generated_col.get_name(), + when(col(col_name).is_null(), generation_expr).otherwise(col(col_name))?, + )? + } + Ok(df) + } + + let generated_col_expressions = snapshot + .schema() + .get_generated_columns() + .unwrap_or_default(); + + let (source, missing_generated_columns) = + add_missing_generated_columns(source, &generated_col_expressions)?; // This is only done to provide the source columns with a correct table reference. Just renaming the columns does not work let source = LogicalPlanBuilder::scan( source_name.clone(), @@ -1160,26 +1228,40 @@ async fn execute( lit(5), ))?; - change_data.push( - cdc_projection - .clone() - .filter( - col(SOURCE_COLUMN) - .is_true() - .and(col(TARGET_COLUMN).is_null()), - )? - .select(write_projection.clone())? - .with_column(CDC_COLUMN_NAME, lit("insert"))?, - ); + let mut cdc_insert_df = cdc_projection + .clone() + .filter( + col(SOURCE_COLUMN) + .is_true() + .and(col(TARGET_COLUMN).is_null()), + )? + .select(write_projection.clone())? + .with_column(CDC_COLUMN_NAME, lit("insert"))?; + + cdc_insert_df = add_generated_columns( + cdc_insert_df, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; - let after = cdc_projection + change_data.push(cdc_insert_df); + + let mut after = cdc_projection .clone() .filter(col(TARGET_COLUMN).is_true())? .select(write_projection.clone())?; + after = add_generated_columns( + after, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; + // Extra select_columns is required so that before and after have same schema order // DataFusion doesn't have UnionByName yet, see https://github.com/apache/datafusion/issues/12650 - let before = cdc_projection + let mut before = cdc_projection .clone() .filter(col(crate::delta_datafusion::PATH_COLUMN).is_not_null())? .select( @@ -1199,11 +1281,24 @@ async fn execute( .collect::>(), )?; + before = add_generated_columns( + before, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; + let tracker = CDCTracker::new(before, after); change_data.push(tracker.collect()?); } - let project = filtered.clone().select(write_projection)?; + let mut project = filtered.clone().select(write_projection)?; + project = add_generated_columns( + project, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; let merge_final = &project.into_unoptimized_plan(); let write = state.create_physical_plan(merge_final).await?; diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 56bfd664cb..56a1f9d1e7 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -174,7 +174,8 @@ impl GeneratedColumn { Self { name: field_name.to_string(), generation_expr: sql_generation.to_string(), - validation_expr: format!("{} <=> {}", field_name, sql_generation), + validation_expr: format!("{field_name} = {sql_generation} OR ({field_name} IS NULL AND {sql_generation} IS NULL)"), + // validation_expr: format!("{} <=> {}", field_name, sql_generation), } } From 04c911d52de9f09d221303a19eedfe30023f243c Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 12 Jan 2025 20:22:17 +0100 Subject: [PATCH 3/6] fix: cast generated col exprs always, don'fetch nested metadata Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/kernel/models/schema.rs | 45 +------------------ crates/core/src/operations/merge/mod.rs | 10 ++++- .../src/operations/transaction/protocol.rs | 4 +- crates/core/src/operations/write.rs | 8 +++- crates/core/src/table/mod.rs | 9 ++-- 5 files changed, 24 insertions(+), 52 deletions(-) diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 976fe467ef..e084fc0c0e 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -56,7 +56,7 @@ pub trait StructTypeExt { } impl StructTypeExt for StructType { - /// Get all invariants in the schemas + /// Get all get_generated_columns in the schemas fn get_generated_columns(&self) -> Result, Error> { let mut remaining_fields: Vec<(String, StructField)> = self .fields() @@ -64,48 +64,7 @@ impl StructTypeExt for StructType { .collect(); let mut generated_cols: Vec = Vec::new(); - let add_segment = |prefix: &str, segment: &str| -> String { - if prefix.is_empty() { - segment.to_owned() - } else { - format!("{prefix}.{segment}") - } - }; - while let Some((field_path, field)) = remaining_fields.pop() { - match field.data_type() { - DataType::Struct(inner) => { - remaining_fields.extend( - inner - .fields() - .map(|field| { - let new_prefix = add_segment(&field_path, &field.name); - (new_prefix, field.clone()) - }) - .collect::>(), - ); - } - DataType::Array(inner) => { - let element_field_name = add_segment(&field_path, "element"); - remaining_fields.push(( - element_field_name, - StructField::new("".to_string(), inner.element_type.clone(), false), - )); - } - DataType::Map(inner) => { - let key_field_name = add_segment(&field_path, "key"); - remaining_fields.push(( - key_field_name, - StructField::new("".to_string(), inner.key_type.clone(), false), - )); - let value_field_name = add_segment(&field_path, "value"); - remaining_fields.push(( - value_field_name, - StructField::new("".to_string(), inner.value_type.clone(), false), - )); - } - _ => {} - } if let Some(MetadataValue::String(generated_col_string)) = field .metadata .get(ColumnMetadataKey::GenerationExpression.as_ref()) @@ -117,7 +76,7 @@ impl StructTypeExt for StructType { } })?; if let Value::String(sql) = json { - generated_cols.push(GeneratedColumn::new(&field_path, &sql)); + generated_cols.push(GeneratedColumn::new(&field_path, &sql, field.data_type())); } } } diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index b9b0124189..0e2541349c 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -50,7 +50,8 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, + ExprSchemable, Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, + UNNAMED_TABLE, }; use filter::try_construct_early_filter; @@ -805,7 +806,12 @@ async fn execute( df = df.clone().with_column( generated_col.get_name(), - when(col(col_name).is_null(), generation_expr).otherwise(col(col_name))?, + when(col(col_name).is_null(), generation_expr) + .otherwise(col(col_name))? + .cast_to( + &arrow_schema::DataType::try_from(&generated_col.data_type)?, + df.schema(), + )?, )? } Ok(df) diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index bb49e0fae9..3fe94a5653 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -5,7 +5,7 @@ use once_cell::sync::Lazy; use tracing::log::*; use super::{TableReference, TransactionError}; -use crate::kernel::{contains_timestampntz, Action, DataType, EagerSnapshot, Schema, StructField}; +use crate::kernel::{contains_timestampntz, Action, EagerSnapshot, Schema}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; @@ -599,6 +599,6 @@ mod tests { let eager_5 = table .snapshot() .expect("Failed to get snapshot from test table"); - assert!(checker_5.can_write_to(eager_5).is_err()); + assert!(checker_5.can_write_to(eager_5).is_ok()); } } diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 1a1cc9d11f..c9813831ae 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -36,7 +36,7 @@ use arrow_cast::can_cast_types; use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion_common::DFSchema; -use datafusion_expr::{col, lit, when, Expr}; +use datafusion_expr::{col, lit, when, Expr, ExprSchemable}; use datafusion_physical_expr::expressions::{self}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::filter::FilterExec; @@ -1050,7 +1050,11 @@ impl std::future::IntoFuture for WriteBuilder { dfschema, )?, ) - .otherwise(col(generated_col.get_name()))?, + .otherwise(col(generated_col.get_name()))? + .cast_to( + &arrow_schema::DataType::try_from(&generated_col.data_type)?, + dfschema, + )?, dfschema, )?; Ok((generation_expr, field.name().to_owned())) diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 56a1f9d1e7..10dc9ae0a2 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -158,7 +158,7 @@ impl DataCheck for Constraint { } /// A generated column -#[derive(Eq, PartialEq, Debug, Default, Clone)] +#[derive(Eq, PartialEq, Debug, Clone)] pub struct GeneratedColumn { /// The full path to the field. pub name: String, @@ -166,16 +166,19 @@ pub struct GeneratedColumn { pub generation_expr: String, /// The SQL string that must always evaluate to true. pub validation_expr: String, + /// Data Type + pub data_type: DataType, } impl GeneratedColumn { /// Create a new invariant - pub fn new(field_name: &str, sql_generation: &str) -> Self { + pub fn new(field_name: &str, sql_generation: &str, data_type: &DataType) -> Self { Self { name: field_name.to_string(), generation_expr: sql_generation.to_string(), validation_expr: format!("{field_name} = {sql_generation} OR ({field_name} IS NULL AND {sql_generation} IS NULL)"), - // validation_expr: format!("{} <=> {}", field_name, sql_generation), + // validation_expr: format!("{} <=> {}", field_name, sql_generation), // update to + data_type: data_type.clone() } } From 823d823aefff47226ac15c39355d598645b999e1 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Mon, 13 Jan 2025 09:35:21 +0100 Subject: [PATCH 4/6] fix: disallow new generated columns Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/operations/add_column.rs | 13 +++++++++- .../core/src/operations/cast/merge_schema.rs | 26 +++++++++++++++++-- crates/core/src/operations/create.rs | 1 - 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/crates/core/src/operations/add_column.rs b/crates/core/src/operations/add_column.rs index a3477405af..9da5b86111 100644 --- a/crates/core/src/operations/add_column.rs +++ b/crates/core/src/operations/add_column.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use super::transaction::{CommitBuilder, CommitProperties, PROTOCOL}; use super::{CustomExecuteHandler, Operation}; -use crate::kernel::StructField; +use crate::kernel::{StructField, StructTypeExt}; use crate::logstore::LogStoreRef; use crate::operations::cast::merge_schema::merge_delta_struct; use crate::protocol::DeltaOperation; @@ -85,6 +85,17 @@ impl std::future::IntoFuture for AddColumnBuilder { this.pre_execute(operation_id).await?; let fields_right = &StructType::new(fields.clone()); + + if !fields_right + .get_generated_columns() + .unwrap_or_default() + .is_empty() + { + return Err(DeltaTableError::Generic( + "New columns cannot be a generated column".to_string(), + )); + } + let table_schema = this.snapshot.schema(); let new_table_schema = merge_delta_struct(table_schema, fields_right)?; diff --git a/crates/core/src/operations/cast/merge_schema.rs b/crates/core/src/operations/cast/merge_schema.rs index 64fe2b7ed6..b57c29b2e8 100644 --- a/crates/core/src/operations/cast/merge_schema.rs +++ b/crates/core/src/operations/cast/merge_schema.rs @@ -7,6 +7,7 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; +use delta_kernel::schema::ColumnMetadataKey; use crate::kernel::{ArrayType, DataType as DeltaDataType, MapType, StructField, StructType}; @@ -23,7 +24,16 @@ fn try_merge_metadata( ))); } } else { - left.insert(k.clone(), v.clone()); + // I'm not sure if updating the schema metadata is even valid? + if k != ColumnMetadataKey::GenerationExpression.as_ref() { + // At least new generated expression may not be insert into existing column metadata! + left.insert(k.clone(), v.clone()); + } else { + return Err(ArrowError::SchemaError(format!( + "Cannot add generated expressions to exists columns {}", + k + ))); + } } } Ok(()) @@ -322,6 +332,10 @@ fn merge_arrow_vec_fields( Err(e) } Ok(mut f) => { + // UNDO the implicit schema merging of batch fields into table fields that is done by + // field.try_merge + f.set_metadata(right_field.metadata().clone()); + let mut field_matadata = f.metadata().clone(); try_merge_metadata(&mut field_matadata, right_field.metadata())?; f.set_metadata(field_matadata); @@ -338,7 +352,15 @@ fn merge_arrow_vec_fields( if preserve_new_fields { for field in batch_fields.into_iter() { if table_fields.find(field.name()).is_none() { - fields.push(field.as_ref().clone()); + if !field + .metadata() + .contains_key(ColumnMetadataKey::GenerationExpression.as_ref()) + { + fields.push(field.as_ref().clone()); + } else { + errors.push("Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.".to_string()); + return Err(ArrowError::SchemaError(errors.join("\n"))); + } } } } diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index bcf79650cf..a4006c692c 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -20,7 +20,6 @@ use crate::protocol::{DeltaOperation, SaveMode}; use crate::table::builder::ensure_table_uri; use crate::table::config::TableProperty; use crate::{DeltaTable, DeltaTableBuilder}; -use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; #[derive(thiserror::Error, Debug)] enum CreateError { From 136779148ef73eaa616fe4b1b9e859ff90d22822 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Mon, 13 Jan 2025 11:13:38 +0100 Subject: [PATCH 5/6] chore: generated expression parsing improvement, support on first write/create Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/delta_datafusion/mod.rs | 11 ++++++++ crates/core/src/kernel/models/schema.rs | 27 ++++++++++++++++--- crates/core/src/operations/create.rs | 1 - .../src/operations/transaction/protocol.rs | 1 - crates/core/src/operations/write.rs | 7 +++-- 5 files changed, 40 insertions(+), 7 deletions(-) diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 08c38ddff3..a89aa710d9 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -1198,6 +1198,17 @@ impl DeltaDataChecker { } } + /// Create a new DeltaDataChecker with a specified set of generated columns + pub fn new_with_generated_columns(generated_columns: Vec) -> Self { + Self { + constraints: vec![], + invariants: vec![], + generated_columns, + non_nullable_columns: vec![], + ctx: DeltaSessionContext::default().into(), + } + } + /// Specify the Datafusion context pub fn with_session_context(mut self, context: SessionContext) -> Self { self.ctx = context; diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index e084fc0c0e..7c6be8a7dd 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -75,11 +75,32 @@ impl StructTypeExt for StructType { line: generated_col_string.to_string(), } })?; - if let Value::String(sql) = json { - generated_cols.push(GeneratedColumn::new(&field_path, &sql, field.data_type())); - } + match json { + Value::String(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &sql, + field.data_type(), + )), + Value::Number(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &format!("{}", sql), + field.data_type(), + )), + Value::Bool(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &format!("{}", sql), + field.data_type(), + )), + Value::Array(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &format!("{:?}", sql), + field.data_type(), + )), + _ => (), // Other types not sure what to do then + }; } } + dbg!(generated_cols.clone()); Ok(generated_cols) } diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index a4006c692c..ea051c58ef 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -6,7 +6,6 @@ use std::sync::Arc; use delta_kernel::schema::MetadataValue; use futures::future::BoxFuture; -use maplit::hashset; use serde_json::Value; use tracing::log::*; use uuid::Uuid; diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index 3fe94a5653..5a6ef6a3cf 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; use lazy_static::lazy_static; use once_cell::sync::Lazy; -use tracing::log::*; use super::{TableReference, TransactionError}; use crate::kernel::{contains_timestampntz, Action, EagerSnapshot, Schema}; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index c9813831ae..7004df74c4 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -25,7 +25,6 @@ //! ```` use std::collections::HashMap; -use std::hash::Hash; use std::str::FromStr; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; @@ -439,7 +438,11 @@ async fn write_execution_plan_with_predicate( let checker = if let Some(snapshot) = snapshot { DeltaDataChecker::new(snapshot) } else { - DeltaDataChecker::empty() + debug!("Using plan schema to derive generated columns, since no shapshot was provided. Implies first write."); + let delta_schema: StructType = schema.as_ref().try_into()?; + DeltaDataChecker::new_with_generated_columns( + delta_schema.get_generated_columns().unwrap_or_default(), + ) }; let checker = match predicate { Some(pred) => { From d54116025ff7765c3859f16adad1027396c42c63 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Mon, 13 Jan 2025 17:04:24 +0100 Subject: [PATCH 6/6] chore: tests Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/kernel/models/schema.rs | 44 ++++- crates/core/src/operations/merge/mod.rs | 7 +- python/tests/test_generated_columns.py | 220 ++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 4 deletions(-) create mode 100644 python/tests/test_generated_columns.py diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 7c6be8a7dd..bd76f0b3e9 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -100,7 +100,6 @@ impl StructTypeExt for StructType { }; } } - dbg!(generated_cols.clone()); Ok(generated_cols) } @@ -183,6 +182,49 @@ mod tests { use serde_json; use serde_json::json; + #[test] + fn test_get_generated_columns() { + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{}}] + } + )) + .unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 0); + + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}}] + } + )).unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 1); + assert_eq!(cols[0].data_type, DataType::INTEGER); + assert_eq!( + cols[0].validation_expr, + "gc = 5 OR (gc IS NULL AND 5 IS NULL)" + ); + + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}}, + {"name":"id2","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"id * 10\""}},] + } + )).unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 2); + } + #[test] fn test_get_invariants() { let schema: StructType = serde_json::from_value(json!({ diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 0e2541349c..e58fd22664 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -761,11 +761,12 @@ async fn execute( for generated_col in generated_cols { let col_name = generated_col.get_name(); - if !df + if df .clone() .schema() - .field_names() - .contains(&col_name.to_string()) + .field_with_unqualified_name(&col_name.to_string()) + .is_err() + // implies it doesn't exist { debug!( "Adding missing generated column {} in source as placeholder", diff --git a/python/tests/test_generated_columns.py b/python/tests/test_generated_columns.py new file mode 100644 index 0000000000..b329c948be --- /dev/null +++ b/python/tests/test_generated_columns.py @@ -0,0 +1,220 @@ +import pyarrow as pa +import pytest + +from deltalake import DeltaTable, Field, Schema, write_deltalake +from deltalake.exceptions import DeltaError, SchemaMismatchError +from deltalake.schema import PrimitiveType + + +@pytest.fixture +def gc_schema() -> Schema: + return Schema( + [ + Field(name="id", type=PrimitiveType("integer")), + Field( + name="gc", + type=PrimitiveType("integer"), + metadata={"delta.generationExpression": "'5'"}, + ), + ] + ) + + +@pytest.fixture +def valid_gc_data() -> pa.Table: + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"}) + data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [10, 10]}, schema=pa.schema([id_col, gc]) + ) + return data + + +@pytest.fixture +def data_without_gc() -> pa.Table: + id_col = pa.field("id", pa.int32()) + data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col])) + return data + + +@pytest.fixture +def invalid_gc_data() -> pa.Table: + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"}) + data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 10]}, schema=pa.schema([id_col, gc]) + ) + return data + + +@pytest.fixture +def table_with_gc(tmp_path, gc_schema) -> DeltaTable: + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + ) + return dt + + +def test_create_table_with_generated_columns(tmp_path, gc_schema: Schema): + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + ) + protocol = dt.protocol() + assert protocol.min_writer_version == 4 + + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + mode="overwrite", + configuration={"delta.minWriterVersion": "7"}, + ) + protocol = dt.protocol() + + assert dt.version() == 1 + assert protocol.writer_features is not None + assert "generatedColumns" in protocol.writer_features + + +def test_write_with_gc(tmp_path, valid_gc_data): + write_deltalake(tmp_path, mode="append", data=valid_gc_data) + dt = DeltaTable(tmp_path) + + assert dt.protocol().min_writer_version == 4 + assert dt.to_pyarrow_table() == valid_gc_data + + +def test_write_with_gc_higher_writer_version(tmp_path, valid_gc_data): + write_deltalake( + tmp_path, + mode="append", + data=valid_gc_data, + configuration={"delta.minWriterVersion": "7"}, + ) + dt = DeltaTable(tmp_path) + protocol = dt.protocol() + assert protocol.min_writer_version == 7 + assert protocol.writer_features is not None + assert "generatedColumns" in protocol.writer_features + assert dt.to_pyarrow_table() == valid_gc_data + + +def test_write_with_invalid_gc(tmp_path, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + 'Invariant violations: ["Check or Invariant (gc = 10 OR (gc IS NULL AND 10 IS NULL)) violated by value in row: [5]"]' + ), + ): + write_deltalake(tmp_path, mode="append", data=invalid_gc_data) + + +def test_write_with_invalid_gc_to_table(table_with_gc, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + "Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]" + ), + ): + write_deltalake(table_with_gc, mode="append", data=invalid_gc_data) + + +def test_write_to_table_generating_data(table_with_gc: DeltaTable): + id_col = pa.field("id", pa.int32()) + data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col])) + write_deltalake(table_with_gc, mode="append", data=data) + + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()) + expected_data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) + ) + + assert table_with_gc.version() == 1 + assert table_with_gc.to_pyarrow_table() == expected_data + + +def test_raise_when_gc_passed_during_schema_evolution( + tmp_path, data_without_gc, valid_gc_data +): + write_deltalake( + tmp_path, + mode="append", + data=data_without_gc, + ) + dt = DeltaTable(tmp_path) + assert dt.protocol().min_writer_version == 2 + + with pytest.raises( + SchemaMismatchError, + match="Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.", + ): + write_deltalake( + dt, + mode="append", + data=valid_gc_data, + schema_mode="merge", + ) + + +def test_raise_when_gc_passed_during_adding_new_columns(tmp_path, data_without_gc): + write_deltalake( + tmp_path, + mode="append", + data=data_without_gc, + ) + dt = DeltaTable(tmp_path) + assert dt.protocol().min_writer_version == 2 + + with pytest.raises(DeltaError, match="New columns cannot be a generated column"): + dt.alter.add_columns( + fields=[ + Field( + name="gc", + type=PrimitiveType("integer"), + metadata={"delta.generationExpression": "'5'"}, + ) + ] + ) + + +def test_merge_with_gc(table_with_gc: DeltaTable, data_without_gc): + ( + table_with_gc.merge( + data_without_gc, predicate="s.id = t.id", source_alias="s", target_alias="t" + ) + .when_not_matched_insert_all() + .execute() + ) + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()) + expected_data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) + ) + assert table_with_gc.to_pyarrow_table() == expected_data + + +def test_merge_with_gc_invalid(table_with_gc: DeltaTable, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + "Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]" + ), + ): + ( + table_with_gc.merge( + invalid_gc_data, + predicate="s.id = t.id", + source_alias="s", + target_alias="t", + ) + .when_not_matched_insert_all() + .execute() + )