Skip to content

Commit

Permalink
fix: add nullability check in deltachecker
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco authored and rtyler committed Nov 29, 2024
1 parent 1083c8c commit a1c37b7
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 2 deletions.
79 changes: 77 additions & 2 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1150,11 +1150,12 @@ pub(crate) async fn execute_plan_to_batch(
Ok(concat_batches(&plan.schema(), data.iter())?)
}

/// Responsible for checking batches of data conform to table's invariants.
#[derive(Clone)]
/// Responsible for checking batches of data conform to table's invariants, constraints and nullability.
#[derive(Clone, Default)]
pub struct DeltaDataChecker {
constraints: Vec<Constraint>,
invariants: Vec<Invariant>,
non_nullable_columns: Vec<String>,
ctx: SessionContext,
}

Expand All @@ -1164,6 +1165,7 @@ impl DeltaDataChecker {
Self {
invariants: vec![],
constraints: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1173,6 +1175,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1182,6 +1185,7 @@ impl DeltaDataChecker {
Self {
constraints,
invariants: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1202,9 +1206,21 @@ impl DeltaDataChecker {
pub fn new(snapshot: &DeltaTableState) -> Self {
let invariants = snapshot.schema().get_invariants().unwrap_or_default();
let constraints = snapshot.table_config().get_constraints();
let non_nullable_columns = snapshot
.schema()
.fields()
.filter_map(|f| {
if !f.is_nullable() {
Some(f.name().clone())
} else {
None
}
})
.collect_vec();
Self {
invariants,
constraints,
non_nullable_columns,
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1214,10 +1230,35 @@ impl DeltaDataChecker {
/// If it does not, it will return [DeltaTableError::InvalidData] with a list
/// of values that violated each invariant.
pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> {
self.check_nullability(record_batch)?;
self.enforce_checks(record_batch, &self.invariants).await?;
self.enforce_checks(record_batch, &self.constraints).await
}

/// Return true if all the nullability checks are valid
fn check_nullability(&self, record_batch: &RecordBatch) -> Result<bool, DeltaTableError> {
let mut violations = Vec::new();
for col in self.non_nullable_columns.iter() {
if let Some(arr) = record_batch.column_by_name(col) {
if arr.null_count() > 0 {
violations.push(format!(
"Non-nullable column violation for {col}, found {} null values",
arr.null_count()
));
}
} else {
violations.push(format!(
"Non-nullable column violation for {col}, not found in batch!"
));
}
}
if !violations.is_empty() {
Err(DeltaTableError::InvalidData { violations })
} else {
Ok(true)
}
}

async fn enforce_checks<C: DataCheck>(
&self,
record_batch: &RecordBatch,
Expand Down Expand Up @@ -2598,4 +2639,38 @@ mod tests {

assert_eq!(actual.len(), 0);
}

#[tokio::test]
async fn test_check_nullability() -> DeltaResult<()> {
use arrow::array::StringArray;

let data_checker = DeltaDataChecker {
non_nullable_columns: vec!["zed".to_string(), "yap".to_string()],
..Default::default()
};

let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
let nulls: Arc<dyn Array> = Arc::new(StringArray::new_null(1));
let batch = RecordBatch::try_from_iter(vec![("a", arr), ("zed", nulls)]).unwrap();

let result = data_checker.check_nullability(&batch);
assert!(
result.is_err(),
"The result should have errored! {result:?}"
);

let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
let batch = RecordBatch::try_from_iter(vec![("zed", arr)]).unwrap();
let result = data_checker.check_nullability(&batch);
assert!(
result.is_err(),
"The result should have errored! {result:?}"
);

let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
let batch = RecordBatch::try_from_iter(vec![("zed", arr.clone()), ("yap", arr)]).unwrap();
let _ = data_checker.check_nullability(&batch)?;

Ok(())
}
}
40 changes: 40 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from deltalake import DeltaTable, write_deltalake
from deltalake.exceptions import DeltaProtocolError
from deltalake.table import CommitProperties


Expand Down Expand Up @@ -1080,3 +1081,42 @@ def test_cdc_merge_planning_union_2908(tmp_path):
assert last_action["operation"] == "MERGE"
assert dt.version() == 1
assert os.path.exists(cdc_path), "_change_data doesn't exist"


@pytest.mark.pandas
def test_merge_non_nullable(tmp_path):
import re

import pandas as pd

from deltalake.schema import Field, PrimitiveType, Schema

schema = Schema(
[
Field("id", PrimitiveType("integer"), nullable=False),
Field("bool", PrimitiveType("boolean"), nullable=False),
]
)

dt = DeltaTable.create(tmp_path, schema=schema)
df = pd.DataFrame(
columns=["id", "bool"],
data=[
[1, True],
[2, None],
[3, False],
],
)

with pytest.raises(
DeltaProtocolError,
match=re.escape(
'Invariant violations: ["Non-nullable column violation for bool, found 1 null values"]'
),
):
dt.merge(
source=df,
source_alias="s",
target_alias="t",
predicate="s.id = t.id",
).when_matched_update_all().when_not_matched_insert_all().execute()

0 comments on commit a1c37b7

Please sign in to comment.