Skip to content

Commit

Permalink
fix(typing): misc assignment/redef errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Feb 13, 2025
1 parent 51ad255 commit 658c963
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
9 changes: 3 additions & 6 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,17 +747,14 @@ def unique(

agg_func = agg_func_map[keep]
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
keep_idx = (
keep_idx_native = (
df.append_column(col_token, pa.array(np.arange(len(self))))
.group_by(subset)
.aggregate([(col_token, agg_func)])
.column(f"{col_token}_{agg_func}")
)

return self._from_native_frame(
pc.take(df, keep_idx), # type: ignore[call-overload, unused-ignore]
validate_column_names=False,
)
indices = cast("Indices", keep_idx_native)
return self._from_native_frame(df.take(indices), validate_column_names=False)

keep_idx = self.simple_select(*subset).is_unique()
plx = self.__narwhals_namespace__()
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:

function_name = re.sub(r"(\w+->)", "", expr._function_name)
if function_name in {"std", "var"}:
option = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
option: Any = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
option = pc.CountOptions(mode="all")
elif function_name == "count":
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,10 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries[Any]]:
if self._otherwise_value is None:
# NOTE: Casting just to match *some overload*, as the series type isn't known statically
null_value = cast("pa.NullScalar", lit(None, type=value_series_native.type))
otherwise_native = pa.repeat(null_value, len(condition_native))
otherwise_null = pa.repeat(null_value, len(condition_native))
return [
value_series._from_native_series(
pc.if_else(condition_native, value_series_native, otherwise_native)
pc.if_else(condition_native, value_series_native, otherwise_null)
)
]
if isinstance(self._otherwise_value, ArrowExpr):
Expand Down
12 changes: 8 additions & 4 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,11 @@ def rank(
sort_keys = "descending" if descending else "ascending"
tiebreaker = "first" if method == "ordinal" else method

native_series = self._native_series
native_series: pa.ChunkedArray[_ScalarT_co] | pa.Array[_ScalarT_co]
if self._backend_version < (14, 0, 0): # pragma: no cover
native_series = native_series.combine_chunks()
native_series = self._native_series.combine_chunks()
else:
native_series = self._native_series

null_mask = pc.is_null(native_series)

Expand Down Expand Up @@ -1141,13 +1143,15 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa
),
width,
)
bin_indices = cast("pa.ChunkedArray[Any]", pc.floor(bin_proportions))
bin_indices: pa.ChunkedArray[Any] = cast(
"pa.ChunkedArray[Any]", pc.floor(bin_proportions)
)

# NOTE: stubs leave unannotated
if_else: Incomplete = pc.if_else

# shift bins so they are right-closed
bin_indices: pa.ChunkedArray[Any] = if_else(
bin_indices = if_else(
pc.and_(
pc.equal(bin_indices, bin_proportions),
pc.greater(bin_indices, 0),
Expand Down

0 comments on commit 658c963

Please sign in to comment.