Skip to content

Commit

Permalink
[PECO-1803] Splitting the PySql connector into the core and the non c…
Browse files Browse the repository at this point in the history
…ore part (#417)

* Implemented ColumnQueue to test the fetchall without pyarrow

Removed token

removed token

* order of fields in row corrected

* Changed the folder structure and tested the basic setup to work

* Refractored the code to make connector to work

* Basic Setup of connector, core and sqlalchemy is working

* Basic integration of core, connect and sqlalchemy is working

* Setup working dynamic change from ColumnQueue to ArrowQueue

* Refractored the test code and moved to respective folders

* Added the unit test for column_queue

Fixed __version__

Fix

* venv_main added to git ignore

* Added code for merging columnar table

* Merging code for columnar

* Fixed the retry_close sesssion test issue with logging

* Fixed the databricks_sqlalchemy tests and introduced pytest.ini for the sqla_testing

* Added pyarrow_test mark on pytest

* Fixed databricks.sqlalchemy to databricks_sqlalchemy imports

* Added poetry.lock

* Added dist folder

* Changed the pyproject.toml

* Minor Fix

* Added the pyarrow skip tag on unit tests and tested their working

* Fixed the Decimal and timestamp conversion issue in non arrow pipeline

* Removed not required files and reformatted

* Fixed test_retry error

* Changed the folder structure to src / databricks

* Removed the columnar non arrow flow to another PR

* Moved the README to the root

* removed columnQueue instance

* Revmoved databricks_sqlalchemy dependency in core

* Changed the pysql_supports_arrow predicate, introduced changes in the pyproject.toml

* Ran the black formatter with the original version

* Extra .py removed from all the __init__.py files names

* Undo formatting check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* Check

* BIG UPDATE

* Refeactor code

* Refractor

* Fixed versioning

* Minor refractoring

* Minor refractoring
  • Loading branch information
jprakash-db authored Sep 24, 2024
1 parent 9cb1ea3 commit 4099939
Show file tree
Hide file tree
Showing 89 changed files with 162 additions and 4,232 deletions.
File renamed without changes.
23 changes: 23 additions & 0 deletions databricks_sql_connector/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[tool.poetry]
name = "databricks-sql-connector"
version = "3.5.0"
description = "Databricks SQL Connector for Python"
authors = ["Databricks <[email protected]>"]
license = "Apache-2.0"


[tool.poetry.dependencies]
databricks_sql_connector_core = { version = ">=1.0.0", extras=["all"]}
databricks_sqlalchemy = { version = ">=1.0.0", optional = true }

[tool.poetry.extras]
databricks_sqlalchemy = ["databricks_sqlalchemy"]

[tool.poetry.urls]
"Homepage" = "https://github.com/databricks/databricks-sql-python"
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

File renamed without changes.
24 changes: 6 additions & 18 deletions pyproject.toml → databricks_sql_connector_core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
[tool.poetry]
name = "databricks-sql-connector"
version = "3.3.0"
description = "Databricks SQL Connector for Python"
name = "databricks-sql-connector-core"
version = "1.0.0"
description = "Databricks SQL Connector core for Python"
authors = ["Databricks <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
packages = [{ include = "databricks", from = "src" }]
include = ["CHANGELOG.md"]

[tool.poetry.dependencies]
python = "^3.8.0"
thrift = ">=0.16.0,<0.21.0"
pandas = [
{ version = ">=1.2.5,<2.3.0", python = ">=3.8" }
]
pyarrow = ">=14.0.1,<17"

lz4 = "^4.0.2"
requests = "^2.18.1"
oauthlib = "^3.1.0"
numpy = [
{ version = "^1.16.6", python = ">=3.8,<3.11" },
{ version = "^1.23.4", python = ">=3.11" },
]
sqlalchemy = { version = ">=2.0.21", optional = true }
openpyxl = "^3.0.10"
alembic = { version = "^1.0.11", optional = true }
urllib3 = ">=1.26"
pyarrow = {version = ">=14.0.1,<17", optional = true}

[tool.poetry.extras]
sqlalchemy = ["sqlalchemy"]
alembic = ["sqlalchemy", "alembic"]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"
Expand All @@ -43,8 +33,6 @@ pytest-dotenv = "^0.5.2"
"Homepage" = "https://github.com/databricks/databricks-sql-python"
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"

[tool.poetry.plugins."sqlalchemy.dialects"]
"databricks" = "databricks.sqlalchemy:DatabricksDialect"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -62,5 +50,5 @@ markers = {"reviewed" = "Test case has been reviewed by Databricks"}
minversion = "6.0"
log_cli = "false"
log_cli_level = "INFO"
testpaths = ["tests", "src/databricks/sqlalchemy/test_local"]
testpaths = ["tests", "databricks_sql_connector_core/tests"]
env_files = ["test.env"]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
import pyarrow
import requests
import json
import os
Expand Down Expand Up @@ -43,6 +42,10 @@
TSparkParameter,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -977,14 +980,14 @@ def fetchmany(self, size: int) -> List[Row]:
else:
raise Error("There is no active result set")

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchall_arrow()
else:
raise Error("There is no active result set")

def fetchmany_arrow(self, size) -> pyarrow.Table:
def fetchmany_arrow(self, size) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchmany_arrow(size)
Expand Down Expand Up @@ -1171,7 +1174,7 @@ def _convert_arrow_table(self, table):
def rownumber(self):
return self._next_row_index

def fetchmany_arrow(self, size: int) -> pyarrow.Table:
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.
Expand All @@ -1196,7 +1199,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:

return results

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

import pyarrow
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
Expand Down Expand Up @@ -37,6 +36,11 @@
convert_column_based_set_to_arrow_table,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

unsafe_logger = logging.getLogger("databricks.sql.unsafe")
Expand Down Expand Up @@ -652,6 +656,12 @@ def _get_metadata_resp(self, op_handle):

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema):

if pyarrow is None:
raise ImportError(
"pyarrow is required to convert Hive schema to Arrow schema"
)

def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -858,7 +868,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ssl import SSLContext

import lz4.frame
import pyarrow

from databricks.sql import OperationalError, exc
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
Expand All @@ -28,16 +27,21 @@

import logging

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)


