Skip to content

Commit

Permalink
Merge pull request #73 from peopledoc/psql
Browse files Browse the repository at this point in the history
Use psql to run migrations
  • Loading branch information
k4nar authored Jul 17, 2020
2 parents 71a4691 + faa427e commit 4b04905
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 257 deletions.
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ tool would be django-north_.
But maybe you're not using Django. You would like a standalone migration tool. You're
looking for Septentrion. Congratulations, you've found it.

Septentrion supports PostgreSQL 9.6+ and Python 3.6+.
Septentrion supports PostgreSQL 9.6+ & Python 3.6+, and requires the ``psql``
executable to be present on the system.

.. _South: https://bitbucket.org/andrewgodwin/south/src
.. _django-north: https://github.com/peopledoc/django-north
Expand Down
2 changes: 1 addition & 1 deletion septentrion/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
folder.expanduser().resolve() / DEDICATED_CONFIGURATION_FILENAME
for folder in CONFIGURATION_PATHS
]
print(DEDICATED_CONFIGURATION_FILES)

# These are the files that can contain septentrion configuration, but
# it's also ok if they exist and they don't configure septentrion.
COMMON_CONFIGURATION_FILES = [pathlib.Path("./setup.cfg")]
Expand Down
3 changes: 1 addition & 2 deletions septentrion/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,4 @@ def run_script(settings: configuration.Settings, path: pathlib.Path) -> None:
logger.info("Running SQL file %s", path)
with io.open(path, "r", encoding="utf8") as f:
script = runner.Script(settings=settings, file_handler=f, path=path)
with db.get_connection(settings=settings) as connection:
script.run(connection=connection)
script.run()
170 changes: 53 additions & 117 deletions septentrion/runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import logging
import pathlib
import subprocess
from typing import Iterable

import sqlparse
from psycopg2.extensions import cursor as Cursor

from septentrion import configuration, files
from septentrion import configuration

logger = logging.getLogger(__name__)

Expand All @@ -14,124 +12,62 @@ class SQLRunnerException(Exception):
pass


def clean_sql_code(code: str) -> str:
output = ""
for line in code.split("\n"):
stripped_line = line.strip()
if stripped_line == "\\timing":
continue
if stripped_line.startswith("--"):
continue
output += stripped_line + "\n"
return output


class Block(object):
def __init__(self):
self.closed = False
self.content = ""

def append_line(self, line: str) -> None:
if self.closed:
raise SQLRunnerException("Block closed !")
self.content += line

def close(self) -> None:
if self.closed:
raise SQLRunnerException("Block closed !")
self.closed = True

def run(self, cursor: Cursor) -> int:
statements = sqlparse.parse(self.content)

content = "".join(str(stmt) for stmt in statements)
if content != self.content:
raise SQLRunnerException("sqlparse failed to properly split input")

rows = 0
for statement in statements:
if clean_sql_code(str(statement)).strip() in ("", ";"):
# Sometimes sqlparse keeps the empty lines here,
# this could negatively affect libpq
continue
logger.debug("Running one statement... <<%s>>", str(statement))
cursor.execute(str(statement).replace("\\timing\n", ""))
logger.debug("Affected %s rows", cursor.rowcount)
rows += cursor.rowcount
return rows


class SimpleBlock(Block):
def run(self, cursor):
statements = clean_sql_code(self.content)
cursor.execute(statements)


class MetaBlock(Block):
def __init__(self, command: str):
super(MetaBlock, self).__init__()
self.command = command
if command != "do-until-0":
raise SQLRunnerException("Unexpected command {}".format(command))

def run(self, cursor: Cursor) -> int:
total_rows = 0
# Simply call super().run in a loop...
delta = 0
batch_delta = -1
while batch_delta != 0:
batch_delta = 0
logger.debug("Running one block in a loop")
delta = super(MetaBlock, self).run(cursor)
if delta > 0:
total_rows += delta
batch_delta = delta
logger.debug("Batch delta done : %s", batch_delta)
return total_rows


class Script(object):
class Script:
def __init__(
self,
settings: configuration.Settings,
file_handler: Iterable[str],
path: pathlib.Path,
):
file_lines = list(file_handler)
is_manual = files.is_manual_migration(
migration_path=path, migration_contents=file_lines
)
self.settings = settings
if is_manual:
self.block_list = [Block()]
elif self.contains_non_transactional_keyword(file_lines):
self.block_list = [Block()]
else:
self.block_list = [SimpleBlock()]
for line in file_lines:
if line.startswith("--meta-psql:") and is_manual:
self.block_list[-1].close()
command = line.split(":")[1].strip()
if command == "done":
# create a new basic block
self.block_list.append(Block())
else:
# create a new meta block
self.block_list.append(MetaBlock(command))
else:
self.block_list[-1].append_line(line)
self.block_list[-1].close()

