diff --git a/snowflake_utils/__main__.py b/snowflake_utils/__main__.py index 360a46f..dd3b660 100644 --- a/snowflake_utils/__main__.py +++ b/snowflake_utils/__main__.py @@ -1,10 +1,11 @@ +import logging +import os + import typer from typing_extensions import Annotated -from .models import FileFormat, InlineFileFormat, Table, Schema, Column -from .queries import connect -import logging -import os +from ..snowflake_utils.settings import SnowflakeSettings +from .models import Column, FileFormat, InlineFileFormat, Schema, Table app = typer.Typer() @@ -42,7 +43,7 @@ def mass_single_column_update( new_column = Column(name=new_column, data_type=data_type) log_level = os.getenv("LOG_LEVEL", "INFO") logging.getLogger("snowflake-utils").setLevel(log_level) - with connect() as conn, conn.cursor() as cursor: + with SnowflakeSettings.connect() as conn, conn.cursor() as cursor: tables = db_schema.get_tables(cursor=cursor) for table in tables: columns = table.get_columns(cursor=cursor) diff --git a/snowflake_utils/models/__init__.py b/snowflake_utils/models/__init__.py new file mode 100644 index 0000000..473fdf3 --- /dev/null +++ b/snowflake_utils/models/__init__.py @@ -0,0 +1,17 @@ +from .column import Column +from .enums import MatchByColumnName, TagLevel +from .file_format import FileFormat, InlineFileFormat +from .schema import Schema +from .table import Table +from .table_structure import TableStructure + +__all__ = [ + "Column", + "MatchByColumnName", + "TagLevel", + "Schema", + "Table", + "TableStructure", + "FileFormat", + "InlineFileFormat", +] diff --git a/snowflake_utils/models/column.py b/snowflake_utils/models/column.py new file mode 100644 index 0000000..1e737f4 --- /dev/null +++ b/snowflake_utils/models/column.py @@ -0,0 +1,43 @@ +from datetime import date, datetime + +from pydantic import BaseModel, Field + + +class Column(BaseModel): + name: str + data_type: str + tags: dict[str, str] = Field(default_factory=dict) + + +def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str: + if old_column_type == "VARIANT" and new_column_type != "VARIANT": + return f"PARSE_JSON({s})" + return s + + +def _matched(columns: list[Column], old_columns: dict[str, str]): + def tmp(x: str) -> str: + return f'tmp."{x}"' + + return ",".join( + f'dest."{c.name}" = {_possibly_cast(tmp(c.name), old_columns.get(c.name), c.data_type)}' + for c in columns + ) + + +def _inserts(columns: list[Column], old_columns: dict[str, str]) -> str: + return ",".join( + _possibly_cast(f'tmp."{c.name}"', old_columns.get(c.name), c.data_type) + for c in columns + ) + + +def _type_cast(s: any) -> any: + if isinstance(s, (int, float)): + return str(s) + elif isinstance(s, str): + return f"'{s}'" + elif isinstance(s, (datetime, date)): + return f"'{s.isoformat()}'" + else: + return f"'{s}'" diff --git a/snowflake_utils/models/enums.py b/snowflake_utils/models/enums.py new file mode 100644 index 0000000..caaad11 --- /dev/null +++ b/snowflake_utils/models/enums.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class MatchByColumnName(Enum): + CASE_SENSITIVE = "CASE_SENSITIVE" + CASE_INSENSITIVE = "CASE_INSENSITIVE" + NONE = "NONE" + + +class TagLevel(Enum): + COLUMN = "column" + TABLE = "table" diff --git a/snowflake_utils/models/file_format.py b/snowflake_utils/models/file_format.py new file mode 100644 index 0000000..962e9e1 --- /dev/null +++ b/snowflake_utils/models/file_format.py @@ -0,0 +1,31 @@ +from typing import Self + +from pydantic import BaseModel + + +class InlineFileFormat(BaseModel): + definition: str + + +class FileFormat(BaseModel): + database: str | None = None + schema_: str | None = None + name: str + + def __str__(self) -> str: + return ".".join( + s for s in [self.database, self.schema_, self.name] if s is not None + ) + + @classmethod + def from_string(cls, s: str) -> Self: + s = s.split(".") + match s: + case [database, schema, name]: + return cls(database=database, schema_=schema, name=name) + case [schema, name]: + return cls(schema_=schema, name=name) + case [name]: + return cls(name=name) + case _: + raise ValueError("Cannot parse file format") diff --git a/snowflake_utils/models/schema.py b/snowflake_utils/models/schema.py new file mode 100644 index 0000000..c0597af --- /dev/null +++ b/snowflake_utils/models/schema.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel +from snowflake.connector.cursor import SnowflakeCursor + +from .table import Table + + +class Schema(BaseModel): + name: str + database: str | None = None + + @property + def fully_qualified_name(self): + if self.database: + return f"{self.database}.{self.name}" + else: + return self.name + + def get_tables(self, cursor: SnowflakeCursor): + cursor.execute(f"show tables in schema {self.fully_qualified_name};") + data = cursor.execute( + 'select "name", "database_name", "schema_name" FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()));' + ).fetchall() + return [ + Table(name=name, schema_=schema, database=database) + for (name, database, schema, *_) in data + ] diff --git a/snowflake_utils/models.py b/snowflake_utils/models/table.py similarity index 82% rename from snowflake_utils/models.py rename to snowflake_utils/models/table.py index 8d2aba5..0dae23f 100644 --- a/snowflake_utils/models.py +++ b/snowflake_utils/models/table.py @@ -1,103 +1,16 @@ import logging from collections import defaultdict -from datetime import date, datetime -from enum import Enum from functools import partial -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from snowflake.connector.cursor import SnowflakeCursor -from typing_extensions import Self -from .queries import connect, execute_statement -from .settings import governance_settings - - -class MatchByColumnName(Enum): - CASE_SENSITIVE = "CASE_SENSITIVE" - CASE_INSENSITIVE = "CASE_INSENSITIVE" - NONE = "NONE" - - -class InlineFileFormat(BaseModel): - definition: str - - -class FileFormat(BaseModel): - database: str | None = None - schema_: str | None = None - name: str - - def __str__(self) -> str: - return ".".join( - s for s in [self.database, self.schema_, self.name] if s is not None - ) - - @classmethod - def from_string(cls, s: str) -> Self: - s = s.split(".") - match s: - case [database, schema, name]: - return cls(database=database, schema_=schema, name=name) - case [schema, name]: - return cls(schema_=schema, name=name) - case [name]: - return cls(name=name) - case _: - raise ValueError("Cannot parse file format") - - -class Column(BaseModel): - name: str - data_type: str - tags: dict[str, str] = Field(default_factory=dict) - - -class TableStructure(BaseModel): - columns: dict = [str, Column] - tags: dict[str, str] = Field(default_factory=dict) - - @property - def parsed_columns(self, replace_chars: bool = False) -> str: - if replace_chars: - return ", ".join( - f'"{str.upper(k).strip().replace("-","_")}" {v.data_type}' - for k, v in self.columns.items() - ) - else: - return ", ".join( - f'"{str.upper(k).strip()}" {v.data_type}' - for k, v in self.columns.items() - ) - - def parse_from_json(self): - raise NotImplementedError("Not implemented yet") - - @field_validator("columns") - @classmethod - def force_columns_to_casefold(cls, value) -> dict: - return {k.casefold(): v for k, v in value.items()} - - -class Schema(BaseModel): - name: str - database: str | None = None - - @property - def fully_qualified_name(self): - if self.database: - return f"{self.database}.{self.name}" - else: - return self.name - - def get_tables(self, cursor: SnowflakeCursor): - cursor.execute(f"show tables in schema {self.fully_qualified_name};") - data = cursor.execute( - 'select "name", "database_name", "schema_name" FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()));' - ).fetchall() - return [ - Table(name=name, schema_=schema, database=database) - for (name, database, schema, *_) in data - ] +from ..queries import execute_statement +from ..settings import connect, governance_settings +from .column import Column, _inserts, _matched, _type_cast +from .enums import MatchByColumnName, TagLevel +from .file_format import FileFormat, InlineFileFormat +from .table_structure import TableStructure class Table(BaseModel): @@ -431,13 +344,13 @@ def single_column_update( f"UPDATE {self.fqn} SET {target_column.name} = {new_column.name};" ) - def _current_tags(self, level: str) -> list[tuple[str, str, str]]: + def _current_tags(self, level: TagLevel) -> list[tuple[str, str, str]]: with connect() as connection: cursor = connection.cursor() cursor.execute( f"""select lower(column_name) as column_name, lower(tag_name) as tag_name, tag_value from table(information_schema.tag_references_all_columns('{self.fqn}', 'table')) - where lower(level) = '{level}' + where lower(level) = '{level.value}' """ ) return cursor.fetchall() @@ -445,14 +358,14 @@ def _current_tags(self, level: str) -> list[tuple[str, str, str]]: def current_column_tags(self) -> dict[str, dict[str, str]]: tags = defaultdict(dict) - for column_name, tag_name, tag_value in self._current_tags("column"): + for column_name, tag_name, tag_value in self._current_tags(TagLevel.COLUMN): tags[column_name][tag_name] = tag_value return tags def current_table_tags(self) -> dict[str, str]: return { tag_name.casefold(): tag_value - for _, tag_name, tag_value in self._current_tags("table") + for _, tag_name, tag_value in self._current_tags(TagLevel.TABLE) } def sync_tags_table(self, cursor: SnowflakeCursor) -> None: @@ -586,37 +499,3 @@ def setup_connection( ) return _execute_statement - - -def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str: - if old_column_type == "VARIANT" and new_column_type != "VARIANT": - return f"PARSE_JSON({s})" - return s - - -def _matched(columns: list[Column], old_columns: dict[str, str]): - def tmp(x: str) -> str: - return f'tmp."{x}"' - - return ",".join( - f'dest."{c.name}" = {_possibly_cast(tmp(c.name), old_columns.get(c.name), c.data_type)}' - for c in columns - ) - - -def _inserts(columns: list[Column], old_columns: dict[str, str]) -> str: - return ",".join( - _possibly_cast(f'tmp."{c.name}"', old_columns.get(c.name), c.data_type) - for c in columns - ) - - -def _type_cast(s: any) -> any: - if isinstance(s, (int, float)): - return str(s) - elif isinstance(s, str): - return f"'{s}'" - elif isinstance(s, (datetime, date)): - return f"'{s.isoformat()}'" - else: - return f"'{s}'" diff --git a/snowflake_utils/models/table_structure.py b/snowflake_utils/models/table_structure.py new file mode 100644 index 0000000..d1e383e --- /dev/null +++ b/snowflake_utils/models/table_structure.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel, Field, field_validator + +from .column import Column + + +class TableStructure(BaseModel): + columns: dict = [str, Column] + tags: dict[str, str] = Field(default_factory=dict) + + @property + def parsed_columns(self, replace_chars: bool = False) -> str: + if replace_chars: + return ", ".join( + f'"{str.upper(k).strip().replace("-","_")}" {v.data_type}' + for k, v in self.columns.items() + ) + else: + return ", ".join( + f'"{str.upper(k).strip()}" {v.data_type}' + for k, v in self.columns.items() + ) + + def parse_from_json(self): + raise NotImplementedError("Not implemented yet") + + @field_validator("columns") + @classmethod + def force_columns_to_casefold(cls, value) -> dict: + return {k.casefold(): v for k, v in value.items()} diff --git a/snowflake_utils/queries.py b/snowflake_utils/queries.py index e336639..5c5bd9a 100644 --- a/snowflake_utils/queries.py +++ b/snowflake_utils/queries.py @@ -1,15 +1,8 @@ +import logging from typing import no_type_check from snowflake import connector -from .settings import SnowflakeSettings -import logging - - -def connect() -> connector.SnowflakeConnection: - settings = SnowflakeSettings() - return connector.connect(**settings.creds()) - @no_type_check def execute_statement( diff --git a/snowflake_utils/settings.py b/snowflake_utils/settings.py index 7d958cb..27c2610 100644 --- a/snowflake_utils/settings.py +++ b/snowflake_utils/settings.py @@ -1,4 +1,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict +from snowflake.connector import SnowflakeConnection +from snowflake.connector import connect as _connect class SnowflakeSettings(BaseSettings): @@ -23,6 +25,13 @@ def creds(self) -> dict[str, str | None]: "warehouse": self.warehouse, } + def connect(self) -> SnowflakeConnection: + return _connect(**self.creds()) + + +def connect() -> SnowflakeConnection: + return SnowflakeSettings().connect() + class GovernanceSettings(BaseSettings): governance_database: str = "governance" diff --git a/tests/test_models.py b/tests/test_models.py index 2411c06..c3ad6aa 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -12,7 +12,7 @@ Table, TableStructure, ) -from snowflake_utils.queries import connect +from snowflake_utils.settings import connect test_table_schema = TableStructure( columns={