diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 6aad9ab3f7..a89aa710d9 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -81,7 +81,7 @@ use crate::kernel::{Add, DataCheck, EagerSnapshot, Invariant, Snapshot, StructTy use crate::logstore::LogStoreRef; use crate::table::builder::ensure_table_uri; use crate::table::state::DeltaTableState; -use crate::table::Constraint; +use crate::table::{Constraint, GeneratedColumn}; use crate::{open_table, open_table_with_storage_options, DeltaTable}; pub(crate) const PATH_COLUMN: &str = "__delta_rs_path"; @@ -1159,6 +1159,7 @@ pub(crate) async fn execute_plan_to_batch( pub struct DeltaDataChecker { constraints: Vec, invariants: Vec, + generated_columns: Vec, non_nullable_columns: Vec, ctx: SessionContext, } @@ -1169,6 +1170,7 @@ impl DeltaDataChecker { Self { invariants: vec![], constraints: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1179,6 +1181,7 @@ impl DeltaDataChecker { Self { invariants, constraints: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1189,6 +1192,18 @@ impl DeltaDataChecker { Self { constraints, invariants: vec![], + generated_columns: vec![], + non_nullable_columns: vec![], + ctx: DeltaSessionContext::default().into(), + } + } + + /// Create a new DeltaDataChecker with a specified set of generated columns + pub fn new_with_generated_columns(generated_columns: Vec) -> Self { + Self { + constraints: vec![], + invariants: vec![], + generated_columns, non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1209,6 +1224,10 @@ impl DeltaDataChecker { /// Create a new DeltaDataChecker pub fn new(snapshot: &DeltaTableState) -> Self { let invariants = snapshot.schema().get_invariants().unwrap_or_default(); + let generated_columns = snapshot + .schema() + .get_generated_columns() + .unwrap_or_default(); let constraints = snapshot.table_config().get_constraints(); let non_nullable_columns = snapshot .schema() @@ -1224,6 +1243,7 @@ impl DeltaDataChecker { Self { invariants, constraints, + generated_columns, non_nullable_columns, ctx: DeltaSessionContext::default().into(), } @@ -1236,7 +1256,9 @@ impl DeltaDataChecker { pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { self.check_nullability(record_batch)?; self.enforce_checks(record_batch, &self.invariants).await?; - self.enforce_checks(record_batch, &self.constraints).await + self.enforce_checks(record_batch, &self.constraints).await?; + self.enforce_checks(record_batch, &self.generated_columns) + .await } /// Return true if all the nullability checks are valid diff --git a/crates/core/src/kernel/error.rs b/crates/core/src/kernel/error.rs index cefe81bf9d..fe34b1d7e4 100644 --- a/crates/core/src/kernel/error.rs +++ b/crates/core/src/kernel/error.rs @@ -65,6 +65,15 @@ pub enum Error { line: String, }, + /// Error returned when the log contains invalid stats JSON. + #[error("Invalid JSON in generation expression, line=`{line}`, err=`{json_err}`")] + InvalidGenerationExpressionJson { + /// JSON error details returned when parsing the generation expression JSON. + json_err: serde_json::error::Error, + /// Generation expression. + line: String, + }, + #[error("Table metadata is invalid: {0}")] MetadataError(String), diff --git a/crates/core/src/kernel/models/actions.rs b/crates/core/src/kernel/models/actions.rs index 3812dc4838..d825d5bec4 100644 --- a/crates/core/src/kernel/models/actions.rs +++ b/crates/core/src/kernel/models/actions.rs @@ -2,12 +2,14 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Display}; use std::str::FromStr; +use delta_kernel::schema::{DataType, StructField}; use maplit::hashset; use serde::{Deserialize, Serialize}; use tracing::warn; use url::Url; use super::schema::StructType; +use super::StructTypeExt; use crate::kernel::{error::Error, DeltaResult}; use crate::TableProperty; use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; @@ -115,6 +117,19 @@ impl Metadata { } } +/// checks if table contains timestamp_ntz in any field including nested fields. +pub fn contains_timestampntz<'a>(mut fields: impl Iterator) -> bool { + fn _check_type(dtype: &DataType) -> bool { + match dtype { + &DataType::TIMESTAMP_NTZ => true, + DataType::Array(inner) => _check_type(inner.element_type()), + DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())), + _ => false, + } + } + fields.any(|f| _check_type(f.data_type())) +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] #[serde(rename_all = "camelCase")] /// Defines a protocol action @@ -146,8 +161,8 @@ impl Protocol { } } - /// set the reader features in the protocol action, automatically bumps min_reader_version - pub fn with_reader_features( + /// Append the reader features in the protocol action, automatically bumps min_reader_version + pub fn append_reader_features( mut self, reader_features: impl IntoIterator>, ) -> Self { @@ -156,14 +171,20 @@ impl Protocol { .map(Into::into) .collect::>(); if !all_reader_features.is_empty() { - self.min_reader_version = 3 + self.min_reader_version = 3; + match self.reader_features { + Some(mut features) => { + features.extend(all_reader_features); + self.reader_features = Some(features); + } + None => self.reader_features = Some(all_reader_features), + }; } - self.reader_features = Some(all_reader_features); self } - /// set the writer features in the protocol action, automatically bumps min_writer_version - pub fn with_writer_features( + /// Append the writer features in the protocol action, automatically bumps min_writer_version + pub fn append_writer_features( mut self, writer_features: impl IntoIterator>, ) -> Self { @@ -172,9 +193,16 @@ impl Protocol { .map(|c| c.into()) .collect::>(); if !all_writer_feautures.is_empty() { - self.min_writer_version = 7 + self.min_writer_version = 7; + + match self.writer_features { + Some(mut features) => { + features.extend(all_writer_feautures); + self.writer_features = Some(features); + } + None => self.writer_features = Some(all_writer_feautures), + }; } - self.writer_features = Some(all_writer_feautures); self } @@ -255,6 +283,32 @@ impl Protocol { } self } + + /// Will apply the column metadata to the protocol by either bumping the version or setting + /// features + pub fn apply_column_metadata_to_protocol( + mut self, + schema: &StructType, + ) -> DeltaResult { + let generated_cols = schema.get_generated_columns()?; + let invariants = schema.get_invariants()?; + let contains_timestamp_ntz = self.contains_timestampntz(schema.fields()); + + if contains_timestamp_ntz { + self = self.enable_timestamp_ntz() + } + + if !generated_cols.is_empty() { + self = self.enable_generated_columns() + } + + if !invariants.is_empty() { + self = self.enable_invariants() + } + + Ok(self) + } + /// Will apply the properties to the protocol by either bumping the version or setting /// features pub fn apply_properties_to_protocol( @@ -391,10 +445,35 @@ impl Protocol { } Ok(self) } + + /// checks if table contains timestamp_ntz in any field including nested fields. + fn contains_timestampntz<'a>(&self, fields: impl Iterator) -> bool { + contains_timestampntz(fields) + } + /// Enable timestamp_ntz in the protocol - pub fn enable_timestamp_ntz(mut self) -> Protocol { - self = self.with_reader_features(vec![ReaderFeatures::TimestampWithoutTimezone]); - self = self.with_writer_features(vec![WriterFeatures::TimestampWithoutTimezone]); + fn enable_timestamp_ntz(mut self) -> Self { + self = self.append_reader_features([ReaderFeatures::TimestampWithoutTimezone]); + self = self.append_writer_features([WriterFeatures::TimestampWithoutTimezone]); + self + } + + /// Enabled generated columns + fn enable_generated_columns(mut self) -> Self { + if self.min_writer_version < 4 { + self.min_writer_version = 4; + } + if self.min_writer_version >= 7 { + self = self.append_writer_features([WriterFeatures::GeneratedColumns]); + } + self + } + + /// Enabled generated columns + fn enable_invariants(mut self) -> Self { + if self.min_writer_version >= 7 { + self = self.append_writer_features([WriterFeatures::Invariants]); + } self } } diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 3a88564f1d..bd76f0b3e9 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -10,6 +10,7 @@ use serde_json::Value; use crate::kernel::error::Error; use crate::kernel::DataCheck; +use crate::table::GeneratedColumn; /// Type alias for a top level schema pub type Schema = StructType; @@ -49,9 +50,59 @@ impl DataCheck for Invariant { pub trait StructTypeExt { /// Get all invariants in the schemas fn get_invariants(&self) -> Result, Error>; + + /// Get all generated column expressions + fn get_generated_columns(&self) -> Result, Error>; } impl StructTypeExt for StructType { + /// Get all get_generated_columns in the schemas + fn get_generated_columns(&self) -> Result, Error> { + let mut remaining_fields: Vec<(String, StructField)> = self + .fields() + .map(|field| (field.name.clone(), field.clone())) + .collect(); + let mut generated_cols: Vec = Vec::new(); + + while let Some((field_path, field)) = remaining_fields.pop() { + if let Some(MetadataValue::String(generated_col_string)) = field + .metadata + .get(ColumnMetadataKey::GenerationExpression.as_ref()) + { + let json: Value = serde_json::from_str(generated_col_string).map_err(|e| { + Error::InvalidGenerationExpressionJson { + json_err: e, + line: generated_col_string.to_string(), + } + })?; + match json { + Value::String(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &sql, + field.data_type(), + )), + Value::Number(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &format!("{}", sql), + field.data_type(), + )), + Value::Bool(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &format!("{}", sql), + field.data_type(), + )), + Value::Array(sql) => generated_cols.push(GeneratedColumn::new( + &field_path, + &format!("{:?}", sql), + field.data_type(), + )), + _ => (), // Other types not sure what to do then + }; + } + } + Ok(generated_cols) + } + /// Get all invariants in the schemas fn get_invariants(&self) -> Result, Error> { let mut remaining_fields: Vec<(String, StructField)> = self @@ -131,6 +182,49 @@ mod tests { use serde_json; use serde_json::json; + #[test] + fn test_get_generated_columns() { + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{}}] + } + )) + .unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 0); + + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}}] + } + )).unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 1); + assert_eq!(cols[0].data_type, DataType::INTEGER); + assert_eq!( + cols[0].validation_expr, + "gc = 5 OR (gc IS NULL AND 5 IS NULL)" + ); + + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}}, + {"name":"id2","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"id * 10\""}},] + } + )).unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 2); + } + #[test] fn test_get_invariants() { let schema: StructType = serde_json::from_value(json!({ diff --git a/crates/core/src/operations/add_column.rs b/crates/core/src/operations/add_column.rs index 2b6d9de7df..9da5b86111 100644 --- a/crates/core/src/operations/add_column.rs +++ b/crates/core/src/operations/add_column.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use super::transaction::{CommitBuilder, CommitProperties, PROTOCOL}; use super::{CustomExecuteHandler, Operation}; -use crate::kernel::StructField; +use crate::kernel::{StructField, StructTypeExt}; use crate::logstore::LogStoreRef; use crate::operations::cast::merge_schema::merge_delta_struct; use crate::protocol::DeltaOperation; @@ -85,27 +85,26 @@ impl std::future::IntoFuture for AddColumnBuilder { this.pre_execute(operation_id).await?; let fields_right = &StructType::new(fields.clone()); + + if !fields_right + .get_generated_columns() + .unwrap_or_default() + .is_empty() + { + return Err(DeltaTableError::Generic( + "New columns cannot be a generated column".to_string(), + )); + } + let table_schema = this.snapshot.schema(); let new_table_schema = merge_delta_struct(table_schema, fields_right)?; - // TODO(ion): Think of a way how we can simply this checking through the API or centralize some checks. - let contains_timestampntz = PROTOCOL.contains_timestampntz(fields.iter()); - let protocol = this.snapshot.protocol(); - - let maybe_new_protocol = if contains_timestampntz { - let updated_protocol = protocol.clone().enable_timestamp_ntz(); - if !(protocol.min_reader_version == 3 && protocol.min_writer_version == 7) { - // Convert existing properties to features since we advanced the protocol to v3,7 - Some( - updated_protocol - .move_table_properties_into_features(&metadata.configuration), - ) - } else { - Some(updated_protocol) - } - } else { - None - }; + let current_protocol = this.snapshot.protocol(); + + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&new_table_schema)? + .move_table_properties_into_features(&metadata.configuration); let operation = DeltaOperation::AddColumn { fields: fields.into_iter().collect_vec(), @@ -115,7 +114,7 @@ impl std::future::IntoFuture for AddColumnBuilder { let mut actions = vec![metadata.into()]; - if let Some(new_protocol) = maybe_new_protocol { + if current_protocol != &new_protocol { actions.push(new_protocol.into()) } diff --git a/crates/core/src/operations/add_feature.rs b/crates/core/src/operations/add_feature.rs index 0e7f88ee7f..31dbb928bf 100644 --- a/crates/core/src/operations/add_feature.rs +++ b/crates/core/src/operations/add_feature.rs @@ -123,8 +123,8 @@ impl std::future::IntoFuture for AddTableFeatureBuilder { } } - protocol = protocol.with_reader_features(reader_features); - protocol = protocol.with_writer_features(writer_features); + protocol = protocol.append_reader_features(reader_features); + protocol = protocol.append_writer_features(writer_features); let operation = DeltaOperation::AddFeature { name: name.to_vec(), diff --git a/crates/core/src/operations/cast/merge_schema.rs b/crates/core/src/operations/cast/merge_schema.rs index 64fe2b7ed6..b57c29b2e8 100644 --- a/crates/core/src/operations/cast/merge_schema.rs +++ b/crates/core/src/operations/cast/merge_schema.rs @@ -7,6 +7,7 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; +use delta_kernel::schema::ColumnMetadataKey; use crate::kernel::{ArrayType, DataType as DeltaDataType, MapType, StructField, StructType}; @@ -23,7 +24,16 @@ fn try_merge_metadata( ))); } } else { - left.insert(k.clone(), v.clone()); + // I'm not sure if updating the schema metadata is even valid? + if k != ColumnMetadataKey::GenerationExpression.as_ref() { + // At least new generated expression may not be insert into existing column metadata! + left.insert(k.clone(), v.clone()); + } else { + return Err(ArrowError::SchemaError(format!( + "Cannot add generated expressions to exists columns {}", + k + ))); + } } } Ok(()) @@ -322,6 +332,10 @@ fn merge_arrow_vec_fields( Err(e) } Ok(mut f) => { + // UNDO the implicit schema merging of batch fields into table fields that is done by + // field.try_merge + f.set_metadata(right_field.metadata().clone()); + let mut field_matadata = f.metadata().clone(); try_merge_metadata(&mut field_matadata, right_field.metadata())?; f.set_metadata(field_matadata); @@ -338,7 +352,15 @@ fn merge_arrow_vec_fields( if preserve_new_fields { for field in batch_fields.into_iter() { if table_fields.find(field.name()).is_none() { - fields.push(field.as_ref().clone()); + if !field + .metadata() + .contains_key(ColumnMetadataKey::GenerationExpression.as_ref()) + { + fields.push(field.as_ref().clone()); + } else { + errors.push("Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.".to_string()); + return Err(ArrowError::SchemaError(errors.join("\n"))); + } } } } diff --git a/crates/core/src/operations/cdc.rs b/crates/core/src/operations/cdc.rs index c9d0ca0665..5e950402b8 100644 --- a/crates/core/src/operations/cdc.rs +++ b/crates/core/src/operations/cdc.rs @@ -175,7 +175,7 @@ mod tests { #[tokio::test] async fn test_should_write_cdc_v7_table_with_writer_feature() { let protocol = - Protocol::new(1, 7).with_writer_features(vec![WriterFeatures::ChangeDataFeed]); + Protocol::new(1, 7).append_writer_features(vec![WriterFeatures::ChangeDataFeed]); let actions = vec![Action::Protocol(protocol)]; let mut table: DeltaTable = DeltaOps::new_in_memory() .create() diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index 5f6ef47bc0..ea051c58ef 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -6,7 +6,6 @@ use std::sync::Arc; use delta_kernel::schema::MetadataValue; use futures::future::BoxFuture; -use maplit::hashset; use serde_json::Value; use tracing::log::*; use uuid::Uuid; @@ -20,7 +19,6 @@ use crate::protocol::{DeltaOperation, SaveMode}; use crate::table::builder::ensure_table_uri; use crate::table::config::TableProperty; use crate::{DeltaTable, DeltaTableBuilder}; -use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; #[derive(thiserror::Error, Debug)] enum CreateError { @@ -289,24 +287,12 @@ impl CreateBuilder { self.pre_execute(operation_id).await?; let configuration = self.configuration; - let contains_timestampntz = PROTOCOL.contains_timestampntz(self.columns.iter()); - // TODO configure more permissive versions based on configuration. Also how should this ideally be handled? - // We set the lowest protocol we can, and if subsequent writes use newer features we update metadata? - - let current_protocol = if contains_timestampntz { - Protocol { - min_reader_version: 3, - min_writer_version: 7, - writer_features: Some(hashset! {WriterFeatures::TimestampWithoutTimezone}), - reader_features: Some(hashset! {ReaderFeatures::TimestampWithoutTimezone}), - } - } else { - Protocol { - min_reader_version: PROTOCOL.default_reader_version(), - min_writer_version: PROTOCOL.default_writer_version(), - reader_features: None, - writer_features: None, - } + + let current_protocol = Protocol { + min_reader_version: PROTOCOL.default_reader_version(), + min_writer_version: PROTOCOL.default_writer_version(), + reader_features: None, + writer_features: None, }; let protocol = self @@ -319,18 +305,21 @@ impl CreateBuilder { }) .unwrap_or_else(|| current_protocol); - let protocol = protocol.apply_properties_to_protocol( - &configuration - .iter() - .map(|(k, v)| (k.clone(), v.clone().unwrap())) - .collect::>(), - self.raise_if_key_not_exists, - )?; + let schema = StructType::new(self.columns); - let protocol = protocol.move_table_properties_into_features(&configuration); + let protocol = protocol + .apply_properties_to_protocol( + &configuration + .iter() + .map(|(k, v)| (k.clone(), v.clone().unwrap())) + .collect::>(), + self.raise_if_key_not_exists, + )? + .apply_column_metadata_to_protocol(&schema)? + .move_table_properties_into_features(&configuration); let mut metadata = Metadata::try_new( - StructType::new(self.columns), + schema, self.partition_columns.unwrap_or_default(), configuration, )? diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 8a4640b9a3..e58fd22664 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -50,7 +50,8 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, + ExprSchemable, Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, + UNNAMED_TABLE, }; use filter::try_construct_early_filter; @@ -74,7 +75,7 @@ use crate::delta_datafusion::{ register_store, DataFusionMixins, DeltaColumn, DeltaScan, DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider, }; -use crate::kernel::Action; +use crate::kernel::{Action, DataCheck, StructTypeExt}; use crate::logstore::LogStoreRef; use crate::operations::cdc::*; use crate::operations::merge::barrier::find_node; @@ -82,6 +83,7 @@ use crate::operations::transaction::CommitBuilder; use crate::operations::write::{write_execution_plan, write_execution_plan_cdc, WriterStatsConfig}; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; +use crate::table::GeneratedColumn; use crate::{DeltaResult, DeltaTable, DeltaTableError}; mod barrier; @@ -750,6 +752,79 @@ async fn execute( None => TableReference::bare(UNNAMED_TABLE), }; + /// Add generated column expressions to a dataframe + fn add_missing_generated_columns( + mut df: DataFrame, + generated_cols: &Vec, + ) -> DeltaResult<(DataFrame, Vec)> { + let mut missing_cols = vec![]; + for generated_col in generated_cols { + let col_name = generated_col.get_name(); + + if df + .clone() + .schema() + .field_with_unqualified_name(&col_name.to_string()) + .is_err() + // implies it doesn't exist + { + debug!( + "Adding missing generated column {} in source as placeholder", + col_name + ); + // If column doesn't exist, we add a null column, later we will generate the values after + // all the merge is projected. + // Other generated columns that were provided upon the start we only validate during write + missing_cols.push(col_name.to_string()); + df = df + .clone() + .with_column(col_name, Expr::Literal(ScalarValue::Null))?; + } + } + Ok((df, missing_cols)) + } + + /// Add generated column expressions to a dataframe + fn add_generated_columns( + mut df: DataFrame, + generated_cols: &Vec, + generated_cols_missing_in_source: &Vec, + state: &SessionState, + ) -> DeltaResult { + debug!("Generating columns in dataframe"); + for generated_col in generated_cols { + // We only validate columns that were missing from the start. We don't update + // update generated columns that were provided during runtime + if !generated_cols_missing_in_source.contains(&generated_col.name) { + continue; + } + + let generation_expr = state.create_logical_expr( + generated_col.get_generation_expression(), + df.clone().schema(), + )?; + let col_name = generated_col.get_name(); + + df = df.clone().with_column( + generated_col.get_name(), + when(col(col_name).is_null(), generation_expr) + .otherwise(col(col_name))? + .cast_to( + &arrow_schema::DataType::try_from(&generated_col.data_type)?, + df.schema(), + )?, + )? + } + Ok(df) + } + + let generated_col_expressions = snapshot + .schema() + .get_generated_columns() + .unwrap_or_default(); + + let (source, missing_generated_columns) = + add_missing_generated_columns(source, &generated_col_expressions)?; // This is only done to provide the source columns with a correct table reference. Just renaming the columns does not work let source = LogicalPlanBuilder::scan( source_name.clone(), @@ -1160,26 +1235,40 @@ async fn execute( lit(5), ))?; - change_data.push( - cdc_projection - .clone() - .filter( - col(SOURCE_COLUMN) - .is_true() - .and(col(TARGET_COLUMN).is_null()), - )? - .select(write_projection.clone())? - .with_column(CDC_COLUMN_NAME, lit("insert"))?, - ); + let mut cdc_insert_df = cdc_projection + .clone() + .filter( + col(SOURCE_COLUMN) + .is_true() + .and(col(TARGET_COLUMN).is_null()), + )? + .select(write_projection.clone())? + .with_column(CDC_COLUMN_NAME, lit("insert"))?; + + cdc_insert_df = add_generated_columns( + cdc_insert_df, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; - let after = cdc_projection + change_data.push(cdc_insert_df); + + let mut after = cdc_projection .clone() .filter(col(TARGET_COLUMN).is_true())? .select(write_projection.clone())?; + after = add_generated_columns( + after, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; + // Extra select_columns is required so that before and after have same schema order // DataFusion doesn't have UnionByName yet, see https://github.com/apache/datafusion/issues/12650 - let before = cdc_projection + let mut before = cdc_projection .clone() .filter(col(crate::delta_datafusion::PATH_COLUMN).is_not_null())? .select( @@ -1199,11 +1288,24 @@ async fn execute( .collect::>(), )?; + before = add_generated_columns( + before, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; + let tracker = CDCTracker::new(before, after); change_data.push(tracker.collect()?); } - let project = filtered.clone().select(write_projection)?; + let mut project = filtered.clone().select(write_projection)?; + project = add_generated_columns( + project, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; let merge_final = &project.into_unoptimized_plan(); let write = state.create_physical_plan(merge_final).await?; diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index ef88fbf8e6..5a6ef6a3cf 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -2,10 +2,9 @@ use std::collections::HashSet; use lazy_static::lazy_static; use once_cell::sync::Lazy; -use tracing::log::*; use super::{TableReference, TransactionError}; -use crate::kernel::{Action, DataType, EagerSnapshot, Schema, StructField}; +use crate::kernel::{contains_timestampntz, Action, EagerSnapshot, Schema}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; @@ -79,29 +78,13 @@ impl ProtocolChecker { Ok(()) } - /// checks if table contains timestamp_ntz in any field including nested fields. - pub fn contains_timestampntz<'a>( - &self, - mut fields: impl Iterator, - ) -> bool { - fn _check_type(dtype: &DataType) -> bool { - match dtype { - &DataType::TIMESTAMP_NTZ => true, - DataType::Array(inner) => _check_type(inner.element_type()), - DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())), - _ => false, - } - } - fields.any(|f| _check_type(f.data_type())) - } - /// Check can write_timestamp_ntz pub fn check_can_write_timestamp_ntz( &self, snapshot: &DeltaTableState, schema: &Schema, ) -> Result<(), TransactionError> { - let contains_timestampntz = self.contains_timestampntz(schema.fields()); + let contains_timestampntz = contains_timestampntz(schema.fields()); let required_features: Option<&HashSet> = match snapshot.protocol().min_writer_version { 0..=6 => None, @@ -159,22 +142,6 @@ impl ProtocolChecker { _ => snapshot.protocol().writer_features.as_ref(), }; - if (4..7).contains(&min_writer_version) { - debug!("min_writer_version is less 4-6, checking for unsupported table features"); - if let Ok(schema) = snapshot.metadata().schema() { - for field in schema.fields() { - if field.metadata.contains_key( - crate::kernel::ColumnMetadataKey::GenerationExpression.as_ref(), - ) { - error!("The table contains `delta.generationExpression` settings on columns which mean this table cannot be currently written to by delta-rs"); - return Err(TransactionError::UnsupportedWriterFeatures(vec![ - WriterFeatures::GeneratedColumns, - ])); - } - } - } - } - if let Some(features) = required_features { let mut diff = features.difference(&self.writer_features).peekable(); if diff.peek().is_some() { @@ -246,15 +213,13 @@ pub static INSTANCE: Lazy = Lazy::new(|| { #[cfg(feature = "cdf")] { writer_features.insert(WriterFeatures::ChangeDataFeed); - writer_features.insert(WriterFeatures::GeneratedColumns); } #[cfg(feature = "datafusion")] { writer_features.insert(WriterFeatures::Invariants); writer_features.insert(WriterFeatures::CheckConstraints); + writer_features.insert(WriterFeatures::GeneratedColumns); } - // writer_features.insert(WriterFeatures::ChangeDataFeed); - // writer_features.insert(WriterFeatures::GeneratedColumns); // writer_features.insert(WriterFeatures::ColumnMapping); // writer_features.insert(WriterFeatures::IdentityColumns); @@ -584,7 +549,7 @@ mod tests { let checker_5 = ProtocolChecker::new(READER_V2.clone(), WRITER_V4.clone()); let actions = vec![ Action::Protocol( - Protocol::new(2, 4).with_writer_features(vec![WriterFeatures::ChangeDataFeed]), + Protocol::new(2, 4).append_writer_features(vec![WriterFeatures::ChangeDataFeed]), ), metadata_action(None).into(), ]; @@ -601,7 +566,7 @@ mod tests { let checker_5 = ProtocolChecker::new(READER_V2.clone(), WRITER_V4.clone()); let actions = vec![ Action::Protocol( - Protocol::new(2, 4).with_writer_features(vec![WriterFeatures::GeneratedColumns]), + Protocol::new(2, 4).append_writer_features([WriterFeatures::GeneratedColumns]), ), metadata_action(None).into(), ]; @@ -633,6 +598,6 @@ mod tests { let eager_5 = table .snapshot() .expect("Failed to get snapshot from test table"); - assert!(checker_5.can_write_to(eager_5).is_err()); + assert!(checker_5.can_write_to(eager_5).is_ok()); } } diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 0be11f05cf..7004df74c4 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -35,7 +35,7 @@ use arrow_cast::can_cast_types; use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion_common::DFSchema; -use datafusion_expr::{lit, Expr}; +use datafusion_expr::{col, lit, when, Expr, ExprSchemable}; use datafusion_physical_expr::expressions::{self}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::filter::FilterExec; @@ -63,14 +63,15 @@ use crate::delta_datafusion::{ use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{ - Action, ActionType, Add, AddCDCFile, Metadata, PartitionsExt, Remove, StructType, + Action, ActionType, Add, AddCDCFile, DataCheck, Metadata, PartitionsExt, Remove, StructType, + StructTypeExt, }; use crate::logstore::LogStoreRef; use crate::operations::cast::{cast_record_batch, merge_schema::merge_arrow_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; -use crate::table::Constraint as DeltaConstraint; +use crate::table::{Constraint as DeltaConstraint, GeneratedColumn}; use crate::writer::record_batch::divide_by_partition_values; use crate::DeltaTable; @@ -437,7 +438,11 @@ async fn write_execution_plan_with_predicate( let checker = if let Some(snapshot) = snapshot { DeltaDataChecker::new(snapshot) } else { - DeltaDataChecker::empty() + debug!("Using plan schema to derive generated columns, since no shapshot was provided. Implies first write."); + let delta_schema: StructType = schema.as_ref().try_into()?; + DeltaDataChecker::new_with_generated_columns( + delta_schema.get_generated_columns().unwrap_or_default(), + ) }; let checker = match predicate { Some(pred) => { @@ -852,8 +857,14 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; + + let generated_col_expressions = this + .snapshot + .as_ref() + .map(|v| v.schema().get_generated_columns().unwrap_or_default()) + .unwrap_or_default(); let mut schema_drift = false; - let plan = if let Some(plan) = this.input { + let mut plan = if let Some(plan) = this.input { if this.schema_mode == Some(SchemaMode::Merge) { return Err(DeltaTableError::Generic( "Schema merge not supported yet for Datafusion".to_string(), @@ -864,11 +875,22 @@ impl std::future::IntoFuture for WriteBuilder { if batches.is_empty() { Err(WriteError::MissingData) } else { - let schema = batches[0].schema(); + let mut schema = batches[0].schema(); + // Schema merging code should be aware of columns that can be generated during write + // so they might be empty in the batch, but the will exist in the input_schema() + // in this case we have to insert the generated column and it's type in the schema of the batch let mut new_schema = None; if let Some(snapshot) = &this.snapshot { let table_schema = snapshot.input_schema()?; + + // Merge schema's initial round when there are generated columns expressions + // This is to have the batch schema be the same as the input schema without adding new fields + // from the incoming batch + if !generated_col_expressions.is_empty() { + schema = merge_arrow_schema(table_schema.clone(), schema, true)?; + } + if let Err(schema_err) = try_cast_batch(schema.fields(), table_schema.fields()) { @@ -876,7 +898,11 @@ impl std::future::IntoFuture for WriteBuilder { if this.mode == SaveMode::Overwrite && this.schema_mode == Some(SchemaMode::Overwrite) { - new_schema = None // we overwrite anyway, so no need to cast + if generated_col_expressions.is_empty() { + new_schema = None // we overwrite anyway, so no need to cast + } else { + new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null + } } else if this.schema_mode == Some(SchemaMode::Merge) { new_schema = Some(merge_arrow_schema( table_schema.clone(), @@ -889,7 +915,11 @@ impl std::future::IntoFuture for WriteBuilder { } else if this.mode == SaveMode::Overwrite && this.schema_mode == Some(SchemaMode::Overwrite) { - new_schema = None // we overwrite anyway, so no need to cast + if generated_col_expressions.is_empty() { + new_schema = None // we overwrite anyway, so no need to cast + } else { + new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null + } } else { // Schema needs to be merged so that utf8/binary/list types are preserved from the batch side if both table // and batch contains such type. Other types are preserved from the table side. @@ -912,7 +942,7 @@ impl std::future::IntoFuture for WriteBuilder { &batch, new_schema, this.safe_cast, - schema_drift, // Schema drifted so we have to add the missing columns/structfields. + schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols.. )?, None => batch, }; @@ -949,7 +979,7 @@ impl std::future::IntoFuture for WriteBuilder { &batch, new_schema.clone(), this.safe_cast, - schema_drift, // Schema drifted so we have to add the missing columns/structfields. + schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols. )?); num_added_rows += batch.num_rows(); } @@ -972,40 +1002,25 @@ impl std::future::IntoFuture for WriteBuilder { } else { Err(WriteError::MissingData) }?; + let schema = plan.schema(); if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { if let Some(snapshot) = &this.snapshot { let schema_struct: StructType = schema.clone().try_into()?; let current_protocol = snapshot.protocol(); let configuration = snapshot.metadata().configuration.clone(); - let maybe_new_protocol = if PROTOCOL - .contains_timestampntz(schema_struct.fields()) - && !current_protocol - .reader_features - .clone() - .unwrap_or_default() - .contains(&delta_kernel::table_features::ReaderFeatures::TimestampWithoutTimezone) - // We can check only reader features, as reader and writer timestampNtz - // should be always enabled together - { - let new_protocol = current_protocol.clone().enable_timestamp_ntz(); - if !(current_protocol.min_reader_version == 3 - && current_protocol.min_writer_version == 7) - { - Some(new_protocol.move_table_properties_into_features(&configuration)) - } else { - Some(new_protocol) - } - } else { - None - }; + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&schema_struct)? + .move_table_properties_into_features(&configuration); + let schema_action = Action::Metadata(Metadata::try_new( schema_struct, partition_columns.clone(), configuration, )?); actions.push(schema_action); - if let Some(new_protocol) = maybe_new_protocol { + if current_protocol != &new_protocol { actions.push(new_protocol.into()) } } @@ -1019,6 +1034,59 @@ impl std::future::IntoFuture for WriteBuilder { } }; + // Add when.then expr for generated columns + if !generated_col_expressions.is_empty() { + fn create_field( + idx: usize, + field: &arrow_schema::Field, + generated_cols_map: &HashMap, + state: &datafusion::execution::session_state::SessionState, + dfschema: &DFSchema, + ) -> DeltaResult<(Arc, String)> { + match generated_cols_map.get(field.name()) { + Some(generated_col) => { + let generation_expr = state.create_physical_expr( + when( + col(generated_col.get_name()).is_null(), + state.create_logical_expr( + generated_col.get_generation_expression(), + dfschema, + )?, + ) + .otherwise(col(generated_col.get_name()))? + .cast_to( + &arrow_schema::DataType::try_from(&generated_col.data_type)?, + dfschema, + )?, + dfschema, + )?; + Ok((generation_expr, field.name().to_owned())) + } + None => Ok(( + Arc::new(expressions::Column::new(field.name(), idx)), + field.name().to_owned(), + )), + } + } + + let dfschema: DFSchema = schema.as_ref().clone().try_into()?; + let generated_cols_map = generated_col_expressions + .into_iter() + .map(|v| (v.name.clone(), v)) + .collect::>(); + let current_fields: DeltaResult, String)>> = plan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, field)| { + create_field(idx, field, &generated_cols_map, &state, &dfschema) + }) + .collect(); + + plan = Arc::new(ProjectionExec::try_new(current_fields?, plan.clone())?); + }; + let (predicate_str, predicate) = match this.predicate { Some(predicate) => { let pred = match predicate { @@ -1074,43 +1142,25 @@ impl std::future::IntoFuture for WriteBuilder { // Update metadata with new schema let table_schema = snapshot.input_schema()?; - let configuration = snapshot.metadata().configuration.clone(); - let current_protocol = snapshot.protocol(); - let maybe_new_protocol = if PROTOCOL.contains_timestampntz( - TryInto::::try_into(schema.clone())?.fields(), - ) && !current_protocol - .reader_features - .clone() - .unwrap_or_default() - .contains( - &delta_kernel::table_features::ReaderFeatures::TimestampWithoutTimezone, - ) - // We can check only reader features, as reader and writer timestampNtz - // should be always enabled together - { - let new_protocol = current_protocol.clone().enable_timestamp_ntz(); - if !(current_protocol.min_reader_version == 3 - && current_protocol.min_writer_version == 7) - { - Some(new_protocol.move_table_properties_into_features(&configuration)) - } else { - Some(new_protocol) - } - } else { - None - }; - - if let Some(protocol) = maybe_new_protocol { - actions.push(protocol.into()) - } - + let delta_schema: StructType = schema.as_ref().try_into()?; if schema != table_schema { let mut metadata = snapshot.metadata().clone(); - let delta_schema: StructType = schema.as_ref().try_into()?; + metadata.schema_string = serde_json::to_string(&delta_schema)?; actions.push(Action::Metadata(metadata)); } + let configuration = snapshot.metadata().configuration.clone(); + let current_protocol = snapshot.protocol(); + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&delta_schema)? + .move_table_properties_into_features(&configuration); + + if current_protocol != &new_protocol { + actions.push(new_protocol.into()) + } + let deletion_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 1409c498c2..10dc9ae0a2 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -157,6 +157,46 @@ impl DataCheck for Constraint { } } +/// A generated column +#[derive(Eq, PartialEq, Debug, Clone)] +pub struct GeneratedColumn { + /// The full path to the field. + pub name: String, + /// The SQL string that generate the column value. + pub generation_expr: String, + /// The SQL string that must always evaluate to true. + pub validation_expr: String, + /// Data Type + pub data_type: DataType, +} + +impl GeneratedColumn { + /// Create a new invariant + pub fn new(field_name: &str, sql_generation: &str, data_type: &DataType) -> Self { + Self { + name: field_name.to_string(), + generation_expr: sql_generation.to_string(), + validation_expr: format!("{field_name} = {sql_generation} OR ({field_name} IS NULL AND {sql_generation} IS NULL)"), + // validation_expr: format!("{} <=> {}", field_name, sql_generation), // update to + data_type: data_type.clone() + } + } + + pub fn get_generation_expression(&self) -> &str { + &self.generation_expr + } +} + +impl DataCheck for GeneratedColumn { + fn get_name(&self) -> &str { + &self.name + } + + fn get_expression(&self) -> &str { + &self.validation_expr + } +} + /// Return partition fields along with their data type from the current schema. pub(crate) fn get_partition_col_data_types<'a>( schema: &'a StructType, diff --git a/python/tests/test_generated_columns.py b/python/tests/test_generated_columns.py new file mode 100644 index 0000000000..b329c948be --- /dev/null +++ b/python/tests/test_generated_columns.py @@ -0,0 +1,220 @@ +import pyarrow as pa +import pytest + +from deltalake import DeltaTable, Field, Schema, write_deltalake +from deltalake.exceptions import DeltaError, SchemaMismatchError +from deltalake.schema import PrimitiveType + + +@pytest.fixture +def gc_schema() -> Schema: + return Schema( + [ + Field(name="id", type=PrimitiveType("integer")), + Field( + name="gc", + type=PrimitiveType("integer"), + metadata={"delta.generationExpression": "'5'"}, + ), + ] + ) + + +@pytest.fixture +def valid_gc_data() -> pa.Table: + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"}) + data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [10, 10]}, schema=pa.schema([id_col, gc]) + ) + return data + + +@pytest.fixture +def data_without_gc() -> pa.Table: + id_col = pa.field("id", pa.int32()) + data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col])) + return data + + +@pytest.fixture +def invalid_gc_data() -> pa.Table: + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"}) + data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 10]}, schema=pa.schema([id_col, gc]) + ) + return data + + +@pytest.fixture +def table_with_gc(tmp_path, gc_schema) -> DeltaTable: + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + ) + return dt + + +def test_create_table_with_generated_columns(tmp_path, gc_schema: Schema): + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + ) + protocol = dt.protocol() + assert protocol.min_writer_version == 4 + + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + mode="overwrite", + configuration={"delta.minWriterVersion": "7"}, + ) + protocol = dt.protocol() + + assert dt.version() == 1 + assert protocol.writer_features is not None + assert "generatedColumns" in protocol.writer_features + + +def test_write_with_gc(tmp_path, valid_gc_data): + write_deltalake(tmp_path, mode="append", data=valid_gc_data) + dt = DeltaTable(tmp_path) + + assert dt.protocol().min_writer_version == 4 + assert dt.to_pyarrow_table() == valid_gc_data + + +def test_write_with_gc_higher_writer_version(tmp_path, valid_gc_data): + write_deltalake( + tmp_path, + mode="append", + data=valid_gc_data, + configuration={"delta.minWriterVersion": "7"}, + ) + dt = DeltaTable(tmp_path) + protocol = dt.protocol() + assert protocol.min_writer_version == 7 + assert protocol.writer_features is not None + assert "generatedColumns" in protocol.writer_features + assert dt.to_pyarrow_table() == valid_gc_data + + +def test_write_with_invalid_gc(tmp_path, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + 'Invariant violations: ["Check or Invariant (gc = 10 OR (gc IS NULL AND 10 IS NULL)) violated by value in row: [5]"]' + ), + ): + write_deltalake(tmp_path, mode="append", data=invalid_gc_data) + + +def test_write_with_invalid_gc_to_table(table_with_gc, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + "Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]" + ), + ): + write_deltalake(table_with_gc, mode="append", data=invalid_gc_data) + + +def test_write_to_table_generating_data(table_with_gc: DeltaTable): + id_col = pa.field("id", pa.int32()) + data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col])) + write_deltalake(table_with_gc, mode="append", data=data) + + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()) + expected_data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) + ) + + assert table_with_gc.version() == 1 + assert table_with_gc.to_pyarrow_table() == expected_data + + +def test_raise_when_gc_passed_during_schema_evolution( + tmp_path, data_without_gc, valid_gc_data +): + write_deltalake( + tmp_path, + mode="append", + data=data_without_gc, + ) + dt = DeltaTable(tmp_path) + assert dt.protocol().min_writer_version == 2 + + with pytest.raises( + SchemaMismatchError, + match="Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.", + ): + write_deltalake( + dt, + mode="append", + data=valid_gc_data, + schema_mode="merge", + ) + + +def test_raise_when_gc_passed_during_adding_new_columns(tmp_path, data_without_gc): + write_deltalake( + tmp_path, + mode="append", + data=data_without_gc, + ) + dt = DeltaTable(tmp_path) + assert dt.protocol().min_writer_version == 2 + + with pytest.raises(DeltaError, match="New columns cannot be a generated column"): + dt.alter.add_columns( + fields=[ + Field( + name="gc", + type=PrimitiveType("integer"), + metadata={"delta.generationExpression": "'5'"}, + ) + ] + ) + + +def test_merge_with_gc(table_with_gc: DeltaTable, data_without_gc): + ( + table_with_gc.merge( + data_without_gc, predicate="s.id = t.id", source_alias="s", target_alias="t" + ) + .when_not_matched_insert_all() + .execute() + ) + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()) + expected_data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) + ) + assert table_with_gc.to_pyarrow_table() == expected_data + + +def test_merge_with_gc_invalid(table_with_gc: DeltaTable, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + "Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]" + ), + ): + ( + table_with_gc.merge( + invalid_gc_data, + predicate="s.id = t.id", + source_alias="s", + target_alias="t", + ) + .when_not_matched_insert_all() + .execute() + )