Skip to content

Commit

Permalink
[SPARK-46540][PYTHON] Respect column names when Python data source re…
Browse files Browse the repository at this point in the history
…ad function outputs named Row objects

### What changes were proposed in this pull request?

This PR fixes an issue when the `read` method of Python DataSourceReader yields named `Row` objects.
Currently, it ignores the name in the Row object:
```Python
def read(self,...):
    yield Row(a=1, b=2)
    yield Row(b=3, a=2)
```
The result should be `[Row(a=1, b=2), Row(a=2, b=3)]`, instead of `[Row(a=1 , b=2), Row(a=3, b=2)]`.

### Why are the changes needed?

To fix an incorrect behavior.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#44531 from allisonwang-db/spark-46540-named-rows.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
allisonwang-db authored and HyukjinKwon committed Jan 2, 2024
1 parent 3fd9876 commit 48a09c4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ def test_data_source_read_output_row(self):
df = self.spark.read.format("test").load()
assertDataFrameEqual(df, [Row(0, 1)])

def test_data_source_read_output_named_row(self):
self.register_data_source(
read_func=lambda schema, partition: iter([Row(j=1, i=0), Row(i=1, j=2)])
)
df = self.spark.read.format("test").load()
assertDataFrameEqual(df, [Row(0, 1), Row(1, 2)])

def test_data_source_read_output_named_row_with_wrong_schema(self):
self.register_data_source(
read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)])
)
with self.assertRaisesRegex(
PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
):
self.spark.read.format("test").load().show()

def test_data_source_read_output_none(self):
self.register_data_source(read_func=lambda schema, partition: None)
df = self.spark.read.format("test").load()
Expand Down
24 changes: 22 additions & 2 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
write_int,
SpecialLengths,
)
from pyspark.sql import Row
from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion
from pyspark.sql.datasource import DataSource, InputPartition
from pyspark.sql.pandas.types import to_arrow_schema
Expand Down Expand Up @@ -234,6 +235,8 @@ def batched(iterator: Iterator, n: int) -> Iterator:

# Convert the results from the `reader.read` method to an iterator of arrow batches.
num_cols = len(column_names)
col_mapping = {name: i for i, name in enumerate(column_names)}
col_name_set = set(column_names)
for batch in batched(output_iter, max_arrow_batch_size):
pylist: List[List] = [[] for _ in range(num_cols)]
for result in batch:
Expand All @@ -258,8 +261,25 @@ def batched(iterator: Iterator, n: int) -> Iterator:
},
)

for col in range(num_cols):
pylist[col].append(column_converters[col](result[col]))
# Assign output values by name of the field, not position, if the result is a
# named `Row` object.
if isinstance(result, Row) and hasattr(result, "__fields__"):
# Check if the names are the same as the schema.
if set(result.__fields__) != col_name_set:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
message_parameters={
"expected": str(column_names),
"actual": str(result.__fields__),
},
)
# Assign the values by name.
for name in column_names:
idx = col_mapping[name]
pylist[idx].append(column_converters[idx](result[name]))
else:
for col in range(num_cols):
pylist[col].append(column_converters[col](result[col]))

yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,32 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-46540: data source read output named rows") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
|class SimpleDataSourceReader(DataSourceReader):
| def read(self, partition):
| from pyspark.sql import Row
| yield Row(x = 0, y = 1)
| yield Row(y = 2, x = 1)
| yield Row(2, 3)
| yield (3, 4)
|
|class $dataSourceName(DataSource):
| def schema(self) -> str:
| return "x int, y int"
|
| def reader(self, schema):
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
spark.dataSource.registerPython(dataSourceName, dataSource)
val df = spark.read.format(dataSourceName).load()
checkAnswer(df, Seq(Row(0, 1), Row(1, 2), Row(2, 3), Row(3, 4)))
}

test("SPARK-46424: Support Python metrics") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
Expand Down

0 comments on commit 48a09c4

Please sign in to comment.