Skip to content

Commit

Permalink
Fix mask and value to support Series and Index
Browse files Browse the repository at this point in the history
  • Loading branch information
beobest2 committed Jun 15, 2020
1 parent 5c31eb3 commit 3771611
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
45 changes: 32 additions & 13 deletions databricks/koalas/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,8 +1558,11 @@ def putmask(self, mask, value):
>>> kidx
Index(['a', 'b', 'c', 'd', 'e'], dtype='object')
>>> kidx.putmask([True if x < 2 else False for x in range(5)], "Koalas").sort_values()
>>> kidx.putmask(kidx < 'c', "Koalas").sort_values()
Index(['Koalas', 'Koalas', 'c', 'd', 'e'], dtype='object')
>>> kidx.putmask(kidx < 'c', ks.Index([100, 200, 300, 400, 500])).sort_values()
Index(['100', '200', 'c', 'd', 'e'], dtype='object')
"""
scol_name = self._internal.index_spark_column_names[0]
sdf = self._internal.spark_frame.select(self.spark.column)
Expand All @@ -1569,22 +1572,38 @@ def putmask(self, mask, value):
sdf, column_name=dist_sequence_col_name
)

replace_col = verify_temp_column_name(sdf, "__replace_column__")
masking_col = verify_temp_column_name(sdf, "__masking_column__")
masking_udf = udf(lambda x: mask[x], BooleanType())

if isinstance(value, (list, tuple)):
replace_udf = udf(lambda x: value[x])
sdf = sdf.withColumn(replace_col, replace_udf(dist_sequence_col_name))
elif isinstance(value, (Index, Series)):
value = value.to_numpy().tolist()
replace_udf = udf(lambda x: value[x])
sdf = sdf.withColumn(replace_col, replace_udf(dist_sequence_col_name))
else:
sdf = sdf.withColumn(replace_col, F.lit(value))

if isinstance(mask, (Index, Series)):
mask = mask.to_numpy().tolist()
elif not isinstance(mask, list) and not isinstance(mask, tuple):
raise TypeError("Mask data doesn't support type " "{0}".format(type(mask).__name__))

masking_udf = udf(lambda x: mask[x], BooleanType())
sdf = sdf.withColumn(masking_col, masking_udf(dist_sequence_col_name))
# spark_frame here looks like below
# +-------------------------------+-----------------+------------------+
# |__distributed_sequence_column__|__index_level_0__|__masking_column__|
# +-------------------------------+-----------------+------------------+
# | 0| a| true|
# | 3| d| false|
# | 1| b| true|
# | 2| c| false|
# | 4| e| false|
# +-------------------------------+-----------------+------------------+

cond = F.when(sdf[masking_col], value).otherwise(sdf[scol_name])
# +-------------------------------+-----------------+------------------+------------------+
# |__distributed_sequence_column__|__index_level_0__|__replace_column__|__masking_column__|
# +-------------------------------+-----------------+------------------+------------------+
# | 0| a| 100| true|
# | 3| d| 400| false|
# | 1| b| 200| true|
# | 2| c| 300| false|
# | 4| e| 500| false|
# +-------------------------------+-----------------+------------------+------------------+

cond = F.when(sdf[masking_col], sdf[replace_col]).otherwise(sdf[scol_name])
sdf = sdf.select(cond.alias(scol_name))

internal = InternalFrame(spark_frame=sdf, index_map=self._internal.index_map)
Expand Down
23 changes: 18 additions & 5 deletions databricks/koalas/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,27 @@ def test_dropna(self):
self.assert_eq((kidx + 1).dropna(), (pidx + 1).dropna())

def test_putmask(self):
pidx = pd.Index(["a", "b", "c", "d", "e"])
pidx = pd.Index([1, 2, 3, 4, 5])
kidx = ks.from_pandas(pidx)

mask = [True if x < 2 else False for x in range(5)]
value = "Koalas"

self.assert_eq(
kidx.putmask(mask, value).sort_values(), pidx.putmask(mask, value).sort_values()
kidx.putmask(kidx < 3, 100).sort_values(), pidx.putmask(pidx < 3, 100).sort_values()
)
self.assert_eq(
kidx.putmask(kidx < 3, [100, 200, 300, 400, 500]).sort_values(),
pidx.putmask(pidx < 3, [100, 200, 300, 400, 500]).sort_values(),
)
self.assert_eq(
kidx.putmask(kidx < 3, (100, 200, 300, 400, 500)).sort_values(),
pidx.putmask(pidx < 3, (100, 200, 300, 400, 500)).sort_values(),
)
self.assert_eq(
kidx.putmask(kidx < 3, ks.Index([100, 200, 300, 400, 500])).sort_values(),
pidx.putmask(pidx < 3, pd.Index([100, 200, 300, 400, 500])).sort_values(),
)
self.assert_eq(
kidx.putmask(kidx < 3, ks.Series([100, 200, 300, 400, 500])).sort_values(),
pidx.putmask(pidx < 3, pd.Series([100, 200, 300, 400, 500])).sort_values(),
)

def test_index_symmetric_difference(self):
Expand Down

0 comments on commit 3771611

Please sign in to comment.