def run(self, connection):
with connection.cursor() as cursor:
for block in self.block_list:
block.run(cursor)
self.file_lines = list(file_handler)
self.path = path

def contains_non_transactional_keyword(self, file_lines: Iterable[str]) -> bool:
keywords = self.settings.NON_TRANSACTIONAL_KEYWORD
for line in file_lines:
for kw in keywords:
if kw.lower() in line.lower():
return True

return False
def run(self):
if any("--meta-psql:" in line for line in self.file_lines):
self._run_with_meta_loop()
else:
self._run_simple()

def _run_simple(self):
creds = []
for name in ["HOST", "PORT", "DBNAME", "USERNAME", "PASSWORD"]:
value = getattr(self.settings, name)
if value:
creds += ["--" + name.lower(), value]

try:
cmd = subprocess.run(
["psql", *creds, "--set", "ON_ERROR_STOP=on", "-f", self.path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
)
except FileNotFoundError:
raise RuntimeError(
"Septentrion requires the 'psql' executable to be present in "
"the PATH."
)
except subprocess.CalledProcessError as e:
msg = "Error during migration: {}".format(e.stderr.decode("utf-8"))
raise SQLRunnerException(msg) from e

return cmd.stdout.decode("utf-8")

def _run_with_meta_loop(self):
KEYWORDS = ["INSERT", "UPDATE", "DELETE"]
rows_remaining = True

while rows_remaining:
out = self._run_simple()

# we can stop once all the write operations return 0 rows
for line in out.split("\n"):
rows_remaining = any(
keyword in line and keyword + " 0" not in line
for keyword in KEYWORDS
)

# we still have work to do, we can go back to the main loop
if rows_remaining:
break
100 changes: 100 additions & 0 deletions tests/integration/test_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import io
import os

import pytest

from septentrion.db import Query
from septentrion.runner import Script, SQLRunnerException


@pytest.fixture()
def run_script(db, settings_factory, tmp_path):
settings = settings_factory(**db)

def _run_script(script):
path = tmp_path / "script.sql"
path.write_text(script)

with io.open(path, "r", encoding="utf8") as f:
script = Script(settings, f, path)
script.run()

return _run_script


@pytest.fixture()
def env():
environ = {**os.environ}
yield os.environ
os.environ.clear()
os.environ.update(environ)


def test_run_simple(db, settings_factory, run_script):
settings = settings_factory(**db)

run_script("CREATE TABLE foo ();")

query = "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE tablename = 'foo'"
with Query(settings, query) as cur:
assert [row[0] for row in cur] == [1]


def test_run_simple_error(run_script):
with pytest.raises(SQLRunnerException) as err:
run_script("CREATE TABLE ???")

assert 'ERROR: syntax error at or near "???"' in str(err.value)


def test_run_psql_not_found(run_script, env):
env["PATH"] = ""

with pytest.raises(RuntimeError) as err:
run_script("SELECT 1;")

assert str(err.value) == (
"Septentrion requires the 'psql' executable to be present in the PATH."
)


def test_run_with_meta_loop(db, settings_factory, run_script):
settings = settings_factory(**db)

# create a table with 10 rows
script = """
CREATE TABLE foo(value int);
INSERT INTO foo SELECT generate_series(1, 10);
"""
run_script(script)

query = "SELECT * FROM foo ORDER BY value"
with Query(settings, query) as cur:
assert [row[0] for row in cur] == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# update the rows 3 by 3 to multiply them by 100
script = """
--meta-psql:do-until-0
WITH to_update AS (
SELECT value FROM foo WHERE value < 100 LIMIT 3
)
UPDATE foo SET value = foo.value * 100
FROM to_update WHERE foo.value = to_update.value
--meta-psql:done
"""
run_script(script)

query = "SELECT * FROM foo ORDER BY value"
with Query(settings, query) as cur:
assert [row[0] for row in cur] == [
100,
200,
300,
400,
500,
600,
700,
800,
900,
1000,
]
Loading

0 comments on commit 4b04905

Please sign in to comment.