Skip to content

Commit

Permalink
progress bar, skip already copied data
Browse files Browse the repository at this point in the history
  • Loading branch information
nicokant committed Nov 29, 2024
1 parent 16e94c0 commit d5ab1cb
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 30 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ authors = [
# See https://www.python.org/dev/peps/pep-0621/
dependencies = [
"sqlglot",
"duckdb"
"duckdb",
"tqdm"
]
description = ""
license = {text = "GPL-3.0+"}
Expand All @@ -23,7 +24,7 @@ tools = [
]

[project.scripts]
"pg_dedump" = "main:start"
"pg_dedump" = "main:cli"

[tool.ruff]
fix = true
Expand Down
10 changes: 7 additions & 3 deletions src/helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import sqlglot
import sqlglot.expressions


def get_sql_block(input): # noqa: A002
def get_sql_block(input, progress): # noqa: A002
text_block = ""
for line in input:
for index, line in enumerate(input):
progress.update(1)
if line.lstrip().startswith("--"):
continue
text_block += line.rstrip()
if text_block.endswith(";"):
yield text_block
yield text_block, index
text_block = ""


def remove_schema(expression):
def transformer(node):
if isinstance(node, sqlglot.expressions.Create):
node.args["exists"] = True
if node.key == "table":
del node.args["db"]
if (
Expand Down
163 changes: 138 additions & 25 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import argparse
import fileinput
import logging
import pathlib
from collections import OrderedDict

import duckdb
import sqlglot
import sqlglot.expressions
from tqdm import tqdm

from helpers import get_sql_block, remove_schema

DEBUG = False

logging.basicConfig(level=(logging.DEBUG if DEBUG else logging.INFO))


def convert(value: str, dtype: sqlglot.expressions.DataType.Type):
if value == "\\N":
Expand Down Expand Up @@ -52,65 +53,177 @@ def handle_create(statement, connection):


def handle_copy(
statement: sqlglot.Expression, stream, connection, table, schema, chunks: int
statement: sqlglot.Expression,
stream,
connection,
table,
schema,
chunks: int,
index: int,
progress: tqdm,
):
columns = [i.this for i in statement.this.expressions]
entries = []
connection.sql("begin;")

line_nr = index

line = next(stream).rstrip()

while line != r"\.":
row = OrderedDict(zip(columns, line.rstrip().split("\t"), strict=False))
values = get_typed_values(row, schema)
entries.append(values)
last_line_nr = get_last_processed_line(connection)

if len(entries) == chunks:
query = get_typed_insert_query(table=table, values=entries, columns=columns)
logging.debug(query)
connection.sql(query)
entries = []
while line != r"\.":
if line_nr > last_line_nr:
values = get_typed_values(
OrderedDict(zip(columns, line.rstrip().split("\t"), strict=False)),
schema,
)
entries.append(values)

if len(entries) == chunks:
query = get_typed_insert_query(
table=table, values=entries, columns=columns
)
connection.sql(query)
set_processed_line(connection, line_nr)
entries = []

logging.debug(connection.sql("FROM duckdb_memory();"))
else:
logging.debug(f"skipping {line_nr}")

line = next(stream).rstrip()
line_nr += 1
progress.update(1)

if entries:
query = get_typed_insert_query(table=table, values=entries, columns=columns)
logging.debug(query)
connection.sql(query)
connection.sql("commit;")

return line_nr


def sql_parsing_iterator(stream):
for _block in get_sql_block(stream):
def sql_parsing_iterator(stream, progress):
for _block, index in get_sql_block(stream, progress):
parsed = sqlglot.parse_one(_block)

if parsed.key in [
"create",
"copy",
]:
yield parsed
yield parsed, index


def set_processed_line(connection, line_nr):
connection.sql(
sqlglot.expressions.update("_stats.processed_line", {"line_nr": line_nr}).sql(
dialect="duckdb"
)
)


def get_last_processed_line(connection):
(processed_line,) = connection.sql(
"select line_nr from _stats.processed_line limit 1"
).fetchone()
return processed_line

def start() -> None:
connection = duckdb.connect()
chunks = 100000

def start(*args, chunks, db, output_type, files, drop_db, total, **kwargs) -> None:
if drop_db:
pathlib.Path(db).unlink(missing_ok=True)
connection = duckdb.connect(db)
TABLE_REGISTRY = {}

with fileinput.input(encoding="utf-8") as f:
for statement in sql_parsing_iterator(f):
connection.sql("""
create schema if not exists _stats;
create table if not exists _stats.processed_line as (
select 0 as line_nr
)
""")

tqdm_params = {}
if total:
tqdm_params["total"] = total

with fileinput.input(
files=files if len(files) > 0 else ("-",), encoding="utf-8"
) as stream:
progress = tqdm(unit=" lines", **tqdm_params)
for statement, index in sql_parsing_iterator(stream, progress):
if statement.key == "create":
table, schema = handle_create(statement, connection)
TABLE_REGISTRY[table] = schema
elif statement.key == "copy":
table = statement.this.this.this.this
schema = TABLE_REGISTRY[table]
handle_copy(statement, f, connection, table, schema, chunks=chunks)
handle_copy(
statement,
stream,
connection,
table,
schema,
chunks=chunks,
index=index,
progress=progress,
)

for table in TABLE_REGISTRY.keys():
connection.sql(f"from {table}").write_parquet(f"{table}.parquet")
if output_type == "parquet" or db == ":memory:":
connection.sql(f"from {table}").write_parquet(f"{table}.parquet")
else:
raise Exception("output format not supported")

connection.close()


def cli():
parser = argparse.ArgumentParser(
prog="pg_dedump",
description="extract tables from postgres dumps",
add_help=True,
)
parser.add_argument(
"files",
metavar="FILE",
nargs="*",
help="files to read, if empty, stdin is used",
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-r", "--drop-db", action="store_true")
parser.add_argument(
"-c",
"--chunks",
type=int,
default=10000,
help="Chunk insert size",
required=False,
)
parser.add_argument(
"-t",
"--total",
type=int,
help="Total lines",
required=False,
)
parser.add_argument(
"-d",
"--db",
default=":memory:",
help="Name of the dump database",
required=False,
)
parser.add_argument(
"-f",
"--output-type",
default="parquet",
help="Format of the tables output",
required=False,
)
args = parser.parse_args()
logging.basicConfig(level=(logging.DEBUG if args.verbose else logging.INFO))
start(**vars(args))


if __name__ == "__main__":
start()
cli()

0 comments on commit d5ab1cb

Please sign in to comment.