Skip to content

Commit

Permalink
feat: generated columns
Browse files Browse the repository at this point in the history
Signed-off-by: Ion Koutsouris <[email protected]>
  • Loading branch information
ion-elgreco committed Jan 12, 2025
1 parent c56d6c0 commit be67624
Show file tree
Hide file tree
Showing 11 changed files with 357 additions and 162 deletions.
15 changes: 13 additions & 2 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use crate::kernel::{Add, DataCheck, EagerSnapshot, Invariant, Snapshot, StructTy
use crate::logstore::LogStoreRef;
use crate::table::builder::ensure_table_uri;
use crate::table::state::DeltaTableState;
use crate::table::Constraint;
use crate::table::{Constraint, GeneratedColumn};
use crate::{open_table, open_table_with_storage_options, DeltaTable};

pub(crate) const PATH_COLUMN: &str = "__delta_rs_path";
Expand Down Expand Up @@ -1159,6 +1159,7 @@ pub(crate) async fn execute_plan_to_batch(
pub struct DeltaDataChecker {
constraints: Vec<Constraint>,
invariants: Vec<Invariant>,
generated_columns: Vec<GeneratedColumn>,
non_nullable_columns: Vec<String>,
ctx: SessionContext,
}
Expand All @@ -1169,6 +1170,7 @@ impl DeltaDataChecker {
Self {
invariants: vec![],
constraints: vec![],
generated_columns: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1179,6 +1181,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints: vec![],
generated_columns: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1189,6 +1192,7 @@ impl DeltaDataChecker {
Self {
constraints,
invariants: vec![],
generated_columns: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1209,6 +1213,10 @@ impl DeltaDataChecker {
/// Create a new DeltaDataChecker
pub fn new(snapshot: &DeltaTableState) -> Self {
let invariants = snapshot.schema().get_invariants().unwrap_or_default();
let generated_columns = snapshot
.schema()
.get_generated_columns()
.unwrap_or_default();
let constraints = snapshot.table_config().get_constraints();
let non_nullable_columns = snapshot
.schema()
Expand All @@ -1224,6 +1232,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints,
generated_columns,
non_nullable_columns,
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1236,7 +1245,9 @@ impl DeltaDataChecker {
pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> {
self.check_nullability(record_batch)?;
self.enforce_checks(record_batch, &self.invariants).await?;
self.enforce_checks(record_batch, &self.constraints).await
self.enforce_checks(record_batch, &self.constraints).await?;
self.enforce_checks(record_batch, &self.generated_columns)
.await
}

/// Return true if all the nullability checks are valid
Expand Down
9 changes: 9 additions & 0 deletions crates/core/src/kernel/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ pub enum Error {
line: String,
},

/// Error returned when the log contains invalid stats JSON.
#[error("Invalid JSON in generation expression, line=`{line}`, err=`{json_err}`")]
InvalidGenerationExpressionJson {
/// JSON error details returned when parsing the generation expression JSON.
json_err: serde_json::error::Error,
/// Generation expression.
line: String,
},

#[error("Table metadata is invalid: {0}")]
MetadataError(String),

Expand Down
99 changes: 89 additions & 10 deletions crates/core/src/kernel/models/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use std::collections::{HashMap, HashSet};
use std::fmt::{self, Display};
use std::str::FromStr;

use delta_kernel::schema::{DataType, StructField};
use maplit::hashset;
use serde::{Deserialize, Serialize};
use tracing::warn;
use url::Url;

use super::schema::StructType;
use super::StructTypeExt;
use crate::kernel::{error::Error, DeltaResult};
use crate::TableProperty;
use delta_kernel::table_features::{ReaderFeatures, WriterFeatures};
Expand Down Expand Up @@ -115,6 +117,19 @@ impl Metadata {
}
}

/// checks if table contains timestamp_ntz in any field including nested fields.
pub fn contains_timestampntz<'a>(mut fields: impl Iterator<Item = &'a StructField>) -> bool {
fn _check_type(dtype: &DataType) -> bool {
match dtype {
&DataType::TIMESTAMP_NTZ => true,
DataType::Array(inner) => _check_type(inner.element_type()),
DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())),
_ => false,
}
}
fields.any(|f| _check_type(f.data_type()))
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
#[serde(rename_all = "camelCase")]
/// Defines a protocol action
Expand Down Expand Up @@ -146,8 +161,8 @@ impl Protocol {
}
}

/// set the reader features in the protocol action, automatically bumps min_reader_version
pub fn with_reader_features(
/// Append the reader features in the protocol action, automatically bumps min_reader_version
pub fn append_reader_features(
mut self,
reader_features: impl IntoIterator<Item = impl Into<ReaderFeatures>>,
) -> Self {
Expand All @@ -156,14 +171,20 @@ impl Protocol {
.map(Into::into)
.collect::<HashSet<_>>();
if !all_reader_features.is_empty() {
self.min_reader_version = 3
self.min_reader_version = 3;
match self.reader_features {
Some(mut features) => {
features.extend(all_reader_features);
self.reader_features = Some(features);
}
None => self.reader_features = Some(all_reader_features),
};
}
self.reader_features = Some(all_reader_features);
self
}

/// set the writer features in the protocol action, automatically bumps min_writer_version
pub fn with_writer_features(
pub fn append_writer_features(
mut self,
writer_features: impl IntoIterator<Item = impl Into<WriterFeatures>>,
) -> Self {
Expand All @@ -172,9 +193,16 @@ impl Protocol {
.map(|c| c.into())
.collect::<HashSet<_>>();
if !all_writer_feautures.is_empty() {
self.min_writer_version = 7
self.min_writer_version = 7;

match self.writer_features {
Some(mut features) => {
features.extend(all_writer_feautures);
self.writer_features = Some(features);
}
None => self.writer_features = Some(all_writer_feautures),
};
}
self.writer_features = Some(all_writer_feautures);
self
}

Expand Down Expand Up @@ -255,6 +283,32 @@ impl Protocol {
}
self
}

/// Will apply the column metadata to the protocol by either bumping the version or setting
/// features
pub fn apply_column_metadata_to_protocol(
mut self,
schema: &StructType,
) -> DeltaResult<Protocol> {
let generated_cols = schema.get_generated_columns()?;
let invariants = schema.get_invariants()?;
let contains_timestamp_ntz = self.contains_timestampntz(schema.fields());

if contains_timestamp_ntz {
self = self.enable_timestamp_ntz()
}

if !generated_cols.is_empty() {
self = self.enable_generated_columns()
}

if !invariants.is_empty() {
self = self.enable_invariants()
}

Ok(self)
}

/// Will apply the properties to the protocol by either bumping the version or setting
/// features
pub fn apply_properties_to_protocol(
Expand Down Expand Up @@ -391,10 +445,35 @@ impl Protocol {
}
Ok(self)
}

/// checks if table contains timestamp_ntz in any field including nested fields.
fn contains_timestampntz<'a>(&self, fields: impl Iterator<Item = &'a StructField>) -> bool {
contains_timestampntz(fields)
}

/// Enable timestamp_ntz in the protocol
pub fn enable_timestamp_ntz(mut self) -> Protocol {
self = self.with_reader_features(vec![ReaderFeatures::TimestampWithoutTimezone]);
self = self.with_writer_features(vec![WriterFeatures::TimestampWithoutTimezone]);
fn enable_timestamp_ntz(mut self) -> Self {
self = self.append_reader_features([ReaderFeatures::TimestampWithoutTimezone]);
self = self.append_writer_features([WriterFeatures::TimestampWithoutTimezone]);
self
}

/// Enabled generated columns
fn enable_generated_columns(mut self) -> Self {
if self.min_writer_version < 4 {
self.min_writer_version = 4;
}
if self.min_writer_version >= 7 {
self = self.append_writer_features([WriterFeatures::GeneratedColumns]);
}
self
}

/// Enabled generated columns
fn enable_invariants(mut self) -> Self {
if self.min_writer_version >= 7 {
self = self.append_writer_features([WriterFeatures::Invariants]);
}
self
}
}
Expand Down
72 changes: 72 additions & 0 deletions crates/core/src/kernel/models/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde_json::Value;

use crate::kernel::error::Error;
use crate::kernel::DataCheck;
use crate::table::GeneratedColumn;

/// Type alias for a top level schema
pub type Schema = StructType;
Expand Down Expand Up @@ -49,9 +50,80 @@ impl DataCheck for Invariant {
pub trait StructTypeExt {
/// Get all invariants in the schemas
fn get_invariants(&self) -> Result<Vec<Invariant>, Error>;

/// Get all generated column expressions
fn get_generated_columns(&self) -> Result<Vec<GeneratedColumn>, Error>;
}

impl StructTypeExt for StructType {
/// Get all invariants in the schemas
fn get_generated_columns(&self) -> Result<Vec<GeneratedColumn>, Error> {
let mut remaining_fields: Vec<(String, StructField)> = self
.fields()
.map(|field| (field.name.clone(), field.clone()))
.collect();
let mut generated_cols: Vec<GeneratedColumn> = Vec::new();

let add_segment = |prefix: &str, segment: &str| -> String {
if prefix.is_empty() {
segment.to_owned()
} else {
format!("{prefix}.{segment}")
}
};

while let Some((field_path, field)) = remaining_fields.pop() {
match field.data_type() {
DataType::Struct(inner) => {
remaining_fields.extend(
inner
.fields()
.map(|field| {
let new_prefix = add_segment(&field_path, &field.name);
(new_prefix, field.clone())
})
.collect::<Vec<(String, StructField)>>(),
);
}
DataType::Array(inner) => {
let element_field_name = add_segment(&field_path, "element");
remaining_fields.push((
element_field_name,
StructField::new("".to_string(), inner.element_type.clone(), false),
));
}
DataType::Map(inner) => {
let key_field_name = add_segment(&field_path, "key");
remaining_fields.push((
key_field_name,
StructField::new("".to_string(), inner.key_type.clone(), false),
));
let value_field_name = add_segment(&field_path, "value");
remaining_fields.push((
value_field_name,
StructField::new("".to_string(), inner.value_type.clone(), false),
));
}
_ => {}
}
if let Some(MetadataValue::String(generated_col_string)) = field
.metadata
.get(ColumnMetadataKey::GenerationExpression.as_ref())
{
let json: Value = serde_json::from_str(generated_col_string).map_err(|e| {
Error::InvalidGenerationExpressionJson {
json_err: e,
line: generated_col_string.to_string(),
}
})?;
if let Value::String(sql) = json {
generated_cols.push(GeneratedColumn::new(&field_path, &sql));
}
}
}
Ok(generated_cols)
}

/// Get all invariants in the schemas
fn get_invariants(&self) -> Result<Vec<Invariant>, Error> {
let mut remaining_fields: Vec<(String, StructField)> = self
Expand Down
26 changes: 7 additions & 19 deletions crates/core/src/operations/add_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,12 @@ impl std::future::IntoFuture for AddColumnBuilder {
let table_schema = this.snapshot.schema();
let new_table_schema = merge_delta_struct(table_schema, fields_right)?;

// TODO(ion): Think of a way how we can simply this checking through the API or centralize some checks.
let contains_timestampntz = PROTOCOL.contains_timestampntz(fields.iter());
let protocol = this.snapshot.protocol();

let maybe_new_protocol = if contains_timestampntz {
let updated_protocol = protocol.clone().enable_timestamp_ntz();
if !(protocol.min_reader_version == 3 && protocol.min_writer_version == 7) {
// Convert existing properties to features since we advanced the protocol to v3,7
Some(
updated_protocol
.move_table_properties_into_features(&metadata.configuration),
)
} else {
Some(updated_protocol)
}
} else {
None
};
let current_protocol = this.snapshot.protocol();

let new_protocol = current_protocol
.clone()
.apply_column_metadata_to_protocol(&new_table_schema)?
.move_table_properties_into_features(&metadata.configuration);

let operation = DeltaOperation::AddColumn {
fields: fields.into_iter().collect_vec(),
Expand All @@ -115,7 +103,7 @@ impl std::future::IntoFuture for AddColumnBuilder {

let mut actions = vec![metadata.into()];

if let Some(new_protocol) = maybe_new_protocol {
if current_protocol != &new_protocol {
actions.push(new_protocol.into())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/operations/add_feature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ impl std::future::IntoFuture for AddTableFeatureBuilder {
}
}

protocol = protocol.with_reader_features(reader_features);
protocol = protocol.with_writer_features(writer_features);
protocol = protocol.append_reader_features(reader_features);
protocol = protocol.append_writer_features(writer_features);

let operation = DeltaOperation::AddFeature {
name: name.to_vec(),
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/operations/cdc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ mod tests {
#[tokio::test]
async fn test_should_write_cdc_v7_table_with_writer_feature() {
let protocol =
Protocol::new(1, 7).with_writer_features(vec![WriterFeatures::ChangeDataFeed]);
Protocol::new(1, 7).append_writer_features(vec![WriterFeatures::ChangeDataFeed]);
let actions = vec![Action::Protocol(protocol)];
let mut table: DeltaTable = DeltaOps::new_in_memory()
.create()
Expand Down
Loading

0 comments on commit be67624

Please sign in to comment.