Skip to content

Commit

Permalink
Revert "[SPARK-48322][SPARK-42965][SQL][CONNECT][PYTHON] Drop interna…
Browse files Browse the repository at this point in the history
…l metadata in `DataFrame.schema`"

revert apache#46636

apache#46636 (comment)

Closes apache#46790 from zhengruifeng/revert_metadata_drop.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed May 29, 2024
1 parent 8bbbde7 commit cfbed99
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 13 deletions.
37 changes: 29 additions & 8 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Window,
)
from pyspark.sql.types import ( # noqa: F401
_drop_metadata,
BooleanType,
DataType,
LongType,
Expand Down Expand Up @@ -756,10 +757,20 @@ def __init__(

if is_testing():
struct_fields = spark_frame.select(index_spark_columns).schema.fields
assert all(
index_field.struct_field == struct_field
for index_field, struct_field in zip(index_fields, struct_fields)
), (index_fields, struct_fields)
if is_remote():
# TODO(SPARK-42965): For some reason, the metadata of StructField is different
# in a few tests when using Spark Connect. However, the function works properly.
# Therefore, we temporarily perform Spark Connect tests by excluding metadata
# until the issue is resolved.
assert all(
_drop_metadata(index_field.struct_field) == _drop_metadata(struct_field)
for index_field, struct_field in zip(index_fields, struct_fields)
), (index_fields, struct_fields)
else:
assert all(
index_field.struct_field == struct_field
for index_field, struct_field in zip(index_fields, struct_fields)
), (index_fields, struct_fields)

self._index_fields: List[InternalField] = index_fields

Expand All @@ -774,10 +785,20 @@ def __init__(

if is_testing():
struct_fields = spark_frame.select(data_spark_columns).schema.fields
assert all(
data_field.struct_field == struct_field
for data_field, struct_field in zip(data_fields, struct_fields)
), (data_fields, struct_fields)
if is_remote():
# TODO(SPARK-42965): For some reason, the metadata of StructField is different
# in a few tests when using Spark Connect. However, the function works properly.
# Therefore, we temporarily perform Spark Connect tests by excluding metadata
# until the issue is resolved.
assert all(
_drop_metadata(data_field.struct_field) == _drop_metadata(struct_field)
for data_field, struct_field in zip(data_fields, struct_fields)
), (data_fields, struct_fields)
else:
assert all(
data_field.struct_field == struct_field
for data_field, struct_field in zip(data_fields, struct_fields)
), (data_fields, struct_fields)

self._data_fields: List[InternalField] = data_fields

Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql.types import (
_drop_metadata,
StringType,
StructType,
StructField,
Expand Down Expand Up @@ -1673,7 +1674,8 @@ def test_nested_lambda_function(self):
)
)

self.assertEqual(cdf.schema, sdf.schema)
# TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias': 'true'}'
self.assertEqual(_drop_metadata(cdf.schema), _drop_metadata(sdf.schema))
self.assertEqual(cdf.collect(), sdf.collect())

def test_csv_functions(self):
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,19 @@ def toJson(self, zone_id: str = "UTC") -> str:
_COLLATIONS_METADATA_KEY = "__COLLATIONS"


def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, StructField]:
assert isinstance(d, (DataType, StructField))
if isinstance(d, StructField):
return StructField(d.name, _drop_metadata(d.dataType), d.nullable, None)
elif isinstance(d, StructType):
return StructType([cast(StructField, _drop_metadata(f)) for f in d.fields])
elif isinstance(d, ArrayType):
return ArrayType(_drop_metadata(d.elementType), d.containsNull)
elif isinstance(d, MapType):
return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), d.valueContainsNull)
return d


def _parse_datatype_string(s: str) -> DataType:
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -561,7 +561,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def schema: StructType = sparkSession.withActive {
removeInternalMetadata(queryExecution.analyzed.schema)
queryExecution.analyzed.schema
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.util.Random
import org.scalatest.matchers.must.Matchers.the

import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
Expand Down Expand Up @@ -1464,7 +1465,7 @@ class DataFrameAggregateSuite extends QueryTest
Duration.ofSeconds(14)) ::
Nil)
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
val metadata = Metadata.empty
val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, "true").build()
assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
StructField("sum(year-month)", YearMonthIntervalType(), metadata = metadata),
StructField("sum(year)", YearMonthIntervalType(YEAR), metadata = metadata),
Expand Down Expand Up @@ -1598,7 +1599,7 @@ class DataFrameAggregateSuite extends QueryTest
Duration.ofMinutes(4).plusSeconds(20),
Duration.ofSeconds(7)) :: Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
val metadata = Metadata.empty
val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, "true").build()
assert(avgDF2.schema == StructType(Seq(
StructField("class", IntegerType, false),
StructField("avg(year-month)", YearMonthIntervalType(), metadata = metadata),
Expand Down

0 comments on commit cfbed99

Please sign in to comment.