Skip to content

Commit

Permalink
use pyarrow for efficient insert
Browse files Browse the repository at this point in the history
  • Loading branch information
nicokant committed Nov 29, 2024
1 parent d5ab1cb commit 4e07d0a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,4 @@ nina-python-init.py

*.db
*.parquet
*.wal
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ authors = [
dependencies = [
"sqlglot",
"duckdb",
"tqdm"
"tqdm",
"pyarrow"
]
description = ""
license = {text = "GPL-3.0+"}
Expand Down
28 changes: 17 additions & 11 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import OrderedDict

import duckdb
import pyarrow as pa
import sqlglot
import sqlglot.expressions
from tqdm import tqdm
Expand All @@ -25,7 +26,9 @@ def convert(value: str, dtype: sqlglot.expressions.DataType.Type):


def get_typed_values(row, schema):
return tuple(convert(value=value, dtype=schema[key]) for key, value in row.items())
return OrderedDict(
(key, convert(value=value, dtype=schema[key])) for key, value in row.items()
)


def get_typed_insert_query(table, values, columns):
Expand Down Expand Up @@ -80,14 +83,13 @@ def handle_copy(
entries.append(values)

if len(entries) == chunks:
query = get_typed_insert_query(
table=table, values=entries, columns=columns
)
connection.sql(query)
set_processed_line(connection, line_nr)
arrow_table = pa.Table.from_pylist(entries)
entries = []

logging.debug(connection.sql("FROM duckdb_memory();"))
connection.register(f"_temp_{table}", arrow_table)
connection.sql(f"insert into {table} from _temp_{table}")
connection.unregister(f"_temp_{table}")
arrow_table = None
set_processed_line(connection, line_nr)
else:
logging.debug(f"skipping {line_nr}")

Expand All @@ -96,9 +98,13 @@ def handle_copy(
progress.update(1)

if entries:
query = get_typed_insert_query(table=table, values=entries, columns=columns)
logging.debug(query)
connection.sql(query)
arrow_table = pa.Table.from_pylist(entries)
entries = []
connection.register(f"_temp_{table}", arrow_table)
connection.sql(f"insert into {table} from _temp_{table}")
connection.unregister(f"_temp_{table}")
arrow_table = None
set_processed_line(connection, line_nr)

return line_nr

Expand Down

0 comments on commit 4e07d0a

Please sign in to comment.