class ResultSetQueue(ABC):
@abstractmethod
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int):
pass

@abstractmethod
def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self):
pass


Expand Down Expand Up @@ -100,7 +104,7 @@ def build_queue(
class ArrowQueue(ResultSetQueue):
def __init__(
self,
arrow_table: pyarrow.Table,
arrow_table: "pyarrow.Table",
n_valid_rows: int,
start_row_index: int = 0,
):
Expand All @@ -115,7 +119,7 @@ def __init__(
self.arrow_table = arrow_table
self.n_valid_rows = n_valid_rows

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""Get upto the next n rows of the Arrow dataframe"""
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
# Note that the table.slice API is not the same as Python's slice
Expand All @@ -124,7 +128,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
self.cur_row_index += slice.num_rows
return slice

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
slice = self.arrow_table.slice(
self.cur_row_index, self.n_valid_rows - self.cur_row_index
)
Expand Down Expand Up @@ -184,7 +188,7 @@ def __init__(
self.table = self._create_next_table()
self.table_row_index = 0

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""
Get up to the next n rows of the cloud fetch Arrow dataframes.
Expand Down Expand Up @@ -216,7 +220,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
return results

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
"""
Get all remaining rows of the cloud fetch Arrow dataframes.
Expand All @@ -237,7 +241,7 @@ def remaining_rows(self) -> pyarrow.Table:
self.table_row_index = 0
return results

def _create_next_table(self) -> Union[pyarrow.Table, None]:
def _create_next_table(self) -> Union["pyarrow.Table", None]:
logger.debug(
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
self.start_row_index
Expand Down Expand Up @@ -276,7 +280,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]:

return arrow_table

def _create_empty_table(self) -> pyarrow.Table:
def _create_empty_table(self) -> "pyarrow.Table":
# Create a 0-row table with just the schema bytes
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)

Expand Down Expand Up @@ -515,7 +519,7 @@ def transform_paramstyle(
return output


def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table:
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table":
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
return convert_decimals_in_arrow_table(arrow_table, description)

Expand All @@ -542,7 +546,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
return arrow_table, n_rows


def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
for i, col in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
from databricks_sqlalchemy import *
except:
import warnings

warnings.warn("Install databricks-sqlalchemy plugin before using this")
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from decimal import Decimal

import pyarrow
import pytest

try:
import pyarrow
except ImportError:
pyarrow = None

class DecimalTestsMixin:
decimal_and_expected_results = [
from tests.e2e.common.predicates import pysql_supports_arrow

def decimal_and_expected_results():

if pyarrow is None:
return []

return [
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
Expand All @@ -17,7 +26,12 @@ class DecimalTestsMixin:
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
]

multi_decimals_and_expected_results = [
def multi_decimals_and_expected_results():

if pyarrow is None:
return []

return [
(
["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
[Decimal("1.00"), Decimal("100.001"), None],
Expand All @@ -30,7 +44,9 @@ class DecimalTestsMixin:
),
]

@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results)
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class DecimalTestsMixin:
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results())
def test_decimals(self, decimal, expected_value, expected_type):
with self.cursor({}) as cursor:
query = "SELECT CAST ({})".format(decimal)
Expand All @@ -39,9 +55,7 @@ def test_decimals(self, decimal, expected_value, expected_type):
assert table.field(0).type == expected_type
assert table.to_pydict().popitem()[1][0] == expected_value

@pytest.mark.parametrize(
"decimals, expected_values, expected_type", multi_decimals_and_expected_results
)
@pytest.mark.parametrize("decimals, expected_values, expected_type", multi_decimals_and_expected_results())
def test_multi_decimals(self, decimals, expected_values, expected_type):
with self.cursor({}) as cursor:
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
import math
import time
from unittest import skipUnless

import pytest
from tests.e2e.common.predicates import pysql_supports_arrow

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,6 +44,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
+ "assuming 10K fetch size."
)

@pytest.mark.skipif(not pysql_supports_arrow(), "Without pyarrow lz4 compression is not supported")
def test_query_with_large_wide_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8192 # B
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@


def pysql_supports_arrow():
"""Import databricks.sql and test whether Cursor has fetchall_arrow."""
from databricks.sql.client import Cursor
return hasattr(Cursor, 'fetchall_arrow')
"""Checks if the pyarrow library is installed or not"""
try:
import pyarrow

return True
except ImportError:
return False


def pysql_has_version(compare, version):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 4099939

Please sign in to comment.