From 4e07d0a2a70294fa34bc87926ebecbd3eb8c9016 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Cant=C3=B9?= Date: Fri, 29 Nov 2024 15:03:23 +0100 Subject: [PATCH] use pyarrow for efficient insert --- .gitignore | 1 + pyproject.toml | 3 ++- src/main.py | 28 +++++++++++++++++----------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 06187fd..d2e2307 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,4 @@ nina-python-init.py *.db *.parquet +*.wal diff --git a/pyproject.toml b/pyproject.toml index a5f0396..f459af0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,8 @@ authors = [ dependencies = [ "sqlglot", "duckdb", - "tqdm" + "tqdm", + "pyarrow" ] description = "" license = {text = "GPL-3.0+"} diff --git a/src/main.py b/src/main.py index 7c2cd8a..81f3874 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from collections import OrderedDict import duckdb +import pyarrow as pa import sqlglot import sqlglot.expressions from tqdm import tqdm @@ -25,7 +26,9 @@ def convert(value: str, dtype: sqlglot.expressions.DataType.Type): def get_typed_values(row, schema): - return tuple(convert(value=value, dtype=schema[key]) for key, value in row.items()) + return OrderedDict( + (key, convert(value=value, dtype=schema[key])) for key, value in row.items() + ) def get_typed_insert_query(table, values, columns): @@ -80,14 +83,13 @@ def handle_copy( entries.append(values) if len(entries) == chunks: - query = get_typed_insert_query( - table=table, values=entries, columns=columns - ) - connection.sql(query) - set_processed_line(connection, line_nr) + arrow_table = pa.Table.from_pylist(entries) entries = [] - - logging.debug(connection.sql("FROM duckdb_memory();")) + connection.register(f"_temp_{table}", arrow_table) + connection.sql(f"insert into {table} from _temp_{table}") + connection.unregister(f"_temp_{table}") + arrow_table = None + set_processed_line(connection, line_nr) else: logging.debug(f"skipping {line_nr}") @@ -96,9 +98,13 @@ def handle_copy( progress.update(1) if entries: - query = get_typed_insert_query(table=table, values=entries, columns=columns) - logging.debug(query) - connection.sql(query) + arrow_table = pa.Table.from_pylist(entries) + entries = [] + connection.register(f"_temp_{table}", arrow_table) + connection.sql(f"insert into {table} from _temp_{table}") + connection.unregister(f"_temp_{table}") + arrow_table = None + set_processed_line(connection, line_nr) return line_nr