Skip to content

Commit

Permalink
numpy2 (#234)
Browse files Browse the repository at this point in the history
* uprev

* numpy2

* no np2

* test nmpy 2

* np2
  • Loading branch information
andrewgsavage authored Jun 15, 2024
1 parent 474bfbe commit 7125454
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
numpy: ["numpy>=1.20.3,<2.0.0"]
numpy: ["numpy>=1.20.3,<2.0.0", "numpy==2.0.0.rc2"]
pandas: ["pandas==2.2.2"]
pint: ["pint==0.24"]

Expand Down
6 changes: 4 additions & 2 deletions pint_pandas/pint_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ def __setitem__(self, key, value):
except IndexError as e:
msg = "Mask is wrong length. {}".format(e)
raise IndexError(msg)
except TypeError as e:
raise ValueError(e)

def _formatter(self, boxed=False):
"""Formatting function for scalar values.
Expand Down Expand Up @@ -833,7 +835,7 @@ def __array__(self, dtype=None, copy=False):
return self._to_array_of_quantity(copy=copy)
if is_string_dtype(dtype):
return np.array([str(x) for x in self.quantity], dtype=str)
return np.array(self._data, dtype=dtype, copy=copy)
return np.array(self._data, dtype=dtype)

def _to_array_of_quantity(self, copy=False):
qtys = [
Expand All @@ -843,7 +845,7 @@ def _to_array_of_quantity(self, copy=False):
for item in self._data
]
with warnings.catch_warnings(record=True):
return np.array(qtys, dtype="object", copy=copy)
return np.array(qtys, dtype="object")

def searchsorted(self, value, side="left", sorter=None):
"""
Expand Down
3 changes: 2 additions & 1 deletion pint_pandas/testsuite/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def test_reductions(self, reduction):
tp = self._timeit(getattr(s_pint, reduction)).to("ms")
t = self._timeit(getattr(s, reduction)).to("ms")

assert tp <= 5 * t
if t > 0:
assert tp <= 5 * t


def test_issue_86():
Expand Down

0 comments on commit 7125454

Please sign in to comment.