Skip to content

Commit

Permalink
Merge pull request #275 from markfairbanks/refactor-case_when
Browse files Browse the repository at this point in the history
Refactor `case_when()`
  • Loading branch information
markfairbanks authored Oct 26, 2024
2 parents c43709f + 9ae9692 commit 2224255
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## v0.3.1 (in development)

#### Functionality improvements

* `if_else()` treats string inputs in `true` and `false` as strings and not as columns
* `case_when()` no has syntax closer to `dplyr::case_when()`

## v0.3.0

* Major refactor to work with `polars>=1.0.0`
Expand Down
6 changes: 3 additions & 3 deletions tests/test_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def test_agg_stats():
def test_case_when():
"""Can use case_when"""
df = tp.tibble(x = range(1, 4))
actual = df.mutate(case_x = tp.case_when(col('x') < 2).then(0)
.when(col('x') < 3).then(1)
.otherwise(0))
actual = df.mutate(case_x = tp.case_when(col('x') < 2, 0,
col('x') < 3, 1,
_default = 0))
expected = tp.tibble(x = range(1, 4), case_x = [0, 1, 0])
assert actual.equals(expected), "case_when failed"

Expand Down
26 changes: 19 additions & 7 deletions tidypolars/funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
_is_constant,
_is_list,
_is_iterable,
_is_series
_is_series,
_is_string,
_str_to_lit
)

__all__ = [
Expand Down Expand Up @@ -169,7 +171,7 @@ def between(x, left, right):
x = _col_expr(x)
return x.is_between(left, right)

def case_when(expr):
def case_when(*args, _default = pl.Null):
"""
Case when
Expand All @@ -182,12 +184,22 @@ def case_when(expr):
--------
>>> df = tp.tibble(x = range(1, 4))
>>> df.mutate(
>>> case_x = tp.case_when(col('x') < 2).then(1)
>>> .when(col('x') < 3).then(2)
>>> .otherwise(0)
>>> case_x = tp.case_when(col('x') < 2, 1,
>>> col('x') < 3, 2,
>>> _default = 0)
>>> )
"""
return pl.when(expr)
conditions = [args[i] for i in range(0, len(args), 2)]
values = [args[i] for i in range(1, len(args), 2)]
values = [_str_to_lit(value) for value in values]
for i in range(len(conditions)):
if i == 0:
expr = pl.when(conditions[i]).then(values[i])
else:
expr = expr.when(conditions[i]).then(values[i])
_default = _str_to_lit(_default)
expr = expr.otherwise(_default)
return expr

def cast(x, dtype):
"""
Expand Down Expand Up @@ -333,7 +345,7 @@ def if_else(condition, true, false):
>>> df = tp.tibble(x = range(1, 4))
>>> df.mutate(if_x = tp.if_else(col('x') < 2, 1, 2))
"""
return pl.when(condition).then(true).otherwise(false)
return case_when(condition, true, _default = false)

def is_finite(x):
"""
Expand Down
7 changes: 6 additions & 1 deletion tidypolars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,9 @@ def _repeat(x, times):
def _mutate_cols(df, exprs):
for expr in exprs:
df = df.with_columns(expr)
return df
return df

def _str_to_lit(x):
if _is_string(x):
x = pl.lit(x)
return x

0 comments on commit 2224255

Please sign in to comment.