diff --git a/target_postgres/connector.py b/target_postgres/connector.py index c9d250f6..0a4ed9a4 100644 --- a/target_postgres/connector.py +++ b/target_postgres/connector.py @@ -31,6 +31,7 @@ VARCHAR, TypeDecorator, ) +from singer_sdk.helpers.capabilities import TargetLoadMethods from sshtunnel import SSHTunnelForwarder @@ -41,6 +42,7 @@ class PostgresConnector(SQLConnector): allow_column_rename: bool = True # Whether RENAME COLUMN is supported. allow_column_alter: bool = False # Whether altering column types is supported. allow_merge_upsert: bool = True # Whether MERGE UPSERT is supported. + allow_overwrite: bool = True # Whether overwrite load method is supported. allow_temp_tables: bool = True # Whether temp tables are supported. def __init__(self, config: dict) -> None: @@ -92,6 +94,24 @@ def interpret_content_encoding(self) -> bool: """ return self.config.get("interpret_content_encoding", False) + def get_table_from_metadata( + self, + full_table_name: str, + connection: sa.engine.Connection + ) -> sa.Table: + """Returns an existing table object from the database + + Args: + full_table_name: the fully qualified table name. + + Returns: + The table object. + """ + _, schema_name, table_name = self.parse_full_table_name(full_table_name) + meta = sa.MetaData(schema=schema_name) + meta.reflect(connection, only=[table_name]) + return meta.tables[full_table_name] + def prepare_table( # type: ignore[override] self, full_table_name: str, @@ -100,7 +120,7 @@ def prepare_table( # type: ignore[override] connection: sa.engine.Connection, partition_keys: list[str] | None = None, as_temp_table: bool = False, - ) -> sa.Table: + ) -> None: """Adapt target table to provided schema if possible. Args: @@ -117,9 +137,23 @@ def prepare_table( # type: ignore[override] _, schema_name, table_name = self.parse_full_table_name(full_table_name) meta = sa.MetaData(schema=schema_name) table: sa.Table + if not self.table_exists(full_table_name=full_table_name): - table = self.create_empty_table( - table_name=table_name, + self.create_empty_table( + full_table_name=full_table_name, + meta=meta, + schema=schema, + primary_keys=primary_keys, + partition_keys=partition_keys, + as_temp_table=as_temp_table, + connection=connection, + ) + return + + if self.config["load_method"] == TargetLoadMethods.OVERWRITE: + self.get_table(full_table_name=full_table_name).drop(self._engine) + self.create_empty_table( + full_table_name=full_table_name, meta=meta, schema=schema, primary_keys=primary_keys, @@ -127,16 +161,15 @@ def prepare_table( # type: ignore[override] as_temp_table=as_temp_table, connection=connection, ) - return table + return + meta.reflect(connection, only=[table_name]) table = meta.tables[ full_table_name ] # So we don't mess up the casing of the Table reference columns = self.get_table_columns( - schema_name=cast(str, schema_name), - table_name=table_name, - connection=connection, + full_table_name=full_table_name, ) for property_name, property_def in schema["properties"].items(): @@ -151,8 +184,6 @@ def prepare_table( # type: ignore[override] column_object=column_object, ) - return meta.tables[full_table_name] - def copy_table_structure( self, full_table_name: str, @@ -331,7 +362,7 @@ def pick_best_sql_type(sql_type_array: list): def create_empty_table( # type: ignore[override] self, - table_name: str, + full_table_name: str, meta: sa.MetaData, schema: dict, connection: sa.engine.Connection, @@ -357,6 +388,9 @@ def create_empty_table( # type: ignore[override] NotImplementedError: if temp tables are unsupported and as_temp_table=True. RuntimeError: if a variant schema is passed with no properties defined. """ + + _, schema_name, table_name = self.parse_full_table_name(full_table_name) + columns: list[sa.Column] = [] primary_keys = primary_keys or [] try: @@ -410,66 +444,31 @@ def prepare_column( _, schema_name, table_name = self.parse_full_table_name(full_table_name) column_exists = column_object is not None or self.column_exists( - full_table_name, column_name, connection=connection + full_table_name, column_name, ) if not column_exists: self._create_empty_column( # We should migrate every function to use sa.Table # instead of having to know what the function wants - table_name=table_name, + full_table_name=full_table_name, column_name=column_name, sql_type=sql_type, - schema_name=cast(str, schema_name), - connection=connection, ) return self._adapt_column_type( - schema_name=cast(str, schema_name), - table_name=table_name, + full_table_name=full_table_name, column_name=column_name, sql_type=sql_type, connection=connection, column_object=column_object, ) - def _create_empty_column( # type: ignore[override] - self, - schema_name: str, - table_name: str, - column_name: str, - sql_type: sa.types.TypeEngine, - connection: sa.engine.Connection, - ) -> None: - """Create a new column. - - Args: - schema_name: The schema name. - table_name: The table name. - column_name: The name of the new column. - sql_type: SQLAlchemy type engine to be used in creating the new column. - connection: The database connection. - - Raises: - NotImplementedError: if adding columns is not supported. - """ - if not self.allow_column_add: - msg = "Adding columns is not supported." - raise NotImplementedError(msg) - - column_add_ddl = self.get_column_add_ddl( - schema_name=schema_name, - table_name=table_name, - column_name=column_name, - column_type=sql_type, - ) - connection.execute(column_add_ddl) def get_column_add_ddl( # type: ignore[override] self, table_name: str, - schema_name: str, column_name: str, column_type: sa.types.TypeEngine, ) -> sa.DDL: @@ -484,6 +483,8 @@ def get_column_add_ddl( # type: ignore[override] Returns: A sqlalchemy DDL instance. """ + _, schema_name, table_name = self.parse_full_table_name(table_name) + column = sa.Column(column_name, column_type) return sa.DDL( @@ -501,12 +502,11 @@ def get_column_add_ddl( # type: ignore[override] def _adapt_column_type( # type: ignore[override] self, - schema_name: str, - table_name: str, + full_table_name: str, column_name: str, sql_type: sa.types.TypeEngine, - connection: sa.engine.Connection, - column_object: sa.Column | None, + connection: sa.engine.Connection | None = None, + column_object: sa.Column | None = None, ) -> None: """Adapt table column type to support the new JSON schema type. @@ -521,15 +521,21 @@ def _adapt_column_type( # type: ignore[override] Raises: NotImplementedError: if altering columns is not supported. """ + if connection is None: + super()._adapt_column_type( + full_table_name=full_table_name, + column_name=column_name, + sql_type=sql_type, + ) + return + current_type: sa.types.TypeEngine if column_object is not None: current_type = t.cast(sa.types.TypeEngine, column_object.type) else: current_type = self._get_column_type( - schema_name=schema_name, - table_name=table_name, + full_table_name=full_table_name, column_name=column_name, - connection=connection, ) # remove collation if present and save it @@ -556,14 +562,13 @@ def _adapt_column_type( # type: ignore[override] if not self.allow_column_alter: msg = ( "Altering columns is not supported. Could not convert column " - f"'{schema_name}.{table_name}.{column_name}' from '{current_type}' to " + f"'{full_table_name}.{column_name}' from '{current_type}' to " f"'{compatible_sql_type}'." ) raise NotImplementedError(msg) alter_column_ddl = self.get_column_alter_ddl( - schema_name=schema_name, - table_name=table_name, + table_name=full_table_name, column_name=column_name, column_type=compatible_sql_type, ) @@ -571,7 +576,6 @@ def _adapt_column_type( # type: ignore[override] def get_column_alter_ddl( # type: ignore[override] self, - schema_name: str, table_name: str, column_name: str, column_type: sa.types.TypeEngine, @@ -589,6 +593,7 @@ def get_column_alter_ddl( # type: ignore[override] Returns: A sqlalchemy DDL instance. """ + _, schema_name, _ = self.parse_full_table_name(table_name) column = sa.Column(column_name, column_type) return sa.DDL( ( @@ -736,98 +741,6 @@ def catch_signal(self, signum, frame) -> None: """ exit(1) # Calling this to be sure atexit is called, so clean_up gets called - def _get_column_type( # type: ignore[override] - self, - schema_name: str, - table_name: str, - column_name: str, - connection: sa.engine.Connection, - ) -> sa.types.TypeEngine: - """Get the SQL type of the declared column. - - Args: - schema_name: The schema name. - table_name: The table name. - column_name: The name of the column. - connection: The database connection. - - Returns: - The type of the column. - - Raises: - KeyError: If the provided column name does not exist. - """ - try: - column = self.get_table_columns( - schema_name=schema_name, - table_name=table_name, - connection=connection, - )[column_name] - except KeyError as ex: - msg = ( - f"Column `{column_name}` does not exist in table" - "`{schema_name}.{table_name}`." - ) - raise KeyError(msg) from ex - - return t.cast(sa.types.TypeEngine, column.type) - - def get_table_columns( # type: ignore[override] - self, - schema_name: str, - table_name: str, - connection: sa.engine.Connection, - column_names: list[str] | None = None, - ) -> dict[str, sa.Column]: - """Return a list of table columns. - - Overrode to support schema_name - - Args: - schema_name: schema name. - table_name: table name to get columns for. - connection: database connection. - column_names: A list of column names to filter to. - - Returns: - An ordered list of column objects. - """ - inspector = sa.inspect(connection) - columns = inspector.get_columns(table_name, schema_name) - - return { - col_meta["name"]: sa.Column( - col_meta["name"], - col_meta["type"], - nullable=col_meta.get("nullable", False), - ) - for col_meta in columns - if not column_names - or col_meta["name"].casefold() in {col.casefold() for col in column_names} - } - - def column_exists( # type: ignore[override] - self, - full_table_name: str, - column_name: str, - connection: sa.engine.Connection, - ) -> bool: - """Determine if the target column already exists. - - Args: - full_table_name: the target table name. - column_name: the target column name. - connection: the database connection. - - Returns: - True if table exists, False if not. - """ - _, schema_name, table_name = self.parse_full_table_name(full_table_name) - assert schema_name is not None - assert table_name is not None - return column_name in self.get_table_columns( - schema_name=schema_name, table_name=table_name, connection=connection - ) class NOTYPE(TypeDecorator): diff --git a/target_postgres/sinks.py b/target_postgres/sinks.py index ea8b8df3..8f4e5da2 100644 --- a/target_postgres/sinks.py +++ b/target_postgres/sinks.py @@ -73,15 +73,12 @@ def process_batch(self, context: dict) -> None: """ # Use one connection so we do this all in a single transaction with self.connector._connect() as connection, connection.begin(): - # Check structure of table - table: sa.Table = self.connector.prepare_table( - full_table_name=self.full_table_name, - schema=self.schema, - primary_keys=self.key_properties, - as_temp_table=False, - connection=connection, - ) + + table = self.connector.get_table_from_metadata(self.full_table_name, connection) + # Create a temp table (Creates from the table above) + # TODO: maybe we should not even create the table copying the structure + # but just create the temp table using the schema temp_table: sa.Table = self.connector.copy_table_structure( full_table_name=self.temp_table_name, from_table=table, @@ -338,7 +335,6 @@ def activate_version(self, new_version: int) -> None: if not self.connector.column_exists( full_table_name=self.full_table_name, column_name=self.version_column_name, - connection=connection, ): raise RuntimeError( f"{self.version_column_name} is required for activate version " @@ -349,7 +345,6 @@ def activate_version(self, new_version: int) -> None: or self.connector.column_exists( full_table_name=self.full_table_name, column_name=self.soft_delete_column_name, - connection=connection, ) ): raise RuntimeError(