Skip to content

Commit

Permalink
Fix property audits_with_args in model definition.py (#3904)
Browse files Browse the repository at this point in the history
Co-authored-by: Iaroslav Zeigerman <[email protected]>
  • Loading branch information
blecourt-private and izeigerman authored Feb 27, 2025
1 parent b9e8e1e commit 57ddadf
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
12 changes: 7 additions & 5 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,16 +1203,18 @@ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS

audits_by_name = {**BUILT_IN_AUDITS, **self.audit_definitions}
audits_with_args = {}
audits_with_args = []
added_audits = set()

for audit_name, audit_args in self.audits:
audits_with_args[audit_name] = (audits_by_name[audit_name], audit_args.copy())
audits_with_args.append((audits_by_name[audit_name], audit_args.copy()))
added_audits.add(audit_name)

for audit_name in self.audit_definitions:
if audit_name not in audits_with_args:
audits_with_args[audit_name] = (audits_by_name[audit_name], {})
if audit_name not in added_audits:
audits_with_args.append((audits_by_name[audit_name], {}))

return list(audits_with_args.values())
return audits_with_args

@property
def _is_time_column_in_partitioned_by(self) -> bool:
Expand Down
38 changes: 38 additions & 0 deletions tests/core/test_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,3 +921,41 @@ def test_rendered_diff():
WHERE
- TRUE > 2
+ FALSE > 2""" in audit1.text_diff(audit2, rendered=True)


def test_multiple_audits_with_same_name():
expressions = parse(
"""
MODEL (
name db.table,
dialect spark,
audits(
does_not_exceed_threshold(column := id, threshold := 1000),
does_not_exceed_threshold(column := price, threshold := 100),
does_not_exceed_threshold(column := price, threshold := 100)
)
);
SELECT id, price FROM tbl;
AUDIT (
name does_not_exceed_threshold,
);
SELECT * FROM @this_model
WHERE @column >= @threshold;
"""
)
model = load_sql_based_model(expressions)
assert len(model.audits) == 3
assert len(model.audits_with_args) == 3
assert len(model.audit_definitions) == 1

# Testing that audit names are identical
assert model.audits[0][0] == model.audits[1][0] == model.audits[2][0]

# Testing that audit arguments are different for first and second audit
assert model.audits[0][1] != model.audits[1][1]

# Testing that audit arguments are identical for second and third audit
# This establishes that identical audits are preserved
assert model.audits[1][1] == model.audits[2][1]

0 comments on commit 57ddadf

Please sign in to comment.