From f5fda832619f520cdb6d33f3cd11666a0ea8be43 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 7 Jan 2020 20:09:39 -0500 Subject: [PATCH 1/7] MAINT: remove hard dependency on odo Odo is currently a hard dependency of warp_prism to convert the sqlalchemy types into a numpy dtypes. Odo is no longer actively maintained and breaks with newer versions of pandas. This change reimplements the needed functionality in warp_prism directly without using odo. This PR does leave the odo edge registration code so that existing users don't see a change in functionality. --- setup.py | 6 +-- warp_prism/__init__.py | 120 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 110 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 175244b..1f4ca48 100644 --- a/setup.py +++ b/setup.py @@ -42,17 +42,17 @@ ), ], install_requires=[ - 'datashape', 'numpy', 'pandas', 'sqlalchemy', 'psycopg2', - 'odo', 'toolz', - 'networkx<=1.11', ], extras_require={ 'dev': [ + 'odo', + 'pandas==0.18.1', + 'networkx<=1.11', 'flake8==3.3.0', 'pycodestyle==2.3.1', 'pyflakes==1.5.0', diff --git a/warp_prism/__init__.py b/warp_prism/__init__.py index 7a560a5..52e8fcf 100644 --- a/warp_prism/__init__.py +++ b/warp_prism/__init__.py @@ -1,13 +1,11 @@ from io import BytesIO +import numbers -from datashape import discover -from datashape.predicates import istabular import numpy as np -from odo import convert import pandas as pd import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as _postgresql from sqlalchemy.ext.compiler import compiles -from toolz import keymap from ._warp_prism import ( raw_to_arrays as _raw_to_arrays, @@ -18,7 +16,7 @@ __version__ = '0.1.1' -_typeid_map = keymap(np.dtype, _raw_typeid_map) +_typeid_map = {np.dtype(k): v for k, v in _raw_typeid_map.items()} _object_type_id = _raw_typeid_map['object'] @@ -66,14 +64,107 @@ def _compile_copy_to_binary_postgres(element, compiler, **kwargs): ) +types = {np.dtype(k): v for k, v in { + 'i8': sa.BigInteger, + 'i4': sa.Integer, + 'i2': sa.SmallInteger, + 'f4': sa.REAL, + 'f8': sa.FLOAT, + 'O': sa.Text, + 'M8[D]': sa.Date, + 'M8[us]': sa.DateTime, + '?': sa.Boolean, + "m8[D]": sa.Interval(second_precision=0, day_precision=9), + "m8[h]": sa.Interval(second_precision=0, day_precision=0), + "m8[m]": sa.Interval(second_precision=0, day_precision=0), + "m8[s]": sa.Interval(second_precision=0, day_precision=0), + "m8[ms]": sa.Interval(second_precision=3, day_precision=0), + "m8[us]": sa.Interval(second_precision=6, day_precision=0), + "m8[ns]": sa.Interval(second_precision=9, day_precision=0), +}.items()} + +_revtypes = dict(map(reversed, types.items())) +_revtypes.update({ + sa.DATETIME: np.dtype('M8[us]'), + sa.TIMESTAMP: np.dtype('M8[us]'), + sa.FLOAT: np.dtype('f8'), + sa.DATE: np.dtype('M8[D]'), + sa.BIGINT: np.dtype('i8'), + sa.INTEGER: np.dtype('i4'), + sa.BIGINT: np.dtype('i8'), + sa.types.NullType: np.dtype('O'), + sa.REAL: np.dtype('f4'), + sa.Float: np.dtype('f8'), +}) + +_precision_types = { + sa.Float, + _postgresql.base.DOUBLE_PRECISION, +} + + +def _precision_to_dtype(precision): + if isinstance(precision, numbers.Integral): + if 1 <= precision <= 24: + return np.dtype('f4') + elif 25 <= precision <= 53: + return np.dtype('f8') + raise ValueError('%s is not a supported precision' % precision) + + +_units_of_power = { + 0: 's', + 3: 'ms', + 6: 'us', + 9: 'ns' +} + + +def _discover_type(type_): + if isinstance(type_, sa.Interval): + if type_.second_precision is None and type_.day_precision is None: + return np.dtype('m8[us]') + elif type_.second_precision == 0 and type_.day_precision == 0: + return np.dtype('m8[s]') + + if (type_.second_precision in _units_of_power and + not type_.day_precision): + unit = _units_of_power[type_.second_precision] + elif type_.day_precision > 0: + unit = 'D' + else: + raise ValueError( + 'Cannot infer INTERVAL type_e with parameters' + 'second_precision=%d, day_precision=%d' % + (type_.second_precision, type_.day_precision), + ) + return np.dtype('m8[%s]' % unit) + if type(type_) in _precision_types and type_.precision is not None: + return _precision_to_dtype(type_.precision) + if type_ in _revtypes: + return _revtypes[type_] + if type(type_) in _revtypes: + return _revtypes[type(type_)] + if isinstance(type_, sa.Numeric): + raise ValueError('Cannot adapt numeric type to numpy dtype') + if isinstance(type_, (sa.String, sa.Unicode)): + return np.dtype('O') + else: + for k, v in _revtypes.items(): + if isinstance(k, type) and (isinstance(type_, k) or + hasattr(type_, 'impl') and + isinstance(type_.impl, k)): + return v + if k == type_: + return v + raise NotImplementedError('No SQL-numpy match for type %s' % type_) + + def _warp_prism_types(query): - for name, dtype in discover(query).measure.fields: + for col in query.columns: + dtype = _discover_type(col.type) try: - np_dtype = getattr(dtype, 'ty', dtype).to_numpy_dtype() - if np_dtype.kind == 'U': - yield _object_type_id - else: - yield _typeid_map[np_dtype] + yield _typeid_map[dtype] except KeyError: raise TypeError( 'warp_prism cannot query columns of type %s' % dtype, @@ -136,7 +227,7 @@ def to_arrays(query, *, bind=None): return {column_names[n]: v for n, v in enumerate(out)} -null_values = keymap(np.dtype, { +null_values = {np.dtype(k): v for k, v in { 'float32': np.nan, 'float64': np.nan, 'int16': np.nan, @@ -145,7 +236,7 @@ def to_arrays(query, *, bind=None): 'bool': np.nan, 'datetime64[ns]': np.datetime64('nat', 'ns'), 'object': None, -}) +}.items()} # alias because ``to_dataframe`` shadows this name _default_null_values_for_type = null_values @@ -216,6 +307,9 @@ def register_odo_dataframe_edge(): If the selectable is not in a postgres database, it will fallback to the default odo edge. """ + from odo import convert + from datashape.predicates import istabular + # estimating 8 times faster df_cost = convert.graph.edge[sa.sql.Select][pd.DataFrame]['cost'] / 8 From 389a349dda8cd16c526e654592a6eae93163f1cd Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Fri, 24 Jan 2020 00:29:38 -0500 Subject: [PATCH 2/7] MAINT: remove hard dependencies on sqlalchemy and pandas --- .gitignore | 1 + setup.py | 3 - warp_prism/__init__.py | 348 +--------------------------- warp_prism/odo.py | 47 ++++ warp_prism/query.py | 125 ++++++++++ warp_prism/sa.py | 12 + warp_prism/sql.py | 60 +++++ warp_prism/tests/__init__.py | 81 ++----- warp_prism/tests/test_warp_prism.py | 153 +++++++----- warp_prism/types.py | 128 ++++++++++ 10 files changed, 490 insertions(+), 468 deletions(-) create mode 100644 warp_prism/odo.py create mode 100644 warp_prism/query.py create mode 100644 warp_prism/sa.py create mode 100644 warp_prism/sql.py create mode 100644 warp_prism/types.py diff --git a/.gitignore b/.gitignore index 036cec9..51fb095 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,4 @@ benchmarks.db .dir-locals.el TAGS +.gdb_history diff --git a/setup.py b/setup.py index 1f4ca48..083247b 100644 --- a/setup.py +++ b/setup.py @@ -43,10 +43,7 @@ ], install_requires=[ 'numpy', - 'pandas', - 'sqlalchemy', 'psycopg2', - 'toolz', ], extras_require={ 'dev': [ diff --git a/warp_prism/__init__.py b/warp_prism/__init__.py index 52e8fcf..6bb2897 100644 --- a/warp_prism/__init__.py +++ b/warp_prism/__init__.py @@ -1,345 +1,9 @@ -from io import BytesIO -import numbers +from .query import to_arrays, to_dataframe +from .odo import register_odo_dataframe_edge -import numpy as np -import pandas as pd -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql as _postgresql -from sqlalchemy.ext.compiler import compiles +__version__ = '0.2.1' -from ._warp_prism import ( - raw_to_arrays as _raw_to_arrays, - typeid_map as _raw_typeid_map, -) - -__version__ = '0.1.1' - - -_typeid_map = {np.dtype(k): v for k, v in _raw_typeid_map.items()} -_object_type_id = _raw_typeid_map['object'] - - -class _CopyToBinary(sa.sql.expression.Executable, sa.sql.ClauseElement): - - def __init__(self, element, bind): - self.element = element - self._bind = bind = bind - - @property - def bind(self): - return self._bind - - -def literal_compile(s): - """Compile a sql expression with bind params inlined as literals. - - Parameters - ---------- - s : Selectable - The expression to compile. - - Returns - ------- - cs : str - An equivalent sql string. - """ - return str(s.compile(compile_kwargs={'literal_binds': True})) - - -@compiles(_CopyToBinary, 'postgresql') -def _compile_copy_to_binary_postgres(element, compiler, **kwargs): - selectable = element.element - return compiler.process( - sa.text( - 'COPY {stmt} TO STDOUT (FORMAT BINARY)'.format( - stmt=( - compiler.preparer.format_table(selectable) - if isinstance(selectable, sa.Table) else - '({})'.format(literal_compile(selectable)) - ), - ) - ), - **kwargs - ) - - -types = {np.dtype(k): v for k, v in { - 'i8': sa.BigInteger, - 'i4': sa.Integer, - 'i2': sa.SmallInteger, - 'f4': sa.REAL, - 'f8': sa.FLOAT, - 'O': sa.Text, - 'M8[D]': sa.Date, - 'M8[us]': sa.DateTime, - '?': sa.Boolean, - "m8[D]": sa.Interval(second_precision=0, day_precision=9), - "m8[h]": sa.Interval(second_precision=0, day_precision=0), - "m8[m]": sa.Interval(second_precision=0, day_precision=0), - "m8[s]": sa.Interval(second_precision=0, day_precision=0), - "m8[ms]": sa.Interval(second_precision=3, day_precision=0), - "m8[us]": sa.Interval(second_precision=6, day_precision=0), - "m8[ns]": sa.Interval(second_precision=9, day_precision=0), -}.items()} - -_revtypes = dict(map(reversed, types.items())) -_revtypes.update({ - sa.DATETIME: np.dtype('M8[us]'), - sa.TIMESTAMP: np.dtype('M8[us]'), - sa.FLOAT: np.dtype('f8'), - sa.DATE: np.dtype('M8[D]'), - sa.BIGINT: np.dtype('i8'), - sa.INTEGER: np.dtype('i4'), - sa.BIGINT: np.dtype('i8'), - sa.types.NullType: np.dtype('O'), - sa.REAL: np.dtype('f4'), - sa.Float: np.dtype('f8'), -}) - -_precision_types = { - sa.Float, - _postgresql.base.DOUBLE_PRECISION, -} - - -def _precision_to_dtype(precision): - if isinstance(precision, numbers.Integral): - if 1 <= precision <= 24: - return np.dtype('f4') - elif 25 <= precision <= 53: - return np.dtype('f8') - raise ValueError('%s is not a supported precision' % precision) - - -_units_of_power = { - 0: 's', - 3: 'ms', - 6: 'us', - 9: 'ns' -} - - -def _discover_type(type_): - if isinstance(type_, sa.Interval): - if type_.second_precision is None and type_.day_precision is None: - return np.dtype('m8[us]') - elif type_.second_precision == 0 and type_.day_precision == 0: - return np.dtype('m8[s]') - - if (type_.second_precision in _units_of_power and - not type_.day_precision): - unit = _units_of_power[type_.second_precision] - elif type_.day_precision > 0: - unit = 'D' - else: - raise ValueError( - 'Cannot infer INTERVAL type_e with parameters' - 'second_precision=%d, day_precision=%d' % - (type_.second_precision, type_.day_precision), - ) - return np.dtype('m8[%s]' % unit) - if type(type_) in _precision_types and type_.precision is not None: - return _precision_to_dtype(type_.precision) - if type_ in _revtypes: - return _revtypes[type_] - if type(type_) in _revtypes: - return _revtypes[type(type_)] - if isinstance(type_, sa.Numeric): - raise ValueError('Cannot adapt numeric type to numpy dtype') - if isinstance(type_, (sa.String, sa.Unicode)): - return np.dtype('O') - else: - for k, v in _revtypes.items(): - if isinstance(k, type) and (isinstance(type_, k) or - hasattr(type_, 'impl') and - isinstance(type_.impl, k)): - return v - if k == type_: - return v - raise NotImplementedError('No SQL-numpy match for type %s' % type_) - - -def _warp_prism_types(query): - for col in query.columns: - dtype = _discover_type(col.type) - try: - yield _typeid_map[dtype] - except KeyError: - raise TypeError( - 'warp_prism cannot query columns of type %s' % dtype, - ) - - -def _getbind(selectable, bind): - """Return an explicitly passed connection or infer the connection from - the selectable. - - Parameters - ---------- - selectable : sa.sql.Selectable - The selectable object being queried. - bind : bind or None - The explicit connection or engine to use to execute the query. - - Returns - ------- - bind : The bind which should be used to execute the query. - """ - if bind is None: - return selectable.bind - - if isinstance(bind, sa.engine.base.Engine): - return bind - - return sa.create_engine(bind) - - -def to_arrays(query, *, bind=None): - """Run the query returning a the results as np.ndarrays. - - Parameters - ---------- - query : sa.sql.Selectable - The query to run. This can be a select or a table. - bind : sa.Engine, optional - The engine used to create the connection. If not provided - ``query.bind`` will be used. - - Returns - ------- - arrays : dict[str, (np.ndarray, np.ndarray)] - A map from column name to the result arrays. The first array holds the - values and the second array is a boolean mask for NULLs. The values - where the mask is False are 0 interpreted by the type. - """ - # check types before doing any work - types = tuple(_warp_prism_types(query)) - - buf = BytesIO() - bind = _getbind(query, bind) - - stmt = _CopyToBinary(query, bind) - with bind.connect() as conn: - conn.connection.cursor().copy_expert(literal_compile(stmt), buf) - out = _raw_to_arrays(buf.getbuffer(), types) - column_names = query.c.keys() - return {column_names[n]: v for n, v in enumerate(out)} - - -null_values = {np.dtype(k): v for k, v in { - 'float32': np.nan, - 'float64': np.nan, - 'int16': np.nan, - 'int32': np.nan, - 'int64': np.nan, - 'bool': np.nan, - 'datetime64[ns]': np.datetime64('nat', 'ns'), - 'object': None, -}.items()} - -# alias because ``to_dataframe`` shadows this name -_default_null_values_for_type = null_values - - -def to_dataframe(query, *, bind=None, null_values=None): - """Run the query returning a the results as a pd.DataFrame. - - Parameters - ---------- - query : sa.sql.Selectable - The query to run. This can be a select or a table. - bind : sa.Engine, optional - The engine used to create the connection. If not provided - ``query.bind`` will be used. - null_values : dict[str, any] - The null values to use for each column. This falls back to - ``warp_prism.null_values`` for columns that are not specified. - - Returns - ------- - df : pd.DataFrame - A pandas DataFrame holding the results of the query. The columns - of the DataFrame will be named the same and be in the same order as the - query. - """ - arrays = to_arrays(query, bind=bind) - - if null_values is None: - null_values = {} - - for name, (array, mask) in arrays.items(): - if array.dtype.kind == 'i': - if not mask.all(): - try: - null = null_values[name] - except KeyError: - # no explicit override, cast to float and use NaN as null - array = array.astype('float64') - null = np.nan - - array[~mask] = null - - arrays[name] = array - continue - - if array.dtype.kind == 'M': - # pandas needs datetime64[ns], not ``us`` or ``D`` - array = array.astype('datetime64[ns]') - - try: - null = null_values[name] - except KeyError: - null = _default_null_values_for_type[array.dtype] - - array[~mask] = null - arrays[name] = array - - return pd.DataFrame(arrays, columns=[column.name for column in query.c]) - - -def register_odo_dataframe_edge(): - """Register an odo edge for sqlalchemy selectable objects to dataframe. - - This edge will have a lower cost that the default edge so it will be - selected as the fasted path. - - If the selectable is not in a postgres database, it will fallback to the - default odo edge. - """ - from odo import convert - from datashape.predicates import istabular - - # estimating 8 times faster - df_cost = convert.graph.edge[sa.sql.Select][pd.DataFrame]['cost'] / 8 - - @convert.register( - pd.DataFrame, - (sa.sql.Select, sa.sql.Selectable), - cost=df_cost, - ) - def select_or_selectable_to_frame(el, bind=None, dshape=None, **kwargs): - bind = _getbind(el, bind) - - if bind.dialect.name != 'postgresql': - # fall back to the general edge - raise NotImplementedError() - - return to_dataframe(el, bind=bind) - - # higher priority than df edge so that - # ``odo('select one_column from ...', list)`` returns a list of scalars - # instead of a list of tuples of length 1 - @convert.register( - pd.Series, - (sa.sql.Select, sa.sql.Selectable), - cost=df_cost - 1, - ) - def select_or_selectable_to_series(el, bind=None, dshape=None, **kwargs): - bind = _getbind(el, bind) - - if istabular(dshape) or bind.dialect.name != 'postgresql': - # fall back to the general edge - raise NotImplementedError() - - return to_dataframe(el, bind=bind).iloc[:, 0] +__all__ = [ + 'to_arrays', 'to_dataframe', 'register_odo_dataframe_edge', +] diff --git a/warp_prism/odo.py b/warp_prism/odo.py new file mode 100644 index 0000000..3704a00 --- /dev/null +++ b/warp_prism/odo.py @@ -0,0 +1,47 @@ +def register_odo_dataframe_edge(): + """Register an odo edge for sqlalchemy selectable objects to dataframe. + + This edge will have a lower cost that the default edge so it will be + selected as the fasted path. + + If the selectable is not in a postgres database, it will fallback to the + default odo edge. + """ + from datashape.predicates import istabular + from odo import convert + import pandas as pd + import sqlalchemy as sa + + # estimating 8 times faster + df_cost = convert.graph.edge[sa.sql.Select][pd.DataFrame]['cost'] / 8 + + @convert.register( + pd.DataFrame, + (sa.sql.Select, sa.sql.Selectable), + cost=df_cost, + ) + def select_or_selectable_to_frame(el, bind=None, dshape=None, **kwargs): + bind = _getbind(el, bind) + + if bind.dialect.name != 'postgresql': + # fall back to the general edge + raise NotImplementedError() + + return to_dataframe(el, bind=bind) + + # higher priority than df edge so that + # ``odo('select one_column from ...', list)`` returns a list of scalars + # instead of a list of tuples of length 1 + @convert.register( + pd.Series, + (sa.sql.Select, sa.sql.Selectable), + cost=df_cost - 1, + ) + def select_or_selectable_to_series(el, bind=None, dshape=None, **kwargs): + bind = _getbind(el, bind) + + if istabular(dshape) or bind.dialect.name != 'postgresql': + # fall back to the general edge + raise NotImplementedError() + + return to_dataframe(el, bind=bind).iloc[:, 0] diff --git a/warp_prism/query.py b/warp_prism/query.py new file mode 100644 index 0000000..ebba412 --- /dev/null +++ b/warp_prism/query.py @@ -0,0 +1,125 @@ +from functools import wraps +import io + +try: + import pandas as pd +except ImportError: + pd = None +import numpy as np + +from .sql import getbind, mogrify +from .types import query_typeids +from ._warp_prism import raw_to_arrays as _raw_to_arrays + + +def to_arrays(query, params=None, *, bind=None): + """Run the query returning a the results as np.ndarrays. + + Parameters + ---------- + query : str or sa.sql.Selectable + The query to run. This can be a select or a table. + params : dict or tuple or None + Bind parameters for ``query``. + bind : psycopg2.connection, sa.Engine, or sa.Connection, optional + The engine used to create the connection. If not provided + ``query.bind`` will be used. + + Returns + ------- + arrays : dict[str, (np.ndarray, np.ndarray)] + A map from column name to the result arrays. The first array holds the + values and the second array is a boolean mask for NULLs. The values + where the mask is False are 0 interpreted by the type. + """ + + buf = io.BytesIO() + bind = getbind(query, bind) + + with bind.cursor() as cur: + bound_query = mogrify(cur, query, params) + column_names, typeids = query_typeids(cur, bound_query) + cur.copy_expert('copy (%s) to stdout binary' % bound_query, buf) + + out = _raw_to_arrays(buf.getbuffer(), typeids) + + return {column_names[n]: v for n, v in enumerate(out)} + + +null_values = {np.dtype(k): v for k, v in { + 'float32': np.nan, + 'float64': np.nan, + 'int16': np.nan, + 'int32': np.nan, + 'int64': np.nan, + 'bool': np.nan, + 'datetime64[ns]': np.datetime64('nat', 'ns'), + 'object': None, +}.items()} + +# alias because ``to_dataframe`` shadows this name +_default_null_values_for_type = null_values + + +def to_dataframe(query, params=None, *, bind=None, null_values=None): + """Run the query returning a the results as a pd.DataFrame. + + Parameters + ---------- + query : str or sa.sql.Selectable + The query to run. This can be a select or a table. + params : dict or tuple or None + Bind parameters for ``query``. + bind : psycopg2.connection, sa.Engine, or sa.Connection, optional + The engine used to create the connection. If not provided + ``query.bind`` will be used. + null_values : dict[str, any] + The null values to use for each column. This falls back to + ``warp_prism.null_values`` for columns that are not specified. + + Returns + ------- + df : pd.DataFrame + A pandas DataFrame holding the results of the query. The columns + of the DataFrame will be named the same and be in the same order as the + query. + """ + arrays = to_arrays(query, bind=bind) + + if null_values is None: + null_values = {} + + for name, (array, mask) in arrays.items(): + if array.dtype.kind == 'i': + if not mask.all(): + try: + null = null_values[name] + except KeyError: + # no explicit override, cast to float and use NaN as null + array = array.astype('float64') + null = np.nan + + array[~mask] = null + + arrays[name] = array + continue + + if array.dtype.kind == 'M': + # pandas needs datetime64[ns], not ``us`` or ``D`` + array = array.astype('datetime64[ns]') + + try: + null = null_values[name] + except KeyError: + null = _default_null_values_for_type[array.dtype] + + array[~mask] = null + arrays[name] = array + + return pd.DataFrame(arrays) + + +if pd is None: + @wraps(to_dataframe) + def to_dataframe(*args, **kwargs): + raise NotImplementedError('to_dataframe requires pandas') diff --git a/warp_prism/sa.py b/warp_prism/sa.py new file mode 100644 index 0000000..55b5fa3 --- /dev/null +++ b/warp_prism/sa.py @@ -0,0 +1,12 @@ +def literal_compile(s): + """Compile a sql expression with bind params inlined as literals. + Parameters + ---------- + s : Selectable + The expression to compile. + Returns + ------- + cs : str + An equivalent sql string. + """ + return str(s.compile(compile_kwargs={'literal_binds': True})) diff --git a/warp_prism/sql.py b/warp_prism/sql.py new file mode 100644 index 0000000..4102345 --- /dev/null +++ b/warp_prism/sql.py @@ -0,0 +1,60 @@ +import psycopg2 +try: + import sqlalchemy as sa +except ImportError: + sa = None + + +def _sa_literal_compile(s): + """Compile a sql expression with variables inlined as literals. + + Parameters + ---------- + s : sa.sql.Selectable + The expression to compile. + + Returns + ------- + cs : str + An equivalent sql string. + """ + return str(s.compile(compile_kwargs={'literal_binds': True})) + + +def mogrify(cursor, query, params): + if sa is not None: + if isinstance(query, sa.Table): + query = _sa_literal_compile(sa.select(query.c)) + elif isinstance(query, sa.sql.Selectable): + query = _sa_literal_compile(query) + + return cursor.mogrify(query, params).decode('utf-8') + + +def getbind(query, bind): + """Get the connection to use for a query. + + Parameters + ---------- + query : str or sa.sql.Selectable + The query to run. + bind : psycopg2.extensions.connection or sa.engine.base.Engine or None + The explicitly provided bind. + + Returns + ------- + bind : psycopg2.extensions.connection + The connection to use for the query. + """ + if bind is not None: + if sa is None or isinstance(bind, psycopg2.extensions.connection): + return bind + + if isinstance(bind, sa.engine.base.Engine): + return bind.connect().connection.connection + + return sa.create_engine(bind).connect().connection.connection + elif sa is None or not isinstance(query, sa.sql.Selectable): + raise TypeError("missing 1 required argument: 'bind'") + else: + return query.bind.connect().connection.connection diff --git a/warp_prism/tests/__init__.py b/warp_prism/tests/__init__.py index 26b3c26..06e0911 100644 --- a/warp_prism/tests/__init__.py +++ b/warp_prism/tests/__init__.py @@ -2,40 +2,12 @@ from uuid import uuid4 import warnings -from odo import resource -import sqlalchemy as sa +import psycopg2 -def _dropdb(root_conn, db_name): - root_conn.execute('COMMIT') - root_conn.execute('DROP DATABASE %s' % db_name) - - -@contextmanager -def disposable_engine(uri): - """An engine which is disposed on exit. - - Parameters - ---------- - uri : str - The uri to the db. - - Yields - ------ - engine : sa.engine.Engine - """ - engine = resource(uri) - try: - yield engine - finally: - engine.dispose() - - -_pg_stat_activity = sa.Table( - 'pg_stat_activity', - sa.MetaData(), - sa.Column('pid', sa.Integer), -) +def _dropdb(cur, db_name): + cur.execute('COMMIT') + cur.execute('DROP DATABASE %s' % db_name) @contextmanager @@ -45,35 +17,26 @@ def tmp_db_uri(): db_name = '_warp_prism_test_' + uuid4().hex root = 'postgresql://localhost/' uri = root + db_name - with disposable_engine(root + 'postgres') as e, e.connect() as root_conn: - root_conn.execute('COMMIT') - root_conn.execute('CREATE DATABASE %s' % db_name) + with psycopg2.connect(root + 'postgres') as conn, conn.cursor() as cur: + cur.execute('COMMIT') + cur.execute('CREATE DATABASE %s' % db_name) try: yield uri finally: - resource(uri).dispose() try: - _dropdb(root_conn, db_name) - except sa.exc.OperationalError: - # We couldn't drop the db. The most likely cause is that there - # are active queries. Even more likely is that these are - # rollbacks because there was an exception somewhere inside the - # tests. We will cancel all the running queries and try to drop - # the database again. - pid = _pg_stat_activity.c.pid - root_conn.execute( - sa.select( - (sa.func.pg_terminate_backend(pid),), - ).where( - pid != sa.func.pg_backend_pid(), - ) + cur.execute(""" + select + pg_terminate_backend(pid) + from + pg_stat_activity + where + pid != pg_backend_pid() + """) + _dropdb(cur, db_name) + except: # pragma: no cover # noqa + # The database wasn't cleaned up. Just tell the user to deal + # with this manually. + warnings.warn( + "leaking database '%s', please manually delete this" % + db_name, ) - try: - _dropdb(root_conn, db_name) - except sa.exc.OperationalError: # pragma: no cover - # The database STILL wasn't cleaned up. Just tell the user - # to deal with this manually. - warnings.warn( - "leaking database '%s', please manually delete this" % - db_name, - ) diff --git a/warp_prism/tests/test_warp_prism.py b/warp_prism/tests/test_warp_prism.py index 9ec9532..e67e4e8 100644 --- a/warp_prism/tests/test_warp_prism.py +++ b/warp_prism/tests/test_warp_prism.py @@ -2,24 +2,19 @@ import struct from uuid import uuid4 -from datashape import var, R, Option, dshape import numpy as np -from odo import resource, odo import pandas as pd +import psycopg2 import pytest -import sqlalchemy as sa from warp_prism._warp_prism import ( postgres_signature, raw_to_arrays, test_overflow_operations as _test_overflow_operations, ) -from warp_prism import ( - to_arrays, - to_dataframe, - null_values as null_values_for_type, - _typeid_map, -) +from warp_prism import to_arrays, to_dataframe +from warp_prism.query import null_values as null_values_for_type +from warp_prism.types import dtype_to_typeid from warp_prism.tests import tmp_db_uri as tmp_db_uri_ctx @@ -34,6 +29,25 @@ def tmp_table_uri(tmp_db_uri): return '%s::%s%s' % (tmp_db_uri, 'table_', uuid4().hex) +def item(a): + """Convert a value to a Python built in type (not a numpy type). + + Parameters + ---------- + a : any + The value to convert. + + Returns + ------- + item : any + The base Python type equivalent value. + """ + try: + return a.item() + except AttributeError: + return a + + def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): """Check the data roundtrip through postgres using warp_prism to read the data @@ -41,37 +55,41 @@ def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): Parameters ---------- table_uri : str - The uri to a unique table. + The uri for the table. data : np.array The input data. dtype : str The dtype of the data. - sqltype : type - The sqlalchemy type of the data. + sqltype : str + The sql type of the data. """ - input_dataframe = pd.DataFrame({'a': data}) - table = odo(input_dataframe, table_uri, dshape=var * R['a': dtype]) - # Ensure that odo created the table correctly. If these fail the other - # tests are not well defined. - assert table.columns.keys() == ['a'] - assert isinstance(table.columns['a'].type, sqltype) - - arrays = to_arrays(table) + db, table = table_uri.split('::') + with psycopg2.connect(db) as conn, conn.cursor() as cur: + cur.execute('create table %s (a %s)' % (table, sqltype)) + cur.executemany( + 'insert into {} values (%s)'.format(table), + [(item(v),) for v in data], + ) + + query = 'select * from %s' % table + arrays = to_arrays(query, bind=conn) + output_dataframe = to_dataframe(query, bind=conn) + assert len(arrays) == 1 array, mask = arrays['a'] assert (array == data).all() assert mask.all() - output_dataframe = to_dataframe(table) - pd.util.testing.assert_frame_equal(output_dataframe, input_dataframe) + expected_dataframe = pd.DataFrame({'a': data}) + pd.util.testing.assert_frame_equal(output_dataframe, expected_dataframe) @pytest.mark.parametrize('dtype,sqltype,start,stop,step', ( - ('int16', sa.SmallInteger, 0, 5000, 1), - ('int32', sa.Integer, 0, 5000, 1), - ('int64', sa.BigInteger, 0, 5000, 1), - ('float32', sa.REAL, 0, 2500, 0.5), - ('float64', sa.FLOAT, 0, 2500, 0.5), + ('int16', 'int2', 0, 5000, 1), + ('int32', 'int4', 0, 5000, 1), + ('int64', 'int8', 0, 5000, 1), + ('float32', 'float4', 0, 2500, 0.5), + ('float64', 'float8', 0, 2500, 0.5), )) def test_numeric_type_nonnull(tmp_table_uri, dtype, @@ -85,12 +103,12 @@ def test_numeric_type_nonnull(tmp_table_uri, def test_bool_type_nonnull(tmp_table_uri): data = np.array([True] * 2500 + [False] * 2500, dtype=bool) - check_roundtrip_nonnull(tmp_table_uri, data, 'bool', sa.Boolean) + check_roundtrip_nonnull(tmp_table_uri, data, 'bool', 'bool') def test_string_type_nonnull(tmp_table_uri): data = np.array(list(ascii_letters) * 200, dtype='object') - check_roundtrip_nonnull(tmp_table_uri, data, 'string', sa.String) + check_roundtrip_nonnull(tmp_table_uri, data, 'object', 'text') def test_datetime_type_nonnull(tmp_table_uri): @@ -98,7 +116,7 @@ def test_datetime_type_nonnull(tmp_table_uri): '2000', '2016', ).values.astype('datetime64[us]') - check_roundtrip_nonnull(tmp_table_uri, data, 'datetime', sa.DateTime) + check_roundtrip_nonnull(tmp_table_uri, data, 'datetime64[us]', 'timestamp') def test_date_type_nonnull(tmp_table_uri): @@ -106,7 +124,7 @@ def test_date_type_nonnull(tmp_table_uri): '2000', '2016', ).values.astype('datetime64[D]') - check_roundtrip_nonnull(tmp_table_uri, data, 'date', sa.Date) + check_roundtrip_nonnull(tmp_table_uri, data, 'datetime64[D]', 'date') def check_roundtrip_null_values(table_uri, @@ -128,30 +146,37 @@ def check_roundtrip_null_values(table_uri, The input data. dtype : str The dtype of the data. - sqltype : type - The sqlalchemy type of the data. + sqltype : str + The sql type of the data. null_values : dict[str, any] The value to coerce ``NULL`` to. astype : bool, optional Coerce the input data to the given dtype before making assertions about the output data. """ - table = resource(table_uri, dshape=var * R['a': Option(dtype)]) - # Ensure that odo created the table correctly. If these fail the other - # tests are not well defined. - assert table.columns.keys() == ['a'] - assert isinstance(table.columns['a'].type, sqltype) - table.insert().values([{'a': v} for v in data]).execute() - - arrays = to_arrays(table) + db, table = table_uri.split('::') + with psycopg2.connect(db) as conn, conn.cursor() as cur: + cur.execute('create table %s (a %s)' % (table, sqltype)) + cur.executemany( + 'insert into {} values (%s)'.format(table), + [(item(v),) for v in data], + ) + + query = 'select * from %s' % table + arrays = to_arrays(query, bind=conn) + output_dataframe = to_dataframe( + query, + null_values=null_values, + bind=conn, + ) + assert len(arrays) == 1 array, actual_mask = arrays['a'] assert (actual_mask == mask).all() assert (array[mask] == data[mask]).all() - output_dataframe = to_dataframe(table, null_values=null_values) if astype: - data = data.astype(dshape(dtype).measure.to_numpy_dtype()) + data = data.astype(dtype, copy=False) expected_dataframe = pd.DataFrame({'a': data}) expected_dataframe[~mask] = null_values.get( 'a', @@ -187,8 +212,8 @@ def check_roundtrip_null(table_uri, The input data. dtype : str The dtype of the data. - sqltype : type - The sqlalchemy type of the data. + sqltype : str + The sql type of the data. null : any The value to coerce ``NULL`` to. astype : bool, optional @@ -207,11 +232,11 @@ def check_roundtrip_null(table_uri, @pytest.mark.parametrize('dtype,sqltype,start,stop,step,null', ( - ('int16', sa.SmallInteger, 0, 5000, 1, -1), - ('int32', sa.Integer, 0, 5000, 1, -1), - ('int64', sa.BigInteger, 0, 5000, 1, -1), - ('float32', sa.REAL, 0, 2500, 0.5, -1.0), - ('float64', sa.FLOAT, 0, 2500, 0.5, -1.0), + ('int16', 'int2', 0, 5000, 1, -1), + ('int32', 'int4', 0, 5000, 1, -1), + ('int64', 'int8', 0, 5000, 1, -1), + ('float32', 'float4', 0, 2500, 0.5, -1.0), + ('float64', 'float8', 0, 2500, 0.5, -1.0), )) def test_numeric_type_null(tmp_table_uri, dtype, @@ -227,9 +252,9 @@ def test_numeric_type_null(tmp_table_uri, @pytest.mark.parametrize('dtype,sqltype', ( - ('int16', sa.SmallInteger), - ('int32', sa.Integer), - ('int64', sa.BigInteger), + ('int16', 'int2'), + ('int32', 'int4'), + ('int64', 'int8'), )) def test_numeric_default_null_promote(tmp_table_uri, dtype, sqltype): data = np.arange(0, 100, dtype=dtype).astype(object) @@ -242,7 +267,7 @@ def test_bool_type_null(tmp_table_uri): data = np.array([True] * 2500 + [False] * 2500, dtype=bool).astype(object) mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None - check_roundtrip_null(tmp_table_uri, data, 'bool', sa.Boolean, False, mask) + check_roundtrip_null(tmp_table_uri, data, 'bool', 'bool', False, mask) def test_string_type_null(tmp_table_uri): @@ -252,8 +277,8 @@ def test_string_type_null(tmp_table_uri): check_roundtrip_null( tmp_table_uri, data, - 'string', - sa.String, + 'object', + 'text', 'ayy lmao', mask, ) @@ -272,8 +297,8 @@ def test_datetime_type_null(tmp_table_uri): check_roundtrip_null( tmp_table_uri, data, - 'datetime', - sa.DateTime, + 'datetime64[us]', + 'timestamp', pd.Timestamp('1995-12-13').to_datetime64(), mask, ) @@ -290,9 +315,9 @@ def test_date_type_null(tmp_table_uri): check_roundtrip_null( tmp_table_uri, data, + 'datetime64[D]', 'date', - sa.Date, - pd.Timestamp('1995-12-13').to_datetime64(), + np.datetime64('1995-12-13', 'ns'), mask, astype=True, ) @@ -341,7 +366,7 @@ def test_invalid_numeric_size(dtype): ) with pytest.raises(ValueError) as e: - raw_to_arrays(input_data, (_typeid_map[dtype],)) + raw_to_arrays(input_data, (dtype_to_typeid(dtype),)) assert str(e.value) == 'mismatched %s size: %s' % ( dtype.name, @@ -363,7 +388,7 @@ def test_invalid_datetime_size(): dtype = np.dtype('datetime64[us]') with pytest.raises(ValueError) as e: - raw_to_arrays(input_data, (_typeid_map[dtype],)) + raw_to_arrays(input_data, (dtype_to_typeid(dtype),)) assert str(e.value) == 'mismatched datetime size: 7' @@ -377,7 +402,7 @@ def test_invalid_date_size(): dtype = np.dtype('datetime64[D]') with pytest.raises(ValueError) as e: - raw_to_arrays(input_data, (_typeid_map[dtype],)) + raw_to_arrays(input_data, (dtype_to_typeid(dtype),)) assert str(e.value) == 'mismatched date size: 3' @@ -405,7 +430,7 @@ def test_invalid_text(): # we put the invalid unicode as the first column to test that we can clean # up the cell in the second column before we have written a string there - str_typeid = _typeid_map[np.dtype(object)] + str_typeid = dtype_to_typeid(np.dtype(object)) with pytest.raises(UnicodeDecodeError): raw_to_arrays(input_data, (str_typeid, str_typeid)) diff --git a/warp_prism/types.py b/warp_prism/types.py new file mode 100644 index 0000000..9f99d6c --- /dev/null +++ b/warp_prism/types.py @@ -0,0 +1,128 @@ +import numpy as np + +from ._warp_prism import typeid_map as _raw_typeid_map + +_typeid_map = {np.dtype(k): v for k, v in _raw_typeid_map.items()} + + +def dtype_to_typeid(dtype): + """Convert a numpy dtype to a warp_prism type id. + + Parameters + ---------- + dtype : np.dtype + The numpy dtype to convert. + + Returns + ------- + typeid : int + The type id for ``dtype``. + """ + try: + return _typeid_map[dtype] + except KeyError: + raise ValueError('no warp_prism type id for dtype %s' % dtype) + + +_oid_map = { + 16: np.dtype('?'), + + # text + 17: np.dtype('O'), + 18: np.dtype('S1'), + 19: np.dtype('O'), + 25: np.dtype('O'), + + # int + 20: np.dtype('i8'), + 21: np.dtype('i2'), + 23: np.dtype('i4'), + 1042: np.dtype('O'), + 1043: np.dtype('O'), + + # float + 700: np.dtype('f4'), + 701: np.dtype('f8'), + + # date(time) + 1082: np.dtype('M8[D]'), + 1114: np.dtype('M8[us]'), + 1184: np.dtype('M8[us]'), +} + + +def oid_to_dtype(oid): + """Get a numpy dtype from postgres oid. + + Parameters + ---------- + oid : int + The oid to convert. + + Returns + ------- + dtype : np.dtype + The corresponding numpy dtype. + """ + try: + return _oid_map[oid] + except KeyError: + raise ValueError('cannot convert oid %s to numpy dtype' % oid) + + +def query_dtypes(cursor, bound_query): + """Get the numpy dtypes for each column returned by a query. + + Parameters + ---------- + cursor : psycopg2.cursor + The psycopg2 cursor to use to get the type information. + bound_query : str + The query to check the types of with all parameters bound. + + Returns + ------- + names : tuple[str] + The column names. + dtypes : tuple[np.dtype] + The column dtypes. + """ + cursor.execute('select * from (%s) a limit 0' % bound_query) + invalid = [] + names = [] + dtypes = [] + for c in cursor.description: + try: + dtypes.append(oid_to_dtype(c.type_code)) + except ValueError: + invalid.append(c) + else: + names.append(c.name) + + if invalid: + raise ValueError( + 'columns cannot be converted to numpy dtype: %s' % invalid + ) + + return tuple(names), tuple(dtypes) + + +def query_typeids(cursor, bound_query): + """Get the warp_prism typeid for each column returned by a query. + + Parameters + ---------- + cursor : psycopg2.cursor + The psycopg2 cursor to use to get the type information. + bound_query : str + The query to check the types of with all parameters bound. + + Returns + ------- + names : tuple[str] + The column names. + typeids : tuple[int] + The warp_prism typeid for each column.. + """ + names, dtypes = query_dtypes(cursor, bound_query) + return names, tuple(dtype_to_typeid(dtype) for dtype in dtypes) From f856748bba78b64d349742d606339ce4195e771d Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 4 Feb 2020 12:12:19 -0500 Subject: [PATCH 3/7] TST: make pandas tests optional --- .travis.yml | 7 +++ warp_prism/tests/test_warp_prism.py | 90 +++++++++++++++++------------ 2 files changed, 59 insertions(+), 38 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6b100df..b362570 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,9 @@ python: - "3.4" - "3.5" - "3.6" +env: + - PANDAS=0 + - PANDAS=1 env: - CC=gcc CXX=g++ @@ -21,6 +24,10 @@ addons: install: - ${CC} --version - pip install numpy + - if [ $PANDAS == "1" ]; + pip install pandas; + python -c "import pandas;print(pandas.__version__)"; + fi - python -c "import numpy;print(numpy.__version__)" - pip install -e .[dev] diff --git a/warp_prism/tests/test_warp_prism.py b/warp_prism/tests/test_warp_prism.py index e67e4e8..12505de 100644 --- a/warp_prism/tests/test_warp_prism.py +++ b/warp_prism/tests/test_warp_prism.py @@ -3,7 +3,6 @@ from uuid import uuid4 import numpy as np -import pandas as pd import psycopg2 import pytest @@ -17,6 +16,11 @@ from warp_prism.types import dtype_to_typeid from warp_prism.tests import tmp_db_uri as tmp_db_uri_ctx +try: + import pandas as pd +except ImportError: + pd = None + @pytest.fixture(scope='module') def tmp_db_uri(): @@ -73,15 +77,21 @@ def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): query = 'select * from %s' % table arrays = to_arrays(query, bind=conn) - output_dataframe = to_dataframe(query, bind=conn) + + if pd is not None: + output_dataframe = to_dataframe(query, bind=conn) assert len(arrays) == 1 array, mask = arrays['a'] assert (array == data).all() assert mask.all() - expected_dataframe = pd.DataFrame({'a': data}) - pd.util.testing.assert_frame_equal(output_dataframe, expected_dataframe) + if pd is not None: + expected_dataframe = pd.DataFrame({'a': data}) + pd.util.testing.assert_frame_equal( + output_dataframe, + expected_dataframe, + ) @pytest.mark.parametrize('dtype,sqltype,start,stop,step', ( @@ -112,17 +122,19 @@ def test_string_type_nonnull(tmp_table_uri): def test_datetime_type_nonnull(tmp_table_uri): - data = pd.date_range( + data = np.arange( '2000', '2016', - ).values.astype('datetime64[us]') + dtype='M8[D]', + ).astype('datetime64[us]') check_roundtrip_nonnull(tmp_table_uri, data, 'datetime64[us]', 'timestamp') def test_date_type_nonnull(tmp_table_uri): - data = pd.date_range( + data = np.arange( '2000', '2016', + dtype='M8[D]', ).values.astype('datetime64[D]') check_roundtrip_nonnull(tmp_table_uri, data, 'datetime64[D]', 'date') @@ -164,33 +176,37 @@ def check_roundtrip_null_values(table_uri, query = 'select * from %s' % table arrays = to_arrays(query, bind=conn) - output_dataframe = to_dataframe( - query, - null_values=null_values, - bind=conn, - ) + if pd is not None: + output_dataframe = to_dataframe( + query, + null_values=null_values, + bind=conn, + ) assert len(arrays) == 1 array, actual_mask = arrays['a'] assert (actual_mask == mask).all() + assert (array[mask] == data[mask]).all() - if astype: - data = data.astype(dtype, copy=False) - expected_dataframe = pd.DataFrame({'a': data}) - expected_dataframe[~mask] = null_values.get( - 'a', - null_values_for_type[ - array.dtype - if array.dtype.kind != 'M' else - np.dtype('datetime64[ns]') - ], - ) - pd.util.testing.assert_frame_equal( - output_dataframe, - expected_dataframe, - check_dtype=False, - ) + if pd is not None: + if astype: + data = data.astype(dtype, copy=False) + + expected_dataframe = pd.DataFrame({'a': data}) + expected_dataframe[~mask] = null_values.get( + 'a', + null_values_for_type[ + array.dtype + if array.dtype.kind != 'M' else + np.dtype('datetime64[ns]') + ], + ) + pd.util.testing.assert_frame_equal( + output_dataframe, + expected_dataframe, + check_dtype=False, + ) def check_roundtrip_null(table_uri, @@ -285,13 +301,12 @@ def test_string_type_null(tmp_table_uri): def test_datetime_type_null(tmp_table_uri): - data = np.array( - list(pd.date_range( - '2000', - '2016', - )), - dtype=object, - )[:-1] # slice the last element off to have an even number + data = np.arange( + '2000', + '2016', + dtype='M8[D]', + ).astype('O')[:-1] # slice the last element off to have an even number + mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None check_roundtrip_null( @@ -299,7 +314,7 @@ def test_datetime_type_null(tmp_table_uri): data, 'datetime64[us]', 'timestamp', - pd.Timestamp('1995-12-13').to_datetime64(), + np.datetime64('1995-12-13', 'ns'), mask, ) @@ -382,8 +397,7 @@ def test_invalid_datetime_size(): input_data = _pack_as_invalid_size_postgres_binary_data( 'q', # int64_t (quadword) 8, - (pd.Timestamp('2014-01-01').to_datetime64().astype('datetime64[us]') + - _epoch_offset).view('int64'), + (np.datetime64('2014-01-01', 'us') + _epoch_offset).view('int64'), ) dtype = np.dtype('datetime64[us]') From 61b65ad31500479e4c64d03878c580b3ab55f146 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 4 Feb 2020 13:15:50 -0500 Subject: [PATCH 4/7] TST: test sqlalchemy --- .travis.yml | 11 +- warp_prism/tests/test_warp_prism.py | 188 ++++++++++++++++++++++------ 2 files changed, 154 insertions(+), 45 deletions(-) diff --git a/.travis.yml b/.travis.yml index b362570..e0fd184 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,9 +4,6 @@ python: - "3.4" - "3.5" - "3.6" -env: - - PANDAS=0 - - PANDAS=1 env: - CC=gcc CXX=g++ @@ -23,12 +20,10 @@ addons: install: - ${CC} --version - - pip install numpy - - if [ $PANDAS == "1" ]; - pip install pandas; - python -c "import pandas;print(pandas.__version__)"; - fi + - pip install numpy pandas sqlalchemy - python -c "import numpy;print(numpy.__version__)" + - python -c "import pandas;print(pandas.__version__)"; + - python -c "import sqlalchemy;print(sqlalchemy.__version__)"; - pip install -e .[dev] script: diff --git a/warp_prism/tests/test_warp_prism.py b/warp_prism/tests/test_warp_prism.py index 12505de..653c55c 100644 --- a/warp_prism/tests/test_warp_prism.py +++ b/warp_prism/tests/test_warp_prism.py @@ -21,6 +21,14 @@ except ImportError: pd = None +try: + import sqlalchemy as sa + + use_sqlalchemy = pytest.mark.parametrize('use_sqlalchemy', [False, True]) +except ImportError: + sa = None + use_sqlalchemy = pytest.mark.parametrize('use_sqlalchemy', [False]) + @pytest.fixture(scope='module') def tmp_db_uri(): @@ -52,7 +60,7 @@ def item(a): return a -def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): +def check_roundtrip_nonnull(table_uri, data, dtype, sqltype, use_sqlalchemy): """Check the data roundtrip through postgres using warp_prism to read the data @@ -66,17 +74,29 @@ def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): The dtype of the data. sqltype : str The sql type of the data. + use_sqlalchemy : bool + Use sqlalchemy for the query instead of psycopg2. """ db, table = table_uri.split('::') - with psycopg2.connect(db) as conn, conn.cursor() as cur: - cur.execute('create table %s (a %s)' % (table, sqltype)) - cur.executemany( - 'insert into {} values (%s)'.format(table), - [(item(v),) for v in data], - ) + with psycopg2.connect(db) as conn: + with conn.cursor() as cur: + cur.execute('create table %s (a %s)' % (table, sqltype)) + cur.executemany( + 'insert into {} values (%s)'.format(table), + [(item(v),) for v in data], + ) + cur.execute('commit') - query = 'select * from %s' % table - arrays = to_arrays(query, bind=conn) + if use_sqlalchemy: + bind = sa.create_engine(db) + meta = sa.MetaData(bind) + t = sa.Table(table, meta, autoload=True) + query = sa.select(t.c) + else: + bind = conn + query = 'select * from %s' % table + + arrays = to_arrays(query, bind=bind) if pd is not None: output_dataframe = to_dataframe(query, bind=conn) @@ -94,6 +114,7 @@ def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): ) +@use_sqlalchemy @pytest.mark.parametrize('dtype,sqltype,start,stop,step', ( ('int16', 'int2', 0, 5000, 1), ('int32', 'int4', 0, 5000, 1), @@ -106,37 +127,72 @@ def test_numeric_type_nonnull(tmp_table_uri, sqltype, start, stop, - step): + step, + use_sqlalchemy): data = np.arange(start, stop, step, dtype=dtype) - check_roundtrip_nonnull(tmp_table_uri, data, dtype, sqltype) + check_roundtrip_nonnull( + tmp_table_uri, + data, + dtype, + sqltype, + use_sqlalchemy, + ) -def test_bool_type_nonnull(tmp_table_uri): +@use_sqlalchemy +def test_bool_type_nonnull(tmp_table_uri, use_sqlalchemy): data = np.array([True] * 2500 + [False] * 2500, dtype=bool) - check_roundtrip_nonnull(tmp_table_uri, data, 'bool', 'bool') + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'bool', + 'bool', + use_sqlalchemy, + ) -def test_string_type_nonnull(tmp_table_uri): +@use_sqlalchemy +def test_string_type_nonnull(tmp_table_uri, use_sqlalchemy): data = np.array(list(ascii_letters) * 200, dtype='object') - check_roundtrip_nonnull(tmp_table_uri, data, 'object', 'text') + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'object', + 'text', + use_sqlalchemy, + ) -def test_datetime_type_nonnull(tmp_table_uri): +@use_sqlalchemy +def test_datetime_type_nonnull(tmp_table_uri, use_sqlalchemy): data = np.arange( '2000', '2016', dtype='M8[D]', ).astype('datetime64[us]') - check_roundtrip_nonnull(tmp_table_uri, data, 'datetime64[us]', 'timestamp') + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'datetime64[us]', + 'timestamp', + use_sqlalchemy, + ) -def test_date_type_nonnull(tmp_table_uri): +@use_sqlalchemy +def test_date_type_nonnull(tmp_table_uri, use_sqlalchemy): data = np.arange( '2000', '2016', dtype='M8[D]', - ).values.astype('datetime64[D]') - check_roundtrip_nonnull(tmp_table_uri, data, 'datetime64[D]', 'date') + ).astype('datetime64[D]') + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'datetime64[D]', + 'date', + use_sqlalchemy, + ) def check_roundtrip_null_values(table_uri, @@ -145,6 +201,7 @@ def check_roundtrip_null_values(table_uri, sqltype, null_values, mask, + use_sqlalchemy, *, astype=False): """Check the data roundtrip through postgres using warp_prism to read the @@ -162,19 +219,33 @@ def check_roundtrip_null_values(table_uri, The sql type of the data. null_values : dict[str, any] The value to coerce ``NULL`` to. + mask : np.ndarray[bool] + A mask indicating which values are non-null. + use_sqlalchemy : bool + Use sqlalchemy for the query instead of psycopg2. astype : bool, optional Coerce the input data to the given dtype before making assertions about the output data. """ db, table = table_uri.split('::') - with psycopg2.connect(db) as conn, conn.cursor() as cur: - cur.execute('create table %s (a %s)' % (table, sqltype)) - cur.executemany( - 'insert into {} values (%s)'.format(table), - [(item(v),) for v in data], - ) + with psycopg2.connect(db) as conn: + with conn.cursor() as cur: + cur.execute('create table %s (a %s)' % (table, sqltype)) + cur.executemany( + 'insert into {} values (%s)'.format(table), + [(item(v),) for v in data], + ) + cur.execute('commit') + + if use_sqlalchemy: + bind = sa.create_engine(db) + meta = sa.MetaData(bind) + t = sa.Table(table, meta, autoload=True) + query = sa.select(t.c) + else: + bind = conn + query = 'select * from %s' % table - query = 'select * from %s' % table arrays = to_arrays(query, bind=conn) if pd is not None: output_dataframe = to_dataframe( @@ -215,6 +286,7 @@ def check_roundtrip_null(table_uri, sqltype, null, mask, + use_sqlalchemy, *, astype=False): """Check the data roundtrip through postgres using warp_prism to read the @@ -232,6 +304,10 @@ def check_roundtrip_null(table_uri, The sql type of the data. null : any The value to coerce ``NULL`` to. + mask : np.ndarray[bool] + A mask indicating which values are non-null. + use_sqlalchemy : bool + Use sqlalchemy for the query instead of psycopg2. astype : bool, optional Coerce the input data to the given dtype before making assertions about the output data. @@ -243,10 +319,12 @@ def check_roundtrip_null(table_uri, sqltype, {'a': null}, mask, + use_sqlalchemy, astype=astype, ) +@use_sqlalchemy @pytest.mark.parametrize('dtype,sqltype,start,stop,step,null', ( ('int16', 'int2', 0, 5000, 1, -1), ('int32', 'int4', 0, 5000, 1, -1), @@ -260,33 +338,64 @@ def test_numeric_type_null(tmp_table_uri, start, stop, step, - null): + null, + use_sqlalchemy): data = np.arange(start, stop, step, dtype=dtype).astype(object) mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None - check_roundtrip_null(tmp_table_uri, data, dtype, sqltype, null, mask) + check_roundtrip_null( + tmp_table_uri, + data, + dtype, + sqltype, + null, + mask, + use_sqlalchemy, + ) +@use_sqlalchemy @pytest.mark.parametrize('dtype,sqltype', ( ('int16', 'int2'), ('int32', 'int4'), ('int64', 'int8'), )) -def test_numeric_default_null_promote(tmp_table_uri, dtype, sqltype): +def test_numeric_default_null_promote(tmp_table_uri, + dtype, + sqltype, + use_sqlalchemy): data = np.arange(0, 100, dtype=dtype).astype(object) mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None - check_roundtrip_null_values(tmp_table_uri, data, dtype, sqltype, {}, mask) + check_roundtrip_null_values( + tmp_table_uri, + data, + dtype, + sqltype, + {}, + mask, + use_sqlalchemy, + ) -def test_bool_type_null(tmp_table_uri): +@use_sqlalchemy +def test_bool_type_null(tmp_table_uri, use_sqlalchemy): data = np.array([True] * 2500 + [False] * 2500, dtype=bool).astype(object) mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None - check_roundtrip_null(tmp_table_uri, data, 'bool', 'bool', False, mask) + check_roundtrip_null( + tmp_table_uri, + data, + 'bool', + 'bool', + False, + mask, + use_sqlalchemy, + ) -def test_string_type_null(tmp_table_uri): +@use_sqlalchemy +def test_string_type_null(tmp_table_uri, use_sqlalchemy): data = np.array(list(ascii_letters) * 200, dtype='object') mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None @@ -297,15 +406,17 @@ def test_string_type_null(tmp_table_uri): 'text', 'ayy lmao', mask, + use_sqlalchemy, ) -def test_datetime_type_null(tmp_table_uri): +@use_sqlalchemy +def test_datetime_type_null(tmp_table_uri, use_sqlalchemy): data = np.arange( '2000', '2016', dtype='M8[D]', - ).astype('O')[:-1] # slice the last element off to have an even number + ).astype('M8[us]').astype('O') mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None @@ -316,10 +427,12 @@ def test_datetime_type_null(tmp_table_uri): 'timestamp', np.datetime64('1995-12-13', 'ns'), mask, + use_sqlalchemy, ) -def test_date_type_null(tmp_table_uri): +@use_sqlalchemy +def test_date_type_null(tmp_table_uri, use_sqlalchemy): data = np.arange( '2000', '2016', @@ -334,6 +447,7 @@ def test_date_type_null(tmp_table_uri): 'date', np.datetime64('1995-12-13', 'ns'), mask, + use_sqlalchemy, astype=True, ) From ca26dc0bef2611b19912de2a1f4d2d2ab6c597cc Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 4 Feb 2020 13:41:32 -0500 Subject: [PATCH 5/7] BUG: fix imports --- warp_prism/odo.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/warp_prism/odo.py b/warp_prism/odo.py index 3704a00..336df31 100644 --- a/warp_prism/odo.py +++ b/warp_prism/odo.py @@ -1,3 +1,7 @@ +from .query import to_dataframe +from .sql import getbind + + def register_odo_dataframe_edge(): """Register an odo edge for sqlalchemy selectable objects to dataframe. @@ -21,7 +25,7 @@ def register_odo_dataframe_edge(): cost=df_cost, ) def select_or_selectable_to_frame(el, bind=None, dshape=None, **kwargs): - bind = _getbind(el, bind) + bind = getbind(el, bind) if bind.dialect.name != 'postgresql': # fall back to the general edge @@ -38,7 +42,7 @@ def select_or_selectable_to_frame(el, bind=None, dshape=None, **kwargs): cost=df_cost - 1, ) def select_or_selectable_to_series(el, bind=None, dshape=None, **kwargs): - bind = _getbind(el, bind) + bind = getbind(el, bind) if istabular(dshape) or bind.dialect.name != 'postgresql': # fall back to the general edge From ba0cc966eb2483d9da48464267756dd538c6584c Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 4 Feb 2020 13:46:48 -0500 Subject: [PATCH 6/7] TST: remove old python --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index e0fd184..9e22c04 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,9 @@ language: python sudo: false python: - - "3.4" - "3.5" - "3.6" + - "3.7" env: - CC=gcc CXX=g++ From 831c5832d22be04f3256fa170e08c56aa035a08f Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 4 Feb 2020 14:54:52 -0500 Subject: [PATCH 7/7] BLD: no 3.7 for now --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9e22c04..2c22faa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,6 @@ sudo: false python: - "3.5" - "3.6" - - "3.7" env: - CC=gcc CXX=g++