diff --git a/CHANGELOG.md b/CHANGELOG.md index a16914c..67c3b66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## v0.3.1 (in development) +#### Functionality improvements + +* `if_else()` treats string inputs in `true` and `false` as strings and not as columns +* `case_when()` no has syntax closer to `dplyr::case_when()` + ## v0.3.0 * Major refactor to work with `polars>=1.0.0` diff --git a/tests/test_funs.py b/tests/test_funs.py index d7f394b..262f661 100644 --- a/tests/test_funs.py +++ b/tests/test_funs.py @@ -56,9 +56,9 @@ def test_agg_stats(): def test_case_when(): """Can use case_when""" df = tp.tibble(x = range(1, 4)) - actual = df.mutate(case_x = tp.case_when(col('x') < 2).then(0) - .when(col('x') < 3).then(1) - .otherwise(0)) + actual = df.mutate(case_x = tp.case_when(col('x') < 2, 0, + col('x') < 3, 1, + _default = 0)) expected = tp.tibble(x = range(1, 4), case_x = [0, 1, 0]) assert actual.equals(expected), "case_when failed" diff --git a/tidypolars/funs.py b/tidypolars/funs.py index 80a912f..6c26a28 100644 --- a/tidypolars/funs.py +++ b/tidypolars/funs.py @@ -7,7 +7,9 @@ _is_constant, _is_list, _is_iterable, - _is_series + _is_series, + _is_string, + _str_to_lit ) __all__ = [ @@ -169,7 +171,7 @@ def between(x, left, right): x = _col_expr(x) return x.is_between(left, right) -def case_when(expr): +def case_when(*args, _default = pl.Null): """ Case when @@ -182,12 +184,22 @@ def case_when(expr): -------- >>> df = tp.tibble(x = range(1, 4)) >>> df.mutate( - >>> case_x = tp.case_when(col('x') < 2).then(1) - >>> .when(col('x') < 3).then(2) - >>> .otherwise(0) + >>> case_x = tp.case_when(col('x') < 2, 1, + >>> col('x') < 3, 2, + >>> _default = 0) >>> ) """ - return pl.when(expr) + conditions = [args[i] for i in range(0, len(args), 2)] + values = [args[i] for i in range(1, len(args), 2)] + values = [_str_to_lit(value) for value in values] + for i in range(len(conditions)): + if i == 0: + expr = pl.when(conditions[i]).then(values[i]) + else: + expr = expr.when(conditions[i]).then(values[i]) + _default = _str_to_lit(_default) + expr = expr.otherwise(_default) + return expr def cast(x, dtype): """ @@ -333,7 +345,7 @@ def if_else(condition, true, false): >>> df = tp.tibble(x = range(1, 4)) >>> df.mutate(if_x = tp.if_else(col('x') < 2, 1, 2)) """ - return pl.when(condition).then(true).otherwise(false) + return case_when(condition, true, _default = false) def is_finite(x): """ diff --git a/tidypolars/utils.py b/tidypolars/utils.py index 38dbc76..5684d82 100644 --- a/tidypolars/utils.py +++ b/tidypolars/utils.py @@ -109,4 +109,9 @@ def _repeat(x, times): def _mutate_cols(df, exprs): for expr in exprs: df = df.with_columns(expr) - return df \ No newline at end of file + return df + +def _str_to_lit(x): + if _is_string(x): + x = pl.lit(x) + return x \ No newline at end of file