diff --git a/databricks/koalas/indexes.py b/databricks/koalas/indexes.py index 8a3f53d725..2cafcaafec 100644 --- a/databricks/koalas/indexes.py +++ b/databricks/koalas/indexes.py @@ -1558,8 +1558,11 @@ def putmask(self, mask, value): >>> kidx Index(['a', 'b', 'c', 'd', 'e'], dtype='object') - >>> kidx.putmask([True if x < 2 else False for x in range(5)], "Koalas").sort_values() + >>> kidx.putmask(kidx < 'c', "Koalas").sort_values() Index(['Koalas', 'Koalas', 'c', 'd', 'e'], dtype='object') + + >>> kidx.putmask(kidx < 'c', ks.Index([100, 200, 300, 400, 500])).sort_values() + Index(['100', '200', 'c', 'd', 'e'], dtype='object') """ scol_name = self._internal.index_spark_column_names[0] sdf = self._internal.spark_frame.select(self.spark.column) @@ -1569,22 +1572,38 @@ def putmask(self, mask, value): sdf, column_name=dist_sequence_col_name ) + replace_col = verify_temp_column_name(sdf, "__replace_column__") masking_col = verify_temp_column_name(sdf, "__masking_column__") - masking_udf = udf(lambda x: mask[x], BooleanType()) + if isinstance(value, (list, tuple)): + replace_udf = udf(lambda x: value[x]) + sdf = sdf.withColumn(replace_col, replace_udf(dist_sequence_col_name)) + elif isinstance(value, (Index, Series)): + value = value.to_numpy().tolist() + replace_udf = udf(lambda x: value[x]) + sdf = sdf.withColumn(replace_col, replace_udf(dist_sequence_col_name)) + else: + sdf = sdf.withColumn(replace_col, F.lit(value)) + + if isinstance(mask, (Index, Series)): + mask = mask.to_numpy().tolist() + elif not isinstance(mask, list) and not isinstance(mask, tuple): + raise TypeError("Mask data doesn't support type " "{0}".format(type(mask).__name__)) + + masking_udf = udf(lambda x: mask[x], BooleanType()) sdf = sdf.withColumn(masking_col, masking_udf(dist_sequence_col_name)) # spark_frame here looks like below - # +-------------------------------+-----------------+------------------+ - # |__distributed_sequence_column__|__index_level_0__|__masking_column__| - # +-------------------------------+-----------------+------------------+ - # | 0| a| true| - # | 3| d| false| - # | 1| b| true| - # | 2| c| false| - # | 4| e| false| - # +-------------------------------+-----------------+------------------+ - - cond = F.when(sdf[masking_col], value).otherwise(sdf[scol_name]) + # +-------------------------------+-----------------+------------------+------------------+ + # |__distributed_sequence_column__|__index_level_0__|__replace_column__|__masking_column__| + # +-------------------------------+-----------------+------------------+------------------+ + # | 0| a| 100| true| + # | 3| d| 400| false| + # | 1| b| 200| true| + # | 2| c| 300| false| + # | 4| e| 500| false| + # +-------------------------------+-----------------+------------------+------------------+ + + cond = F.when(sdf[masking_col], sdf[replace_col]).otherwise(sdf[scol_name]) sdf = sdf.select(cond.alias(scol_name)) internal = InternalFrame(spark_frame=sdf, index_map=self._internal.index_map) diff --git a/databricks/koalas/tests/test_indexes.py b/databricks/koalas/tests/test_indexes.py index 3d5b64c792..363b80dc33 100644 --- a/databricks/koalas/tests/test_indexes.py +++ b/databricks/koalas/tests/test_indexes.py @@ -279,14 +279,27 @@ def test_dropna(self): self.assert_eq((kidx + 1).dropna(), (pidx + 1).dropna()) def test_putmask(self): - pidx = pd.Index(["a", "b", "c", "d", "e"]) + pidx = pd.Index([1, 2, 3, 4, 5]) kidx = ks.from_pandas(pidx) - mask = [True if x < 2 else False for x in range(5)] - value = "Koalas" - self.assert_eq( - kidx.putmask(mask, value).sort_values(), pidx.putmask(mask, value).sort_values() + kidx.putmask(kidx < 3, 100).sort_values(), pidx.putmask(pidx < 3, 100).sort_values() + ) + self.assert_eq( + kidx.putmask(kidx < 3, [100, 200, 300, 400, 500]).sort_values(), + pidx.putmask(pidx < 3, [100, 200, 300, 400, 500]).sort_values(), + ) + self.assert_eq( + kidx.putmask(kidx < 3, (100, 200, 300, 400, 500)).sort_values(), + pidx.putmask(pidx < 3, (100, 200, 300, 400, 500)).sort_values(), + ) + self.assert_eq( + kidx.putmask(kidx < 3, ks.Index([100, 200, 300, 400, 500])).sort_values(), + pidx.putmask(pidx < 3, pd.Index([100, 200, 300, 400, 500])).sort_values(), + ) + self.assert_eq( + kidx.putmask(kidx < 3, ks.Series([100, 200, 300, 400, 500])).sort_values(), + pidx.putmask(pidx < 3, pd.Series([100, 200, 300, 400, 500])).sort_values(), ) def test_index_symmetric_difference(self):