Skip to content

Commit

Permalink
Some more improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
raulcd committed Jun 19, 2024
1 parent a6371f3 commit b0617b9
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 95 deletions.
6 changes: 2 additions & 4 deletions python/pyarrow/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
from pyarrow import Codec
from pyarrow import fs

try:
import numpy as np
except ImportError:
pass

groups = [
'acero',
Expand Down Expand Up @@ -306,6 +302,7 @@ def unary_agg_func_fixture():
Register a unary aggregate function (mean)
"""
from pyarrow import compute as pc
import numpy as np

def func(ctx, x):
return pa.scalar(np.nanmean(x))
Expand All @@ -331,6 +328,7 @@ def varargs_agg_func_fixture():
Register a unary aggregate function
"""
from pyarrow import compute as pc
import numpy as np

def func(ctx, *args):
sum = 0.0
Expand Down
80 changes: 47 additions & 33 deletions python/pyarrow/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@
import re
import warnings

import numpy as np

try:
import numpy as np
except ImportError:
np = None
import pyarrow as pa
from pyarrow.lib import _pandas_api, frombytes # noqa


_logical_type_map = {}
_numpy_logical_type_map = {}
_pandas_logical_type_map = {}


def get_logical_type_map():
Expand Down Expand Up @@ -85,27 +89,32 @@ def get_logical_type(arrow_type):
return 'object'


_numpy_logical_type_map = {
np.bool_: 'bool',
np.int8: 'int8',
np.int16: 'int16',
np.int32: 'int32',
np.int64: 'int64',
np.uint8: 'uint8',
np.uint16: 'uint16',
np.uint32: 'uint32',
np.uint64: 'uint64',
np.float32: 'float32',
np.float64: 'float64',
'datetime64[D]': 'date',
np.str_: 'string',
np.bytes_: 'bytes',
}
def get_numpy_logical_type_map():
global _numpy_logical_type_map
if not _numpy_logical_type_map:
_numpy_logical_type_map.update({
np.bool_: 'bool',
np.int8: 'int8',
np.int16: 'int16',
np.int32: 'int32',
np.int64: 'int64',
np.uint8: 'uint8',
np.uint16: 'uint16',
np.uint32: 'uint32',
np.uint64: 'uint64',
np.float32: 'float32',
np.float64: 'float64',
'datetime64[D]': 'date',
np.str_: 'string',
np.bytes_: 'bytes',
})
return _numpy_logical_type_map


def get_logical_type_from_numpy(pandas_collection):
numpy_logical_type_map = get_numpy_logical_type_map()
try:
return _numpy_logical_type_map[pandas_collection.dtype.type]
return numpy_logical_type_map[pandas_collection.dtype.type]
except KeyError:
if hasattr(pandas_collection.dtype, 'tz'):
return 'datetimetz'
Expand Down Expand Up @@ -1019,19 +1028,23 @@ def _is_generated_index_name(name):
pattern = r'^__index_level_\d+__$'
return re.match(pattern, name) is not None


_pandas_logical_type_map = {
'date': 'datetime64[D]',
'datetime': 'datetime64[ns]',
'datetimetz': 'datetime64[ns]',
'unicode': np.str_,
'bytes': np.bytes_,
'string': np.str_,
'integer': np.int64,
'floating': np.float64,
'decimal': np.object_,
'empty': np.object_,
}
def get_pandas_logical_type_map():
global _pandas_logical_type_map

if not _pandas_logical_type_map:
_pandas_logical_type_map.update({
'date': 'datetime64[D]',
'datetime': 'datetime64[ns]',
'datetimetz': 'datetime64[ns]',
'unicode': np.str_,
'bytes': np.bytes_,
'string': np.str_,
'integer': np.int64,
'floating': np.float64,
'decimal': np.object_,
'empty': np.object_,
})
return _pandas_logical_type_map


def _pandas_type_to_numpy_type(pandas_type):
Expand All @@ -1047,8 +1060,9 @@ def _pandas_type_to_numpy_type(pandas_type):
dtype : np.dtype
The dtype that corresponds to `pandas_type`.
"""
pandas_logical_type_map = get_pandas_logical_type_map()
try:
return _pandas_logical_type_map[pandas_type]
return pandas_logical_type_map[pandas_type]
except KeyError:
if 'mixed' in pandas_type:
# catching 'mixed', 'mixed-integer' and 'mixed-integer-float'
Expand Down
7 changes: 2 additions & 5 deletions python/pyarrow/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import datetime
import math
import sys

import pytest
Expand All @@ -38,10 +39,6 @@
import tzdata # noqa:F401
except ImportError:
zoneinfo = None
try:
import numpy as np
except ImportError:
pass

import pyarrow as pa

Expand Down Expand Up @@ -282,7 +279,7 @@ def arrays(draw, type, size=None, nullable=True):
values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,)))
# Workaround ARROW-4952: no easy way to assert array equality
# in a NaN-tolerant way.
values[np.isnan(values)] = -42.0
values[math.isnan(values)] = -42.0
return pa.array(values, type=ty)
elif pa.types.is_decimal(ty):
# TODO(kszucs): properly limit the precision
Expand Down
17 changes: 2 additions & 15 deletions python/pyarrow/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import contextlib
import decimal
import gc
try:
import numpy as np
except ImportError:
pass
import os
import random
import re
Expand Down Expand Up @@ -113,29 +109,20 @@ def randdecimal(precision, scale):


def random_ascii(length):
import numpy as np
return bytes(np.random.randint(65, 123, size=length, dtype='i1'))


def rands(nchars):
"""
Generate one random string.
"""
import numpy as np
RANDS_CHARS = np.array(
list(string.ascii_letters + string.digits), dtype=(np.str_, 1))
return "".join(np.random.choice(RANDS_CHARS, nchars))


def make_dataframe():
import pandas as pd

N = 30
df = pd.DataFrame(
{col: np.random.randn(N) for col in string.ascii_uppercase[:4]},
index=pd.Index([rands(10) for _ in range(N)])
)
return df


def memory_leak_check(f, metric='rss', threshold=1 << 17, iterations=10,
check_interval=1):
"""
Expand Down
80 changes: 42 additions & 38 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -33,45 +33,48 @@ from cython import sizeof

# These are imprecise because the type (in pandas 0.x) depends on the presence
# of nulls
cdef dict _pandas_type_map = {}


def _get_pandas_type_map():
cdef dict _pandas_type_map = {
_Type_NA: np.object_, # NaNs
_Type_BOOL: np.bool_,
_Type_INT8: np.int8,
_Type_INT16: np.int16,
_Type_INT32: np.int32,
_Type_INT64: np.int64,
_Type_UINT8: np.uint8,
_Type_UINT16: np.uint16,
_Type_UINT32: np.uint32,
_Type_UINT64: np.uint64,
_Type_HALF_FLOAT: np.float16,
_Type_FLOAT: np.float32,
_Type_DOUBLE: np.float64,
# Pandas does not support [D]ay, so default to [ms] for date32
_Type_DATE32: np.dtype('datetime64[ms]'),
_Type_DATE64: np.dtype('datetime64[ms]'),
_Type_TIMESTAMP: {
's': np.dtype('datetime64[s]'),
'ms': np.dtype('datetime64[ms]'),
'us': np.dtype('datetime64[us]'),
'ns': np.dtype('datetime64[ns]'),
},
_Type_DURATION: {
's': np.dtype('timedelta64[s]'),
'ms': np.dtype('timedelta64[ms]'),
'us': np.dtype('timedelta64[us]'),
'ns': np.dtype('timedelta64[ns]'),
},
_Type_BINARY: np.object_,
_Type_FIXED_SIZE_BINARY: np.object_,
_Type_STRING: np.object_,
_Type_LIST: np.object_,
_Type_MAP: np.object_,
_Type_DECIMAL128: np.object_,
}
global _pandas_type_map
if not _pandas_type_map:
_pandas_type_map.update({
_Type_NA: np.object_, # NaNs
_Type_BOOL: np.bool_,
_Type_INT8: np.int8,
_Type_INT16: np.int16,
_Type_INT32: np.int32,
_Type_INT64: np.int64,
_Type_UINT8: np.uint8,
_Type_UINT16: np.uint16,
_Type_UINT32: np.uint32,
_Type_UINT64: np.uint64,
_Type_HALF_FLOAT: np.float16,
_Type_FLOAT: np.float32,
_Type_DOUBLE: np.float64,
# Pandas does not support [D]ay, so default to [ms] for date32
_Type_DATE32: np.dtype('datetime64[ms]'),
_Type_DATE64: np.dtype('datetime64[ms]'),
_Type_TIMESTAMP: {
's': np.dtype('datetime64[s]'),
'ms': np.dtype('datetime64[ms]'),
'us': np.dtype('datetime64[us]'),
'ns': np.dtype('datetime64[ns]'),
},
_Type_DURATION: {
's': np.dtype('timedelta64[s]'),
'ms': np.dtype('timedelta64[ms]'),
'us': np.dtype('timedelta64[us]'),
'ns': np.dtype('timedelta64[ns]'),
},
_Type_BINARY: np.object_,
_Type_FIXED_SIZE_BINARY: np.object_,
_Type_STRING: np.object_,
_Type_LIST: np.object_,
_Type_MAP: np.object_,
_Type_DECIMAL128: np.object_,
})
return _pandas_type_map


Expand Down Expand Up @@ -154,14 +157,15 @@ def _is_primitive(Type type):

def _get_pandas_type(arrow_type, coerce_to_ns=False):
cdef Type type_id = arrow_type.id
if type_id not in _get_pandas_type_map():
cdef dict pandas_type_map = _get_pandas_type_map()
if type_id not in pandas_type_map:
return None
if coerce_to_ns:
# ARROW-3789: Coerce date/timestamp types to datetime64[ns]
if type_id == _Type_DURATION:
return np.dtype('timedelta64[ns]')
return np.dtype('datetime64[ns]')
pandas_type = _get_pandas_type_map()[type_id]
pandas_type = pandas_type_map[type_id]
if isinstance(pandas_type, dict):
unit = getattr(arrow_type, 'unit', None)
pandas_type = pandas_type.get(unit, None)
Expand Down

0 comments on commit b0617b9

Please sign in to comment.