From 5efd99764a48c54eb108517c3ef68898bd24c90d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pau=20Tallada=20Cresp=C3=AD?= Date: Tue, 8 Feb 2022 18:31:23 +0100 Subject: [PATCH] Produce rows as slices of pyarrow.Table --- TCLIService/ttypes.py | 25 ++----- pyhive/common.py | 6 +- pyhive/hive.py | 158 ++++++++++++++++++++++++++++++++++++------ pyhive/schema.py | 97 ++++++++++++++++++++++++++ pyproject.toml | 3 + 5 files changed, 247 insertions(+), 42 deletions(-) create mode 100644 pyhive/schema.py create mode 100644 pyproject.toml diff --git a/TCLIService/ttypes.py b/TCLIService/ttypes.py index 573bd043..47b4cec1 100644 --- a/TCLIService/ttypes.py +++ b/TCLIService/ttypes.py @@ -8,6 +8,7 @@ from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException from thrift.protocol.TProtocol import TProtocolException +import numpy as np import sys from thrift.transport import TTransport @@ -2013,9 +2014,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype51, _size48) = iprot.readListBegin() - for _i52 in range(_size48): - _elem53 = iprot.readBool() - self.values.append(_elem53) + self.values = np.frombuffer(iprot.trans.readAll(1 * _size48), dtype='>?') iprot.readListEnd() else: iprot.skip(ftype) @@ -2097,9 +2096,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype58, _size55) = iprot.readListBegin() - for _i59 in range(_size55): - _elem60 = iprot.readByte() - self.values.append(_elem60) + self.values = np.frombuffer(iprot.trans.readAll(1 * _size55), dtype='>i1') iprot.readListEnd() else: iprot.skip(ftype) @@ -2181,9 +2178,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype65, _size62) = iprot.readListBegin() - for _i66 in range(_size62): - _elem67 = iprot.readI16() - self.values.append(_elem67) + self.values = np.frombuffer(iprot.trans.readAll(2 * _size62), dtype='>i2') iprot.readListEnd() else: iprot.skip(ftype) @@ -2265,9 +2260,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype72, _size69) = iprot.readListBegin() - for _i73 in range(_size69): - _elem74 = iprot.readI32() - self.values.append(_elem74) + self.values = np.frombuffer(iprot.trans.readAll(4 * _size69), dtype='>i4') iprot.readListEnd() else: iprot.skip(ftype) @@ -2349,9 +2342,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype79, _size76) = iprot.readListBegin() - for _i80 in range(_size76): - _elem81 = iprot.readI64() - self.values.append(_elem81) + self.values = np.frombuffer(iprot.trans.readAll(8 * _size76), dtype='>i8') iprot.readListEnd() else: iprot.skip(ftype) @@ -2433,9 +2424,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype86, _size83) = iprot.readListBegin() - for _i87 in range(_size83): - _elem88 = iprot.readDouble() - self.values.append(_elem88) + self.values = np.frombuffer(iprot.trans.readAll(8 * _size83), dtype='>f8') iprot.readListEnd() else: iprot.skip(ftype) diff --git a/pyhive/common.py b/pyhive/common.py index 51692b97..9adbc5d7 100644 --- a/pyhive/common.py +++ b/pyhive/common.py @@ -43,12 +43,12 @@ def _reset_state(self): # Internal helper state self._state = self._STATE_NONE - self._data = collections.deque() + self._data = None self._columns = None - def _fetch_while(self, fn): + def _fetch_while(self, fn, schema): while fn(): - self._fetch_more() + self._fetch_more(schema) if fn(): time.sleep(self._poll_interval) diff --git a/pyhive/hive.py b/pyhive/hive.py index c1287488..9e7e1387 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -10,6 +10,11 @@ import base64 import datetime +import io +import itertools +import numpy as np +import pyarrow as pa +import pyarrow.json import re from decimal import Decimal from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context @@ -40,7 +45,8 @@ _logger = logging.getLogger(__name__) -_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)') +_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,9})?)') +_INTERVAL_DAY_TIME_PATTERN = re.compile(r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)') ssl_cert_parameter_map = { "none": CERT_NONE, @@ -106,9 +112,36 @@ def _parse_timestamp(value): value = None return value +def _parse_date(value): + if value: + format = '%Y-%m-%d' + value = datetime.datetime.strptime(value, format).date() + else: + value = None + return value -TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal, - "TIMESTAMP_TYPE": _parse_timestamp} +def _parse_interval_day_time(value): + if value: + match = _INTERVAL_DAY_TIME_PATTERN.match(value) + if match: + days = int(match.group(1)) + hours = int(match.group(2)) + minutes = int(match.group(3)) + seconds = float(match.group(4)) + value = datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) + else: + raise Exception( + 'Cannot convert "{}" into an interval_day_time'.format(value)) + else: + value = None + return value + +TYPES_CONVERTER = { + "DECIMAL_TYPE": Decimal, + "TIMESTAMP_TYPE": _parse_timestamp, + "DATE_TYPE": _parse_date, + "INTERVAL_DAY_TIME_TYPE": _parse_interval_day_time, +} class HiveParamEscaper(common.ParamEscaper): @@ -488,7 +521,50 @@ def cancel(self): response = self._connection.client.CancelOperation(req) _check_status(response) - def _fetch_more(self): + def fetchone(self, schema=[]): + return self.fetchmany(1, schema) + + def fetchall(self, schema=[]): + return self.fetchmany(-1, schema) + + def fetchmany(self, size=None, schema=[]): + if size is None: + size = self.arraysize + + if self._state == self._STATE_NONE: + raise exc.ProgrammingError("No query yet") + + if size == -1: + # Fetch everything + self._fetch_while(lambda: self._state != self._STATE_FINISHED, schema) + else: + self._fetch_while(lambda: + (self._state != self._STATE_FINISHED) and + (self._data is None or self._data.num_rows < size), + schema + ) + + if not self._data: + return None + + if size == -1: + # Fetch everything + size = self._data.num_rows + else: + size = min(size, self._data.num_rows) + + self._rownumber += size + rows = self._data[:size] + + if size == self._data.num_rows: + # Fetch everything + self._data = None + else: + self._data = self._data[size:] + + return rows + + def _fetch_more(self, ext_schema): """Send another TFetchResultsReq and update state""" assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more" assert(self._operationHandle is not None), "Should have an op handle in _fetch_more" @@ -503,15 +579,21 @@ def _fetch_more(self): _check_status(response) schema = self.description assert not response.results.rows, 'expected data in columnar format' - columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in - zip(response.results.columns, schema)] - new_data = list(zip(*columns)) - self._data += new_data + columns = [_unwrap_column(col, col_schema[1], e_schema) for col, col_schema, e_schema in + itertools.zip_longest(response.results.columns, schema, ext_schema)] + names = [col[0] for col in schema] + new_data = pa.Table.from_batches([pa.RecordBatch.from_arrays(columns, names=names)]) # response.hasMoreRows seems to always be False, so we instead check the number of rows # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678 # if not response.hasMoreRows: - if not new_data: + if new_data.num_rows == 0: self._state = self._STATE_FINISHED + return + + if self._data is None: + self._data = new_data + else: + self._data = pa.concat_tables([self._data, new_data]) def poll(self, get_progress_update=True): """Poll for and return the raw status data provided by the Hive Thrift REST API. @@ -585,21 +667,55 @@ def fetch_logs(self): # -def _unwrap_column(col, type_=None): +def _unwrap_column(col, type_=None, schema=None): """Return a list of raw values from a TColumn instance.""" for attr, wrapper in iteritems(col.__dict__): if wrapper is not None: - result = wrapper.values - nulls = wrapper.nulls # bit set describing what's null - assert isinstance(nulls, bytes) - for i, char in enumerate(nulls): - byte = ord(char) if sys.version_info[0] == 2 else char - for b in range(8): - if byte & (1 << b): - result[i * 8 + b] = None - converter = TYPES_CONVERTER.get(type_, None) - if converter and type_: - result = [converter(row) if row else row for row in result] + if attr in ['boolVal', 'byteVal', 'i16Val', 'i32Val', 'i64Val', 'doubleVal']: + values = wrapper.values + # unpack nulls as a byte array + nulls = np.unpackbits(np.frombuffer(wrapper.nulls, dtype='uint8')).view(bool) + # override a full mask as trailing False values are not sent + mask = np.zeros(values.shape, dtype='?') + end = min(len(mask), len(nulls)) + mask[:end] = nulls[:end] + + # float values are transferred as double + if type_ == 'FLOAT_TYPE': + values = values.astype('>f4') + + result = pa.array(values.byteswap().view(values.dtype.newbyteorder()), mask=mask) + + else: + result = wrapper.values + nulls = wrapper.nulls # bit set describing what's null + if len(result) == 0: + return pa.array([]) + assert isinstance(nulls, bytes) + for i, char in enumerate(nulls): + byte = ord(char) if sys.version_info[0] == 2 else char + for b in range(8): + if byte & (1 << b): + result[i * 8 + b] = None + converter = TYPES_CONVERTER.get(type_, None) + if converter and type_: + result = [converter(row) if row else row for row in result] + + if type_ in ['ARRAY_TYPE', 'MAP_TYPE', 'STRUCT_TYPE']: + fd = io.BytesIO() + for row in result: + if row is None: + row = 'null' + fd.write(f'{{"c":{row}}}\n'.encode('utf8')) + fd.seek(0) + + if schema == None: + # NOTE: JSON map conversion (from the original struct) is not supported + result = pa.json.read_json(fd, parse_options=None)[0].combine_chunks() + else: + sch = pa.schema([('c', schema)]) + opts = pa.json.ParseOptions(explicit_schema=sch) + result = pa.json.read_json(fd, parse_options=opts)[0].combine_chunks() return result raise DataError("Got empty column value {}".format(col)) # pragma: no cover diff --git a/pyhive/schema.py b/pyhive/schema.py new file mode 100644 index 00000000..cb16876b --- /dev/null +++ b/pyhive/schema.py @@ -0,0 +1,97 @@ +""" +This module attempts to reconstruct an Arrow schema from the info dumped at the beginning of a Hive query log. + +SUPPORTS: + * All primitive types _except_ INTERVAL. + * STRUCT and ARRAY types. + * Composition of any combination of previous types. + +LIMITATIONS: + * PyHive does not support INTERVAL types yet. A converter needs to be implemented. + * Hive sends complex types always as strings as something _similar_ to JSON. + * Arrow can parse most of this pseudo-JSON excluding: + * MAP and INTERVAL types + * A custom parser would be needed to implement support for all types and their composition. +""" + +import pyparsing as pp +import pyarrow as pa + +def a_type(s, loc, toks): + m_basic = { + 'tinyint' : pa.int8(), + 'smallint' : pa.int16(), + 'int' : pa.int32(), + 'bigint' : pa.int64(), + 'float' : pa.float32(), + 'double' : pa.float64(), + 'boolean' : pa.bool_(), + 'string' : pa.string(), + 'char' : pa.string(), + 'varchar' : pa.string(), + 'binary' : pa.binary(), + 'timestamp' : pa.timestamp('ns'), + 'date' : pa.date32(), + #'interval_year_month' : pa.month_day_nano_interval(), + #'interval_day_time' : pa.month_day_nano_interval(), + } + + typ, args = toks[0], toks[1:] + + if typ in m_basic: + return m_basic[typ] + if typ == 'decimal': + return pa.decimal128(*map(int, args)) + if typ == 'array': + return pa.list_(args[0]) + #if typ == 'map': + # return pa.map_(args[0], args[1]) + if typ == 'struct': + return pa.struct(args) + raise NotImplementedError(f"Type {typ} is not supported") + +def a_field(s, loc, toks): + return pa.field(toks[0], toks[1]) + +LB, RB, LP, RP, LT, RT, COMMA, COLON = map(pp.Suppress, "[]()<>,:") + +def t_args(n): + return LP + pp.delimitedList(pp.Word(pp.nums), ",", min=n, max=n) + RP + +t_basic = pp.one_of( + "tinyint smallint int bigint float double boolean string binary timestamp date decimal", + caseless=True, as_keyword=True +) +t_interval = pp.one_of( + "interval_year_month interval_day_time", + caseless=True, as_keyword=True +) +t_char = pp.one_of("char varchar", caseless=True, as_keyword=True) + t_args(1) +t_decimal = pp.CaselessKeyword("decimal") + t_args(2) +t_primitive = (t_basic ^ t_char ^ t_decimal).set_parse_action(a_type) + +t_type = pp.Forward() + +t_label = pp.Word(pp.alphas + "_", pp.alphanums + "_") +t_array = pp.CaselessKeyword('array') + LT + t_type + RT +t_map = pp.CaselessKeyword('map') + LT + t_primitive + COMMA + t_type + RT +t_struct = pp.CaselessKeyword('struct') + LT + pp.delimitedList((t_label + COLON + t_type).set_parse_action(a_field), ",") + RT +t_complex = (t_array ^ t_map ^ t_struct).set_parse_action(a_type) + +t_type <<= t_primitive ^ t_complex +t_top_type = t_type ^ t_interval + +l_schema, l_fieldschemas, l_fieldschema, l_name, l_type, l_comment, l_properties, l_null = map( + lambda x: pp.Keyword(x).suppress(), "Schema fieldSchemas FieldSchema name type comment properties null".split(' ') +) +t_fieldschema = l_fieldschema + LP + l_name + COLON + t_label.suppress() + COMMA + l_type + COLON + t_top_type + COMMA + l_comment + COLON + l_null + RP +t_schema = l_schema + LP + l_fieldschemas + COLON + LB + pp.delimitedList(t_fieldschema, ',') + RB + COMMA + l_properties + COLON + l_null + RP + +def parse_schema(logs): + prefix = 'INFO : Returning Hive schema: ' + + for l in logs: + if l.startswith(prefix): + str_schema = l[len(prefix):] + + return t_schema.parse_string(str_schema).as_list() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..fed528d4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta"