diff --git a/snowflake_utils/models.py b/snowflake_utils/models.py index 3bbdb10..23e2417 100644 --- a/snowflake_utils/models.py +++ b/snowflake_utils/models.py @@ -4,7 +4,7 @@ from enum import Enum from functools import partial -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from snowflake.connector.cursor import SnowflakeCursor from typing_extensions import Self @@ -65,6 +65,11 @@ def parsed_columns(self): def parse_from_json(self): raise NotImplementedError("Not implemented yet") + @field_validator("columns") + @classmethod + def force_columns_to_casefold(cls, value) -> dict: + return {k.casefold(): v for k, v in value.items()} + class Schema(BaseModel): name: str @@ -397,12 +402,15 @@ def sync_tags(self, cursor: SnowflakeCursor) -> None: desired_tags = { f"{column}.{tag_name}.{tag_value}".casefold(): (column, tag_name, tag_value) for column in self.table_structure.columns - for tag_name, tag_value in self.table_structure.columns[column].tags.items() + for tag_name, tag_value in self.table_structure.columns[ + column.casefold() + ].tags.items() } for tag in existing_tags: if tag not in desired_tags: - self._unset_tag(cursor, *existing_tags[tag]) + column, tag_name, _value = existing_tags[tag] + self._unset_tag(cursor, column, tag_name) for tag in desired_tags: if tag not in existing_tags: