Skip to content

Commit

Permalink
[SPARK-47500][PYTHON][CONNECT] Factor column name handling out of `pl…
Browse files Browse the repository at this point in the history
…an.py`

### What changes were proposed in this pull request?
Factor column name handling out of `plan.py`

### Why are the changes needed?
there are too many parameters preprocessing in `plan.py`, e.g. the column name handling,
there are multiple duplicated helper functions here and there, make it hard to follow some times.

### Does this PR introduce _any_ user-facing change?
no, just code refactor

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#45636 from zhengruifeng/plan_clean_up.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Mar 22, 2024
1 parent 47bce8e commit aea13fc
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 141 deletions.
70 changes: 37 additions & 33 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,10 @@ def isEmpty(self) -> bool:
def select(self, *cols: "ColumnOrName") -> "DataFrame":
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]

return DataFrame(plan.Project(self._plan, *cols), session=self._session)
return DataFrame(
plan.Project(self._plan, [F._to_col(c) for c in cols]),
session=self._session,
)

select.__doc__ = PySparkDataFrame.select.__doc__

Expand All @@ -197,7 +199,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame":
else:
sql_expr.extend([F.expr(e) for e in element])

return DataFrame(plan.Project(self._plan, *sql_expr), session=self._session)
return DataFrame(plan.Project(self._plan, sql_expr), session=self._session)

selectExpr.__doc__ = PySparkDataFrame.selectExpr.__doc__

Expand Down Expand Up @@ -309,18 +311,20 @@ def repartition( # type: ignore[misc]
)
if len(cols) == 0:
return DataFrame(
plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=True),
plan.Repartition(self._plan, numPartitions, shuffle=True),
self._session,
)
else:
return DataFrame(
plan.RepartitionByExpression(self._plan, numPartitions, list(cols)),
plan.RepartitionByExpression(
self._plan, numPartitions, [F._to_col(c) for c in cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
return DataFrame(
plan.RepartitionByExpression(self._plan, None, list(cols)),
plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]),
self.sparkSession,
)
else:
Expand All @@ -345,14 +349,14 @@ def repartitionByRange(self, *cols: "ColumnOrName") -> "DataFrame":
def repartitionByRange( # type: ignore[misc]
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> "DataFrame":
def _convert_col(col: "ColumnOrName") -> "ColumnOrName":
def _convert_col(col: "ColumnOrName") -> Column:
if isinstance(col, Column):
if isinstance(col._expr, SortOrder):
return col
else:
return Column(SortOrder(col._expr))
return col.asc()
else:
return Column(SortOrder(ColumnReference(col)))
return F.col(col).asc()

if isinstance(numPartitions, int):
if not numPartitions > 0:
Expand All @@ -369,18 +373,17 @@ def _convert_col(col: "ColumnOrName") -> "ColumnOrName":
message_parameters={"item": "cols"},
)
else:
sort = []
sort.extend([_convert_col(c) for c in cols])
return DataFrame(
plan.RepartitionByExpression(self._plan, numPartitions, sort),
plan.RepartitionByExpression(
self._plan, numPartitions, [_convert_col(c) for c in cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
sort = []
sort.extend([_convert_col(c) for c in cols])
return DataFrame(
plan.RepartitionByExpression(self._plan, None, sort),
plan.RepartitionByExpression(
self._plan, None, [_convert_col(c) for c in [numPartitions] + list(cols)]
),
self.sparkSession,
)
else:
Expand Down Expand Up @@ -648,12 +651,18 @@ def _joinAsOf(
if tolerance is not None:
assert isinstance(tolerance, Column), "tolerance should be Column"

def _convert_col(df: "DataFrame", col: "ColumnOrName") -> Column:
if isinstance(col, Column):
return col
else:
return Column(ColumnReference(col, df._plan._plan_id))

return DataFrame(
plan.AsOfJoin(
left=self._plan,
right=other._plan,
left_as_of=leftAsOfColumn,
right_as_of=rightAsOfColumn,
left_as_of=_convert_col(self, leftAsOfColumn),
right_as_of=_convert_col(other, rightAsOfColumn),
on=on,
how=how,
tolerance=tolerance,
Expand Down Expand Up @@ -940,24 +949,21 @@ def unpivot(
) -> "DataFrame":
assert ids is not None, "ids must not be None"

def to_jcols(
def _convert_cols(
cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]]
) -> List["ColumnOrName"]:
) -> List[Column]:
if cols is None:
lst = []
elif isinstance(cols, tuple):
lst = list(cols)
elif isinstance(cols, list):
lst = cols
return []
elif isinstance(cols, (tuple, list)):
return [F._to_col(c) for c in cols]
else:
lst = [cols]
return lst
return [F._to_col(cols)]

return DataFrame(
plan.Unpivot(
self._plan,
to_jcols(ids),
to_jcols(values) if values is not None else None,
_convert_cols(ids),
_convert_cols(values) if values is not None else None,
variableColumnName,
valueColumnName,
),
Expand Down Expand Up @@ -1645,9 +1651,7 @@ def freqItems(
def sampleBy(
self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None
) -> "DataFrame":
if isinstance(col, str):
col = Column(ColumnReference(col))
elif not isinstance(col, Column):
if not isinstance(col, (str, Column)):
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
Expand All @@ -1671,7 +1675,7 @@ def sampleBy(
fractions[k] = float(v)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(
plan.StatSampleBy(child=self._plan, col=col, fractions=fractions, seed=seed),
plan.StatSampleBy(child=self._plan, col=F._to_col(col), fractions=fractions, seed=seed),
session=self._session,
)

Expand Down
Loading

0 comments on commit aea13fc

Please sign in to comment.