diff --git a/README.md b/README.md index e84bcd1..66c8e1e 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ here. | `max_buffer_size` | `["integer", "null"]` | `104857600` (100MB in bytes) | The maximum number of bytes to buffer in memory before writing to the destination table in Redshift | `batch_detection_threshold` | `["integer", "null"]` | `5000`, or 1/40th `max_batch_rows` | How often, in rows received, to count the buffered rows and bytes to check if a flush is necessary. There's a slight performance penalty to checking the buffered records count or bytesize, so this controls how often this is polled in order to mitigate the penalty. This value is usually not necessary to set as the default is dynamically adjusted to check reasonably often. | `persist_empty_tables` | `["boolean", "null"]` | `False` | Whether the Target should create tables which have no records present in Remote. | +| `redshift_copy_options` | `["list"]` | `[]` | Allows adding additional options to the [COPY](https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html) statement sent to Redshift. A list of available parameters can be found [here](https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-data-conversion.html). For example, this could be set to `["TRUNCATECOLUMNS"]` to enable the [`TRUNCATECOLUMNS` data conversion parameter](https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-data-conversion.html#copy-truncatecolumns). | `default_column_length` | `["integer", "null"]` | `1000` | All columns with the VARCHAR(CHARACTER VARYING) type will be have this length.Range: 1-65535. | | `state_support` | `["boolean", "null"]` | `True` | Whether the Target should emit `STATE` messages to stdout for further consumption. In this mode, which is on by default, STATE messages are buffered in memory until all the records that occurred before them are flushed according to the batch flushing schedule the target is configured with. | | `target_s3` | `["object"]` | `N/A` | See `S3` below | diff --git a/target_redshift/__init__.py b/target_redshift/__init__.py index a962d66..4142d50 100644 --- a/target_redshift/__init__.py +++ b/target_redshift/__init__.py @@ -40,7 +40,8 @@ def main(config, input_stream=None): redshift_schema=config.get('redshift_schema', 'public'), logging_level=config.get('logging_level'), default_column_length=config.get('default_column_length', 1000), - persist_empty_tables=config.get('persist_empty_tables') + persist_empty_tables=config.get('persist_empty_tables'), + redshift_copy_options=config.get('redshift_copy_options') ) if input_stream: diff --git a/target_redshift/redshift.py b/target_redshift/redshift.py index 2ddbb35..830a0db 100644 --- a/target_redshift/redshift.py +++ b/target_redshift/redshift.py @@ -50,6 +50,7 @@ def __init__(self, connection, s3, *args, logging_level=None, default_column_length=DEFAULT_COLUMN_LENGTH, persist_empty_tables=False, + redshift_copy_options=[], **kwargs): self.LOGGER.info( @@ -58,6 +59,12 @@ def __init__(self, connection, s3, *args, self.s3 = s3 self.default_column_length = default_column_length + + if isinstance(redshift_copy_options, list): + self.redshift_copy_options = redshift_copy_options + else: + self.redshift_copy_options = [] + PostgresTarget.__init__(self, connection, postgres_schema=redshift_schema, logging_level=logging_level, persist_empty_tables=persist_empty_tables, add_upsert_indexes=False) @@ -155,7 +162,7 @@ def persist_csv_rows(self, aws_secret_access_key= credentials.get('aws_secret_access_key') aws_session_token = credentials.get('aws_session_token') - copy_sql = sql.SQL('COPY {}.{} ({}) FROM {} CREDENTIALS {} FORMAT AS CSV NULL AS {}').format( + copy_sql = sql.SQL('COPY {}.{} ({}) FROM {} CREDENTIALS {} FORMAT AS CSV NULL AS {} {}').format( sql.Identifier(self.postgres_schema), sql.Identifier(temp_table_name), sql.SQL(', ').join(map(sql.Identifier, columns)), @@ -165,7 +172,8 @@ def persist_csv_rows(self, aws_secret_access_key, ";token={}".format(aws_session_token) if aws_session_token else '', )), - sql.Literal(RESERVED_NULL_DEFAULT)) + sql.Literal(RESERVED_NULL_DEFAULT), + sql.SQL(' '.join(self.redshift_copy_options))) cur.execute(copy_sql) diff --git a/tests/fixtures.py b/tests/fixtures.py index 73e3b0a..aafcca8 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -65,6 +65,9 @@ 'age': { 'type': ['null', 'integer'] }, + 'description': { + 'type': ['null', 'string'] + }, 'adoption': { 'type': ['object', 'null'], 'properties': { @@ -217,6 +220,16 @@ def generate_record(self): } +class LongCatStream(CatStream): + def generate_record(self): + record = CatStream.generate_record(self) + + # add some seriously long text + record['description'] = fake.paragraph(nb_sentences=1000) + + return record + + class InvalidCatStream(CatStream): def generate_record(self): record = CatStream.generate_record(self) diff --git a/tests/test_target_redshift.py b/tests/test_target_redshift.py index b2f1b67..01be183 100644 --- a/tests/test_target_redshift.py +++ b/tests/test_target_redshift.py @@ -6,7 +6,7 @@ import psycopg2.extras import pytest -from fixtures import CatStream, CONFIG, db_prep, MultiTypeStream, NestedStream, TEST_DB +from fixtures import CatStream, CONFIG, db_prep, MultiTypeStream, NestedStream, TEST_DB, LongCatStream from target_postgres import singer_stream from target_postgres.target_tools import TargetError @@ -711,3 +711,38 @@ def test_deduplication_existing_new_rows(db_prep): assert len(sequences) == 1 assert sequences[0][0] == original_sequence + + +def test_truncate_columns(db_prep): + stream = LongCatStream(100, version=1, nested_count=2) + + # this is what we're testing for + CONFIG['redshift_copy_options'] = ['TRUNCATECOLUMNS'] + CONFIG['default_column_length'] = 1000 + + main(CONFIG, input_stream=stream) + + with psycopg2.connect(**TEST_DB) as conn: + with conn.cursor() as cur: + cur.execute(get_count_sql('cats')) + table_count = cur.fetchone()[0] + + cur.execute(sql.SQL('SELECT {}, {} FROM {}.{}').format( + sql.SQL('MAX(LEN(description))'), + sql.SQL('MIN(LEN(description))'), + sql.Identifier(CONFIG['redshift_schema']), + sql.Identifier('cats') + )) + + result = cur.fetchone() + max_length = result[0] + min_length = result[1] + + # check if all records were inserted + assert table_count == 100 + + # check if they were truncated properly. + # LongCats' description is definitely longer than 1000 bytes, + # so it should always end up at exactly 1000 + assert max_length == CONFIG['default_column_length'] + assert min_length == CONFIG['default_column_length']