Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 1, 2024
1 parent beafc54 commit fac3f73
Show file tree
Hide file tree
Showing 18 changed files with 115 additions and 136 deletions.
14 changes: 6 additions & 8 deletions py-polars/polars/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from functools import reduce
from operator import or_
from typing import TYPE_CHECKING, Callable, TypeVar
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
from warnings import warn

import polars._reexport as pl
Expand All @@ -11,8 +11,6 @@
if TYPE_CHECKING:
from polars import DataFrame, Expr, LazyFrame, Series

NS = TypeVar("NS")


__all__ = [
"register_expr_namespace",
Expand All @@ -24,14 +22,14 @@
# do not allow override of polars' own namespaces (as registered by '_accessors')
_reserved_namespaces: set[str] = reduce(
or_,
(
cls._accessors # type: ignore[attr-defined]
for cls in (pl.DataFrame, pl.Expr, pl.LazyFrame, pl.Series)
),
(cls._accessors for cls in (pl.DataFrame, pl.Expr, pl.LazyFrame, pl.Series)),
)


class NameSpace:
NS = TypeVar("NS")


class NameSpace(Generic[NS]):
"""Establish property-like namespace object for user-defined functionality."""

def __init__(self, name: str, namespace: type[NS]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/convert/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def _from_dataframe_repr(m: re.Match[str]) -> DataFrame:
if el in headers:
idx = headers.index(el)
for table_elem in (headers, dtypes):
table_elem.pop(idx) # type: ignore[attr-defined]
table_elem.pop(idx)
if coldata:
coldata.pop(idx)

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def __add__(
other = _prepare_other_arg(other)
return self._from_pydf(self._df.add(other._s))

def __radd__( # type: ignore[misc]
def __radd__(
self, other: DataFrame | Series | int | float | bool | str
) -> DataFrame:
if isinstance(other, str):
Expand Down
5 changes: 1 addition & 4 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@


@overload
def is_selector(obj: _selector_proxy_) -> Literal[True]: # type: ignore[overload-overlap]
...


def is_selector(obj: _selector_proxy_) -> Literal[True]: ...
@overload
def is_selector(obj: Any) -> Literal[False]: ...

Expand Down
65 changes: 22 additions & 43 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
return self._from_pyseries(f(other))

@overload # type: ignore[override]
def __eq__(self, other: Expr) -> Expr: ... # type: ignore[overload-overlap]
def __eq__(self, other: Expr) -> Expr: ...

@overload
def __eq__(self, other: Any) -> Series: ...
Expand All @@ -773,8 +773,7 @@ def __eq__(self, other: Any) -> Series | Expr:
return self._comp(other, "eq")

@overload # type: ignore[override]
def __ne__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __ne__(self, other: Expr) -> Expr: ...

@overload
def __ne__(self, other: Any) -> Series: ...
Expand All @@ -786,8 +785,7 @@ def __ne__(self, other: Any) -> Series | Expr:
return self._comp(other, "neq")

@overload
def __gt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __gt__(self, other: Expr) -> Expr: ...

@overload
def __gt__(self, other: Any) -> Series: ...
Expand All @@ -799,8 +797,7 @@ def __gt__(self, other: Any) -> Series | Expr:
return self._comp(other, "gt")

@overload
def __lt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __lt__(self, other: Expr) -> Expr: ...

@overload
def __lt__(self, other: Any) -> Series: ...
Expand All @@ -812,8 +809,7 @@ def __lt__(self, other: Any) -> Series | Expr:
return self._comp(other, "lt")

@overload
def __ge__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __ge__(self, other: Expr) -> Expr: ...

@overload
def __ge__(self, other: Any) -> Series: ...
Expand All @@ -825,8 +821,7 @@ def __ge__(self, other: Any) -> Series | Expr:
return self._comp(other, "gt_eq")

@overload
def __le__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __le__(self, other: Expr) -> Expr: ...

@overload
def __le__(self, other: Any) -> Series: ...
Expand All @@ -838,8 +833,7 @@ def __le__(self, other: Any) -> Series | Expr:
return self._comp(other, "lt_eq")

@overload
def le(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def le(self, other: Expr) -> Expr: ...

@overload
def le(self, other: Any) -> Series: ...
Expand All @@ -849,8 +843,7 @@ def le(self, other: Any) -> Series | Expr:
return self.__le__(other)

@overload
def lt(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def lt(self, other: Expr) -> Expr: ...

@overload
def lt(self, other: Any) -> Series: ...
Expand All @@ -860,8 +853,7 @@ def lt(self, other: Any) -> Series | Expr:
return self.__lt__(other)

@overload
def eq(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def eq(self, other: Expr) -> Expr: ...

@overload
def eq(self, other: Any) -> Series: ...
Expand All @@ -871,8 +863,7 @@ def eq(self, other: Any) -> Series | Expr:
return self.__eq__(other)

@overload
def eq_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def eq_missing(self, other: Expr) -> Expr: ...

@overload
def eq_missing(self, other: Any) -> Series: ...
Expand Down Expand Up @@ -919,8 +910,7 @@ def eq_missing(self, other: Any) -> Series | Expr:
return self.to_frame().select(F.col(self.name).eq_missing(other)).to_series()

@overload
def ne(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def ne(self, other: Expr) -> Expr: ...

@overload
def ne(self, other: Any) -> Series: ...
Expand All @@ -930,8 +920,7 @@ def ne(self, other: Any) -> Series | Expr:
return self.__ne__(other)

@overload
def ne_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def ne_missing(self, other: Expr) -> Expr: ...

@overload
def ne_missing(self, other: Any) -> Series: ...
Expand Down Expand Up @@ -978,8 +967,7 @@ def ne_missing(self, other: Any) -> Series | Expr:
return self.to_frame().select(F.col(self.name).ne_missing(other)).to_series()

@overload
def ge(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def ge(self, other: Expr) -> Expr: ...

@overload
def ge(self, other: Any) -> Series: ...
Expand All @@ -989,8 +977,7 @@ def ge(self, other: Any) -> Series | Expr:
return self.__ge__(other)

@overload
def gt(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def gt(self, other: Expr) -> Expr: ...

@overload
def gt(self, other: Any) -> Series: ...
Expand Down Expand Up @@ -1043,12 +1030,10 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self:
return self._from_pyseries(f(other))

@overload
def __add__(self, other: DataFrame) -> DataFrame: # type: ignore[overload-overlap]
...
def __add__(self, other: DataFrame) -> DataFrame: ...

@overload
def __add__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __add__(self, other: Expr) -> Expr: ...

@overload
def __add__(self, other: Any) -> Self: ...
Expand All @@ -1063,8 +1048,7 @@ def __add__(self, other: Any) -> Self | DataFrame | Expr:
return self._arithmetic(other, "add", "add_<>")

@overload
def __sub__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __sub__(self, other: Expr) -> Expr: ...

@overload
def __sub__(self, other: Any) -> Self: ...
Expand All @@ -1075,8 +1059,7 @@ def __sub__(self, other: Any) -> Self | Expr:
return self._arithmetic(other, "sub", "sub_<>")

@overload
def __truediv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __truediv__(self, other: Expr) -> Expr: ...

@overload
def __truediv__(self, other: Any) -> Series: ...
Expand All @@ -1095,8 +1078,7 @@ def __truediv__(self, other: Any) -> Series | Expr:
return self.cast(Float64) / other

@overload
def __floordiv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __floordiv__(self, other: Expr) -> Expr: ...

@overload
def __floordiv__(self, other: Any) -> Series: ...
Expand All @@ -1116,12 +1098,10 @@ def __invert__(self) -> Series:
return self.not_()

@overload
def __mul__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __mul__(self, other: Expr) -> Expr: ...

@overload
def __mul__(self, other: DataFrame) -> DataFrame: # type: ignore[overload-overlap]
...
def __mul__(self, other: DataFrame) -> DataFrame: ...

@overload
def __mul__(self, other: Any) -> Series: ...
Expand All @@ -1138,8 +1118,7 @@ def __mul__(self, other: Any) -> Series | DataFrame | Expr:
return self._arithmetic(other, "mul", "mul_<>")

@overload
def __mod__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
def __mod__(self, other: Expr) -> Expr: ...

@overload
def __mod__(self, other: Any) -> Series: ...
Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ class TradeNT(NamedTuple):
columns = ["timestamp", "ticker", "price", "size"]

for TradeClass in (TradeDC, TradeNT, TradePD):
trades = [TradeClass(**dict(zip(columns, values))) for values in raw_data]
trades = [TradeClass(**dict(zip(columns, values))) for values in raw_data] # type: ignore[arg-type]

for DF in (pl.DataFrame, pl.from_records):
df = DF(data=trades) # type: ignore[operator]
df = DF(data=trades)
assert df.schema == {
"timestamp": pl.Datetime("us"),
"ticker": pl.String,
Expand All @@ -229,7 +229,7 @@ class TradeNT(NamedTuple):
assert df.rows() == raw_data

# partial dtypes override
df = DF( # type: ignore[operator]
df = DF(
data=trades,
schema_overrides={"timestamp": pl.Datetime("ms"), "size": pl.Int32},
)
Expand Down Expand Up @@ -1041,13 +1041,13 @@ def test_init_records_schema_order() -> None:
shuffle(data)
shuffle(cols)

df = constructor(data, schema=cols) # type: ignore[operator]
df = constructor(data, schema=cols)
for col in df.columns:
assert all(value in (None, lookup[col]) for value in df[col].to_list())

# have schema override inferred types, omit some columns, add a new one
schema = {"a": pl.Int8, "c": pl.Int16, "e": pl.Int32}
df = constructor(data, schema=schema) # type: ignore[operator]
df = constructor(data, schema=schema)

assert df.schema == schema
for col in df.columns:
Expand Down
23 changes: 12 additions & 11 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ def test_null_count() -> None:
assert df.null_count().row(np.int64(0)) == (0, 1) # type: ignore[call-overload]


def test_init_empty() -> None:
@pytest.mark.parametrize("input", [None, (), [], {}, pa.Table.from_arrays([])])
def test_init_empty(input: Any) -> None:
# test various flavours of empty init
for empty in (None, (), [], {}, pa.Table.from_arrays([])):
df = pl.DataFrame(empty)
assert df.shape == (0, 0)
assert df.is_empty()
df = pl.DataFrame(input)
assert df.shape == (0, 0)
assert df.is_empty()


# note: cannot use df (empty or otherwise) in boolean context
def test_df_bool_ambiguous() -> None:
empty_df = pl.DataFrame()
with pytest.raises(TypeError, match="ambiguous"):
not empty_df
Expand Down Expand Up @@ -1185,7 +1186,7 @@ def test_from_rows_of_dicts() -> None:
{"id": 2, "value": 101, "_meta": "b"},
]
df_init: Callable[..., Any]
for df_init in (pl.from_dicts, pl.DataFrame): # type:ignore[assignment]
for df_init in (pl.from_dicts, pl.DataFrame):
df1 = df_init(records)
assert df1.rows() == [(1, 100, "a"), (2, 101, "b")]

Expand Down Expand Up @@ -2168,12 +2169,12 @@ def test_selection_misc() -> None:

# literal values (as scalar/list)
for zero in (0, [0]):
assert df.select(zero)["literal"].to_list() == [0] # type: ignore[arg-type]
assert df.select(zero)["literal"].to_list() == [0]
assert df.select(literal=0)["literal"].to_list() == [0]

# expect string values to be interpreted as cols
for x in ("x", ["x"], pl.col("x")):
assert df.select(x).rows() == [("abc",)] # type: ignore[arg-type]
assert df.select(x).rows() == [("abc",)]

# string col + lit
assert df.with_columns(["x", 0]).to_dicts() == [{"x": "abc", "literal": 0}]
Expand Down Expand Up @@ -2476,7 +2477,7 @@ def test_init_datetimes_with_timezone() -> None:
}
},
):
result = pl.DataFrame( # type: ignore[arg-type]
result = pl.DataFrame(
data={
"d1": [dtm.replace(tzinfo=ZoneInfo(tz_us))],
"d2": [dtm.replace(tzinfo=ZoneInfo(tz_europe))],
Expand Down Expand Up @@ -2739,7 +2740,7 @@ def test_unstack() -> None:
assert df.unstack(
step=3,
how="horizontal",
columns=column_subset, # type: ignore[arg-type]
columns=column_subset,
).to_dict(as_series=False) == {
"col2_0": [0, 3, 6],
"col2_1": [1, 4, 7],
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Y:
row_data = [(d,) for d in data]
for cls in (X, Y):
for ctor in (pl.DataFrame, pl.from_records):
df = ctor(data=list(map(cls, data))) # type: ignore[operator]
df = ctor(data=list(map(cls, data)))
assert df.schema == {
"a": pl.Decimal(scale=7),
}
Expand Down
Loading

0 comments on commit fac3f73

Please sign in to comment.