Skip to content

Commit

Permalink
unxfail for dask, but is it a good idea?
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Feb 13, 2025
1 parent caf8bf9 commit 0bf0103
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
16 changes: 10 additions & 6 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.typing import CompliantNamespace

if TYPE_CHECKING:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals.dtypes import DType
Expand Down Expand Up @@ -157,7 +157,11 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:

def sum_horizontal(self: Self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = [s.fillna(0) for _expr in exprs for s in _expr(df)]
series = [
s.fillna(0) if isinstance(s, dx.Series) else 0 if s is None else s
for _expr in exprs
for s in _expr(df)
]
return [reduce(operator.add, series)]

return DaskExpr(
Expand Down
8 changes: 2 additions & 6 deletions tests/expr_and_series/sum_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ def test_sumh_all(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_sumh_aggregations(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_sumh_aggregations(constructor: Constructor) -> None:
data = {"a": [1, 2, 3], "b": [10, 20, 30]}
df = nw.from_native(constructor(data))
result = df.select(nw.sum_horizontal(nw.all().mean().name.suffix("_foo")))
Expand All @@ -64,7 +60,7 @@ def test_sumh_aggregations(
def test_sumh_transformations(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if any(x in str(constructor) for x in ("dask", "duckdb")):
if any(x in str(constructor) for x in ("duckdb",)):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}
df = nw.from_native(constructor(data))
Expand Down

0 comments on commit 0bf0103

Please sign in to comment.