Skip to content

Commit

Permalink
fix: cast generated col exprs always, don'fetch nested metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Ion Koutsouris <[email protected]>
  • Loading branch information
ion-elgreco committed Jan 12, 2025
1 parent 0a8bea6 commit 173659c
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 52 deletions.
45 changes: 2 additions & 43 deletions crates/core/src/kernel/models/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,56 +56,15 @@ pub trait StructTypeExt {
}

impl StructTypeExt for StructType {
/// Get all invariants in the schemas
/// Get all get_generated_columns in the schemas
fn get_generated_columns(&self) -> Result<Vec<GeneratedColumn>, Error> {
let mut remaining_fields: Vec<(String, StructField)> = self
.fields()
.map(|field| (field.name.clone(), field.clone()))
.collect();
let mut generated_cols: Vec<GeneratedColumn> = Vec::new();

let add_segment = |prefix: &str, segment: &str| -> String {
if prefix.is_empty() {
segment.to_owned()
} else {
format!("{prefix}.{segment}")
}
};

while let Some((field_path, field)) = remaining_fields.pop() {
match field.data_type() {
DataType::Struct(inner) => {
remaining_fields.extend(
inner
.fields()
.map(|field| {
let new_prefix = add_segment(&field_path, &field.name);
(new_prefix, field.clone())
})
.collect::<Vec<(String, StructField)>>(),
);
}
DataType::Array(inner) => {
let element_field_name = add_segment(&field_path, "element");
remaining_fields.push((
element_field_name,
StructField::new("".to_string(), inner.element_type.clone(), false),
));
}
DataType::Map(inner) => {
let key_field_name = add_segment(&field_path, "key");
remaining_fields.push((
key_field_name,
StructField::new("".to_string(), inner.key_type.clone(), false),
));
let value_field_name = add_segment(&field_path, "value");
remaining_fields.push((
value_field_name,
StructField::new("".to_string(), inner.value_type.clone(), false),
));
}
_ => {}
}
if let Some(MetadataValue::String(generated_col_string)) = field
.metadata
.get(ColumnMetadataKey::GenerationExpression.as_ref())
Expand All @@ -117,7 +76,7 @@ impl StructTypeExt for StructType {
}
})?;
if let Value::String(sql) = json {
generated_cols.push(GeneratedColumn::new(&field_path, &sql));
generated_cols.push(GeneratedColumn::new(&field_path, &sql, field.data_type()));
}
}
}
Expand Down
10 changes: 8 additions & 2 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType};
use datafusion_expr::{
Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE,
ExprSchemable, Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode,
UNNAMED_TABLE,
};

use filter::try_construct_early_filter;
Expand Down Expand Up @@ -805,7 +806,12 @@ async fn execute(

df = df.clone().with_column(
generated_col.get_name(),
when(col(col_name).is_null(), generation_expr).otherwise(col(col_name))?,
when(col(col_name).is_null(), generation_expr)
.otherwise(col(col_name))?
.cast_to(
&arrow_schema::DataType::try_from(&generated_col.data_type)?,
df.schema(),
)?,
)?
}
Ok(df)
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/operations/transaction/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use once_cell::sync::Lazy;
use tracing::log::*;

use super::{TableReference, TransactionError};
use crate::kernel::{contains_timestampntz, Action, DataType, EagerSnapshot, Schema, StructField};
use crate::kernel::{contains_timestampntz, Action, EagerSnapshot, Schema};
use crate::protocol::DeltaOperation;
use crate::table::state::DeltaTableState;
use delta_kernel::table_features::{ReaderFeatures, WriterFeatures};
Expand Down Expand Up @@ -599,6 +599,6 @@ mod tests {
let eager_5 = table
.snapshot()
.expect("Failed to get snapshot from test table");
assert!(checker_5.can_write_to(eager_5).is_err());
assert!(checker_5.can_write_to(eager_5).is_ok());
}
}
8 changes: 6 additions & 2 deletions crates/core/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use arrow_cast::can_cast_types;
use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion_common::DFSchema;
use datafusion_expr::{col, lit, when, Expr};
use datafusion_expr::{col, lit, when, Expr, ExprSchemable};
use datafusion_physical_expr::expressions::{self};
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_plan::filter::FilterExec;
Expand Down Expand Up @@ -1050,7 +1050,11 @@ impl std::future::IntoFuture for WriteBuilder {
dfschema,
)?,
)
.otherwise(col(generated_col.get_name()))?,
.otherwise(col(generated_col.get_name()))?
.cast_to(
&arrow_schema::DataType::try_from(&generated_col.data_type)?,
dfschema,
)?,
dfschema,
)?;
Ok((generation_expr, field.name().to_owned()))
Expand Down
9 changes: 6 additions & 3 deletions crates/core/src/table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,27 @@ impl DataCheck for Constraint {
}

/// A generated column
#[derive(Eq, PartialEq, Debug, Default, Clone)]
#[derive(Eq, PartialEq, Debug, Clone)]
pub struct GeneratedColumn {
/// The full path to the field.
pub name: String,
/// The SQL string that generate the column value.
pub generation_expr: String,
/// The SQL string that must always evaluate to true.
pub validation_expr: String,
/// Data Type
pub data_type: DataType,
}

impl GeneratedColumn {
/// Create a new invariant
pub fn new(field_name: &str, sql_generation: &str) -> Self {
pub fn new(field_name: &str, sql_generation: &str, data_type: &DataType) -> Self {
Self {
name: field_name.to_string(),
generation_expr: sql_generation.to_string(),
validation_expr: format!("{field_name} = {sql_generation} OR ({field_name} IS NULL AND {sql_generation} IS NULL)"),
// validation_expr: format!("{} <=> {}", field_name, sql_generation),
// validation_expr: format!("{} <=> {}", field_name, sql_generation), // update to
data_type: data_type.clone()
}
}

Expand Down

0 comments on commit 173659c

Please sign in to comment.