Skip to content

Commit

Permalink
Merge pull request #70 from peopledoc/customize-columns
Browse files Browse the repository at this point in the history
  • Loading branch information
Joachim Jablon authored Jul 3, 2020
2 parents ae8202f + 15f1e8b commit 71a4691
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 149 deletions.
24 changes: 24 additions & 0 deletions septentrion/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ def split_envvar_value(self, rv: str):
"immediately if it doesn't exist (env: SEPTENTRION_TABLE)",
default=configuration.DEFAULTS["table"],
)
@click.option(
"--version_column",
help="Name of the column describing the migration version in the migrations table. "
"(env: SEPTENTRION_VERSION_COLUMN)",
default=configuration.DEFAULTS["version_column"],
)
@click.option(
"--name_column",
help="Name of the column describing the migration name in the migrations table. "
"(env: SEPTENTRION_NAME_COLUMN)",
default=configuration.DEFAULTS["name_column"],
)
@click.option(
"--applied_at_column",
help="Name of the column describing the date at which the migration was applied "
"in the migrations table. (env: SEPTENTRION_APPLIED_AT_COLUMN)",
default=configuration.DEFAULTS["applied_at_column"],
)
@click.option(
"--migrations-root",
help="Path to the migration files (env: SEPTENTRION_MIGRATION_ROOT)",
Expand Down Expand Up @@ -180,6 +198,12 @@ def split_envvar_value(self, rv: str):
default=configuration.DEFAULTS["ignore_symlinks"],
help="Ignore migration files that are symlinks",
)
@click.option(
"--create-table/--no-create-table",
default=configuration.DEFAULTS["create_table"],
help="Controls whether the migrations table should be created if it doesn't exist. "
"(env: SEPTENTRION_CREATE_TABLE)",
)
def cli(ctx: click.Context, **kwargs):
if kwargs.pop("password_flag"):
password = click.prompt("Database password", hide_input=True)
Expand Down
4 changes: 4 additions & 0 deletions septentrion/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
ALL_CONFIGURATION_FILES = DEDICATED_CONFIGURATION_FILES + COMMON_CONFIGURATION_FILES

DEFAULTS = {
"create_table": True,
"table": "septentrion_migrations",
"version_column": "version",
"name_column": "name",
"applied_at_column": "applied_at",
"migrations_root": ".",
"schema_template": "schema_{}.sql",
"fixtures_template": "fixtures_{}.sql",
Expand Down
4 changes: 2 additions & 2 deletions septentrion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def initialize(**kwargs):

# All other commands will need the table to be created
logger.info("Ensuring migration table exists")
# TODO: this probably deserves an option
db.create_table(settings=settings) # idempotent
if settings.CREATE_TABLE:
db.create_table(settings=settings) # idempotent

return settings

Expand Down
51 changes: 33 additions & 18 deletions septentrion/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Iterable, Optional, Tuple

import psycopg2
import psycopg2.sql
from psycopg2.extensions import connection as Connection
from psycopg2.extras import DictCursor

Expand All @@ -15,6 +16,7 @@
logger = logging.getLogger(__name__)


@contextmanager
def get_connection(settings: configuration.Settings) -> Connection:
"""
Opens a PostgreSQL connection using psycopg2.
Expand Down Expand Up @@ -42,10 +44,13 @@ def get_connection(settings: configuration.Settings) -> Connection:
# default settings or its own environment variables (PGHOST, PGUSER, ...)
connection = psycopg2.connect(dsn="", **kwargs)

# Autocommit=true means we'll have more control over when the code is commited
# (even if this sounds strange)
connection.set_session(autocommit=True)
return connection
try:
# Autocommit=true means we'll have more control over when the code is commited
# (even if this sounds strange)
connection.set_session(autocommit=True)
yield connection
finally:
connection.close()


@contextmanager
Expand All @@ -55,7 +60,12 @@ def execute(
args: Tuple = tuple(),
commit: bool = False,
) -> Any:
query = " ".join(query.format(table=settings.TABLE).split())
query = psycopg2.sql.SQL(query).format(
table=psycopg2.sql.Identifier(settings.TABLE),
version_column=psycopg2.sql.Identifier(settings.VERSION_COLUMN),
name_column=psycopg2.sql.Identifier(settings.NAME_COLUMN),
applied_at_column=psycopg2.sql.Identifier(settings.APPLIED_AT_COLUMN),
)
with get_connection(settings=settings) as conn:
with conn.cursor(cursor_factory=DictCursor) as cur:
logger.debug("Executing %s -- Args: %s", query, args)
Expand Down Expand Up @@ -89,27 +99,27 @@ def __call__(self):


query_create_table = """
CREATE TABLE IF NOT EXISTS "{table}" (
CREATE TABLE IF NOT EXISTS {table} (
id BIGSERIAL PRIMARY KEY,
version TEXT,
name TEXT,
applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
{version_column} TEXT,
{name_column} TEXT,
{applied_at_column} TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
"""

query_max_version = """SELECT DISTINCT "version" FROM "{table}" """
query_max_version = """SELECT DISTINCT {version_column} FROM {table} """

query_write_migration = """
INSERT INTO "{table}" ("version", "name")
INSERT INTO {table} ({version_column}, {name_column})
VALUES (%s, %s)
"""

query_get_applied_migrations = """
SELECT name FROM "{table}" WHERE "version" = %s
SELECT {name_column} FROM {table} WHERE {version_column} = %s
"""

query_is_schema_initialized = """
SELECT TRUE FROM "{table}" LIMIT 1
SELECT TRUE FROM {table} LIMIT 1
"""


Expand Down Expand Up @@ -141,11 +151,16 @@ def get_applied_migrations(


def is_schema_initialized(settings: configuration.Settings) -> bool:
with Query(settings=settings, query=query_is_schema_initialized) as cur:
try:
return next(cur)
except StopIteration:
return False

try:
with Query(settings=settings, query=query_is_schema_initialized) as cur:
try:
return next(cur)
except StopIteration:
return False
except psycopg2.errors.UndefinedTable:
# If table doesn't exist
return False


def create_table(settings: configuration.Settings) -> None:
Expand Down
6 changes: 2 additions & 4 deletions septentrion/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ def show_migrations(**settings_kwargs):
core.describe_migration_plan(**lib_kwargs)


def migrate(*, migration_applied_callback=None, **settings_kwargs):
def migrate(**settings_kwargs):
lib_kwargs = initialize(settings_kwargs)
migration.migrate(
migration_applied_callback=migration_applied_callback, **lib_kwargs,
)
migration.migrate(**lib_kwargs,)


def is_schema_initialized(**settings_kwargs):
Expand Down
9 changes: 3 additions & 6 deletions septentrion/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@


def migrate(
settings: configuration.Settings,
stylist: style.Stylist = style.noop_stylist,
migration_applied_callback=None,
settings: configuration.Settings, stylist: style.Stylist = style.noop_stylist,
) -> None:

logger.info("Starting migrations")
Expand Down Expand Up @@ -63,8 +61,6 @@ def migrate(
run_script(settings=settings, path=path)
logger.info("Saving operation in the database")
db.write_migration(settings=settings, version=version, name=mig)
if migration_applied_callback is not None:
migration_applied_callback(version.original_string, mig)


def _load_schema_files(settings: configuration.Settings, schema_files: List[str]):
Expand Down Expand Up @@ -192,4 +188,5 @@ def run_script(settings: configuration.Settings, path: pathlib.Path) -> None:
logger.info("Running SQL file %s", path)
with io.open(path, "r", encoding="utf8") as f:
script = runner.Script(settings=settings, file_handler=f, path=path)
script.run(connection=db.get_connection(settings=settings))
with db.get_connection(settings=settings) as connection:
script.run(connection=connection)
45 changes: 45 additions & 0 deletions tests/acceptance/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from septentrion import core


def test_initialize(db):

settings_kwargs = {
# database connection settings
"host": db["host"],
"port": db["port"],
"username": db["user"],
"dbname": db["dbname"],
# migrate settings
"target_version": "1.1",
"migrations_root": "example_migrations",
}

# create table with no error
core.initialize(**settings_kwargs)
# action is idempotent, no error either
core.initialize(**settings_kwargs)


def test_initialize_customize_names(db):

settings_kwargs = {
# database connection settings
"host": db["host"],
"port": db["port"],
"username": db["user"],
"dbname": db["dbname"],
# migrate settings
"target_version": "1.1",
"migrations_root": "example_migrations",
# customize table
"table": "my_own_table",
# customize columns
"name_column": "name_custo",
"version_column": "version_custo",
"applied_at_column": "applied_custo",
}

# create table with no error
core.initialize(**settings_kwargs)
# action is idempotent, no error either
core.initialize(**settings_kwargs)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import psycopg2
import pytest

from septentrion import configuration


@pytest.fixture
def db():
Expand Down Expand Up @@ -39,3 +41,11 @@ def fake_db(mocker):
def temporary_directory(tmpdir):
with tmpdir.as_cwd():
yield


@pytest.fixture()
def settings_factory():
def _(**kwargs):
return configuration.Settings(**kwargs)

return _
21 changes: 21 additions & 0 deletions tests/integration/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import psycopg2.errors
import pytest

from septentrion import db as db_module


def test_execute(db, settings_factory):
settings = settings_factory(**db)
with db_module.execute(settings=settings, query="SELECT 1;") as cursor:
assert cursor.fetchall() == [[1]]


def test_execute_sql_injection(db, settings_factory):
settings = settings_factory(**db, table='"pg_enum"; -- SQLi')
with pytest.raises(psycopg2.errors.UndefinedTable) as exc_info:
with db_module.execute(
settings=settings, query="SELECT * FROM {table};"
) as cursor:
assert cursor.fetchall() == [[1]]

assert 'relation ""pg_enum"; -- SQLi" does not exist' in str(exc_info.value)
Loading

0 comments on commit 71a4691

Please sign in to comment.