From 48a09c457cf5854d956138d3881d2c45e15b291d Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 2 Jan 2024 16:20:11 +0900 Subject: [PATCH] [SPARK-46540][PYTHON] Respect column names when Python data source read function outputs named Row objects ### What changes were proposed in this pull request? This PR fixes an issue when the `read` method of Python DataSourceReader yields named `Row` objects. Currently, it ignores the name in the Row object: ```Python def read(self,...): yield Row(a=1, b=2) yield Row(b=3, a=2) ``` The result should be `[Row(a=1, b=2), Row(a=2, b=3)]`, instead of `[Row(a=1 , b=2), Row(a=3, b=2)]`. ### Why are the changes needed? To fix an incorrect behavior. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #44531 from allisonwang-db/spark-46540-named-rows. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- .../sql/tests/test_python_datasource.py | 16 ++++++++++++ .../sql/worker/plan_data_source_read.py | 24 +++++++++++++++-- .../python/PythonDataSourceSuite.scala | 26 +++++++++++++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 32333a8ccee91..8517d8f36382b 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -135,6 +135,22 @@ def test_data_source_read_output_row(self): df = self.spark.read.format("test").load() assertDataFrameEqual(df, [Row(0, 1)]) + def test_data_source_read_output_named_row(self): + self.register_data_source( + read_func=lambda schema, partition: iter([Row(j=1, i=0), Row(i=1, j=2)]) + ) + df = self.spark.read.format("test").load() + assertDataFrameEqual(df, [Row(0, 1), Row(1, 2)]) + + def test_data_source_read_output_named_row_with_wrong_schema(self): + self.register_data_source( + read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)]) + ) + with self.assertRaisesRegex( + PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" + ): + self.spark.read.format("test").load().show() + def test_data_source_read_output_none(self): self.register_data_source(read_func=lambda schema, partition: None) df = self.spark.read.format("test").load() diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index d2fcb5096ae2a..d4693f5ff7be2 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -29,6 +29,7 @@ write_int, SpecialLengths, ) +from pyspark.sql import Row from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion from pyspark.sql.datasource import DataSource, InputPartition from pyspark.sql.pandas.types import to_arrow_schema @@ -234,6 +235,8 @@ def batched(iterator: Iterator, n: int) -> Iterator: # Convert the results from the `reader.read` method to an iterator of arrow batches. num_cols = len(column_names) + col_mapping = {name: i for i, name in enumerate(column_names)} + col_name_set = set(column_names) for batch in batched(output_iter, max_arrow_batch_size): pylist: List[List] = [[] for _ in range(num_cols)] for result in batch: @@ -258,8 +261,25 @@ def batched(iterator: Iterator, n: int) -> Iterator: }, ) - for col in range(num_cols): - pylist[col].append(column_converters[col](result[col])) + # Assign output values by name of the field, not position, if the result is a + # named `Row` object. + if isinstance(result, Row) and hasattr(result, "__fields__"): + # Check if the names are the same as the schema. + if set(result.__fields__) != col_name_set: + raise PySparkRuntimeError( + error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH", + message_parameters={ + "expected": str(column_names), + "actual": str(result.__fields__), + }, + ) + # Assign the values by name. + for name in column_names: + idx = col_mapping[name] + pylist[idx].append(column_converters[idx](result[name])) + else: + for col in range(num_cols): + pylist[col].append(column_converters[col](result[col])) yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 080f57aa08a04..49fb2e859fff5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -455,6 +455,32 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-46540: data source read output named rows") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def read(self, partition): + | from pyspark.sql import Row + | yield Row(x = 0, y = 1) + | yield Row(y = 2, x = 1) + | yield Row(2, 3) + | yield (3, 4) + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "x int, y int" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.format(dataSourceName).load() + checkAnswer(df, Seq(Row(0, 1), Row(1, 2), Row(2, 3), Row(3, 4))) + } + test("SPARK-46424: Support Python metrics") { assume(shouldTestPandasUDFs) val dataSourceScript =