From eeaee969db373f4ae6ac34b343557c33d22f29a0 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 13 Sep 2024 13:48:08 +0530 Subject: [PATCH 1/6] Implemented the columnar flow for non arrow users --- src/databricks/sql/client.py | 92 +++++++++++++++++-- src/databricks/sql/thrift_backend.py | 30 +++++-- src/databricks/sql/utils.py | 126 +++++++++++++++++++++++---- src/databricks/sqlalchemy/pytest.ini | 4 + tests/unit/test_column_queue.py | 20 +++++ 5 files changed, 240 insertions(+), 32 deletions(-) create mode 100644 src/databricks/sqlalchemy/pytest.ini create mode 100644 tests/unit/test_column_queue.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index addc340e..3023adbe 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,7 +1,10 @@ from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None import requests import json import os @@ -22,6 +25,8 @@ ParamEscaper, inject_parameters, transform_paramstyle, + ArrowQueue, + ColumnQueue ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]: else: raise Error("There is no active result set") - def fetchall_arrow(self) -> pyarrow.Table: + def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: raise Error("There is no active result set") - def fetchmany_arrow(self, size) -> pyarrow.Table: + def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) @@ -1143,6 +1148,18 @@ def _fill_results_buffer(self): self.results = results self.has_more_rows = has_more_rows + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(len(table[0])): + curr_row = [] + for col_index in range(len(table)): + curr_row.append(table[col_index][row_index]) + result.append(ResultRow(*curr_row)) + + return result + def _convert_arrow_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) @@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table): def rownumber(self): return self._next_row_index - def fetchmany_arrow(self, size: int) -> pyarrow.Table: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1210,7 +1227,42 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table: return results - def fetchall_arrow(self) -> pyarrow.Table: + def merge_columnar(self, result1, result2): + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + merged_result = [result1[i] + result2[i] for i in range(len(result1))] + return merged_result + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - len(results[0]) + self._next_row_index += len(results[0]) + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= len(partial_results[0]) + self._next_row_index += len(partial_results[0]) + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -1223,12 +1275,30 @@ def fetchall_arrow(self) -> pyarrow.Table: return results + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += len(results[0]) + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += len(partial_results[0]) + + return results + def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ - res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + if len(res) > 0: return res[0] else: @@ -1238,7 +1308,10 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ - return self._convert_arrow_table(self.fetchall_arrow()) + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) def fetchmany(self, size: int) -> List[Row]: """ @@ -1246,7 +1319,10 @@ def fetchmany(self, size: int) -> List[Row]: An empty sequence is returned when no more rows are available. """ - return self._convert_arrow_table(self.fetchmany_arrow(size)) + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) def close(self) -> None: """ diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e89bff26..e60b9db0 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -7,7 +7,10 @@ import threading from typing import List, Union -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket @@ -621,6 +624,12 @@ def _get_metadata_resp(self, op_handle): @staticmethod def _hive_schema_to_arrow_schema(t_table_schema): + + if pyarrow is None: + raise ImportError( + "pyarrow is required to convert Hive schema to Arrow schema" + ) + def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -726,12 +735,17 @@ def _results_message_to_execute_response(self, resp, operation_state): description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) - schema_bytes = ( - t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) - .serialize() - .to_pybytes() - ) + + if pyarrow: + schema_bytes = ( + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() + ) + else: + schema_bytes = None + lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation if direct_results and direct_results.resultSet: @@ -827,7 +841,7 @@ def execute_command( getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), - canReadArrowResult=True, + canReadArrowResult=True if pyarrow else False, canDecompressLZ4Result=lz4_compression, canDownloadResult=use_cloud_fetch, confOverlay={ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2807bd2b..f63b3c00 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pytz import datetime import decimal from abc import ABC, abstractmethod @@ -11,7 +12,10 @@ import re import lz4.frame -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None from databricks.sql import OperationalError, exc from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager @@ -33,11 +37,11 @@ class ResultSetQueue(ABC): @abstractmethod - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int): pass @abstractmethod - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self): pass @@ -76,13 +80,15 @@ def build_queue( ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: - arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( + column_table, column_names = convert_column_based_set_to_column_table( t_row_set.columns, description ) - converted_arrow_table = convert_decimals_in_arrow_table( - arrow_table, description + + converted_column_table = convert_to_assigned_datatypes_in_column_table( + column_table, description ) - return ArrowQueue(converted_arrow_table, n_valid_rows) + + return ColumnQueue(converted_column_table, column_names) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( schema_bytes=arrow_schema_bytes, @@ -96,11 +102,33 @@ def build_queue( else: raise AssertionError("Row set type is not valid") +class ColumnQueue(ResultSetQueue): + def __init__(self, columnar_table, column_names): + self.columnar_table = columnar_table + self.cur_row_index = 0 + self.n_valid_rows = len(columnar_table[0]) + self.column_names = column_names + + def next_n_rows(self, num_rows): + length = min(num_rows, self.n_valid_rows - self.cur_row_index) + # Slicing using the default python slice + next_data = [ + column[self.cur_row_index : self.cur_row_index + length] + for column in self.columnar_table + ] + self.cur_row_index += length + return next_data + + def remaining_rows(self): + next_data = [column[self.cur_row_index :] for column in self.columnar_table] + self.cur_row_index += len(next_data[0]) + return next_data + class ArrowQueue(ResultSetQueue): def __init__( self, - arrow_table: pyarrow.Table, + arrow_table: "pyarrow.Table", n_valid_rows: int, start_row_index: int = 0, ): @@ -115,7 +143,7 @@ def __init__( self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice @@ -124,7 +152,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: self.cur_row_index += slice.num_rows return slice - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": slice = self.arrow_table.slice( self.cur_row_index, self.n_valid_rows - self.cur_row_index ) @@ -184,7 +212,7 @@ def __init__( self.table = self._create_next_table() self.table_row_index = 0 - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """ Get up to the next n rows of the cloud fetch Arrow dataframes. @@ -216,7 +244,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) return results - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": """ Get all remaining rows of the cloud fetch Arrow dataframes. @@ -237,7 +265,7 @@ def remaining_rows(self) -> pyarrow.Table: self.table_row_index = 0 return results - def _create_next_table(self) -> Union[pyarrow.Table, None]: + def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( "CloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index @@ -276,7 +304,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]: return arrow_table - def _create_empty_table(self) -> pyarrow.Table: + def _create_empty_table(self) -> "pyarrow.Table": # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) @@ -515,7 +543,7 @@ def transform_paramstyle( return output -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table: +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table": arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -542,7 +570,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema return arrow_table, n_rows -def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: +def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": for i, col in enumerate(table.itercolumns()): if description[i][1] == "decimal": decimal_col = col.to_pandas().apply( @@ -560,6 +588,29 @@ def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: return table +def convert_to_assigned_datatypes_in_column_table(column_table, description): + for i, col in enumerate(column_table): + if description[i][1] == "decimal": + column_table[i] = tuple(v if v is None else Decimal(v) for v in col) + elif description[i][1] == "date": + column_table[i] = tuple( + v if v is None else datetime.date.fromisoformat(v) for v in col + ) + elif description[i][1] == "timestamp": + column_table[i] = tuple( + ( + v + if v is None + else datetime.datetime.strptime(v, "%Y-%m-%d %H:%M:%S.%f").replace( + tzinfo=pytz.UTC + ) + ) + for v in col + ) + + return column_table + + def convert_column_based_set_to_arrow_table(columns, description): arrow_table = pyarrow.Table.from_arrays( [_convert_column_to_arrow_array(c) for c in columns], @@ -571,6 +622,13 @@ def convert_column_based_set_to_arrow_table(columns, description): return arrow_table, arrow_table.num_rows +def convert_column_based_set_to_column_table(columns, description): + column_names = [c[0] for c in description] + column_table = [_covert_column_to_list(c) for c in columns] + + return column_table, column_names + + def _convert_column_to_arrow_array(t_col): """ Return a pyarrow array from the values in a TColumn instance. @@ -595,6 +653,26 @@ def _convert_column_to_arrow_array(t_col): raise OperationalError("Empty TColumn instance {}".format(t_col)) +def _covert_column_to_list(t_col): + supported_field_types = ( + "boolVal", + "byteVal", + "i16Val", + "i32Val", + "i64Val", + "doubleVal", + "stringVal", + "binaryVal", + ) + + for field in supported_field_types: + wrapper = getattr(t_col, field) + if wrapper: + return _create_python_tuple(wrapper) + + raise OperationalError("Empty TColumn instance {}".format(t_col)) + + def _create_arrow_array(t_col_value_wrapper, arrow_type): result = t_col_value_wrapper.values nulls = t_col_value_wrapper.nulls # bitfield describing which values are null @@ -609,3 +687,19 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type): result[i] = None return pyarrow.array(result, type=arrow_type) + + +def _create_python_tuple(t_col_value_wrapper): + result = t_col_value_wrapper.values + nulls = t_col_value_wrapper.nulls # bitfield describing which values are null + assert isinstance(nulls, bytes) + + # The number of bits in nulls can be both larger or smaller than the number of + # elements in result, so take the minimum of both to iterate over. + length = min(len(result), len(nulls) * 8) + + for i in range(length): + if nulls[i >> 3] & BIT_MASKS[i & 0x7]: + result[i] = None + + return tuple(result) \ No newline at end of file diff --git a/src/databricks/sqlalchemy/pytest.ini b/src/databricks/sqlalchemy/pytest.ini new file mode 100644 index 00000000..ab89d17d --- /dev/null +++ b/src/databricks/sqlalchemy/pytest.ini @@ -0,0 +1,4 @@ + +[sqla_testing] +requirement_cls=databricks.sqlalchemy.requirements:Requirements +profile_file=profiles.txt diff --git a/tests/unit/test_column_queue.py b/tests/unit/test_column_queue.py new file mode 100644 index 00000000..b2dead3b --- /dev/null +++ b/tests/unit/test_column_queue.py @@ -0,0 +1,20 @@ +import pytest +from databricks.sql.utils import ColumnQueue + + +class TestColumnQueueSuite: + @pytest.fixture(scope="function") + def setup(self): + columnar_table = [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] + column_names = [f"col_{i}" for i in range(len(columnar_table))] + return ColumnQueue(columnar_table, column_names) + + def test_fetchmany_respects_n_rows(self, setup): + column_queue = setup + assert column_queue.next_n_rows(2) == [[0, 3], [1, 4], [2, 5]] + assert column_queue.next_n_rows(2) == [[6, 9], [7, 10], [8, 11]] + + def test_fetch_remaining_rows_respects_n_rows(self, setup): + column_queue = setup + assert column_queue.next_n_rows(2) == [[0, 3], [1, 4], [2, 5]] + assert column_queue.remaining_rows() == [[6, 9], [7, 10], [8, 11]] From 6a8646d31bd48dd85282d037c71e5f77d44387b7 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 20 Sep 2024 10:22:52 +0530 Subject: [PATCH 2/6] Minor fixes --- src/databricks/sql/client.py | 18 +++++++++++------- src/databricks/sql/utils.py | 24 ++++++++++++++---------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3023adbe..fc1db0e9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1227,13 +1227,17 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": return results - def merge_columnar(self, result1, result2): - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ + def merge_columnar(self, result1, result2): + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if len(result1) != len(result2): + raise ValueError("The number of columns in both results must be the same") + merged_result = [result1[i] + result2[i] for i in range(len(result1))] return merged_result diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index f63b3c00..fd84fc0f 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -589,15 +589,17 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": def convert_to_assigned_datatypes_in_column_table(column_table, description): + + converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": - column_table[i] = tuple(v if v is None else Decimal(v) for v in col) + converted_column_table.append(tuple(v if v is None else Decimal(v) for v in col)) elif description[i][1] == "date": - column_table[i] = tuple( + converted_column_table[i].append(tuple( v if v is None else datetime.date.fromisoformat(v) for v in col - ) + )) elif description[i][1] == "timestamp": - column_table[i] = tuple( + converted_column_table[i].append(tuple( ( v if v is None @@ -606,9 +608,11 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description): ) ) for v in col - ) + )) + else: + converted_column_table.append(col) - return column_table + return converted_column_table def convert_column_based_set_to_arrow_table(columns, description): @@ -624,7 +628,7 @@ def convert_column_based_set_to_arrow_table(columns, description): def convert_column_based_set_to_column_table(columns, description): column_names = [c[0] for c in description] - column_table = [_covert_column_to_list(c) for c in columns] + column_table = [_convert_column_to_list(c) for c in columns] return column_table, column_names @@ -653,8 +657,8 @@ def _convert_column_to_arrow_array(t_col): raise OperationalError("Empty TColumn instance {}".format(t_col)) -def _covert_column_to_list(t_col): - supported_field_types = ( +def _convert_column_to_list(t_col): + SUPPORTED_FIELD_TYPES = ( "boolVal", "byteVal", "i16Val", @@ -665,7 +669,7 @@ def _covert_column_to_list(t_col): "binaryVal", ) - for field in supported_field_types: + for field in SUPPORTED_FIELD_TYPES: wrapper = getattr(t_col, field) if wrapper: return _create_python_tuple(wrapper) From a87e2cb79d9d664c5eb8e38968ce13fdcb08aaf5 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 20 Sep 2024 11:29:55 +0530 Subject: [PATCH 3/6] Introduced the Column Table structure --- src/databricks/sql/client.py | 28 ++++++++++---------- src/databricks/sql/utils.py | 50 +++++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index fc1db0e9..4df67a08 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -25,7 +25,7 @@ ParamEscaper, inject_parameters, transform_paramstyle, - ArrowQueue, + ColumnTable, ColumnQueue ) from databricks.sql.parameters.native import ( @@ -1152,10 +1152,10 @@ def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) result = [] - for row_index in range(len(table[0])): + for row_index in range(table.num_rows): curr_row = [] - for col_index in range(len(table)): - curr_row.append(table[col_index][row_index]) + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) result.append(ResultRow(*curr_row)) return result @@ -1235,11 +1235,11 @@ def merge_columnar(self, result1, result2): :return: """ - if len(result1) != len(result2): - raise ValueError("The number of columns in both results must be the same") + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") - merged_result = [result1[i] + result2[i] for i in range(len(result1))] - return merged_result + merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)] + return ColumnTable(merged_result, result1.column_names) def fetchmany_columnar(self, size: int): """ @@ -1250,8 +1250,8 @@ def fetchmany_columnar(self, size: int): raise ValueError("size argument for fetchmany is %s but must be >= 0", size) results = self.results.next_n_rows(size) - n_remaining_rows = size - len(results[0]) - self._next_row_index += len(results[0]) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows while ( n_remaining_rows > 0 @@ -1261,8 +1261,8 @@ def fetchmany_columnar(self, size: int): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) results = self.merge_columnar(results, partial_results) - n_remaining_rows -= len(partial_results[0]) - self._next_row_index += len(partial_results[0]) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows return results @@ -1282,13 +1282,13 @@ def fetchall_arrow(self) -> "pyarrow.Table": def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" results = self.results.remaining_rows() - self._next_row_index += len(results[0]) + self._next_row_index += results.num_rows while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) - self._next_row_index += len(partial_results[0]) + self._next_row_index += partial_results.num_rows return results diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index fd84fc0f..321384f1 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -88,7 +88,7 @@ def build_queue( column_table, description ) - return ColumnQueue(converted_column_table, column_names) + return ColumnQueue(ColumnTable(converted_column_table, column_names)) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( schema_bytes=arrow_schema_bytes, @@ -102,27 +102,47 @@ def build_queue( else: raise AssertionError("Row set type is not valid") +class ColumnTable: + def __init__(self, column_table, column_names): + self.column_table = column_table + self.column_names = column_names + + @property + def num_rows(self): + if len(self.column_table) == 0: + return 0 + else: + return len(self.column_table[0]) + + @property + def num_columns(self): + return len(self.column_names) + + def get_item(self, col_index, row_index): + return self.column_table[col_index][row_index] + + def slice(self, curr_index, length): + sliced_column_table = [column[curr_index : curr_index + length] for column in self.column_table] + return ColumnTable(sliced_column_table, self.column_names) + + class ColumnQueue(ResultSetQueue): - def __init__(self, columnar_table, column_names): - self.columnar_table = columnar_table + def __init__(self, column_table: ColumnTable): + self.column_table = column_table self.cur_row_index = 0 - self.n_valid_rows = len(columnar_table[0]) - self.column_names = column_names + self.n_valid_rows = column_table.num_rows def next_n_rows(self, num_rows): length = min(num_rows, self.n_valid_rows - self.cur_row_index) - # Slicing using the default python slice - next_data = [ - column[self.cur_row_index : self.cur_row_index + length] - for column in self.columnar_table - ] - self.cur_row_index += length - return next_data + + slice = self.column_table.slice(self.cur_row_index, length) + self.cur_row_index += slice.num_rows + return slice def remaining_rows(self): - next_data = [column[self.cur_row_index :] for column in self.columnar_table] - self.cur_row_index += len(next_data[0]) - return next_data + slice = self.column_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) + self.cur_row_index += slice.num_rows + return slice class ArrowQueue(ResultSetQueue): From 146b8c7bfc7a76610b97d88f073b45e87c5ac82e Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 20 Sep 2024 12:02:18 +0530 Subject: [PATCH 4/6] Added test for the new column table --- src/databricks/sql/utils.py | 2 ++ tests/unit/test_column_queue.py | 32 +++++++++++++++++--------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 321384f1..5f7c6d16 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -125,6 +125,8 @@ def slice(self, curr_index, length): sliced_column_table = [column[curr_index : curr_index + length] for column in self.column_table] return ColumnTable(sliced_column_table, self.column_names) + def __eq__(self, other): + return self.column_table == other.column_table and self.column_names == other.column_names class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): diff --git a/tests/unit/test_column_queue.py b/tests/unit/test_column_queue.py index b2dead3b..130b589b 100644 --- a/tests/unit/test_column_queue.py +++ b/tests/unit/test_column_queue.py @@ -1,20 +1,22 @@ -import pytest -from databricks.sql.utils import ColumnQueue +from databricks.sql.utils import ColumnQueue, ColumnTable class TestColumnQueueSuite: - @pytest.fixture(scope="function") - def setup(self): - columnar_table = [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] - column_names = [f"col_{i}" for i in range(len(columnar_table))] - return ColumnQueue(columnar_table, column_names) + @staticmethod + def make_column_table(table): + n_cols = len(table) if table else 0 + return ColumnTable(table, [f"col_{i}" for i in range(n_cols)]) - def test_fetchmany_respects_n_rows(self, setup): - column_queue = setup - assert column_queue.next_n_rows(2) == [[0, 3], [1, 4], [2, 5]] - assert column_queue.next_n_rows(2) == [[6, 9], [7, 10], [8, 11]] + def test_fetchmany_respects_n_rows(self): + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_queue = ColumnQueue(column_table) - def test_fetch_remaining_rows_respects_n_rows(self, setup): - column_queue = setup - assert column_queue.next_n_rows(2) == [[0, 3], [1, 4], [2, 5]] - assert column_queue.remaining_rows() == [[6, 9], [7, 10], [8, 11]] + assert column_queue.next_n_rows(2) == column_table.slice(0, 2) + assert column_queue.next_n_rows(2) == column_table.slice(2, 2) + + def test_fetch_remaining_rows_respects_n_rows(self): + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_queue = ColumnQueue(column_table) + + assert column_queue.next_n_rows(2) == column_table.slice(0, 2) + assert column_queue.remaining_rows() == column_table.slice(2, 2) From 37015c650e5425f58651dbc2fedd50584bb29676 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 20 Sep 2024 12:30:48 +0530 Subject: [PATCH 5/6] Minor fix --- src/databricks/sql/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 5f7c6d16..97df6d4d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -617,11 +617,11 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description): if description[i][1] == "decimal": converted_column_table.append(tuple(v if v is None else Decimal(v) for v in col)) elif description[i][1] == "date": - converted_column_table[i].append(tuple( + converted_column_table.append(tuple( v if v is None else datetime.date.fromisoformat(v) for v in col )) elif description[i][1] == "timestamp": - converted_column_table[i].append(tuple( + converted_column_table.append(tuple( ( v if v is None From 7be3edd4d5ae4968dc478b0bbf568c929fa33087 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Thu, 3 Oct 2024 17:26:30 +0530 Subject: [PATCH 6/6] Removed unnecessory fikes --- src/databricks/sql/thrift_backend.py | 5 ----- src/databricks/sqlalchemy/pytest.ini | 4 ---- 2 files changed, 9 deletions(-) delete mode 100644 src/databricks/sqlalchemy/pytest.ini diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e60b9db0..7f6ada9d 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -625,11 +625,6 @@ def _get_metadata_resp(self, op_handle): @staticmethod def _hive_schema_to_arrow_schema(t_table_schema): - if pyarrow is None: - raise ImportError( - "pyarrow is required to convert Hive schema to Arrow schema" - ) - def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { diff --git a/src/databricks/sqlalchemy/pytest.ini b/src/databricks/sqlalchemy/pytest.ini deleted file mode 100644 index ab89d17d..00000000 --- a/src/databricks/sqlalchemy/pytest.ini +++ /dev/null @@ -1,4 +0,0 @@ - -[sqla_testing] -requirement_cls=databricks.sqlalchemy.requirements:Requirements -profile_file=profiles.txt