diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index c3bfefb47..0b0615a3d 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1065,6 +1065,7 @@ def metadata_hash(self) -> str: str(self.allow_partials), gen(self.session_properties_) if self.session_properties_ else None, str(self.validate_query) if self.validate_query is not None else None, + *[gen(g) for g in self.grains], ] for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]): diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 4222409e3..74b90b844 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -110,6 +110,7 @@ class BaseModelConfig(GeneralConfig): dependencies: Dependencies = Dependencies() tests: t.List[TestConfig] = [] dialect_: t.Optional[str] = Field(None, alias="dialect") + grain: t.Union[str, t.List[str]] = [] # DBT configuration fields name: str = "" @@ -294,6 +295,7 @@ def sqlmesh_model_kwargs( ) -> t.Dict[str, t.Any]: """Get common sqlmesh model parameters""" self.check_for_circular_test_refs(context) + model_dialect = self.dialect(context) model_context = context.context_for_dependencies( self.dependencies.union(self.tests_ref_source_dependencies) ) @@ -337,6 +339,7 @@ def sqlmesh_model_kwargs( "tags": self.tags, "physical_schema_mapping": context.sqlmesh_config.physical_schema_mapping, "default_catalog": context.target.database, + "grain": [d.parse_one(g, dialect=model_dialect) for g in ensure_list(self.grain)], **self.sqlmesh_config_kwargs, } diff --git a/sqlmesh/migrations/v0070_include_grains_in_metadata_hash.py b/sqlmesh/migrations/v0070_include_grains_in_metadata_hash.py new file mode 100644 index 000000000..dc75ac333 --- /dev/null +++ b/sqlmesh/migrations/v0070_include_grains_in_metadata_hash.py @@ -0,0 +1,5 @@ +"""Include grains in the metadata hash.""" + + +def migrate(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 00c3ec573..8c0b0631c 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -1482,3 +1482,29 @@ def test_dbt_incremental_allow_partials_by_default(): model.allow_partials = False assert not model.to_sqlmesh(context).allow_partials + + +def test_grain(): + context = DbtContext() + context._target = SnowflakeConfig( + name="target", + schema="test", + database="test", + account="account", + user="user", + password="password", + ) + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + sql="SELECT * FROM baz", + materialized=Materialization.TABLE.value, + grain=["id_a", "id_b"], + ) + assert model.to_sqlmesh(context).grains == [exp.to_column("id_a"), exp.to_column("id_b")] + + model.grain = "id_a" + assert model.to_sqlmesh(context).grains == [exp.to_column("id_a")]