Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check top level nullability during write #3026

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1150,11 +1150,12 @@ pub(crate) async fn execute_plan_to_batch(
Ok(concat_batches(&plan.schema(), data.iter())?)
}

/// Responsible for checking batches of data conform to table's invariants.
#[derive(Clone)]
/// Responsible for checking batches of data conform to table's invariants, constraints and nullability.
#[derive(Clone, Default)]
pub struct DeltaDataChecker {
constraints: Vec<Constraint>,
invariants: Vec<Invariant>,
non_nullable_columns: Vec<String>,
ctx: SessionContext,
}

Expand All @@ -1164,6 +1165,7 @@ impl DeltaDataChecker {
Self {
invariants: vec![],
constraints: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1173,6 +1175,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1182,6 +1185,7 @@ impl DeltaDataChecker {
Self {
constraints,
invariants: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1202,9 +1206,21 @@ impl DeltaDataChecker {
pub fn new(snapshot: &DeltaTableState) -> Self {
let invariants = snapshot.schema().get_invariants().unwrap_or_default();
let constraints = snapshot.table_config().get_constraints();
let non_nullable_columns = snapshot
.schema()
.fields()
.filter_map(|f| {
if !f.is_nullable() {
Some(f.name().clone())
} else {
None
}
})
.collect_vec();
Self {
invariants,
constraints,
non_nullable_columns,
ctx: DeltaSessionContext::default().into(),
}
}
Expand All @@ -1214,10 +1230,35 @@ impl DeltaDataChecker {
/// If it does not, it will return [DeltaTableError::InvalidData] with a list
/// of values that violated each invariant.
pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> {
self.check_nullability(record_batch)?;
self.enforce_checks(record_batch, &self.invariants).await?;
self.enforce_checks(record_batch, &self.constraints).await
}

/// Return true if all the nullability checks are valid
fn check_nullability(&self, record_batch: &RecordBatch) -> Result<bool, DeltaTableError> {
let mut violations = Vec::new();
for col in self.non_nullable_columns.iter() {
if let Some(arr) = record_batch.column_by_name(col) {
if arr.null_count() > 0 {
violations.push(format!(
"Non-nullable column violation for {col}, found {} null values",
arr.null_count()
));
}
} else {
violations.push(format!(
"Non-nullable column violation for {col}, not found in batch!"
));
}
}
if !violations.is_empty() {
Err(DeltaTableError::InvalidData { violations })
} else {
Ok(true)
}
}

async fn enforce_checks<C: DataCheck>(
&self,
record_batch: &RecordBatch,
Expand Down Expand Up @@ -2598,4 +2639,38 @@ mod tests {

assert_eq!(actual.len(), 0);
}

#[tokio::test]
async fn test_check_nullability() -> DeltaResult<()> {
use arrow::array::StringArray;

let data_checker = DeltaDataChecker {
non_nullable_columns: vec!["zed".to_string(), "yap".to_string()],
..Default::default()
};

let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
let nulls: Arc<dyn Array> = Arc::new(StringArray::new_null(1));
let batch = RecordBatch::try_from_iter(vec![("a", arr), ("zed", nulls)]).unwrap();

let result = data_checker.check_nullability(&batch);
assert!(
result.is_err(),
"The result should have errored! {result:?}"
);

let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
let batch = RecordBatch::try_from_iter(vec![("zed", arr)]).unwrap();
let result = data_checker.check_nullability(&batch);
assert!(
result.is_err(),
"The result should have errored! {result:?}"
);

let arr: Arc<dyn Array> = Arc::new(StringArray::from(vec!["s"]));
let batch = RecordBatch::try_from_iter(vec![("zed", arr.clone()), ("yap", arr)]).unwrap();
let _ = data_checker.check_nullability(&batch)?;

Ok(())
}
}
40 changes: 40 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from deltalake import DeltaTable, write_deltalake
from deltalake.exceptions import DeltaProtocolError
from deltalake.table import CommitProperties


Expand Down Expand Up @@ -1080,3 +1081,42 @@ def test_cdc_merge_planning_union_2908(tmp_path):
assert last_action["operation"] == "MERGE"
assert dt.version() == 1
assert os.path.exists(cdc_path), "_change_data doesn't exist"


@pytest.mark.pandas
def test_merge_non_nullable(tmp_path):
import re

import pandas as pd

from deltalake.schema import Field, PrimitiveType, Schema

schema = Schema(
[
Field("id", PrimitiveType("integer"), nullable=False),
Field("bool", PrimitiveType("boolean"), nullable=False),
]
)

dt = DeltaTable.create(tmp_path, schema=schema)
df = pd.DataFrame(
columns=["id", "bool"],
data=[
[1, True],
[2, None],
[3, False],
],
)

with pytest.raises(
DeltaProtocolError,
match=re.escape(
'Invariant violations: ["Non-nullable column violation for bool, found 1 null values"]'
),
):
dt.merge(
source=df,
source_alias="s",
target_alias="t",
predicate="s.id = t.id",
).when_matched_update_all().when_not_matched_insert_all().execute()
Loading