From 06aafb1eeaf32bdc7abce5bb4a9ffb474a9e61ae Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 23 May 2024 08:52:36 +0900 Subject: [PATCH 01/45] [SPARK-48258][PYTHON][CONNECT][FOLLOW-UP] Bind relation ID to the plan instead of DataFrame ### What changes were proposed in this pull request? This PR addresses https://github.com/apache/spark/pull/46683#discussion_r1608527529 comment within Python, by using ID at the plan instead of DataFrame itself. ### Why are the changes needed? Because the DataFrame holds the relation ID, if DataFrame B are derived from DataFrame A, and DataFrame A is garbage-collected, then the cache might not exist anymore. See the example below: ```python df = spark.range(1).localCheckpoint() df2 = df.repartition(10) del df df2.collect() ``` ``` pyspark.errors.exceptions.connect.SparkConnectGrpcException: (org.apache.spark.sql.connect.common.InvalidPlanInput) No DataFrame with id a4efa660-897c-4500-bd4e-bd57cd0263d2 is found in the session cd4764b4-90a9-4249-9140-12a6e4a98cd3 ``` ### Does this PR introduce _any_ user-facing change? No, the main change has not been released out yet. ### How was this patch tested? Manually tested, and added a unittest. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46694 from HyukjinKwon/SPARK-48258-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/conversion.py | 5 +- python/pyspark/sql/connect/dataframe.py | 38 ------------- python/pyspark/sql/connect/plan.py | 54 ++++++++++++++++--- python/pyspark/sql/connect/session.py | 2 +- .../sql/tests/connect/test_connect_basic.py | 53 +++++++++++++++--- 5 files changed, 97 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index b1cf88e40a4e8..1c205586d6096 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -577,7 +577,8 @@ def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "Dat from pyspark.sql.connect.session import SparkSession import pyspark.sql.connect.plan as plan + session = SparkSession.active() return DataFrame( - plan=plan.CachedRemoteRelation(relation.relation_id), - session=SparkSession.active(), + plan=plan.CachedRemoteRelation(relation.relation_id, session), + session=session, ) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 3725bc3ba0e40..510776bb752d3 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -16,7 +16,6 @@ # # mypy: disable-error-code="override" -from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 from pyspark.errors.exceptions.base import ( SessionNotSameException, PySparkIndexError, @@ -138,41 +137,6 @@ def __init__( # by __repr__ and _repr_html_ while eager evaluation opens. self._support_repr_html = False self._cached_schema: Optional[StructType] = None - self._cached_remote_relation_id: Optional[str] = None - - def __del__(self) -> None: - # If session is already closed, all cached DataFrame should be released. - if not self._session.client.is_closed and self._cached_remote_relation_id is not None: - try: - command = plan.RemoveRemoteCachedRelation( - plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id) - ).command(session=self._session.client) - req = self._session.client._execute_plan_request_with_metadata() - if self._session.client._user_id: - req.user_context.user_id = self._session.client._user_id - req.plan.command.CopyFrom(command) - - for attempt in self._session.client._retrying(): - with attempt: - # !!HACK ALERT!! - # unary_stream does not work on Python's exit for an unknown reasons - # Therefore, here we open unary_unary channel instead. - # See also :class:`SparkConnectServiceStub`. - request_serializer = ( - spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString - ) - response_deserializer = ( - spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString - ) - channel = self._session.client._channel.unary_unary( - "/spark.connect.SparkConnectService/ExecutePlan", - request_serializer=request_serializer, - response_deserializer=response_deserializer, - ) - metadata = self._session.client._builder.metadata() - channel(req, metadata=metadata) # type: ignore[arg-type] - except Exception as e: - warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") def __reduce__(self) -> Tuple: """ @@ -2137,7 +2101,6 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) - checkpointed._cached_remote_relation_id = checkpointed._plan._relationId return checkpointed def localCheckpoint(self, eager: bool = True) -> "DataFrame": @@ -2146,7 +2109,6 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame": assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) - checkpointed._cached_remote_relation_id = checkpointed._plan._relationId return checkpointed if not is_remote_only(): diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 94c2641bb4d21..868bd4fb57aa4 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -40,6 +40,7 @@ import pickle from threading import Lock from inspect import signature, isclass +import warnings import pyarrow as pa @@ -49,6 +50,7 @@ import pyspark.sql.connect.proto as proto from pyspark.sql.column import Column +from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 from pyspark.sql.connect.conversion import storage_level_to_proto from pyspark.sql.connect.expressions import Expression from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType @@ -62,6 +64,7 @@ from pyspark.sql.connect.client import SparkConnectClient from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.observation import Observation + from pyspark.sql.connect.session import SparkSession class LogicalPlan: @@ -547,14 +550,49 @@ class CachedRemoteRelation(LogicalPlan): """Logical plan object for a DataFrame reference which represents a DataFrame that's been cached on the server with a given id.""" - def __init__(self, relationId: str): + def __init__(self, relation_id: str, spark_session: "SparkSession"): super().__init__(None) - self._relationId = relationId - - def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = self._create_proto_relation() - plan.cached_remote_relation.relation_id = self._relationId - return plan + self._relation_id = relation_id + # Needs to hold the session to make a request itself. + self._spark_session = spark_session + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + plan = self._create_proto_relation() + plan.cached_remote_relation.relation_id = self._relation_id + return plan + + def __del__(self) -> None: + session = self._spark_session + # If session is already closed, all cached DataFrame should be released. + if session is not None and not session.client.is_closed and self._relation_id is not None: + try: + command = RemoveRemoteCachedRelation(self).command(session=session.client) + req = session.client._execute_plan_request_with_metadata() + if session.client._user_id: + req.user_context.user_id = session.client._user_id + req.plan.command.CopyFrom(command) + + for attempt in session.client._retrying(): + with attempt: + # !!HACK ALERT!! + # unary_stream does not work on Python's exit for an unknown reasons + # Therefore, here we open unary_unary channel instead. + # See also :class:`SparkConnectServiceStub`. + request_serializer = ( + spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString + ) + response_deserializer = ( + spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString + ) + channel = session.client._channel.unary_unary( + "/spark.connect.SparkConnectService/ExecutePlan", + request_serializer=request_serializer, + response_deserializer=response_deserializer, + ) + metadata = session.client._builder.metadata() + channel(req, metadata=metadata) # type: ignore[arg-type] + except Exception as e: + warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") class Hint(LogicalPlan): @@ -1792,7 +1830,7 @@ def __init__(self, relation: CachedRemoteRelation) -> None: def command(self, session: "SparkConnectClient") -> proto.Command: plan = self._create_proto_relation() - plan.cached_remote_relation.relation_id = self._relation._relationId + plan.cached_remote_relation.relation_id = self._relation._relation_id cmd = proto.Command() cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation) return cmd diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 5e6c5e5587646..f99d298ea1170 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -926,7 +926,7 @@ def _create_remote_dataframe(self, remote_id: str) -> "ParentDataFrame": This is used in ForeachBatch() runner, where the remote DataFrame refers to the output of a micro batch. """ - return DataFrame(CachedRemoteRelation(remote_id), self) + return DataFrame(CachedRemoteRelation(remote_id, spark_session=self), self) @staticmethod def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index b144c3b8de208..0648b5ce9925c 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -16,10 +16,10 @@ # import os +import gc import unittest import shutil import tempfile -import time from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError @@ -34,6 +34,7 @@ ArrayType, Row, ) +from pyspark.testing.utils import eventually from pyspark.testing.sqlutils import SQLTestUtils from pyspark.testing.connectutils import ( should_test_connect, @@ -1379,8 +1380,8 @@ def test_garbage_collection_checkpoint(self): # SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state # in Spark Connect server df = self.connect.range(10).localCheckpoint() - self.assertIsNotNone(df._cached_remote_relation_id) - cached_remote_relation_id = df._cached_remote_relation_id + self.assertIsNotNone(df._plan._relation_id) + cached_remote_relation_id = df._plan._relation_id jvm = self.spark._jvm session_holder = getattr( @@ -1397,14 +1398,54 @@ def test_garbage_collection_checkpoint(self): ) del df + gc.collect() - time.sleep(3) # Make sure removing is triggered, and executed in the server. + def condition(): + # Check the state was removed up on garbage-collection. + self.assertIsNone( + session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None) + ) + + eventually(catch_assertions=True)(condition)() + + def test_garbage_collection_derived_checkpoint(self): + # SPARK-48258: Should keep the cached remote relation when derived DataFrames exist + df = self.connect.range(10).localCheckpoint() + self.assertIsNotNone(df._plan._relation_id) + derived = df.repartition(10) + cached_remote_relation_id = df._plan._relation_id - # Check the state was removed up on garbage-collection. - self.assertIsNone( + jvm = self.spark._jvm + session_holder = getattr( + getattr( + jvm.org.apache.spark.sql.connect.service, + "SparkConnectService$", + ), + "MODULE$", + ).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id) + + # Check the state exists. + self.assertIsNotNone( session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None) ) + del df + gc.collect() + + def condition(): + self.assertIsNone( + session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None) + ) + + # Should not remove the cache + with self.assertRaises(AssertionError): + eventually(catch_assertions=True, timeout=5)(condition)() + + del derived + gc.collect() + + eventually(catch_assertions=True)(condition)() + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 From d96ab44c82519eec88b28df6974ddb5b7f429dbf Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 23 May 2024 09:34:42 +0900 Subject: [PATCH 02/45] [SPARK-48393][PYTHON] Move a group of constants to `pyspark.util` ### What changes were proposed in this pull request? Move a group of constants to `pyspark.util`, move them from connect to pyspark.util, so reusable in both ### Why are the changes needed? code clean up ### Does this PR introduce _any_ user-facing change? no, they are internal constants ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46710 from zhengruifeng/unity_constant. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/expressions.py | 4 +++- python/pyspark/sql/connect/types.py | 10 ---------- python/pyspark/sql/connect/window.py | 2 +- .../sql/tests/connect/test_connect_column.py | 2 +- python/pyspark/sql/types.py | 16 +++++++--------- python/pyspark/sql/utils.py | 9 +++------ python/pyspark/util.py | 10 ++++++++++ 7 files changed, 25 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 4dc54793ed81b..8cd386ba03aea 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -61,7 +61,7 @@ ) import pyspark.sql.connect.proto as proto -from pyspark.sql.connect.types import ( +from pyspark.util import ( JVM_BYTE_MIN, JVM_BYTE_MAX, JVM_SHORT_MIN, @@ -70,6 +70,8 @@ JVM_INT_MAX, JVM_LONG_MIN, JVM_LONG_MAX, +) +from pyspark.sql.connect.types import ( UnparsedDataType, pyspark_types_to_proto_types, proto_schema_to_pyspark_data_type, diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index f058c6390612a..351fa01659657 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -55,16 +55,6 @@ import pyspark.sql.connect.proto as pb2 -JVM_BYTE_MIN: int = -(1 << 7) -JVM_BYTE_MAX: int = (1 << 7) - 1 -JVM_SHORT_MIN: int = -(1 << 15) -JVM_SHORT_MAX: int = (1 << 15) - 1 -JVM_INT_MIN: int = -(1 << 31) -JVM_INT_MAX: int = (1 << 31) - 1 -JVM_LONG_MIN: int = -(1 << 63) -JVM_LONG_MAX: int = (1 << 63) - 1 - - class UnparsedDataType(DataType): """ Unparsed data type. diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py index 04cf4c91d3207..6fc1a1fac1e3d 100644 --- a/python/pyspark/sql/connect/window.py +++ b/python/pyspark/sql/connect/window.py @@ -27,7 +27,7 @@ Expression, SortOrder, ) -from pyspark.sql.connect.types import ( +from pyspark.util import ( JVM_LONG_MIN, JVM_LONG_MAX, ) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index a9e3adb972e95..9a850dcae6f53 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -51,7 +51,7 @@ from pyspark.sql.connect import functions as CF from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import DistributedSequenceID, LiteralExpression - from pyspark.sql.connect.types import ( + from pyspark.util import ( JVM_BYTE_MIN, JVM_BYTE_MAX, JVM_SHORT_MIN, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fa98d09a9af9a..ee0cc9db5c445 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -45,7 +45,7 @@ TYPE_CHECKING, ) -from pyspark.util import is_remote_only +from pyspark.util import is_remote_only, JVM_INT_MAX from pyspark.serializers import CloudPickleSerializer from pyspark.sql.utils import ( has_numpy, @@ -104,8 +104,6 @@ "VariantVal", ] -_JVM_INT_MAX: int = (1 << 31) - 1 - class DataType: """Base class for data types.""" @@ -756,7 +754,7 @@ def _build_formatted_string( self, prefix: str, stringConcat: StringConcat, - maxDepth: int = _JVM_INT_MAX, + maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: stringConcat.append( @@ -905,7 +903,7 @@ def _build_formatted_string( self, prefix: str, stringConcat: StringConcat, - maxDepth: int = _JVM_INT_MAX, + maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n") @@ -1072,7 +1070,7 @@ def _build_formatted_string( self, prefix: str, stringConcat: StringConcat, - maxDepth: int = _JVM_INT_MAX, + maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: stringConcat.append( @@ -1507,16 +1505,16 @@ def _build_formatted_string( self, prefix: str, stringConcat: StringConcat, - maxDepth: int = _JVM_INT_MAX, + maxDepth: int = JVM_INT_MAX, ) -> None: for field in self.fields: field._build_formatted_string(prefix, stringConcat, maxDepth) - def treeString(self, maxDepth: int = _JVM_INT_MAX) -> str: + def treeString(self, maxDepth: int = JVM_INT_MAX) -> str: stringConcat = StringConcat() stringConcat.append("root\n") prefix = " |" - depth = maxDepth if maxDepth > 0 else _JVM_INT_MAX + depth = maxDepth if maxDepth > 0 else JVM_INT_MAX for field in self.fields: field._build_formatted_string(prefix, stringConcat, depth) return stringConcat.toString() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 171f92e557a12..33e01ba378c49 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -43,7 +43,7 @@ PySparkNotImplementedError, PySparkRuntimeError, ) -from pyspark.util import is_remote_only +from pyspark.util import is_remote_only, JVM_INT_MAX from pyspark.errors.exceptions.captured import CapturedException # noqa: F401 from pyspark.find_spark_home import _find_spark_home @@ -136,11 +136,8 @@ class Java: # Python implementation of 'org.apache.spark.sql.catalyst.util.StringConcat' -_MAX_ROUNDED_ARRAY_LENGTH = (1 << 31) - 1 - 15 - - class StringConcat: - def __init__(self, maxLength: int = _MAX_ROUNDED_ARRAY_LENGTH): + def __init__(self, maxLength: int = JVM_INT_MAX - 15): self.maxLength: int = maxLength self.strings: List[str] = [] self.length: int = 0 @@ -156,7 +153,7 @@ def append(self, s: str) -> None: stringToAppend = s if available >= sLen else s[0:available] self.strings.append(stringToAppend) - self.length = min(self.length + sLen, _MAX_ROUNDED_ARRAY_LENGTH) + self.length = min(self.length + sLen, JVM_INT_MAX - 15) def toString(self) -> str: # finalLength = self.maxLength if self.atLimit() else self.length diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 4920ba957c192..49766913e6ee2 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -71,6 +71,16 @@ from pyspark.sql import SparkSession +JVM_BYTE_MIN: int = -(1 << 7) +JVM_BYTE_MAX: int = (1 << 7) - 1 +JVM_SHORT_MIN: int = -(1 << 15) +JVM_SHORT_MAX: int = (1 << 15) - 1 +JVM_INT_MIN: int = -(1 << 31) +JVM_INT_MAX: int = (1 << 31) - 1 +JVM_LONG_MIN: int = -(1 << 63) +JVM_LONG_MAX: int = (1 << 63) - 1 + + def print_exec(stream: TextIO) -> None: ei = sys.exc_info() traceback.print_exception(ei[0], ei[1], ei[2], None, stream) From a48365dd98c9e52b5648d1cc0af203a7290cb1dc Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 23 May 2024 10:27:16 +0800 Subject: [PATCH 03/45] [SPARK-48387][SQL] Postgres: Map TimestampType to TIMESTAMP WITH TIME ZONE ### What changes were proposed in this pull request? Currently, Both TimestampType/TimestampNTZType are mapped to TIMESTAMP WITHOUT TIME ZONE for writing while being differentiated for reading. In this PR, we map TimestampType to TIMESTAMP WITH TIME ZONE to differentiate TimestampType/TimestampNTZType for writing against Postgres. ### Why are the changes needed? TimestampType <-> TIMESTAMP WITHOUT TIME ZONE is incorrect and ambiguous with TimestampNTZType ### Does this PR introduce _any_ user-facing change? Yes migration guide and legacy configuration provided ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46701 from yaooqinn/SPARK-48387. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../sql/jdbc/PostgresIntegrationSuite.scala | 46 +++++++++++++++++++ docs/sql-data-sources-jdbc.md | 4 +- docs/sql-migration-guide.md | 3 +- .../apache/spark/sql/internal/SQLConf.scala | 14 ++++++ .../spark/sql/jdbc/PostgresDialect.scala | 6 ++- 5 files changed, 68 insertions(+), 5 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index dd6f1bfd3b3f4..5ad4f15216b74 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -583,4 +584,49 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(cause.getSQLState === "22003") } } + + test("SPARK-48387: Timestamp write as timestamp with time zone") { + val df = spark.sql("select TIMESTAMP '2018-11-17 13:33:33' as col0") + // write timestamps for preparation + withSQLConf(SQLConf.LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED.key -> "false") { + // write timestamp as timestamp with time zone + df.write.jdbc(jdbcUrl, "ts_with_timezone_copy_false", new Properties) + } + withSQLConf(SQLConf.LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED.key -> "true") { + // write timestamp as timestamp without time zone + df.write.jdbc(jdbcUrl, "ts_with_timezone_copy_true", new Properties) + } + + // read timestamps for test + withSQLConf(SQLConf.LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED.key -> "true") { + val df1 = spark.read.option("preferTimestampNTZ", false) + .jdbc(jdbcUrl, "ts_with_timezone_copy_false", new Properties) + checkAnswer(df1, Row(Timestamp.valueOf("2018-11-17 13:33:33"))) + val df2 = spark.read.option("preferTimestampNTZ", true) + .jdbc(jdbcUrl, "ts_with_timezone_copy_false", new Properties) + checkAnswer(df2, Row(LocalDateTime.of(2018, 11, 17, 13, 33, 33))) + + val df3 = spark.read.option("preferTimestampNTZ", false) + .jdbc(jdbcUrl, "ts_with_timezone_copy_true", new Properties) + checkAnswer(df3, Row(Timestamp.valueOf("2018-11-17 13:33:33"))) + val df4 = spark.read.option("preferTimestampNTZ", true) + .jdbc(jdbcUrl, "ts_with_timezone_copy_true", new Properties) + checkAnswer(df4, Row(LocalDateTime.of(2018, 11, 17, 13, 33, 33))) + } + withSQLConf(SQLConf.LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED.key -> "false") { + Seq("true", "false").foreach { prefer => + val prop = new Properties + prop.setProperty("preferTimestampNTZ", prefer) + val dfCopy = spark.read.jdbc(jdbcUrl, "ts_with_timezone_copy_false", prop) + checkAnswer(dfCopy, Row(Timestamp.valueOf("2018-11-17 13:33:33"))) + } + + val df5 = spark.read.option("preferTimestampNTZ", false) + .jdbc(jdbcUrl, "ts_with_timezone_copy_true", new Properties) + checkAnswer(df5, Row(Timestamp.valueOf("2018-11-17 13:33:33"))) + val df6 = spark.read.option("preferTimestampNTZ", true) + .jdbc(jdbcUrl, "ts_with_timezone_copy_true", new Properties) + checkAnswer(df6, Row(LocalDateTime.of(2018, 11, 17, 13, 33, 33))) + } + } } diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 54a8506bff51e..371dc05950717 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -1074,8 +1074,8 @@ the [PostgreSQL JDBC Driver](https://mvnrepository.com/artifact/org.postgresql/p TimestampType - timestamp - + timestamp with time zone + Before Spark 4.0, it was mapped as timestamp. Please refer to the migration guide for more information TimestampNTZType diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index e668a9f9ef754..8f6a415569863 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -41,7 +41,8 @@ license: | - `spark.sql.avro.datetimeRebaseModeInRead` instead of `spark.sql.legacy.avro.datetimeRebaseModeInRead` - Since Spark 4.0, the default value of `spark.sql.orc.compression.codec` is changed from `snappy` to `zstd`. To restore the previous behavior, set `spark.sql.orc.compression.codec` to `snappy`. - Since Spark 4.0, the SQL config `spark.sql.legacy.allowZeroIndexInFormatString` is deprecated. Consider to change `strfmt` of the `format_string` function to use 1-based indexes. The first argument must be referenced by "1$", the second by "2$", etc. -- Since Spark 4.0, JDBC read option `preferTimestampNTZ=true` will not convert Postgres TIMESTAMP WITH TIME ZONE and TIME WITH TIME ZONE data types to TimestampNTZType, which is available in Spark 3.5. +- Since Spark 4.0, Postgres JDBC datasource will read JDBC read TIMESTAMP WITH TIME ZONE as TimestampType regardless of the JDBC read option `preferTimestampNTZ`, while in 3.5 and previous, TimestampNTZType when `preferTimestampNTZ=true`. To restore the previous behavior, set `spark.sql.legacy.postgres.datetimeMapping.enabled` to `true`. +- Since Spark 4.0, Postgres JDBC datasource will write TimestampType as TIMESTAMP WITH TIME ZONE, while in 3.5 and previous, it wrote as TIMESTAMP a.k.a. TIMESTAMP WITHOUT TIME ZONE. To restore the previous behavior, set `spark.sql.legacy.postgres.datetimeMapping.enabled` to `true`. - Since Spark 4.0, MySQL JDBC datasource will read TIMESTAMP as TimestampType regardless of the JDBC read option `preferTimestampNTZ`, while in 3.5 and previous, TimestampNTZType when `preferTimestampNTZ=true`. To restore the previous behavior, set `spark.sql.legacy.mysql.timestampNTZMapping.enabled` to `true`, MySQL DATETIME is not affected. - Since Spark 4.0, MySQL JDBC datasource will read SMALLINT as ShortType, while in Spark 3.5 and previous, it was read as IntegerType. MEDIUMINT UNSIGNED is read as IntegerType, while in Spark 3.5 and previous, it was read as LongType. To restore the previous behavior, you can cast the column to the old type. - Since Spark 4.0, MySQL JDBC datasource will read FLOAT as FloatType, while in Spark 3.5 and previous, it was read as DoubleType. To restore the previous behavior, you can cast the column to the old type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 545b0a610cdfd..06e0c6eda5896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4265,6 +4265,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED = + buildConf("spark.sql.legacy.postgres.datetimeMapping.enabled") + .internal() + .doc("When true, TimestampType maps to TIMESTAMP WITHOUT TIME ZONE in PostgreSQL for " + + "writing; otherwise, TIMESTAMP WITH TIME ZONE. When true, TIMESTAMP WITH TIME ZONE " + + "can be converted to TimestampNTZType when JDBC read option preferTimestampNTZ is " + + "true; otherwise, converted to TimestampType regardless of preferTimestampNTZ.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CSV_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.csv.filterPushdown.enabled") .doc("When true, enable filter pushdown to CSV datasource.") .version("3.0.0") @@ -5410,6 +5421,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyDB2BooleanMappingEnabled: Boolean = getConf(LEGACY_DB2_BOOLEAN_MAPPING_ENABLED) + def legacyPostgresDatetimeMappingEnabled: Boolean = + getConf(LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED) + override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = { LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index f3fb115c70575..93052a0c37b59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -61,8 +61,8 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { // money type seems to be broken but one workaround is to handle it as string. // See SPARK-34333 and https://github.com/pgjdbc/pgjdbc/issues/100 Some(StringType) - case Types.TIMESTAMP - if "timestamptz".equalsIgnoreCase(typeName) => + case Types.TIMESTAMP if "timestamptz".equalsIgnoreCase(typeName) && + !conf.legacyPostgresDatetimeMappingEnabled => // timestamptz represents timestamp with time zone, currently it maps to Types.TIMESTAMP. // We need to change to Types.TIMESTAMP_WITH_TIMEZONE if the upstream changes. Some(TimestampType) @@ -149,6 +149,8 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper { case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case ShortType | ByteType => Some(JdbcType("SMALLINT", Types.SMALLINT)) + case TimestampType if !conf.legacyPostgresDatetimeMappingEnabled => + Some(JdbcType("TIMESTAMP WITH TIME ZONE", Types.TIMESTAMP)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] || et.isInstanceOf[ArrayType] => From 5df9a0866ae60a42d78136a21a82a0b6e58daefa Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 23 May 2024 10:46:08 +0800 Subject: [PATCH 04/45] [SPARK-48386][TESTS] Replace JVM assert with JUnit Assert in tests ### What changes were proposed in this pull request? The pr aims to replace `JVM assert` with `JUnit Assert` in tests. ### Why are the changes needed? assert() statements do not produce as useful errors when they fail, and, if they were somehow disabled, would fail to test anything. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Manually test. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46698 from panbingkun/minor_assert. Authored-by: panbingkun Signed-off-by: yangjie01 --- .../EncryptedMessageWithHeaderSuite.java | 2 +- .../shuffle/RetryingBlockTransferorSuite.java | 8 ++--- .../spark/util/SparkLoggerSuiteBase.java | 30 ++++++++++--------- .../spark/sql/TestStatefulProcessor.java | 10 ++++--- ...TestStatefulProcessorWithInitialState.java | 4 ++- .../JavaAdvancedDataSourceV2WithV2Filter.java | 14 +++++---- 6 files changed, 38 insertions(+), 30 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java index 7478fa1db7113..2865d411bf673 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java @@ -116,7 +116,7 @@ public void testChunkedStream() throws Exception { // Validate we read data correctly assertEquals(bodyResult.readableBytes(), chunkSize); - assert(bodyResult.readableBytes() < (randomData.length - readIndex)); + assertTrue(bodyResult.readableBytes() < (randomData.length - readIndex)); while (bodyResult.readableBytes() > 0) { assertEquals(bodyResult.readByte(), randomData[readIndex++]); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 3725973ae7333..84c8b1b3353f2 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -288,7 +288,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0); verify(listener).getTransferType(); verifyNoMoreInteractions(listener); - assert(_retryingBlockTransferor.getRetryCount() == 0); + assertEquals(0, _retryingBlockTransferor.getRetryCount()); } @Test @@ -310,7 +310,7 @@ public void testRepeatedSaslRetryFailures() throws IOException, InterruptedExcep verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); verify(listener, times(3)).getTransferType(); verifyNoMoreInteractions(listener); - assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES); + assertEquals(MAX_RETRIES, _retryingBlockTransferor.getRetryCount()); } @Test @@ -339,7 +339,7 @@ public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedE // This should be equal to 1 because after the SASL exception is retried, // retryCount should be set back to 0. Then after that b1 encounters an // exception that is retried. - assert(_retryingBlockTransferor.getRetryCount() == 1); + assertEquals(1, _retryingBlockTransferor.getRetryCount()); } @Test @@ -368,7 +368,7 @@ public void testIOExceptionFailsConnectionEvenWithSaslException() verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslExceptionFinal); verify(listener, atLeastOnce()).getTransferType(); verifyNoMoreInteractions(listener); - assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES); + assertEquals(MAX_RETRIES, _retryingBlockTransferor.getRetryCount()); } @Test diff --git a/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java b/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java index 46bfe3415080d..0869f9827324d 100644 --- a/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java +++ b/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java @@ -30,6 +30,8 @@ import org.apache.spark.internal.LogKeys; import org.apache.spark.internal.MDC; +import static org.junit.jupiter.api.Assertions.assertTrue; + public abstract class SparkLoggerSuiteBase { abstract SparkLogger logger(); @@ -104,8 +106,8 @@ public void testBasicMsgLogger() { Pair.of(Level.DEBUG, debugFn), Pair.of(Level.TRACE, traceFn)).forEach(pair -> { try { - assert (captureLogOutput(pair.getRight()).matches( - expectedPatternForBasicMsg(pair.getLeft()))); + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForBasicMsg(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } @@ -127,8 +129,8 @@ public void testBasicLoggerWithException() { Pair.of(Level.DEBUG, debugFn), Pair.of(Level.TRACE, traceFn)).forEach(pair -> { try { - assert (captureLogOutput(pair.getRight()).matches( - expectedPatternForBasicMsgWithException(pair.getLeft()))); + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForBasicMsgWithException(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } @@ -147,8 +149,8 @@ public void testLoggerWithMDC() { Pair.of(Level.WARN, warnFn), Pair.of(Level.INFO, infoFn)).forEach(pair -> { try { - assert (captureLogOutput(pair.getRight()).matches( - expectedPatternForMsgWithMDC(pair.getLeft()))); + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForMsgWithMDC(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } @@ -165,8 +167,8 @@ public void testLoggerWithMDCs() { Pair.of(Level.WARN, warnFn), Pair.of(Level.INFO, infoFn)).forEach(pair -> { try { - assert (captureLogOutput(pair.getRight()).matches( - expectedPatternForMsgWithMDCs(pair.getLeft()))); + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForMsgWithMDCs(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } @@ -184,8 +186,8 @@ public void testLoggerWithMDCsAndException() { Pair.of(Level.WARN, warnFn), Pair.of(Level.INFO, infoFn)).forEach(pair -> { try { - assert (captureLogOutput(pair.getRight()).matches( - expectedPatternForMsgWithMDCsAndException(pair.getLeft()))); + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForMsgWithMDCsAndException(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } @@ -202,8 +204,8 @@ public void testLoggerWithMDCValueIsNull() { Pair.of(Level.WARN, warnFn), Pair.of(Level.INFO, infoFn)).forEach(pair -> { try { - assert (captureLogOutput(pair.getRight()).matches( - expectedPatternForMsgWithMDCValueIsNull(pair.getLeft()))); + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForMsgWithMDCValueIsNull(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } @@ -220,8 +222,8 @@ public void testLoggerWithExternalSystemCustomLogKey() { Pair.of(Level.WARN, warnFn), Pair.of(Level.INFO, infoFn)).forEach(pair -> { try { - assert (captureLogOutput(pair.getRight()).matches( - expectedPatternForExternalSystemCustomLogKey(pair.getLeft()))); + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForExternalSystemCustomLogKey(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java index e53e977da1494..b9841ee0f9735 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java @@ -24,6 +24,8 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.streaming.*; +import static org.junit.jupiter.api.Assertions.*; + /** * A test stateful processor used with transformWithState arbitrary stateful operator in * Structured Streaming. The processor primarily aims to test various functionality of the Java API @@ -74,7 +76,7 @@ public scala.collection.Iterator handleInputRows( } else { keyCountMap.updateValue(value, 1L); } - assert(keyCountMap.containsKey(value)); + assertTrue(keyCountMap.containsKey(value)); keysList.appendValue(value); sb.append(value); } @@ -82,13 +84,13 @@ public scala.collection.Iterator handleInputRows( scala.collection.Iterator keys = keysList.get(); while (keys.hasNext()) { String keyVal = keys.next(); - assert(keyCountMap.containsKey(keyVal)); - assert(keyCountMap.getValue(keyVal) > 0); + assertTrue(keyCountMap.containsKey(keyVal)); + assertTrue(keyCountMap.getValue(keyVal) > 0); } count += numRows; countState.update(count); - assert (countState.get() == count); + assertEquals(count, (long) countState.get()); result.add(sb.toString()); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java index bfa542e81e354..55046a7c0d3df 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java @@ -24,6 +24,8 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.streaming.*; +import static org.junit.jupiter.api.Assertions.assertFalse; + /** * A test stateful processor concatenates all input rows for a key and emits the result. * Primarily used for testing the Java API for arbitrary stateful operator in structured streaming @@ -71,7 +73,7 @@ public scala.collection.Iterator handleInputRows( } testState.clear(); - assert(testState.exists() == false); + assertFalse(testState.exists()); result.add(sb.toString()); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java index 0e3f6aed3b681..07bef16cdf2da 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java @@ -34,6 +34,8 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + public class JavaAdvancedDataSourceV2WithV2Filter implements TestingV2Source { @Override @@ -66,9 +68,9 @@ public StructType readSchema() { public Predicate[] pushPredicates(Predicate[] predicates) { Predicate[] supported = Arrays.stream(predicates).filter(f -> { if (f.name().equals(">")) { - assert(f.children()[0] instanceof FieldReference); + assertInstanceOf(FieldReference.class, f.children()[0]); FieldReference column = (FieldReference) f.children()[0]; - assert(f.children()[1] instanceof LiteralValue); + assertInstanceOf(LiteralValue.class, f.children()[1]); Literal value = (Literal) f.children()[1]; return column.describe().equals("i") && value.value() instanceof Integer; } else { @@ -78,9 +80,9 @@ public Predicate[] pushPredicates(Predicate[] predicates) { Predicate[] unsupported = Arrays.stream(predicates).filter(f -> { if (f.name().equals(">")) { - assert(f.children()[0] instanceof FieldReference); + assertInstanceOf(FieldReference.class, f.children()[0]); FieldReference column = (FieldReference) f.children()[0]; - assert(f.children()[1] instanceof LiteralValue); + assertInstanceOf(LiteralValue.class, f.children()[1]); Literal value = (LiteralValue) f.children()[1]; return !column.describe().equals("i") || !(value.value() instanceof Integer); } else { @@ -125,9 +127,9 @@ public InputPartition[] planInputPartitions() { Integer lowerBound = null; for (Predicate predicate : predicates) { if (predicate.name().equals(">")) { - assert(predicate.children()[0] instanceof FieldReference); + assertInstanceOf(FieldReference.class, predicate.children()[0]); FieldReference column = (FieldReference) predicate.children()[0]; - assert(predicate.children()[1] instanceof LiteralValue); + assertInstanceOf(LiteralValue.class, predicate.children()[1]); Literal value = (Literal) predicate.children()[1]; if ("i".equals(column.describe()) && value.value() instanceof Integer integer) { lowerBound = integer; From a393d6cdf00ce95b2a3fb4bd15bfc4d82883d1d2 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 23 May 2024 12:19:07 +0900 Subject: [PATCH 05/45] [SPARK-48370][CONNECT] Checkpoint and localCheckpoint in Scala Spark Connect client ### What changes were proposed in this pull request? This PR adds `Dataset.checkpoint` and `Dataset.localCheckpoint` into Scala Spark Connect client. Python API was implemented at https://github.com/apache/spark/pull/46570 ### Why are the changes needed? For API parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds `Dataset.checkpoint` and `Dataset.localCheckpoint` into Scala Spark Connect client. ### How was this patch tested? Unittests added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46683 from HyukjinKwon/SPARK-48370. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/Dataset.scala | 107 +++++++++++-- .../org/apache/spark/sql/SparkSession.scala | 10 +- .../spark/sql/internal/SessionCleaner.scala | 146 ++++++++++++++++++ .../apache/spark/sql/CheckpointSuite.scala | 117 ++++++++++++++ .../CheckConnectJvmClientCompatibility.scala | 10 ++ .../connect/client/SparkConnectClient.scala | 2 +- 6 files changed, 379 insertions(+), 13 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala create mode 100644 connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 37f770319b695..fc9766357cb22 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3402,20 +3402,105 @@ class Dataset[T] private[sql] ( df } - def checkpoint(): Dataset[T] = { - throw new UnsupportedOperationException("checkpoint is not implemented.") - } + /** + * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. + * + * @group basic + * @since 4.0.0 + */ + def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) - def checkpoint(eager: Boolean): Dataset[T] = { - throw new UnsupportedOperationException("checkpoint is not implemented.") - } + /** + * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the + * logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint directory set + * with `SparkContext#setCheckpointDir`. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. + * + * @group basic + * @since 4.0.0 + */ + def checkpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = true) - def localCheckpoint(): Dataset[T] = { - throw new UnsupportedOperationException("localCheckpoint is not implemented.") - } + /** + * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used + * to truncate the logical plan of this Dataset, which is especially useful in iterative + * algorithms where the plan may grow exponentially. Local checkpoints are written to executor + * storage and despite potentially faster they are unreliable and may compromise job completion. + * + * @group basic + * @since 4.0.0 + */ + def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) + + /** + * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. Local checkpoints are written to executor storage and + * despite potentially faster they are unreliable and may compromise job completion. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. + * + * @group basic + * @since 4.0.0 + */ + def localCheckpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = false) - def localCheckpoint(eager: Boolean): Dataset[T] = { - throw new UnsupportedOperationException("localCheckpoint is not implemented.") + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * @param reliableCheckpoint + * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If + * false creates a local checkpoint using the caching subsystem + */ + private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + sparkSession.newDataset(agnosticEncoder) { builder => + val command = sparkSession.newCommand { builder => + builder.getCheckpointCommandBuilder + .setLocal(reliableCheckpoint) + .setEager(eager) + .setRelation(this.plan.getRoot) + } + val responseIter = sparkSession.execute(command) + try { + val response = responseIter + .find(_.hasCheckpointCommandResult) + .getOrElse(throw new RuntimeException("CheckpointCommandResult must be present")) + + val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation + sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation) + + // Update the builder with the values from the result. + builder.setCachedRemoteRelation(cachedRemoteRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } + } } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1188fba60a2fe..91ee0f52e8bd0 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} +import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType @@ -73,6 +73,11 @@ class SparkSession private[sql] ( with Logging { private[this] val allocator = new RootAllocator() + private var shouldStopCleaner = false + private[sql] lazy val cleaner = { + shouldStopCleaner = true + new SessionCleaner(this) + } // a unique session ID for this session from client. private[sql] def sessionId: String = client.sessionId @@ -714,6 +719,9 @@ class SparkSession private[sql] ( if (releaseSessionOnClose) { client.releaseSession() } + if (shouldStopCleaner) { + cleaner.stop() + } client.shutdown() allocator.close() SparkSession.onSessionClose(this) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala new file mode 100644 index 0000000000000..036ea4a84fa97 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession + +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask + +/** + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for objects. + * + * This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed + * when the associated object goes out of scope of the application. Actual cleanup is performed in + * a separate daemon thread. + */ +private[sql] class SessionCleaner(session: SparkSession) extends Logging { + + /** + * How often (seconds) to trigger a garbage collection in this JVM. This context cleaner + * triggers cleanups only when weak references are garbage collected. In long-running + * applications with large driver JVMs, where there is little memory pressure on the driver, + * this may happen very occasionally or not at all. Not cleaning at all may lead to executors + * running out of disk space after a while. + */ + private val refQueuePollTimeout: Long = 100 + + /** + * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they + * have not been handled by the reference queue. + */ + private val referenceBuffer = + Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap) + + private val referenceQueue = new ReferenceQueue[AnyRef] + + private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() } + + @volatile private var started = false + @volatile private var stopped = false + + /** Start the cleaner. */ + def start(): Unit = { + cleaningThread.setDaemon(true) + cleaningThread.setName("Spark Connect Context Cleaner") + cleaningThread.start() + } + + /** + * Stop the cleaning thread and wait until the thread has finished running its current task. + */ + def stop(): Unit = { + stopped = true + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkSession. + synchronized { + cleaningThread.interrupt() + } + cleaningThread.join() + } + + /** Register a CachedRemoteRelation for cleanup when it is garbage collected. */ + def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = { + registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId)) + } + + /** Register an object for cleanup. */ + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { + if (!started) { + // Lazily starts when the first cleanup is registered. + start() + started = true + } + referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) + } + + /** Keep cleaning objects. */ + private def keepCleaning(): Unit = { + while (!stopped && !session.client.channel.isShutdown) { + try { + val reference = Option(referenceQueue.remove(refQueuePollTimeout)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.foreach { ref => + logDebug("Got cleaning task " + ref.task) + referenceBuffer.remove(ref) + ref.task match { + case CleanupCachedRemoteRelation(dfID) => + doCleanupCachedRemoteRelation(dfID) + } + } + } + } catch { + case e: Throwable => logError("Error in cleaning thread", e) + } + } + } + + /** Perform CleanupCachedRemoteRelation cleanup. */ + private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = { + session.execute { + session.newCommand { builder => + builder.getRemoveCachedRemoteRelationCommandBuilder + .setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build()) + } + } + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala new file mode 100644 index 0000000000000..e57b051890f56 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.io.{ByteArrayOutputStream, PrintStream} + +import scala.concurrent.duration.DurationInt + +import org.apache.commons.io.output.TeeOutputStream +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} +import org.scalatest.exceptions.TestFailedDueToTimeoutException + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} + +class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { + + private def captureStdOut(block: => Unit): String = { + val currentOut = Console.out + val capturedOut = new ByteArrayOutputStream() + val newOut = new PrintStream(new TeeOutputStream(currentOut, capturedOut)) + Console.withOut(newOut) { + block + } + capturedOut.toString + } + + private def checkFragments(result: String, fragmentsToCheck: Seq[String]): Unit = { + fragmentsToCheck.foreach { fragment => + assert(result.contains(fragment)) + } + } + + private def testCapturedStdOut(block: => Unit, fragmentsToCheck: String*): Unit = { + checkFragments(captureStdOut(block), fragmentsToCheck) + } + + test("checkpoint") { + val df = spark.range(100).localCheckpoint() + testCapturedStdOut(df.explain(), "ExistingRDD") + } + + test("checkpoint gc") { + val df = spark.range(100).localCheckpoint(eager = true) + val encoder = df.agnosticEncoder + val dfId = df.plan.getRoot.getCachedRemoteRelation.getRelationId + spark.cleaner.doCleanupCachedRemoteRelation(dfId) + + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } + + // This test is flaky because cannot guarantee GC + // You can locally run this to verify the behavior. + ignore("checkpoint gc derived DataFrame") { + var df1 = spark.range(100).localCheckpoint(eager = true) + var derived = df1.repartition(10) + val encoder = df1.agnosticEncoder + val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId + + df1 = null + System.gc() + Thread.sleep(3000L) + + def condition(): Unit = { + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } + + intercept[TestFailedDueToTimeoutException] { + eventually(timeout(5.seconds), interval(1.second))(condition()) + } + + // GC triggers remove the cached remote relation + derived = null + System.gc() + Thread.sleep(3000L) + + // Check the state was removed up on garbage-collection. + eventually(timeout(60.seconds), interval(1.second))(condition()) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 374d8464deebf..2e4bbab8d3a41 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -334,6 +334,16 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[ReversedMissingMethodProblem]( "org.apache.spark.sql.SQLImplicits._sqlContext" // protected ), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.SessionCleaner"), + + // private + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.CleanupTask"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupTaskWeakReference"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupCachedRemoteRelation"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupCachedRemoteRelation$"), // Catalyst Refactoring ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils"), diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 1e7b4e6574ddb..b5eda024bfb3c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.connect.common.config.ConnectCommon */ private[sql] class SparkConnectClient( private[sql] val configuration: SparkConnectClient.Configuration, - private val channel: ManagedChannel) { + private[sql] val channel: ManagedChannel) { private val userContext: UserContext = configuration.userContext From e8f58a9c4a641b830c5304b34b876e0cd5d3ed8e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 23 May 2024 13:50:34 +0900 Subject: [PATCH 06/45] [SPARK-48370][SPARK-48258][CONNECT][PYTHON][FOLLOW-UP] Refactor local and eager required fields in CheckpointCommand ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/46683 and https://github.com/apache/spark/pull/46570 that refactors `local` and `eager` required fields in `CheckpointCommand` ### Why are the changes needed? To make the code easier to maintain. ### Does this PR introduce _any_ user-facing change? No, the main change has not been released yet. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46712 from HyukjinKwon/SPARK-48370-SPARK-48258-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../protobuf/spark/connect/commands.proto | 8 ++-- .../connect/planner/SparkConnectPlanner.scala | 12 ++---- python/pyspark/sql/connect/dataframe.py | 2 +- .../pyspark/sql/connect/proto/commands_pb2.py | 10 ++--- .../sql/connect/proto/commands_pb2.pyi | 41 +++---------------- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- 7 files changed, 21 insertions(+), 56 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index fc9766357cb22..5ac07270b22b3 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3481,7 +3481,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataset(agnosticEncoder) { builder => val command = sparkSession.newCommand { builder => builder.getCheckpointCommandBuilder - .setLocal(reliableCheckpoint) + .setLocal(!reliableCheckpoint) .setEager(eager) .setRelation(this.plan.getRoot) } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index c526f8d3f65d4..0e0c55fa34f00 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -497,10 +497,10 @@ message CheckpointCommand { // (Required) The logical plan to checkpoint. Relation relation = 1; - // (Optional) Locally checkpoint using a local temporary + // (Required) Locally checkpoint using a local temporary // directory in Spark Connect server (Spark Driver) - optional bool local = 2; + bool local = 2; - // (Optional) Whether to checkpoint this dataframe immediately. - optional bool eager = 3; + // (Required) Whether to checkpoint this dataframe immediately. + bool eager = 3; } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cbc60d2873f91..a339469e61cdf 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3523,15 +3523,9 @@ class SparkConnectPlanner( responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { val target = Dataset .ofRows(session, transformRelation(checkpointCommand.getRelation)) - val checkpointed = if (checkpointCommand.hasLocal && checkpointCommand.hasEager) { - target.localCheckpoint(eager = checkpointCommand.getEager) - } else if (checkpointCommand.hasLocal) { - target.localCheckpoint() - } else if (checkpointCommand.hasEager) { - target.checkpoint(eager = checkpointCommand.getEager) - } else { - target.checkpoint() - } + val checkpointed = target.checkpoint( + eager = checkpointCommand.getEager, + reliableCheckpoint = !checkpointCommand.getLocal) val dfId = UUID.randomUUID().toString logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}") diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 510776bb752d3..62c73da374bc9 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -2096,7 +2096,7 @@ def offset(self, n: int) -> ParentDataFrame: return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) def checkpoint(self, eager: bool = True) -> "DataFrame": - cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) + cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager) _, properties = self._session.client.execute_command(cmd.command(self._session.client)) assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 43673d9707a9b..8f67f817c3f00 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xaf\x0c\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x81\x01\n$streaming_query_listener_bus_command\x18\x0b \x01(\x0b\x32/.spark.connect.StreamingQueryListenerBusCommandH\x00R streamingQueryListenerBusCommand\x12\x64\n\x14register_data_source\x18\x0c \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R\x12registerDataSource\x12t\n\x1f\x63reate_resource_profile_command\x18\r \x01(\x0b\x32+.spark.connect.CreateResourceProfileCommandH\x00R\x1c\x63reateResourceProfileCommand\x12Q\n\x12\x63heckpoint_command\x18\x0e \x01(\x0b\x32 .spark.connect.CheckpointCommandH\x00R\x11\x63heckpointCommand\x12\x84\x01\n%remove_cached_remote_relation_command\x18\x0f \x01(\x0b\x32\x30.spark.connect.RemoveCachedRemoteRelationCommandH\x00R!removeCachedRemoteRelationCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xaa\x04\n\nSqlCommand\x12\x14\n\x03sql\x18\x01 \x01(\tB\x02\x18\x01R\x03sql\x12;\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12Z\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32-.spark.connect.SqlCommand.NamedArgumentsEntryB\x02\x18\x01R\x0enamedArguments\x12\x42\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionB\x02\x18\x01R\x0cposArguments\x12-\n\x05input\x18\x06 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xca\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x12-\n\x12\x63lustering_columns\x18\n \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xdc\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x12-\n\x12\x63lustering_columns\x18\t \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xa0\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"\xd4\x01\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12<\n\x18query_started_event_json\x18\x03 \x01(\tH\x00R\x15queryStartedEventJson\x88\x01\x01\x42\x1b\n\x19_query_started_event_json"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xbd\x06\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1a\xcd\x01\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x12U\n\x17python_listener_payload\x18\x02 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x15pythonListenerPayload\x88\x01\x01\x12\x0e\n\x02id\x18\x03 \x01(\tR\x02idB\x1a\n\x18_python_listener_payloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\xad\x01\n StreamingQueryListenerBusCommand\x12;\n\x19\x61\x64\x64_listener_bus_listener\x18\x01 \x01(\x08H\x00R\x16\x61\x64\x64ListenerBusListener\x12\x41\n\x1cremove_listener_bus_listener\x18\x02 \x01(\x08H\x00R\x19removeListenerBusListenerB\t\n\x07\x63ommand"\x83\x01\n\x1bStreamingQueryListenerEvent\x12\x1d\n\nevent_json\x18\x01 \x01(\tR\teventJson\x12\x45\n\nevent_type\x18\x02 \x01(\x0e\x32&.spark.connect.StreamingQueryEventTypeR\teventType"\xcc\x01\n"StreamingQueryListenerEventsResult\x12\x42\n\x06\x65vents\x18\x01 \x03(\x0b\x32*.spark.connect.StreamingQueryListenerEventR\x06\x65vents\x12\x42\n\x1blistener_bus_listener_added\x18\x02 \x01(\x08H\x00R\x18listenerBusListenerAdded\x88\x01\x01\x42\x1e\n\x1c_listener_bus_listener_added"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01"X\n\x1c\x43reateResourceProfileCommand\x12\x38\n\x07profile\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ResourceProfileR\x07profile"C\n"CreateResourceProfileCommandResult\x12\x1d\n\nprofile_id\x18\x01 \x01(\x05R\tprofileId"d\n!RemoveCachedRemoteRelationCommand\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"\x92\x01\n\x11\x43heckpointCommand\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x19\n\x05local\x18\x02 \x01(\x08H\x00R\x05local\x88\x01\x01\x12\x19\n\x05\x65\x61ger\x18\x03 \x01(\x08H\x01R\x05\x65\x61ger\x88\x01\x01\x42\x08\n\x06_localB\x08\n\x06_eager*\x85\x01\n\x17StreamingQueryEventType\x12\x1e\n\x1aQUERY_PROGRESS_UNSPECIFIED\x10\x00\x12\x18\n\x14QUERY_PROGRESS_EVENT\x10\x01\x12\x1a\n\x16QUERY_TERMINATED_EVENT\x10\x02\x12\x14\n\x10QUERY_IDLE_EVENT\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xaf\x0c\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x81\x01\n$streaming_query_listener_bus_command\x18\x0b \x01(\x0b\x32/.spark.connect.StreamingQueryListenerBusCommandH\x00R streamingQueryListenerBusCommand\x12\x64\n\x14register_data_source\x18\x0c \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R\x12registerDataSource\x12t\n\x1f\x63reate_resource_profile_command\x18\r \x01(\x0b\x32+.spark.connect.CreateResourceProfileCommandH\x00R\x1c\x63reateResourceProfileCommand\x12Q\n\x12\x63heckpoint_command\x18\x0e \x01(\x0b\x32 .spark.connect.CheckpointCommandH\x00R\x11\x63heckpointCommand\x12\x84\x01\n%remove_cached_remote_relation_command\x18\x0f \x01(\x0b\x32\x30.spark.connect.RemoveCachedRemoteRelationCommandH\x00R!removeCachedRemoteRelationCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xaa\x04\n\nSqlCommand\x12\x14\n\x03sql\x18\x01 \x01(\tB\x02\x18\x01R\x03sql\x12;\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12Z\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32-.spark.connect.SqlCommand.NamedArgumentsEntryB\x02\x18\x01R\x0enamedArguments\x12\x42\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionB\x02\x18\x01R\x0cposArguments\x12-\n\x05input\x18\x06 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xca\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x12-\n\x12\x63lustering_columns\x18\n \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xdc\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x12-\n\x12\x63lustering_columns\x18\t \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xa0\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"\xd4\x01\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12<\n\x18query_started_event_json\x18\x03 \x01(\tH\x00R\x15queryStartedEventJson\x88\x01\x01\x42\x1b\n\x19_query_started_event_json"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xbd\x06\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1a\xcd\x01\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x12U\n\x17python_listener_payload\x18\x02 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x15pythonListenerPayload\x88\x01\x01\x12\x0e\n\x02id\x18\x03 \x01(\tR\x02idB\x1a\n\x18_python_listener_payloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\xad\x01\n StreamingQueryListenerBusCommand\x12;\n\x19\x61\x64\x64_listener_bus_listener\x18\x01 \x01(\x08H\x00R\x16\x61\x64\x64ListenerBusListener\x12\x41\n\x1cremove_listener_bus_listener\x18\x02 \x01(\x08H\x00R\x19removeListenerBusListenerB\t\n\x07\x63ommand"\x83\x01\n\x1bStreamingQueryListenerEvent\x12\x1d\n\nevent_json\x18\x01 \x01(\tR\teventJson\x12\x45\n\nevent_type\x18\x02 \x01(\x0e\x32&.spark.connect.StreamingQueryEventTypeR\teventType"\xcc\x01\n"StreamingQueryListenerEventsResult\x12\x42\n\x06\x65vents\x18\x01 \x03(\x0b\x32*.spark.connect.StreamingQueryListenerEventR\x06\x65vents\x12\x42\n\x1blistener_bus_listener_added\x18\x02 \x01(\x08H\x00R\x18listenerBusListenerAdded\x88\x01\x01\x42\x1e\n\x1c_listener_bus_listener_added"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01"X\n\x1c\x43reateResourceProfileCommand\x12\x38\n\x07profile\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ResourceProfileR\x07profile"C\n"CreateResourceProfileCommandResult\x12\x1d\n\nprofile_id\x18\x01 \x01(\x05R\tprofileId"d\n!RemoveCachedRemoteRelationCommand\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"t\n\x11\x43heckpointCommand\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x14\n\x05local\x18\x02 \x01(\x08R\x05local\x12\x14\n\x05\x65\x61ger\x18\x03 \x01(\x08R\x05\x65\x61ger*\x85\x01\n\x17StreamingQueryEventType\x12\x1e\n\x1aQUERY_PROGRESS_UNSPECIFIED\x10\x00\x12\x18\n\x14QUERY_PROGRESS_EVENT\x10\x01\x12\x1a\n\x16QUERY_TERMINATED_EVENT\x10\x02\x12\x14\n\x10QUERY_IDLE_EVENT\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -71,8 +71,8 @@ _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_options = b"8\001" _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001" - _STREAMINGQUERYEVENTTYPE._serialized_start = 10549 - _STREAMINGQUERYEVENTTYPE._serialized_end = 10682 + _STREAMINGQUERYEVENTTYPE._serialized_start = 10518 + _STREAMINGQUERYEVENTTYPE._serialized_end = 10651 _COMMAND._serialized_start = 167 _COMMAND._serialized_end = 1750 _SQLCOMMAND._serialized_start = 1753 @@ -167,6 +167,6 @@ _CREATERESOURCEPROFILECOMMANDRESULT._serialized_end = 10295 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_start = 10297 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_end = 10397 - _CHECKPOINTCOMMAND._serialized_start = 10400 - _CHECKPOINTCOMMAND._serialized_end = 10546 + _CHECKPOINTCOMMAND._serialized_start = 10399 + _CHECKPOINTCOMMAND._serialized_end = 10515 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 61691abbdd855..04d50d5b5e4f4 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -2174,55 +2174,26 @@ class CheckpointCommand(google.protobuf.message.Message): def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: """(Required) The logical plan to checkpoint.""" local: builtins.bool - """(Optional) Locally checkpoint using a local temporary + """(Required) Locally checkpoint using a local temporary directory in Spark Connect server (Spark Driver) """ eager: builtins.bool - """(Optional) Whether to checkpoint this dataframe immediately.""" + """(Required) Whether to checkpoint this dataframe immediately.""" def __init__( self, *, relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., - local: builtins.bool | None = ..., - eager: builtins.bool | None = ..., + local: builtins.bool = ..., + eager: builtins.bool = ..., ) -> None: ... def HasField( - self, - field_name: typing_extensions.Literal[ - "_eager", - b"_eager", - "_local", - b"_local", - "eager", - b"eager", - "local", - b"local", - "relation", - b"relation", - ], + self, field_name: typing_extensions.Literal["relation", b"relation"] ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "_eager", - b"_eager", - "_local", - b"_local", - "eager", - b"eager", - "local", - b"local", - "relation", - b"relation", + "eager", b"eager", "local", b"local", "relation", b"relation" ], ) -> None: ... - @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_eager", b"_eager"] - ) -> typing_extensions.Literal["eager"] | None: ... - @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_local", b"_local"] - ) -> typing_extensions.Literal["local"] | None: ... global___CheckpointCommand = CheckpointCommand diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3e843e64ebbf6..c7511737b2b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -754,7 +754,7 @@ class Dataset[T] private[sql]( * checkpoint directory. If false creates a local checkpoint using * the caching subsystem */ - private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + private[sql] def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { val actionName = if (reliableCheckpoint) "checkpoint" else "localCheckpoint" withAction(actionName, queryExecution) { physicalPlan => val internalRdd = physicalPlan.execute().map(_.copy()) From 14d3f447360b66663c8979a8cdb4c40c480a1e04 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 23 May 2024 16:12:38 +0800 Subject: [PATCH 07/45] [SPARK-48395][PYTHON] Fix `StructType.treeString` for parameterized types ### What changes were proposed in this pull request? this PR is a follow up of https://github.com/apache/spark/pull/46685. ### Why are the changes needed? `StructType.treeString` uses `DataType.typeName` to generate the tree string, however, the `typeName` in python is a class method and can not return the same string for parameterized types. ``` In [2]: schema = StructType().add("c", CharType(10), True).add("v", VarcharType(10), True).add("d", DecimalType(10, 2), True).add("ym00", YearM ...: onthIntervalType(0, 0)).add("ym01", YearMonthIntervalType(0, 1)).add("ym11", YearMonthIntervalType(1, 1)) In [3]: print(schema.treeString()) root |-- c: char (nullable = true) |-- v: varchar (nullable = true) |-- d: decimal (nullable = true) |-- ym00: yearmonthinterval (nullable = true) |-- ym01: yearmonthinterval (nullable = true) |-- ym11: yearmonthinterval (nullable = true) ``` it should be ``` In [4]: print(schema.treeString()) root |-- c: char(10) (nullable = true) |-- v: varchar(10) (nullable = true) |-- d: decimal(10,2) (nullable = true) |-- ym00: interval year (nullable = true) |-- ym01: interval year to month (nullable = true) |-- ym11: interval month (nullable = true) ``` ### Does this PR introduce _any_ user-facing change? no, this feature was just added and not release out yet. ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46711 from zhengruifeng/tree_string_fix. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/tests/test_types.py | 67 ++++++++++++++++++++++++++ python/pyspark/sql/types.py | 27 +++++++++-- 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index ec07406b11912..6c64a9471363a 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -41,6 +41,7 @@ FloatType, DateType, TimestampType, + TimestampNTZType, DayTimeIntervalType, YearMonthIntervalType, CalendarIntervalType, @@ -1411,6 +1412,72 @@ def test_tree_string(self): ], ) + def test_tree_string_for_builtin_types(self): + schema = ( + StructType() + .add("n", NullType()) + .add("str", StringType()) + .add("c", CharType(10)) + .add("v", VarcharType(10)) + .add("bin", BinaryType()) + .add("bool", BooleanType()) + .add("date", DateType()) + .add("ts", TimestampType()) + .add("ts_ntz", TimestampNTZType()) + .add("dec", DecimalType(10, 2)) + .add("double", DoubleType()) + .add("float", FloatType()) + .add("long", LongType()) + .add("int", IntegerType()) + .add("short", ShortType()) + .add("byte", ByteType()) + .add("ym_interval_1", YearMonthIntervalType()) + .add("ym_interval_2", YearMonthIntervalType(YearMonthIntervalType.YEAR)) + .add( + "ym_interval_3", + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + ) + .add("dt_interval_1", DayTimeIntervalType()) + .add("dt_interval_2", DayTimeIntervalType(DayTimeIntervalType.DAY)) + .add( + "dt_interval_3", + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + ) + .add("cal_interval", CalendarIntervalType()) + .add("var", VariantType()) + ) + self.assertEqual( + schema.treeString().split("\n"), + [ + "root", + " |-- n: void (nullable = true)", + " |-- str: string (nullable = true)", + " |-- c: char(10) (nullable = true)", + " |-- v: varchar(10) (nullable = true)", + " |-- bin: binary (nullable = true)", + " |-- bool: boolean (nullable = true)", + " |-- date: date (nullable = true)", + " |-- ts: timestamp (nullable = true)", + " |-- ts_ntz: timestamp_ntz (nullable = true)", + " |-- dec: decimal(10,2) (nullable = true)", + " |-- double: double (nullable = true)", + " |-- float: float (nullable = true)", + " |-- long: long (nullable = true)", + " |-- int: integer (nullable = true)", + " |-- short: short (nullable = true)", + " |-- byte: byte (nullable = true)", + " |-- ym_interval_1: interval year to month (nullable = true)", + " |-- ym_interval_2: interval year (nullable = true)", + " |-- ym_interval_3: interval year to month (nullable = true)", + " |-- dt_interval_1: interval day to second (nullable = true)", + " |-- dt_interval_2: interval day (nullable = true)", + " |-- dt_interval_3: interval hour to second (nullable = true)", + " |-- cal_interval: interval (nullable = true)", + " |-- var: variant (nullable = true)", + "", + ], + ) + def test_metadata_null(self): schema = StructType( [ diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ee0cc9db5c445..17b019240f826 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -215,6 +215,24 @@ def _data_type_build_formatted_string( if isinstance(dataType, (ArrayType, StructType, MapType)): dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1) + # The method typeName() is not always the same as the Scala side. + # Add this helper method to make TreeString() compatible with Scala side. + @classmethod + def _get_jvm_type_name(cls, dataType: "DataType") -> str: + if isinstance( + dataType, + ( + DecimalType, + CharType, + VarcharType, + DayTimeIntervalType, + YearMonthIntervalType, + ), + ): + return dataType.simpleString() + else: + return dataType.typeName() + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -758,7 +776,7 @@ def _build_formatted_string( ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- element: {self.elementType.typeName()} " + f"{prefix}-- element: {DataType._get_jvm_type_name(self.elementType)} " + f"(containsNull = {str(self.containsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -906,12 +924,12 @@ def _build_formatted_string( maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: - stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n") + stringConcat.append(f"{prefix}-- key: {DataType._get_jvm_type_name(self.keyType)}\n") DataType._data_type_build_formatted_string( self.keyType, f"{prefix} |", stringConcat, maxDepth ) stringConcat.append( - f"{prefix}-- value: {self.valueType.typeName()} " + f"{prefix}-- value: {DataType._get_jvm_type_name(self.valueType)} " + f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -1074,7 +1092,8 @@ def _build_formatted_string( ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} " + f"{prefix}-- {escape_meta_characters(self.name)}: " + + f"{DataType._get_jvm_type_name(self.dataType)} " + f"(nullable = {str(self.nullable).lower()})\n" ) DataType._data_type_build_formatted_string( From 4a471cceebedd938f781eb385162d33058124092 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 23 May 2024 19:46:46 +0800 Subject: [PATCH 08/45] [MINOR][TESTS] Add a helper function for `spark.table` in dsl ### What changes were proposed in this pull request? Add a helper function for `spark.table` in dsl ### Why are the changes needed? to be used in tests ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46717 from zhengruifeng/dsl_read. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../apache/spark/sql/connect/dsl/package.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala index a94bbf9c8f244..3edb63ee8e815 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -332,6 +332,21 @@ package object dsl { def sql(sqlText: String): Relation = { Relation.newBuilder().setSql(SQL.newBuilder().setQuery(sqlText)).build() } + + def table(name: String): Relation = { + proto.Relation + .newBuilder() + .setRead( + proto.Read + .newBuilder() + .setNamedTable( + proto.Read.NamedTable + .newBuilder() + .setUnparsedIdentifier(name) + .build()) + .build()) + .build() + } } implicit class DslNAFunctions(val logicalPlan: Relation) { From 2516fd8439df42b1c161fbd346a0c346cc075f0f Mon Sep 17 00:00:00 2001 From: Andy Lam Date: Thu, 23 May 2024 15:46:09 -0700 Subject: [PATCH 09/45] [SPARK-45009][SQL][FOLLOW UP] Add error class and tests for decorrelation of predicate subqueries in join condition which reference both join child ### What changes were proposed in this pull request? This is a follow up PR for https://github.com/apache/spark/pull/42725, which decorrelates predicate subqueries in join conditions. I forgot to add the error class definition for the case where the subquery references both join children, and test cases for it. ### Why are the changes needed? To show a clear error message when the condition is hit. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added SQL test and golden files. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46708 from andylam-db/follow-up-decorrelate-subqueries-in-join-cond. Authored-by: Andy Lam Signed-off-by: Gengliang Wang --- .../resources/error/error-conditions.json | 6 +++ .../exists-in-join-condition.sql.out | 44 +++++++++++++++++++ .../in-subquery-in-join-condition.sql.out | 44 +++++++++++++++++++ .../exists-in-join-condition.sql | 4 ++ .../in-subquery-in-join-condition.sql | 4 ++ .../exists-in-join-condition.sql.out | 30 +++++++++++++ .../in-subquery-in-join-condition.sql.out | 30 +++++++++++++ 7 files changed, 162 insertions(+) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index c1c0cd6bfb39e..883c51bffadec 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4731,6 +4731,12 @@ "" ] }, + "UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION" : { + "message" : [ + "Correlated subqueries in the join predicate cannot reference both join inputs:", + "" + ] + }, "UNSUPPORTED_CORRELATED_REFERENCE_DATA_TYPE" : { "message" : [ "Correlated column reference '' cannot be type." diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-in-join-condition.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-in-join-condition.sql.out index 1b09e8798a325..3b55a7293bcfa 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-in-join-condition.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-in-join-condition.sql.out @@ -1004,3 +1004,47 @@ Sort [x1#x ASC NULLS FIRST, x2#x ASC NULLS FIRST, y1#x ASC NULLS FIRST, y2#x ASC +- View (`y`, [y1#x, y2#x]) +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +- LocalRelation [col1#x, col2#x] + + +-- !query +select * from x join y on x1 = y1 and exists (select * from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query analysis +Sort [x1#x ASC NULLS FIRST, x2#x ASC NULLS FIRST, y1#x ASC NULLS FIRST, y2#x ASC NULLS FIRST], true ++- Project [x1#x, x2#x, y1#x, y2#x] + +- Join Inner, ((x1#x = y1#x) AND exists#x [x2#x && y2#x]) + : +- Project [z1#x, z2#x] + : +- Filter ((z2#x = outer(x2#x)) AND (z2#x = outer(y2#x))) + : +- SubqueryAlias z + : +- View (`z`, [z1#x, z2#x]) + : +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x] + : +- LocalRelation [col1#x, col2#x] + :- SubqueryAlias x + : +- View (`x`, [x1#x, x2#x]) + : +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias y + +- View (`y`, [y1#x, y2#x]) + +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select * from x join y on x1 = y1 and not exists (select * from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query analysis +Sort [x1#x ASC NULLS FIRST, x2#x ASC NULLS FIRST, y1#x ASC NULLS FIRST, y2#x ASC NULLS FIRST], true ++- Project [x1#x, x2#x, y1#x, y2#x] + +- Join Inner, ((x1#x = y1#x) AND NOT exists#x [x2#x && y2#x]) + : +- Project [z1#x, z2#x] + : +- Filter ((z2#x = outer(x2#x)) AND (z2#x = outer(y2#x))) + : +- SubqueryAlias z + : +- View (`z`, [z1#x, z2#x]) + : +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x] + : +- LocalRelation [col1#x, col2#x] + :- SubqueryAlias x + : +- View (`x`, [x1#x, x2#x]) + : +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias y + +- View (`y`, [y1#x, y2#x]) + +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] + +- LocalRelation [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-subquery-in-join-condition.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-subquery-in-join-condition.sql.out index 422ac4d5c2cbf..ce6a1a3d7ed53 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-subquery-in-join-condition.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-subquery-in-join-condition.sql.out @@ -916,3 +916,47 @@ Sort [x1#x ASC NULLS FIRST, x2#x ASC NULLS FIRST, y1#x ASC NULLS FIRST, y2#x ASC +- View (`y`, [y1#x, y2#x]) +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +- LocalRelation [col1#x, col2#x] + + +-- !query +select * from x left join y on x1 = y1 and x2 IN (select z1 from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query analysis +Sort [x1#x ASC NULLS FIRST, x2#x ASC NULLS FIRST, y1#x ASC NULLS FIRST, y2#x ASC NULLS FIRST], true ++- Project [x1#x, x2#x, y1#x, y2#x] + +- Join LeftOuter, ((x1#x = y1#x) AND x2#x IN (list#x [x2#x && y2#x])) + : +- Project [z1#x] + : +- Filter ((z2#x = outer(x2#x)) AND (z2#x = outer(y2#x))) + : +- SubqueryAlias z + : +- View (`z`, [z1#x, z2#x]) + : +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x] + : +- LocalRelation [col1#x, col2#x] + :- SubqueryAlias x + : +- View (`x`, [x1#x, x2#x]) + : +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias y + +- View (`y`, [y1#x, y2#x]) + +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select * from x left join y on x1 = y1 and x2 not IN (select z1 from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query analysis +Sort [x1#x ASC NULLS FIRST, x2#x ASC NULLS FIRST, y1#x ASC NULLS FIRST, y2#x ASC NULLS FIRST], true ++- Project [x1#x, x2#x, y1#x, y2#x] + +- Join LeftOuter, ((x1#x = y1#x) AND NOT x2#x IN (list#x [x2#x && y2#x])) + : +- Project [z1#x] + : +- Filter ((z2#x = outer(x2#x)) AND (z2#x = outer(y2#x))) + : +- SubqueryAlias z + : +- View (`z`, [z1#x, z2#x]) + : +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x] + : +- LocalRelation [col1#x, col2#x] + :- SubqueryAlias x + : +- View (`x`, [x1#x, x2#x]) + : +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias y + +- View (`y`, [y1#x, y2#x]) + +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] + +- LocalRelation [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-in-join-condition.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-in-join-condition.sql index ad2e7ad563e08..bc732cc3d320d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-in-join-condition.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-in-join-condition.sql @@ -89,3 +89,7 @@ select * from x inner join y on x1 = y1 and exists (select * from z where z1 = y select * from x inner join y on x1 = y1 and not exists (select * from z where z1 = y1) order by x1, x2, y1, y2; select * from x left join y on x1 = y1 and exists (select * from z where z1 = y1) order by x1, x2, y1, y2; select * from x left join y on x1 = y1 and not exists (select * from z where z1 = y1) order by x1, x2, y1, y2; + +-- Correlated subquery references both left and right children, errors +select * from x join y on x1 = y1 and exists (select * from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2; +select * from x join y on x1 = y1 and not exists (select * from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-subquery-in-join-condition.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-subquery-in-join-condition.sql index d519abdbacc05..c906390c99c32 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-subquery-in-join-condition.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-subquery-in-join-condition.sql @@ -84,3 +84,7 @@ select * from x inner join y on x1 = y1 and y2 IN (select z1 from z where z1 = y select * from x inner join y on x1 = y1 and y2 not IN (select z1 from z where z1 = y1) order by x1, x2, y1, y2; select * from x left join y on x1 = y1 and y2 IN (select z1 from z where z1 = y1) order by x1, x2, y1, y2; select * from x left join y on x1 = y1 and y2 not IN (select z1 from z where z1 = y1) order by x1, x2, y1, y2; + +-- Correlated subquery references both left and right children, errors +select * from x left join y on x1 = y1 and x2 IN (select z1 from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2; +select * from x left join y on x1 = y1 and x2 not IN (select z1 from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-in-join-condition.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-in-join-condition.sql.out index b490704bebc57..c9c68a5f0602b 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-in-join-condition.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/exists-subquery/exists-in-join-condition.sql.out @@ -472,3 +472,33 @@ struct 1 1 1 4 2 1 NULL NULL 3 4 NULL NULL + + +-- !query +select * from x join y on x1 = y1 and exists (select * from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION", + "sqlState" : "0A000", + "messageParameters" : { + "subqueryExpression" : "exists(x.x2, y.y2, (z.z2 = x.x2), (z.z2 = y.y2))" + } +} + + +-- !query +select * from x join y on x1 = y1 and not exists (select * from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION", + "sqlState" : "0A000", + "messageParameters" : { + "subqueryExpression" : "exists(x.x2, y.y2, (z.z2 = x.x2), (z.z2 = y.y2))" + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-subquery-in-join-condition.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-subquery-in-join-condition.sql.out index 9f829d522ad25..13af4c81173ae 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-subquery-in-join-condition.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-subquery-in-join-condition.sql.out @@ -434,3 +434,33 @@ struct 1 1 1 4 2 1 NULL NULL 3 4 NULL NULL + + +-- !query +select * from x left join y on x1 = y1 and x2 IN (select z1 from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION", + "sqlState" : "0A000", + "messageParameters" : { + "subqueryExpression" : "(x.x2 IN (listquery(x.x2, y.y2, (z.z2 = x.x2), (z.z2 = y.y2))))" + } +} + + +-- !query +select * from x left join y on x1 = y1 and x2 not IN (select z1 from z where z2 = x2 AND z2 = y2) order by x1, x2, y1, y2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION", + "sqlState" : "0A000", + "messageParameters" : { + "subqueryExpression" : "(x.x2 IN (listquery(x.x2, y.y2, (z.z2 = x.x2), (z.z2 = y.y2))))" + } +} From 6afa6cc3c16e21f94087ebb6adb01bd1ff397086 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 24 May 2024 10:13:49 +0800 Subject: [PATCH 10/45] [SPARK-48399][SQL] Teradata: ByteType should map to BYTEINT instead of BYTE(binary) ### What changes were proposed in this pull request? According to the doc of Teradata and Teradata jdbc, BYTE represents binary type in Teradata, while BYTEINT is used for tinyint. - https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Data-Types-and-Literals/Numeric-Data-Types/BYTEINT-Data-Type - https://teradata-docs.s3.amazonaws.com/doc/connectivity/jdbc/reference/current/frameset.html ### Why are the changes needed? Bugfix ### Does this PR introduce _any_ user-facing change? Yes, ByteType used to be stored as binary type in Teradata, now it has become BYTEINT. (The use-case seems rare, the migration guide or legacy config are pending reviewer's comments) ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46715 from yaooqinn/SPARK-48399. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../apache/spark/sql/jdbc/TeradataDialect.scala | 1 + .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 15 +++++---------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 7acd22a3f10be..95a9f60b64ed8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -42,6 +42,7 @@ private case class TeradataDialect() extends JdbcDialect { override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case ByteType => Option(JdbcType("BYTEINT", java.sql.Types.TINYINT)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0a792f44d3e22..e4116b565818e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1477,16 +1477,11 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } } - test("SPARK-15648: teradataDialect StringType data mapping") { - val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") - assert(teradataDialect.getJDBCType(StringType). - map(_.databaseTypeDefinition).get == "VARCHAR(255)") - } - - test("SPARK-15648: teradataDialect BooleanType data mapping") { - val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") - assert(teradataDialect.getJDBCType(BooleanType). - map(_.databaseTypeDefinition).get == "CHAR(1)") + test("SPARK-48399: TeradataDialect jdbc data mapping") { + val dialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + assert(dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "VARCHAR(255)") + assert(dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + assert(dialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "BYTEINT") } test("SPARK-38846: TeradataDialect catalyst type mapping") { From 3b9b52dff6149e499c59bb30641df777bd712d9b Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 24 May 2024 11:52:37 +0800 Subject: [PATCH 11/45] [SPARK-48405][BUILD] Upgrade `commons-compress` to 1.26.2 ### What changes were proposed in this pull request? The pr aims to upgrade `commons-compress` to `1.26.2`. ### Why are the changes needed? The full release notes: https://commons.apache.org/proper/commons-compress/changes-report.html#a1.26.2 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46725 from panbingkun/SPARK-48405. Authored-by: panbingkun Signed-off-by: Kent Yao --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 79ce883dc672c..35f6103e9fa45 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -42,7 +42,7 @@ commons-codec/1.17.0//commons-codec-1.17.0.jar commons-collections/3.2.2//commons-collections-3.2.2.jar commons-collections4/4.4//commons-collections4-4.4.jar commons-compiler/3.1.9//commons-compiler-3.1.9.jar -commons-compress/1.26.1//commons-compress-1.26.1.jar +commons-compress/1.26.2//commons-compress-1.26.2.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar commons-io/2.16.1//commons-io-2.16.1.jar diff --git a/pom.xml b/pom.xml index 6bbcf05b59e54..ecd05ee996e1b 100644 --- a/pom.xml +++ b/pom.xml @@ -187,7 +187,7 @@ 1.1.10.5 3.0.3 1.17.0 - 1.26.1 + 1.26.2 2.16.1 2.6 From f42ed6c760043b0213ebf0348a22dec7c0bb8244 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 24 May 2024 14:23:23 +0800 Subject: [PATCH 12/45] [SPARK-48406][BUILD] Upgrade commons-cli to 1.8.0 ### What changes were proposed in this pull request? This pr aims to upgrade Apache `commons-cli` from 1.6.0 to 1.8.0. ### Why are the changes needed? The full release notes as follows: - https://commons.apache.org/proper/commons-cli/changes-report.html#a1.7.0 - https://commons.apache.org/proper/commons-cli/changes-report.html#a1.8.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #46727 from LuciferYang/commons-cli-180. Authored-by: yangjie01 Signed-off-by: Kent Yao --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 35f6103e9fa45..46c5108e4eba4 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -37,7 +37,7 @@ cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar checker-qual/3.42.0//checker-qual-3.42.0.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.13/0.10.0//chill_2.13-0.10.0.jar -commons-cli/1.6.0//commons-cli-1.6.0.jar +commons-cli/1.8.0//commons-cli-1.8.0.jar commons-codec/1.17.0//commons-codec-1.17.0.jar commons-collections/3.2.2//commons-collections-3.2.2.jar commons-collections4/4.4//commons-collections4-4.4.jar diff --git a/pom.xml b/pom.xml index ecd05ee996e1b..e8d47afa1cca5 100644 --- a/pom.xml +++ b/pom.xml @@ -210,7 +210,7 @@ 4.17.0 3.1.0 1.1.0 - 1.6.0 + 1.8.0 1.78 1.13.0 6.0.0 From a29c9653f3d48d97875ae446d82896bdf0de61ca Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 24 May 2024 14:31:52 +0800 Subject: [PATCH 13/45] [SPARK-46090][SQL] Support plan fragment level SQL configs in AQE ### What changes were proposed in this pull request? This pr introduces `case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean)` which can be used inside adaptive sql extension rules through thread local, so that developers can modify the next plan fragment configs using `AdaptiveRuleContext.get()`. The plan fragment configs can be propagated through multi-phases, e.g., if set a config in `queryPostPlannerStrategyRules` then the config can be gotten in `queryStagePrepRules`, `queryStageOptimizerRules` and `columnarRules`. The configs will be cleanup before going to execute, so in next round the configs will be empty. ### Why are the changes needed? To support modify the plan fragment level SQL configs through AQE rules. ### Does this PR introduce _any_ user-facing change? no, only affect developers. ### How was this patch tested? add new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #44013 from ulysses-you/rule-context. Lead-authored-by: ulysses-you Co-authored-by: Kent Yao Signed-off-by: youxiduo --- .../adaptive/AdaptiveRuleContext.scala | 89 +++++++++ .../adaptive/AdaptiveSparkPlanExec.scala | 42 ++++- .../adaptive/AdaptiveRuleContextSuite.scala | 176 ++++++++++++++++++ 3 files changed, 299 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala new file mode 100644 index 0000000000000..fce20b79e1136 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.SQLConfHelper + +/** + * Provide the functionality to modify the next plan fragment configs in AQE rules. + * The configs will be cleanup before going to execute next plan fragment. + * To get instance, use: {{{ AdaptiveRuleContext.get() }}} + * + * @param isSubquery if the input query plan is subquery + * @param isFinalStage if the next stage is final stage + */ +@Experimental +@DeveloperApi +case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean) { + + /** + * Set SQL configs for next plan fragment. The configs will affect all of rules in AQE, + * i.e., the runtime optimizer, planner, queryStagePreparationRules, queryStageOptimizerRules, + * columnarRules. + * This configs will be cleared before going to get the next plan fragment. + */ + private val nextPlanFragmentConf = new mutable.HashMap[String, String]() + + private[sql] def withFinalStage(isFinalStage: Boolean): AdaptiveRuleContext = { + if (this.isFinalStage == isFinalStage) { + this + } else { + val newRuleContext = copy(isFinalStage = isFinalStage) + newRuleContext.setConfigs(this.configs()) + newRuleContext + } + } + + def setConfig(key: String, value: String): Unit = { + nextPlanFragmentConf.put(key, value) + } + + def setConfigs(kvs: Map[String, String]): Unit = { + kvs.foreach(kv => nextPlanFragmentConf.put(kv._1, kv._2)) + } + + private[sql] def configs(): Map[String, String] = nextPlanFragmentConf.toMap + + private[sql] def clearConfigs(): Unit = nextPlanFragmentConf.clear() +} + +object AdaptiveRuleContext extends SQLConfHelper { + private val ruleContextThreadLocal = new ThreadLocal[AdaptiveRuleContext] + + /** + * If a rule is applied inside AQE then the returned value is always defined, else return None. + */ + def get(): Option[AdaptiveRuleContext] = Option(ruleContextThreadLocal.get()) + + private[sql] def withRuleContext[T](ruleContext: AdaptiveRuleContext)(block: => T): T = { + assert(ruleContext != null) + val origin = ruleContextThreadLocal.get() + ruleContextThreadLocal.set(ruleContext) + try { + val conf = ruleContext.configs() + withSQLConf(conf.toSeq: _*) { + block + } + } finally { + ruleContextThreadLocal.set(origin) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index f30ffaf515664..f21960aeedd64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -85,6 +85,25 @@ case class AdaptiveSparkPlanExec( case _ => logDebug(_) } + @transient private var ruleContext = new AdaptiveRuleContext( + isSubquery = isSubquery, + isFinalStage = false) + + private def withRuleContext[T](f: => T): T = + AdaptiveRuleContext.withRuleContext(ruleContext) { f } + + private def applyPhysicalRulesWithRuleContext( + plan: => SparkPlan, + rules: Seq[Rule[SparkPlan]], + loggerAndBatchName: Option[(PlanChangeLogger[SparkPlan], String)] = None): SparkPlan = { + // Apply the last rules if exists before going to apply the next batch of rules, + // so that we can propagate the configs. + val newPlan = plan + withRuleContext { + applyPhysicalRules(newPlan, rules, loggerAndBatchName) + } + } + @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. @@ -161,7 +180,9 @@ case class AdaptiveSparkPlanExec( collapseCodegenStagesRule ) - private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { + private def optimizeQueryStage( + plan: SparkPlan, + isFinalStage: Boolean): SparkPlan = withRuleContext { val rules = if (isFinalStage && !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) @@ -197,7 +218,7 @@ case class AdaptiveSparkPlanExec( } private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = { - applyPhysicalRules( + applyPhysicalRulesWithRuleContext( plan, context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules, Some((planChangeLogger, "AQE Query Post Planner Strategy Rules")) @@ -205,7 +226,7 @@ case class AdaptiveSparkPlanExec( } @transient val initialPlan = context.session.withActive { - applyPhysicalRules( + applyPhysicalRulesWithRuleContext( applyQueryPostPlannerStrategyRules(inputPlan), queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations"))) @@ -282,6 +303,7 @@ case class AdaptiveSparkPlanExec( val errors = new mutable.ArrayBuffer[Throwable]() var stagesToReplace = Seq.empty[QueryStageExec] while (!result.allChildStagesMaterialized) { + ruleContext.clearConfigs() currentPhysicalPlan = result.newPlan if (result.newStages.nonEmpty) { stagesToReplace = result.newStages ++ stagesToReplace @@ -373,11 +395,13 @@ case class AdaptiveSparkPlanExec( result = createQueryStages(currentPhysicalPlan) } + ruleContext = ruleContext.withFinalStage(isFinalStage = true) // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = applyPhysicalRules( + currentPhysicalPlan = applyPhysicalRulesWithRuleContext( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) + ruleContext.clearConfigs() _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan @@ -595,7 +619,7 @@ case class AdaptiveSparkPlanExec( val queryStage = plan match { case e: Exchange => val optimized = e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false))) - val newPlan = applyPhysicalRules( + val newPlan = applyPhysicalRulesWithRuleContext( optimized, postStageCreationRules(outputsColumnar = plan.supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) @@ -722,9 +746,11 @@ case class AdaptiveSparkPlanExec( private def reOptimize(logicalPlan: LogicalPlan): Option[(SparkPlan, LogicalPlan)] = { try { logicalPlan.invalidateStatsCache() - val optimized = optimizer.execute(logicalPlan) - val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() - val newPlan = applyPhysicalRules( + val optimized = withRuleContext { optimizer.execute(logicalPlan) } + val sparkPlan = withRuleContext { + context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() + } + val newPlan = applyPhysicalRulesWithRuleContext( applyQueryPostPlannerStrategyRules(sparkPlan), preprocessingRules ++ queryStagePreparationRules, Some((planChangeLogger, "AQE Replanning"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala new file mode 100644 index 0000000000000..04c9e6c946b45 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{SparkSession, SparkSessionExtensionsProvider} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ColumnarRule, RangeExec, SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +class AdaptiveRuleContextSuite extends SparkFunSuite with AdaptiveSparkPlanHelper { + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession( + builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = { + val builder = SparkSession.builder().master("local[1]") + builders.foreach(builder.withExtensions) + val spark = builder.getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("test adaptive rule context") { + withSession( + Seq(_.injectRuntimeOptimizerRule(_ => MyRuleContextForRuntimeOptimization), + _.injectPlannerStrategy(_ => MyRuleContextForPlannerStrategy), + _.injectQueryPostPlannerStrategyRule(_ => MyRuleContextForPostPlannerStrategyRule), + _.injectQueryStagePrepRule(_ => MyRuleContextForPreQueryStageRule), + _.injectQueryStageOptimizerRule(_ => MyRuleContextForQueryStageRule), + _.injectColumnar(_ => MyRuleContextForColumnarRule))) { spark => + val df = spark.range(1, 10, 1, 3).selectExpr("id % 3 as c").groupBy("c").count() + df.collect() + assert(collectFirst(df.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.numPartitions == 2 => s + }.isDefined) + } + } + + test("test adaptive rule context with subquery") { + withSession( + Seq(_.injectQueryStagePrepRule(_ => MyRuleContextForQueryStageWithSubquery))) { spark => + spark.sql("select (select count(*) from range(10)), id from range(10)").collect() + } + } +} + +object MyRuleContext { + def checkAndGetRuleContext(): AdaptiveRuleContext = { + val ruleContextOpt = AdaptiveRuleContext.get() + assert(ruleContextOpt.isDefined) + ruleContextOpt.get + } + + def checkRuleContextForQueryStage(plan: SparkPlan): SparkPlan = { + val ruleContext = checkAndGetRuleContext() + assert(!ruleContext.isSubquery) + val stage = plan.find(_.isInstanceOf[ShuffleQueryStageExec]) + if (stage.isDefined && stage.get.asInstanceOf[ShuffleQueryStageExec].isMaterialized) { + assert(ruleContext.isFinalStage) + assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + } else { + assert(!ruleContext.isFinalStage) + assert(ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + } + plan + } +} + +object MyRuleContextForRuntimeOptimization extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + MyRuleContext.checkAndGetRuleContext() + plan + } +} + +object MyRuleContextForPlannerStrategy extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + plan match { + case _: LogicalQueryStage => + val ruleContext = MyRuleContext.checkAndGetRuleContext() + assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + Nil + case _ => Nil + } + } +} + +object MyRuleContextForPostPlannerStrategyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + if (plan.find(_.isInstanceOf[RangeExec]).isDefined) { + ruleContext.setConfig("spark.sql.shuffle.partitions", "2") + } + plan + } +} + +object MyRuleContextForPreQueryStageRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + assert(!ruleContext.isFinalStage) + plan + } +} + +object MyRuleContextForQueryStageRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + MyRuleContext.checkRuleContextForQueryStage(plan) + } +} + +object MyRuleContextForColumnarRule extends ColumnarRule { + override def preColumnarTransitions: Rule[SparkPlan] = { + plan: SparkPlan => { + if (plan.isInstanceOf[AdaptiveSparkPlanExec]) { + // skip if we are not inside AQE + assert(AdaptiveRuleContext.get().isEmpty) + plan + } else { + MyRuleContext.checkRuleContextForQueryStage(plan) + } + } + } + + override def postColumnarTransitions: Rule[SparkPlan] = { + plan: SparkPlan => { + if (plan.isInstanceOf[AdaptiveSparkPlanExec]) { + // skip if we are not inside AQE + assert(AdaptiveRuleContext.get().isEmpty) + plan + } else { + MyRuleContext.checkRuleContextForQueryStage(plan) + } + } + } +} + +object MyRuleContextForQueryStageWithSubquery extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + if (plan.exists(_.isInstanceOf[HashAggregateExec])) { + assert(ruleContext.isSubquery) + if (plan.exists(_.isInstanceOf[RangeExec])) { + assert(!ruleContext.isFinalStage) + } else { + assert(ruleContext.isFinalStage) + } + } else { + assert(!ruleContext.isSubquery) + } + plan + } +} From 3346afd4b250c3aead5a237666d4942018a463e0 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 24 May 2024 14:53:26 +0800 Subject: [PATCH 14/45] [SPARK-46090][SQL][FOLLOWUP] Add DeveloperApi import ### What changes were proposed in this pull request? Add DeveloperApi import ### Why are the changes needed? Fix compile issue ### Does this PR introduce _any_ user-facing change? Fix compile issue ### How was this patch tested? pass CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #46730 from ulysses-you/hot-fix. Authored-by: ulysses-you Signed-off-by: Kent Yao --- .../spark/sql/execution/adaptive/AdaptiveRuleContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala index fce20b79e1136..23817be71c89c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.catalyst.SQLConfHelper /** From ef43bbbc11638b6ad3f02b9f4d74a6357ef09f13 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 24 May 2024 16:20:54 +0800 Subject: [PATCH 15/45] [SPARK-48384][BUILD] Exclude `io.netty:netty-tcnative-boringssl-static` from `zookeeper` ### What changes were proposed in this pull request? The pr aims to exclude `io.netty:netty-tcnative-boringssl-static` from `zookeeper`. ### Why are the changes needed? 1.According to: https://github.com/netty/netty-tcnative/blob/c9b4b6ab62cdbfb4aab6ab3efb8dd7cf73f353ad/boringssl-static/pom.xml#L970-L982 the `io.netty:netty-tcnative-boringssl-static` is composed of: `io.netty:netty-tcnative-boringssl-static:linux-aarch_64` `io.netty:netty-tcnative-boringssl-static:linux-x86_64` `io.netty:netty-tcnative-boringssl-static:osx-aarch_64` `io.netty:netty-tcnative-boringssl-static:osx-x86_64` `io.netty:netty-tcnative-boringssl-staticwindows-x86_64` and our `common/network-common/pom.xml` file already explicitly depends on them. 2.Their versions are different in `dev/deps/spark-deps-hadoop-3-hive-2.3` image Let's keep one version to avoid conflicts. 3.In the `pom.xml` file, zookeeper already has other `netty` related exclusions, eg: https://github.com/apache/spark/blob/master/pom.xml#L1838-L1842 image ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46695 from panbingkun/SPARK-48384. Authored-by: panbingkun Signed-off-by: yangjie01 --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 1 - pom.xml | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 46c5108e4eba4..61d7861f4469b 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -207,7 +207,6 @@ netty-common/4.1.109.Final//netty-common-4.1.109.Final.jar netty-handler-proxy/4.1.109.Final//netty-handler-proxy-4.1.109.Final.jar netty-handler/4.1.109.Final//netty-handler-4.1.109.Final.jar netty-resolver/4.1.109.Final//netty-resolver-4.1.109.Final.jar -netty-tcnative-boringssl-static/2.0.61.Final//netty-tcnative-boringssl-static-2.0.61.Final.jar netty-tcnative-boringssl-static/2.0.65.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-aarch_64.jar netty-tcnative-boringssl-static/2.0.65.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-x86_64.jar netty-tcnative-boringssl-static/2.0.65.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-aarch_64.jar diff --git a/pom.xml b/pom.xml index e8d47afa1cca5..eef7237ac12f9 100644 --- a/pom.xml +++ b/pom.xml @@ -1839,6 +1839,10 @@ io.netty netty-transport-native-epoll + + io.netty + netty-tcnative-boringssl-static + com.github.spotbugs spotbugs-annotations From b15b6cf1f537756eafbe8dd31a3b03dc500077f3 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 24 May 2024 17:04:38 +0800 Subject: [PATCH 16/45] [SPARK-48409][BUILD][TESTS] Upgrade MySQL & Postgres & Mariadb docker image version ### What changes were proposed in this pull request? The pr aims to upgrade some db docker image version, include: - `MySQL` from `8.3.0` to `8.4.0` - `Postgres` from `10.5.12` to `10.5.25` - `Mariadb` from `16.2-alpine` to `16.3-alpine` ### Why are the changes needed? Tests dependencies upgrading. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46704 from panbingkun/db_images_upgrade. Authored-by: panbingkun Signed-off-by: Kent Yao --- .../apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala | 6 +++--- .../org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala | 2 +- .../apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala | 6 +++--- .../spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala | 6 +++--- .../spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala | 6 +++--- 8 files changed, 22 insertions(+), 22 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala index 6825c001f7670..efb2fa09f6a3f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnecti import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mariadb:10.5.12): + * To run this test suite for a specific version (e.g., mariadb:10.5.25): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MARIADB_DOCKER_IMAGE_NAME=mariadb:10.5.12 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MARIADB_DOCKER_IMAGE_NAME=mariadb:10.5.25 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.MariaDBKrbIntegrationSuite" * }}} @@ -38,7 +38,7 @@ class MariaDBKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "mariadb.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MARIADB_DOCKER_IMAGE_NAME", "mariadb:10.5.12") + override val imageName = sys.env.getOrElse("MARIADB_DOCKER_IMAGE_NAME", "mariadb:10.5.25") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala index 568eb5f109731..570a81ac3947f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.jdbc class MySQLDatabaseOnDocker extends DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.3.0") + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.4.0") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 5ad4f15216b74..12a71dbd7c7f8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.2): + * To run this test suite for a specific version (e.g., postgres:16.3-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} @@ -42,7 +42,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index d08be3b5f40e3..af1cd464ad5fe 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnecti import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.2): + * To run this test suite for a specific version (e.g., postgres:16.3-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly *PostgresKrbIntegrationSuite" * }}} @@ -38,7 +38,7 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "postgres.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala index 7ae03e974845b..8b27e9cb0e0a3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.tags.DockerTest /** * This suite is used to generate subqueries, and test Spark against Postgres. - * To run this test suite for a specific version (e.g., postgres:16.2): + * To run this test suite for a specific version (e.g., postgres:16.3-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.GeneratedSubquerySuite" * }}} @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest class GeneratedSubquerySuite extends DockerJDBCIntegrationSuite with QueryGeneratorHelper { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala index f2a7e14cfc4b9..de28e16b325ce 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.tags.DockerTest * confidence, and you won't have to manually verify the golden files generated with your test. * 2. Add this line to your .sql file: --ONLY_IF spark * - * Note: To run this test suite for a specific version (e.g., postgres:16.2): + * Note: To run this test suite for a specific version (e.g., postgres:16.3-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.PostgreSQLQueryTestSuite" * }}} @@ -45,7 +45,7 @@ class PostgreSQLQueryTestSuite extends CrossDbmsQueryTestSuite { protected val customInputFilePath: String = new File(inputFilePath, "subquery").getAbsolutePath override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 7fef3ccd6b3f6..7c439d449d86f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.2) + * To run this test suite for a specific version (e.g., postgres:16.3-alpine) * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresIntegrationSuite" * }}} */ @@ -38,7 +38,7 @@ import org.apache.spark.tags.DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index 838de5acab0df..8a2d0ded84381 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -26,16 +26,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.2): + * To run this test suite for a specific version (e.g., postgres:16.3-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresNamespaceSuite" * }}} */ @DockerTest class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) From bd95040c3170aaed61ee5e9090d1b8580351ee80 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 24 May 2024 17:36:46 +0800 Subject: [PATCH 17/45] [SPARK-48412][PYTHON] Refactor data type json parse ### What changes were proposed in this pull request? Refactor data type json parse ### Why are the changes needed? the `_all_atomic_types` causes confusions: - it is only used in json parse, so it should use the `jsonValue` instead of `typeName` (and so it causes the `typeName` not consistent with Scala, will fix in separate PR); - not all atomic types are included in it (e.g. `YearMonthIntervalType`); - not all atomic types should be placed in it (e.g. `VarcharType` which has to be excluded here and there) ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci, added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46733 from zhengruifeng/refactor_json_parse. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/tests/test_types.py | 42 +++++++++++++++++-- python/pyspark/sql/types.py | 57 ++++++++++++++++++-------- 2 files changed, 79 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 6c64a9471363a..d665053d94904 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1136,12 +1136,46 @@ def test_struct_type(self): self.assertRaises(IndexError, lambda: struct1[9]) self.assertRaises(TypeError, lambda: struct1[9.9]) + def test_parse_datatype_json_string(self): + from pyspark.sql.types import _parse_datatype_json_string + + for dataType in [ + StringType(), + CharType(5), + VarcharType(10), + BinaryType(), + BooleanType(), + DecimalType(), + DecimalType(10, 2), + FloatType(), + DoubleType(), + ByteType(), + ShortType(), + IntegerType(), + LongType(), + DateType(), + TimestampType(), + TimestampNTZType(), + NullType(), + VariantType(), + YearMonthIntervalType(), + YearMonthIntervalType(YearMonthIntervalType.YEAR), + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + DayTimeIntervalType(), + DayTimeIntervalType(DayTimeIntervalType.DAY), + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + CalendarIntervalType(), + ]: + json_str = dataType.json() + parsed = _parse_datatype_json_string(json_str) + self.assertEqual(dataType, parsed) + def test_parse_datatype_string(self): - from pyspark.sql.types import _all_atomic_types, _parse_datatype_string + from pyspark.sql.types import _all_mappable_types, _parse_datatype_string + + for k, t in _all_mappable_types.items(): + self.assertEqual(t(), _parse_datatype_string(k)) - for k, t in _all_atomic_types.items(): - if k != "varchar" and k != "char": - self.assertEqual(t(), _parse_datatype_string(k)) self.assertEqual(IntegerType(), _parse_datatype_string("int")) self.assertEqual(StringType(), _parse_datatype_string("string")) self.assertEqual(CharType(1), _parse_datatype_string("char(1)")) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 17b019240f826..b9db59e0a58ac 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1756,13 +1756,45 @@ def toJson(self, zone_id: str = "UTC") -> str: TimestampNTZType, NullType, VariantType, + YearMonthIntervalType, + DayTimeIntervalType, ] -_all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) -_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [ArrayType, MapType, StructType] -_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dict( - (v.typeName(), v) for v in _complex_types -) +_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [ + ArrayType, + MapType, + StructType, +] +_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = { + "array": ArrayType, + "map": MapType, + "struct": StructType, +} + +# Datatypes that can be directly parsed by mapping a json string without regex. +# This dict should be only used in json parsing. +# Note that: +# 1, CharType and VarcharType are not listed here, since they need regex; +# 2, DecimalType can be parsed by both mapping ('decimal') and regex ('decimal(10, 2)'); +# 3, CalendarIntervalType is not an atomic type, but can be mapped by 'interval'; +_all_mappable_types: Dict[str, Type[DataType]] = { + "string": StringType, + "binary": BinaryType, + "boolean": BooleanType, + "decimal": DecimalType, + "float": FloatType, + "double": DoubleType, + "byte": ByteType, + "short": ShortType, + "integer": IntegerType, + "long": LongType, + "date": DateType, + "timestamp": TimestampType, + "timestamp_ntz": TimestampNTZType, + "void": NullType, + "variant": VariantType, + "interval": CalendarIntervalType, +} _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)") _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)") @@ -1887,11 +1919,8 @@ def _parse_datatype_json_string(json_string: str) -> DataType: ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype ... - >>> for cls in _all_atomic_types.values(): - ... if cls is not VarcharType and cls is not CharType: - ... check_datatype(cls()) - ... else: - ... check_datatype(cls(1)) + >>> for cls in _all_mappable_types.values(): + ... check_datatype(cls()) >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) @@ -1938,14 +1967,12 @@ def _parse_datatype_json_value( collationsMap: Optional[Dict[str, str]] = None, ) -> DataType: if not isinstance(json_value, dict): - if json_value in _all_atomic_types.keys(): + if json_value in _all_mappable_types.keys(): if collationsMap is not None and fieldPath in collationsMap: _assert_valid_type_for_collation(fieldPath, json_value, collationsMap) collation_name = collationsMap[fieldPath] return StringType(collation_name) - return _all_atomic_types[json_value]() - elif json_value == "decimal": - return DecimalType() + return _all_mappable_types[json_value]() elif _FIXED_DECIMAL.match(json_value): m = _FIXED_DECIMAL.match(json_value) return DecimalType(int(m.group(1)), int(m.group(2))) # type: ignore[union-attr] @@ -1965,8 +1992,6 @@ def _parse_datatype_json_value( if first_field is not None and second_field is None: return YearMonthIntervalType(first_field) return YearMonthIntervalType(first_field, second_field) - elif json_value == "interval": - return CalendarIntervalType() elif _LENGTH_CHAR.match(json_value): m = _LENGTH_CHAR.match(json_value) return CharType(int(m.group(1))) # type: ignore[union-attr] From 7ae939ae12a68b8664af4e4d9bfe11902ec3494d Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 24 May 2024 18:20:14 +0800 Subject: [PATCH 18/45] [SPARK-48168][SQL] Add bitwise shifting operators support ### What changes were proposed in this pull request? This PR introduces three bitwise shifting operators as aliases for existing shifting functions. ### Why are the changes needed? The bit shifting functions named in alphabet form vary from one platform to anthor. Take our shiftleft as an example, - Hive, shiftleft (where we copied it from) - MsSQL Server LEFT_SHIFT - MySQL, N/A - PostgreSQL, N/A - Presto, bitwise_left_shift The [bit shifting operators](https://en.wikipedia.org/wiki/Bitwise_operations_in_C) share a much more common and consistent way for users to port their queries. For self-consistent with existing bit operators in Spark, `AND &`, `OR |`, `XOR ^` and `NOT ~`, we now add `<<`, `>>` and `>>>`. For other systems that we can refer to: https://learn.microsoft.com/en-us/sql/t-sql/functions/left-shift-transact-sql?view=sql-server-ver16 https://www.postgresql.org/docs/9.4/functions-bitstring.html https://dev.mysql.com/doc/refman/8.0/en/bit-functions.html ### Does this PR introduce _any_ user-facing change? Yes, new operators were added but no behavior change ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46440 from yaooqinn/SPARK-48168. Authored-by: Kent Yao Signed-off-by: youxiduo --- .../function_shiftleft.explain | 2 +- .../function_shiftright.explain | 2 +- .../function_shiftrightunsigned.explain | 2 +- .../grouping_and_grouping_id.explain | 2 +- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 40 +++++- .../sql/catalyst/parser/SqlBaseParser.g4 | 8 ++ .../catalyst/analysis/FunctionRegistry.scala | 3 + .../expressions/mathExpressions.scala | 134 +++++++++--------- .../sql/catalyst/parser/AstBuilder.scala | 11 ++ .../spark/sql/catalyst/SQLKeywordSuite.scala | 2 +- .../sql-functions/sql-expression-schema.md | 9 +- .../analyzer-results/bitwise.sql.out | 112 +++++++++++++++ .../analyzer-results/group-analytics.sql.out | 10 +- .../analyzer-results/grouping_set.sql.out | 6 +- .../postgreSQL/groupingsets.sql.out | 44 +++--- .../analyzer-results/postgreSQL/int2.sql.out | 4 +- .../analyzer-results/postgreSQL/int4.sql.out | 4 +- .../analyzer-results/postgreSQL/int8.sql.out | 4 +- .../udf/udf-group-analytics.sql.out | 10 +- .../resources/sql-tests/inputs/bitwise.sql | 12 ++ .../sql-tests/results/bitwise.sql.out | 128 +++++++++++++++++ .../sql-tests/results/postgreSQL/int2.sql.out | 4 +- .../sql-tests/results/postgreSQL/int4.sql.out | 4 +- .../sql-tests/results/postgreSQL/int8.sql.out | 2 +- .../approved-plans-v1_4/q17/explain.txt | 2 +- .../approved-plans-v1_4/q25/explain.txt | 2 +- .../approved-plans-v1_4/q27.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q27/explain.txt | 2 +- .../approved-plans-v1_4/q29/explain.txt | 2 +- .../approved-plans-v1_4/q36.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q36/explain.txt | 2 +- .../approved-plans-v1_4/q39a/explain.txt | 2 +- .../approved-plans-v1_4/q39b/explain.txt | 2 +- .../approved-plans-v1_4/q49/explain.txt | 6 +- .../approved-plans-v1_4/q5/explain.txt | 2 +- .../approved-plans-v1_4/q64/explain.txt | 4 +- .../approved-plans-v1_4/q70.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q70/explain.txt | 2 +- .../approved-plans-v1_4/q72/explain.txt | 2 +- .../approved-plans-v1_4/q85/explain.txt | 2 +- .../approved-plans-v1_4/q86.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q86/explain.txt | 2 +- .../approved-plans-v2_7/q24.sf100/explain.txt | 2 +- .../approved-plans-v2_7/q49/explain.txt | 6 +- .../approved-plans-v2_7/q5a/explain.txt | 2 +- .../approved-plans-v2_7/q64/explain.txt | 4 +- .../approved-plans-v2_7/q72/explain.txt | 2 +- .../sql/expressions/ExpressionInfoSuite.scala | 5 +- 48 files changed, 469 insertions(+), 153 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain index f89a8be7ceedb..6d5eb29944d52 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain @@ -1,2 +1,2 @@ -Project [shiftleft(cast(b#0 as int), 2) AS shiftleft(b, 2)#0] +Project [(cast(b#0 as int) << 2) AS (b << 2)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain index b436f52e912b5..b1c2c35ac2d0e 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain @@ -1,2 +1,2 @@ -Project [shiftright(cast(b#0 as int), 2) AS shiftright(b, 2)#0] +Project [(cast(b#0 as int) >> 2) AS (b >> 2)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain index 282ad156b3825..508c78a7f421f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain @@ -1,2 +1,2 @@ -Project [shiftrightunsigned(cast(b#0 as int), 2) AS shiftrightunsigned(b, 2)#0] +Project [(cast(b#0 as int) >>> 2) AS (b >>> 2)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain index 3b7d6fb2b7072..f46fa38989ed4 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain @@ -1,4 +1,4 @@ -Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, cast((shiftright(spark_grouping_id#0L, 1) & 1) as tinyint) AS grouping(a)#0, cast((shiftright(spark_grouping_id#0L, 0) & 1) as tinyint) AS grouping(b)#0, spark_grouping_id#0L AS grouping_id(a, b)#0L] +Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, cast(((spark_grouping_id#0L >> 1) & 1) as tinyint) AS grouping(a)#0, cast(((spark_grouping_id#0L >> 0) & 1) as tinyint) AS grouping(b)#0, spark_grouping_id#0L AS grouping_id(a, b)#0L] +- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], [id#0L, a#0, b#0, null, b#0, 2], [id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, spark_grouping_id#0L] +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 25a06a5b98cf7..a9705c1733df5 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -69,6 +69,35 @@ lexer grammar SqlBaseLexer; public void markUnclosedComment() { has_unclosed_bracketed_comment = true; } + + /** + * When greater than zero, it's in the middle of parsing ARRAY/MAP/STRUCT type. + */ + public int complex_type_level_counter = 0; + + /** + * Increase the counter by one when hits KEYWORD 'ARRAY', 'MAP', 'STRUCT'. + */ + public void incComplexTypeLevelCounter() { + complex_type_level_counter++; + } + + /** + * Decrease the counter by one when hits close tag '>' && the counter greater than zero + * which means we are in the middle of complex type parsing. Otherwise, it's a dangling + * GT token and we do nothing. + */ + public void decComplexTypeLevelCounter() { + if (complex_type_level_counter > 0) complex_type_level_counter--; + } + + /** + * If the counter is zero, it's a shift right operator. It can be closing tags of an complex + * type definition, such as MAP>. + */ + public boolean isShiftRightOperator() { + return complex_type_level_counter == 0 ? true : false; + } } SEMICOLON: ';'; @@ -100,7 +129,7 @@ ANTI: 'ANTI'; ANY: 'ANY'; ANY_VALUE: 'ANY_VALUE'; ARCHIVE: 'ARCHIVE'; -ARRAY: 'ARRAY'; +ARRAY: 'ARRAY' {incComplexTypeLevelCounter();}; AS: 'AS'; ASC: 'ASC'; AT: 'AT'; @@ -259,7 +288,7 @@ LOCKS: 'LOCKS'; LOGICAL: 'LOGICAL'; LONG: 'LONG'; MACRO: 'MACRO'; -MAP: 'MAP'; +MAP: 'MAP' {incComplexTypeLevelCounter();}; MATCHED: 'MATCHED'; MERGE: 'MERGE'; MICROSECOND: 'MICROSECOND'; @@ -362,7 +391,7 @@ STATISTICS: 'STATISTICS'; STORED: 'STORED'; STRATIFY: 'STRATIFY'; STRING: 'STRING'; -STRUCT: 'STRUCT'; +STRUCT: 'STRUCT' {incComplexTypeLevelCounter();}; SUBSTR: 'SUBSTR'; SUBSTRING: 'SUBSTRING'; SYNC: 'SYNC'; @@ -439,8 +468,11 @@ NEQ : '<>'; NEQJ: '!='; LT : '<'; LTE : '<=' | '!>'; -GT : '>'; +GT : '>' {decComplexTypeLevelCounter();}; GTE : '>=' | '!<'; +SHIFT_LEFT: '<<'; +SHIFT_RIGHT: '>>' {isShiftRightOperator()}?; +SHIFT_RIGHT_UNSIGNED: '>>>' {isShiftRightOperator()}?; PLUS: '+'; MINUS: '-'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index a87ecd02fb3a4..f0c0adb881212 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -395,6 +395,7 @@ describeFuncName | comparisonOperator | arithmeticOperator | predicateOperator + | shiftOperator | BANG ; @@ -989,6 +990,13 @@ valueExpression | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary | left=valueExpression comparisonOperator right=valueExpression #comparison + | left=valueExpression shiftOperator right=valueExpression #shiftExpression + ; + +shiftOperator + : SHIFT_LEFT + | SHIFT_RIGHT + | SHIFT_RIGHT_UNSIGNED ; datetimeUnit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 78126ce30af5e..3a418497fa537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -800,6 +800,9 @@ object FunctionRegistry { expression[BitwiseNot]("~"), expression[BitwiseOr]("|"), expression[BitwiseXor]("^"), + expression[ShiftLeft]("<<", true, Some("4.0.0")), + expression[ShiftRight](">>", true, Some("4.0.0")), + expression[ShiftRightUnsigned](">>>", true, Some("4.0.0")), expression[BitwiseCount]("bit_count"), expression[BitAndAgg]("bit_and"), expression[BitOrAgg]("bit_or"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index dc50c18f2ebbf..4bb0c658eacf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1261,6 +1261,41 @@ case class Pow(left: Expression, right: Expression) newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } +sealed trait BitShiftOperation + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + def symbol: String + def shiftInt: (Int, Int) => Int + def shiftLong: (Long, Int) => Long + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, (left, right) => s"$left $symbol $right") + } + + override protected def nullSafeEval(input1: Any, input2: Any): Any = input1 match { + case l: jl.Long => shiftLong(l, input2.asInstanceOf[Int]) + case i: jl.Integer => shiftInt(i, input2.asInstanceOf[Int]) + } + + override def toString: String = { + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(symbol) match { + case alias if alias == symbol => s"($left $symbol $right)" + case _ => super.toString + } + } + + override def sql: String = { + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(symbol) match { + case alias if alias == symbol => s"(${left.sql} $symbol ${right.sql})" + case _ => super.sql + } + } +} /** * Bitwise left shift. @@ -1269,38 +1304,28 @@ case class Pow(left: Expression, right: Expression) * @param right number of bits to left shift. */ @ExpressionDescription( - usage = "_FUNC_(base, expr) - Bitwise left shift.", + usage = "base << exp - Bitwise left shift.", examples = """ Examples: - > SELECT _FUNC_(2, 1); + > SELECT shiftleft(2, 1); + 4 + > SELECT 2 << 1; 4 """, + note = """ + `<<` operator is added in Spark 4.0.0 as an alias for `shiftleft`. + """, since = "1.5.0", group = "bitwise_funcs") -case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(IntegerType, LongType), IntegerType) - - override def dataType: DataType = left.dataType - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - input1 match { - case l: jl.Long => l << input2.asInstanceOf[jl.Integer] - case i: jl.Integer => i << input2.asInstanceOf[jl.Integer] - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (left, right) => s"$left << $right") - } - +case class ShiftLeft(left: Expression, right: Expression) extends BitShiftOperation { + override def symbol: String = "<<" + override def shiftInt: (Int, Int) => Int = (x: Int, y: Int) => x << y + override def shiftLong: (Long, Int) => Long = (x: Long, y: Int) => x << y + val shift: (Number, Int) => Any = (x: Number, y: Int) => x.longValue() << y override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ShiftLeft = copy(left = newLeft, right = newRight) } - /** * Bitwise (signed) right shift. * @@ -1308,38 +1333,27 @@ case class ShiftLeft(left: Expression, right: Expression) * @param right number of bits to right shift. */ @ExpressionDescription( - usage = "_FUNC_(base, expr) - Bitwise (signed) right shift.", + usage = "base >> expr - Bitwise (signed) right shift.", examples = """ Examples: - > SELECT _FUNC_(4, 1); + > SELECT shiftright(4, 1); + 2 + > SELECT 4 >> 1; 2 """, + note = """ + `>>` operator is added in Spark 4.0.0 as an alias for `shiftright`. + """, since = "1.5.0", group = "bitwise_funcs") -case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(IntegerType, LongType), IntegerType) - - override def dataType: DataType = left.dataType - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - input1 match { - case l: jl.Long => l >> input2.asInstanceOf[jl.Integer] - case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer] - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") - } - +case class ShiftRight(left: Expression, right: Expression) extends BitShiftOperation { + override def symbol: String = ">>" + override def shiftInt: (Int, Int) => Int = (x: Int, y: Int) => x >> y + override def shiftLong: (Long, Int) => Long = (x: Long, y: Int) => x >> y override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ShiftRight = copy(left = newLeft, right = newRight) } - /** * Bitwise unsigned right shift, for integer and long data type. * @@ -1347,33 +1361,23 @@ case class ShiftRight(left: Expression, right: Expression) * @param right the number of bits to right shift. */ @ExpressionDescription( - usage = "_FUNC_(base, expr) - Bitwise unsigned right shift.", + usage = "base >>> expr - Bitwise unsigned right shift.", examples = """ Examples: - > SELECT _FUNC_(4, 1); + > SELECT shiftrightunsigned(4, 1); 2 + > SELECT 4 >>> 1; + 2 + """, + note = """ + `>>>` operator is added in Spark 4.0.0 as an alias for `shiftrightunsigned`. """, since = "1.5.0", group = "bitwise_funcs") -case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(IntegerType, LongType), IntegerType) - - override def dataType: DataType = left.dataType - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - input1 match { - case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer] - case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer] - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") - } - +case class ShiftRightUnsigned(left: Expression, right: Expression) extends BitShiftOperation { + override def symbol: String = ">>>" + override def shiftInt: (Int, Int) => Int = (x: Int, y: Int) => x >>> y + override def shiftLong: (Long, Int) => Long = (x: Long, y: Int) => x >>> y override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ShiftRightUnsigned = copy(left = newLeft, right = newRight) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b6816f5bb2925..52c32530f2e92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2196,6 +2196,17 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } } + override def visitShiftExpression(ctx: ShiftExpressionContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.shiftOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case SqlBaseParser.SHIFT_LEFT => ShiftLeft(left, right) + case SqlBaseParser.SHIFT_RIGHT => ShiftRight(left, right) + case SqlBaseParser.SHIFT_RIGHT_UNSIGNED => ShiftRightUnsigned(left, right) + } + } + /** * Create a unary arithmetic expression. The following arithmetic operators are supported: * - Plus: '+' diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala index 8806431ab4395..9977dcd83d6af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala @@ -98,7 +98,7 @@ trait SQLKeywordUtils extends SparkFunSuite with SQLHelper { } (symbol, literals) :: Nil } else { - val literal = literalDef.replaceAll("'", "").trim + val literal = literalDef.split("\\{")(0).replaceAll("'", "").trim // The case where a symbol string and its literal string are different, // e.g., `SETMINUS: 'MINUS';`. if (symbol != literal) { diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index ca864dddf19b1..bf46fe91eb903 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -289,9 +289,12 @@ | org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct | -| org.apache.spark.sql.catalyst.expressions.ShiftLeft | shiftleft | SELECT shiftleft(2, 1) | struct | -| org.apache.spark.sql.catalyst.expressions.ShiftRight | shiftright | SELECT shiftright(4, 1) | struct | -| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | shiftrightunsigned | SELECT shiftrightunsigned(4, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ShiftLeft | << | SELECT shiftleft(2, 1) | struct<(2 << 1):int> | +| org.apache.spark.sql.catalyst.expressions.ShiftLeft | shiftleft | SELECT shiftleft(2, 1) | struct<(2 << 1):int> | +| org.apache.spark.sql.catalyst.expressions.ShiftRight | >> | SELECT shiftright(4, 1) | struct<(4 >> 1):int> | +| org.apache.spark.sql.catalyst.expressions.ShiftRight | shiftright | SELECT shiftright(4, 1) | struct<(4 >> 1):int> | +| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | >>> | SELECT shiftrightunsigned(4, 1) | struct<(4 >>> 1):int> | +| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | shiftrightunsigned | SELECT shiftrightunsigned(4, 1) | struct<(4 >>> 1):int> | | org.apache.spark.sql.catalyst.expressions.Shuffle | shuffle | SELECT shuffle(array(1, 20, 3, 5)) | struct> | | org.apache.spark.sql.catalyst.expressions.Signum | sign | SELECT sign(40) | struct | | org.apache.spark.sql.catalyst.expressions.Signum | signum | SELECT signum(40) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out index b8958f4a331a6..fee226c0c3411 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out @@ -306,3 +306,115 @@ select getbit(11L, 64) -- !query analysis Project [getbit(11, 64) AS getbit(11, 64)#x] +- OneRowRelation + + +-- !query +SELECT 20181117 >> 2 +-- !query analysis +Project [(20181117 >> 2) AS (20181117 >> 2)#x] ++- OneRowRelation + + +-- !query +SELECT 20181117 << 2 +-- !query analysis +Project [(20181117 << 2) AS (20181117 << 2)#x] ++- OneRowRelation + + +-- !query +SELECT 20181117 >>> 2 +-- !query analysis +Project [(20181117 >>> 2) AS (20181117 >>> 2)#x] ++- OneRowRelation + + +-- !query +SELECT 20181117 > > 2 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'>'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 < < 2 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'<'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 > >> 2 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'>>'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 <<< 2 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'<'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 >>>> 2 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'>'", + "hint" : "" + } +} + + +-- !query +select cast(null as array>), 20181117 >> 2 +-- !query analysis +Project [cast(null as array>) AS NULL#x, (20181117 >> 2) AS (20181117 >> 2)#x] ++- OneRowRelation + + +-- !query +select cast(null as array>), 20181117 >>> 2 +-- !query analysis +Project [cast(null as array>) AS NULL#x, (20181117 >>> 2) AS (20181117 >>> 2)#x] ++- OneRowRelation + + +-- !query +select cast(null as map>), 20181117 >> 2 +-- !query analysis +Project [cast(null as map>) AS NULL#x, (20181117 >> 2) AS (20181117 >> 2)#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out index cdb6372bec099..f8967d7df0b0c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out @@ -316,7 +316,7 @@ Sort [course#x ASC NULLS FIRST, sum#xL ASC NULLS FIRST], true SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) -- !query analysis -Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -382,7 +382,7 @@ HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, yea -- !query analysis Sort [course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true +- Project [course#x, year#x] - +- Filter ((cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) + +- Filter ((cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] @@ -435,8 +435,8 @@ SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY ORDER BY GROUPING(course), GROUPING(year), course, year -- !query analysis Project [course#x, year#x, grouping(course)#x, grouping(year)#x] -+- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true - +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] ++- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true + +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -452,7 +452,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year -- !query analysis Project [course#x, year#x, grouping_id(course, year)#xL] -+- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true ++- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL AS grouping_id(course, year)#xL, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out index b73ee16c8bdef..cbbcb77325348 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out @@ -72,7 +72,7 @@ Aggregate [c1#x, spark_grouping_id#xL], [c1#x, sum(c2#x) AS sum(c2)#xL] -- !query SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) -- !query analysis -Aggregate [c1#x, spark_grouping_id#xL], [c1#x, sum(c2#x) AS sum(c2)#xL, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(c1)#x] +Aggregate [c1#x, spark_grouping_id#xL], [c1#x, sum(c2#x) AS sum(c2)#xL, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(c1)#x] +- Expand [[c1#x, c2#x, c3#x, c1#x, 0]], [c1#x, c2#x, c3#x, c1#x, spark_grouping_id#xL] +- Project [c1#x, c2#x, c3#x, c1#x AS c1#x] +- SubqueryAlias t @@ -98,7 +98,7 @@ Filter (grouping__id#xL > cast(1 as bigint)) -- !query SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2) -- !query analysis -Aggregate [c1#x, c2#x, spark_grouping_id#xL], [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(c1)#x] +Aggregate [c1#x, c2#x, spark_grouping_id#xL], [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(c1)#x] +- Expand [[c1#x, c2#x, c3#x, c1#x, null, 1], [c1#x, c2#x, c3#x, null, c2#x, 2]], [c1#x, c2#x, c3#x, c1#x, c2#x, spark_grouping_id#xL] +- Project [c1#x, c2#x, c3#x, c1#x AS c1#x, c2#x AS c2#x] +- SubqueryAlias t @@ -223,7 +223,7 @@ Aggregate [k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x], [spark_groupi -- !query SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1)) -- !query analysis -Aggregate [k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x], [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(k1)#x, k1#x, k2#x, avg(v#x) AS avg(v)#x] +Aggregate [k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x], [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(k1)#x, k1#x, k2#x, avg(v#x) AS avg(v)#x] +- Expand [[k1#x, k2#x, v#x, k1#x, null, 1, 0], [k1#x, k2#x, v#x, k1#x, k2#x, 0, 1], [k1#x, k2#x, v#x, k1#x, k2#x, 0, 2]], [k1#x, k2#x, v#x, k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x] +- Project [k1#x, k2#x, v#x, k1#x AS k1#x, k2#x AS k2#x] +- SubqueryAlias t diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out index 27e9707425833..d2a25fabe2059 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out @@ -82,7 +82,7 @@ CreateDataSourceTableCommand `spark_catalog`.`default`.`gstest_empty`, false select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -96,7 +96,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) order by a,b -- !query analysis Sort [a#x ASC NULLS FIRST, b#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -110,7 +110,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) order by b desc, a -- !query analysis Sort [b#x DESC NULLS LAST, a#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -124,7 +124,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) order by coalesce(a,0)+coalesce(b,0), a -- !query analysis Sort [(coalesce(a#x, 0) + coalesce(b#x, 0)) ASC NULLS FIRST, a#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -209,7 +209,7 @@ select t1.a, t2.b, grouping(t1.a), grouping(t2.b), sum(t1.v), max(t2.a) from gstest1 t1, gstest2 t2 group by grouping sets ((t1.a, t2.b), ()) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x AS a#x, b#x AS b#x] +- Join Inner @@ -228,7 +228,7 @@ select t1.a, t2.b, grouping(t1.a), grouping(t2.b), sum(t1.v), max(t2.a) from gstest1 t1 join gstest2 t2 on (t1.a=t2.a) group by grouping sets ((t1.a, t2.b), ()) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x AS a#x, b#x AS b#x] +- Join Inner, (a#x = a#x) @@ -247,7 +247,7 @@ select a, b, grouping(a), grouping(b), sum(t1.v), max(t2.c) from gstest1 t1 join gstest2 t2 using (a,b) group by grouping sets ((a, b), ()) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(c#x) AS max(c)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(c#x) AS max(c)#x] +- Expand [[a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, 0], [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, null, null, 3]], [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x AS a#x, b#x AS b#x] +- Project [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x] @@ -402,8 +402,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) >= 0) - +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) >= 0) + +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, 0]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x] +- SubqueryAlias spark_catalog.default.onek @@ -417,8 +417,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) as int) > 0) - +- Aggregate [ten#x, four#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) as int) > 0) + +- Aggregate [ten#x, four#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, null, 1], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, null, four#x, 2]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, four#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x, four#x AS four#x] +- SubqueryAlias spark_catalog.default.onek @@ -432,8 +432,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) > 0) - +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) > 0) + +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, 0], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, null, 1]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x] +- SubqueryAlias spark_catalog.default.onek @@ -447,8 +447,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) > 0) - +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) > 0) + +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, 0], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, null, 1]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x] +- SubqueryAlias spark_catalog.default.onek @@ -482,7 +482,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by grouping sets ((a),(b)) order by 3,4,1,2 /* 3,1,2 */ -- !query analysis Sort [grouping(a)#x ASC NULLS FIRST, grouping(b)#x ASC NULLS FIRST, a#x ASC NULLS FIRST, b#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, b#x, 2]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -496,7 +496,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by cube(a,b) order by 3,4,1,2 /* 3,1,2 */ -- !query analysis Sort [grouping(a)#x ASC NULLS FIRST, grouping(b)#x ASC NULLS FIRST, a#x ASC NULLS FIRST, b#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, b#x, 2], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -526,7 +526,7 @@ select unhashable_col, unsortable_col, order by 3, 4, 6 /* 3, 5 */ -- !query analysis Sort [grouping(unhashable_col)#x ASC NULLS FIRST, grouping(unsortable_col)#x ASC NULLS FIRST, sum(v)#xL ASC NULLS FIRST], true -+- Aggregate [unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] ++- Aggregate [unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] +- Expand [[id#x, v#x, unhashable_col#x, unsortable_col#x, unhashable_col#x, null, 1], [id#x, v#x, unhashable_col#x, unsortable_col#x, null, unsortable_col#x, 2]], [id#x, v#x, unhashable_col#x, unsortable_col#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL] +- Project [id#x, v#x, unhashable_col#x, unsortable_col#x, unhashable_col#x AS unhashable_col#x, unsortable_col#x AS unsortable_col#x] +- SubqueryAlias spark_catalog.default.gstest4 @@ -541,7 +541,7 @@ select unhashable_col, unsortable_col, order by 3, 4, 6 /* 3,5 */ -- !query analysis Sort [grouping(unhashable_col)#x ASC NULLS FIRST, grouping(unsortable_col)#x ASC NULLS FIRST, sum(v)#xL ASC NULLS FIRST], true -+- Aggregate [v#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] ++- Aggregate [v#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] +- Expand [[id#x, v#x, unhashable_col#x, unsortable_col#x, v#x, unhashable_col#x, null, 1], [id#x, v#x, unhashable_col#x, unsortable_col#x, v#x, null, unsortable_col#x, 2]], [id#x, v#x, unhashable_col#x, unsortable_col#x, v#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL] +- Project [id#x, v#x, unhashable_col#x, unsortable_col#x, v#x AS v#x, unhashable_col#x AS unhashable_col#x, unsortable_col#x AS unsortable_col#x] +- SubqueryAlias spark_catalog.default.gstest4 @@ -593,7 +593,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by grouping sets ((a,b),(a+1,b+1),(a+2,b+2)) order by 3,4,7 /* 3,6 */ -- !query analysis Sort [grouping(a)#x ASC NULLS FIRST, grouping(b)#x ASC NULLS FIRST, max(v)#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, (a#x + 1)#x, (b#x + 1)#x, (a#x + 2)#x, (b#x + 2)#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 5) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 4) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, (a#x + 1)#x, (b#x + 1)#x, (a#x + 2)#x, (b#x + 2)#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 5) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 4) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, null, null, null, null, 15], [a#x, b#x, v#x, null, null, (a#x + 1)#x, (b#x + 1)#x, null, null, 51], [a#x, b#x, v#x, null, null, null, null, (a#x + 2)#x, (b#x + 2)#x, 60]], [a#x, b#x, v#x, a#x, b#x, (a#x + 1)#x, (b#x + 1)#x, (a#x + 2)#x, (b#x + 2)#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x, (a#x + 1) AS (a#x + 1)#x, (b#x + 1) AS (b#x + 1)#x, (a#x + 2) AS (a#x + 2)#x, (b#x + 2) AS (b#x + 2)#x] +- SubqueryAlias gstest1 @@ -634,7 +634,7 @@ select v||'a', case grouping(v||'a') when 1 then 1 else 0 end, count(*) group by rollup(i, v||'a') order by 1,3 -- !query analysis Sort [concat(v, a)#x ASC NULLS FIRST, count(1)#xL ASC NULLS FIRST], true -+- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] ++- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] +- Expand [[i#x, v#x, i#x, concat(v#x, a)#x, 0], [i#x, v#x, i#x, null, 1], [i#x, v#x, null, null, 3]], [i#x, v#x, i#x, concat(v#x, a)#x, spark_grouping_id#xL] +- Project [i#x, v#x, i#x AS i#x, concat(v#x, a) AS concat(v#x, a)#x] +- SubqueryAlias u @@ -647,7 +647,7 @@ select v||'a', case when grouping(v||'a') = 1 then 1 else 0 end, count(*) group by rollup(i, v||'a') order by 1,3 -- !query analysis Sort [concat(v, a)#x ASC NULLS FIRST, count(1)#xL ASC NULLS FIRST], true -+- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] ++- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] +- Expand [[i#x, v#x, i#x, concat(v#x, a)#x, 0], [i#x, v#x, i#x, null, 1], [i#x, v#x, null, null, 3]], [i#x, v#x, i#x, concat(v#x, a)#x, spark_grouping_id#xL] +- Project [i#x, v#x, i#x AS i#x, concat(v#x, a) AS concat(v#x, a)#x] +- SubqueryAlias u diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out index 9dda3c0dc42d4..3fa919434da79 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out @@ -274,14 +274,14 @@ Project [ AS five#x, f1#x, (cast(f1#x as double) / cast(cast(2 as int) as double -- !query SELECT string(shiftleft(smallint(-1), 15)) -- !query analysis -Project [cast(shiftleft(cast(cast(-1 as smallint) as int), 15) as string) AS shiftleft(-1, 15)#x] +Project [cast((cast(cast(-1 as smallint) as int) << 15) as string) AS (-1 << 15)#x] +- OneRowRelation -- !query SELECT string(smallint(shiftleft(smallint(-1), 15))+1) -- !query analysis -Project [cast((cast(cast(shiftleft(cast(cast(-1 as smallint) as int), 15) as smallint) as int) + 1) as string) AS (shiftleft(-1, 15) + 1)#x] +Project [cast((cast(cast((cast(cast(-1 as smallint) as int) << 15) as smallint) as int) + 1) as string) AS ((-1 << 15) + 1)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out index d261b59a4c5e2..f6a8b24f917d2 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out @@ -411,14 +411,14 @@ Project [(cast((2 + 2) as double) / cast(2 as double)) AS two#x] -- !query SELECT string(shiftleft(int(-1), 31)) -- !query analysis -Project [cast(shiftleft(cast(-1 as int), 31) as string) AS shiftleft(-1, 31)#x] +Project [cast((cast(-1 as int) << 31) as string) AS (-1 << 31)#x] +- OneRowRelation -- !query SELECT string(int(shiftleft(int(-1), 31))+1) -- !query analysis -Project [cast((cast(shiftleft(cast(-1 as int), 31) as int) + 1) as string) AS (shiftleft(-1, 31) + 1)#x] +Project [cast((cast((cast(-1 as int) << 31) as int) + 1) as string) AS ((-1 << 31) + 1)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out index 72972469fa6ef..dfc96427b57ed 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out @@ -659,14 +659,14 @@ Project [id#xL] -- !query SELECT string(shiftleft(bigint(-1), 63)) -- !query analysis -Project [cast(shiftleft(cast(-1 as bigint), 63) as string) AS shiftleft(-1, 63)#x] +Project [cast((cast(-1 as bigint) << 63) as string) AS (-1 << 63)#x] +- OneRowRelation -- !query SELECT string(int(shiftleft(bigint(-1), 63))+1) -- !query analysis -Project [cast((cast(shiftleft(cast(-1 as bigint), 63) as int) + 1) as string) AS (shiftleft(-1, 63) + 1)#x] +Project [cast((cast((cast(-1 as bigint) << 63) as int) + 1) as string) AS ((-1 << 63) + 1)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out index fbee3e2c8c89f..7d6eb0a83bf4e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out @@ -189,7 +189,7 @@ Sort [cast(udf(cast(course#x as string)) as string) ASC NULLS FIRST, sum#xL ASC SELECT udf(course), udf(year), GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) -- !query analysis -Aggregate [course#x, year#x, spark_grouping_id#xL], [cast(udf(cast(course#x as string)) as string) AS udf(course)#x, cast(udf(cast(year#x as string)) as int) AS udf(year)#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +Aggregate [course#x, year#x, spark_grouping_id#xL], [cast(udf(cast(course#x as string)) as string) AS udf(course)#x, cast(udf(cast(year#x as string)) as int) AS udf(year)#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -255,7 +255,7 @@ HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, udf -- !query analysis Sort [course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true +- Project [course#x, year#x] - +- Filter ((cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) + +- Filter ((cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] @@ -308,8 +308,8 @@ SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY ORDER BY GROUPING(course), GROUPING(year), course, udf(year) -- !query analysis Project [course#x, year#x, grouping(course)#x, grouping(year)#x] -+- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true - +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] ++- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true + +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -325,7 +325,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, udf(year) -- !query analysis Project [course#x, year#x, grouping_id(course, year)#xL] -+- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true ++- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL AS grouping_id(course, year)#xL, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql index f9dfd161d0c07..5823b22ef6453 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql @@ -75,3 +75,15 @@ select getbit(11L, 2 + 1), getbit(11L, 3 - 1), getbit(10L + 1, 1 * 1), getbit(ca select getbit(11L, 63); select getbit(11L, -1); select getbit(11L, 64); + +SELECT 20181117 >> 2; +SELECT 20181117 << 2; +SELECT 20181117 >>> 2; +SELECT 20181117 > > 2; +SELECT 20181117 < < 2; +SELECT 20181117 > >> 2; +SELECT 20181117 <<< 2; +SELECT 20181117 >>>> 2; +select cast(null as array>), 20181117 >> 2; +select cast(null as array>), 20181117 >>> 2; +select cast(null as map>), 20181117 >> 2; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out index 2c8b733004aac..a7ebaea293bf9 100644 --- a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out @@ -322,3 +322,131 @@ org.apache.spark.SparkIllegalArgumentException "upper" : "64" } } + + +-- !query +SELECT 20181117 >> 2 +-- !query schema +struct<(20181117 >> 2):int> +-- !query output +5045279 + + +-- !query +SELECT 20181117 << 2 +-- !query schema +struct<(20181117 << 2):int> +-- !query output +80724468 + + +-- !query +SELECT 20181117 >>> 2 +-- !query schema +struct<(20181117 >>> 2):int> +-- !query output +5045279 + + +-- !query +SELECT 20181117 > > 2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'>'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 < < 2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'<'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 > >> 2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'>>'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 <<< 2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'<'", + "hint" : "" + } +} + + +-- !query +SELECT 20181117 >>>> 2 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'>'", + "hint" : "" + } +} + + +-- !query +select cast(null as array>), 20181117 >> 2 +-- !query schema +struct>,(20181117 >> 2):int> +-- !query output +NULL 5045279 + + +-- !query +select cast(null as array>), 20181117 >>> 2 +-- !query schema +struct>,(20181117 >>> 2):int> +-- !query output +NULL 5045279 + + +-- !query +select cast(null as map>), 20181117 >> 2 +-- !query schema +struct>,(20181117 >> 2):int> +-- !query output +NULL 5045279 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out index ca55b6accc665..1c96f8dfa5e54 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out @@ -289,7 +289,7 @@ struct -- !query SELECT string(shiftleft(smallint(-1), 15)) -- !query schema -struct +struct<(-1 << 15):string> -- !query output -32768 @@ -297,7 +297,7 @@ struct -- !query SELECT string(smallint(shiftleft(smallint(-1), 15))+1) -- !query schema -struct<(shiftleft(-1, 15) + 1):string> +struct<((-1 << 15) + 1):string> -- !query output -32767 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out index 16c18c86f2919..afe0211bd1947 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out @@ -520,7 +520,7 @@ struct -- !query SELECT string(shiftleft(int(-1), 31)) -- !query schema -struct +struct<(-1 << 31):string> -- !query output -2147483648 @@ -528,7 +528,7 @@ struct -- !query SELECT string(int(shiftleft(int(-1), 31))+1) -- !query schema -struct<(shiftleft(-1, 31) + 1):string> +struct<((-1 << 31) + 1):string> -- !query output -2147483647 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out index f6e4bd8bd7e08..6e7ca4afab67d 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out @@ -883,7 +883,7 @@ struct -- !query SELECT string(shiftleft(bigint(-1), 63)) -- !query schema -struct +struct<(-1 << 63):string> -- !query output -9223372036854775808 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt index 850b20431e487..6908b8137b0c4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt @@ -102,7 +102,7 @@ Condition : (isnotnull(cs_bill_customer_sk#14) AND isnotnull(cs_item_sk#15)) (13) BroadcastExchange Input [4]: [cs_bill_customer_sk#14, cs_item_sk#15, cs_quantity#16, cs_sold_date_sk#17] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] (14) BroadcastHashJoin [codegen id : 8] Left keys [2]: [sr_customer_sk#9, sr_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt index e2caa9f171b86..15b74bac0fbec 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt @@ -102,7 +102,7 @@ Condition : (isnotnull(cs_bill_customer_sk#14) AND isnotnull(cs_item_sk#15)) (13) BroadcastExchange Input [4]: [cs_bill_customer_sk#14, cs_item_sk#15, cs_net_profit#16, cs_sold_date_sk#17] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] (14) BroadcastHashJoin [codegen id : 8] Left keys [2]: [sr_customer_sk#9, sr_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt index 6cc9c3a4834ee..4b7d4f2f068d6 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt @@ -167,7 +167,7 @@ Input [11]: [i_item_id#19, s_state#20, spark_grouping_id#21, sum#30, count#31, s Keys [3]: [i_item_id#19, s_state#20, spark_grouping_id#21] Functions [4]: [avg(ss_quantity#4), avg(UnscaledValue(ss_list_price#5)), avg(UnscaledValue(ss_coupon_amt#7)), avg(UnscaledValue(ss_sales_price#6))] Aggregate Attributes [4]: [avg(ss_quantity#4)#38, avg(UnscaledValue(ss_list_price#5))#39, avg(UnscaledValue(ss_coupon_amt#7))#40, avg(UnscaledValue(ss_sales_price#6))#41] -Results [7]: [i_item_id#19, s_state#20, cast((shiftright(spark_grouping_id#21, 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] +Results [7]: [i_item_id#19, s_state#20, cast(((spark_grouping_id#21 >> 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] (30) TakeOrderedAndProject Input [7]: [i_item_id#19, s_state#20, g_state#42, agg1#43, agg2#44, agg3#45, agg4#46] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt index 6cc9c3a4834ee..4b7d4f2f068d6 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt @@ -167,7 +167,7 @@ Input [11]: [i_item_id#19, s_state#20, spark_grouping_id#21, sum#30, count#31, s Keys [3]: [i_item_id#19, s_state#20, spark_grouping_id#21] Functions [4]: [avg(ss_quantity#4), avg(UnscaledValue(ss_list_price#5)), avg(UnscaledValue(ss_coupon_amt#7)), avg(UnscaledValue(ss_sales_price#6))] Aggregate Attributes [4]: [avg(ss_quantity#4)#38, avg(UnscaledValue(ss_list_price#5))#39, avg(UnscaledValue(ss_coupon_amt#7))#40, avg(UnscaledValue(ss_sales_price#6))#41] -Results [7]: [i_item_id#19, s_state#20, cast((shiftright(spark_grouping_id#21, 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] +Results [7]: [i_item_id#19, s_state#20, cast(((spark_grouping_id#21 >> 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] (30) TakeOrderedAndProject Input [7]: [i_item_id#19, s_state#20, g_state#42, agg1#43, agg2#44, agg3#45, agg4#46] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt index 76a6ab9c7215b..27534390f0a24 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt @@ -102,7 +102,7 @@ Condition : (isnotnull(cs_bill_customer_sk#14) AND isnotnull(cs_item_sk#15)) (13) BroadcastExchange Input [4]: [cs_bill_customer_sk#14, cs_item_sk#15, cs_quantity#16, cs_sold_date_sk#17] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] (14) BroadcastHashJoin [codegen id : 8] Left keys [2]: [sr_customer_sk#9, sr_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt index ea59f2b926c9d..63cb718a827f3 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt @@ -137,7 +137,7 @@ Input [5]: [i_category#13, i_class#14, spark_grouping_id#15, sum#18, sum#19] Keys [3]: [i_category#13, i_class#14, spark_grouping_id#15] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#20, sum(UnscaledValue(ss_ext_sales_price#3))#21] -Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] +Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast(((spark_grouping_id#15 >> 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] (24) Exchange Input [7]: [gross_margin#22, i_category#13, i_class#14, lochierarchy#23, _w0#24, _w1#25, _w2#26] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt index 6cc55ab063f68..eb59673575aba 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt @@ -137,7 +137,7 @@ Input [5]: [i_category#13, i_class#14, spark_grouping_id#15, sum#18, sum#19] Keys [3]: [i_category#13, i_class#14, spark_grouping_id#15] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#20, sum(UnscaledValue(ss_ext_sales_price#3))#21] -Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] +Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast(((spark_grouping_id#15 >> 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] (24) Exchange Input [7]: [gross_margin#22, i_category#13, i_class#14, lochierarchy#23, _w0#24, _w1#25, _w2#26] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt index 220598440e092..995b723c6e287 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt @@ -237,7 +237,7 @@ Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, stdev#46, mean#47] (41) BroadcastExchange Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, mean#47, cov#48] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] (42) BroadcastHashJoin [codegen id : 10] Left keys [2]: [i_item_sk#6, w_warehouse_sk#7] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt index 585e748860557..dba61c77b774e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt @@ -237,7 +237,7 @@ Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, stdev#46, mean#47] (41) BroadcastExchange Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, mean#47, cov#48] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] (42) BroadcastHashJoin [codegen id : 10] Left keys [2]: [i_item_sk#6, w_warehouse_sk#7] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt index 9eea658d789e4..93f79d66f0973 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt @@ -99,7 +99,7 @@ Input [6]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_ne (5) BroadcastExchange Input [5]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_sold_date_sk#6] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] (6) Scan parquet spark_catalog.default.web_returns Output [5]: [wr_item_sk#8, wr_order_number#9, wr_return_quantity#10, wr_return_amt#11, wr_returned_date_sk#12] @@ -209,7 +209,7 @@ Input [6]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, c (29) BroadcastExchange Input [5]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, cs_sold_date_sk#41] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] (30) Scan parquet spark_catalog.default.catalog_returns Output [5]: [cr_item_sk#42, cr_order_number#43, cr_return_quantity#44, cr_return_amount#45, cr_returned_date_sk#46] @@ -319,7 +319,7 @@ Input [6]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, (53) BroadcastExchange Input [5]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, ss_sold_date_sk#75] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] (54) Scan parquet spark_catalog.default.store_returns Output [5]: [sr_item_sk#76, sr_ticket_number#77, sr_return_quantity#78, sr_return_amt#79, sr_returned_date_sk#80] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt index 313959456c809..93103073d6f85 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt @@ -304,7 +304,7 @@ Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, (49) BroadcastExchange Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, wr_returned_date_sk#96] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, true] as bigint) << 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] (50) Scan parquet spark_catalog.default.web_sales Output [4]: [ws_item_sk#97, ws_web_site_sk#98, ws_order_number#99, ws_sold_date_sk#100] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt index 69023c88202af..3a049ca71e742 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt @@ -201,7 +201,7 @@ Condition : (((((((isnotnull(ss_item_sk#1) AND isnotnull(ss_ticket_number#8)) AN (4) BroadcastExchange Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_cdemo_sk#3, ss_hdemo_sk#4, ss_addr_sk#5, ss_store_sk#6, ss_promo_sk#7, ss_ticket_number#8, ss_wholesale_cost#9, ss_list_price#10, ss_coupon_amt#11, ss_sold_date_sk#12] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] (5) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#14, sr_ticket_number#15, sr_returned_date_sk#16] @@ -714,7 +714,7 @@ Condition : (((((((isnotnull(ss_item_sk#106) AND isnotnull(ss_ticket_number#113) (115) BroadcastExchange Input [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_ticket_number#113, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] (116) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#119, sr_ticket_number#120, sr_returned_date_sk#121] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt index d64f560f144e0..b6b480018aa46 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt @@ -224,7 +224,7 @@ Input [4]: [s_state#20, s_county#21, spark_grouping_id#22, sum#24] Keys [3]: [s_state#20, s_county#21, spark_grouping_id#22] Functions [1]: [sum(UnscaledValue(ss_net_profit#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ss_net_profit#2))#25] -Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] +Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast(((spark_grouping_id#22 >> 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] (39) Exchange Input [7]: [total_sum#26, s_state#20, s_county#21, lochierarchy#27, _w0#28, _w1#29, _w2#30] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt index dade1b4f55c5f..9495128a50e13 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt @@ -224,7 +224,7 @@ Input [4]: [s_state#20, s_county#21, spark_grouping_id#22, sum#24] Keys [3]: [s_state#20, s_county#21, spark_grouping_id#22] Functions [1]: [sum(UnscaledValue(ss_net_profit#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ss_net_profit#2))#25] -Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] +Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast(((spark_grouping_id#22 >> 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] (39) Exchange Input [7]: [total_sum#26, s_state#20, s_county#21, lochierarchy#27, _w0#28, _w1#29, _w2#30] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt index 12ba2db6323e4..31da928a5d7a3 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt @@ -264,7 +264,7 @@ Condition : (isnotnull(d_week_seq#26) AND isnotnull(d_date_sk#25)) (42) BroadcastExchange Input [2]: [d_date_sk#25, d_week_seq#26] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, false] as bigint), 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, false] as bigint) << 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] (43) BroadcastHashJoin [codegen id : 10] Left keys [2]: [d_week_seq#24, inv_date_sk#13] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt index af6632f4fb608..31c804c73eaef 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt @@ -66,7 +66,7 @@ Condition : ((((isnotnull(ws_item_sk#1) AND isnotnull(ws_order_number#3)) AND is (4) BroadcastExchange Input [7]: [ws_item_sk#1, ws_web_page_sk#2, ws_order_number#3, ws_quantity#4, ws_sales_price#5, ws_net_profit#6, ws_sold_date_sk#7] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[2, int, false] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[2, int, false] as bigint) & 4294967295))),false), [plan_id=1] (5) Scan parquet spark_catalog.default.web_returns Output [9]: [wr_item_sk#9, wr_refunded_cdemo_sk#10, wr_refunded_addr_sk#11, wr_returning_cdemo_sk#12, wr_reason_sk#13, wr_order_number#14, wr_fee#15, wr_refunded_cash#16, wr_returned_date_sk#17] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt index d1802b2e4a7c6..c496d20204875 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt @@ -98,7 +98,7 @@ Input [4]: [i_category#9, i_class#10, spark_grouping_id#11, sum#13] Keys [3]: [i_category#9, i_class#10, spark_grouping_id#11] Functions [1]: [sum(UnscaledValue(ws_net_paid#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ws_net_paid#2))#14] -Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] +Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast(((spark_grouping_id#11 >> 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] (17) Exchange Input [7]: [total_sum#15, i_category#9, i_class#10, lochierarchy#16, _w0#17, _w1#18, _w2#19] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt index d1802b2e4a7c6..c496d20204875 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt @@ -98,7 +98,7 @@ Input [4]: [i_category#9, i_class#10, spark_grouping_id#11, sum#13] Keys [3]: [i_category#9, i_class#10, spark_grouping_id#11] Functions [1]: [sum(UnscaledValue(ws_net_paid#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ws_net_paid#2))#14] -Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] +Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast(((spark_grouping_id#11 >> 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] (17) Exchange Input [7]: [total_sum#15, i_category#9, i_class#10, lochierarchy#16, _w0#17, _w1#18, _w2#19] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt index 9d80077e99372..e437dea8ca9a0 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt @@ -125,7 +125,7 @@ Input [11]: [s_store_sk#1, s_store_name#2, s_state#4, ca_address_sk#6, ca_state# (17) BroadcastExchange Input [7]: [s_store_sk#1, s_store_name#2, s_state#4, ca_state#7, c_customer_sk#10, c_first_name#12, c_last_name#13] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[4, int, true] as bigint) & 4294967295))),false), [plan_id=3] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, true] as bigint) << 32) | (cast(input[4, int, true] as bigint) & 4294967295))),false), [plan_id=3] (18) Scan parquet spark_catalog.default.store_sales Output [6]: [ss_item_sk#15, ss_customer_sk#16, ss_store_sk#17, ss_ticket_number#18, ss_net_paid#19, ss_sold_date_sk#20] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt index fea7a9fe207df..ec609603ea35b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt @@ -99,7 +99,7 @@ Input [6]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_ne (5) BroadcastExchange Input [5]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_sold_date_sk#6] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] (6) Scan parquet spark_catalog.default.web_returns Output [5]: [wr_item_sk#8, wr_order_number#9, wr_return_quantity#10, wr_return_amt#11, wr_returned_date_sk#12] @@ -209,7 +209,7 @@ Input [6]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, c (29) BroadcastExchange Input [5]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, cs_sold_date_sk#41] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] (30) Scan parquet spark_catalog.default.catalog_returns Output [5]: [cr_item_sk#42, cr_order_number#43, cr_return_quantity#44, cr_return_amount#45, cr_returned_date_sk#46] @@ -319,7 +319,7 @@ Input [6]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, (53) BroadcastExchange Input [5]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, ss_sold_date_sk#75] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] (54) Scan parquet spark_catalog.default.store_returns Output [5]: [sr_item_sk#76, sr_ticket_number#77, sr_return_quantity#78, sr_return_amt#79, sr_returned_date_sk#80] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt index 34c6ecf3cf2fa..1b9bf5123e965 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt @@ -317,7 +317,7 @@ Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, (49) BroadcastExchange Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, wr_returned_date_sk#96] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, true] as bigint) << 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] (50) Scan parquet spark_catalog.default.web_sales Output [4]: [ws_item_sk#97, ws_web_site_sk#98, ws_order_number#99, ws_sold_date_sk#100] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt index 40eddbbacf38a..4579b8bbe8197 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt @@ -201,7 +201,7 @@ Condition : (((((((isnotnull(ss_item_sk#1) AND isnotnull(ss_ticket_number#8)) AN (4) BroadcastExchange Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_cdemo_sk#3, ss_hdemo_sk#4, ss_addr_sk#5, ss_store_sk#6, ss_promo_sk#7, ss_ticket_number#8, ss_wholesale_cost#9, ss_list_price#10, ss_coupon_amt#11, ss_sold_date_sk#12] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] (5) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#14, sr_ticket_number#15, sr_returned_date_sk#16] @@ -714,7 +714,7 @@ Condition : (((((((isnotnull(ss_item_sk#106) AND isnotnull(ss_ticket_number#113) (115) BroadcastExchange Input [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_ticket_number#113, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] +Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] (116) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#119, sr_ticket_number#120, sr_returned_date_sk#121] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt index 13d7d1bc9c4d8..47974c9691023 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt @@ -264,7 +264,7 @@ Condition : (isnotnull(d_week_seq#26) AND isnotnull(d_date_sk#25)) (42) BroadcastExchange Input [2]: [d_date_sk#25, d_week_seq#26] -Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, false] as bigint), 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] +Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, false] as bigint) << 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] (43) BroadcastHashJoin [codegen id : 10] Left keys [2]: [d_week_seq#24, inv_date_sk#13] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 19251330cffe3..e80f4af1dc462 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -141,7 +141,10 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // Examples demonstrate alternative syntax, see SPARK-45574 "org.apache.spark.sql.catalyst.expressions.Cast", // Examples demonstrate alternative syntax, see SPARK-47012 - "org.apache.spark.sql.catalyst.expressions.Collate" + "org.apache.spark.sql.catalyst.expressions.Collate", + classOf[ShiftLeft].getName, + classOf[ShiftRight].getName, + classOf[ShiftRightUnsigned].getName ) spark.sessionState.functionRegistry.listFunction().foreach { funcId => val info = spark.sessionState.catalog.lookupFunctionInfo(funcId) From 3cb30c2366b27c5a65ec02121c30bd1a4eb20584 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 24 May 2024 09:43:03 -0700 Subject: [PATCH 19/45] [SPARK-47579][SQL][FOLLOWUP] Restore the `--help` print format of spark sql shell ### What changes were proposed in this pull request? Restore the print format of spark sql shell ### Why are the changes needed? bugfix ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually ![image](https://github.com/apache/spark/assets/8326978/17b9d009-5d93-4d84-9367-7308b4cda426) ![image](https://github.com/apache/spark/assets/8326978/a5e333bd-0e22-4d5a-83f1-843767f6d5f5) ### Was this patch authored or co-authored using generative AI tooling? no Closes #46735 from yaooqinn/SPARK-47579. Authored-by: Kent Yao Signed-off-by: Gengliang Wang --- .../src/main/scala/org/apache/spark/internal/LogKey.scala | 1 - .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index 1f67a211c01fa..99fc58b035030 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -585,7 +585,6 @@ object LogKeys { case object SESSION_KEY extends LogKey case object SET_CLIENT_INFO_REQUEST extends LogKey case object SHARD_ID extends LogKey - case object SHELL_OPTIONS extends LogKey case object SHORT_USER_NAME extends LogKey case object SHUFFLE_BLOCK_INFO extends LogKey case object SHUFFLE_DB_BACKEND_KEY extends LogKey diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 61235a7019070..e47596a6ae430 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -588,7 +588,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S ) if (SparkSubmit.isSqlShell(mainClass)) { - logInfo(log"CLI options:\n${MDC(SHELL_OPTIONS, getSqlShellOptions())}") + logInfo("CLI options:") + logInfo(getSqlShellOptions()) } throw SparkUserAppException(exitCode) From 7d96334902f22a80af63ce1253d5abda63178c4e Mon Sep 17 00:00:00 2001 From: Bo Zhang Date: Fri, 24 May 2024 15:54:21 -0700 Subject: [PATCH 20/45] [SPARK-48325][CORE] Always specify messages in ExecutorRunner.killProcess ### What changes were proposed in this pull request? This change is to always specify the message in `ExecutorRunner.killProcess`. ### Why are the changes needed? This is to get the occurrence rate for different cases when killing the executor process, in order to analyze executor running stability. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? No Closes #46641 from bozhang2820/spark-48325. Authored-by: Bo Zhang Signed-off-by: Dongjoon Hyun --- .../apache/spark/deploy/worker/ExecutorRunner.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 7bb8b74eb0218..bd98f19cdb605 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -88,7 +88,7 @@ private[deploy] class ExecutorRunner( if (state == ExecutorState.LAUNCHING || state == ExecutorState.RUNNING) { state = ExecutorState.FAILED } - killProcess(Some("Worker shutting down")) } + killProcess("Worker shutting down") } } /** @@ -96,7 +96,7 @@ private[deploy] class ExecutorRunner( * * @param message the exception message which caused the executor's death */ - private def killProcess(message: Option[String]): Unit = { + private def killProcess(message: String): Unit = { var exitCode: Option[Int] = None if (process != null) { logInfo("Killing process!") @@ -113,7 +113,7 @@ private[deploy] class ExecutorRunner( } } try { - worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), exitCode)) } catch { case e: IllegalStateException => logWarning(log"${MDC(ERROR, e.getMessage())}", e) } @@ -206,11 +206,11 @@ private[deploy] class ExecutorRunner( case interrupted: InterruptedException => logInfo("Runner thread for executor " + fullId + " interrupted") state = ExecutorState.KILLED - killProcess(None) + killProcess(s"Runner thread for executor $fullId interrupted") case e: Exception => logError("Error running executor", e) state = ExecutorState.FAILED - killProcess(Some(e.toString)) + killProcess(s"Error running executor: $e") } } } From 1a536f01ead35b770467381c476e093338d81e7c Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 24 May 2024 15:56:19 -0700 Subject: [PATCH 21/45] [SPARK-48407][SQL][DOCS] Teradata: Document Type Conversion rules between Spark SQL and teradata ### What changes were proposed in this pull request? This PR adds documentation for the builtin teradata jdbc dialect's data type conversion rules ### Why are the changes needed? doc improvement ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ![image](https://github.com/apache/spark/assets/8326978/e1ec0de5-cd83-4339-896a-50c58ad01c4d) ### Was this patch authored or co-authored using generative AI tooling? no Closes #46728 from yaooqinn/SPARK-48407. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- docs/sql-data-sources-jdbc.md | 214 ++++++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 371dc05950717..9ffd96cd40ee5 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -1991,3 +1991,217 @@ The Spark Catalyst data types below are not supported with suitable DB2 types. - NullType - ObjectType - VariantType + +### Mapping Spark SQL Data Types from Teradata + +The below table describes the data type conversions from Teradata data types to Spark SQL Data Types, +when reading data from a Teradata table using the built-in jdbc data source with the [Teradata JDBC Driver](https://mvnrepository.com/artifact/com.teradata.jdbc/terajdbc) +as the activated JDBC Driver. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Teradata Data TypeSpark SQL Data TypeRemarks
BYTEINTByteType
SMALLINTShortType
INTEGER, INTIntegerType
BIGINTLongType
REAL, DOUBLE PRECISION, FLOATDoubleType
DECIMAL, NUMERIC, NUMBERDecimalType
DATEDateType
TIMESTAMP, TIMESTAMP WITH TIME ZONETimestampType(Default)preferTimestampNTZ=false or spark.sql.timestampType=TIMESTAMP_LTZ
TIMESTAMP, TIMESTAMP WITH TIME ZONETimestampNTZTypepreferTimestampNTZ=true or spark.sql.timestampType=TIMESTAMP_NTZ
TIME, TIME WITH TIME ZONETimestampType(Default)preferTimestampNTZ=false or spark.sql.timestampType=TIMESTAMP_LTZ
TIME, TIME WITH TIME ZONETimestampNTZTypepreferTimestampNTZ=true or spark.sql.timestampType=TIMESTAMP_NTZ
CHARACTER(n), CHAR(n), GRAPHIC(n)CharType(n)
VARCHAR(n), VARGRAPHIC(n)VarcharType(n)
BYTE(n), VARBYTE(n)BinaryType
CLOBStringType
BLOBBinaryType
INTERVAL Data Types-The INTERVAL data types are unknown yet
Period Data Types, ARRAY, UDT-Not Supported
+ +### Mapping Spark SQL Data Types to Teradata + +The below table describes the data type conversions from Spark SQL Data Types to Teradata data types, +when creating, altering, or writing data to a Teradata table using the built-in jdbc data source with +the [Teradata JDBC Driver](https://mvnrepository.com/artifact/com.teradata.jdbc/terajdbc) as the activated JDBC Driver. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Spark SQL Data TypeTeradata Data TypeRemarks
BooleanTypeCHAR(1)
ByteTypeBYTEINT
ShortTypeSMALLINT
IntegerTypeINTEGER
LongTypeBIGINT
FloatTypeREAL
DoubleTypeDOUBLE PRECISION
DecimalType(p, s)DECIMAL(p,s)
DateTypeDATE
TimestampTypeTIMESTAMP
TimestampNTZTypeTIMESTAMP
StringTypeVARCHAR(255)
BinaryTypeBLOB
CharType(n)CHAR(n)
VarcharType(n)VARCHAR(n)
+ +The Spark Catalyst data types below are not supported with suitable Teradata types. + +- DayTimeIntervalType +- YearMonthIntervalType +- CalendarIntervalType +- ArrayType +- MapType +- StructType +- UserDefinedType +- NullType +- ObjectType +- VariantType From 6cd1ccc56321dfa52672cd25f4cfdf2bbc86b3ea Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Fri, 24 May 2024 16:01:17 -0700 Subject: [PATCH 22/45] [SPARK-48394][CORE] Cleanup mapIdToMapIndex on mapoutput unregister ### What changes were proposed in this pull request? This PR cleans up `mapIdToMapIndex` when the corresponding mapstatus is unregistered in three places: * `removeMapOutput` * `removeOutputsByFilter` * `addMapOutput` (old mapstatus overwritten) ### Why are the changes needed? There is only one valid mapstatus for the same `mapIndex` at the same time in Spark. `mapIdToMapIndex` should also follows the same rule to avoid chaos. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46706 from Ngone51/SPARK-43043-followup. Lead-authored-by: Yi Wu Co-authored-by: wuyi Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/MapOutputTracker.scala | 26 ++++++--- .../apache/spark/MapOutputTrackerSuite.scala | 55 +++++++++++++++++++ 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index fdc2b0a4c20f0..a660bccd2e68f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -44,7 +44,6 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} import org.apache.spark.util._ import org.apache.spark.util.ArrayImplicits._ -import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** @@ -153,17 +152,22 @@ private class ShuffleStatus( /** * Mapping from a mapId to the mapIndex, this is required to reduce the searching overhead within * the function updateMapOutput(mapId, bmAddress). + * + * Exposed for testing. */ - private[this] val mapIdToMapIndex = new OpenHashMap[Long, Int]() + private[spark] val mapIdToMapIndex = new HashMap[Long, Int]() /** * Register a map output. If there is already a registered location for the map output then it * will be replaced by the new location. */ def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { - if (mapStatuses(mapIndex) == null) { + val currentMapStatus = mapStatuses(mapIndex) + if (currentMapStatus == null) { _numAvailableMapOutputs += 1 invalidateSerializedMapOutputStatusCache() + } else { + mapIdToMapIndex.remove(currentMapStatus.mapId) } mapStatuses(mapIndex) = status mapIdToMapIndex(status.mapId) = mapIndex @@ -193,8 +197,8 @@ private class ShuffleStatus( mapStatus.updateLocation(bmAddress) invalidateSerializedMapOutputStatusCache() case None => - if (mapIndex.map(mapStatusesDeleted).exists(_.mapId == mapId)) { - val index = mapIndex.get + val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId == mapId) + if (index >= 0 && mapStatuses(index) == null) { val mapStatus = mapStatusesDeleted(index) mapStatus.updateLocation(bmAddress) mapStatuses(index) = mapStatus @@ -222,9 +226,11 @@ private class ShuffleStatus( */ def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock { logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}") - if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { + val currentMapStatus = mapStatuses(mapIndex) + if (currentMapStatus != null && currentMapStatus.location == bmAddress) { _numAvailableMapOutputs -= 1 - mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex) + mapIdToMapIndex.remove(currentMapStatus.mapId) + mapStatusesDeleted(mapIndex) = currentMapStatus mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } @@ -290,9 +296,11 @@ private class ShuffleStatus( */ def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { for (mapIndex <- mapStatuses.indices) { - if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { + val currentMapStatus = mapStatuses(mapIndex) + if (currentMapStatus != null && f(currentMapStatus.location)) { _numAvailableMapOutputs -= 1 - mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex) + mapIdToMapIndex.remove(currentMapStatus.mapId) + mapStatusesDeleted(mapIndex) = currentMapStatus mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7aec8eeaad423..26dc218c30c74 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -1110,4 +1110,59 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { rpcEnv.shutdown() } } + + test( + "SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after removeOutputsByFilter" + ) { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + tracker.removeOutputsOnHost("hostA") + assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } + + test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after unregisterMapOutput") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + tracker.unregisterMapOutput(0, 0, BlockManagerId("exec-1", "hostA", 1000)) + assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } + + test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after registerMapOutput") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + // Another task also finished working on partition 0. + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-2", "hostB", 1000), + Array(2L), 1)) + assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 1) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } } From cff7014b337be18f14289ad8c50c3e08c214a9e4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 24 May 2024 20:38:07 -0700 Subject: [PATCH 23/45] [SPARK-47579][CORE][PART3] Spark core: Migrate logInfo with variables to structured logging framework ### What changes were proposed in this pull request? The PR aims to migrate logInfo in Core module with variables to structured logging framework. ### Why are the changes needed? To enhance Apache Spark's logging system by implementing structured logging. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? GA tests ### Was this patch authored or co-authored using generative AI tooling? Yes, Generated-by: Github Copilot Github Copilot provides a few suggestions. Closes #46739 from gengliangwang/logInfoCoreP3. Authored-by: Gengliang Wang Signed-off-by: Gengliang Wang --- .../shuffle/RetryingBlockTransferor.java | 6 +- .../org/apache/spark/internal/LogKey.scala | 51 ++++- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/deploy/LocalSparkCluster.scala | 5 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 7 +- .../deploy/client/StandaloneAppClient.scala | 19 +- .../deploy/master/RecoveryModeFactory.scala | 7 +- .../deploy/rest/RestSubmissionClient.scala | 30 ++- .../spark/deploy/worker/CommandUtils.scala | 5 +- .../spark/deploy/worker/ExecutorRunner.scala | 2 +- .../apache/spark/deploy/worker/Worker.scala | 96 ++++---- .../spark/memory/ExecutionMemoryPool.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../spark/rdd/SequenceFileRDDFunctions.scala | 7 +- .../apache/spark/scheduler/DAGScheduler.scala | 210 ++++++++++-------- .../spark/scheduler/HealthTracker.scala | 35 +-- .../scheduler/OutputCommitCoordinator.scala | 16 +- .../spark/scheduler/SchedulableBuilder.scala | 32 ++- .../spark/scheduler/StatsReportListener.scala | 9 +- .../spark/scheduler/TaskSchedulerImpl.scala | 69 +++--- .../spark/scheduler/TaskSetExcludeList.scala | 9 +- .../spark/scheduler/TaskSetManager.scala | 67 +++--- .../CoarseGrainedSchedulerBackend.scala | 48 ++-- .../cluster/StandaloneSchedulerBackend.scala | 25 ++- .../scheduler/dynalloc/ExecutorMonitor.scala | 18 +- .../apache/spark/storage/BlockManager.scala | 30 +-- .../storage/BlockManagerDecommissioner.scala | 60 +++-- .../spark/storage/BlockManagerMaster.scala | 10 +- .../storage/BlockManagerMasterEndpoint.scala | 58 ++--- .../spark/storage/PushBasedFetchHelper.scala | 3 +- .../storage/ShuffleBlockFetcherIterator.scala | 35 +-- .../spark/storage/memory/MemoryStore.scala | 37 +-- 34 files changed, 614 insertions(+), 402 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 83be2db5d0b73..31c454f63a92e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -183,12 +183,12 @@ private void transferAllOutstanding() { if (numRetries > 0) { logger.error("Exception while beginning {} of {} outstanding blocks (after {} retries)", e, MDC.of(LogKeys.TRANSFER_TYPE$.MODULE$, listener.getTransferType()), - MDC.of(LogKeys.NUM_BLOCK_IDS$.MODULE$, blockIdsToTransfer.length), + MDC.of(LogKeys.NUM_BLOCKS$.MODULE$, blockIdsToTransfer.length), MDC.of(LogKeys.NUM_RETRY$.MODULE$, numRetries)); } else { logger.error("Exception while beginning {} of {} outstanding blocks", e, MDC.of(LogKeys.TRANSFER_TYPE$.MODULE$, listener.getTransferType()), - MDC.of(LogKeys.NUM_BLOCK_IDS$.MODULE$, blockIdsToTransfer.length)); + MDC.of(LogKeys.NUM_BLOCKS$.MODULE$, blockIdsToTransfer.length)); } if (shouldRetry(e) && initiateRetry(e)) { // successfully initiated a retry @@ -219,7 +219,7 @@ synchronized boolean initiateRetry(Throwable e) { MDC.of(LogKeys.TRANSFER_TYPE$.MODULE$, listener.getTransferType()), MDC.of(LogKeys.NUM_RETRY$.MODULE$, retryCount), MDC.of(LogKeys.MAX_ATTEMPTS$.MODULE$, maxRetries), - MDC.of(LogKeys.NUM_BLOCK_IDS$.MODULE$, outstandingBlocksIds.size()), + MDC.of(LogKeys.NUM_BLOCKS$.MODULE$, outstandingBlocksIds.size()), MDC.of(LogKeys.RETRY_WAIT_TIME$.MODULE$, retryWaitTime)); try { diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index 99fc58b035030..534f009119226 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -72,6 +72,7 @@ object LogKeys { case object CACHE_UNTIL_LAST_PRODUCED_SIZE extends LogKey case object CALL_SITE_LONG_FORM extends LogKey case object CALL_SITE_SHORT_FORM extends LogKey + case object CANCEL_FUTURE_JOBS extends LogKey case object CATALOG_NAME extends LogKey case object CATEGORICAL_FEATURES extends LogKey case object CHECKPOINT_FILE extends LogKey @@ -118,10 +119,10 @@ object LogKeys { case object CONTAINER_ID extends LogKey case object CONTAINER_STATE extends LogKey case object CONTEXT extends LogKey - case object CONTEXT_CREATION_SITE extends LogKey case object COST extends LogKey case object COUNT extends LogKey case object CREATED_POOL_NAME extends LogKey + case object CREATION_SITE extends LogKey case object CREDENTIALS_RENEWAL_INTERVAL_RATIO extends LogKey case object CROSS_VALIDATION_METRIC extends LogKey case object CROSS_VALIDATION_METRICS extends LogKey @@ -132,8 +133,9 @@ object LogKeys { case object CSV_SCHEMA_FIELD_NAMES extends LogKey case object CSV_SOURCE extends LogKey case object CURRENT_BATCH_ID extends LogKey + case object CURRENT_DISK_SIZE extends LogKey case object CURRENT_FILE extends LogKey - case object CURRENT_MEMORY_BYTES extends LogKey + case object CURRENT_MEMORY_SIZE extends LogKey case object CURRENT_PATH extends LogKey case object CURRENT_TIME extends LogKey case object DATA extends LogKey @@ -146,7 +148,6 @@ object LogKeys { case object DEFAULT_COMPACT_INTERVAL extends LogKey case object DEFAULT_ISOLATION_LEVEL extends LogKey case object DEFAULT_NAME extends LogKey - case object DEFAULT_SCHEDULING_MODE extends LogKey case object DEFAULT_VALUE extends LogKey case object DELAY extends LogKey case object DELEGATE extends LogKey @@ -207,6 +208,8 @@ object LogKeys { case object EXPR extends LogKey case object EXPR_TERMS extends LogKey case object EXTENDED_EXPLAIN_GENERATOR extends LogKey + case object FAILED_STAGE extends LogKey + case object FAILED_STAGE_NAME extends LogKey case object FAILURES extends LogKey case object FALLBACK_VERSION extends LogKey case object FEATURE_COLUMN extends LogKey @@ -231,6 +234,7 @@ object LogKeys { case object FINAL_OUTPUT_PATH extends LogKey case object FINAL_PATH extends LogKey case object FINISH_TRIGGER_DURATION extends LogKey + case object FREE_MEMORY_SIZE extends LogKey case object FROM_OFFSET extends LogKey case object FROM_TIME extends LogKey case object FUNCTION_NAME extends LogKey @@ -250,6 +254,7 @@ object LogKeys { case object HIVE_OPERATION_STATE extends LogKey case object HIVE_OPERATION_TYPE extends LogKey case object HOST extends LogKey + case object HOST_LOCAL_BLOCKS_SIZE extends LogKey case object HOST_NAMES extends LogKey case object HOST_PORT extends LogKey case object HOST_PORT2 extends LogKey @@ -265,6 +270,7 @@ object LogKeys { case object INITIAL_HEARTBEAT_INTERVAL extends LogKey case object INIT_MODE extends LogKey case object INPUT extends LogKey + case object INPUT_SPLIT extends LogKey case object INTERVAL extends LogKey case object ISOLATION_LEVEL extends LogKey case object ISSUE_DATE extends LogKey @@ -299,11 +305,14 @@ object LogKeys { case object LOAD_FACTOR extends LogKey case object LOAD_TIME extends LogKey case object LOCALE extends LogKey + case object LOCAL_BLOCKS_SIZE extends LogKey case object LOCAL_SCRATCH_DIR extends LogKey case object LOCATION extends LogKey case object LOGICAL_PLAN_COLUMNS extends LogKey case object LOGICAL_PLAN_LEAVES extends LogKey case object LOG_ID extends LogKey + case object LOG_KEY_FILE extends LogKey + case object LOG_LEVEL extends LogKey case object LOG_OFFSET extends LogKey case object LOG_TYPE extends LogKey case object LOWER_BOUND extends LogKey @@ -351,6 +360,7 @@ object LogKeys { case object MIN_SIZE extends LogKey case object MIN_TIME extends LogKey case object MIN_VERSION_NUM extends LogKey + case object MISSING_PARENT_STAGES extends LogKey case object MODEL_WEIGHTS extends LogKey case object MODULE_NAME extends LogKey case object NAMESPACE extends LogKey @@ -368,8 +378,9 @@ object LogKeys { case object NORM extends LogKey case object NUM_ADDED_PARTITIONS extends LogKey case object NUM_APPS extends LogKey + case object NUM_ATTEMPT extends LogKey case object NUM_BIN extends LogKey - case object NUM_BLOCK_IDS extends LogKey + case object NUM_BLOCKS extends LogKey case object NUM_BROADCAST_BLOCK extends LogKey case object NUM_BYTES extends LogKey case object NUM_BYTES_CURRENT extends LogKey @@ -388,12 +399,16 @@ object LogKeys { case object NUM_CORES extends LogKey case object NUM_DATA_FILE extends LogKey case object NUM_DATA_FILES extends LogKey + case object NUM_DECOMMISSIONED extends LogKey case object NUM_DRIVERS extends LogKey case object NUM_DROPPED_PARTITIONS extends LogKey case object NUM_EFFECTIVE_RULE_OF_RUNS extends LogKey case object NUM_ELEMENTS_SPILL_THRESHOLD extends LogKey case object NUM_EVENTS extends LogKey case object NUM_EXAMPLES extends LogKey + case object NUM_EXECUTORS extends LogKey + case object NUM_EXECUTORS_EXITED extends LogKey + case object NUM_EXECUTORS_KILLED extends LogKey case object NUM_EXECUTOR_CORES extends LogKey case object NUM_EXECUTOR_CORES_REMAINING extends LogKey case object NUM_EXECUTOR_CORES_TOTAL extends LogKey @@ -407,6 +422,7 @@ object LogKeys { case object NUM_FILES_FAILED_TO_DELETE extends LogKey case object NUM_FILES_REUSED extends LogKey case object NUM_FREQUENT_ITEMS extends LogKey + case object NUM_HOST_LOCAL_BLOCKS extends LogKey case object NUM_INDEX_FILE extends LogKey case object NUM_INDEX_FILES extends LogKey case object NUM_ITERATIONS extends LogKey @@ -415,8 +431,10 @@ object LogKeys { case object NUM_LEADING_SINGULAR_VALUES extends LogKey case object NUM_LEFT_PARTITION_VALUES extends LogKey case object NUM_LOADED_ENTRIES extends LogKey + case object NUM_LOCAL_BLOCKS extends LogKey case object NUM_LOCAL_DIRS extends LogKey case object NUM_LOCAL_FREQUENT_PATTERN extends LogKey + case object NUM_MERGERS extends LogKey case object NUM_MERGER_LOCATIONS extends LogKey case object NUM_META_FILES extends LogKey case object NUM_NODES extends LogKey @@ -433,7 +451,10 @@ object LogKeys { case object NUM_POINT extends LogKey case object NUM_PREFIXES extends LogKey case object NUM_PRUNED extends LogKey + case object NUM_PUSH_MERGED_LOCAL_BLOCKS extends LogKey case object NUM_RECORDS_READ extends LogKey + case object NUM_REMAINED extends LogKey + case object NUM_REMOTE_BLOCKS extends LogKey case object NUM_REMOVED_WORKERS extends LogKey case object NUM_REPLICAS extends LogKey case object NUM_REQUESTS extends LogKey @@ -449,11 +470,14 @@ object LogKeys { case object NUM_SPILL_INFOS extends LogKey case object NUM_SPILL_WRITERS extends LogKey case object NUM_SUB_DIRS extends LogKey + case object NUM_SUCCESSFUL_TASKS extends LogKey case object NUM_TASKS extends LogKey case object NUM_TASK_CPUS extends LogKey case object NUM_TRAIN_WORD extends LogKey + case object NUM_UNFINISHED_DECOMMISSIONED extends LogKey case object NUM_VERSIONS_RETAIN extends LogKey case object NUM_WEIGHTED_EXAMPLES extends LogKey + case object NUM_WORKERS extends LogKey case object OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD extends LogKey case object OBJECT_ID extends LogKey case object OFFSET extends LogKey @@ -470,16 +494,20 @@ object LogKeys { case object OPTIONS extends LogKey case object OP_ID extends LogKey case object OP_TYPE extends LogKey + case object ORIGINAL_DISK_SIZE extends LogKey + case object ORIGINAL_MEMORY_SIZE extends LogKey case object OS_ARCH extends LogKey case object OS_NAME extends LogKey case object OS_VERSION extends LogKey case object OUTPUT extends LogKey case object OVERHEAD_MEMORY_SIZE extends LogKey case object PAGE_SIZE extends LogKey + case object PARENT_STAGES extends LogKey case object PARSE_MODE extends LogKey case object PARTITIONED_FILE_READER extends LogKey case object PARTITIONER extends LogKey case object PARTITION_ID extends LogKey + case object PARTITION_IDS extends LogKey case object PARTITION_SPECIFICATION extends LogKey case object PARTITION_SPECS extends LogKey case object PATH extends LogKey @@ -511,12 +539,14 @@ object LogKeys { case object PROTOCOL_VERSION extends LogKey case object PROVIDER extends LogKey case object PUSHED_FILTERS extends LogKey + case object PUSH_MERGED_LOCAL_BLOCKS_SIZE extends LogKey case object PVC_METADATA_NAME extends LogKey case object PYTHON_EXEC extends LogKey case object PYTHON_PACKAGES extends LogKey case object PYTHON_VERSION extends LogKey case object PYTHON_WORKER_MODULE extends LogKey case object PYTHON_WORKER_RESPONSE extends LogKey + case object QUANTILES extends LogKey case object QUERY_CACHE_VALUE extends LogKey case object QUERY_HINT extends LogKey case object QUERY_ID extends LogKey @@ -542,11 +572,13 @@ object LogKeys { case object REDACTED_STATEMENT extends LogKey case object REDUCE_ID extends LogKey case object REGISTERED_EXECUTOR_FILE extends LogKey + case object REGISTER_MERGE_RESULTS extends LogKey case object RELATION_NAME extends LogKey case object RELATION_OUTPUT extends LogKey case object RELATIVE_TOLERANCE extends LogKey case object REMAINING_PARTITIONS extends LogKey case object REMOTE_ADDRESS extends LogKey + case object REMOTE_BLOCKS_SIZE extends LogKey case object REMOVE_FROM_MASTER extends LogKey case object REPORT_DETAILS extends LogKey case object REQUESTER_SIZE extends LogKey @@ -574,6 +606,7 @@ object LogKeys { case object RUN_ID extends LogKey case object SCALA_VERSION extends LogKey case object SCHEDULER_POOL_NAME extends LogKey + case object SCHEDULING_MODE extends LogKey case object SCHEMA extends LogKey case object SCHEMA2 extends LogKey case object SERVER_NAME extends LogKey @@ -615,13 +648,17 @@ object LogKeys { case object SPILL_TIMES extends LogKey case object SQL_TEXT extends LogKey case object SRC_PATH extends LogKey + case object STAGE extends LogKey + case object STAGES extends LogKey case object STAGE_ATTEMPT extends LogKey case object STAGE_ID extends LogKey + case object STAGE_NAME extends LogKey case object START_INDEX extends LogKey case object STATEMENT_ID extends LogKey case object STATE_STORE_ID extends LogKey case object STATE_STORE_PROVIDER extends LogKey case object STATE_STORE_VERSION extends LogKey + case object STATS extends LogKey case object STATUS extends LogKey case object STDERR extends LogKey case object STOP_SITE_SHORT_FORM extends LogKey @@ -647,13 +684,18 @@ object LogKeys { case object TABLE_NAME extends LogKey case object TABLE_TYPE extends LogKey case object TABLE_TYPES extends LogKey + case object TAG extends LogKey case object TARGET_NUM_EXECUTOR extends LogKey case object TARGET_NUM_EXECUTOR_DELTA extends LogKey case object TARGET_PATH extends LogKey case object TASK_ATTEMPT_ID extends LogKey case object TASK_ID extends LogKey + case object TASK_LOCALITY extends LogKey case object TASK_NAME extends LogKey case object TASK_REQUIREMENTS extends LogKey + case object TASK_RESOURCE_ASSIGNMENTS extends LogKey + case object TASK_SET_ID extends LogKey + case object TASK_SET_MANAGER extends LogKey case object TASK_SET_NAME extends LogKey case object TASK_STATE extends LogKey case object TEMP_FILE extends LogKey @@ -685,6 +727,7 @@ object LogKeys { case object TOPIC_PARTITION_OFFSET_RANGE extends LogKey case object TOTAL extends LogKey case object TOTAL_EFFECTIVE_TIME extends LogKey + case object TOTAL_SIZE extends LogKey case object TOTAL_TIME extends LogKey case object TOTAL_TIME_READ extends LogKey case object TO_TIME extends LogKey diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6018c87b01224..c70576b8adc10 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2930,7 +2930,7 @@ object SparkContext extends Logging { log" constructor). This may indicate an error, since only one SparkContext should be" + log" running in this JVM (see SPARK-2243)." + log" The other SparkContext was created at:\n" + - log"${MDC(LogKeys.CONTEXT_CREATION_SITE, otherContextCreationSite)}" + log"${MDC(LogKeys.CREATION_SITE, otherContextCreationSite)}" logWarning(warnMsg) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9c57269b28f47..263b1a233b808 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.{config, Logging, LogKeys, MDC} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils @@ -51,7 +51,8 @@ class LocalSparkCluster private ( private val workerDirs = ArrayBuffer[String]() def start(): Array[String] = { - logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") + logInfo(log"Starting a local Spark cluster with " + + log"${MDC(LogKeys.NUM_WORKERS, numWorkers)} workers.") // Disable REST server on Master in this mode unless otherwise specified val _conf = conf.clone() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 2edd80db2637f..ca932ef5dc05c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -142,8 +142,9 @@ private[spark] class SparkHadoopUtil extends Logging { if (!new File(keytabFilename).exists()) { throw new SparkException(s"Keytab file: ${keytabFilename} does not exist") } else { - logInfo("Attempting to login to Kerberos " + - s"using principal: ${principalName} and keytab: ${keytabFilename}") + logInfo(log"Attempting to login to Kerberos using principal: " + + log"${MDC(LogKeys.PRINCIPAL, principalName)} and keytab: " + + log"${MDC(LogKeys.KEYTAB, keytabFilename)}") UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index ec231610b8575..b34e5c408c3be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.internal.{Logging, MDC} +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc._ @@ -105,7 +105,8 @@ private[spark] class StandaloneAppClient( if (registered.get) { return } - logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + logInfo( + log"Connecting to master ${MDC(LogKeys.MASTER_URL, masterAddress.toSparkURL)}...") val masterRef = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) masterRef.send(RegisterApplication(appDescription, self)) } catch { @@ -175,14 +176,16 @@ private[spark] class StandaloneAppClient( case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = s"$appId/$id" - logInfo("Executor added: %s on %s (%s) with %d core(s)".format(fullId, workerId, hostPort, - cores)) + logInfo(log"Executor added: ${MDC(LogKeys.EXECUTOR_ID, fullId)} on " + + log"${MDC(LogKeys.WORKER_ID, workerId)} (${MDC(LogKeys.HOST_PORT, hostPort)}) " + + log"with ${MDC(LogKeys.NUM_CORES, cores)} core(s)") listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus, workerHost) => val fullId = s"$appId/$id" val messageText = message.map(s => " (" + s + ")").getOrElse("") - logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) + logInfo(log"Executor updated: ${MDC(LogKeys.EXECUTOR_ID, fullId)} is now " + + log"${MDC(LogKeys.EXECUTOR_STATE, state)}${MDC(LogKeys.MESSAGE, messageText)}") if (ExecutorState.isFinished(state)) { listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerHost) } else if (state == ExecutorState.DECOMMISSIONED) { @@ -191,11 +194,13 @@ private[spark] class StandaloneAppClient( } case WorkerRemoved(id, host, message) => - logInfo("Master removed worker %s: %s".format(id, message)) + logInfo(log"Master removed worker ${MDC(LogKeys.WORKER_ID, id)}: " + + log"${MDC(LogKeys.MESSAGE, message)}") listener.workerRemoved(id, host, message) case MasterChanged(masterRef, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + logInfo(log"Master has changed, new master is at " + + log"${MDC(LogKeys.MASTER_URL, masterRef.address.toSparkURL)}") master = Some(masterRef) alreadyDisconnected = false masterRef.send(MasterChangeAcknowledged(appId.get)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 106acc9a79446..964b115865aef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.config.Deploy.{RECOVERY_COMPRESSION_CODEC, RECOVERY_DIRECTORY} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer @@ -57,7 +57,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: val recoveryDir = conf.get(RECOVERY_DIRECTORY) def createPersistenceEngine(): PersistenceEngine = { - logInfo("Persisting recovery state to directory: " + recoveryDir) + logInfo(log"Persisting recovery state to directory: ${MDC(LogKeys.PATH, recoveryDir)}") val codec = conf.get(RECOVERY_COMPRESSION_CODEC).map(c => CompressionCodec.createCodec(conf, c)) new FileSystemPersistenceEngine(recoveryDir, serializer, codec) } @@ -76,7 +76,8 @@ private[master] class RocksDBRecoveryModeFactory(conf: SparkConf, serializer: Se def createPersistenceEngine(): PersistenceEngine = { val recoveryDir = conf.get(RECOVERY_DIRECTORY) - logInfo("Persisting recovery state to directory: " + recoveryDir) + logInfo(log"Persisting recovery state to directory: " + + log"${MDC(LogKeys.PATH, recoveryDir)}") new RocksDBPersistenceEngine(recoveryDir, serializer) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 247504f5ebbb9..4fb95033cecef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -79,7 +79,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { * it to the user. Otherwise, report the error message provided by the server. */ def createSubmission(request: CreateSubmissionRequest): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request to launch an application in $master.") + logInfo(log"Submitting a request to launch an application in ${MDC(MASTER_URL, master)}.") var handled: Boolean = false var response: SubmitRestProtocolResponse = null for (m <- masters if !handled) { @@ -109,7 +109,9 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { /** Request that the server kill the specified submission. */ def killSubmission(submissionId: String): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request to kill submission $submissionId in $master.") + logInfo(log"Submitting a request to kill submission " + + log"${MDC(SUBMISSION_ID, submissionId)} in " + + log"${MDC(MASTER_URL, master)}.") var handled: Boolean = false var response: SubmitRestProtocolResponse = null for (m <- masters if !handled) { @@ -138,7 +140,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { /** Request that the server kill all submissions. */ def killAllSubmissions(): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request to kill all submissions in $master.") + logInfo(log"Submitting a request to kill all submissions in ${MDC(MASTER_URL, master)}.") var handled: Boolean = false var response: SubmitRestProtocolResponse = null for (m <- masters if !handled) { @@ -167,7 +169,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { /** Request that the server clears all submissions and applications. */ def clear(): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request to clear $master.") + logInfo(log"Submitting a request to clear ${MDC(MASTER_URL, master)}.") var handled: Boolean = false var response: SubmitRestProtocolResponse = null for (m <- masters if !handled) { @@ -196,7 +198,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { /** Check the readiness of Master. */ def readyz(): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request to check the status of $master.") + logInfo(log"Submitting a request to check the status of ${MDC(MASTER_URL, master)}.") var handled: Boolean = false var response: SubmitRestProtocolResponse = new ErrorResponse for (m <- masters if !handled) { @@ -227,7 +229,9 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { def requestSubmissionStatus( submissionId: String, quiet: Boolean = false): SubmitRestProtocolResponse = { - logInfo(s"Submitting a request for the status of submission $submissionId in $master.") + logInfo(log"Submitting a request for the status of submission " + + log"${MDC(SUBMISSION_ID, submissionId)} in " + + log"${MDC(MASTER_URL, master)}.") var handled: Boolean = false var response: SubmitRestProtocolResponse = null @@ -440,7 +444,8 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { if (submitResponse.success) { val submissionId = submitResponse.submissionId if (submissionId != null) { - logInfo(s"Submission successfully created as $submissionId. Polling submission state...") + logInfo(log"Submission successfully created as ${MDC(SUBMISSION_ID, submissionId)}. " + + log"Polling submission state...") pollSubmissionStatus(submissionId) } else { // should never happen @@ -470,13 +475,17 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { val exception = Option(statusResponse.message) // Log driver state, if present driverState match { - case Some(state) => logInfo(s"State of driver $submissionId is now $state.") + case Some(state) => + logInfo(log"State of driver ${MDC(SUBMISSION_ID, submissionId)} is now " + + log"${MDC(DRIVER_STATE, state)}.") case _ => logError(log"State of driver ${MDC(SUBMISSION_ID, submissionId)} was not found!") } // Log worker node, if present (workerId, workerHostPort) match { - case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case (Some(id), Some(hp)) => + logInfo( + log"Driver is running on worker ${MDC(WORKER_ID, id)} at ${MDC(HOST_PORT, hp)}.") case _ => } // Log exception stack trace, if present @@ -490,7 +499,8 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { /** Log the response sent by the server in the REST application submission protocol. */ private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = { - logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}") + logInfo(log"Server responded with ${MDC(CLASS_NAME, response.messageType)}:\n" + + log"${MDC(RESULT, response.toJson)}") } /** Log an appropriate error if the response sent by the server is not of the expected type. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index d1190ca46c2a8..a3e7276fc83e1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SecurityManager, SSLOptions} import org.apache.spark.deploy.Command -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils @@ -120,7 +120,8 @@ object CommandUtils extends Logging { Utils.copyStream(in, out, true) } catch { case e: IOException => - logInfo("Redirection to " + file + " closed: " + e.getMessage) + logInfo(log"Redirection to ${MDC(LogKeys.FILE_NAME, file)} closed: " + + log"${MDC(LogKeys.ERROR, e.getMessage)}") } } }.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index bd98f19cdb605..8d0fb7a54f72a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -204,7 +204,7 @@ private[deploy] class ExecutorRunner( worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => - logInfo("Runner thread for executor " + fullId + " interrupted") + logInfo(log"Runner thread for executor ${MDC(EXECUTOR_ID, fullId)} interrupted") state = ExecutorState.KILLED killProcess(s"Runner thread for executor $fullId interrupted") case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f030475131d24..7ff7974ab59f6 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -38,7 +38,6 @@ import org.apache.spark.deploy.StandaloneResourceUtils._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.internal.{config, Logging, MDC} -import org.apache.spark.internal.LogKeys import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.internal.config.UI._ @@ -74,8 +73,8 @@ private[deploy] class Worker( // If worker decommissioning is enabled register a handler on the configured signal to shutdown. if (conf.get(config.DECOMMISSION_ENABLED)) { val signal = conf.get(config.Worker.WORKER_DECOMMISSION_SIGNAL) - logInfo(s"Registering SIG$signal handler to trigger decommissioning.") - SignalUtils.register(signal, log"Failed to register SIG${MDC(LogKeys.SIGNAL, signal)} " + + logInfo(log"Registering SIG${MDC(SIGNAL, signal)} handler to trigger decommissioning.") + SignalUtils.register(signal, log"Failed to register SIG${MDC(SIGNAL, signal)} " + log"handler - disabling worker decommission feature.") { self.send(WorkerDecommissionSigReceived) true @@ -106,8 +105,12 @@ private[deploy] class Worker( private val INITIAL_REGISTRATION_RETRIES = conf.get(WORKER_INITIAL_REGISTRATION_RETRIES) private val TOTAL_REGISTRATION_RETRIES = conf.get(WORKER_MAX_REGISTRATION_RETRIES) if (INITIAL_REGISTRATION_RETRIES > TOTAL_REGISTRATION_RETRIES) { - logInfo(s"${WORKER_INITIAL_REGISTRATION_RETRIES.key} ($INITIAL_REGISTRATION_RETRIES) is " + - s"capped by ${WORKER_MAX_REGISTRATION_RETRIES.key} ($TOTAL_REGISTRATION_RETRIES)") + logInfo( + log"${MDC(CONFIG, WORKER_INITIAL_REGISTRATION_RETRIES.key)} " + + log"(${MDC(VALUE, INITIAL_REGISTRATION_RETRIES)}) is capped by " + + log"${MDC(CONFIG2, WORKER_MAX_REGISTRATION_RETRIES.key)} " + + log"(${MDC(MAX_ATTEMPTS, TOTAL_REGISTRATION_RETRIES)})" + ) } private val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 private val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { @@ -236,10 +239,11 @@ private[deploy] class Worker( override def onStart(): Unit = { assert(!registered) - logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( - host, port, cores, Utils.megabytesToString(memory))) - logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - logInfo("Spark home: " + sparkHome) + logInfo(log"Starting Spark worker ${MDC(HOST, host)}:${MDC(PORT, port)} " + + log"with ${MDC(NUM_CORES, cores)} cores, " + + log"${MDC(MEMORY_SIZE, Utils.megabytesToString(memory))} RAM") + logInfo(log"Running Spark version ${MDC(SPARK_VERSION, org.apache.spark.SPARK_VERSION)}") + logInfo(log"Spark home: ${MDC(PATH, sparkHome)}") createWorkDir() startExternalShuffleService() setupWorkerResources() @@ -300,8 +304,9 @@ private[deploy] class Worker( master = Some(masterRef) connected = true if (reverseProxy) { - logInfo("WorkerWebUI is available at %s/proxy/%s".format( - activeMasterWebUiUrl.stripSuffix("/"), workerId)) + logInfo( + log"WorkerWebUI is available at ${MDC(WEB_URL, activeMasterWebUiUrl.stripSuffix("/"))}" + + log"/proxy/${MDC(WORKER_ID, workerId)}") // if reverseProxyUrl is not set, then we continue to generate relative URLs // starting with "/" throughout the UI and do not use activeMasterWebUiUrl val proxyUrl = conf.get(UI_REVERSE_PROXY_URL.key, "").stripSuffix("/") @@ -318,7 +323,7 @@ private[deploy] class Worker( registerMasterThreadPool.submit(new Runnable { override def run(): Unit = { try { - logInfo("Connecting to master " + masterAddress + "...") + logInfo(log"Connecting to master ${MDC(MASTER_URL, masterAddress)}...") val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) sendRegisterMessageToMaster(masterEndpoint) } catch { @@ -342,7 +347,8 @@ private[deploy] class Worker( if (registered) { cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { - logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") + logInfo(log"Retrying connection to master (attempt # " + + log"${MDC(NUM_ATTEMPT, connectionAttemptCount)})") /** * Re-register with the active master this worker has been communicating with. If there * is none, then it means this worker is still bootstrapping and hasn't established a @@ -376,7 +382,7 @@ private[deploy] class Worker( registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { override def run(): Unit = { try { - logInfo("Connecting to master " + masterAddress + "...") + logInfo(log"Connecting to master ${MDC(MASTER_URL, masterAddress)}...") val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) sendRegisterMessageToMaster(masterEndpoint) } catch { @@ -483,7 +489,7 @@ private[deploy] class Worker( log"${MDC(MASTER_URL, preferredMasterAddress)}") } - logInfo(s"Successfully registered with master $preferredMasterAddress") + logInfo(log"Successfully registered with master ${MDC(MASTER_URL, preferredMasterAddress)}") registered = true changeMaster(masterRef, masterWebUiUrl, masterAddress) forwardMessageScheduler.scheduleAtFixedRate( @@ -491,7 +497,8 @@ private[deploy] class Worker( 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo( - s"Worker cleanup enabled; old application directories will be deleted in: $workDir") + log"Worker cleanup enabled; old application directories will be deleted in: " + + log"${MDC(PATH, workDir)}") forwardMessageScheduler.scheduleAtFixedRate( () => Utils.tryLogNonFatalError { self.send(WorkDirCleanup) }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) @@ -539,7 +546,7 @@ private[deploy] class Worker( dir.isDirectory && !isAppStillRunning && !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => - logInfo(s"Removing directory: ${dir.getPath}") + logInfo(log"Removing directory: ${MDC(PATH, dir.getPath)}") Utils.deleteRecursively(dir) // Remove some registeredExecutors information of DB in external shuffle service when @@ -562,7 +569,8 @@ private[deploy] class Worker( } case MasterChanged(masterRef, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + logInfo(log"Master has changed, new master is at " + + log"${MDC(MASTER_URL, masterRef.address.toSparkURL)}") changeMaster(masterRef, masterWebUiUrl, masterRef.address) val executorResponses = executors.values.map { e => @@ -575,7 +583,8 @@ private[deploy] class Worker( workerId, executorResponses.toList, driverResponses.toSeq)) case ReconnectWorker(masterUrl) => - logInfo(s"Master with url $masterUrl requested this worker to reconnect.") + logInfo( + log"Master with url ${MDC(MASTER_URL, masterUrl)} requested this worker to reconnect.") registerWithMaster() case LaunchExecutor(masterUrl, appId, execId, rpId, appDesc, cores_, memory_, resources_) => @@ -586,7 +595,8 @@ private[deploy] class Worker( logWarning("Asked to launch an executor while decommissioned. Not launching executor.") } else { try { - logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) + logInfo(log"Asked to launch executor ${MDC(APP_ID, appId)}/${MDC(EXECUTOR_ID, execId)}" + + log" for ${MDC(APP_DESC, appDesc.name)}") // Create the executor's working directory val executorDir = new File(workDir, appId + "/" + execId) @@ -645,8 +655,8 @@ private[deploy] class Worker( } catch { case e: Exception => logError( - log"Failed to launch executor ${MDC(APP_ID, appId)}/${MDC(EXECUTOR_ID, execId)} " + - log"for ${MDC(APP_DESC, appDesc.name)}.", e) + log"Failed to launch executor ${MDC(APP_ID, appId)}/" + + log"${MDC(EXECUTOR_ID, execId)} for ${MDC(APP_DESC, appDesc.name)}.", e) if (executors.contains(appId + "/" + execId)) { executors(appId + "/" + execId).kill() executors -= appId + "/" + execId @@ -667,15 +677,15 @@ private[deploy] class Worker( val fullId = appId + "/" + execId executors.get(fullId) match { case Some(executor) => - logInfo("Asked to kill executor " + fullId) + logInfo(log"Asked to kill executor ${MDC(EXECUTOR_ID, fullId)}") executor.kill() case None => - logInfo("Asked to kill unknown executor " + fullId) + logInfo(log"Asked to kill unknown executor ${MDC(EXECUTOR_ID, fullId)}") } } case LaunchDriver(driverId, driverDesc, resources_) => - logInfo(s"Asked to launch driver $driverId") + logInfo(log"Asked to launch driver ${MDC(DRIVER_ID, driverId)}") val driver = new DriverRunner( conf, driverId, @@ -695,7 +705,7 @@ private[deploy] class Worker( addResourcesUsed(resources_) case KillDriver(driverId) => - logInfo(s"Asked to kill driver $driverId") + logInfo(log"Asked to kill driver ${MDC(DRIVER_ID, driverId)}") drivers.get(driverId) match { case Some(runner) => runner.kill() @@ -735,7 +745,7 @@ private[deploy] class Worker( override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (master.exists(_.address == remoteAddress) || masterAddressToConnect.contains(remoteAddress)) { - logInfo(s"$remoteAddress Disassociated !") + logInfo(log"${MDC(REMOTE_ADDRESS, remoteAddress)} Disassociated !") masterDisconnected() } } @@ -753,7 +763,7 @@ private[deploy] class Worker( try { appDirectories.remove(id).foreach { dirList => concurrent.Future { - logInfo(s"Cleaning up local directories for application $id") + logInfo(log"Cleaning up local directories for application ${MDC(APP_ID, id)}") dirList.foreach { dir => Utils.deleteRecursively(new File(dir)) } @@ -874,7 +884,7 @@ private[deploy] class Worker( private[deploy] def decommissionSelf(): Unit = { if (conf.get(config.DECOMMISSION_ENABLED) && !decommissioned) { decommissioned = true - logInfo(s"Decommission worker $workerId.") + logInfo(log"Decommission worker ${MDC(WORKER_ID, workerId)}.") } else if (decommissioned) { logWarning(log"Worker ${MDC(WORKER_ID, workerId)} already started decommissioning.") } else { @@ -898,10 +908,10 @@ private[deploy] class Worker( logWarning(log"Driver ${MDC(DRIVER_ID, driverId)} " + log"exited successfully while master is disconnected.") case _ => - logInfo(s"Driver $driverId exited successfully") + logInfo(log"Driver ${MDC(DRIVER_ID, driverId)} exited successfully") } case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") + logInfo(log"Driver ${MDC(DRIVER_ID, driverId)} was killed by user") case _ => logDebug(s"Driver $driverId changed state to $state") } @@ -921,13 +931,22 @@ private[deploy] class Worker( if (ExecutorState.isFinished(state)) { val appId = executorStateChanged.appId val fullId = appId + "/" + executorStateChanged.execId - val message = executorStateChanged.message - val exitStatus = executorStateChanged.exitStatus + val message = executorStateChanged.message match { + case Some(msg) => + log" message ${MDC(MESSAGE, msg)}" + case None => + log"" + } + val exitStatus = executorStateChanged.exitStatus match { + case Some(status) => + log" exitStatus ${MDC(EXIT_CODE, status)}" + case None => + log"" + } executors.get(fullId) match { case Some(executor) => - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) + logInfo(log"Executor ${MDC(EXECUTOR_ID, fullId)} finished with state " + + log"${MDC(EXECUTOR_STATE, state)}" + message + exitStatus) executors -= fullId finishedExecutors(fullId) = executor trimFinishedExecutorsIfNecessary() @@ -939,9 +958,8 @@ private[deploy] class Worker( shuffleService.executorRemoved(executorStateChanged.execId.toString, appId) } case None => - logInfo("Unknown Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) + logInfo(log"Unknown Executor ${MDC(EXECUTOR_ID, fullId)} finished with state " + + log"${MDC(EXECUTOR_STATE, state)}" + message + exitStatus) } maybeCleanupApplication(appId) } diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index 0158dd6ba7757..7098961d1649a 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -156,7 +156,7 @@ private[memory] class ExecutionMemoryPool( val memoryToFree = if (curMem < numBytes) { logWarning( log"Internal error: release called on ${MDC(NUM_BYTES, numBytes)} " + - log"bytes but task only has ${MDC(CURRENT_MEMORY_BYTES, curMem)} bytes " + + log"bytes but task only has ${MDC(CURRENT_MEMORY_SIZE, curMem)} bytes " + log"of memory from the ${MDC(MEMORY_POOL_NAME, poolName)} pool") curMem } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index ff899a2e56dc0..cbfce378879ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -270,7 +270,7 @@ class HadoopRDD[K, V]( val iter = new NextIterator[(K, V)] { private val split = theSplit.asInstanceOf[HadoopPartition] - logInfo("Input split: " + split.inputSplit) + logInfo(log"Input split: ${MDC(INPUT_SPLIT, split.inputSplit)}") private val jobConf = getJobConf() private val inputMetrics = context.taskMetrics().inputMetrics diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index a87d02287302d..3a1ce4bd1dfde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -197,7 +197,7 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new Iterator[(K, V)] { private val split = theSplit.asInstanceOf[NewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) + logInfo(log"Input split: ${MDC(INPUT_SPLIT, split.serializableHadoopSplit)}") private val conf = getConf private val inputMetrics = context.taskMetrics().inputMetrics diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1f44f7e782c48..ac93abf3fe7a0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -211,7 +211,7 @@ abstract class RDD[T: ClassTag]( * @return This RDD. */ def unpersist(blocking: Boolean = false): this.type = { - logInfo(s"Removing RDD $id from persistence list") + logInfo(log"Removing RDD ${MDC(RDD_ID, id)} from persistence list") sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 2f6ff0acdf024..118660ef69476 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.SequenceFileOutputFormat -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, @@ -58,8 +58,9 @@ class SequenceFileRDDFunctions[K: IsWritable: ClassTag, V: IsWritable: ClassTag] val convertKey = self.keyClass != _keyWritableClass val convertValue = self.valueClass != _valueWritableClass - logInfo("Saving as sequence file of type " + - s"(${_keyWritableClass.getSimpleName},${_valueWritableClass.getSimpleName})" ) + logInfo(log"Saving as sequence file of type " + + log"(${MDC(LogKeys.KEY, _keyWritableClass.getSimpleName)}," + + log"${MDC(LogKeys.VALUE, _valueWritableClass.getSimpleName)})") val format = classOf[SequenceFileOutputFormat[Writable, Writable]] val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cc9ae5eb1ebe5..7c096dd110e50 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -535,8 +535,9 @@ private[spark] class DAGScheduler( if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " + - s"shuffle ${shuffleDep.shuffleId}") + logInfo(log"Registering RDD ${MDC(RDD_ID, rdd.id)} " + + log"(${MDC(CREATION_SITE, rdd.getCreationSite)}) as input to " + + log"shuffle ${MDC(SHUFFLE_ID, shuffleDep.shuffleId)}") mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length, shuffleDep.partitioner.numPartitions) } @@ -1097,7 +1098,7 @@ private[spark] class DAGScheduler( * Cancel a job that is running or waiting in the queue. */ def cancelJob(jobId: Int, reason: Option[String]): Unit = { - logInfo("Asked to cancel job " + jobId) + logInfo(log"Asked to cancel job ${MDC(JOB_ID, jobId)}") eventProcessLoop.post(JobCancelled(jobId, reason)) } @@ -1106,7 +1107,8 @@ private[spark] class DAGScheduler( * @param cancelFutureJobs if true, future submitted jobs in this job group will be cancelled */ def cancelJobGroup(groupId: String, cancelFutureJobs: Boolean = false): Unit = { - logInfo(s"Asked to cancel job group $groupId with cancelFutureJobs=$cancelFutureJobs") + logInfo(log"Asked to cancel job group ${MDC(GROUP_ID, groupId)} with " + + log"cancelFutureJobs=${MDC(CANCEL_FUTURE_JOBS, cancelFutureJobs)}") eventProcessLoop.post(JobGroupCancelled(groupId, cancelFutureJobs)) } @@ -1115,7 +1117,7 @@ private[spark] class DAGScheduler( */ def cancelJobsWithTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - logInfo(s"Asked to cancel jobs with tag $tag") + logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") eventProcessLoop.post(JobTagCancelled(tag)) } @@ -1209,7 +1211,7 @@ private[spark] class DAGScheduler( // If cancelFutureJobs is true, store the cancelled job group id into internal states. // When a job belonging to this job group is submitted, skip running it. if (cancelFutureJobs) { - logInfo(s"Add job group $groupId into cancelled job groups") + logInfo(log"Add job group ${MDC(GROUP_ID, groupId)} into cancelled job groups") cancelledJobGroups.add(groupId) } @@ -1314,7 +1316,7 @@ private[spark] class DAGScheduler( if (jobGroupIdOpt.exists(cancelledJobGroups.contains(_))) { listener.jobFailed( SparkCoreErrors.sparkJobCancelledAsPartOfJobGroupError(jobId, jobGroupIdOpt.get)) - logInfo(s"Skip running a job that belongs to the cancelled job group ${jobGroupIdOpt.get}.") + logInfo(log"Skip running a job that belongs to the cancelled job group ${MDC(GROUP_ID, jobGroupIdOpt.get)}") return } @@ -1362,11 +1364,13 @@ private[spark] class DAGScheduler( val job = new ActiveJob(jobId, finalStage, callSite, listener, artifacts, properties) clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions".format( - job.jobId, callSite.shortForm, partitions.length)) - logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) + logInfo( + log"Got job ${MDC(JOB_ID, job.jobId)} (${MDC(CALL_SITE_SHORT_FORM, callSite.shortForm)}) " + + log"with ${MDC(NUM_PARTITIONS, partitions.length)} output partitions") + logInfo(log"Final stage: ${MDC(STAGE_ID, finalStage)} " + + log"(${MDC(STAGE_NAME, finalStage.name)})") + logInfo(log"Parents of final stage: ${MDC(STAGE_ID, finalStage.parents)}") + logInfo(log"Missing parents: ${MDC(MISSING_PARENT_STAGES, getMissingParentStages(finalStage))}") val jobSubmissionTime = clock.getTimeMillis() jobIdToActiveJob(jobId) = job @@ -1403,11 +1407,13 @@ private[spark] class DAGScheduler( val job = new ActiveJob(jobId, finalStage, callSite, listener, artifacts, properties) clearCacheLocs() - logInfo("Got map stage job %s (%s) with %d output partitions".format( - jobId, callSite.shortForm, dependency.rdd.partitions.length)) - logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) + logInfo(log"Got map stage job ${MDC(JOB_ID, jobId)} " + + log"(${MDC(CALL_SITE_SHORT_FORM, callSite.shortForm)}) with " + + log"${MDC(NUM_PARTITIONS, dependency.rdd.partitions.length)} output partitions") + logInfo(log"Final stage: ${MDC(STAGE_ID, finalStage)} " + + log"(${MDC(STAGE_NAME, finalStage.name)})") + logInfo(log"Parents of final stage: ${MDC(PARENT_STAGES, finalStage.parents.toString)}") + logInfo(log"Missing parents: ${MDC(MISSING_PARENT_STAGES, getMissingParentStages(finalStage))}") val jobSubmissionTime = clock.getTimeMillis() jobIdToActiveJob(jobId) = job @@ -1444,7 +1450,8 @@ private[spark] class DAGScheduler( val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) if (missing.isEmpty) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + logInfo(log"Submitting ${MDC(STAGE_ID, stage)} (${MDC(RDD_ID, stage.rdd)}), " + + log"which has no missing parents") submitMissingTasks(stage, jobId.get) } else { for (parent <- missing) { @@ -1495,13 +1502,16 @@ private[spark] class DAGScheduler( val shuffleId = stage.shuffleDep.shuffleId val shuffleMergeId = stage.shuffleDep.shuffleMergeId if (stage.shuffleDep.shuffleMergeEnabled) { - logInfo(s"Shuffle merge enabled before starting the stage for $stage with shuffle" + - s" $shuffleId and shuffle merge $shuffleMergeId with" + - s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + logInfo(log"Shuffle merge enabled before starting the stage for ${MDC(STAGE_ID, stage)}" + + log" with shuffle ${MDC(SHUFFLE_ID, shuffleId)} and shuffle merge" + + log" ${MDC(SHUFFLE_MERGE_ID, shuffleMergeId)} with" + + log" ${MDC(NUM_MERGER_LOCATIONS, stage.shuffleDep.getMergerLocs.size.toString)} merger locations") } else { - logInfo(s"Shuffle merge disabled for $stage with shuffle $shuffleId" + - s" and shuffle merge $shuffleMergeId, but can get enabled later adaptively" + - s" once enough mergers are available") + logInfo(log"Shuffle merge disabled for ${MDC(STAGE_ID, stage)} with " + + log"shuffle ${MDC(SHUFFLE_ID, shuffleId)} and " + + log"shuffle merge ${MDC(SHUFFLE_MERGE_ID, shuffleMergeId)}, " + + log"but can get enabled later adaptively once enough " + + log"mergers are available") } } @@ -1558,8 +1568,8 @@ private[spark] class DAGScheduler( // merger locations but the corresponding shuffle map stage did not complete // successfully, we would still enable push for its retry. s.shuffleDep.setShuffleMergeAllowed(false) - logInfo(s"Push-based shuffle disabled for $stage (${stage.name}) since it" + - " is already shuffle merge finalized") + logInfo(log"Push-based shuffle disabled for ${MDC(STAGE_ID, stage)} " + + log"(${MDC(STAGE_NAME, stage.name)}) since it is already shuffle merge finalized") } } case s: ResultStage => @@ -1681,8 +1691,9 @@ private[spark] class DAGScheduler( } if (tasks.nonEmpty) { - logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + - s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") + logInfo(log"Submitting ${MDC(NUM_TASKS, tasks.size)} missing tasks from " + + log"${MDC(STAGE_ID, stage)} (${MDC(RDD_ID, stage.rdd)}) (first 15 tasks are " + + log"for partitions ${MDC(PARTITION_IDS, tasks.take(15).map(_.partitionId))})") val shuffleId = stage match { case s: ShuffleMapStage => Some(s.shuffleDep.shuffleId) case _: ResultStage => None @@ -1751,9 +1762,10 @@ private[spark] class DAGScheduler( case Some(accum) => accum.getClass.getName case None => "Unknown class" } - logError( - log"Failed to update accumulator ${MDC(ACCUMULATOR_ID, id)} (${MDC(CLASS_NAME, accumClassName)}) " + - log"for task ${MDC(PARTITION_ID, task.partitionId)}", e) + logError( + log"Failed to update accumulator ${MDC(ACCUMULATOR_ID, id)} " + + log"(${MDC(CLASS_NAME, accumClassName)}) for task " + + log"${MDC(PARTITION_ID, task.partitionId)}", e) } } } @@ -1926,8 +1938,8 @@ private[spark] class DAGScheduler( try { // killAllTaskAttempts will fail if a SchedulerBackend does not implement // killTask. - logInfo(s"Job ${job.jobId} is finished. Cancelling potential speculative " + - "or zombie tasks for this job") + logInfo(log"Job ${MDC(JOB_ID, job.jobId)} is finished. Cancelling " + + log"potential speculative or zombie tasks for this job") // ResultStage is only used by this job. It's safe to kill speculative or // zombie tasks in this stage. taskScheduler.killAllTaskAttempts( @@ -1954,7 +1966,7 @@ private[spark] class DAGScheduler( } } case None => - logInfo("Ignoring result from " + rt + " because its job has finished") + logInfo(log"Ignoring result from ${MDC(RESULT, rt)} because its job has finished") } case smt: ShuffleMapTask => @@ -1969,7 +1981,8 @@ private[spark] class DAGScheduler( logDebug("ShuffleMapTask finished on " + execId) if (executorFailureEpoch.contains(execId) && smt.epoch <= executorFailureEpoch(execId)) { - logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") + logInfo(log"Ignoring possibly bogus ${MDC(STAGE_ID, smt)} completion from " + + log"executor ${MDC(EXECUTOR_ID, execId)}") } else { // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as @@ -1978,7 +1991,7 @@ private[spark] class DAGScheduler( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) } } else { - logInfo(s"Ignoring $smt completion from an older attempt of indeterminate stage") + logInfo(log"Ignoring ${MDC(TASK_NAME, smt)} completion from an older attempt of indeterminate stage") } if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { @@ -1996,17 +2009,22 @@ private[spark] class DAGScheduler( val mapStage = shuffleIdToMapStage(shuffleId) if (failedStage.latestInfo.attemptNumber() != task.stageAttemptId) { - logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + - s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ${failedStage.latestInfo.attemptNumber()}) running") + logInfo(log"Ignoring fetch failure from " + + log"${MDC(TASK_ID, task)} as it's from " + + log"${MDC(STAGE_ID, failedStage)} attempt " + + log"${MDC(STAGE_ATTEMPT, task.stageAttemptId)} and there is a more recent attempt for " + + log"that stage (attempt " + + log"${MDC(NUM_ATTEMPT, failedStage.latestInfo.attemptNumber())}) running") } else { val ignoreStageFailure = ignoreDecommissionFetchFailure && isExecutorDecommissioningOrDecommissioned(taskScheduler, bmAddress) if (ignoreStageFailure) { - logInfo(s"Ignoring fetch failure from $task of $failedStage attempt " + - s"${task.stageAttemptId} when count ${config.STAGE_MAX_CONSECUTIVE_ATTEMPTS.key} " + - s"as executor ${bmAddress.executorId} is decommissioned and " + - s" ${config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE.key}=true") + logInfo(log"Ignoring fetch failure from ${MDC(TASK_ID, task)} of " + + log"${MDC(STAGE, failedStage)} attempt " + + log"${MDC(STAGE_ATTEMPT, task.stageAttemptId)} when count " + + log"${MDC(MAX_ATTEMPTS, config.STAGE_MAX_CONSECUTIVE_ATTEMPTS.key)} " + + log"as executor ${MDC(EXECUTOR_ID, bmAddress.executorId)} is decommissioned and " + + log"${MDC(CONFIG, config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE.key)}=true") } else { failedStage.failedAttemptIds.add(task.stageAttemptId) } @@ -2019,8 +2037,10 @@ private[spark] class DAGScheduler( // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + - s"due to a fetch failure from $mapStage (${mapStage.name})") + logInfo(log"Marking ${MDC(FAILED_STAGE, failedStage)} " + + log"(${MDC(FAILED_STAGE_NAME, failedStage.name)}) as failed " + + log"due to a fetch failure from ${MDC(STAGE, mapStage)} " + + log"(${MDC(STAGE_NAME, mapStage.name)})") markStageAsFinished(failedStage, errorMessage = Some(failureMessage), willRetry = !shouldAbortStage) } else { @@ -2148,9 +2168,9 @@ private[spark] class DAGScheduler( case _ => } - logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " + - s"we will roll back and rerun below stages which include itself and all its " + - s"indeterminate child stages: $rollingBackStages") + logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " + + log"we will roll back and rerun below stages which include itself and all its " + + log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") } // We expect one executor failure to trigger many FetchFailures in rapid succession, @@ -2162,9 +2182,9 @@ private[spark] class DAGScheduler( // producing a resubmit for each failed stage makes debugging and logging a little // simpler while not producing an overwhelming number of scheduler events. logInfo( - s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure" - ) + log"Resubmitting ${MDC(STAGE, mapStage)} " + + log"(${MDC(STAGE_NAME, mapStage.name)}) and ${MDC(FAILED_STAGE, failedStage)} " + + log"(${MDC(FAILED_STAGE_NAME, failedStage.name)}) due to fetch failure") messageScheduler.schedule( new Runnable { override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) @@ -2223,12 +2243,13 @@ private[spark] class DAGScheduler( // Always fail the current stage and retry all the tasks when a barrier task fail. val failedStage = stageIdToStage(task.stageId) if (failedStage.latestInfo.attemptNumber() != task.stageAttemptId) { - logInfo(s"Ignoring task failure from $task as it's from $failedStage attempt" + - s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ${failedStage.latestInfo.attemptNumber()}) running") + logInfo(log"Ignoring task failure from ${MDC(TASK_ID, task)} as it's from " + + log"${MDC(FAILED_STAGE, failedStage)} attempt ${MDC(STAGE_ATTEMPT, task.stageAttemptId)} " + + log"and there is a more recent attempt for that stage (attempt " + + log"${MDC(NUM_ATTEMPT, failedStage.latestInfo.attemptNumber())}) running") } else { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + - "failed.") + logInfo(log"Marking ${MDC(STAGE_ID, failedStage.id)} (${MDC(STAGE_NAME, failedStage.name)}) " + + log"as failed due to a barrier task failed.") val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + failure.toErrorString try { @@ -2283,8 +2304,8 @@ private[spark] class DAGScheduler( val noResubmitEnqueued = !failedStages.contains(failedStage) failedStages += failedStage if (noResubmitEnqueued) { - logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + - "failure.") + logInfo(log"Resubmitting ${MDC(FAILED_STAGE, failedStage)} " + + log"(${MDC(FAILED_STAGE_NAME, failedStage.name)}) due to barrier stage failure.") messageScheduler.schedule(new Runnable { override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) @@ -2361,8 +2382,8 @@ private[spark] class DAGScheduler( // delay should be 0 and registerMergeResults should be true. assert(delay == 0 && registerMergeResults) if (task.getDelay(TimeUnit.NANOSECONDS) > 0 && task.cancel(false)) { - logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle merge immediately " + - s"after cancelling previously scheduled task.") + logInfo(log"${MDC(STAGE, stage)} (${MDC(STAGE_NAME, stage.name)}) scheduled " + + log"for finalizing shuffle merge immediately after cancelling previously scheduled task.") shuffleDep.setFinalizeTask( shuffleMergeFinalizeScheduler.schedule( new Runnable { @@ -2373,13 +2394,15 @@ private[spark] class DAGScheduler( ) ) } else { - logInfo(s"$stage (${stage.name}) existing scheduled task for finalizing shuffle merge" + - s"would either be in-progress or finished. No need to schedule shuffle merge" + - s" finalization again.") + logInfo( + log"${MDC(STAGE, stage)} (${MDC(STAGE_NAME, stage.name)}) existing scheduled task " + + log"for finalizing shuffle merge would either be in-progress or finished. " + + log"No need to schedule shuffle merge finalization again.") } case None => // If no previous finalization task is scheduled, schedule the finalization task. - logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle merge in $delay s") + logInfo(log"${MDC(STAGE, stage)} (${MDC(STAGE_NAME, stage.name)}) scheduled for " + + log"finalizing shuffle merge in ${MDC(DELAY, delay * 1000L)} ms") shuffleDep.setFinalizeTask( shuffleMergeFinalizeScheduler.schedule( new Runnable { @@ -2408,8 +2431,9 @@ private[spark] class DAGScheduler( private[scheduler] def finalizeShuffleMerge( stage: ShuffleMapStage, registerMergeResults: Boolean = true): Unit = { - logInfo(s"$stage (${stage.name}) finalizing the shuffle merge with registering merge " + - s"results set to $registerMergeResults") + logInfo( + log"${MDC(STAGE, stage)} (${MDC(STAGE_NAME, stage.name)}) finalizing the shuffle merge with" + + log" registering merge results set to ${MDC(REGISTER_MERGE_RESULTS, registerMergeResults)}") val shuffleId = stage.shuffleDep.shuffleId val shuffleMergeId = stage.shuffleDep.shuffleMergeId val numMergers = stage.shuffleDep.getMergerLocs.length @@ -2479,8 +2503,9 @@ private[spark] class DAGScheduler( } catch { case _: TimeoutException => timedOut = true - logInfo(s"Timed out on waiting for merge results from all " + - s"$numMergers mergers for shuffle $shuffleId") + logInfo(log"Timed out on waiting for merge results from all " + + log"${MDC(NUM_MERGERS, numMergers)} mergers for " + + log"shuffle ${MDC(SHUFFLE_ID, shuffleId)}") } finally { if (timedOut || !registerMergeResults) { cancelFinalizeShuffleMergeFutures(scheduledFutures, @@ -2511,9 +2536,9 @@ private[spark] class DAGScheduler( private def processShuffleMapStageCompletion(shuffleStage: ShuffleMapStage): Unit = { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") - logInfo("running: " + runningStages) - logInfo("waiting: " + waitingStages) - logInfo("failed: " + failedStages) + logInfo(log"running: ${MDC(STAGES, runningStages)}") + logInfo(log"waiting: ${MDC(STAGES, waitingStages)}") + logInfo(log"failed: ${MDC(STAGES, failedStages)}") // This call to increment the epoch may not be strictly necessary, but it is retained // for now in order to minimize the changes in behavior from an earlier version of the @@ -2529,9 +2554,10 @@ private[spark] class DAGScheduler( if (!shuffleStage.isAvailable) { // Some tasks had failed; let's resubmit this shuffleStage. // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + - ") because some of its tasks had failed: " + - shuffleStage.findMissingPartitions().mkString(", ")) + logInfo(log"Resubmitting ${MDC(STAGE, shuffleStage)} " + + log"(${MDC(STAGE_NAME, shuffleStage.name)}) " + + log"because some of its tasks had failed: " + + log"${MDC(PARTITION_IDS, shuffleStage.findMissingPartitions().mkString(", "))}") submitStage(shuffleStage) } else { markMapStageJobsAsFinished(shuffleStage) @@ -2603,7 +2629,7 @@ private[spark] class DAGScheduler( } private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { - logInfo(s"Resubmitted $task, so marking it as still running.") + logInfo(log"Resubmitted ${MDC(TASK_ID, task)}, so marking it as still running.") stage match { case sms: ShuffleMapStage => sms.pendingPartitions += task.partitionId @@ -2679,7 +2705,7 @@ private[spark] class DAGScheduler( if (!isShuffleMerger && (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch)) { executorFailureEpoch(execId) = currentEpoch - logInfo(s"Executor lost: $execId (epoch $currentEpoch)") + logInfo(log"Executor lost: ${MDC(EXECUTOR_ID, execId)} (epoch ${MDC(EPOCH, currentEpoch)})") if (pushBasedShuffleEnabled) { // Remove fetchFailed host in the shuffle push merger list for push based shuffle hostToUnregisterOutputs.foreach( @@ -2703,10 +2729,12 @@ private[spark] class DAGScheduler( if (remove) { hostToUnregisterOutputs match { case Some(host) => - logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)") + logInfo(log"Shuffle files lost for host: ${MDC(HOST, host)} (epoch " + + log"${MDC(EPOCH, currentEpoch)}") mapOutputTracker.removeOutputsOnHost(host) case None => - logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)") + logInfo(log"Shuffle files lost for executor: ${MDC(EXECUTOR_ID, execId)} " + + log"(epoch ${MDC(EPOCH, currentEpoch)})") mapOutputTracker.removeOutputsOnExecutor(execId) } } @@ -2728,7 +2756,8 @@ private[spark] class DAGScheduler( workerId: String, host: String, message: String): Unit = { - logInfo("Shuffle files lost for worker %s on host %s".format(workerId, host)) + logInfo(log"Shuffle files lost for worker ${MDC(WORKER_ID, workerId)} " + + log"on host ${MDC(HOST, host)}") mapOutputTracker.removeOutputsOnHost(host) clearCacheLocs() } @@ -2736,7 +2765,7 @@ private[spark] class DAGScheduler( private[scheduler] def handleExecutorAdded(execId: String, host: String): Unit = { // remove from executorFailureEpoch(execId) ? if (executorFailureEpoch.contains(execId)) { - logInfo("Host added was in lost list earlier: " + host) + logInfo(log"Host added was in lost list earlier: ${MDC(HOST, host)}") executorFailureEpoch -= execId } shuffleFileLostEpoch -= execId @@ -2749,10 +2778,10 @@ private[spark] class DAGScheduler( }.foreach { case (_, stage: ShuffleMapStage) => configureShufflePushMergerLocations(stage) if (stage.shuffleDep.getMergerLocs.nonEmpty) { - logInfo(s"Shuffle merge enabled adaptively for $stage with shuffle" + - s" ${stage.shuffleDep.shuffleId} and shuffle merge" + - s" ${stage.shuffleDep.shuffleMergeId} with ${stage.shuffleDep.getMergerLocs.size}" + - s" merger locations") + logInfo(log"Shuffle merge enabled adaptively for ${MDC(STAGE, stage)} with shuffle" + + log" ${MDC(SHUFFLE_ID, stage.shuffleDep.shuffleId)} and shuffle merge" + + log" ${MDC(SHUFFLE_MERGE_ID, stage.shuffleDep.shuffleMergeId)} with " + + log"${MDC(NUM_MERGER_LOCATIONS, stage.shuffleDep.getMergerLocs.size)} merger locations") } } } @@ -2772,7 +2801,7 @@ private[spark] class DAGScheduler( handleJobCancellation(jobId, Option(reasonStr)) } case None => - logInfo("No active jobs to kill for Stage " + stageId) + logInfo(log"No active jobs to kill for Stage ${MDC(STAGE_ID, stageId)}") } } @@ -2795,11 +2824,12 @@ private[spark] class DAGScheduler( errorMessage: Option[String] = None, willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case Some(t) => clock.getTimeMillis() - t case _ => "Unknown" } if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + logInfo(log"${MDC(STAGE, stage)} (${MDC(STAGE_NAME, stage.name)}) " + + log"finished in ${MDC(TIME_UNITS, serviceTime)} ms") stage.latestInfo.completionTime = Some(clock.getTimeMillis()) // Clear failure count for this stage, now that it's succeeded. @@ -2809,7 +2839,8 @@ private[spark] class DAGScheduler( stage.clearFailures() } else { stage.latestInfo.stageFailed(errorMessage.get) - logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") + logInfo(log"${MDC(STAGE, stage)} (${MDC(STAGE_NAME, stage.name)}) failed in " + + log"${MDC(TIME_UNITS, serviceTime)} ms due to ${MDC(ERROR, errorMessage.get)}") } updateStageInfoForPushBasedShuffle(stage) if (!willRetry) { @@ -2855,7 +2886,8 @@ private[spark] class DAGScheduler( failJobAndIndependentStages(job, finalException) } if (dependentJobs.isEmpty) { - logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") + logInfo(log"Ignoring failure of ${MDC(FAILED_STAGE, failedStage)} because all jobs " + + log"depending on it are done") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala index cecf5d498ac4b..1606072153906 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala @@ -111,8 +111,8 @@ private[scheduler] class HealthTracker ( val execsToInclude = executorIdToExcludedStatus.filter(_._2.expiryTime < now).keys if (execsToInclude.nonEmpty) { // Include any executors that have been excluded longer than the excludeOnFailure timeout. - logInfo(s"Removing executors $execsToInclude from exclude list because the " + - s"the executors have reached the timed out") + logInfo(log"Removing executors ${MDC(EXECUTOR_IDS, execsToInclude)} from " + + log"exclude list because the executors have reached the timed out") execsToInclude.foreach { exec => val status = executorIdToExcludedStatus.remove(exec).get val failedExecsOnNode = nodeToExcludedExecs(status.node) @@ -128,8 +128,8 @@ private[scheduler] class HealthTracker ( val nodesToInclude = nodeIdToExcludedExpiryTime.filter(_._2 < now).keys if (nodesToInclude.nonEmpty) { // Include any nodes that have been excluded longer than the excludeOnFailure timeout. - logInfo(s"Removing nodes $nodesToInclude from exclude list because the " + - s"nodes have reached has timed out") + logInfo(log"Removing nodes ${MDC(NODES, nodesToInclude)} from exclude list because the " + + log"nodes have reached has timed out") nodesToInclude.foreach { node => nodeIdToExcludedExpiryTime.remove(node) // post both to keep backwards compatibility @@ -173,8 +173,8 @@ private[scheduler] class HealthTracker ( force = true) } case None => - logInfo(s"Not attempting to kill excluded executor id $exec " + - s"since allocation client is not defined.") + logInfo(log"Not attempting to kill excluded executor id ${MDC(EXECUTOR_ID, exec)}" + + log" since allocation client is not defined.") } } @@ -196,14 +196,15 @@ private[scheduler] class HealthTracker ( allocationClient match { case Some(a) => if (EXCLUDE_ON_FAILURE_DECOMMISSION_ENABLED) { - logInfo(s"Decommissioning all executors on excluded host $node " + - s"since ${config.EXCLUDE_ON_FAILURE_KILL_ENABLED.key} is set.") + logInfo(log"Decommissioning all executors on excluded host ${MDC(HOST, node)} " + + log"since ${MDC(CONFIG, config.EXCLUDE_ON_FAILURE_KILL_ENABLED.key)} " + + log"is set.") if (!a.decommissionExecutorsOnHost(node)) { logError(log"Decommissioning executors on ${MDC(HOST, node)} failed.") } } else { - logInfo(s"Killing all executors on excluded host $node " + - s"since ${config.EXCLUDE_ON_FAILURE_KILL_ENABLED.key} is set.") + logInfo(log"Killing all executors on excluded host ${MDC(HOST, node)} " + + log"since ${MDC(CONFIG, config.EXCLUDE_ON_FAILURE_KILL_ENABLED.key)} is set.") if (!a.killExecutorsOnHost(node)) { logError(log"Killing executors on node ${MDC(HOST, node)} failed.") } @@ -231,7 +232,8 @@ private[scheduler] class HealthTracker ( if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { if (!nodeIdToExcludedExpiryTime.contains(host)) { - logInfo(s"excluding node $host due to fetch failure of external shuffle service") + logInfo(log"excluding node ${MDC(HOST, host)} due to fetch failure of " + + log"external shuffle service") nodeIdToExcludedExpiryTime.put(host, expiryTimeForNewExcludes) // post both to keep backwards compatibility @@ -242,7 +244,7 @@ private[scheduler] class HealthTracker ( updateNextExpiryTime() } } else if (!executorIdToExcludedStatus.contains(exec)) { - logInfo(s"Excluding executor $exec due to fetch failure") + logInfo(log"Excluding executor ${MDC(EXECUTOR_ID, exec)} due to fetch failure") executorIdToExcludedStatus.put(exec, ExcludedExecutor(host, expiryTimeForNewExcludes)) // We hardcoded number of failure tasks to 1 for fetch failure, because there's no @@ -280,8 +282,8 @@ private[scheduler] class HealthTracker ( // some of the logic around expiry times a little more confusing. But it also wouldn't be a // problem to re-exclude, with a later expiry time. if (newTotal >= MAX_FAILURES_PER_EXEC && !executorIdToExcludedStatus.contains(exec)) { - logInfo(s"Excluding executor id: $exec because it has $newTotal" + - s" task failures in successful task sets") + logInfo(log"Excluding executor id: ${MDC(EXECUTOR_ID, exec)} because it has " + + log"${MDC(TOTAL, newTotal)} task failures in successful task sets") val node = failuresInTaskSet.node executorIdToExcludedStatus.put(exec, ExcludedExecutor(node, expiryTimeForNewExcludes)) // post both to keep backwards compatibility @@ -299,8 +301,9 @@ private[scheduler] class HealthTracker ( // time. if (excludedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE && !nodeIdToExcludedExpiryTime.contains(node)) { - logInfo(s"Excluding node $node because it has ${excludedExecsOnNode.size} " + - s"executors excluded: ${excludedExecsOnNode}") + logInfo(log"Excluding node ${MDC(HOST, node)} because it has " + + log"${MDC(NUM_EXECUTORS, excludedExecsOnNode.size)} executors " + + log"excluded: ${MDC(EXECUTOR_IDS, excludedExecsOnNode)}") nodeIdToExcludedExpiryTime.put(node, expiryTimeForNewExcludes) // post both to keep backwards compatibility listenerBus.post(SparkListenerNodeBlacklisted(now, node, excludedExecsOnNode.size)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index cd5d6b8f9c90d..d9020da4cdcb3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.util.{RpcUtils, ThreadUtils} @@ -124,7 +124,7 @@ private[spark] class OutputCommitCoordinator( stageStates.get(stage) match { case Some(state) => require(state.authorizedCommitters.length == maxPartitionId + 1) - logInfo(s"Reusing state from previous attempt of stage $stage.") + logInfo(log"Reusing state from previous attempt of stage ${MDC(LogKeys.STAGE_ID, stage)}") case _ => stageStates(stage) = new StageState(maxPartitionId + 1) @@ -151,8 +151,10 @@ private[spark] class OutputCommitCoordinator( case Success => // The task output has been committed successfully case _: TaskCommitDenied => - logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + - s"partition: $partition, attempt: $attemptNumber") + logInfo(log"Task was denied committing, stage: ${MDC(LogKeys.STAGE_ID, stage)}." + + log"${MDC(LogKeys.STAGE_ATTEMPT, stageAttempt)}, " + + log"partition: ${MDC(LogKeys.PARTITION_ID, partition)}, " + + log"attempt: ${MDC(LogKeys.NUM_ATTEMPT, attemptNumber)}") case _ => // Mark the attempt as failed to exclude from future commit protocol val taskId = TaskIdentifier(stageAttempt, attemptNumber) @@ -182,8 +184,10 @@ private[spark] class OutputCommitCoordinator( attemptNumber: Int): Boolean = synchronized { stageStates.get(stage) match { case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => - logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + - s"task attempt $attemptNumber already marked as failed.") + logInfo(log"Commit denied for stage=${MDC(LogKeys.STAGE_ID, stage)}." + + log"${MDC(LogKeys.STAGE_ATTEMPT, stageAttempt)}, partition=" + + log"${MDC(LogKeys.PARTITION_ID, partition)}: task attempt " + + log"${MDC(LogKeys.NUM_ATTEMPT, attemptNumber)} already marked as failed.") false case Some(state) => val existing = state.authorizedCommitters(partition) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index b5cc6261cea38..6f64dff3f39d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -80,20 +80,23 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, sc: SparkContext fileData = schedulerAllocFile.map { f => val filePath = new Path(f) val fis = filePath.getFileSystem(sc.hadoopConfiguration).open(filePath) - logInfo(s"Creating Fair Scheduler pools from $f") + logInfo(log"Creating Fair Scheduler pools from ${MDC(LogKeys.FILE_NAME, f)}") Some((fis, f)) }.getOrElse { val is = Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE) if (is != null) { - logInfo(s"Creating Fair Scheduler pools from default file: $DEFAULT_SCHEDULER_FILE") + logInfo(log"Creating Fair Scheduler pools from default file: " + + log"${MDC(LogKeys.FILE_NAME, DEFAULT_SCHEDULER_FILE)}") Some((is, DEFAULT_SCHEDULER_FILE)) } else { val schedulingMode = SchedulingMode.withName(sc.conf.get(SCHEDULER_MODE)) rootPool.addSchedulable(new Pool( DEFAULT_POOL_NAME, schedulingMode, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) - logInfo("Fair scheduler configuration not found, created default pool: " + - "%s, schedulingMode: %s, minShare: %d, weight: %d".format( - DEFAULT_POOL_NAME, schedulingMode, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + logInfo(log"Fair scheduler configuration not found, created default pool: " + + log"${MDC(LogKeys.DEFAULT_NAME, DEFAULT_POOL_NAME)}, " + + log"schedulingMode: ${MDC(LogKeys.SCHEDULING_MODE, schedulingMode)}, " + + log"minShare: ${MDC(LogKeys.MIN_SHARE, DEFAULT_MINIMUM_SHARE)}, " + + log"weight: ${MDC(LogKeys.WEIGHT, DEFAULT_WEIGHT)}") None } } @@ -122,8 +125,10 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, sc: SparkContext val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) rootPool.addSchedulable(pool) - logInfo("Created default pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format( - DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + logInfo(log"Created default pool: ${MDC(LogKeys.POOL_NAME, DEFAULT_POOL_NAME)}, " + + log"schedulingMode: ${MDC(LogKeys.SCHEDULING_MODE, DEFAULT_SCHEDULING_MODE)}, " + + log"minShare: ${MDC(LogKeys.MIN_SHARE, DEFAULT_MINIMUM_SHARE)}, " + + log"weight: ${MDC(LogKeys.WEIGHT, DEFAULT_WEIGHT)}") } } @@ -142,8 +147,10 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, sc: SparkContext rootPool.addSchedulable(new Pool(poolName, schedulingMode, minShare, weight)) - logInfo("Created pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, schedulingMode, minShare, weight)) + logInfo(log"Created pool: ${MDC(LogKeys.POOL_NAME, poolName)}, " + + log"schedulingMode: ${MDC(LogKeys.SCHEDULING_MODE, schedulingMode)}, " + + log"minShare: ${MDC(LogKeys.MIN_SHARE, minShare)}, " + + log"weight: ${MDC(LogKeys.WEIGHT, weight)}") } } @@ -159,7 +166,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, sc: SparkContext log"${MDC(XML_SCHEDULING_MODE, xmlSchedulingMode)} found in " + log"Fair Scheduler configuration file: ${MDC(FILE_NAME, fileName)}, using " + log"the default schedulingMode: " + - log"${MDC(LogKeys.DEFAULT_SCHEDULING_MODE, defaultValue)} for pool: " + + log"${MDC(LogKeys.SCHEDULING_MODE, defaultValue)} for pool: " + log"${MDC(POOL_NAME, poolName)}" try { if (SchedulingMode.withName(xmlSchedulingMode) != SchedulingMode.NONE) { @@ -215,11 +222,12 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, sc: SparkContext log"when that file doesn't contain ${MDC(POOL_NAME, poolName)}. " + log"Created ${MDC(CREATED_POOL_NAME, poolName)} with default " + log"configuration (schedulingMode: " + - log"${MDC(LogKeys.DEFAULT_SCHEDULING_MODE, DEFAULT_SCHEDULING_MODE)}, " + + log"${MDC(LogKeys.SCHEDULING_MODE, DEFAULT_SCHEDULING_MODE)}, " + log"minShare: ${MDC(MIN_SHARE, DEFAULT_MINIMUM_SHARE)}, " + log"weight: ${MDC(WEIGHT, DEFAULT_WEIGHT)}") } parentPool.addSchedulable(manager) - logInfo("Added task set " + manager.name + " tasks to pool " + poolName) + logInfo(log"Added task set ${MDC(LogKeys.TASK_SET_MANAGER, manager.name)} tasks to pool " + + log"${MDC(LogKeys.POOL_NAME, poolName)}") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index 1f12b46412bc5..e46dde5561a26 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.util.{Distribution, Utils} @@ -46,7 +46,8 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { implicit val sc = stageCompleted - this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") + this.logInfo( + log"Finished stage: ${MDC(LogKeys.STAGE, getStatusDetail(stageCompleted.stageInfo))}") showMillisDistribution("task runtime:", (info, _) => info.duration, taskInfoMetrics.toSeq) // Shuffle write @@ -111,9 +112,9 @@ private[spark] object StatsReportListener extends Logging { def showDistribution(heading: String, d: Distribution, formatNumber: Double => String): Unit = { val stats = d.statCounter val quantiles = d.getQuantiles(probabilities).map(formatNumber) - logInfo(heading + stats) + logInfo(log"${MDC(LogKeys.DESCRIPTION, heading)}${MDC(LogKeys.STATS, stats)}") logInfo(percentilesHeader) - logInfo("\t" + quantiles.mkString("\t")) + logInfo(log"\t" + log"${MDC(LogKeys.QUANTILES, quantiles.mkString("\t"))}") } def showDistribution( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 15bdd58288f1e..ad0e0ddb687ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -250,8 +250,9 @@ private[spark] class TaskSchedulerImpl( override def submitTasks(taskSet: TaskSet): Unit = { val tasks = taskSet.tasks - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks " - + "resource profile " + taskSet.resourceProfileId) + logInfo(log"Adding task set ${MDC(LogKeys.TASK_SET_ID, taskSet.id)} with " + + log"${MDC(LogKeys.NUM_TASKS, tasks.length)} tasks resource profile " + + log"${MDC(LogKeys.RESOURCE_PROFILE_ID, taskSet.resourceProfileId)}") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) val stage = taskSet.stageId @@ -306,9 +307,10 @@ private[spark] class TaskSchedulerImpl( stageId: Int, interruptThread: Boolean, reason: String): Unit = synchronized { - logInfo("Cancelling stage " + stageId) + logInfo(log"Canceling stage ${MDC(LogKeys.STAGE_ID, stageId)}") // Kill all running tasks for the stage. - logInfo(s"Killing all running tasks in stage $stageId: $reason") + logInfo(log"Killing all running tasks in stage ${MDC(LogKeys.STAGE_ID, stageId)}: " + + log"${MDC(LogKeys.REASON, reason)}") taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => attempts.foreach { case (_, tsm) => // There are two possible cases here: @@ -322,7 +324,8 @@ private[spark] class TaskSchedulerImpl( } } tsm.suspend() - logInfo("Stage %s.%s was cancelled".format(stageId, tsm.taskSet.stageAttemptId)) + logInfo(log"Stage ${MDC(LogKeys.STAGE_ID, stageId)}." + + log"${MDC(LogKeys.STAGE_ATTEMPT, tsm.taskSet.stageAttemptId)} was cancelled") } } } @@ -331,7 +334,7 @@ private[spark] class TaskSchedulerImpl( taskId: Long, interruptThread: Boolean, reason: String): Boolean = synchronized { - logInfo(s"Killing task $taskId: $reason") + logInfo(log"Killing task ${MDC(LogKeys.TASK_ID, taskId)}: ${MDC(LogKeys.REASON, reason)}") val execId = taskIdToExecutorId.get(taskId) if (execId.isDefined) { backend.killTask(taskId, execId.get, interruptThread, reason) @@ -361,8 +364,8 @@ private[spark] class TaskSchedulerImpl( } noRejectsSinceLastReset -= manager.taskSet manager.parent.removeSchedulable(manager) - logInfo(s"Removed TaskSet ${manager.taskSet.id}, whose tasks have all completed, from pool" + - s" ${manager.parent.name}") + logInfo(log"Removed TaskSet ${MDC(LogKeys.TASK_SET_NAME, manager.taskSet.id)}, whose tasks " + + log"have all completed, from pool ${MDC(LogKeys.POOL_NAME, manager.parent.name)}") } /** @@ -559,9 +562,10 @@ private[spark] class TaskSchedulerImpl( // Skip the launch process. // TODO SPARK-24819 If the job requires more slots than available (both busy and free // slots), fail the job on submit. - logInfo(s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + - s"because the barrier taskSet requires ${taskSet.numTasks} slots, while the total " + - s"number of available slots is $numBarrierSlotsAvailable.") + logInfo(log"Skip current round of resource offers for barrier stage " + + log"${MDC(LogKeys.STAGE_ID, taskSet.stageId)} because the barrier taskSet requires " + + log"${MDC(LogKeys.TASK_SET_NAME, taskSet.numTasks)} slots, while the total " + + log"number of available slots is ${MDC(LogKeys.NUM_SLOTS, numBarrierSlotsAvailable)}.") } else { var launchedAnyTask = false var noDelaySchedulingRejects = true @@ -619,18 +623,18 @@ private[spark] class TaskSchedulerImpl( // in order to provision more executors to make them schedulable if (Utils.isDynamicAllocationEnabled(conf)) { if (!unschedulableTaskSetToExpiryTime.contains(taskSet)) { - logInfo("Notifying ExecutorAllocationManager to allocate more executors to" + - " schedule the unschedulable task before aborting" + - s" stage ${taskSet.stageId}.") + logInfo(log"Notifying ExecutorAllocationManager to allocate more executors to" + + log" schedule the unschedulable task before aborting" + + log" stage ${MDC(LogKeys.STAGE_ID, taskSet.stageId)}.") dagScheduler.unschedulableTaskSetAdded(taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId) updateUnschedulableTaskSetTimeoutAndStartAbortTimer(taskSet, taskIndex) } } else { // Abort Immediately - logInfo("Cannot schedule any task because all executors excluded from " + - "failures. No idle executors can be found to kill. Aborting stage " + - s"${taskSet.stageId}.") + logInfo(log"Cannot schedule any task because all executors excluded from " + + log"failures. No idle executors can be found to kill. Aborting stage " + + log"${MDC(LogKeys.STAGE_ID, taskSet.stageId)}.") taskSet.abortSinceCompletelyExcludedOnFailure(taskIndex) } } @@ -643,8 +647,8 @@ private[spark] class TaskSchedulerImpl( // non-excluded executor and the abort timer doesn't kick in because of a constant // submission of new TaskSets. See the PR for more details. if (unschedulableTaskSetToExpiryTime.nonEmpty) { - logInfo("Clearing the expiry times for all unschedulable taskSets as a task was " + - "recently scheduled.") + logInfo(log"Clearing the expiry times for all unschedulable taskSets as a task " + + log"was recently scheduled.") // Notify ExecutorAllocationManager as well as other subscribers that a task now // recently becomes schedulable dagScheduler.unschedulableTaskSetRemoved(taskSet.taskSet.stageId, @@ -679,8 +683,8 @@ private[spark] class TaskSchedulerImpl( val curTime = clock.getTimeMillis() if (curTime - taskSet.lastResourceOfferFailLogTime > TaskSetManager.BARRIER_LOGGING_INTERVAL) { - logInfo("Releasing the assigned resource offers since only partial tasks can " + - "be launched. Waiting for later round resource offers.") + logInfo(log"Releasing the assigned resource offers since only partial tasks can " + + log"be launched. Waiting for later round resource offers.") taskSet.lastResourceOfferFailLogTime = curTime } barrierPendingLaunchTasks.foreach { task => @@ -722,8 +726,8 @@ private[spark] class TaskSchedulerImpl( .mkString(",") addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr)) - logInfo(s"Successfully scheduled all the ${addressesWithDescs.length} tasks for " + - s"barrier stage ${taskSet.stageId}.") + logInfo(log"Successfully scheduled all the ${MDC(LogKeys.NUM_TASKS, addressesWithDescs.length)} " + + log"tasks for barrier stage ${MDC(LogKeys.STAGE_ID, taskSet.stageId)}.") } taskSet.barrierPendingLaunchTasks.clear() } @@ -743,8 +747,8 @@ private[spark] class TaskSchedulerImpl( taskIndex: Int): Unit = { val timeout = conf.get(config.UNSCHEDULABLE_TASKSET_TIMEOUT) * 1000 unschedulableTaskSetToExpiryTime(taskSet) = clock.getTimeMillis() + timeout - logInfo(s"Waiting for $timeout ms for completely " + - s"excluded task to be schedulable again before aborting stage ${taskSet.stageId}.") + logInfo(log"Waiting for ${MDC(LogKeys.TIMEOUT, timeout)} ms for completely " + + log"excluded task to be schedulable again before aborting stage ${MDC(LogKeys.STAGE_ID, taskSet.stageId)}.") abortTimer.schedule( createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout, TimeUnit.MILLISECONDS) } @@ -756,8 +760,8 @@ private[spark] class TaskSchedulerImpl( override def run(): Unit = TaskSchedulerImpl.this.synchronized { if (unschedulableTaskSetToExpiryTime.contains(taskSet) && unschedulableTaskSetToExpiryTime(taskSet) <= clock.getTimeMillis()) { - logInfo("Cannot schedule any task because all executors excluded due to failures. " + - s"Wait time for scheduling expired. Aborting stage ${taskSet.stageId}.") + logInfo(log"Cannot schedule any task because all executors excluded due to failures. " + + log"Wait time for scheduling expired. Aborting stage ${MDC(LogKeys.STAGE_ID, taskSet.stageId)}.") taskSet.abortSinceCompletelyExcludedOnFailure(taskIndex) } else { this.cancel() @@ -1043,7 +1047,8 @@ private[spark] class TaskSchedulerImpl( } override def workerRemoved(workerId: String, host: String, message: String): Unit = { - logInfo(s"Handle removed worker $workerId: $message") + logInfo(log"Handle removed worker ${MDC(LogKeys.WORKER_ID, workerId)}: " + + log"${MDC(LogKeys.MESSAGE, message)}") dagScheduler.workerRemoved(workerId, host, message) } @@ -1054,10 +1059,12 @@ private[spark] class TaskSchedulerImpl( case LossReasonPending => logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.") case ExecutorKilled => - logInfo(s"Executor $executorId on $hostPort killed by driver.") + logInfo(log"Executor ${MDC(LogKeys.EXECUTOR_ID, executorId)} on " + + log"${MDC(LogKeys.HOST_PORT, hostPort)} killed by driver.") case _: ExecutorDecommission => - logInfo(s"Executor $executorId on $hostPort is decommissioned" + - s"${getDecommissionDuration(executorId)}.") + logInfo(log"Executor ${MDC(LogKeys.EXECUTOR_ID, executorId)} on " + + log"${MDC(LogKeys.HOST_PORT, hostPort)} is decommissioned" + + log"${MDC(DURATION, getDecommissionDuration(executorId))}.") case _ => logError(log"Lost executor ${MDC(LogKeys.EXECUTOR_ID, executorId)} on " + log"${MDC(LogKeys.HOST, hostPort)}: ${MDC(LogKeys.REASON, reason)}") diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala index f479e5e32bc2f..c9aa74e0852be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala @@ -19,8 +19,7 @@ package org.apache.spark.scheduler import scala.collection.mutable.{HashMap, HashSet} import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config +import org.apache.spark.internal.{config, Logging, LogKeys, MDC} import org.apache.spark.util.Clock /** @@ -134,7 +133,8 @@ private[scheduler] class TaskSetExcludelist( val numFailures = execFailures.numUniqueTasksWithFailures if (numFailures >= MAX_FAILURES_PER_EXEC_STAGE) { if (excludedExecs.add(exec)) { - logInfo(s"Excluding executor ${exec} for stage $stageId") + logInfo(log"Excluding executor ${MDC(LogKeys.EXECUTOR_ID, exec)} for stage " + + log"${MDC(LogKeys.STAGE_ID, stageId)}") // This executor has been excluded for this stage. Let's check if it // the whole node should be excluded. val excludedExecutorsOnNode = @@ -149,7 +149,8 @@ private[scheduler] class TaskSetExcludelist( val numFailExec = excludedExecutorsOnNode.size if (numFailExec >= MAX_FAILED_EXEC_PER_NODE_STAGE) { if (excludedNodes.add(host)) { - logInfo(s"Excluding ${host} for stage $stageId") + logInfo(log"Excluding ${MDC(LogKeys.HOST, host)} for " + + log"stage ${MDC(LogKeys.STAGE_ID, stageId)}") // SparkListenerNodeBlacklistedForStage is deprecated but post both events // to keep backward compatibility listenerBus.post( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index b7ff443231f30..6573ab2f23d62 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -30,8 +30,7 @@ import org.apache.spark.InternalAccumulator import org.apache.spark.InternalAccumulator.{input, shuffleRead} import org.apache.spark.TaskState.TaskState import org.apache.spark.errors.SparkCoreErrors -import org.apache.spark.internal.{config, Logging, MDC} -import org.apache.spark.internal.LogKeys +import org.apache.spark.internal.{config, Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.config._ import org.apache.spark.scheduler.SchedulingMode._ @@ -280,8 +279,9 @@ private[spark] class TaskSetManager( for (e <- set) { pendingTaskSetToAddTo.forExecutor.getOrElseUpdate(e, new ArrayBuffer) += index } - logInfo(s"Pending task $index has a cached location at ${e.host} " + - ", where there are executors " + set.mkString(",")) + logInfo(log"Pending task ${MDC(INDEX, index)} has a cached location at " + + log"${MDC(HOST, e.host)}, where there are executors " + + log"${MDC(EXECUTOR_IDS, set.mkString(","))}") case None => logDebug(s"Pending task $index has a cached location at ${e.host} " + ", but there are no executors alive there.") } @@ -554,10 +554,16 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val tName = taskName(taskId) - logInfo(s"Starting $tName ($host, executor ${info.executorId}, " + - s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit()} bytes) " + - (if (taskResourceAssignments.nonEmpty) s"taskResourceAssignments ${taskResourceAssignments}" - else "")) + logInfo(log"Starting ${MDC(TASK_NAME, tName)} (${MDC(HOST, host)}," + + log"executor ${MDC(LogKeys.EXECUTOR_ID, info.executorId)}, " + + log"partition ${MDC(PARTITION_ID, task.partitionId)}, " + + log"${MDC(TASK_LOCALITY, taskLocality)}, " + + log"${MDC(SIZE, serializedTask.limit())} bytes) " + + (if (taskResourceAssignments.nonEmpty) { + log"taskResourceAssignments ${MDC(TASK_RESOURCE_ASSIGNMENTS, taskResourceAssignments)}" + } else { + log"" + })) sched.dagScheduler.taskStarted(task, info) new TaskDescription( @@ -829,8 +835,11 @@ private[spark] class TaskSetManager( // Kill any other attempts for the same task (since those are unnecessary now that one // attempt completed successfully). for (attemptInfo <- taskAttempts(index) if attemptInfo.running) { - logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for ${taskName(attemptInfo.taskId)}" + - s" on ${attemptInfo.host} as the attempt ${info.attemptNumber} succeeded on ${info.host}") + logInfo(log"Killing attempt ${MDC(NUM_ATTEMPT, attemptInfo.attemptNumber)} for " + + log"${MDC(TASK_NAME, taskName(attemptInfo.taskId))} on " + + log"${MDC(HOST, attemptInfo.host)} as the attempt " + + log"${MDC(TASK_ATTEMPT_ID, info.attemptNumber)} succeeded on " + + log"${MDC(HOST, info.host)}") killedByOtherAttempt += attemptInfo.taskId sched.backend.killTask( attemptInfo.taskId, @@ -840,8 +849,10 @@ private[spark] class TaskSetManager( } if (!successful(index)) { tasksSuccessful += 1 - logInfo(s"Finished ${taskName(info.taskId)} in ${info.duration} ms " + - s"on ${info.host} (executor ${info.executorId}) ($tasksSuccessful/$numTasks)") + logInfo(log"Finished ${MDC(TASK_NAME, taskName(info.taskId))} in " + + log"${MDC(DURATION, info.duration)} ms on ${MDC(HOST, info.host)} " + + log"(executor ${MDC(LogKeys.EXECUTOR_ID, info.executorId)}) " + + log"(${MDC(NUM_SUCCESSFUL_TASKS, tasksSuccessful)}/${MDC(NUM_TASKS, numTasks)})") // Mark successful and stop if all the tasks have succeeded. successful(index) = true numFailures(index) = 0 @@ -849,8 +860,9 @@ private[spark] class TaskSetManager( isZombie = true } } else { - logInfo(s"Ignoring task-finished event for ${taskName(info.taskId)} " + - s"because it has already completed successfully") + logInfo(log"Ignoring task-finished event for " + + log"${MDC(TASK_NAME, taskName(info.taskId))} " + + log"because it has already completed successfully") } // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not @@ -1007,8 +1019,10 @@ private[spark] class TaskSetManager( logWarning(failureReason) } else { logInfo( - s"Lost $task on ${info.host}, executor ${info.executorId}: " + - s"${ef.className} (${ef.description}) [duplicate $dupCount]") + log"Lost ${MDC(TASK_NAME, task)} on ${MDC(HOST, info.host)}, " + + log"executor ${MDC(LogKeys.EXECUTOR_ID, info.executorId)}: " + + log"${MDC(CLASS_NAME, ef.className)} " + + log"(${MDC(DESCRIPTION, ef.description)}) [duplicate ${MDC(COUNT, dupCount)}]") } ef.exception @@ -1020,9 +1034,9 @@ private[spark] class TaskSetManager( None case e: ExecutorLostFailure if !e.exitCausedByApp => - logInfo(s"${taskName(tid)} failed because while it was being computed, its executor " + - "exited for a reason unrelated to the task. Not counting this failure towards the " + - "maximum number of failures for the task.") + logInfo(log"${MDC(TASK_NAME, taskName(tid))} failed because while it was being computed," + + log" its executor exited for a reason unrelated to the task. " + + log"Not counting this failure towards the maximum number of failures for the task.") None case _: TaskFailedReason => // TaskResultLost and others @@ -1052,10 +1066,10 @@ private[spark] class TaskSetManager( } if (successful(index)) { - logInfo(s"${taskName(info.taskId)} failed, but the task will not" + - " be re-executed (either because the task failed with a shuffle data fetch failure," + - " so the previous stage needs to be re-run, or because a different copy of the task" + - " has already succeeded).") + logInfo(log"${MDC(LogKeys.TASK_NAME, taskName(info.taskId))} failed, but the task will not" + + log" be re-executed (either because the task failed with a shuffle data fetch failure," + + log" so the previous stage needs to be re-run, or because a different copy of the task" + + log" has already succeeded).") } else { addPendingTask(index) } @@ -1238,9 +1252,10 @@ private[spark] class TaskSetManager( if (speculated) { addPendingTask(index, speculatable = true) logInfo( - ("Marking task %d in stage %s (on %s) as speculatable because it ran more" + - " than %.0f ms(%d speculatable tasks in this taskset now)") - .format(index, taskSet.id, info.host, threshold, speculatableTasks.size + 1)) + log"Marking task ${MDC(INDEX, index)} in stage ${MDC(STAGE_ID, taskSet.id)} (on " + + log"${MDC(HOST, info.host)}) as speculatable because it ran more than " + + log"${MDC(TIMEOUT, threshold)} ms(${MDC(NUM_TASKS, speculatableTasks.size + 1)}" + + log"speculatable tasks in this taskset now)") speculatableTasks += index sched.dagScheduler.speculativeTaskSubmitted(tasks(index), index) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index d359b65caa931..deaa1b4e47906 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -42,6 +42,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ @@ -258,7 +259,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // If the cluster manager gives us an executor on an excluded node (because it // already started allocating those resources before we informed it of our exclusion, // or if it ignored our exclusion), then we reject that executor immediately. - logInfo(s"Rejecting $executorId as it has been excluded.") + logInfo(log"Rejecting ${MDC(LogKeys.EXECUTOR_ID, executorId)} as it has been excluded.") context.sendFailure( new IllegalStateException(s"Executor is excluded due to failures: $executorId")) } else { @@ -269,8 +270,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } else { context.senderAddress } - logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId, " + - s" ResourceProfileId $resourceProfileId") + logInfo(log"Registered executor ${MDC(LogKeys.RPC_ENDPOINT_REF, executorRef)} " + + log"(${MDC(LogKeys.RPC_ADDRESS, executorAddress)}) " + + log"with ID ${MDC(LogKeys.EXECUTOR_ID, executorId)}, " + + log"ResourceProfileId ${MDC(LogKeys.RESOURCE_PROFILE_ID, resourceProfileId)}") addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) @@ -324,7 +327,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case UpdateExecutorsLogLevel(logLevel) => currentLogLevel = Some(logLevel) - logInfo(s"Asking each executor to refresh the log level to $logLevel") + logInfo(log"Asking each executor to refresh the log level to " + + log"${MDC(LogKeys.LOG_LEVEL, logLevel)}") for ((_, executorData) <- executorDataMap) { executorData.executorEndpoint.send(UpdateExecutorLogLevel(logLevel)) } @@ -497,7 +501,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // forever. Therefore, we should also post `SparkListenerExecutorRemoved` here. listenerBus.post(SparkListenerExecutorRemoved( System.currentTimeMillis(), executorId, reason.toString)) - logInfo(s"Asked to remove non-existent executor $executorId") + logInfo( + log"Asked to remove non-existent executor ${MDC(LogKeys.EXECUTOR_ID, executorId)}") } } @@ -526,7 +531,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } if (shouldDisable) { - logInfo(s"Disabling executor $executorId.") + logInfo(log"Disabling executor ${MDC(LogKeys.EXECUTOR_ID, executorId)}.") scheduler.executorLost(executorId, LossReasonPending) } @@ -570,7 +575,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp return executorsToDecommission.toImmutableArraySeq } - logInfo(s"Decommission executors: ${executorsToDecommission.mkString(", ")}") + logInfo(log"Decommission executors: " + + log"${MDC(LogKeys.EXECUTOR_IDS, executorsToDecommission.mkString(", "))}") // If we don't want to replace the executors we are decommissioning if (adjustTargetNumExecutors) { @@ -589,7 +595,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (!triggeredByExecutor) { executorsToDecommission.foreach { executorId => - logInfo(s"Notify executor $executorId to decommission.") + logInfo(log"Notify executor ${MDC(LogKeys.EXECUTOR_ID, executorId)} to decommission.") executorDataMap(executorId).executorEndpoint.send(DecommissionExecutor) } } @@ -601,7 +607,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorsToDecommission.filter(executorsPendingDecommission.contains) } if (stragglers.nonEmpty) { - logInfo(s"${stragglers.toList} failed to decommission in ${cleanupInterval}, killing.") + logInfo( + log"${MDC(LogKeys.EXECUTOR_IDS, stragglers.toList)} failed to decommission in " + + log"${MDC(LogKeys.INTERVAL, cleanupInterval)}, killing.") killExecutors(stragglers.toImmutableArraySeq, false, false, true) } } @@ -718,13 +726,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def isReady(): Boolean = { if (sufficientResourcesRegistered()) { - logInfo("SchedulerBackend is ready for scheduling beginning after " + - s"reached minRegisteredResourcesRatio: $minRegisteredRatio") + logInfo(log"SchedulerBackend is ready for scheduling beginning after " + + log"reached minRegisteredResourcesRatio: ${MDC(LogKeys.MIN_SIZE, minRegisteredRatio)}") return true } if ((System.nanoTime() - createTimeNs) >= maxRegisteredWaitingTimeNs) { - logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTimeNs(ns)") + logInfo(log"SchedulerBackend is ready for scheduling beginning after waiting " + + log"maxRegisteredResourcesWaitingTime: " + + log"${MDC(LogKeys.TIMEOUT, maxRegisteredWaitingTimeNs / NANOS_PER_MILLIS.toDouble)}(ms)") return true } false @@ -801,7 +810,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp "Attempted to request a negative number of additional executor(s) " + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") } - logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") + logInfo(log"Requesting ${MDC(LogKeys.NUM_EXECUTORS, numAdditionalExecutors)} additional " + + log"executor(s) from the cluster manager") val response = synchronized { val defaultProf = scheduler.sc.resourceProfileManager.defaultResourceProfile @@ -951,7 +961,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp adjustTargetNumExecutors: Boolean, countFailures: Boolean, force: Boolean): Seq[String] = { - logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") + logInfo( + log"Requesting to kill executor(s) ${MDC(LogKeys.EXECUTOR_IDS, executorIds.mkString(", "))}") val response = withLock { val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) @@ -966,7 +977,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp .filter { id => force || !scheduler.isExecutorBusy(id) } executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures } - logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") + logInfo(log"Actual list of executor(s) to be killed is " + + log"${MDC(LogKeys.EXECUTOR_IDS, executorsToKill.mkString(", "))}") // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, @@ -1007,7 +1019,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * @return whether the decommission request is acknowledged. */ final override def decommissionExecutorsOnHost(host: String): Boolean = { - logInfo(s"Requesting to kill any and all executors on host $host") + logInfo(log"Requesting to kill any and all executors on host ${MDC(LogKeys.HOST, host)}") // A potential race exists if a new executor attempts to register on a host // that is on the exclude list and is no longer valid. To avoid this race, // all executor registration and decommissioning happens in the event loop. This way, either @@ -1023,7 +1035,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * @return whether the kill request is acknowledged. */ final override def killExecutorsOnHost(host: String): Boolean = { - logInfo(s"Requesting to kill any and all executors on host $host") + logInfo(log"Requesting to kill any and all executors on host ${MDC(LogKeys.HOST, host)}") // A potential race exists if a new executor attempts to register on a host // that is on the exclude list and is no longer valid. To avoid this race, // all executor registration and killing happens in the event loop. This way, either diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 8f15dec6739a8..f4caecd7d6741 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -27,8 +27,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{StandaloneAppClient, StandaloneAppClientListener} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.internal.{config, Logging, MDC} -import org.apache.spark.internal.LogKeys.REASON +import org.apache.spark.internal.{config, Logging, LogKeys, MDC} import org.apache.spark.internal.config.EXECUTOR_REMOVE_DELAY import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} @@ -145,7 +144,7 @@ private[spark] class StandaloneSchedulerBackend( } override def connected(appId: String): Unit = { - logInfo("Connected to Spark cluster with app ID " + appId) + logInfo(log"Connected to Spark cluster with app ID ${MDC(LogKeys.APP_ID, appId)}") this.appId = appId notifyContext() launcherBackend.setAppId(appId) @@ -162,7 +161,7 @@ private[spark] class StandaloneSchedulerBackend( notifyContext() if (!stopping.get) { launcherBackend.setState(SparkAppHandle.State.KILLED) - logError(log"Application has been killed. Reason: ${MDC(REASON, reason)}") + logError(log"Application has been killed. Reason: ${MDC(LogKeys.REASON, reason)}") try { scheduler.error(reason) } finally { @@ -174,8 +173,9 @@ private[spark] class StandaloneSchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit = { - logInfo("Granted executor ID %s on hostPort %s with %d core(s), %s RAM".format( - fullId, hostPort, cores, Utils.megabytesToString(memory))) + logInfo(log"Granted executor ID ${MDC(LogKeys.EXECUTOR_ID, fullId)} on hostPort " + + log"${MDC(LogKeys.HOST_PORT, hostPort)} with ${MDC(LogKeys.NUM_CORES, cores)} core(s), " + + log"${MDC(LogKeys.MEMORY_SIZE, Utils.megabytesToString(memory))} RAM") } override def executorRemoved( @@ -192,23 +192,28 @@ private[spark] class StandaloneSchedulerBackend( case Some(code) => ExecutorExited(code, exitCausedByApp = true, message) case None => ExecutorProcessLost(message, workerHost, causedByApp = workerHost.isEmpty) } - logInfo("Executor %s removed: %s".format(fullId, message)) + logInfo( + log"Executor ${MDC(LogKeys.EXECUTOR_ID, fullId)} removed: ${MDC(LogKeys.MESSAGE, message)}") removeExecutor(fullId.split("/")(1), reason) } override def executorDecommissioned(fullId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = { - logInfo(s"Asked to decommission executor $fullId") + logInfo(log"Asked to decommission executor ${MDC(LogKeys.EXECUTOR_ID, fullId)}") val execId = fullId.split("/")(1) decommissionExecutors( Array((execId, decommissionInfo)), adjustTargetNumExecutors = false, triggeredByExecutor = false) - logInfo("Executor %s decommissioned: %s".format(fullId, decommissionInfo)) + logInfo( + log"Executor ${MDC(LogKeys.EXECUTOR_ID, fullId)} " + + log"decommissioned: ${MDC(LogKeys.DESCRIPTION, decommissionInfo)}" + ) } override def workerRemoved(workerId: String, host: String, message: String): Unit = { - logInfo("Worker %s removed: %s".format(workerId, message)) + logInfo(log"Worker ${MDC(LogKeys.WORKER_ID, workerId)} removed: " + + log"${MDC(LogKeys.MESSAGE, message)}") removeWorker(workerId, host, message) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala index c389b0c988f4d..57505c87f879e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark._ import org.apache.spark.errors.SparkCoreErrors -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.config._ import org.apache.spark.resource.ResourceProfile.UNKNOWN_RESOURCE_PROFILE_ID import org.apache.spark.scheduler._ @@ -342,7 +342,8 @@ private[spark] class ExecutorMonitor( override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { val exec = ensureExecutorIsTracked(event.executorId, event.executorInfo.resourceProfileId) exec.updateRunningTasks(0) - logInfo(s"New executor ${event.executorId} has registered (new total is ${executors.size()})") + logInfo(log"New executor ${MDC(LogKeys.EXECUTOR_ID, event.executorId)} has registered " + + log"(new total is ${MDC(LogKeys.COUNT, executors.size())})") } private def decrementExecResourceProfileCount(rpId: Int): Unit = { @@ -365,11 +366,14 @@ private[spark] class ExecutorMonitor( } else { metrics.exitedUnexpectedly.inc() } - logInfo(s"Executor ${event.executorId} is removed. Remove reason statistics: (" + - s"gracefully decommissioned: ${metrics.gracefullyDecommissioned.getCount()}, " + - s"decommision unfinished: ${metrics.decommissionUnfinished.getCount()}, " + - s"driver killed: ${metrics.driverKilled.getCount()}, " + - s"unexpectedly exited: ${metrics.exitedUnexpectedly.getCount()}).") + // scalastyle:off line.size.limit + logInfo(log"Executor ${MDC(LogKeys.EXECUTOR_ID, event.executorId)} is removed. " + + log"Remove reason statistics: (gracefully decommissioned: " + + log"${MDC(LogKeys.NUM_DECOMMISSIONED, metrics.gracefullyDecommissioned.getCount())}, " + + log"decommission unfinished: ${MDC(LogKeys.NUM_UNFINISHED_DECOMMISSIONED, metrics.decommissionUnfinished.getCount())}, " + + log"driver killed: ${MDC(LogKeys.NUM_EXECUTORS_KILLED, metrics.driverKilled.getCount())}, " + + log"unexpectedly exited: ${MDC(LogKeys.NUM_EXECUTORS_EXITED, metrics.exitedUnexpectedly.getCount())}).") + // scalastyle:on line.size.limit if (!removed.pendingRemoval || !removed.decommissioning) { nextTimeout.set(Long.MinValue) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 26e2acf4392ca..8655b72310795 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -535,7 +535,7 @@ private[spark] class BlockManager( val priorityClass = conf.get(config.STORAGE_REPLICATION_POLICY) val clazz = Utils.classForName(priorityClass) val ret = clazz.getConstructor().newInstance().asInstanceOf[BlockReplicationPolicy] - logInfo(s"Using $priorityClass for block replication policy") + logInfo(log"Using ${MDC(CLASS_NAME, priorityClass)} for block replication policy") ret } @@ -547,7 +547,7 @@ private[spark] class BlockManager( // the registration with the ESS. Therefore, this registration should be prior to // the BlockManager registration. See SPARK-39647. if (externalShuffleServiceEnabled) { - logInfo(s"external shuffle service port = $externalShuffleServicePort") + logInfo(log"external shuffle service port = ${MDC(PORT, externalShuffleServicePort)}") shuffleServerId = BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) if (!isDriver && !(Utils.isTesting && conf.get(Tests.TEST_SKIP_ESS_REGISTER))) { @@ -585,7 +585,7 @@ private[spark] class BlockManager( } } - logInfo(s"Initialized BlockManager: $blockManagerId") + logInfo(log"Initialized BlockManager: ${MDC(BLOCK_MANAGER_ID, blockManagerId)}") } def shuffleMetricsSource: Source = { @@ -646,7 +646,7 @@ private[spark] class BlockManager( * will be made then. */ private def reportAllBlocks(): Unit = { - logInfo(s"Reporting ${blockInfoManager.size} blocks to the master.") + logInfo(log"Reporting ${MDC(NUM_BLOCKS, blockInfoManager.size)} blocks to the master.") for ((blockId, info) <- blockInfoManager.entries) { val status = getCurrentBlockStatus(blockId, info) if (info.tellMaster && !tryToReportBlockStatus(blockId, status)) { @@ -664,7 +664,7 @@ private[spark] class BlockManager( */ def reregister(): Unit = { // TODO: We might need to rate limit re-registering. - logInfo(s"BlockManager $blockManagerId re-registering with master") + logInfo(log"BlockManager ${MDC(BLOCK_MANAGER_ID, blockManagerId)} re-registering with master") val id = master.registerBlockManager(blockManagerId, diskBlockManager.localDirsString, maxOnHeapMemory, maxOffHeapMemory, storageEndpoint, isReRegister = true) if (id.executorId != BlockManagerId.INVALID_EXECUTOR_ID) { @@ -875,7 +875,7 @@ private[spark] class BlockManager( droppedMemorySize: Long = 0L): Unit = { val needReregister = !tryToReportBlockStatus(blockId, status, droppedMemorySize) if (needReregister) { - logInfo(s"Got told to re-register updating block $blockId") + logInfo(log"Got told to re-register updating block ${MDC(BLOCK_ID, blockId)}") // Re-registering will report our new block for free. asyncReregister() } @@ -1139,8 +1139,9 @@ private[spark] class BlockManager( None } } - logInfo(s"Read $blockId from the disk of a same host executor is " + - (if (res.isDefined) "successful." else "failed.")) + logInfo( + log"Read ${MDC(BLOCK_ID, blockId)} from the disk of a same host executor is " + + log"${MDC(STATUS, if (res.isDefined) "successful." else "failed.")}") res }.orElse { fetchRemoteManagedBuffer(blockId, blockSize, locationsAndStatus).map(bufferTransformer) @@ -1308,12 +1309,12 @@ private[spark] class BlockManager( def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = { val local = getLocalValues(blockId) if (local.isDefined) { - logInfo(s"Found block $blockId locally") + logInfo(log"Found block ${MDC(BLOCK_ID, blockId)} locally") return local } val remote = getRemoteValues[T](blockId) if (remote.isDefined) { - logInfo(s"Found block $blockId remotely") + logInfo(log"Found block ${MDC(BLOCK_ID, blockId)} remotely") return remote } None @@ -1820,7 +1821,8 @@ private[spark] class BlockManager( existingReplicas: Set[BlockManagerId], maxReplicas: Int, maxReplicationFailures: Option[Int] = None): Boolean = { - logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") + logInfo(log"Using ${MDC(BLOCK_MANAGER_ID, blockManagerId)} to pro-actively replicate " + + log"${MDC(BLOCK_ID, blockId)}") blockInfoManager.lockForReading(blockId).forall { info => val data = doGetLocalBytes(blockId, info) val storageLevel = StorageLevel( @@ -1977,14 +1979,14 @@ private[spark] class BlockManager( private[storage] override def dropFromMemory[T: ClassTag]( blockId: BlockId, data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { - logInfo(s"Dropping block $blockId from memory") + logInfo(log"Dropping block ${MDC(BLOCK_ID, blockId)} from memory") val info = blockInfoManager.assertBlockIsLockedForWriting(blockId) var blockIsUpdated = false val level = info.level // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { - logInfo(s"Writing block $blockId to disk") + logInfo(log"Writing block ${MDC(BLOCK_ID, blockId)} to disk") data() match { case Left(elements) => diskStore.put(blockId) { channel => @@ -2028,7 +2030,7 @@ private[spark] class BlockManager( */ def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. - logInfo(s"Removing RDD $rddId") + logInfo(log"Removing RDD ${MDC(RDD_ID, rddId)}") val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala index fc98fbf6e18b3..19807453ee28c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala @@ -73,14 +73,15 @@ private[storage] class BlockManagerDecommissioner( private def allowRetry(shuffleBlock: ShuffleBlockInfo, failureNum: Int): Boolean = { if (failureNum < maxReplicationFailuresForDecommission) { - logInfo(s"Add $shuffleBlock back to migration queue for " + - s"retry ($failureNum / $maxReplicationFailuresForDecommission)") + logInfo(log"Add ${MDC(SHUFFLE_BLOCK_INFO, shuffleBlock)} back to migration queue for " + + log" retry (${MDC(FAILURES, failureNum)} / " + + log"${MDC(MAX_ATTEMPTS, maxReplicationFailuresForDecommission)})") // The block needs to retry so we should not mark it as finished shufflesToMigrate.add((shuffleBlock, failureNum)) } else { logWarning(log"Give up migrating ${MDC(SHUFFLE_BLOCK_INFO, shuffleBlock)} " + log"since it's been failed for " + - log"${MDC(NUM_FAILURES, maxReplicationFailuresForDecommission)} times") + log"${MDC(MAX_ATTEMPTS, maxReplicationFailuresForDecommission)} times") false } } @@ -98,7 +99,7 @@ private[storage] class BlockManagerDecommissioner( } override def run(): Unit = { - logInfo(s"Starting shuffle block migration thread for $peer") + logInfo(log"Starting shuffle block migration thread for ${MDC(PEER, peer)}") // Once a block fails to transfer to an executor stop trying to transfer more blocks while (keepRunning) { try { @@ -107,10 +108,12 @@ private[storage] class BlockManagerDecommissioner( var isTargetDecommissioned = false // We only migrate a shuffle block when both index file and data file exist. if (blocks.isEmpty) { - logInfo(s"Ignore deleted shuffle block $shuffleBlockInfo") + logInfo(log"Ignore deleted shuffle block ${MDC(SHUFFLE_BLOCK_INFO, shuffleBlockInfo)}") } else { - logInfo(s"Got migration sub-blocks $blocks. Trying to migrate $shuffleBlockInfo " + - s"to $peer ($retryCount / $maxReplicationFailuresForDecommission)") + logInfo(log"Got migration sub-blocks ${MDC(BLOCK_IDS, blocks)}. Trying to migrate " + + log"${MDC(SHUFFLE_BLOCK_INFO, shuffleBlockInfo)} to ${MDC(PEER, peer)} " + + log"(${MDC(NUM_RETRY, retryCount)} / " + + log"${MDC(MAX_ATTEMPTS, maxReplicationFailuresForDecommission)}") // Migrate the components of the blocks. try { val startTime = System.currentTimeMillis() @@ -130,9 +133,10 @@ private[storage] class BlockManagerDecommissioner( logDebug(s"Migrated sub-block $blockId") } } - logInfo(s"Migrated $shuffleBlockInfo (" + - s"size: ${Utils.bytesToString(blocks.map(b => b._2.size()).sum)}) to $peer " + - s"in ${System.currentTimeMillis() - startTime} ms") + logInfo(log"Migrated ${MDC(SHUFFLE_BLOCK_INFO, shuffleBlockInfo)} (" + + log"size: ${MDC(SIZE, Utils.bytesToString(blocks.map(b => b._2.size()).sum))}) " + + log"to ${MDC(PEER, peer)} in " + + log"${MDC(DURATION, System.currentTimeMillis() - startTime)} ms") } catch { case e @ ( _ : IOException | _ : SparkException) => // If a block got deleted before netty opened the file handle, then trying to @@ -181,7 +185,11 @@ private[storage] class BlockManagerDecommissioner( } } catch { case _: InterruptedException => - logInfo(s"Stop shuffle block migration${if (keepRunning) " unexpectedly"}.") + if (keepRunning) { + logInfo("Stop shuffle block migration unexpectedly.") + } else { + logInfo("Stop shuffle block migration.") + } keepRunning = false case NonFatal(e) => keepRunning = false @@ -234,12 +242,16 @@ private[storage] class BlockManagerDecommissioner( logInfo("Attempting to migrate all cached RDD blocks") rddBlocksLeft = decommissionRddCacheBlocks() lastRDDMigrationTime = startTime - logInfo(s"Finished current round RDD blocks migration, " + - s"waiting for ${sleepInterval}ms before the next round migration.") + logInfo(log"Finished current round RDD blocks migration, " + + log"waiting for ${MDC(SLEEP_TIME, sleepInterval)}ms before the next round migration.") Thread.sleep(sleepInterval) } catch { case _: InterruptedException => - logInfo(s"Stop RDD blocks migration${if (!stopped && !stoppedRDD) " unexpectedly"}.") + if (!stopped && !stoppedRDD) { + logInfo("Stop RDD blocks migration unexpectedly.") + } else { + logInfo("Stop RDD blocks migration.") + } stoppedRDD = true case NonFatal(e) => logError("Error occurred during RDD blocks migration.", e) @@ -265,8 +277,9 @@ private[storage] class BlockManagerDecommissioner( val startTime = System.nanoTime() shuffleBlocksLeft = refreshMigratableShuffleBlocks() lastShuffleMigrationTime = startTime - logInfo(s"Finished current round refreshing migratable shuffle blocks, " + - s"waiting for ${sleepInterval}ms before the next round refreshing.") + logInfo(log"Finished current round refreshing migratable shuffle blocks, " + + log"waiting for ${MDC(SLEEP_TIME, sleepInterval)}ms before the " + + log"next round refreshing.") Thread.sleep(sleepInterval) } catch { case _: InterruptedException if stopped => @@ -302,8 +315,9 @@ private[storage] class BlockManagerDecommissioner( shufflesToMigrate.addAll(newShufflesToMigrate.map(x => (x, 0)).asJava) migratingShuffles ++= newShufflesToMigrate val remainedShuffles = migratingShuffles.size - numMigratedShuffles.get() - logInfo(s"${newShufflesToMigrate.size} of ${localShuffles.size} local shuffles " + - s"are added. In total, $remainedShuffles shuffles are remained.") + logInfo(log"${MDC(COUNT, newShufflesToMigrate.size)} of " + + log"${MDC(TOTAL, localShuffles.size)} local shuffles are added. " + + log"In total, ${MDC(NUM_REMAINED, remainedShuffles)} shuffles are remained.") // Update the threads doing migrations val livePeerSet = bm.getPeers(false).toSet @@ -350,8 +364,9 @@ private[storage] class BlockManagerDecommissioner( // Refresh peers and validate we have somewhere to move blocks. if (replicateBlocksInfo.nonEmpty) { - logInfo(s"Need to replicate ${replicateBlocksInfo.size} RDD blocks " + - "for block manager decommissioning") + logInfo( + log"Need to replicate ${MDC(NUM_REPLICAS, replicateBlocksInfo.size)} RDD blocks " + + log"for block manager decommissioning") } else { logWarning("Asked to decommission RDD cache blocks, but no blocks to migrate") return false @@ -378,9 +393,10 @@ private[storage] class BlockManagerDecommissioner( blockToReplicate.maxReplicas, maxReplicationFailures = Some(maxReplicationFailuresForDecommission)) if (replicatedSuccessfully) { - logInfo(s"Block ${blockToReplicate.blockId} migrated successfully, Removing block now") + logInfo(log"Block ${MDC(BLOCK_ID, blockToReplicate.blockId)} migrated " + + log"successfully, Removing block now") bm.removeBlock(blockToReplicate.blockId) - logInfo(s"Block ${blockToReplicate.blockId} removed") + logInfo(log"Block ${MDC(BLOCK_ID, blockToReplicate.blockId)} removed") } else { logWarning(log"Failed to migrate block ${MDC(BLOCK_ID, blockToReplicate.blockId)}") } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 95af44deef93d..276bd63e14237 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -42,7 +42,7 @@ class BlockManagerMaster( /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String): Unit = { tell(RemoveExecutor(execId)) - logInfo("Removed " + execId + " successfully in removeExecutor") + logInfo(log"Removed ${MDC(EXECUTOR_ID, execId)} successfully in removeExecutor") } /** Decommission block managers corresponding to given set of executors @@ -62,7 +62,7 @@ class BlockManagerMaster( */ def removeExecutorAsync(execId: String): Unit = { driverEndpoint.ask[Boolean](RemoveExecutor(execId)) - logInfo("Removal of executor " + execId + " requested") + logInfo(log"Removal of executor ${MDC(EXECUTOR_ID, execId)} requested") } /** @@ -77,7 +77,7 @@ class BlockManagerMaster( maxOffHeapMemSize: Long, storageEndpoint: RpcEndpointRef, isReRegister: Boolean = false): BlockManagerId = { - logInfo(s"Registering BlockManager $id") + logInfo(log"Registering BlockManager ${MDC(BLOCK_MANAGER_ID, id)}") val updatedId = driverEndpoint.askSync[BlockManagerId]( RegisterBlockManager( id, @@ -90,9 +90,9 @@ class BlockManagerMaster( ) if (updatedId.executorId == BlockManagerId.INVALID_EXECUTOR_ID) { assert(isReRegister, "Got invalid executor id from non re-register case") - logInfo(s"Re-register BlockManager $id failed") + logInfo(log"Re-register BlockManager ${MDC(BLOCK_MANAGER_ID, id)} failed") } else { - logInfo(s"Registered BlockManager $updatedId") + logInfo(log"Registered BlockManager ${MDC(BLOCK_MANAGER_ID, updatedId)}") } updatedId } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index be7082807182b..73f89ea0e86e5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -110,7 +110,7 @@ class BlockManagerMasterEndpoint( val clazz = Utils.classForName(topologyMapperClassName) val mapper = clazz.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[TopologyMapper] - logInfo(s"Using $topologyMapperClassName for getting topology information") + logInfo(log"Using ${MDC(CLASS_NAME, topologyMapperClassName)} for getting topology information") mapper } @@ -218,7 +218,8 @@ class BlockManagerMasterEndpoint( // executor is notified(see BlockManager.decommissionSelf), so we don't need to send the // notification here. val bms = executorIds.flatMap(blockManagerIdByExecutor.get) - logInfo(s"Mark BlockManagers (${bms.mkString(", ")}) as being decommissioning.") + logInfo(log"Mark BlockManagers (${MDC(BLOCK_MANAGER_IDS, bms.mkString(", "))}) as " + + log"being decommissioning.") decommissioningBlockManagerSet ++= bms context.reply(true) @@ -535,7 +536,7 @@ class BlockManagerMasterEndpoint( } listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) - logInfo(s"Removing block manager $blockManagerId") + logInfo(log"Removing block manager ${MDC(BLOCK_MANAGER_ID, blockManagerId)}") } @@ -551,7 +552,7 @@ class BlockManagerMasterEndpoint( } private def removeExecutor(execId: String): Unit = { - logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + logInfo(log"Trying to remove executor ${MDC(EXECUTOR_ID, execId)} from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } @@ -707,8 +708,9 @@ class BlockManagerMasterEndpoint( removeExecutor(id.executorId) case None => } - logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxOnHeapMemSize + maxOffHeapMemSize), id)) + logInfo(log"Registering block manager ${MDC(HOST_PORT, id.hostPort)} with " + + log"${MDC(MEMORY_SIZE, Utils.bytesToString(maxOnHeapMemSize + maxOffHeapMemSize))} RAM, " + + log"${MDC(BLOCK_MANAGER_ID, id)}") blockManagerIdByExecutor(id.executorId) = id @@ -738,8 +740,8 @@ class BlockManagerMasterEndpoint( assert(!blockManagerInfo.contains(id), "BlockManager re-registration shouldn't succeed when the executor is lost") - logInfo(s"BlockManager ($id) re-registration is rejected since " + - s"the executor (${id.executorId}) has been lost") + logInfo(log"BlockManager (${MDC(BLOCK_MANAGER_ID, id)}) re-registration is rejected since " + + log"the executor (${MDC(EXECUTOR_ID, id.executorId)}) has been lost") // Use "invalid" as the return executor id to indicate the block manager that // re-registration failed. It's a bit hacky but fine since the returned block @@ -1057,26 +1059,30 @@ private[spark] class BlockManagerInfo( _blocks.put(blockId, blockStatus) _remainingMem -= memSize if (blockExists) { - logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" + - s" (current size: ${Utils.bytesToString(memSize)}," + - s" original size: ${Utils.bytesToString(originalMemSize)}," + - s" free: ${Utils.bytesToString(_remainingMem)})") + logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} in memory on " + + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (current size: " + + log"${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, original " + + log"size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } else { - logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" + - s" (size: ${Utils.bytesToString(memSize)}," + - s" free: ${Utils.bytesToString(_remainingMem)})") + logInfo(log"Added ${MDC(BLOCK_ID, blockId)} in memory on " + + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + + log"(size: ${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, " + + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } } if (storageLevel.useDisk) { blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) if (blockExists) { - logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" + - s" (current size: ${Utils.bytesToString(diskSize)}," + - s" original size: ${Utils.bytesToString(originalDiskSize)})") + logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} on disk on " + + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + + log"(current size: ${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))}," + + log" original size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } else { - logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" + - s" (size: ${Utils.bytesToString(diskSize)})") + logInfo(log"Added ${MDC(BLOCK_ID, blockId)} on disk on " + + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (size: " + + log"${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))})") } } @@ -1092,13 +1098,15 @@ private[spark] class BlockManagerInfo( blockStatus.remove(blockId) } if (originalLevel.useMemory) { - logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + - s" (size: ${Utils.bytesToString(originalMemSize)}," + - s" free: ${Utils.bytesToString(_remainingMem)})") + logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + log"${MDC(HOST_PORT, blockManagerId.hostPort)} in memory " + + log"(size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } if (originalLevel.useDisk) { - logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" + - s" (size: ${Utils.bytesToString(originalDiskSize)})") + logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + log"${MDC(HOST_PORT, blockManagerId.hostPort)} on disk" + + log" (size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } } } diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 059bb52714106..8a3ca3066961c 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -342,7 +342,8 @@ private class PushBasedFetchHelper( // Fallback for all the pending fetch requests val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) pendingShuffleChunks.foreach { pendingBlockId => - logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + logInfo( + log"Falling back immediately for shuffle chunk ${MDC(BLOCK_ID, pendingBlockId)}") shuffleMetrics.incMergedFetchFallbackCount(1) val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get chunkBitmap.or(bitmapOfPendingChunk) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b89d24f4c6ddc..ff1799d8ff3e1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -342,8 +342,8 @@ final class ShuffleBlockFetcherIterator( if (isNettyOOMOnShuffle.compareAndSet(false, true)) { // The fetcher can fail remaining blocks in batch for the same error. So we only // log the warning once to avoid flooding the logs. - logInfo(s"Block $blockId has failed $failureTimes times " + - s"due to Netty OOM, will retry") + logInfo(log"Block ${MDC(BLOCK_ID, blockId)} has failed " + + log"${MDC(FAILURES, failureTimes)} times due to Netty OOM, will retry") } remainingBlocks -= blockId deferredBlocks += blockId @@ -448,14 +448,17 @@ final class ShuffleBlockFetcherIterator( s"the number of host-local blocks ${numHostLocalBlocks} " + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + s"+ the number of remote blocks ${numRemoteBlocks} ") - logInfo(s"Getting $blocksToFetchCurrentIteration " + - s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + - s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + - s"${numHostLocalBlocks} (${Utils.bytesToString(hostLocalBlockBytes)}) " + - s"host-local and ${pushMergedLocalBlocks.size} " + - s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + - s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + - s"remote blocks") + logInfo( + log"Getting ${MDC(NUM_BLOCKS, blocksToFetchCurrentIteration)} " + + log"(${MDC(TOTAL_SIZE, Utils.bytesToString(totalBytes))}) non-empty blocks including " + + log"${MDC(NUM_LOCAL_BLOCKS, localBlocks.size)} " + + log"(${MDC(LOCAL_BLOCKS_SIZE, Utils.bytesToString(localBlockBytes))}) local and " + + log"${MDC(NUM_HOST_LOCAL_BLOCKS, numHostLocalBlocks)} " + + log"(${MDC(HOST_LOCAL_BLOCKS_SIZE, Utils.bytesToString(hostLocalBlockBytes))}) " + + log"host-local and ${MDC(NUM_PUSH_MERGED_LOCAL_BLOCKS, pushMergedLocalBlocks.size)} " + + log"(${MDC(PUSH_MERGED_LOCAL_BLOCKS_SIZE, Utils.bytesToString(pushMergedLocalBlockBytes))})" + + log" push-merged-local and ${MDC(NUM_REMOTE_BLOCKS, numRemoteBlocks)} " + + log"(${MDC(REMOTE_BLOCKS_SIZE, Utils.bytesToString(remoteBlockBytes))}) remote blocks") this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values .flatMap { infos => infos.map(info => (info._1, info._3)) } collectedRemoteRequests @@ -719,8 +722,10 @@ final class ShuffleBlockFetcherIterator( val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest - logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + - (if (numDeferredRequest > 0 ) s", deferred $numDeferredRequest requests" else "")) + logInfo(log"Started ${MDC(COUNT, numFetches)} remote fetches in " + + log"${MDC(DURATION, Utils.getUsedTimeNs(startTimeNs))}" + + (if (numDeferredRequest > 0) log", deferred ${MDC(NUM_REQUESTS, numDeferredRequest)} requests" + else log"")) // Get Local Blocks fetchLocalBlocks(localBlocks) @@ -1141,7 +1146,8 @@ final class ShuffleBlockFetcherIterator( case otherCause => s"Block $blockId is corrupted due to $otherCause" } - logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + logInfo(log"Finished corruption diagnosis in ${MDC(DURATION, duration)} ms. " + + log"${MDC(STATUS, diagnosisResponse)}") diagnosisResponse case shuffleBlockChunk: ShuffleBlockChunkId => // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle @@ -1277,7 +1283,8 @@ final class ShuffleBlockFetcherIterator( originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks) // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(originalRemoteReqs) - logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + logInfo(log"Created ${MDC(NUM_REQUESTS, originalRemoteReqs.size)} fallback remote requests " + + log"for push-merged") // fetch all the fallback blocks that are local. fetchLocalBlocks(originalLocalBlocks) // Merged local blocks should be empty during fallback diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 1283b9340a455..6746bbd490c42 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -117,7 +117,8 @@ private[spark] class MemoryStore( log"needed to store a block in memory. Please configure Spark with more memory.") } - logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) + logInfo(log"MemoryStore started with capacity " + + log"${MDC(MEMORY_SIZE, Utils.bytesToString(maxMemory))}") /** Total storage memory used including unroll memory, in bytes. */ private def memoryUsed: Long = memoryManager.storageMemoryUsed @@ -158,8 +159,9 @@ private[spark] class MemoryStore( entries.synchronized { entries.put(blockId, entry) } - logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + logInfo(log"Block ${MDC(BLOCK_ID, blockId)} stored as bytes in memory " + + log"(estimated size ${MDC(SIZE, Utils.bytesToString(size))}, " + + log"free ${MDC(MEMORY_SIZE, Utils.bytesToString(maxMemory - blocksMemoryUsed))})") true } else { false @@ -250,7 +252,8 @@ private[spark] class MemoryStore( // SPARK-45025 - if a thread interrupt was received, we log a warning and return used memory // to avoid getting killed by task reaper eventually. if (shouldCheckThreadInterruption && Thread.currentThread().isInterrupted) { - logInfo(s"Failed to unroll block=$blockId since thread interrupt was received") + logInfo( + log"Failed to unroll block=${MDC(BLOCK_ID, blockId)} since thread interrupt was received") Left(unrollMemoryUsedByThisBlock) } else if (keepUnrolling) { // Make sure that we have enough memory to store the block. By this point, it is possible that @@ -279,8 +282,9 @@ private[spark] class MemoryStore( entries.put(blockId, entry) } - logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(blockId, - Utils.bytesToString(entry.size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + logInfo(log"Block ${MDC(BLOCK_ID, blockId)} stored as values in memory " + + log"(estimated size ${MDC(MEMORY_SIZE, Utils.bytesToString(entry.size))}, free " + + log"${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(maxMemory - blocksMemoryUsed))})") Right(entry.size) } else { // We ran out of space while unrolling the values for this block @@ -521,8 +525,8 @@ private[spark] class MemoryStore( if (freedMemory >= space) { var lastSuccessfulBlock = -1 try { - logInfo(s"${selectedBlocks.size} blocks selected for dropping " + - s"(${Utils.bytesToString(freedMemory)} bytes)") + logInfo(log"${MDC(NUM_BLOCKS, selectedBlocks.size)} blocks selected for dropping " + + log"(${MDC(MEMORY_SIZE, Utils.bytesToString(freedMemory))} bytes)") selectedBlocks.indices.foreach { idx => val blockId = selectedBlocks(idx) val entry = entries.synchronized { @@ -537,8 +541,9 @@ private[spark] class MemoryStore( } lastSuccessfulBlock = idx } - logInfo(s"After dropping ${selectedBlocks.size} blocks, " + - s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") + logInfo( + log"After dropping ${MDC(NUM_BLOCKS, selectedBlocks.size)} blocks, free memory is" + + log"${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(maxMemory - blocksMemoryUsed))}") freedMemory } finally { // like BlockManager.doPut, we use a finally rather than a catch to avoid having to deal @@ -553,7 +558,7 @@ private[spark] class MemoryStore( } } else { blockId.foreach { id => - logInfo(s"Will not store $id") + logInfo(log"Will not store ${MDC(BLOCK_ID, id)}") } selectedBlocks.foreach { id => blockInfoManager.unlock(id) @@ -649,11 +654,11 @@ private[spark] class MemoryStore( */ private def logMemoryUsage(): Unit = { logInfo( - s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " + - s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " + - s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " + - s"Storage limit = ${Utils.bytesToString(maxMemory)}." - ) + log"Memory use = ${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(blocksMemoryUsed))} " + + log"(blocks) + ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(currentUnrollMemory))} " + + log"(scratch space shared across ${MDC(NUM_TASKS, numTasksUnrolling)} " + + log"tasks(s)) = ${MDC(STORAGE_MEMORY_SIZE, Utils.bytesToString(memoryUsed))}. " + + log"Storage limit = ${MDC(MAX_MEMORY_SIZE, Utils.bytesToString(maxMemory))}.") } /** From 416d7f24fc354e912773ceb160210ad6a0c5fe99 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 24 May 2024 20:53:00 -0700 Subject: [PATCH 24/45] [SPARK-48239][INFRA][FOLLOWUP] install the missing `jq` library ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/46534 . We missed the `jq` library which is needed to create git tags. ### Why are the changes needed? fix bug ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manual ### Was this patch authored or co-authored using generative AI tooling? no Closes #46743 from cloud-fan/script. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- dev/create-release/release-util.sh | 3 +++ dev/create-release/spark-rm/Dockerfile | 1 + 2 files changed, 4 insertions(+) diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh index 0394fb49c2fa0..b5edbf40d487d 100755 --- a/dev/create-release/release-util.sh +++ b/dev/create-release/release-util.sh @@ -128,6 +128,9 @@ function get_release_info { RC_COUNT=1 fi + if [ "$GIT_BRANCH" = "master" ]; then + RELEASE_VERSION="$RELEASE_VERSION-preview1" + fi export NEXT_VERSION export RELEASE_VERSION=$(read_config "Release" "$RELEASE_VERSION") diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index adaa4df3f5791..5fdaf58feee2e 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -58,6 +58,7 @@ RUN apt-get update && apt-get install -y \ texinfo \ texlive-latex-extra \ qpdf \ + jq \ r-base \ ruby \ ruby-dev \ From 468aa842c6435b3c3ff49df30e8958d08ec2edb0 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 25 May 2024 09:43:01 -0700 Subject: [PATCH 25/45] [SPARK-48320][CORE][DOCS] Add structured logging guide to the scala and java doc ### What changes were proposed in this pull request? The pr aims to add `external third-party ecosystem access` guide to the `scala/java` doc. The external third-party ecosystem is very extensive. Currently, the document covers two scenarios: - Pure java (for example, an application only uses the java language - many of our internal production applications are like this) - java + scala ### Why are the changes needed? Provide instructions for external third-party ecosystem access to the structured log framework. ### Does this PR introduce _any_ user-facing change? Yes, When an external third-party ecosystem wants to access the structured log framework, developers can get help through this document. ### How was this patch tested? - Add new UT. - Manually test. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46634 from panbingkun/SPARK-48320. Lead-authored-by: panbingkun Co-authored-by: panbingkun Signed-off-by: Gengliang Wang --- .../apache/spark/internal/SparkLogger.java | 43 +++++++++++++++ .../org/apache/spark/internal/LogKey.scala | 30 +++++++++- .../org/apache/spark/internal/Logging.scala | 38 +++++++++++++ .../scala/org/apache/spark/internal/README.md | 47 ---------------- .../spark/util/PatternSparkLoggerSuite.java | 9 ++- .../spark/util/SparkLoggerSuiteBase.java | 55 +++++++++++++++---- .../util/StructuredSparkLoggerSuite.java | 20 ++++++- .../spark/util/PatternLoggingSuite.scala | 4 +- .../spark/util/StructuredLoggingSuite.scala | 29 +++++----- 9 files changed, 193 insertions(+), 82 deletions(-) delete mode 100644 common/utils/src/main/scala/org/apache/spark/internal/README.md diff --git a/common/utils/src/main/java/org/apache/spark/internal/SparkLogger.java b/common/utils/src/main/java/org/apache/spark/internal/SparkLogger.java index bf8adb70637e2..32dd8f1f26b58 100644 --- a/common/utils/src/main/java/org/apache/spark/internal/SparkLogger.java +++ b/common/utils/src/main/java/org/apache/spark/internal/SparkLogger.java @@ -28,6 +28,49 @@ import org.slf4j.Logger; // checkstyle.on: RegexpSinglelineJava +// checkstyle.off: RegexpSinglelineJava +/** + * Guidelines for the Structured Logging Framework - Java Logging + *

+ * + * Use the `org.apache.spark.internal.SparkLoggerFactory` to get the logger instance in Java code: + * Getting Logger Instance: + * Instead of using `org.slf4j.LoggerFactory`, use `org.apache.spark.internal.SparkLoggerFactory` + * to ensure structured logging. + *

+ * + * import org.apache.spark.internal.SparkLogger; + * import org.apache.spark.internal.SparkLoggerFactory; + * private static final SparkLogger logger = SparkLoggerFactory.getLogger(JavaUtils.class); + *

+ * + * Logging Messages with Variables: + * When logging messages with variables, wrap all the variables with `MDC`s and they will be + * automatically added to the Mapped Diagnostic Context (MDC). + *

+ * + * import org.apache.spark.internal.LogKeys; + * import org.apache.spark.internal.MDC; + * logger.error("Unable to delete file for partition {}", MDC.of(LogKeys.PARTITION_ID$.MODULE$, i)); + *

+ * + * Constant String Messages: + * For logging constant string messages, use the standard logging methods. + *

+ * + * logger.error("Failed to abort the writer after failing to write map output.", e); + *

+ * + * If you want to output logs in `java code` through the structured log framework, + * you can define `custom LogKey` and use it in `java` code as follows: + *

+ * + * // To add a `custom LogKey`, implement `LogKey` + * public static class CUSTOM_LOG_KEY implements LogKey { } + * import org.apache.spark.internal.MDC; + * logger.error("Unable to delete key {} for cache", MDC.of(CUSTOM_LOG_KEY, "key")); + */ +// checkstyle.on: RegexpSinglelineJava public class SparkLogger { private static final MessageFactory MESSAGE_FACTORY = ParameterizedMessageFactory.INSTANCE; diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index 534f009119226..1366277827f75 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -20,9 +20,37 @@ import java.util.Locale /** * All structured logging `keys` used in `MDC` must be extends `LogKey` + *

+ * + * `LogKey`s serve as identifiers for mapped diagnostic contexts (MDC) within logs. + * Follow these guidelines when adding a new LogKey: + *

    + *
  • + * Define all structured logging keys in `LogKey.scala`, and sort them alphabetically for + * ease of search. + *
  • + *
  • + * Use `UPPER_SNAKE_CASE` for key names. + *
  • + *
  • + * Key names should be both simple and broad, yet include specific identifiers like `STAGE_ID`, + * `TASK_ID`, and `JOB_ID` when needed for clarity. For instance, use `MAX_ATTEMPTS` as a + * general key instead of creating separate keys for each scenario such as + * `EXECUTOR_STATE_SYNC_MAX_ATTEMPTS` and `MAX_TASK_FAILURES`. + * This balances simplicity with the detail needed for effective logging. + *
  • + *
  • + * Use abbreviations in names if they are widely understood, + * such as `APP_ID` for APPLICATION_ID, and `K8S` for KUBERNETES. + *
  • + *
  • + * For time-related keys, use milliseconds as the unit of time. + *
  • + *
*/ trait LogKey { - val name: String = this.toString.toLowerCase(Locale.ROOT) + private lazy val _name: String = getClass.getSimpleName.stripSuffix("$").toLowerCase(Locale.ROOT) + def name: String = _name } /** diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala index 2ea61358b6add..72c7cdfa62362 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala @@ -29,6 +29,44 @@ import org.slf4j.{Logger, LoggerFactory} import org.apache.spark.internal.Logging.SparkShellLoggingFilter import org.apache.spark.util.SparkClassUtils +/** + * Guidelines for the Structured Logging Framework - Scala Logging + *

+ * + * Use the `org.apache.spark.internal.Logging` trait for logging in Scala code: + * Logging Messages with Variables: + * When logging a message with variables, wrap all the variables with `MDC`s and they will be + * automatically added to the Mapped Diagnostic Context (MDC). + * This allows for structured logging and better log analysis. + *

+ * + * logInfo(log"Trying to recover app: ${MDC(LogKeys.APP_ID, app.id)}") + *

+ * + * Constant String Messages: + * If you are logging a constant string message, use the log methods that accept a constant + * string. + *

+ * + * logInfo("StateStore stopped") + *

+ * + * Exceptions: + * To ensure logs are compatible with Spark SQL and log analysis tools, avoid + * `Exception.printStackTrace()`. Use `logError`, `logWarning`, and `logInfo` methods from + * the `Logging` trait to log exceptions, maintaining structured and parsable logs. + *

+ * + * If you want to output logs in `scala code` through the structured log framework, + * you can define `custom LogKey` and use it in `scala` code as follows: + *

+ * + * // To add a `custom LogKey`, implement `LogKey` + * case object CUSTOM_LOG_KEY extends LogKey + * import org.apache.spark.internal.MDC; + * logInfo(log"${MDC(CUSTOM_LOG_KEY, "key")}") + */ + /** * Mapped Diagnostic Context (MDC) that will be used in log messages. * The values of the MDC will be inline in the log message, while the key-value pairs will be diff --git a/common/utils/src/main/scala/org/apache/spark/internal/README.md b/common/utils/src/main/scala/org/apache/spark/internal/README.md deleted file mode 100644 index 28d2794851870..0000000000000 --- a/common/utils/src/main/scala/org/apache/spark/internal/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# Guidelines for the Structured Logging Framework - -## Scala Logging -Use the `org.apache.spark.internal.Logging` trait for logging in Scala code: -* **Logging Messages with Variables**: When logging a message with variables, wrap all the variables with `MDC`s and they will be automatically added to the Mapped Diagnostic Context (MDC). This allows for structured logging and better log analysis. -```scala -logInfo(log"Trying to recover app: ${MDC(LogKeys.APP_ID, app.id)}") -``` -* **Constant String Messages**: If you are logging a constant string message, use the log methods that accept a constant string. -```scala -logInfo("StateStore stopped") -``` - -## Java Logging -Use the `org.apache.spark.internal.SparkLoggerFactory` to get the logger instance in Java code: -* **Getting Logger Instance**: Instead of using `org.slf4j.LoggerFactory`, use `org.apache.spark.internal.SparkLoggerFactory` to ensure structured logging. -```java -import org.apache.spark.internal.SparkLogger; -import org.apache.spark.internal.SparkLoggerFactory; - -private static final SparkLogger logger = SparkLoggerFactory.getLogger(JavaUtils.class); -``` -* **Logging Messages with Variables**: When logging messages with variables, wrap all the variables with `MDC`s and they will be automatically added to the Mapped Diagnostic Context (MDC). -```java -import org.apache.spark.internal.LogKeys; -import org.apache.spark.internal.MDC; - -logger.error("Unable to delete file for partition {}", MDC.of(LogKeys.PARTITION_ID$.MODULE$, i)); -``` - -* **Constant String Messages**: For logging constant string messages, use the standard logging methods. -```java -logger.error("Failed to abort the writer after failing to write map output.", e); -``` - -## LogKey - -`LogKey`s serve as identifiers for mapped diagnostic contexts (MDC) within logs. Follow these guidelines when adding a new LogKey: -* Define all structured logging keys in `LogKey.scala`, and sort them alphabetically for ease of search. -* Use `UPPER_SNAKE_CASE` for key names. -* Key names should be both simple and broad, yet include specific identifiers like `STAGE_ID`, `TASK_ID`, and `JOB_ID` when needed for clarity. For instance, use `MAX_ATTEMPTS` as a general key instead of creating separate keys for each scenario such as `EXECUTOR_STATE_SYNC_MAX_ATTEMPTS` and `MAX_TASK_FAILURES`. This balances simplicity with the detail needed for effective logging. -* Use abbreviations in names if they are widely understood, such as `APP_ID` for APPLICATION_ID, and `K8S` for KUBERNETES. -* For time-related keys, use milliseconds as the unit of time. - -## Exceptions - -To ensure logs are compatible with Spark SQL and log analysis tools, avoid `Exception.printStackTrace()`. Use `logError`, `logWarning`, and `logInfo` methods from the `Logging` trait to log exceptions, maintaining structured and parsable logs. diff --git a/common/utils/src/test/java/org/apache/spark/util/PatternSparkLoggerSuite.java b/common/utils/src/test/java/org/apache/spark/util/PatternSparkLoggerSuite.java index 2d370bad4cc80..1d2e6d76a7590 100644 --- a/common/utils/src/test/java/org/apache/spark/util/PatternSparkLoggerSuite.java +++ b/common/utils/src/test/java/org/apache/spark/util/PatternSparkLoggerSuite.java @@ -84,7 +84,12 @@ String expectedPatternForMsgWithMDCValueIsNull(Level level) { } @Override - String expectedPatternForExternalSystemCustomLogKey(Level level) { - return toRegexPattern(level, ".* : External system custom log message.\n"); + String expectedPatternForScalaCustomLogKey(Level level) { + return toRegexPattern(level, ".* : Scala custom log message.\n"); + } + + @Override + String expectedPatternForJavaCustomLogKey(Level level) { + return toRegexPattern(level, ".* : Java custom log message.\n"); } } diff --git a/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java b/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java index 0869f9827324d..90677b521640f 100644 --- a/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java +++ b/common/utils/src/test/java/org/apache/spark/util/SparkLoggerSuiteBase.java @@ -26,9 +26,10 @@ import org.apache.logging.log4j.Level; import org.junit.jupiter.api.Test; -import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.LogKey; import org.apache.spark.internal.LogKeys; import org.apache.spark.internal.MDC; +import org.apache.spark.internal.SparkLogger; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -68,8 +69,11 @@ private String basicMsg() { private final MDC executorIDMDCValueIsNull = MDC.of(LogKeys.EXECUTOR_ID$.MODULE$, null); - private final MDC externalSystemCustomLog = - MDC.of(CustomLogKeys.CUSTOM_LOG_KEY$.MODULE$, "External system custom log message."); + private final MDC scalaCustomLogMDC = + MDC.of(CustomLogKeys.CUSTOM_LOG_KEY$.MODULE$, "Scala custom log message."); + + private final MDC javaCustomLogMDC = + MDC.of(JavaCustomLogKeys.CUSTOM_LOG_KEY, "Java custom log message."); // test for basic message (without any mdc) abstract String expectedPatternForBasicMsg(Level level); @@ -89,8 +93,11 @@ private String basicMsg() { // test for message (with mdc - the value is null) abstract String expectedPatternForMsgWithMDCValueIsNull(Level level); - // test for external system custom LogKey - abstract String expectedPatternForExternalSystemCustomLogKey(Level level); + // test for scala custom LogKey + abstract String expectedPatternForScalaCustomLogKey(Level level); + + // test for java custom LogKey + abstract String expectedPatternForJavaCustomLogKey(Level level); @Test public void testBasicMsgLogger() { @@ -142,8 +149,6 @@ public void testLoggerWithMDC() { Runnable errorFn = () -> logger().error(msgWithMDC, executorIDMDC); Runnable warnFn = () -> logger().warn(msgWithMDC, executorIDMDC); Runnable infoFn = () -> logger().info(msgWithMDC, executorIDMDC); - Runnable debugFn = () -> logger().debug(msgWithMDC, executorIDMDC); - Runnable traceFn = () -> logger().trace(msgWithMDC, executorIDMDC); List.of( Pair.of(Level.ERROR, errorFn), Pair.of(Level.WARN, warnFn), @@ -213,20 +218,46 @@ public void testLoggerWithMDCValueIsNull() { } @Test - public void testLoggerWithExternalSystemCustomLogKey() { - Runnable errorFn = () -> logger().error("{}", externalSystemCustomLog); - Runnable warnFn = () -> logger().warn("{}", externalSystemCustomLog); - Runnable infoFn = () -> logger().info("{}", externalSystemCustomLog); + public void testLoggerWithScalaCustomLogKey() { + Runnable errorFn = () -> logger().error("{}", scalaCustomLogMDC); + Runnable warnFn = () -> logger().warn("{}", scalaCustomLogMDC); + Runnable infoFn = () -> logger().info("{}", scalaCustomLogMDC); + List.of( + Pair.of(Level.ERROR, errorFn), + Pair.of(Level.WARN, warnFn), + Pair.of(Level.INFO, infoFn)).forEach(pair -> { + try { + assertTrue(captureLogOutput(pair.getRight()).matches( + expectedPatternForScalaCustomLogKey(pair.getLeft()))); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testLoggerWithJavaCustomLogKey() { + Runnable errorFn = () -> logger().error("{}", javaCustomLogMDC); + Runnable warnFn = () -> logger().warn("{}", javaCustomLogMDC); + Runnable infoFn = () -> logger().info("{}", javaCustomLogMDC); List.of( Pair.of(Level.ERROR, errorFn), Pair.of(Level.WARN, warnFn), Pair.of(Level.INFO, infoFn)).forEach(pair -> { try { assertTrue(captureLogOutput(pair.getRight()).matches( - expectedPatternForExternalSystemCustomLogKey(pair.getLeft()))); + expectedPatternForJavaCustomLogKey(pair.getLeft()))); } catch (IOException e) { throw new RuntimeException(e); } }); } } + +class JavaCustomLogKeys { + // Custom `LogKey` must be `implements LogKey` + public static class CUSTOM_LOG_KEY implements LogKey { } + + // Singleton + public static final CUSTOM_LOG_KEY CUSTOM_LOG_KEY = new CUSTOM_LOG_KEY(); +} diff --git a/common/utils/src/test/java/org/apache/spark/util/StructuredSparkLoggerSuite.java b/common/utils/src/test/java/org/apache/spark/util/StructuredSparkLoggerSuite.java index 416f0b6172c00..ec19014e117ce 100644 --- a/common/utils/src/test/java/org/apache/spark/util/StructuredSparkLoggerSuite.java +++ b/common/utils/src/test/java/org/apache/spark/util/StructuredSparkLoggerSuite.java @@ -149,14 +149,28 @@ String expectedPatternForMsgWithMDCValueIsNull(Level level) { } @Override - String expectedPatternForExternalSystemCustomLogKey(Level level) { + String expectedPatternForScalaCustomLogKey(Level level) { return compactAndToRegexPattern(level, """ { "ts": "", "level": "", - "msg": "External system custom log message.", + "msg": "Scala custom log message.", "context": { - "custom_log_key": "External system custom log message." + "custom_log_key": "Scala custom log message." + }, + "logger": "" + }"""); + } + + @Override + String expectedPatternForJavaCustomLogKey(Level level) { + return compactAndToRegexPattern(level, """ + { + "ts": "", + "level": "", + "msg": "Java custom log message.", + "context": { + "custom_log_key": "Java custom log message." }, "logger": "" }"""); diff --git a/common/utils/src/test/scala/org/apache/spark/util/PatternLoggingSuite.scala b/common/utils/src/test/scala/org/apache/spark/util/PatternLoggingSuite.scala index 3baa720f38a90..ab9803d83bf62 100644 --- a/common/utils/src/test/scala/org/apache/spark/util/PatternLoggingSuite.scala +++ b/common/utils/src/test/scala/org/apache/spark/util/PatternLoggingSuite.scala @@ -47,8 +47,8 @@ class PatternLoggingSuite extends LoggingSuiteBase with BeforeAndAfterAll { override def expectedPatternForMsgWithMDCAndException(level: Level): String = s""".*$level $className: Error in executor 1.\njava.lang.RuntimeException: OOM\n[\\s\\S]*""" - override def expectedPatternForExternalSystemCustomLogKey(level: Level): String = { - s""".*$level $className: External system custom log message.\n""" + override def expectedPatternForCustomLogKey(level: Level): String = { + s""".*$level $className: Custom log message.\n""" } override def verifyMsgWithConcat(level: Level, logOutput: String): Unit = { diff --git a/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala b/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala index 2152b57524d72..694f06706421a 100644 --- a/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala +++ b/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala @@ -78,8 +78,8 @@ trait LoggingSuiteBase // test for message and exception def expectedPatternForMsgWithMDCAndException(level: Level): String - // test for external system custom LogKey - def expectedPatternForExternalSystemCustomLogKey(level: Level): String + // test for custom LogKey + def expectedPatternForCustomLogKey(level: Level): String def verifyMsgWithConcat(level: Level, logOutput: String): Unit @@ -146,18 +146,17 @@ trait LoggingSuiteBase } } - private val externalSystemCustomLog = - log"${MDC(CustomLogKeys.CUSTOM_LOG_KEY, "External system custom log message.")}" - test("Logging with external system custom LogKey") { + private val customLog = log"${MDC(CustomLogKeys.CUSTOM_LOG_KEY, "Custom log message.")}" + test("Logging with custom LogKey") { Seq( - (Level.ERROR, () => logError(externalSystemCustomLog)), - (Level.WARN, () => logWarning(externalSystemCustomLog)), - (Level.INFO, () => logInfo(externalSystemCustomLog)), - (Level.DEBUG, () => logDebug(externalSystemCustomLog)), - (Level.TRACE, () => logTrace(externalSystemCustomLog))).foreach { + (Level.ERROR, () => logError(customLog)), + (Level.WARN, () => logWarning(customLog)), + (Level.INFO, () => logInfo(customLog)), + (Level.DEBUG, () => logDebug(customLog)), + (Level.TRACE, () => logTrace(customLog))).foreach { case (level, logFunc) => val logOutput = captureLogOutput(logFunc) - assert(expectedPatternForExternalSystemCustomLogKey(level).r.matches(logOutput)) + assert(expectedPatternForCustomLogKey(level).r.matches(logOutput)) } } @@ -261,15 +260,15 @@ class StructuredLoggingSuite extends LoggingSuiteBase { }""") } - override def expectedPatternForExternalSystemCustomLogKey(level: Level): String = { + override def expectedPatternForCustomLogKey(level: Level): String = { compactAndToRegexPattern( s""" { "ts": "", "level": "$level", - "msg": "External system custom log message.", + "msg": "Custom log message.", "context": { - "custom_log_key": "External system custom log message." + "custom_log_key": "Custom log message." }, "logger": "$className" }""" @@ -307,6 +306,6 @@ class StructuredLoggingSuite extends LoggingSuiteBase { } object CustomLogKeys { - // External system custom LogKey must be `extends LogKey` + // Custom `LogKey` must be `extends LogKey` case object CUSTOM_LOG_KEY extends LogKey } From 541158fe03529d5a28eaeb61d801d065ff4ef664 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sun, 26 May 2024 08:35:45 -0700 Subject: [PATCH 26/45] [SPARK-47579][CORE][PART3][FOLLOWUP] Fix KubernetesSuite ### What changes were proposed in this pull request? The pr is following up https://github.com/apache/spark/pull/46739, and aims to fix `KubernetesSuite `. 1.Unfortunately, after `correcting` the `typo` from `decommision` to `decommission`, it seems that GA has been broken. image 2.https://github.com/panbingkun/spark/actions/runs/9232744348/job/25406127982 image ### Why are the changes needed? Only fix `KubernetesSuite`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46746 from panbingkun/fix_KubernetesSuite. Authored-by: panbingkun Signed-off-by: Gengliang Wang --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 1b9b5310c2ee2..ae5f037c6b7d4 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -175,7 +175,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => expectedDriverLogOnCompletion = Seq( "Finished waiting, stopping Spark", "Decommission executors", - "Remove reason statistics: (gracefully decommissioned: 1, decommision unfinished: 0, " + + "Remove reason statistics: (gracefully decommissioned: 1, decommission unfinished: 0, " + "driver killed: 0, unexpectedly exited: 0)."), appArgs = Array.empty[String], driverPodChecker = doBasicDriverPyPodCheck, From 4ef5ec92ef70fffa231b422c7da17c4438e95d0d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 27 May 2024 10:28:06 +0900 Subject: [PATCH 27/45] [SPARK-48424][INFRA] Make dev/is-changed.py to return true it it fails ### What changes were proposed in this pull request? This PR proposes to make dev/is-changed.py to return true it it fails ### Why are the changes needed? To make the test robust. GitHub Actions sometimes fail to set the hash for commit properly, e.g., https://github.com/apache/spark/actions/runs/9244026522/job/25435224163?pr=46747 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Manually tested: ```bash GITHUB_SHA=a29c9653f3d48d97875ae446d82896bdf0de61ca GITHUB_PREV_SHA=0000000000000000000000000000000000000000 ./dev/is-changed.py -m root ``` ```bash a=`GITHUB_SHA=a29c9653f3d48d97875ae446d82896bdf0de61ca GITHUB_PREV_SHA=0000000000000000000000000000000000000000 ./dev/is-changed.py -m root` echo $a ``` ```bash GITHUB_SHA=a29c9653f3d48d97875ae446d82896bdf0de61ca GITHUB_PREV_SHA=3346afd4b250c3aead5a237666d4942018a463e0 ./dev/is-changed.py -m root ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46749 from HyukjinKwon/SPARK-48424. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- dev/is-changed.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dev/is-changed.py b/dev/is-changed.py index 85f0d3cda6df4..1962e244d5dd7 100755 --- a/dev/is-changed.py +++ b/dev/is-changed.py @@ -17,6 +17,8 @@ # limitations under the License. # +import warnings +import traceback import os import sys from argparse import ArgumentParser @@ -82,4 +84,8 @@ def main(): if __name__ == "__main__": - main() + try: + main() + except Exception: + warnings.warn(f"Ignored exception:\n\n{traceback.format_exc()}") + print("true") From 8f678a402e901666e4757b4520962ce358ec7a9c Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 27 May 2024 12:34:52 +0900 Subject: [PATCH 28/45] [SPARK-48425][PYTHON][BUILD] Replaces pyspark-connect to pyspark_connect for its output name ### What changes were proposed in this pull request? This PR proposes to replace `pyspark-connect` to `pyspark_connect` for its output name. ### Why are the changes needed? `setuptools` from 69.X.X changes the output name. It replaces dash in package name to underscore (`pyspark_connect-4.0.0.dev1.tar.gz` vs `pyspark-connect-4.0.0.dev1.tar.gz`), I think it is https://github.com/pypa/setuptools/issues/4214. ### Does this PR introduce _any_ user-facing change? No, this package has not been released out yet. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46751 from HyukjinKwon/SPARK-48425. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- dev/create-release/release-build.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 0fb16aafcbaad..cd0220db75b1a 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -126,12 +126,12 @@ if [[ "$1" == "finalize" ]]; then --repository-url https://upload.pypi.org/legacy/ \ "pyspark-$RELEASE_VERSION.tar.gz" \ "pyspark-$RELEASE_VERSION.tar.gz.asc" - svn update "pyspark-connect-$RELEASE_VERSION.tar.gz" - svn update "pyspark-connect-$RELEASE_VERSION.tar.gz.asc" + svn update "pyspark_connect-$RELEASE_VERSION.tar.gz" + svn update "pyspark_connect-$RELEASE_VERSION.tar.gz.asc" TWINE_USERNAME=spark-upload TWINE_PASSWORD="$PYPI_PASSWORD" twine upload \ --repository-url https://upload.pypi.org/legacy/ \ - "pyspark-connect-$RELEASE_VERSION.tar.gz" \ - "pyspark-connect-$RELEASE_VERSION.tar.gz.asc" + "pyspark_connect-$RELEASE_VERSION.tar.gz" \ + "pyspark_connect-$RELEASE_VERSION.tar.gz.asc" cd .. rm -rf svn-spark echo "PySpark uploaded" @@ -314,7 +314,7 @@ if [[ "$1" == "package" ]]; then --detach-sig $PYTHON_DIST_NAME shasum -a 512 $PYTHON_DIST_NAME > $PYTHON_DIST_NAME.sha512 - PYTHON_CONNECT_DIST_NAME=pyspark-connect-$PYSPARK_VERSION.tar.gz + PYTHON_CONNECT_DIST_NAME=pyspark_connect-$PYSPARK_VERSION.tar.gz cp spark-$SPARK_VERSION-bin-$NAME/python/dist/$PYTHON_CONNECT_DIST_NAME . echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ From 90c163753426916c9dc9f4ad09baa829d8df210c Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 27 May 2024 13:57:24 +0900 Subject: [PATCH 29/45] [MINOR][INFRA] Make pure Python build compatible with other setuptools versions ### What changes were proposed in this pull request? This PR is related to https://github.com/apache/spark/pull/46751. For now, we're using low `setuptools` version so the build is not affected, but it will be broken if we happen to upgrade it at some point. This PR makes the build compatible with `_` and `-`. ### Why are the changes needed? This is a proactive action to prevent breaking the build. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46752 from HyukjinKwon/minor-build. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .github/workflows/build_python_connect.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index 639b0d0843142..2ad56f5c90003 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -70,7 +70,7 @@ jobs: cd python python packaging/connect/setup.py sdist cd dist - pip install pyspark-connect-*.tar.gz + pip install pyspark*connect-*.tar.gz pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' torch torchvision torcheval deepspeed unittest-xml-reporting - name: Run tests env: From 49da3a43f7ab41dad59e7deb810974902625c41c Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 27 May 2024 17:32:46 +0900 Subject: [PATCH 30/45] [SPARK-48340][PYTHON][FOLLOWUP] Support TimestampNTZ infer schema miss prefer_timestamp_ntz ### What changes were proposed in this pull request? Add UT ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #46750 from AngersZhuuuu/SPARK-48340-FOLLOWUP. Lead-authored-by: Angerszhuuuu Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_types.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index d665053d94904..80f2c0fcbc033 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -192,6 +192,7 @@ def __init__(self): Row(a=1), Row("a")(1), A(), + Row(b=Row(c=datetime.datetime(1970, 1, 1, 0, 0))), ] df = self.spark.createDataFrame([data]) @@ -214,6 +215,7 @@ def __init__(self): "struct", "struct", "struct", + "struct>", ] self.assertEqual(actual, expected) @@ -236,14 +238,25 @@ def __init__(self): Row(a=1), Row(a=1), Row(a=1), + Row(b=Row(c=datetime.datetime(1970, 1, 1, 0, 0))), ] self.assertEqual(actual, expected) with self.sql_conf({"spark.sql.timestampType": "TIMESTAMP_NTZ"}): with self.sql_conf({"spark.sql.session.timeZone": "America/Sao_Paulo"}): - df = self.spark.createDataFrame([(datetime.datetime(1970, 1, 1, 0, 0),)]) + data = [ + ( + datetime.datetime(1970, 1, 1, 0, 0), + Row(a=Row(a=datetime.datetime(1970, 1, 1, 0, 0))), + ) + ] + df = self.spark.createDataFrame(data) self.assertEqual(list(df.schema)[0].dataType.simpleString(), "timestamp_ntz") self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0)) + self.assertEqual( + list(df.schema)[1].dataType.simpleString(), "struct>" + ) + self.assertEqual(df.first()[1], Row(a=Row(a=datetime.datetime(1970, 1, 1, 0, 0)))) df = self.spark.createDataFrame( [ From 48a4bdb9eacb4c7a5c56812171a9093d120b98b7 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 27 May 2024 17:53:53 +0800 Subject: [PATCH 31/45] [SPARK-48427][BUILD] Upgrade `scala-parser-combinators` to 2.4 ### What changes were proposed in this pull request? This pr aims to upgrade `scala-parser-combinators` from 2.3 to 2.4 ### Why are the changes needed? This version begins to validate the build and testing for Java 21. The full release note as follows: - https://github.com/scala/scala-parser-combinators/releases/tag/v2.4.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #46754 from LuciferYang/SPARK-48427. Authored-by: yangjie01 Signed-off-by: Kent Yao --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 61d7861f4469b..10d812c9fd8a4 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -250,7 +250,7 @@ scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar scala-compiler/2.13.14//scala-compiler-2.13.14.jar scala-library/2.13.14//scala-library-2.13.14.jar scala-parallel-collections_2.13/1.0.4//scala-parallel-collections_2.13-1.0.4.jar -scala-parser-combinators_2.13/2.3.0//scala-parser-combinators_2.13-2.3.0.jar +scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar scala-reflect/2.13.14//scala-reflect-2.13.14.jar scala-xml_2.13/2.2.0//scala-xml_2.13-2.2.0.jar slf4j-api/2.0.13//slf4j-api-2.0.13.jar diff --git a/pom.xml b/pom.xml index eef7237ac12f9..5b088db7b20b5 100644 --- a/pom.xml +++ b/pom.xml @@ -1151,7 +1151,7 @@ org.scala-lang.modules scala-parser-combinators_${scala.binary.version} - 2.3.0 + 2.4.0 jline From b52645652eff35345c868dc47e50b3970f3a7002 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 27 May 2024 17:55:17 +0800 Subject: [PATCH 32/45] [SPARK-48168][SQL][FOLLOWUP] Fix bitwise shifting operator's precedence ### What changes were proposed in this pull request? After referencing both `C`, `MySQL`'s doc, https://en.cppreference.com/w/c/language/operator_precedence https://dev.mysql.com/doc/refman/8.0/en/operator-precedence.html And doing some experiments on scala-shell ```scala scala> 1 & 2 >> 1 val res0: Int = 1 scala> 2 >> 1 << 1 val res1: Int = 2 scala> 1 << 1 + 2 val res2: Int = 8 ``` The suitable precedence for `<< >> >>>` is between '+/-' and '&' with a left-to-right associativity. ### Why are the changes needed? bugfix ### Does this PR introduce _any_ user-facing change? now, unreleased yet ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46753 from yaooqinn/SPARK-48168-F. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../sql/catalyst/parser/SqlBaseParser.g4 | 2 +- .../analyzer-results/bitwise.sql.out | 21 ++++++++++++++++ .../resources/sql-tests/inputs/bitwise.sql | 6 ++++- .../sql-tests/results/bitwise.sql.out | 24 +++++++++++++++++++ 4 files changed, 51 insertions(+), 2 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f0c0adb881212..4552c17e0cf14 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -986,11 +986,11 @@ valueExpression | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression shiftOperator right=valueExpression #shiftExpression | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary | left=valueExpression comparisonOperator right=valueExpression #comparison - | left=valueExpression shiftOperator right=valueExpression #shiftExpression ; shiftOperator diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out index fee226c0c3411..1267a984565ad 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out @@ -418,3 +418,24 @@ select cast(null as map>), 20181117 >> 2 -- !query analysis Project [cast(null as map>) AS NULL#x, (20181117 >> 2) AS (20181117 >> 2)#x] +- OneRowRelation + + +-- !query +select 1 << 1 + 2 as plus_over_shift +-- !query analysis +Project [(1 << (1 + 2)) AS plus_over_shift#x] ++- OneRowRelation + + +-- !query +select 2 >> 1 << 1 as left_to_right +-- !query analysis +Project [((2 >> 1) << 1) AS left_to_right#x] ++- OneRowRelation + + +-- !query +select 1 & 2 >> 1 as shift_over_ampersand +-- !query analysis +Project [(1 & (2 >> 1)) AS shift_over_ampersand#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql index 5823b22ef6453..e080fdd32a4aa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql @@ -86,4 +86,8 @@ SELECT 20181117 <<< 2; SELECT 20181117 >>>> 2; select cast(null as array>), 20181117 >> 2; select cast(null as array>), 20181117 >>> 2; -select cast(null as map>), 20181117 >> 2; \ No newline at end of file +select cast(null as map>), 20181117 >> 2; + +select 1 << 1 + 2 as plus_over_shift; -- if correct, the result is 8. otherwise, 4 +select 2 >> 1 << 1 as left_to_right; -- if correct, the result is 2. otherwise, 0 +select 1 & 2 >> 1 as shift_over_ampersand; -- if correct, the result is 1. otherwise, 0 diff --git a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out index a7ebaea293bf9..7233b0d0ae499 100644 --- a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out @@ -450,3 +450,27 @@ select cast(null as map>), 20181117 >> 2 struct>,(20181117 >> 2):int> -- !query output NULL 5045279 + + +-- !query +select 1 << 1 + 2 as plus_over_shift +-- !query schema +struct +-- !query output +8 + + +-- !query +select 2 >> 1 << 1 as left_to_right +-- !query schema +struct +-- !query output +2 + + +-- !query +select 1 & 2 >> 1 as shift_over_ampersand +-- !query schema +struct +-- !query output +1 From cc3bf36c9f22d54606f858f0f90008cff792c59d Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Tue, 28 May 2024 08:55:39 +0900 Subject: [PATCH 33/45] [SPARK-48432][SQL] Avoid unboxing integers in UnivocityParser ### What changes were proposed in this pull request? `tokenIndexArr` is created as an array of `java.lang.Integers`. However, it is used not only for the wrapped java parser, but also during parsing to identify the correct token index. ### Why are the changes needed? This noticeably improves CSV parsing performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? `testOnly org.apache.spark.sql.catalyst.csv.UnivocityParserSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #46759 from vladimirg-db/vladimirg-db/avoid-unboxing-in-univocity-parser. Authored-by: Vladimir Golubev Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/catalyst/csv/UnivocityParser.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 4d95097e16816..61c2f7a5926b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -63,8 +63,7 @@ class UnivocityParser( private type ValueConverter = String => Any // This index is used to reorder parsed tokens - private val tokenIndexArr = - requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + private val tokenIndexArr = requiredSchema.map(f => dataSchema.indexOf(f)).toArray // True if we should inform the Univocity CSV parser to select which fields to read by their // positions. Generally assigned by input configuration options, except when input column(s) have @@ -81,7 +80,8 @@ class UnivocityParser( // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV // parser select a sequence of fields for reading by their positions. if (parsedSchema.length < dataSchema.length) { - parserSetting.selectIndexes(tokenIndexArr: _*) + // Box into Integer here to avoid unboxing where `tokenIndexArr` is used during parsing + parserSetting.selectIndexes(tokenIndexArr.map(java.lang.Integer.valueOf(_)): _*) } new CsvParser(parserSetting) } From de8d96892f9212a1bd7cd1b4dfad172d0cb8cd35 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 28 May 2024 09:33:20 +0900 Subject: [PATCH 34/45] [SPARK-48370][CONNECT][FOLLOW-UP] Use JDK's Cleaner instead ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/46683 that replaces our custom cleaner to JDK's cleaner. ### Why are the changes needed? Reuse the standard builtin library. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I manually tested via reenabling `CheckpointSuite.checkpoint gc derived DataFrame` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46726 from HyukjinKwon/SPARK-48370-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 9 +- .../spark/sql/internal/SessionCleaner.scala | 125 ++---------------- 3 files changed, 16 insertions(+), 120 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 5ac07270b22b3..204d3985cf4bf 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3492,7 +3492,7 @@ class Dataset[T] private[sql] ( .getOrElse(throw new RuntimeException("CheckpointCommandResult must be present")) val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation - sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation) + sparkSession.cleaner.register(cachedRemoteRelation) // Update the builder with the values from the result. builder.setCachedRemoteRelation(cachedRemoteRelation) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 91ee0f52e8bd0..19c5a3f14c64f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -73,11 +73,7 @@ class SparkSession private[sql] ( with Logging { private[this] val allocator = new RootAllocator() - private var shouldStopCleaner = false - private[sql] lazy val cleaner = { - shouldStopCleaner = true - new SessionCleaner(this) - } + private[sql] lazy val cleaner = new SessionCleaner(this) // a unique session ID for this session from client. private[sql] def sessionId: String = client.sessionId @@ -719,9 +715,6 @@ class SparkSession private[sql] ( if (releaseSessionOnClose) { client.releaseSession() } - if (shouldStopCleaner) { - cleaner.stop() - } client.shutdown() allocator.close() SparkSession.onSessionClose(this) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala index 036ea4a84fa97..21e4f4d141a89 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala @@ -17,130 +17,33 @@ package org.apache.spark.sql.internal -import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.Collections -import java.util.concurrent.ConcurrentHashMap +import java.lang.ref.Cleaner import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -/** - * Classes that represent cleaning tasks. - */ -private sealed trait CleanupTask -private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask - -/** - * A WeakReference associated with a CleanupTask. - * - * When the referent object becomes only weakly reachable, the corresponding - * CleanupTaskWeakReference is automatically added to the given reference queue. - */ -private class CleanupTaskWeakReference( - val task: CleanupTask, - referent: AnyRef, - referenceQueue: ReferenceQueue[AnyRef]) - extends WeakReference(referent, referenceQueue) - -/** - * An asynchronous cleaner for objects. - * - * This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed - * when the associated object goes out of scope of the application. Actual cleanup is performed in - * a separate daemon thread. - */ private[sql] class SessionCleaner(session: SparkSession) extends Logging { - - /** - * How often (seconds) to trigger a garbage collection in this JVM. This context cleaner - * triggers cleanups only when weak references are garbage collected. In long-running - * applications with large driver JVMs, where there is little memory pressure on the driver, - * this may happen very occasionally or not at all. Not cleaning at all may lead to executors - * running out of disk space after a while. - */ - private val refQueuePollTimeout: Long = 100 - - /** - * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they - * have not been handled by the reference queue. - */ - private val referenceBuffer = - Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap) - - private val referenceQueue = new ReferenceQueue[AnyRef] - - private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() } - - @volatile private var started = false - @volatile private var stopped = false - - /** Start the cleaner. */ - def start(): Unit = { - cleaningThread.setDaemon(true) - cleaningThread.setName("Spark Connect Context Cleaner") - cleaningThread.start() - } - - /** - * Stop the cleaning thread and wait until the thread has finished running its current task. - */ - def stop(): Unit = { - stopped = true - // Interrupt the cleaning thread, but wait until the current task has finished before - // doing so. This guards against the race condition where a cleaning thread may - // potentially clean similarly named variables created by a different SparkSession. - synchronized { - cleaningThread.interrupt() - } - cleaningThread.join() - } + private val cleaner = Cleaner.create() /** Register a CachedRemoteRelation for cleanup when it is garbage collected. */ - def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = { - registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId)) - } - - /** Register an object for cleanup. */ - private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { - if (!started) { - // Lazily starts when the first cleanup is registered. - start() - started = true - } - referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) + def register(relation: proto.CachedRemoteRelation): Unit = { + val dfID = relation.getRelationId + cleaner.register(relation, () => doCleanupCachedRemoteRelation(dfID)) } - /** Keep cleaning objects. */ - private def keepCleaning(): Unit = { - while (!stopped && !session.client.channel.isShutdown) { - try { - val reference = Option(referenceQueue.remove(refQueuePollTimeout)) - .map(_.asInstanceOf[CleanupTaskWeakReference]) - // Synchronize here to avoid being interrupted on stop() - synchronized { - reference.foreach { ref => - logDebug("Got cleaning task " + ref.task) - referenceBuffer.remove(ref) - ref.task match { - case CleanupCachedRemoteRelation(dfID) => - doCleanupCachedRemoteRelation(dfID) - } + private[sql] def doCleanupCachedRemoteRelation(dfID: String): Unit = { + try { + if (!session.client.channel.isShutdown) { + session.execute { + session.newCommand { builder => + builder.getRemoveCachedRemoteRelationCommandBuilder + .setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build()) } } - } catch { - case e: Throwable => logError("Error in cleaning thread", e) - } - } - } - - /** Perform CleanupCachedRemoteRelation cleanup. */ - private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = { - session.execute { - session.newCommand { builder => - builder.getRemoveCachedRemoteRelationCommandBuilder - .setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build()) } + } catch { + case e: Throwable => logError("Error in cleaning thread", e) } } } From fc1435d14d090b792a0f19372d6b11c7ff026372 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 28 May 2024 08:39:28 +0800 Subject: [PATCH 35/45] [SPARK-48415][PYTHON] Refactor `TypeName` to support parameterized datatypes ### What changes were proposed in this pull request? 1, refactor instance method `TypeName` to support parameterized datatypes 2, remove redundant simpleString/jsonValue methods, since they are type name by default. ### Why are the changes needed? to be consistent with the Scala side ### Does this PR introduce _any_ user-facing change? type names changes: `CharType(10)`: `char` -> `char(10)` `VarcharType(10)`: `varchar` -> `varchar(10)` `DecimalType(10, 2)`: `decimal` -> `decimal(10,2)` `DayTimeIntervalType(DAY, HOUR)`: `daytimeinterval` -> `interval day to hour` `YearMonthIntervalType(YEAR, MONTH)`: `yearmonthinterval` -> `interval year to month` ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46738 from zhengruifeng/py_type_name. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/tests/test_types.py | 133 +++++++++++++++++++++++++ python/pyspark/sql/types.py | 74 +++++--------- 2 files changed, 160 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 80f2c0fcbc033..cc482b886e3a9 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -81,6 +81,139 @@ class TypesTestsMixin: + def test_class_method_type_name(self): + for dataType, expected in [ + (StringType, "string"), + (CharType, "char"), + (VarcharType, "varchar"), + (BinaryType, "binary"), + (BooleanType, "boolean"), + (DecimalType, "decimal"), + (FloatType, "float"), + (DoubleType, "double"), + (ByteType, "byte"), + (ShortType, "short"), + (IntegerType, "integer"), + (LongType, "long"), + (DateType, "date"), + (TimestampType, "timestamp"), + (TimestampNTZType, "timestamp_ntz"), + (NullType, "void"), + (VariantType, "variant"), + (YearMonthIntervalType, "yearmonthinterval"), + (DayTimeIntervalType, "daytimeinterval"), + (CalendarIntervalType, "interval"), + ]: + self.assertEqual(dataType.typeName(), expected) + + def test_instance_method_type_name(self): + for dataType, expected in [ + (StringType(), "string"), + (CharType(5), "char(5)"), + (VarcharType(10), "varchar(10)"), + (BinaryType(), "binary"), + (BooleanType(), "boolean"), + (DecimalType(), "decimal(10,0)"), + (DecimalType(10, 2), "decimal(10,2)"), + (FloatType(), "float"), + (DoubleType(), "double"), + (ByteType(), "byte"), + (ShortType(), "short"), + (IntegerType(), "integer"), + (LongType(), "long"), + (DateType(), "date"), + (TimestampType(), "timestamp"), + (TimestampNTZType(), "timestamp_ntz"), + (NullType(), "void"), + (VariantType(), "variant"), + (YearMonthIntervalType(), "interval year to month"), + (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), + ( + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + "interval year to month", + ), + (DayTimeIntervalType(), "interval day to second"), + (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), + ( + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + "interval hour to second", + ), + (CalendarIntervalType(), "interval"), + ]: + self.assertEqual(dataType.typeName(), expected) + + def test_simple_string(self): + for dataType, expected in [ + (StringType(), "string"), + (CharType(5), "char(5)"), + (VarcharType(10), "varchar(10)"), + (BinaryType(), "binary"), + (BooleanType(), "boolean"), + (DecimalType(), "decimal(10,0)"), + (DecimalType(10, 2), "decimal(10,2)"), + (FloatType(), "float"), + (DoubleType(), "double"), + (ByteType(), "tinyint"), + (ShortType(), "smallint"), + (IntegerType(), "int"), + (LongType(), "bigint"), + (DateType(), "date"), + (TimestampType(), "timestamp"), + (TimestampNTZType(), "timestamp_ntz"), + (NullType(), "void"), + (VariantType(), "variant"), + (YearMonthIntervalType(), "interval year to month"), + (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), + ( + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + "interval year to month", + ), + (DayTimeIntervalType(), "interval day to second"), + (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), + ( + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + "interval hour to second", + ), + (CalendarIntervalType(), "interval"), + ]: + self.assertEqual(dataType.simpleString(), expected) + + def test_json_value(self): + for dataType, expected in [ + (StringType(), "string"), + (CharType(5), "char(5)"), + (VarcharType(10), "varchar(10)"), + (BinaryType(), "binary"), + (BooleanType(), "boolean"), + (DecimalType(), "decimal(10,0)"), + (DecimalType(10, 2), "decimal(10,2)"), + (FloatType(), "float"), + (DoubleType(), "double"), + (ByteType(), "byte"), + (ShortType(), "short"), + (IntegerType(), "integer"), + (LongType(), "long"), + (DateType(), "date"), + (TimestampType(), "timestamp"), + (TimestampNTZType(), "timestamp_ntz"), + (NullType(), "void"), + (VariantType(), "variant"), + (YearMonthIntervalType(), "interval year to month"), + (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval year"), + ( + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + "interval year to month", + ), + (DayTimeIntervalType(), "interval day to second"), + (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"), + ( + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + "interval hour to second", + ), + (CalendarIntervalType(), "interval"), + ]: + self.assertEqual(dataType.jsonValue(), expected) + def test_apply_schema_to_row(self): df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b9db59e0a58ac..563c63f5dfb1a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -115,7 +115,11 @@ def __hash__(self) -> int: return hash(str(self)) def __eq__(self, other: Any) -> bool: - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + if isinstance(other, self.__class__): + self_dict = {k: v for k, v in self.__dict__.items() if k != "typeName"} + other_dict = {k: v for k, v in other.__dict__.items() if k != "typeName"} + return self_dict == other_dict + return False def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -124,6 +128,12 @@ def __ne__(self, other: Any) -> bool: def typeName(cls) -> str: return cls.__name__[:-4].lower() + # The classmethod 'typeName' is not always consistent with the Scala side, e.g. + # DecimalType(10, 2): 'decimal' vs 'decimal(10, 2)' + # This method is used in subclass initializer to replace 'typeName' if they are different. + def _type_name(self) -> str: + return self.__class__.__name__.removesuffix("Type").removesuffix("UDT").lower() + def simpleString(self) -> str: return self.typeName() @@ -215,24 +225,6 @@ def _data_type_build_formatted_string( if isinstance(dataType, (ArrayType, StructType, MapType)): dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1) - # The method typeName() is not always the same as the Scala side. - # Add this helper method to make TreeString() compatible with Scala side. - @classmethod - def _get_jvm_type_name(cls, dataType: "DataType") -> str: - if isinstance( - dataType, - ( - DecimalType, - CharType, - VarcharType, - DayTimeIntervalType, - YearMonthIntervalType, - ), - ): - return dataType.simpleString() - else: - return dataType.typeName() - # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -294,6 +286,7 @@ class StringType(AtomicType): providers = [providerSpark, providerICU] def __init__(self, collation: Optional[str] = None): + self.typeName = self._type_name # type: ignore[method-assign] self.collationId = 0 if collation is None else self.collationNameToId(collation) @classmethod @@ -315,7 +308,7 @@ def collationProvider(cls, collationName: str) -> str: return StringType.providerSpark return StringType.providerICU - def simpleString(self) -> str: + def _type_name(self) -> str: if self.isUTF8BinaryCollation(): return "string" @@ -348,12 +341,10 @@ class CharType(AtomicType): """ def __init__(self, length: int): + self.typeName = self._type_name # type: ignore[method-assign] self.length = length - def simpleString(self) -> str: - return "char(%d)" % (self.length) - - def jsonValue(self) -> str: + def _type_name(self) -> str: return "char(%d)" % (self.length) def __repr__(self) -> str: @@ -370,12 +361,10 @@ class VarcharType(AtomicType): """ def __init__(self, length: int): + self.typeName = self._type_name # type: ignore[method-assign] self.length = length - def simpleString(self) -> str: - return "varchar(%d)" % (self.length) - - def jsonValue(self) -> str: + def _type_name(self) -> str: return "varchar(%d)" % (self.length) def __repr__(self) -> str: @@ -474,14 +463,12 @@ class DecimalType(FractionalType): """ def __init__(self, precision: int = 10, scale: int = 0): + self.typeName = self._type_name # type: ignore[method-assign] self.precision = precision self.scale = scale self.hasPrecisionInfo = True # this is a public API - def simpleString(self) -> str: - return "decimal(%d,%d)" % (self.precision, self.scale) - - def jsonValue(self) -> str: + def _type_name(self) -> str: return "decimal(%d,%d)" % (self.precision, self.scale) def __repr__(self) -> str: @@ -556,6 +543,7 @@ class DayTimeIntervalType(AnsiIntervalType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + self.typeName = self._type_name # type: ignore[method-assign] if startField is None and endField is None: # Default matched to scala side. startField = DayTimeIntervalType.DAY @@ -572,7 +560,7 @@ def __init__(self, startField: Optional[int] = None, endField: Optional[int] = N self.startField = startField self.endField = endField - def _str_repr(self) -> str: + def _type_name(self) -> str: fields = DayTimeIntervalType._fields start_field_name = fields[self.startField] end_field_name = fields[self.endField] @@ -581,10 +569,6 @@ def _str_repr(self) -> str: else: return "interval %s to %s" % (start_field_name, end_field_name) - simpleString = _str_repr - - jsonValue = _str_repr - def __repr__(self) -> str: return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) @@ -614,6 +598,7 @@ class YearMonthIntervalType(AnsiIntervalType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + self.typeName = self._type_name # type: ignore[method-assign] if startField is None and endField is None: # Default matched to scala side. startField = YearMonthIntervalType.YEAR @@ -630,7 +615,7 @@ def __init__(self, startField: Optional[int] = None, endField: Optional[int] = N self.startField = startField self.endField = endField - def _str_repr(self) -> str: + def _type_name(self) -> str: fields = YearMonthIntervalType._fields start_field_name = fields[self.startField] end_field_name = fields[self.endField] @@ -639,10 +624,6 @@ def _str_repr(self) -> str: else: return "interval %s to %s" % (start_field_name, end_field_name) - simpleString = _str_repr - - jsonValue = _str_repr - def __repr__(self) -> str: return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) @@ -776,7 +757,7 @@ def _build_formatted_string( ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- element: {DataType._get_jvm_type_name(self.elementType)} " + f"{prefix}-- element: {self.elementType.typeName()} " + f"(containsNull = {str(self.containsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -924,12 +905,12 @@ def _build_formatted_string( maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: - stringConcat.append(f"{prefix}-- key: {DataType._get_jvm_type_name(self.keyType)}\n") + stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n") DataType._data_type_build_formatted_string( self.keyType, f"{prefix} |", stringConcat, maxDepth ) stringConcat.append( - f"{prefix}-- value: {DataType._get_jvm_type_name(self.valueType)} " + f"{prefix}-- value: {self.valueType.typeName()} " + f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -1092,8 +1073,7 @@ def _build_formatted_string( ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- {escape_meta_characters(self.name)}: " - + f"{DataType._get_jvm_type_name(self.dataType)} " + f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} " + f"(nullable = {str(self.nullable).lower()})\n" ) DataType._data_type_build_formatted_string( From d863503e8737937fc90c68583a3762fa67f53401 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 28 May 2024 13:37:25 +0900 Subject: [PATCH 36/45] [SPARK-48434][PYTHON][CONNECT] Make `printSchema` use the cached schema ### What changes were proposed in this pull request? Make `printSchema` use the cached schema ### Why are the changes needed? to avoid extra RPCs ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46764 from zhengruifeng/connect_print_schema. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/dataframe.py | 5 +++- .../sql/tests/connect/test_connect_basic.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 62c73da374bc9..354cf60c20144 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1811,7 +1811,10 @@ def _tree_string(self, level: Optional[int] = None) -> str: return result def printSchema(self, level: Optional[int] = None) -> None: - print(self._tree_string(level)) + if level: + print(self.schema.treeString(level)) + else: + print(self.schema.treeString()) def inputFiles(self) -> List[str]: query = self._plan.to_proto(self._session.client) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 0648b5ce9925c..eb5cb18d11e63 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -20,6 +20,8 @@ import unittest import shutil import tempfile +import io +from contextlib import redirect_stdout from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError @@ -352,6 +354,24 @@ def test_simple_explain_string(self): result = df._explain_string() self.assertGreater(len(result), 0) + def _check_print_schema(self, query: str): + with io.StringIO() as buf, redirect_stdout(buf): + self.spark.sql(query).printSchema() + print1 = buf.getvalue() + with io.StringIO() as buf, redirect_stdout(buf): + self.connect.sql(query).printSchema() + print2 = buf.getvalue() + self.assertEqual(print1, print2, query) + + for level in [-1, 0, 1, 2, 3, 4]: + with io.StringIO() as buf, redirect_stdout(buf): + self.spark.sql(query).printSchema(level) + print1 = buf.getvalue() + with io.StringIO() as buf, redirect_stdout(buf): + self.connect.sql(query).printSchema(level) + print2 = buf.getvalue() + self.assertEqual(print1, print2, query) + def test_schema(self): schema = self.connect.read.table(self.tbl_name).schema self.assertEqual( @@ -373,6 +393,7 @@ def test_schema(self): self.spark.sql(query).schema, self.connect.sql(query).schema, ) + self._check_print_schema(query) # test TimestampType, DateType query = """ @@ -386,6 +407,7 @@ def test_schema(self): self.spark.sql(query).schema, self.connect.sql(query).schema, ) + self._check_print_schema(query) # test DayTimeIntervalType query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """ @@ -393,6 +415,7 @@ def test_schema(self): self.spark.sql(query).schema, self.connect.sql(query).schema, ) + self._check_print_schema(query) # test MapType query = """ @@ -406,6 +429,7 @@ def test_schema(self): self.spark.sql(query).schema, self.connect.sql(query).schema, ) + self._check_print_schema(query) # test ArrayType query = """ @@ -419,6 +443,7 @@ def test_schema(self): self.spark.sql(query).schema, self.connect.sql(query).schema, ) + self._check_print_schema(query) # test StructType query = """ @@ -432,6 +457,7 @@ def test_schema(self): self.spark.sql(query).schema, self.connect.sql(query).schema, ) + self._check_print_schema(query) def test_to(self): # SPARK-41464: test DataFrame.to() From a88cc1ad9319bd0f4a14e2d6094865229449c8cb Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 28 May 2024 13:09:39 +0800 Subject: [PATCH 37/45] [SPARK-48420][BUILD] Upgrade netty to `4.1.110.Final` ### What changes were proposed in this pull request? The pr aims to upgrade `netty` from `4.1.109.Final` to `4.1.110.Final`. ### Why are the changes needed? - https://netty.io/news/2024/05/22/4-1-110-Final.html This version has brought some bug fixes and improvements, such as: Fix Zstd throws Exception on read-only volumes (https://github.com/netty/netty/pull/13982) Add unix domain socket transport in netty 4.x via JDK16+ ([#13965](https://github.com/netty/netty/pull/13965)) Backport #13075: Add the AdaptivePoolingAllocator ([#13976](https://github.com/netty/netty/pull/13976)) Add no-value key handling only for form body ([#13998](https://github.com/netty/netty/pull/13998)) Add support for specifying SecureRandom in SSLContext initialization ([#14058](https://github.com/netty/netty/pull/14058)) - https://github.com/netty/netty/issues?q=milestone%3A4.1.110.Final+is%3Aclosed ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46744 from panbingkun/SPARK-48420. Authored-by: panbingkun Signed-off-by: yangjie01 --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 38 +++++++++++++-------------- pom.xml | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 10d812c9fd8a4..e854bd0e804a3 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -197,31 +197,31 @@ metrics-jmx/4.2.25//metrics-jmx-4.2.25.jar metrics-json/4.2.25//metrics-json-4.2.25.jar metrics-jvm/4.2.25//metrics-jvm-4.2.25.jar minlog/1.3.0//minlog-1.3.0.jar -netty-all/4.1.109.Final//netty-all-4.1.109.Final.jar -netty-buffer/4.1.109.Final//netty-buffer-4.1.109.Final.jar -netty-codec-http/4.1.109.Final//netty-codec-http-4.1.109.Final.jar -netty-codec-http2/4.1.109.Final//netty-codec-http2-4.1.109.Final.jar -netty-codec-socks/4.1.109.Final//netty-codec-socks-4.1.109.Final.jar -netty-codec/4.1.109.Final//netty-codec-4.1.109.Final.jar -netty-common/4.1.109.Final//netty-common-4.1.109.Final.jar -netty-handler-proxy/4.1.109.Final//netty-handler-proxy-4.1.109.Final.jar -netty-handler/4.1.109.Final//netty-handler-4.1.109.Final.jar -netty-resolver/4.1.109.Final//netty-resolver-4.1.109.Final.jar +netty-all/4.1.110.Final//netty-all-4.1.110.Final.jar +netty-buffer/4.1.110.Final//netty-buffer-4.1.110.Final.jar +netty-codec-http/4.1.110.Final//netty-codec-http-4.1.110.Final.jar +netty-codec-http2/4.1.110.Final//netty-codec-http2-4.1.110.Final.jar +netty-codec-socks/4.1.110.Final//netty-codec-socks-4.1.110.Final.jar +netty-codec/4.1.110.Final//netty-codec-4.1.110.Final.jar +netty-common/4.1.110.Final//netty-common-4.1.110.Final.jar +netty-handler-proxy/4.1.110.Final//netty-handler-proxy-4.1.110.Final.jar +netty-handler/4.1.110.Final//netty-handler-4.1.110.Final.jar +netty-resolver/4.1.110.Final//netty-resolver-4.1.110.Final.jar netty-tcnative-boringssl-static/2.0.65.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-aarch_64.jar netty-tcnative-boringssl-static/2.0.65.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-x86_64.jar netty-tcnative-boringssl-static/2.0.65.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-aarch_64.jar netty-tcnative-boringssl-static/2.0.65.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-x86_64.jar netty-tcnative-boringssl-static/2.0.65.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-windows-x86_64.jar netty-tcnative-classes/2.0.65.Final//netty-tcnative-classes-2.0.65.Final.jar -netty-transport-classes-epoll/4.1.109.Final//netty-transport-classes-epoll-4.1.109.Final.jar -netty-transport-classes-kqueue/4.1.109.Final//netty-transport-classes-kqueue-4.1.109.Final.jar -netty-transport-native-epoll/4.1.109.Final/linux-aarch_64/netty-transport-native-epoll-4.1.109.Final-linux-aarch_64.jar -netty-transport-native-epoll/4.1.109.Final/linux-riscv64/netty-transport-native-epoll-4.1.109.Final-linux-riscv64.jar -netty-transport-native-epoll/4.1.109.Final/linux-x86_64/netty-transport-native-epoll-4.1.109.Final-linux-x86_64.jar -netty-transport-native-kqueue/4.1.109.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.109.Final-osx-aarch_64.jar -netty-transport-native-kqueue/4.1.109.Final/osx-x86_64/netty-transport-native-kqueue-4.1.109.Final-osx-x86_64.jar -netty-transport-native-unix-common/4.1.109.Final//netty-transport-native-unix-common-4.1.109.Final.jar -netty-transport/4.1.109.Final//netty-transport-4.1.109.Final.jar +netty-transport-classes-epoll/4.1.110.Final//netty-transport-classes-epoll-4.1.110.Final.jar +netty-transport-classes-kqueue/4.1.110.Final//netty-transport-classes-kqueue-4.1.110.Final.jar +netty-transport-native-epoll/4.1.110.Final/linux-aarch_64/netty-transport-native-epoll-4.1.110.Final-linux-aarch_64.jar +netty-transport-native-epoll/4.1.110.Final/linux-riscv64/netty-transport-native-epoll-4.1.110.Final-linux-riscv64.jar +netty-transport-native-epoll/4.1.110.Final/linux-x86_64/netty-transport-native-epoll-4.1.110.Final-linux-x86_64.jar +netty-transport-native-kqueue/4.1.110.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.110.Final-osx-aarch_64.jar +netty-transport-native-kqueue/4.1.110.Final/osx-x86_64/netty-transport-native-kqueue-4.1.110.Final-osx-x86_64.jar +netty-transport-native-unix-common/4.1.110.Final//netty-transport-native-unix-common-4.1.110.Final.jar +netty-transport/4.1.110.Final//netty-transport-4.1.110.Final.jar objenesis/3.3//objenesis-3.3.jar okhttp/3.12.12//okhttp-3.12.12.jar okio/1.15.0//okio-1.15.0.jar diff --git a/pom.xml b/pom.xml index 5b088db7b20b5..36d12b9b52814 100644 --- a/pom.xml +++ b/pom.xml @@ -214,7 +214,7 @@ 1.78 1.13.0 6.0.0 - 4.1.109.Final + 4.1.110.Final 2.0.65.Final 72.1 5.9.3 From af1ac1edc2a96c9aba949e3100ddae37b6f0e5b2 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 27 May 2024 22:40:13 -0700 Subject: [PATCH 38/45] [SPARK-41049][SQL][FOLLOW-UP] Mark map related expressions as stateful expressions ### What changes were proposed in this pull request? MapConcat contains a state so it is stateful: ``` private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) ``` Similarly `MapFromEntries, CreateMap, MapFromArrays, StringToMap, and TransformKeys` need the same change. ### Why are the changes needed? Stateful expression should be marked as stateful. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? No Closes #46721 from amaliujia/statefulexpr. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- .../catalyst/expressions/collectionOperations.scala | 3 +++ .../sql/catalyst/expressions/complexTypeCreator.scala | 6 ++++++ .../catalyst/expressions/higherOrderFunctions.scala | 2 ++ .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 +++++++++- 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 632e2f3d3e973..ea117f876550e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -713,6 +713,7 @@ case class MapConcat(children: Seq[Expression]) } } + override def stateful: Boolean = true override def nullable: Boolean = children.exists(_.nullable) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) @@ -828,6 +829,8 @@ case class MapFromEntries(child: Expression) override def nullable: Boolean = child.nullable || nullEntries + override def stateful: Boolean = true + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4c0d005340606..167c02c0bafc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -245,6 +245,8 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def stateful: Boolean = true + override def eval(input: InternalRow): Any = { var i = 0 while (i < keys.length) { @@ -320,6 +322,8 @@ case class MapFromArrays(left: Expression, right: Expression) valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } + override def stateful: Boolean = true + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { @@ -568,6 +572,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E this(child, Literal(","), Literal(":")) } + override def stateful: Boolean = true + override def first: Expression = text override def second: Expression = pairDelim override def third: Expression = keyValueDelim diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 896f3e9774f37..80bcf156133ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -920,6 +920,8 @@ case class TransformKeys( override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull) + override def stateful: Boolean = true + override def checkInputDataTypes(): TypeCheckResult = { TypeUtils.checkForMapKeyType(function.dataType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f11ad230ec160..760ee80260808 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation} @@ -2504,6 +2504,14 @@ class DataFrameSuite extends QueryTest assert(row.getInt(0).toString == row.getString(2)) assert(row.getInt(0).toString == row.getString(3)) } + + val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value")))) + val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback + df.select(v3, v3, v4, v4).collect().foreach { row => + assert(row.getMap(0).toString() == row.getMap(1).toString()) + assert(row.getString(2) == s"{key -> ${row.getMap(0).get("key").get}}") + assert(row.getString(3) == s"{key -> ${row.getMap(0).get("key").get}}") + } } test("SPARK-45216: Non-deterministic functions with seed") { From df5ef91f44bacf60b39bccecc049eb8cb5714b39 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 28 May 2024 15:43:25 +0900 Subject: [PATCH 39/45] [SPARK-48425][INFRA][FOLLOWUP] copy spark connect release tarballs ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/46751 . Now the spark connect tarball is named `pyspark_connect...`, and we need to update the release scripts as well. ### Why are the changes needed? fix release scripts ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manual ### Was this patch authored or co-authored using generative AI tooling? no Closes #46769 from cloud-fan/script. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- dev/create-release/release-build.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index cd0220db75b1a..cf9c146e081a7 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -382,9 +382,9 @@ if [[ "$1" == "package" ]]; then mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" echo "Copying release tarballs" - cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + cp spark* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR* "svn-spark/${DEST_DIR_NAME}-bin/" svn add "svn-spark/${DEST_DIR_NAME}-bin" cd svn-spark From f164e4ae53ca333c49baa65b9f4d332a0a7ec4c8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 28 May 2024 16:48:39 +0900 Subject: [PATCH 40/45] [SPARK-48425][INFRA][FOLLOWUP] Do not copy the base spark folder ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/46769 to fix a mistake. The release scripts use a base `spark` folder to build tarballs and it shouldn't be copied with the tarballs. ### Why are the changes needed? fix mistake ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manual ### Was this patch authored or co-authored using generative AI tooling? no Closes #46771 from cloud-fan/script. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- dev/create-release/release-build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index cf9c146e081a7..0435960c93cd0 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -382,7 +382,7 @@ if [[ "$1" == "package" ]]; then mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" echo "Copying release tarballs" - cp spark* "svn-spark/${DEST_DIR_NAME}-bin/" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" cp pyspark* "svn-spark/${DEST_DIR_NAME}-bin/" cp SparkR* "svn-spark/${DEST_DIR_NAME}-bin/" svn add "svn-spark/${DEST_DIR_NAME}-bin" From a78ef738af0233f0e6abff74645f8d4b19c29e28 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 28 May 2024 17:31:05 +0800 Subject: [PATCH 41/45] [SPARK-48168][SQL][FOLLOWUP] Match expression strings of shift operators & functions with user inputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Followup of SPARK-48168,restores shiftleft/shiftright/shiftrightunsigned in plans if user used them. ### Why are the changes needed? bugfix ### Does this PR introduce _any_ user-facing change? no,unreleased cases ### How was this patch tested? modified tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46767 from yaooqinn/SPARK-48168-FF. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../function_shiftleft.explain | 2 +- .../function_shiftright.explain | 2 +- .../function_shiftrightunsigned.explain | 2 +- .../grouping_and_grouping_id.explain | 2 +- .../expressions/mathExpressions.scala | 8 ++-- .../sql/catalyst/parser/AstBuilder.scala | 5 ++- .../sql-functions/sql-expression-schema.md | 12 ++--- .../analyzer-results/group-analytics.sql.out | 10 ++--- .../analyzer-results/grouping_set.sql.out | 6 +-- .../postgreSQL/groupingsets.sql.out | 44 +++++++++---------- .../analyzer-results/postgreSQL/int2.sql.out | 4 +- .../analyzer-results/postgreSQL/int4.sql.out | 4 +- .../analyzer-results/postgreSQL/int8.sql.out | 4 +- .../udf/udf-group-analytics.sql.out | 10 ++--- .../sql-tests/results/postgreSQL/int2.sql.out | 4 +- .../sql-tests/results/postgreSQL/int4.sql.out | 4 +- .../sql-tests/results/postgreSQL/int8.sql.out | 2 +- .../approved-plans-v1_4/q17/explain.txt | 2 +- .../approved-plans-v1_4/q25/explain.txt | 2 +- .../approved-plans-v1_4/q27.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q27/explain.txt | 2 +- .../approved-plans-v1_4/q29/explain.txt | 2 +- .../approved-plans-v1_4/q36.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q36/explain.txt | 2 +- .../approved-plans-v1_4/q39a/explain.txt | 2 +- .../approved-plans-v1_4/q39b/explain.txt | 2 +- .../approved-plans-v1_4/q49/explain.txt | 6 +-- .../approved-plans-v1_4/q5/explain.txt | 2 +- .../approved-plans-v1_4/q64/explain.txt | 4 +- .../approved-plans-v1_4/q70.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q70/explain.txt | 2 +- .../approved-plans-v1_4/q72/explain.txt | 2 +- .../approved-plans-v1_4/q85/explain.txt | 2 +- .../approved-plans-v1_4/q86.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q86/explain.txt | 2 +- .../approved-plans-v2_7/q24.sf100/explain.txt | 2 +- .../approved-plans-v2_7/q49/explain.txt | 6 +-- .../approved-plans-v2_7/q5a/explain.txt | 2 +- .../approved-plans-v2_7/q64/explain.txt | 4 +- .../approved-plans-v2_7/q72/explain.txt | 2 +- 40 files changed, 93 insertions(+), 90 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain index 6d5eb29944d52..f89a8be7ceedb 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftleft.explain @@ -1,2 +1,2 @@ -Project [(cast(b#0 as int) << 2) AS (b << 2)#0] +Project [shiftleft(cast(b#0 as int), 2) AS shiftleft(b, 2)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain index b1c2c35ac2d0e..b436f52e912b5 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftright.explain @@ -1,2 +1,2 @@ -Project [(cast(b#0 as int) >> 2) AS (b >> 2)#0] +Project [shiftright(cast(b#0 as int), 2) AS shiftright(b, 2)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain index 508c78a7f421f..282ad156b3825 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_shiftrightunsigned.explain @@ -1,2 +1,2 @@ -Project [(cast(b#0 as int) >>> 2) AS (b >>> 2)#0] +Project [shiftrightunsigned(cast(b#0 as int), 2) AS shiftrightunsigned(b, 2)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain index f46fa38989ed4..3b7d6fb2b7072 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/grouping_and_grouping_id.explain @@ -1,4 +1,4 @@ -Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, cast(((spark_grouping_id#0L >> 1) & 1) as tinyint) AS grouping(a)#0, cast(((spark_grouping_id#0L >> 0) & 1) as tinyint) AS grouping(b)#0, spark_grouping_id#0L AS grouping_id(a, b)#0L] +Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, cast((shiftright(spark_grouping_id#0L, 1) & 1) as tinyint) AS grouping(a)#0, cast((shiftright(spark_grouping_id#0L, 0) & 1) as tinyint) AS grouping(b)#0, spark_grouping_id#0L AS grouping_id(a, b)#0L] +- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], [id#0L, a#0, b#0, null, b#0, 2], [id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, spark_grouping_id#0L] +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 4bb0c658eacf1..8df46500ddcf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1283,15 +1283,15 @@ sealed trait BitShiftOperation } override def toString: String = { - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(symbol) match { - case alias if alias == symbol => s"($left $symbol $right)" + getTagValue(FunctionRegistry.FUNC_ALIAS) match { + case Some(alias) if alias == symbol => s"($left $symbol $right)" case _ => super.toString } } override def sql: String = { - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(symbol) match { - case alias if alias == symbol => s"(${left.sql} $symbol ${right.sql})" + getTagValue(FunctionRegistry.FUNC_ALIAS) match { + case Some(alias) if alias == symbol => s"(${left.sql} $symbol ${right.sql})" case _ => super.sql } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 52c32530f2e92..e2c975433ebdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PARTITION_SPECIFICATION import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, ClusterBySpec} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last} @@ -2200,11 +2201,13 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val left = expression(ctx.left) val right = expression(ctx.right) val operator = ctx.shiftOperator().getChild(0).asInstanceOf[TerminalNode] - operator.getSymbol.getType match { + val shift = operator.getSymbol.getType match { case SqlBaseParser.SHIFT_LEFT => ShiftLeft(left, right) case SqlBaseParser.SHIFT_RIGHT => ShiftRight(left, right) case SqlBaseParser.SHIFT_RIGHT_UNSIGNED => ShiftRightUnsigned(left, right) } + shift.setTagValue(FUNC_ALIAS, operator.getText) + shift } /** diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index bf46fe91eb903..65b513264598b 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -289,12 +289,12 @@ | org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct | -| org.apache.spark.sql.catalyst.expressions.ShiftLeft | << | SELECT shiftleft(2, 1) | struct<(2 << 1):int> | -| org.apache.spark.sql.catalyst.expressions.ShiftLeft | shiftleft | SELECT shiftleft(2, 1) | struct<(2 << 1):int> | -| org.apache.spark.sql.catalyst.expressions.ShiftRight | >> | SELECT shiftright(4, 1) | struct<(4 >> 1):int> | -| org.apache.spark.sql.catalyst.expressions.ShiftRight | shiftright | SELECT shiftright(4, 1) | struct<(4 >> 1):int> | -| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | >>> | SELECT shiftrightunsigned(4, 1) | struct<(4 >>> 1):int> | -| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | shiftrightunsigned | SELECT shiftrightunsigned(4, 1) | struct<(4 >>> 1):int> | +| org.apache.spark.sql.catalyst.expressions.ShiftLeft | << | SELECT shiftleft(2, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ShiftLeft | shiftleft | SELECT shiftleft(2, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ShiftRight | >> | SELECT shiftright(4, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ShiftRight | shiftright | SELECT shiftright(4, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | >>> | SELECT shiftrightunsigned(4, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | shiftrightunsigned | SELECT shiftrightunsigned(4, 1) | struct | | org.apache.spark.sql.catalyst.expressions.Shuffle | shuffle | SELECT shuffle(array(1, 20, 3, 5)) | struct> | | org.apache.spark.sql.catalyst.expressions.Signum | sign | SELECT sign(40) | struct | | org.apache.spark.sql.catalyst.expressions.Signum | signum | SELECT signum(40) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out index f8967d7df0b0c..cdb6372bec099 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out @@ -316,7 +316,7 @@ Sort [course#x ASC NULLS FIRST, sum#xL ASC NULLS FIRST], true SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) -- !query analysis -Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -382,7 +382,7 @@ HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, yea -- !query analysis Sort [course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true +- Project [course#x, year#x] - +- Filter ((cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) + +- Filter ((cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] @@ -435,8 +435,8 @@ SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY ORDER BY GROUPING(course), GROUPING(year), course, year -- !query analysis Project [course#x, year#x, grouping(course)#x, grouping(year)#x] -+- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true - +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] ++- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true + +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -452,7 +452,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year -- !query analysis Project [course#x, year#x, grouping_id(course, year)#xL] -+- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true ++- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL AS grouping_id(course, year)#xL, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out index cbbcb77325348..b73ee16c8bdef 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out @@ -72,7 +72,7 @@ Aggregate [c1#x, spark_grouping_id#xL], [c1#x, sum(c2#x) AS sum(c2)#xL] -- !query SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) -- !query analysis -Aggregate [c1#x, spark_grouping_id#xL], [c1#x, sum(c2#x) AS sum(c2)#xL, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(c1)#x] +Aggregate [c1#x, spark_grouping_id#xL], [c1#x, sum(c2#x) AS sum(c2)#xL, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(c1)#x] +- Expand [[c1#x, c2#x, c3#x, c1#x, 0]], [c1#x, c2#x, c3#x, c1#x, spark_grouping_id#xL] +- Project [c1#x, c2#x, c3#x, c1#x AS c1#x] +- SubqueryAlias t @@ -98,7 +98,7 @@ Filter (grouping__id#xL > cast(1 as bigint)) -- !query SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2) -- !query analysis -Aggregate [c1#x, c2#x, spark_grouping_id#xL], [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(c1)#x] +Aggregate [c1#x, c2#x, spark_grouping_id#xL], [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(c1)#x] +- Expand [[c1#x, c2#x, c3#x, c1#x, null, 1], [c1#x, c2#x, c3#x, null, c2#x, 2]], [c1#x, c2#x, c3#x, c1#x, c2#x, spark_grouping_id#xL] +- Project [c1#x, c2#x, c3#x, c1#x AS c1#x, c2#x AS c2#x] +- SubqueryAlias t @@ -223,7 +223,7 @@ Aggregate [k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x], [spark_groupi -- !query SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1)) -- !query analysis -Aggregate [k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x], [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(k1)#x, k1#x, k2#x, avg(v#x) AS avg(v)#x] +Aggregate [k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x], [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(k1)#x, k1#x, k2#x, avg(v#x) AS avg(v)#x] +- Expand [[k1#x, k2#x, v#x, k1#x, null, 1, 0], [k1#x, k2#x, v#x, k1#x, k2#x, 0, 1], [k1#x, k2#x, v#x, k1#x, k2#x, 0, 2]], [k1#x, k2#x, v#x, k1#x, k2#x, spark_grouping_id#xL, _gen_grouping_pos#x] +- Project [k1#x, k2#x, v#x, k1#x AS k1#x, k2#x AS k2#x] +- SubqueryAlias t diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out index d2a25fabe2059..27e9707425833 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/groupingsets.sql.out @@ -82,7 +82,7 @@ CreateDataSourceTableCommand `spark_catalog`.`default`.`gstest_empty`, false select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -96,7 +96,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) order by a,b -- !query analysis Sort [a#x ASC NULLS FIRST, b#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -110,7 +110,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) order by b desc, a -- !query analysis Sort [b#x DESC NULLS LAST, a#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -124,7 +124,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by rollup (a,b) order by coalesce(a,0)+coalesce(b,0), a -- !query analysis Sort [(coalesce(a#x, 0) + coalesce(b#x, 0)) ASC NULLS FIRST, a#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -209,7 +209,7 @@ select t1.a, t2.b, grouping(t1.a), grouping(t2.b), sum(t1.v), max(t2.a) from gstest1 t1, gstest2 t2 group by grouping sets ((t1.a, t2.b), ()) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x AS a#x, b#x AS b#x] +- Join Inner @@ -228,7 +228,7 @@ select t1.a, t2.b, grouping(t1.a), grouping(t2.b), sum(t1.v), max(t2.a) from gstest1 t1 join gstest2 t2 on (t1.a=t2.a) group by grouping sets ((t1.a, t2.b), ()) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(a#x) AS max(a)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x, b#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x AS a#x, b#x AS b#x] +- Join Inner, (a#x = a#x) @@ -247,7 +247,7 @@ select a, b, grouping(a), grouping(b), sum(t1.v), max(t2.c) from gstest1 t1 join gstest2 t2 using (a,b) group by grouping sets ((a, b), ()) -- !query analysis -Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(c#x) AS max(c)#x] +Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, max(c#x) AS max(c)#x] +- Expand [[a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, 0], [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, null, null, 3]], [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x, a#x AS a#x, b#x AS b#x] +- Project [a#x, b#x, v#x, c#x, d#x, e#x, f#x, g#x, h#x] @@ -402,8 +402,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) >= 0) - +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) >= 0) + +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, 0]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x] +- SubqueryAlias spark_catalog.default.onek @@ -417,8 +417,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) as int) > 0) - +- Aggregate [ten#x, four#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) as int) > 0) + +- Aggregate [ten#x, four#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, null, 1], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, null, four#x, 2]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, four#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x, four#x AS four#x] +- SubqueryAlias spark_catalog.default.onek @@ -432,8 +432,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) > 0) - +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) > 0) + +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, 0], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, null, 1]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x] +- SubqueryAlias spark_catalog.default.onek @@ -447,8 +447,8 @@ order by 2,1 -- !query analysis Sort [grouping(ten)#x ASC NULLS FIRST, ten#x ASC NULLS FIRST], true +- Project [ten#x, grouping(ten)#x] - +- Filter (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) > 0) - +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] + +- Filter (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) > 0) + +- Aggregate [ten#x, spark_grouping_id#xL], [ten#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(ten)#x, spark_grouping_id#xL] +- Expand [[unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, 0], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, null, 1]], [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x, spark_grouping_id#xL] +- Project [unique1#x, unique2#x, two#x, four#x, ten#x, twenty#x, hundred#x, thousand#x, twothousand#x, fivethous#x, tenthous#x, odd#x, even#x, stringu1#x, stringu2#x, string4#x, ten#x AS ten#x] +- SubqueryAlias spark_catalog.default.onek @@ -482,7 +482,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by grouping sets ((a),(b)) order by 3,4,1,2 /* 3,1,2 */ -- !query analysis Sort [grouping(a)#x ASC NULLS FIRST, grouping(b)#x ASC NULLS FIRST, a#x ASC NULLS FIRST, b#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, b#x, 2]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -496,7 +496,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by cube(a,b) order by 3,4,1,2 /* 3,1,2 */ -- !query analysis Sort [grouping(a)#x ASC NULLS FIRST, grouping(b)#x ASC NULLS FIRST, a#x ASC NULLS FIRST, b#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, 0], [a#x, b#x, v#x, a#x, null, 1], [a#x, b#x, v#x, null, b#x, 2], [a#x, b#x, v#x, null, null, 3]], [a#x, b#x, v#x, a#x, b#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x] +- SubqueryAlias gstest1 @@ -526,7 +526,7 @@ select unhashable_col, unsortable_col, order by 3, 4, 6 /* 3, 5 */ -- !query analysis Sort [grouping(unhashable_col)#x ASC NULLS FIRST, grouping(unsortable_col)#x ASC NULLS FIRST, sum(v)#xL ASC NULLS FIRST], true -+- Aggregate [unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] ++- Aggregate [unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] +- Expand [[id#x, v#x, unhashable_col#x, unsortable_col#x, unhashable_col#x, null, 1], [id#x, v#x, unhashable_col#x, unsortable_col#x, null, unsortable_col#x, 2]], [id#x, v#x, unhashable_col#x, unsortable_col#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL] +- Project [id#x, v#x, unhashable_col#x, unsortable_col#x, unhashable_col#x AS unhashable_col#x, unsortable_col#x AS unsortable_col#x] +- SubqueryAlias spark_catalog.default.gstest4 @@ -541,7 +541,7 @@ select unhashable_col, unsortable_col, order by 3, 4, 6 /* 3,5 */ -- !query analysis Sort [grouping(unhashable_col)#x ASC NULLS FIRST, grouping(unsortable_col)#x ASC NULLS FIRST, sum(v)#xL ASC NULLS FIRST], true -+- Aggregate [v#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] ++- Aggregate [v#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL], [unhashable_col#x, unsortable_col#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(unhashable_col)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(unsortable_col)#x, count(1) AS count(1)#xL, sum(v#x) AS sum(v)#xL] +- Expand [[id#x, v#x, unhashable_col#x, unsortable_col#x, v#x, unhashable_col#x, null, 1], [id#x, v#x, unhashable_col#x, unsortable_col#x, v#x, null, unsortable_col#x, 2]], [id#x, v#x, unhashable_col#x, unsortable_col#x, v#x, unhashable_col#x, unsortable_col#x, spark_grouping_id#xL] +- Project [id#x, v#x, unhashable_col#x, unsortable_col#x, v#x AS v#x, unhashable_col#x AS unhashable_col#x, unsortable_col#x AS unsortable_col#x] +- SubqueryAlias spark_catalog.default.gstest4 @@ -593,7 +593,7 @@ select a, b, grouping(a), grouping(b), sum(v), count(*), max(v) from gstest1 group by grouping sets ((a,b),(a+1,b+1),(a+2,b+2)) order by 3,4,7 /* 3,6 */ -- !query analysis Sort [grouping(a)#x ASC NULLS FIRST, grouping(b)#x ASC NULLS FIRST, max(v)#x ASC NULLS FIRST], true -+- Aggregate [a#x, b#x, (a#x + 1)#x, (b#x + 1)#x, (a#x + 2)#x, (b#x + 2)#x, spark_grouping_id#xL], [a#x, b#x, cast(((spark_grouping_id#xL >> 5) & 1) as tinyint) AS grouping(a)#x, cast(((spark_grouping_id#xL >> 4) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] ++- Aggregate [a#x, b#x, (a#x + 1)#x, (b#x + 1)#x, (a#x + 2)#x, (b#x + 2)#x, spark_grouping_id#xL], [a#x, b#x, cast((shiftright(spark_grouping_id#xL, 5) & 1) as tinyint) AS grouping(a)#x, cast((shiftright(spark_grouping_id#xL, 4) & 1) as tinyint) AS grouping(b)#x, sum(v#x) AS sum(v)#xL, count(1) AS count(1)#xL, max(v#x) AS max(v)#x] +- Expand [[a#x, b#x, v#x, a#x, b#x, null, null, null, null, 15], [a#x, b#x, v#x, null, null, (a#x + 1)#x, (b#x + 1)#x, null, null, 51], [a#x, b#x, v#x, null, null, null, null, (a#x + 2)#x, (b#x + 2)#x, 60]], [a#x, b#x, v#x, a#x, b#x, (a#x + 1)#x, (b#x + 1)#x, (a#x + 2)#x, (b#x + 2)#x, spark_grouping_id#xL] +- Project [a#x, b#x, v#x, a#x AS a#x, b#x AS b#x, (a#x + 1) AS (a#x + 1)#x, (b#x + 1) AS (b#x + 1)#x, (a#x + 2) AS (a#x + 2)#x, (b#x + 2) AS (b#x + 2)#x] +- SubqueryAlias gstest1 @@ -634,7 +634,7 @@ select v||'a', case grouping(v||'a') when 1 then 1 else 0 end, count(*) group by rollup(i, v||'a') order by 1,3 -- !query analysis Sort [concat(v, a)#x ASC NULLS FIRST, count(1)#xL ASC NULLS FIRST], true -+- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] ++- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] +- Expand [[i#x, v#x, i#x, concat(v#x, a)#x, 0], [i#x, v#x, i#x, null, 1], [i#x, v#x, null, null, 3]], [i#x, v#x, i#x, concat(v#x, a)#x, spark_grouping_id#xL] +- Project [i#x, v#x, i#x AS i#x, concat(v#x, a) AS concat(v#x, a)#x] +- SubqueryAlias u @@ -647,7 +647,7 @@ select v||'a', case when grouping(v||'a') = 1 then 1 else 0 end, count(*) group by rollup(i, v||'a') order by 1,3 -- !query analysis Sort [concat(v, a)#x ASC NULLS FIRST, count(1)#xL ASC NULLS FIRST], true -+- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] ++- Aggregate [i#x, concat(v#x, a)#x, spark_grouping_id#xL], [concat(v#x, a)#x AS concat(v, a)#x, CASE WHEN (cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) THEN 1 ELSE 0 END AS CASE WHEN (grouping(concat(v, a)) = 1) THEN 1 ELSE 0 END#x, count(1) AS count(1)#xL] +- Expand [[i#x, v#x, i#x, concat(v#x, a)#x, 0], [i#x, v#x, i#x, null, 1], [i#x, v#x, null, null, 3]], [i#x, v#x, i#x, concat(v#x, a)#x, spark_grouping_id#xL] +- Project [i#x, v#x, i#x AS i#x, concat(v#x, a) AS concat(v#x, a)#x] +- SubqueryAlias u diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out index 3fa919434da79..9dda3c0dc42d4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int2.sql.out @@ -274,14 +274,14 @@ Project [ AS five#x, f1#x, (cast(f1#x as double) / cast(cast(2 as int) as double -- !query SELECT string(shiftleft(smallint(-1), 15)) -- !query analysis -Project [cast((cast(cast(-1 as smallint) as int) << 15) as string) AS (-1 << 15)#x] +Project [cast(shiftleft(cast(cast(-1 as smallint) as int), 15) as string) AS shiftleft(-1, 15)#x] +- OneRowRelation -- !query SELECT string(smallint(shiftleft(smallint(-1), 15))+1) -- !query analysis -Project [cast((cast(cast((cast(cast(-1 as smallint) as int) << 15) as smallint) as int) + 1) as string) AS ((-1 << 15) + 1)#x] +Project [cast((cast(cast(shiftleft(cast(cast(-1 as smallint) as int), 15) as smallint) as int) + 1) as string) AS (shiftleft(-1, 15) + 1)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out index f6a8b24f917d2..d261b59a4c5e2 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int4.sql.out @@ -411,14 +411,14 @@ Project [(cast((2 + 2) as double) / cast(2 as double)) AS two#x] -- !query SELECT string(shiftleft(int(-1), 31)) -- !query analysis -Project [cast((cast(-1 as int) << 31) as string) AS (-1 << 31)#x] +Project [cast(shiftleft(cast(-1 as int), 31) as string) AS shiftleft(-1, 31)#x] +- OneRowRelation -- !query SELECT string(int(shiftleft(int(-1), 31))+1) -- !query analysis -Project [cast((cast((cast(-1 as int) << 31) as int) + 1) as string) AS ((-1 << 31) + 1)#x] +Project [cast((cast(shiftleft(cast(-1 as int), 31) as int) + 1) as string) AS (shiftleft(-1, 31) + 1)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out index dfc96427b57ed..72972469fa6ef 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/int8.sql.out @@ -659,14 +659,14 @@ Project [id#xL] -- !query SELECT string(shiftleft(bigint(-1), 63)) -- !query analysis -Project [cast((cast(-1 as bigint) << 63) as string) AS (-1 << 63)#x] +Project [cast(shiftleft(cast(-1 as bigint), 63) as string) AS shiftleft(-1, 63)#x] +- OneRowRelation -- !query SELECT string(int(shiftleft(bigint(-1), 63))+1) -- !query analysis -Project [cast((cast((cast(-1 as bigint) << 63) as int) + 1) as string) AS ((-1 << 63) + 1)#x] +Project [cast((cast(shiftleft(cast(-1 as bigint), 63) as int) + 1) as string) AS (shiftleft(-1, 63) + 1)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out index 7d6eb0a83bf4e..fbee3e2c8c89f 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out @@ -189,7 +189,7 @@ Sort [cast(udf(cast(course#x as string)) as string) ASC NULLS FIRST, sum#xL ASC SELECT udf(course), udf(year), GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) -- !query analysis -Aggregate [course#x, year#x, spark_grouping_id#xL], [cast(udf(cast(course#x as string)) as string) AS udf(course)#x, cast(udf(cast(year#x as string)) as int) AS udf(year)#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +Aggregate [course#x, year#x, spark_grouping_id#xL], [cast(udf(cast(course#x as string)) as string) AS udf(course)#x, cast(udf(cast(year#x as string)) as int) AS udf(year)#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL AS grouping_id(course, year)#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -255,7 +255,7 @@ HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, udf -- !query analysis Sort [course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true +- Project [course#x, year#x] - +- Filter ((cast(cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) + +- Filter ((cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint))) +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] @@ -308,8 +308,8 @@ SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY ORDER BY GROUPING(course), GROUPING(year), course, udf(year) -- !query analysis Project [course#x, year#x, grouping(course)#x, grouping(year)#x] -+- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true - +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) AS grouping(course)#x, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] ++- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true + +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) AS grouping(course)#x, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) AS grouping(year)#x, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] +- SubqueryAlias coursesales @@ -325,7 +325,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, udf(year) -- !query analysis Project [course#x, year#x, grouping_id(course, year)#xL] -+- Sort [cast(((spark_grouping_id#xL >> 1) & 1) as tinyint) ASC NULLS FIRST, cast(((spark_grouping_id#xL >> 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true ++- Sort [cast((shiftright(spark_grouping_id#xL, 1) & 1) as tinyint) ASC NULLS FIRST, cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) ASC NULLS FIRST, course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true +- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL AS grouping_id(course, year)#xL, spark_grouping_id#xL, spark_grouping_id#xL] +- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL] +- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x] diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out index 1c96f8dfa5e54..ca55b6accc665 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int2.sql.out @@ -289,7 +289,7 @@ struct -- !query SELECT string(shiftleft(smallint(-1), 15)) -- !query schema -struct<(-1 << 15):string> +struct -- !query output -32768 @@ -297,7 +297,7 @@ struct<(-1 << 15):string> -- !query SELECT string(smallint(shiftleft(smallint(-1), 15))+1) -- !query schema -struct<((-1 << 15) + 1):string> +struct<(shiftleft(-1, 15) + 1):string> -- !query output -32767 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out index afe0211bd1947..16c18c86f2919 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int4.sql.out @@ -520,7 +520,7 @@ struct -- !query SELECT string(shiftleft(int(-1), 31)) -- !query schema -struct<(-1 << 31):string> +struct -- !query output -2147483648 @@ -528,7 +528,7 @@ struct<(-1 << 31):string> -- !query SELECT string(int(shiftleft(int(-1), 31))+1) -- !query schema -struct<((-1 << 31) + 1):string> +struct<(shiftleft(-1, 31) + 1):string> -- !query output -2147483647 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out index 6e7ca4afab67d..f6e4bd8bd7e08 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out @@ -883,7 +883,7 @@ struct -- !query SELECT string(shiftleft(bigint(-1), 63)) -- !query schema -struct<(-1 << 63):string> +struct -- !query output -9223372036854775808 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt index 6908b8137b0c4..850b20431e487 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q17/explain.txt @@ -102,7 +102,7 @@ Condition : (isnotnull(cs_bill_customer_sk#14) AND isnotnull(cs_item_sk#15)) (13) BroadcastExchange Input [4]: [cs_bill_customer_sk#14, cs_item_sk#15, cs_quantity#16, cs_sold_date_sk#17] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] (14) BroadcastHashJoin [codegen id : 8] Left keys [2]: [sr_customer_sk#9, sr_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt index 15b74bac0fbec..e2caa9f171b86 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q25/explain.txt @@ -102,7 +102,7 @@ Condition : (isnotnull(cs_bill_customer_sk#14) AND isnotnull(cs_item_sk#15)) (13) BroadcastExchange Input [4]: [cs_bill_customer_sk#14, cs_item_sk#15, cs_net_profit#16, cs_sold_date_sk#17] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] (14) BroadcastHashJoin [codegen id : 8] Left keys [2]: [sr_customer_sk#9, sr_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt index 4b7d4f2f068d6..6cc9c3a4834ee 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27.sf100/explain.txt @@ -167,7 +167,7 @@ Input [11]: [i_item_id#19, s_state#20, spark_grouping_id#21, sum#30, count#31, s Keys [3]: [i_item_id#19, s_state#20, spark_grouping_id#21] Functions [4]: [avg(ss_quantity#4), avg(UnscaledValue(ss_list_price#5)), avg(UnscaledValue(ss_coupon_amt#7)), avg(UnscaledValue(ss_sales_price#6))] Aggregate Attributes [4]: [avg(ss_quantity#4)#38, avg(UnscaledValue(ss_list_price#5))#39, avg(UnscaledValue(ss_coupon_amt#7))#40, avg(UnscaledValue(ss_sales_price#6))#41] -Results [7]: [i_item_id#19, s_state#20, cast(((spark_grouping_id#21 >> 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] +Results [7]: [i_item_id#19, s_state#20, cast((shiftright(spark_grouping_id#21, 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] (30) TakeOrderedAndProject Input [7]: [i_item_id#19, s_state#20, g_state#42, agg1#43, agg2#44, agg3#45, agg4#46] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt index 4b7d4f2f068d6..6cc9c3a4834ee 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q27/explain.txt @@ -167,7 +167,7 @@ Input [11]: [i_item_id#19, s_state#20, spark_grouping_id#21, sum#30, count#31, s Keys [3]: [i_item_id#19, s_state#20, spark_grouping_id#21] Functions [4]: [avg(ss_quantity#4), avg(UnscaledValue(ss_list_price#5)), avg(UnscaledValue(ss_coupon_amt#7)), avg(UnscaledValue(ss_sales_price#6))] Aggregate Attributes [4]: [avg(ss_quantity#4)#38, avg(UnscaledValue(ss_list_price#5))#39, avg(UnscaledValue(ss_coupon_amt#7))#40, avg(UnscaledValue(ss_sales_price#6))#41] -Results [7]: [i_item_id#19, s_state#20, cast(((spark_grouping_id#21 >> 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] +Results [7]: [i_item_id#19, s_state#20, cast((shiftright(spark_grouping_id#21, 0) & 1) as tinyint) AS g_state#42, avg(ss_quantity#4)#38 AS agg1#43, cast((avg(UnscaledValue(ss_list_price#5))#39 / 100.0) as decimal(11,6)) AS agg2#44, cast((avg(UnscaledValue(ss_coupon_amt#7))#40 / 100.0) as decimal(11,6)) AS agg3#45, cast((avg(UnscaledValue(ss_sales_price#6))#41 / 100.0) as decimal(11,6)) AS agg4#46] (30) TakeOrderedAndProject Input [7]: [i_item_id#19, s_state#20, g_state#42, agg1#43, agg2#44, agg3#45, agg4#46] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt index 27534390f0a24..76a6ab9c7215b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q29/explain.txt @@ -102,7 +102,7 @@ Condition : (isnotnull(cs_bill_customer_sk#14) AND isnotnull(cs_item_sk#15)) (13) BroadcastExchange Input [4]: [cs_bill_customer_sk#14, cs_item_sk#15, cs_quantity#16, cs_sold_date_sk#17] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [plan_id=2] (14) BroadcastHashJoin [codegen id : 8] Left keys [2]: [sr_customer_sk#9, sr_item_sk#8] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt index 63cb718a827f3..ea59f2b926c9d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36.sf100/explain.txt @@ -137,7 +137,7 @@ Input [5]: [i_category#13, i_class#14, spark_grouping_id#15, sum#18, sum#19] Keys [3]: [i_category#13, i_class#14, spark_grouping_id#15] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#20, sum(UnscaledValue(ss_ext_sales_price#3))#21] -Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast(((spark_grouping_id#15 >> 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] +Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] (24) Exchange Input [7]: [gross_margin#22, i_category#13, i_class#14, lochierarchy#23, _w0#24, _w1#25, _w2#26] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt index eb59673575aba..6cc55ab063f68 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q36/explain.txt @@ -137,7 +137,7 @@ Input [5]: [i_category#13, i_class#14, spark_grouping_id#15, sum#18, sum#19] Keys [3]: [i_category#13, i_class#14, spark_grouping_id#15] Functions [2]: [sum(UnscaledValue(ss_net_profit#4)), sum(UnscaledValue(ss_ext_sales_price#3))] Aggregate Attributes [2]: [sum(UnscaledValue(ss_net_profit#4))#20, sum(UnscaledValue(ss_ext_sales_price#3))#21] -Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast(((spark_grouping_id#15 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#15 >> 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast(((spark_grouping_id#15 >> 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] +Results [7]: [(MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS gross_margin#22, i_category#13, i_class#14, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS lochierarchy#23, (MakeDecimal(sum(UnscaledValue(ss_net_profit#4))#20,17,2) / MakeDecimal(sum(UnscaledValue(ss_ext_sales_price#3))#21,17,2)) AS _w0#24, (cast((shiftright(spark_grouping_id#15, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint)) AS _w1#25, CASE WHEN (cast((shiftright(spark_grouping_id#15, 0) & 1) as tinyint) = 0) THEN i_category#13 END AS _w2#26] (24) Exchange Input [7]: [gross_margin#22, i_category#13, i_class#14, lochierarchy#23, _w0#24, _w1#25, _w2#26] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt index 995b723c6e287..220598440e092 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39a/explain.txt @@ -237,7 +237,7 @@ Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, stdev#46, mean#47] (41) BroadcastExchange Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, mean#47, cov#48] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] (42) BroadcastHashJoin [codegen id : 10] Left keys [2]: [i_item_sk#6, w_warehouse_sk#7] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt index dba61c77b774e..585e748860557 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q39b/explain.txt @@ -237,7 +237,7 @@ Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, stdev#46, mean#47] (41) BroadcastExchange Input [5]: [w_warehouse_sk#32, i_item_sk#31, d_moy#35, mean#47, cov#48] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=5] (42) BroadcastHashJoin [codegen id : 10] Left keys [2]: [i_item_sk#6, w_warehouse_sk#7] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt index 93f79d66f0973..9eea658d789e4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q49/explain.txt @@ -99,7 +99,7 @@ Input [6]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_ne (5) BroadcastExchange Input [5]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_sold_date_sk#6] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] (6) Scan parquet spark_catalog.default.web_returns Output [5]: [wr_item_sk#8, wr_order_number#9, wr_return_quantity#10, wr_return_amt#11, wr_returned_date_sk#12] @@ -209,7 +209,7 @@ Input [6]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, c (29) BroadcastExchange Input [5]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, cs_sold_date_sk#41] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] (30) Scan parquet spark_catalog.default.catalog_returns Output [5]: [cr_item_sk#42, cr_order_number#43, cr_return_quantity#44, cr_return_amount#45, cr_returned_date_sk#46] @@ -319,7 +319,7 @@ Input [6]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, (53) BroadcastExchange Input [5]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, ss_sold_date_sk#75] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] (54) Scan parquet spark_catalog.default.store_returns Output [5]: [sr_item_sk#76, sr_ticket_number#77, sr_return_quantity#78, sr_return_amt#79, sr_returned_date_sk#80] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt index 93103073d6f85..313959456c809 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt @@ -304,7 +304,7 @@ Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, (49) BroadcastExchange Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, wr_returned_date_sk#96] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, true] as bigint) << 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] (50) Scan parquet spark_catalog.default.web_sales Output [4]: [ws_item_sk#97, ws_web_site_sk#98, ws_order_number#99, ws_sold_date_sk#100] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt index 3a049ca71e742..69023c88202af 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q64/explain.txt @@ -201,7 +201,7 @@ Condition : (((((((isnotnull(ss_item_sk#1) AND isnotnull(ss_ticket_number#8)) AN (4) BroadcastExchange Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_cdemo_sk#3, ss_hdemo_sk#4, ss_addr_sk#5, ss_store_sk#6, ss_promo_sk#7, ss_ticket_number#8, ss_wholesale_cost#9, ss_list_price#10, ss_coupon_amt#11, ss_sold_date_sk#12] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] (5) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#14, sr_ticket_number#15, sr_returned_date_sk#16] @@ -714,7 +714,7 @@ Condition : (((((((isnotnull(ss_item_sk#106) AND isnotnull(ss_ticket_number#113) (115) BroadcastExchange Input [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_ticket_number#113, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] (116) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#119, sr_ticket_number#120, sr_returned_date_sk#121] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt index b6b480018aa46..d64f560f144e0 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70.sf100/explain.txt @@ -224,7 +224,7 @@ Input [4]: [s_state#20, s_county#21, spark_grouping_id#22, sum#24] Keys [3]: [s_state#20, s_county#21, spark_grouping_id#22] Functions [1]: [sum(UnscaledValue(ss_net_profit#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ss_net_profit#2))#25] -Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast(((spark_grouping_id#22 >> 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] +Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] (39) Exchange Input [7]: [total_sum#26, s_state#20, s_county#21, lochierarchy#27, _w0#28, _w1#29, _w2#30] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt index 9495128a50e13..dade1b4f55c5f 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q70/explain.txt @@ -224,7 +224,7 @@ Input [4]: [s_state#20, s_county#21, spark_grouping_id#22, sum#24] Keys [3]: [s_state#20, s_county#21, spark_grouping_id#22] Functions [1]: [sum(UnscaledValue(ss_net_profit#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ss_net_profit#2))#25] -Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast(((spark_grouping_id#22 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#22 >> 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast(((spark_grouping_id#22 >> 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] +Results [7]: [MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS total_sum#26, s_state#20, s_county#21, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS lochierarchy#27, MakeDecimal(sum(UnscaledValue(ss_net_profit#2))#25,17,2) AS _w0#28, (cast((shiftright(spark_grouping_id#22, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint)) AS _w1#29, CASE WHEN (cast((shiftright(spark_grouping_id#22, 0) & 1) as tinyint) = 0) THEN s_state#20 END AS _w2#30] (39) Exchange Input [7]: [total_sum#26, s_state#20, s_county#21, lochierarchy#27, _w0#28, _w1#29, _w2#30] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt index 31da928a5d7a3..12ba2db6323e4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q72/explain.txt @@ -264,7 +264,7 @@ Condition : (isnotnull(d_week_seq#26) AND isnotnull(d_date_sk#25)) (42) BroadcastExchange Input [2]: [d_date_sk#25, d_week_seq#26] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, false] as bigint) << 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, false] as bigint), 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] (43) BroadcastHashJoin [codegen id : 10] Left keys [2]: [d_week_seq#24, inv_date_sk#13] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt index 31c804c73eaef..af6632f4fb608 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q85/explain.txt @@ -66,7 +66,7 @@ Condition : ((((isnotnull(ws_item_sk#1) AND isnotnull(ws_order_number#3)) AND is (4) BroadcastExchange Input [7]: [ws_item_sk#1, ws_web_page_sk#2, ws_order_number#3, ws_quantity#4, ws_sales_price#5, ws_net_profit#6, ws_sold_date_sk#7] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[2, int, false] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[2, int, false] as bigint) & 4294967295))),false), [plan_id=1] (5) Scan parquet spark_catalog.default.web_returns Output [9]: [wr_item_sk#9, wr_refunded_cdemo_sk#10, wr_refunded_addr_sk#11, wr_returning_cdemo_sk#12, wr_reason_sk#13, wr_order_number#14, wr_fee#15, wr_refunded_cash#16, wr_returned_date_sk#17] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt index c496d20204875..d1802b2e4a7c6 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86.sf100/explain.txt @@ -98,7 +98,7 @@ Input [4]: [i_category#9, i_class#10, spark_grouping_id#11, sum#13] Keys [3]: [i_category#9, i_class#10, spark_grouping_id#11] Functions [1]: [sum(UnscaledValue(ws_net_paid#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ws_net_paid#2))#14] -Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast(((spark_grouping_id#11 >> 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] +Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] (17) Exchange Input [7]: [total_sum#15, i_category#9, i_class#10, lochierarchy#16, _w0#17, _w1#18, _w2#19] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt index c496d20204875..d1802b2e4a7c6 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q86/explain.txt @@ -98,7 +98,7 @@ Input [4]: [i_category#9, i_class#10, spark_grouping_id#11, sum#13] Keys [3]: [i_category#9, i_class#10, spark_grouping_id#11] Functions [1]: [sum(UnscaledValue(ws_net_paid#2))] Aggregate Attributes [1]: [sum(UnscaledValue(ws_net_paid#2))#14] -Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast(((spark_grouping_id#11 >> 1) & 1) as tinyint) + cast(((spark_grouping_id#11 >> 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast(((spark_grouping_id#11 >> 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] +Results [7]: [MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS total_sum#15, i_category#9, i_class#10, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS lochierarchy#16, MakeDecimal(sum(UnscaledValue(ws_net_paid#2))#14,17,2) AS _w0#17, (cast((shiftright(spark_grouping_id#11, 1) & 1) as tinyint) + cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint)) AS _w1#18, CASE WHEN (cast((shiftright(spark_grouping_id#11, 0) & 1) as tinyint) = 0) THEN i_category#9 END AS _w2#19] (17) Exchange Input [7]: [total_sum#15, i_category#9, i_class#10, lochierarchy#16, _w0#17, _w1#18, _w2#19] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt index e437dea8ca9a0..9d80077e99372 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q24.sf100/explain.txt @@ -125,7 +125,7 @@ Input [11]: [s_store_sk#1, s_store_name#2, s_state#4, ca_address_sk#6, ca_state# (17) BroadcastExchange Input [7]: [s_store_sk#1, s_store_name#2, s_state#4, ca_state#7, c_customer_sk#10, c_first_name#12, c_last_name#13] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, true] as bigint) << 32) | (cast(input[4, int, true] as bigint) & 4294967295))),false), [plan_id=3] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[4, int, true] as bigint) & 4294967295))),false), [plan_id=3] (18) Scan parquet spark_catalog.default.store_sales Output [6]: [ss_item_sk#15, ss_customer_sk#16, ss_store_sk#17, ss_ticket_number#18, ss_net_paid#19, ss_sold_date_sk#20] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt index ec609603ea35b..fea7a9fe207df 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q49/explain.txt @@ -99,7 +99,7 @@ Input [6]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_ne (5) BroadcastExchange Input [5]: [ws_item_sk#1, ws_order_number#2, ws_quantity#3, ws_net_paid#4, ws_sold_date_sk#6] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=1] (6) Scan parquet spark_catalog.default.web_returns Output [5]: [wr_item_sk#8, wr_order_number#9, wr_return_quantity#10, wr_return_amt#11, wr_returned_date_sk#12] @@ -209,7 +209,7 @@ Input [6]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, c (29) BroadcastExchange Input [5]: [cs_item_sk#36, cs_order_number#37, cs_quantity#38, cs_net_paid#39, cs_sold_date_sk#41] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=4] (30) Scan parquet spark_catalog.default.catalog_returns Output [5]: [cr_item_sk#42, cr_order_number#43, cr_return_quantity#44, cr_return_amount#45, cr_returned_date_sk#46] @@ -319,7 +319,7 @@ Input [6]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, (53) BroadcastExchange Input [5]: [ss_item_sk#70, ss_ticket_number#71, ss_quantity#72, ss_net_paid#73, ss_sold_date_sk#75] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, true] as bigint) << 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))),false), [plan_id=7] (54) Scan parquet spark_catalog.default.store_returns Output [5]: [sr_item_sk#76, sr_ticket_number#77, sr_return_quantity#78, sr_return_amt#79, sr_returned_date_sk#80] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt index 1b9bf5123e965..34c6ecf3cf2fa 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q5a/explain.txt @@ -317,7 +317,7 @@ Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, (49) BroadcastExchange Input [5]: [wr_item_sk#92, wr_order_number#93, wr_return_amt#94, wr_net_loss#95, wr_returned_date_sk#96] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, true] as bigint) << 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[1, int, true] as bigint) & 4294967295))),false), [plan_id=5] (50) Scan parquet spark_catalog.default.web_sales Output [4]: [ws_item_sk#97, ws_web_site_sk#98, ws_order_number#99, ws_sold_date_sk#100] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt index 4579b8bbe8197..40eddbbacf38a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q64/explain.txt @@ -201,7 +201,7 @@ Condition : (((((((isnotnull(ss_item_sk#1) AND isnotnull(ss_ticket_number#8)) AN (4) BroadcastExchange Input [12]: [ss_item_sk#1, ss_customer_sk#2, ss_cdemo_sk#3, ss_hdemo_sk#4, ss_addr_sk#5, ss_store_sk#6, ss_promo_sk#7, ss_ticket_number#8, ss_wholesale_cost#9, ss_list_price#10, ss_coupon_amt#11, ss_sold_date_sk#12] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=1] (5) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#14, sr_ticket_number#15, sr_returned_date_sk#16] @@ -714,7 +714,7 @@ Condition : (((((((isnotnull(ss_item_sk#106) AND isnotnull(ss_ticket_number#113) (115) BroadcastExchange Input [12]: [ss_item_sk#106, ss_customer_sk#107, ss_cdemo_sk#108, ss_hdemo_sk#109, ss_addr_sk#110, ss_store_sk#111, ss_promo_sk#112, ss_ticket_number#113, ss_wholesale_cost#114, ss_list_price#115, ss_coupon_amt#116, ss_sold_date_sk#117] -Arguments: HashedRelationBroadcastMode(List(((cast(input[0, int, false] as bigint) << 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[7, int, false] as bigint) & 4294967295))),false), [plan_id=16] (116) Scan parquet spark_catalog.default.store_returns Output [3]: [sr_item_sk#119, sr_ticket_number#120, sr_returned_date_sk#121] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt index 47974c9691023..13d7d1bc9c4d8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q72/explain.txt @@ -264,7 +264,7 @@ Condition : (isnotnull(d_week_seq#26) AND isnotnull(d_date_sk#25)) (42) BroadcastExchange Input [2]: [d_date_sk#25, d_week_seq#26] -Arguments: HashedRelationBroadcastMode(List(((cast(input[1, int, false] as bigint) << 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] +Arguments: HashedRelationBroadcastMode(List((shiftleft(cast(input[1, int, false] as bigint), 32) | (cast(input[0, int, false] as bigint) & 4294967295))),false), [plan_id=6] (43) BroadcastHashJoin [codegen id : 10] Left keys [2]: [d_week_seq#24, inv_date_sk#13] From 7fe1b93884aa8e9ba20f19351b8537c687b8f59c Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Tue, 28 May 2024 09:56:16 -0700 Subject: [PATCH 42/45] [SPARK-46841][SQL] Add collation support for ICU locales and collation specifiers ### What changes were proposed in this pull request? Languages and localization for collations are supported by ICU library. Collation naming format is as follows: ``` <2-letter language code>[_<4-letter script>][_<3-letter country code>][_specifier_specifier...] ``` Locale specifier consists of the first part of collation name (language + script + country). Locale specifiers need to be stable across ICU versions; to keep existing ids and names invariant we introduce golden file will locale table which should case CI failure on any silent changes. Currently supported optional specifiers: - `CS`/`CI` - case sensitivity, default is case-sensitive; supported by configuring ICU collation levels - `AS`/`AI` - accent sensitivity, default is accent-sensitive; supported by configuring ICU collation levels User can use collation specifiers in any order except of locale which is mandatory and must go first. There is a one-to-one mapping between collation ids and collation names defined in `CollationFactory`. ### Why are the changes needed? To add languages and localization support for collations. ### Does this PR introduce _any_ user-facing change? Yes, it adds new predefined collations. ### How was this patch tested? Added checks to `CollationFactorySuite` and ICU locale map golden file. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46180 from nikolamand-db/SPARK-46841. Authored-by: Nikola Mandic Signed-off-by: Wenchen Fan --- .../sql/catalyst/util/CollationFactory.java | 678 ++++++++++++++---- .../unsafe/types/CollationFactorySuite.scala | 323 ++++++++- .../resources/error/error-conditions.json | 4 +- .../spark/sql/PlanGenerationTestSuite.scala | 4 +- .../main/protobuf/spark/connect/types.proto | 2 +- .../common/DataTypeProtoConverter.scala | 9 +- .../query-tests/queries/csv_from_dataset.json | 2 +- .../queries/csv_from_dataset.proto.bin | Bin 158 -> 169 bytes .../queries/function_lit_array.json | 4 +- .../queries/function_lit_array.proto.bin | Bin 889 -> 911 bytes .../queries/function_typedLit.json | 32 +- .../queries/function_typedLit.proto.bin | Bin 1199 -> 1381 bytes .../queries/json_from_dataset.json | 2 +- .../queries/json_from_dataset.proto.bin | Bin 169 -> 180 bytes python/pyspark/sql/connect/proto/types_pb2.py | 78 +- .../pyspark/sql/connect/proto/types_pb2.pyi | 11 +- python/pyspark/sql/connect/types.py | 5 +- python/pyspark/sql/types.py | 27 +- .../apache/spark/sql/internal/SQLConf.scala | 15 +- .../CollationExpressionSuite.scala | 33 +- .../collations/ICU-collations-map.md | 143 ++++ .../analyzer-results/collations.sql.out | 77 ++ .../resources/sql-tests/inputs/collations.sql | 13 + .../sql-tests/results/collations.sql.out | 88 +++ .../org/apache/spark/sql/CollationSuite.scala | 2 +- .../spark/sql/ICUCollationsMapSuite.scala | 69 ++ .../spark/sql/internal/SQLConfSuite.scala | 3 +- 27 files changed, 1388 insertions(+), 236 deletions(-) create mode 100644 sql/core/src/test/resources/collations/ICU-collations-map.md create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 0133c3feb611a..fce12510afaf5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -19,6 +19,7 @@ import java.text.CharacterIterator; import java.text.StringCharacterIterator; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; import java.util.function.ToLongFunction; @@ -173,26 +174,546 @@ public Collation( } /** - * Constructor with comparators that are inherited from the given collator. + * Collation ID is defined as 32-bit integer. We specify binary layouts for different classes of + * collations. Classes of collations are differentiated by most significant 3 bits (bit 31, 30 + * and 29), bit 31 being most significant and bit 0 being least significant. + * --- + * General collation ID binary layout: + * bit 31: 1 for INDETERMINATE (requires all other bits to be 1 as well), 0 otherwise. + * bit 30: 0 for predefined, 1 for user-defined. + * Following bits are specified for predefined collations: + * bit 29: 0 for UTF8_BINARY, 1 for ICU collations. + * bit 28-24: Reserved. + * bit 23-22: Reserved for version. + * bit 21-18: Reserved for space trimming. + * bit 17-0: Depend on collation family. + * --- + * INDETERMINATE collation ID binary layout: + * bit 31-0: 1 + * INDETERMINATE collation ID is equal to -1. + * --- + * User-defined collation ID binary layout: + * bit 31: 0 + * bit 30: 1 + * bit 29-0: Undefined, reserved for future use. + * --- + * UTF8_BINARY collation ID binary layout: + * bit 31-24: Zeroes. + * bit 23-22: Zeroes, reserved for version. + * bit 21-18: Zeroes, reserved for space trimming. + * bit 17-3: Zeroes. + * bit 2: 0, reserved for accent sensitivity. + * bit 1: 0, reserved for uppercase and case-insensitive. + * bit 0: 0 = case-sensitive, 1 = lowercase. + * --- + * ICU collation ID binary layout: + * bit 31-30: Zeroes. + * bit 29: 1 + * bit 28-24: Zeroes. + * bit 23-22: Zeroes, reserved for version. + * bit 21-18: Zeroes, reserved for space trimming. + * bit 17: 0 = case-sensitive, 1 = case-insensitive. + * bit 16: 0 = accent-sensitive, 1 = accent-insensitive. + * bit 15-14: Zeroes, reserved for punctuation sensitivity. + * bit 13-12: Zeroes, reserved for first letter preference. + * bit 11-0: Locale ID as specified in `ICULocaleToId` mapping. + * --- + * Some illustrative examples of collation name to ID mapping: + * - UTF8_BINARY -> 0 + * - UTF8_BINARY_LCASE -> 1 + * - UNICODE -> 0x20000000 + * - UNICODE_AI -> 0x20010000 + * - UNICODE_CI -> 0x20020000 + * - UNICODE_CI_AI -> 0x20030000 + * - af -> 0x20000001 + * - af_CI_AI -> 0x20030001 */ - public Collation( - String collationName, - String provider, - Collator collator, - String version, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality) { - this( - collationName, - provider, - collator, - (s1, s2) -> collator.compare(s1.toString(), s2.toString()), - version, - s -> (long)collator.getCollationKey(s.toString()).hashCode(), - supportsBinaryEquality, - supportsBinaryOrdering, - supportsLowercaseEquality); + private abstract static class CollationSpec { + + /** + * Bit 30 in collation ID having value 0 for predefined and 1 for user-defined collation. + */ + private enum DefinitionOrigin { + PREDEFINED, USER_DEFINED + } + + /** + * Bit 29 in collation ID having value 0 for UTF8_BINARY family and 1 for ICU family of + * collations. + */ + protected enum ImplementationProvider { + UTF8_BINARY, ICU + } + + /** + * Offset in binary collation ID layout. + */ + private static final int DEFINITION_ORIGIN_OFFSET = 30; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int DEFINITION_ORIGIN_MASK = 0b1; + + /** + * Offset in binary collation ID layout. + */ + protected static final int IMPLEMENTATION_PROVIDER_OFFSET = 29; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b1; + + private static final int INDETERMINATE_COLLATION_ID = -1; + + /** + * Thread-safe cache mapping collation IDs to corresponding `Collation` instances. + * We add entries to this cache lazily as new `Collation` instances are requested. + */ + private static final Map collationMap = new ConcurrentHashMap<>(); + + /** + * Utility function to retrieve `ImplementationProvider` enum instance from collation ID. + */ + private static ImplementationProvider getImplementationProvider(int collationId) { + return ImplementationProvider.values()[SpecifierUtils.getSpecValue(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK)]; + } + + /** + * Utility function to retrieve `DefinitionOrigin` enum instance from collation ID. + */ + private static DefinitionOrigin getDefinitionOrigin(int collationId) { + return DefinitionOrigin.values()[SpecifierUtils.getSpecValue(collationId, + DEFINITION_ORIGIN_OFFSET, DEFINITION_ORIGIN_MASK)]; + } + + /** + * Main entry point for retrieving `Collation` instance from collation ID. + */ + private static Collation fetchCollation(int collationId) { + // User-defined collations and INDETERMINATE collations cannot produce a `Collation` + // instance. + assert (collationId >= 0 && getDefinitionOrigin(collationId) + == DefinitionOrigin.PREDEFINED); + if (collationId == UTF8_BINARY_COLLATION_ID) { + // Skip cache. + return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION; + } else if (collationMap.containsKey(collationId)) { + // Already in cache. + return collationMap.get(collationId); + } else { + // Build `Collation` instance and put into cache. + CollationSpec spec; + ImplementationProvider implementationProvider = getImplementationProvider(collationId); + if (implementationProvider == ImplementationProvider.UTF8_BINARY) { + spec = CollationSpecUTF8Binary.fromCollationId(collationId); + } else { + spec = CollationSpecICU.fromCollationId(collationId); + } + Collation collation = spec.buildCollation(); + collationMap.put(collationId, collation); + return collation; + } + } + + protected static SparkException collationInvalidNameException(String collationName) { + return new SparkException("COLLATION_INVALID_NAME", + SparkException.constructMessageParams(Map.of("collationName", collationName)), null); + } + + private static int collationNameToId(String collationName) throws SparkException { + // Collation names provided by user are treated as case-insensitive. + String collationNameUpper = collationName.toUpperCase(); + if (collationNameUpper.startsWith("UTF8_BINARY")) { + return CollationSpecUTF8Binary.collationNameToId(collationName, collationNameUpper); + } else { + return CollationSpecICU.collationNameToId(collationName, collationNameUpper); + } + } + + protected abstract Collation buildCollation(); + } + + private static class CollationSpecUTF8Binary extends CollationSpec { + + /** + * Bit 0 in collation ID having value 0 for plain UTF8_BINARY and 1 for UTF8_BINARY_LCASE + * collation. + */ + private enum CaseSensitivity { + UNSPECIFIED, LCASE + } + + /** + * Offset in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_OFFSET = 0; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_MASK = 0b1; + + private static final int UTF8_BINARY_COLLATION_ID = + new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).collationId; + private static final int UTF8_BINARY_LCASE_COLLATION_ID = + new CollationSpecUTF8Binary(CaseSensitivity.LCASE).collationId; + protected static Collation UTF8_BINARY_COLLATION = + new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).buildCollation(); + protected static Collation UTF8_BINARY_LCASE_COLLATION = + new CollationSpecUTF8Binary(CaseSensitivity.LCASE).buildCollation(); + + private final int collationId; + + private CollationSpecUTF8Binary(CaseSensitivity caseSensitivity) { + this.collationId = + SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET, caseSensitivity); + } + + private static int collationNameToId(String originalName, String collationName) + throws SparkException { + if (UTF8_BINARY_COLLATION.collationName.equals(collationName)) { + return UTF8_BINARY_COLLATION_ID; + } else if (UTF8_BINARY_LCASE_COLLATION.collationName.equals(collationName)) { + return UTF8_BINARY_LCASE_COLLATION_ID; + } else { + // Throw exception with original (before case conversion) collation name. + throw collationInvalidNameException(originalName); + } + } + + private static CollationSpecUTF8Binary fromCollationId(int collationId) { + // Extract case sensitivity from collation ID. + int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); + // Verify only case sensitivity bits were set settable in UTF8_BINARY family of collations. + assert (SpecifierUtils.removeSpec(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0); + return new CollationSpecUTF8Binary(CaseSensitivity.values()[caseConversionOrdinal]); + } + + @Override + protected Collation buildCollation() { + if (collationId == UTF8_BINARY_COLLATION_ID) { + return new Collation( + "UTF8_BINARY", + PROVIDER_SPARK, + null, + UTF8String::binaryCompare, + "1.0", + s -> (long) s.hashCode(), + /* supportsBinaryEquality = */ true, + /* supportsBinaryOrdering = */ true, + /* supportsLowercaseEquality = */ false); + } else { + return new Collation( + "UTF8_BINARY_LCASE", + PROVIDER_SPARK, + null, + UTF8String::compareLowerCase, + "1.0", + s -> (long) s.toLowerCase().hashCode(), + /* supportsBinaryEquality = */ false, + /* supportsBinaryOrdering = */ false, + /* supportsLowercaseEquality = */ true); + } + } + } + + private static class CollationSpecICU extends CollationSpec { + + /** + * Bit 17 in collation ID having value 0 for case-sensitive and 1 for case-insensitive + * collation. + */ + private enum CaseSensitivity { + CS, CI + } + + /** + * Bit 16 in collation ID having value 0 for accent-sensitive and 1 for accent-insensitive + * collation. + */ + private enum AccentSensitivity { + AS, AI + } + + /** + * Offset in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_OFFSET = 17; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_MASK = 0b1; + + /** + * Offset in binary collation ID layout. + */ + private static final int ACCENT_SENSITIVITY_OFFSET = 16; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int ACCENT_SENSITIVITY_MASK = 0b1; + + /** + * Array of locale names, each locale ID corresponds to the index in this array. + */ + private static final String[] ICULocaleNames; + + /** + * Mapping of locale names to corresponding `ULocale` instance. + */ + private static final Map ICULocaleMap = new HashMap<>(); + + /** + * Used to parse user input collation names which are converted to uppercase. + */ + private static final Map ICULocaleMapUppercase = new HashMap<>(); + + /** + * Reverse mapping of `ICULocaleNames`. + */ + private static final Map ICULocaleToId = new HashMap<>(); + + /** + * ICU library Collator version passed to `Collation` instance. + */ + private static final String ICU_COLLATOR_VERSION = "153.120.0.0"; + + static { + ICULocaleMap.put("UNICODE", ULocale.ROOT); + // ICU-implemented `ULocale`s which have corresponding `Collator` installed. + ULocale[] locales = Collator.getAvailableULocales(); + // Build locale names in format: language["_" optional script]["_" optional country code]. + // Examples: en, en_USA, sr_Cyrl_SRB + for (ULocale locale : locales) { + // Skip variants. + if (locale.getVariant().isEmpty()) { + String language = locale.getLanguage(); + // Require non-empty language as first component of locale name. + assert (!language.isEmpty()); + StringBuilder builder = new StringBuilder(language); + // Script tag. + String script = locale.getScript(); + if (!script.isEmpty()) { + builder.append('_'); + builder.append(script); + } + // 3-letter country code. + String country = locale.getISO3Country(); + if (!country.isEmpty()) { + builder.append('_'); + builder.append(country); + } + String localeName = builder.toString(); + // Verify locale names are unique. + assert (!ICULocaleMap.containsKey(localeName)); + ICULocaleMap.put(localeName, locale); + } + } + // Construct uppercase-normalized locale name mapping. + for (String localeName : ICULocaleMap.keySet()) { + String localeUppercase = localeName.toUpperCase(); + // Locale names are unique case-insensitively. + assert (!ICULocaleMapUppercase.containsKey(localeUppercase)); + ICULocaleMapUppercase.put(localeUppercase, localeName); + } + // Construct locale name to ID mapping. Locale ID is defined as index in `ICULocaleNames`. + ICULocaleNames = ICULocaleMap.keySet().toArray(new String[0]); + Arrays.sort(ICULocaleNames); + // Maximum number of locale IDs as defined by binary layout. + assert (ICULocaleNames.length <= (1 << 12)); + for (int i = 0; i < ICULocaleNames.length; ++i) { + ICULocaleToId.put(ICULocaleNames[i], i); + } + } + + private static final int UNICODE_COLLATION_ID = + new CollationSpecICU("UNICODE", CaseSensitivity.CS, AccentSensitivity.AS).collationId; + private static final int UNICODE_CI_COLLATION_ID = + new CollationSpecICU("UNICODE", CaseSensitivity.CI, AccentSensitivity.AS).collationId; + + private final CaseSensitivity caseSensitivity; + private final AccentSensitivity accentSensitivity; + private final String locale; + private final int collationId; + + private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, + AccentSensitivity accentSensitivity) { + this.locale = locale; + this.caseSensitivity = caseSensitivity; + this.accentSensitivity = accentSensitivity; + // Construct collation ID from locale, case-sensitivity and accent-sensitivity specifiers. + int collationId = ICULocaleToId.get(locale); + // Mandatory ICU implementation provider. + collationId = SpecifierUtils.setSpecValue(collationId, IMPLEMENTATION_PROVIDER_OFFSET, + ImplementationProvider.ICU); + collationId = SpecifierUtils.setSpecValue(collationId, CASE_SENSITIVITY_OFFSET, + caseSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, + accentSensitivity); + this.collationId = collationId; + } + + private static int collationNameToId( + String originalName, String collationName) throws SparkException { + // Search for the longest locale match because specifiers are designed to be different from + // script tag and country code, meaning the only valid locale name match can be the longest + // one. + int lastPos = -1; + for (int i = 1; i <= collationName.length(); i++) { + String localeName = collationName.substring(0, i); + if (ICULocaleMapUppercase.containsKey(localeName)) { + lastPos = i; + } + } + if (lastPos == -1) { + throw collationInvalidNameException(originalName); + } else { + String locale = collationName.substring(0, lastPos); + int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); + + // Try all combinations of AS/AI and CS/CI. + CaseSensitivity caseSensitivity; + AccentSensitivity accentSensitivity; + if (collationName.equals(locale) || + collationName.equals(locale + "_AS") || + collationName.equals(locale + "_CS") || + collationName.equals(locale + "_AS_CS") || + collationName.equals(locale + "_CS_AS") + ) { + caseSensitivity = CaseSensitivity.CS; + accentSensitivity = AccentSensitivity.AS; + } else if (collationName.equals(locale + "_CI") || + collationName.equals(locale + "_AS_CI") || + collationName.equals(locale + "_CI_AS")) { + caseSensitivity = CaseSensitivity.CI; + accentSensitivity = AccentSensitivity.AS; + } else if (collationName.equals(locale + "_AI") || + collationName.equals(locale + "_CS_AI") || + collationName.equals(locale + "_AI_CS")) { + caseSensitivity = CaseSensitivity.CS; + accentSensitivity = AccentSensitivity.AI; + } else if (collationName.equals(locale + "_AI_CI") || + collationName.equals(locale + "_CI_AI")) { + caseSensitivity = CaseSensitivity.CI; + accentSensitivity = AccentSensitivity.AI; + } else { + throw collationInvalidNameException(originalName); + } + + // Build collation ID from computed specifiers. + collationId = SpecifierUtils.setSpecValue(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); + collationId = SpecifierUtils.setSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, caseSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + return collationId; + } + } + + private static CollationSpecICU fromCollationId(int collationId) { + // Parse specifiers from collation ID. + int caseSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); + int accentSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, + ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK); + // Locale ID remains after removing all other specifiers. + int localeId = collationId; + // Verify locale ID is valid against `ICULocaleNames` array. + assert (localeId < ICULocaleNames.length); + CaseSensitivity caseSensitivity = CaseSensitivity.values()[caseSensitivityOrdinal]; + AccentSensitivity accentSensitivity = AccentSensitivity.values()[accentSensitivityOrdinal]; + String locale = ICULocaleNames[localeId]; + return new CollationSpecICU(locale, caseSensitivity, accentSensitivity); + } + + @Override + protected Collation buildCollation() { + ULocale.Builder builder = new ULocale.Builder(); + builder.setLocale(ICULocaleMap.get(locale)); + // Compute unicode locale keyword for all combinations of case/accent sensitivity. + if (caseSensitivity == CaseSensitivity.CS && + accentSensitivity == AccentSensitivity.AS) { + builder.setUnicodeLocaleKeyword("ks", "level3"); + } else if (caseSensitivity == CaseSensitivity.CS && + accentSensitivity == AccentSensitivity.AI) { + builder + .setUnicodeLocaleKeyword("ks", "level1") + .setUnicodeLocaleKeyword("kc", "true"); + } else if (caseSensitivity == CaseSensitivity.CI && + accentSensitivity == AccentSensitivity.AS) { + builder.setUnicodeLocaleKeyword("ks", "level2"); + } else if (caseSensitivity == CaseSensitivity.CI && + accentSensitivity == AccentSensitivity.AI) { + builder.setUnicodeLocaleKeyword("ks", "level1"); + } + ULocale resultLocale = builder.build(); + Collator collator = Collator.getInstance(resultLocale); + // Freeze ICU collator to ensure thread safety. + collator.freeze(); + return new Collation( + collationName(), + PROVIDER_ICU, + collator, + (s1, s2) -> collator.compare(s1.toString(), s2.toString()), + ICU_COLLATOR_VERSION, + s -> (long) collator.getCollationKey(s.toString()).hashCode(), + /* supportsBinaryEquality = */ collationId == UNICODE_COLLATION_ID, + /* supportsBinaryOrdering = */ false, + /* supportsLowercaseEquality = */ false); + } + + /** + * Compute normalized collation name. Components of collation name are given in order: + * - Locale name + * - Optional case sensitivity when non-default preceded by underscore + * - Optional accent sensitivity when non-default preceded by underscore + * Examples: en, en_USA_CI_AI, sr_Cyrl_SRB_AI. + */ + private String collationName() { + StringBuilder builder = new StringBuilder(); + builder.append(locale); + if (caseSensitivity != CaseSensitivity.CS) { + builder.append('_'); + builder.append(caseSensitivity.toString()); + } + if (accentSensitivity != AccentSensitivity.AS) { + builder.append('_'); + builder.append(accentSensitivity.toString()); + } + return builder.toString(); + } + } + + /** + * Utility class for manipulating conversions between collation IDs and specifier enums/locale + * IDs. Scope bitwise operations here to avoid confusion. + */ + private static class SpecifierUtils { + private static int getSpecValue(int collationId, int offset, int mask) { + return (collationId >> offset) & mask; + } + + private static int removeSpec(int collationId, int offset, int mask) { + return collationId & ~(mask << offset); + } + + private static int setSpecValue(int collationId, int offset, Enum spec) { + return collationId | (spec.ordinal() << offset); + } } /** Returns the collation identifier. */ @@ -201,75 +722,20 @@ public CollationIdentifier identifier() { } } - private static final Collation[] collationTable = new Collation[4]; - private static final HashMap collationNameToIdMap = new HashMap<>(); - - public static final int UTF8_BINARY_COLLATION_ID = 0; - public static final int UTF8_BINARY_LCASE_COLLATION_ID = 1; - public static final String PROVIDER_SPARK = "spark"; public static final String PROVIDER_ICU = "icu"; public static final List SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU); - static { - // Binary comparison. This is the default collation. - // No custom comparators will be used for this collation. - // Instead, we rely on byte for byte comparison. - collationTable[0] = new Collation( - "UTF8_BINARY", - PROVIDER_SPARK, - null, - UTF8String::binaryCompare, - "1.0", - s -> (long)s.hashCode(), - true, - true, - false); - - // Case-insensitive UTF8 binary collation. - // TODO: Do in place comparisons instead of creating new strings. - collationTable[1] = new Collation( - "UTF8_BINARY_LCASE", - PROVIDER_SPARK, - null, - UTF8String::compareLowerCase, - "1.0", - (s) -> (long)s.toLowerCase().hashCode(), - false, - false, - true); - - // UNICODE case sensitive comparison (ROOT locale, in ICU). - collationTable[2] = new Collation( - "UNICODE", - PROVIDER_ICU, - Collator.getInstance(ULocale.ROOT), - "153.120.0.0", - true, - false, - false - ); - - collationTable[2].collator.setStrength(Collator.TERTIARY); - collationTable[2].collator.freeze(); - - // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary strength). - collationTable[3] = new Collation( - "UNICODE_CI", - PROVIDER_ICU, - Collator.getInstance(ULocale.ROOT), - "153.120.0.0", - false, - false, - false - ); - collationTable[3].collator.setStrength(Collator.SECONDARY); - collationTable[3].collator.freeze(); - - for (int i = 0; i < collationTable.length; i++) { - collationNameToIdMap.put(collationTable[i].collationName, i); - } - } + public static final int UTF8_BINARY_COLLATION_ID = + Collation.CollationSpecUTF8Binary.UTF8_BINARY_COLLATION_ID; + public static final int UTF8_BINARY_LCASE_COLLATION_ID = + Collation.CollationSpecUTF8Binary.UTF8_BINARY_LCASE_COLLATION_ID; + public static final int UNICODE_COLLATION_ID = + Collation.CollationSpecICU.UNICODE_COLLATION_ID; + public static final int UNICODE_CI_COLLATION_ID = + Collation.CollationSpecICU.UNICODE_CI_COLLATION_ID; + public static final int INDETERMINATE_COLLATION_ID = + Collation.CollationSpec.INDETERMINATE_COLLATION_ID; /** * Returns a StringSearch object for the given pattern and target strings, under collation @@ -297,23 +763,6 @@ public static StringSearch getStringSearch( return new StringSearch(patternString, target, (RuleBasedCollator) collator); } - /** - * Returns if the given collationName is valid one. - */ - public static boolean isValidCollation(String collationName) { - return collationNameToIdMap.containsKey(collationName.toUpperCase()); - } - - /** - * Returns closest valid name to collationName - */ - public static String getClosestCollation(String collationName) { - Collation suggestion = Collections.min(List.of(collationTable), Comparator.comparingInt( - c -> UTF8String.fromString(c.collationName).levenshteinDistance( - UTF8String.fromString(collationName.toUpperCase())))); - return suggestion.collationName; - } - /** * Returns a collation-unaware StringSearch object for the given pattern and target strings. * While this object does not respect collation, it can be used to find occurrences of the pattern @@ -326,24 +775,10 @@ public static StringSearch getStringSearch( } /** - * Returns the collation id for the given collation name. + * Returns the collation ID for the given collation name. */ public static int collationNameToId(String collationName) throws SparkException { - String normalizedName = collationName.toUpperCase(); - if (collationNameToIdMap.containsKey(normalizedName)) { - return collationNameToIdMap.get(normalizedName); - } else { - Collation suggestion = Collections.min(List.of(collationTable), Comparator.comparingInt( - c -> UTF8String.fromString(c.collationName).levenshteinDistance( - UTF8String.fromString(normalizedName)))); - - Map params = new HashMap<>(); - params.put("collationName", collationName); - params.put("proposal", suggestion.collationName); - - throw new SparkException( - "COLLATION_INVALID_NAME", SparkException.constructMessageParams(params), null); - } + return Collation.CollationSpec.collationNameToId(collationName); } public static void assertValidProvider(String provider) throws SparkException { @@ -359,12 +794,15 @@ public static void assertValidProvider(String provider) throws SparkException { } public static Collation fetchCollation(int collationId) { - return collationTable[collationId]; + return Collation.CollationSpec.fetchCollation(collationId); } public static Collation fetchCollation(String collationName) throws SparkException { - int collationId = collationNameToId(collationName); - return collationTable[collationId]; + return fetchCollation(collationNameToId(collationName)); + } + + public static String[] getICULocaleNames() { + return Collation.CollationSpecICU.ICULocaleNames; } public static UTF8String getCollationKey(UTF8String input, int collationId) { diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 768d26bf0e11e..69104dea0e992 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -20,7 +20,10 @@ package org.apache.spark.unsafe.types import scala.collection.parallel.immutable.ParSeq import scala.jdk.CollectionConverters.MapHasAsScala +import com.ibm.icu.util.ULocale + import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.util.CollationFactory.fetchCollation // scalastyle:off import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.must.Matchers @@ -30,31 +33,95 @@ import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8} class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ignore funsuite test("collationId stability") { - val utf8Binary = fetchCollation(0) + assert(INDETERMINATE_COLLATION_ID == -1) + + assert(UTF8_BINARY_COLLATION_ID == 0) + val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") assert(utf8Binary.supportsBinaryEquality) - val utf8BinaryLcase = fetchCollation(1) + assert(UTF8_BINARY_LCASE_COLLATION_ID == 1) + val utf8BinaryLcase = fetchCollation(UTF8_BINARY_LCASE_COLLATION_ID) assert(utf8BinaryLcase.collationName == "UTF8_BINARY_LCASE") assert(!utf8BinaryLcase.supportsBinaryEquality) - val unicode = fetchCollation(2) + assert(UNICODE_COLLATION_ID == (1 << 29)) + val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(unicode.supportsBinaryEquality); + assert(unicode.supportsBinaryEquality) - val unicodeCi = fetchCollation(3) + assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) + val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") assert(!unicodeCi.supportsBinaryEquality) } - test("fetch invalid collation name") { - val error = intercept[SparkException] { - fetchCollation("UTF8_BS") + test("UTF8_BINARY and ICU root locale collation names") { + // Collation name already normalized. + Seq( + "UTF8_BINARY", + "UTF8_BINARY_LCASE", + "UNICODE", + "UNICODE_CI", + "UNICODE_AI", + "UNICODE_CI_AI" + ).foreach(collationName => { + val col = fetchCollation(collationName) + assert(col.collationName == collationName) + }) + // Collation name normalization. + Seq( + // ICU root locale. + ("UNICODE_CS", "UNICODE"), + ("UNICODE_CS_AS", "UNICODE"), + ("UNICODE_CI_AS", "UNICODE_CI"), + ("UNICODE_AI_CS", "UNICODE_AI"), + ("UNICODE_AI_CI", "UNICODE_CI_AI"), + // Randomized case collation names. + ("utf8_binary", "UTF8_BINARY"), + ("UtF8_binARy_LcasE", "UTF8_BINARY_LCASE"), + ("unicode", "UNICODE"), + ("UnICoDe_cs_aI", "UNICODE_AI") + ).foreach{ + case (name, normalized) => + val col = fetchCollation(name) + assert(col.collationName == normalized) } + } + + test("fetch invalid UTF8_BINARY and ICU root locale collation names") { + Seq( + "UTF8_BINARY_CS", + "UTF8_BINARY_AS", + "UTF8_BINARY_CS_AS", + "UTF8_BINARY_AS_CS", + "UTF8_BINARY_CI", + "UTF8_BINARY_AI", + "UTF8_BINARY_CI_AI", + "UTF8_BINARY_AI_CI", + "UTF8_BS", + "BINARY_UTF8", + "UTF8_BINARY_A", + "UNICODE_X", + "UNICODE_CI_X", + "UNICODE_LCASE_X", + "UTF8_UNICODE", + "UTF8_BINARY_UNICODE", + "CI_UNICODE", + "LCASE_UNICODE", + "UNICODE_UNSPECIFIED", + "UNICODE_CI_UNSPECIFIED", + "UNICODE_UNSPECIFIED_CI_UNSPECIFIED", + "UNICODE_INDETERMINATE", + "UNICODE_CI_INDETERMINATE" + ).foreach(collationName => { + val error = intercept[SparkException] { + fetchCollation(collationName) + } - assert(error.getErrorClass === "COLLATION_INVALID_NAME") - assert(error.getMessageParameters.asScala === - Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS")) + assert(error.getErrorClass === "COLLATION_INVALID_NAME") + assert(error.getMessageParameters.asScala === Map("collationName" -> collationName)) + }) } case class CollationTestCase[R](collationName: String, s1: String, s2: String, expectedResult: R) @@ -152,4 +219,238 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig } }) } + + test("test collation caching") { + Seq( + "UTF8_BINARY", + "UTF8_BINARY_LCASE", + "UNICODE", + "UNICODE_CI", + "UNICODE_AI", + "UNICODE_CI_AI", + "UNICODE_AI_CI" + ).foreach(collationId => { + val col1 = fetchCollation(collationId) + val col2 = fetchCollation(collationId) + assert(col1 eq col2) // Check for reference equality. + }) + } + + test("collations with ICU non-root localization") { + Seq( + // Language only. + "en", + "en_CS", + "en_CI", + "en_AS", + "en_AI", + // Language + 3-letter country code. + "en_USA", + "en_USA_CS", + "en_USA_CI", + "en_USA_AS", + "en_USA_AI", + // Language + script code. + "sr_Cyrl", + "sr_Cyrl_CS", + "sr_Cyrl_CI", + "sr_Cyrl_AS", + "sr_Cyrl_AI", + // Language + script code + 3-letter country code. + "sr_Cyrl_SRB", + "sr_Cyrl_SRB_CS", + "sr_Cyrl_SRB_CI", + "sr_Cyrl_SRB_AS", + "sr_Cyrl_SRB_AI" + ).foreach(collationICU => { + val col = fetchCollation(collationICU) + assert(col.collator.getLocale(ULocale.VALID_LOCALE) != ULocale.ROOT) + }) + } + + test("invalid names of collations with ICU non-root localization") { + Seq( + "en_US", // Must use 3-letter country code + "enn", + "en_AAA", + "en_Something", + "en_Something_USA", + "en_LCASE", + "en_UCASE", + "en_CI_LCASE", + "en_CI_UCASE", + "en_CI_UNSPECIFIED", + "en_USA_UNSPECIFIED", + "en_USA_UNSPECIFIED_CI", + "en_INDETERMINATE", + "en_USA_INDETERMINATE", + "en_Latn_USA", // Use en_USA instead. + "en_Cyrl_USA", + "en_USA_AAA", + "sr_Cyrl_SRB_AAA", + // Invalid ordering of language, script and country code. + "USA_en", + "sr_SRB_Cyrl", + "SRB_sr", + "SRB_sr_Cyrl", + "SRB_Cyrl_sr", + "Cyrl_sr", + "Cyrl_sr_SRB", + "Cyrl_SRB_sr", + // Collation specifiers in the middle of locale. + "CI_en", + "USA_CI_en", + "en_CI_USA", + "CI_sr_Cyrl_SRB", + "sr_CI_Cyrl_SRB", + "sr_Cyrl_CI_SRB", + "CI_Cyrl_sr", + "Cyrl_CI_sr", + "Cyrl_CI_sr_SRB", + "Cyrl_sr_CI_SRB" + ).foreach(collationName => { + val error = intercept[SparkException] { + fetchCollation(collationName) + } + + assert(error.getErrorClass === "COLLATION_INVALID_NAME") + assert(error.getMessageParameters.asScala === Map("collationName" -> collationName)) + }) + } + + test("collations name normalization for ICU non-root localization") { + Seq( + ("en_USA", "en_USA"), + ("en_CS", "en"), + ("en_AS", "en"), + ("en_CS_AS", "en"), + ("en_AS_CS", "en"), + ("en_CI", "en_CI"), + ("en_AI", "en_AI"), + ("en_AI_CI", "en_CI_AI"), + ("en_CI_AI", "en_CI_AI"), + ("en_CS_AI", "en_AI"), + ("en_AI_CS", "en_AI"), + ("en_CI_AS", "en_CI"), + ("en_AS_CI", "en_CI"), + ("en_USA_AI_CI", "en_USA_CI_AI"), + // Randomized case. + ("EN_USA", "en_USA"), + ("SR_CYRL", "sr_Cyrl"), + ("sr_cyrl_srb", "sr_Cyrl_SRB"), + ("sR_cYRl_sRb", "sr_Cyrl_SRB") + ).foreach { + case (name, normalized) => + val col = fetchCollation(name) + assert(col.collationName == normalized) + } + } + + test("invalid collationId") { + val badCollationIds = Seq( + INDETERMINATE_COLLATION_ID, // Indeterminate collation. + 1 << 30, // User-defined collation range. + (1 << 30) | 1, // User-defined collation range. + (1 << 30) | (1 << 29), // User-defined collation range. + 1 << 1, // UTF8_BINARY mandatory zero bit 1 breach. + 1 << 2, // UTF8_BINARY mandatory zero bit 2 breach. + 1 << 3, // UTF8_BINARY mandatory zero bit 3 breach. + 1 << 4, // UTF8_BINARY mandatory zero bit 4 breach. + 1 << 5, // UTF8_BINARY mandatory zero bit 5 breach. + 1 << 6, // UTF8_BINARY mandatory zero bit 6 breach. + 1 << 7, // UTF8_BINARY mandatory zero bit 7 breach. + 1 << 8, // UTF8_BINARY mandatory zero bit 8 breach. + 1 << 9, // UTF8_BINARY mandatory zero bit 9 breach. + 1 << 10, // UTF8_BINARY mandatory zero bit 10 breach. + 1 << 11, // UTF8_BINARY mandatory zero bit 11 breach. + 1 << 12, // UTF8_BINARY mandatory zero bit 12 breach. + 1 << 13, // UTF8_BINARY mandatory zero bit 13 breach. + 1 << 14, // UTF8_BINARY mandatory zero bit 14 breach. + 1 << 15, // UTF8_BINARY mandatory zero bit 15 breach. + 1 << 16, // UTF8_BINARY mandatory zero bit 16 breach. + 1 << 17, // UTF8_BINARY mandatory zero bit 17 breach. + 1 << 18, // UTF8_BINARY mandatory zero bit 18 breach. + 1 << 19, // UTF8_BINARY mandatory zero bit 19 breach. + 1 << 20, // UTF8_BINARY mandatory zero bit 20 breach. + 1 << 23, // UTF8_BINARY mandatory zero bit 23 breach. + 1 << 24, // UTF8_BINARY mandatory zero bit 24 breach. + 1 << 25, // UTF8_BINARY mandatory zero bit 25 breach. + 1 << 26, // UTF8_BINARY mandatory zero bit 26 breach. + 1 << 27, // UTF8_BINARY mandatory zero bit 27 breach. + 1 << 28, // UTF8_BINARY mandatory zero bit 28 breach. + (1 << 29) | (1 << 12), // ICU mandatory zero bit 12 breach. + (1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach. + (1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach. + (1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach. + (1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach. + (1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach. + (1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach. + (1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach. + (1 << 29) | (1 << 22), // ICU mandatory zero bit 22 breach. + (1 << 29) | (1 << 23), // ICU mandatory zero bit 23 breach. + (1 << 29) | (1 << 24), // ICU mandatory zero bit 24 breach. + (1 << 29) | (1 << 25), // ICU mandatory zero bit 25 breach. + (1 << 29) | (1 << 26), // ICU mandatory zero bit 26 breach. + (1 << 29) | (1 << 27), // ICU mandatory zero bit 27 breach. + (1 << 29) | (1 << 28), // ICU mandatory zero bit 28 breach. + (1 << 29) | 0xFFFF // ICU with invalid locale id. + ) + badCollationIds.foreach(collationId => { + // Assumptions about collation id will break and assert statement will fail. + intercept[AssertionError](fetchCollation(collationId)) + }) + } + + test("repeated and/or incompatible specifiers in collation name") { + Seq( + "UTF8_BINARY_LCASE_LCASE", + "UNICODE_CS_CS", + "UNICODE_CI_CI", + "UNICODE_CI_CS", + "UNICODE_CS_CI", + "UNICODE_AS_AS", + "UNICODE_AI_AI", + "UNICODE_AS_AI", + "UNICODE_AI_AS", + "UNICODE_AS_CS_AI", + "UNICODE_CS_AI_CI", + "UNICODE_CS_AS_CI_AI" + ).foreach(collationName => { + val error = intercept[SparkException] { + fetchCollation(collationName) + } + + assert(error.getErrorClass === "COLLATION_INVALID_NAME") + assert(error.getMessageParameters.asScala === Map("collationName" -> collationName)) + }) + } + + test("basic ICU collator checks") { + Seq( + CollationTestCase("UNICODE_CI", "a", "A", true), + CollationTestCase("UNICODE_CI", "a", "å", false), + CollationTestCase("UNICODE_CI", "a", "Å", false), + CollationTestCase("UNICODE_AI", "a", "A", false), + CollationTestCase("UNICODE_AI", "a", "å", true), + CollationTestCase("UNICODE_AI", "a", "Å", false), + CollationTestCase("UNICODE_CI_AI", "a", "A", true), + CollationTestCase("UNICODE_CI_AI", "a", "å", true), + CollationTestCase("UNICODE_CI_AI", "a", "Å", true) + ).foreach(testCase => { + val collation = fetchCollation(testCase.collationName) + assert(collation.equalsFunction(toUTF8(testCase.s1), toUTF8(testCase.s2)) == + testCase.expectedResult) + }) + Seq( + CollationTestCase("en", "a", "A", -1), + CollationTestCase("en_CI", "a", "A", 0), + CollationTestCase("en_AI", "a", "å", 0), + CollationTestCase("sv", "Kypper", "Köpfe", -1), + CollationTestCase("de", "Kypper", "Köpfe", 1) + ).foreach(testCase => { + val collation = fetchCollation(testCase.collationName) + val result = collation.comparator.compare(toUTF8(testCase.s1), toUTF8(testCase.s2)) + assert(Integer.signum(result) == testCase.expectedResult) + }) + } } diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 883c51bffadec..b19b05859f786 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -469,7 +469,7 @@ }, "COLLATION_INVALID_NAME" : { "message" : [ - "The value does not represent a correct collation name. Suggested valid collation name: []." + "The value does not represent a correct collation name." ], "sqlState" : "42704" }, @@ -1921,7 +1921,7 @@ "subClass" : { "DEFAULT_COLLATION" : { "message" : [ - "Cannot resolve the given default collation. Did you mean ''?" + "Cannot resolve the given default collation." ] }, "TIME_ZONE" : { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 49b1a5312fdab..e0ad8f7078caf 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.{functions => fn} import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -699,7 +700,8 @@ class PlanGenerationTestSuite } test("select collated string") { - val schema = StructType(StructField("s", StringType(1)) :: Nil) + val schema = StructType( + StructField("s", StringType(CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID)) :: Nil) createLocalRelation(schema.catalogString).select("s") } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/types.proto b/connector/connect/common/src/main/protobuf/spark/connect/types.proto index 48f7385330c86..4f768f201575b 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/types.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/types.proto @@ -101,7 +101,7 @@ message DataType { message String { uint32 type_variation_reference = 1; - uint32 collation_id = 2; + string collation = 2; } message Binary { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index 1f580a0ffc0a3..f63692717947a 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.common import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils @@ -80,7 +81,7 @@ object DataTypeProtoConverter { } private def toCatalystStringType(t: proto.DataType.String): StringType = - StringType(t.getCollationId) + StringType(if (t.getCollation.nonEmpty) t.getCollation else "UTF8_BINARY") private def toCatalystYearMonthIntervalType(t: proto.DataType.YearMonthInterval) = { (t.hasStartField, t.hasEndField) match { @@ -177,7 +178,11 @@ object DataTypeProtoConverter { case s: StringType => proto.DataType .newBuilder() - .setString(proto.DataType.String.newBuilder().setCollationId(s.collationId).build()) + .setString( + proto.DataType.String + .newBuilder() + .setCollation(CollationFactory.fetchCollation(s.collationId).collationName) + .build()) .build() case CharType(length) => diff --git a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json index 33f6007ec68a1..e4b31258f984a 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json @@ -18,7 +18,7 @@ "name": "c1", "dataType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "nullable": true diff --git a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin index da4ad9bf9a4ed1611479dc6f4c75f33df9c561d6..c39243a10a8e4b06c04e2e3505669e780d4e9cd9 100644 GIT binary patch delta 46 zcmbQoxRQ~bi%Ed-6XU{(?3vOUs~AQM C1Pb;5 delta 35 rcmZ3Xo#CdypyM&V^HK|8zvQTBQAqY NjJgK@ld_I-@vI$T=27}dG>qj-h5Lqpsw;+;JG9D^b!Ut|QaOPq@--cSfb*?-16mP;AA7BOyOoX5x|v58R>!zdvpBL*Q{mWYYrkXp!$ zQ_7ZE+@*z)tC4XRV;vXBaGcgkF&QudF_RK*>v2deWEQVK%*1to=@8={Cawm?RgAS< z{NNy9U&P1)vQ>hSO93cggxhwY3J;(PVLz-Mqm&FJp&J7c34!% uu*itw^gxAB*&)VaF7m8^+91Iw!6?L>#h3_*3lvvCmB2$1q)Z7&Bmn?x!hNs+ delta 515 zcmaFLwVsomi%Eb{Y6a&;_I-?2id^!$7-hIvqL>62gqVyNq?of96G23h5|aTV5Hm?R zaM>MVwB}M>#i+!^!^ISDCJd\n\x08unparsed\x18\x18 \x01(\x0b\x32 .spark.connect.DataType.UnparsedH\x00R\x08unparsed\x1a\x43\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x42yte\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05Short\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Integer\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04Long\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05\x46loat\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x42\n\x06\x44ouble\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x65\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12!\n\x0c\x63ollation_id\x18\x02 \x01(\rR\x0b\x63ollationId\x1a\x42\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04NULL\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x45\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aH\n\x0cTimestampNTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aL\n\x10\x43\x61lendarInterval\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xb3\x01\n\x11YearMonthInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1a\xb1\x01\n\x0f\x44\x61yTimeInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1aX\n\x04\x43har\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a[\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x99\x01\n\x07\x44\x65\x63imal\x12\x19\n\x05scale\x18\x01 \x01(\x05H\x00R\x05scale\x88\x01\x01\x12!\n\tprecision\x18\x02 \x01(\x05H\x01R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x08\n\x06_scaleB\x0c\n\n_precision\x1a\xa1\x01\n\x0bStructField\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x34\n\tdata_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x08\x64\x61taType\x12\x1a\n\x08nullable\x18\x03 \x01(\x08R\x08nullable\x12\x1f\n\x08metadata\x18\x04 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x7f\n\x06Struct\x12;\n\x06\x66ields\x18\x01 \x03(\x0b\x32#.spark.connect.DataType.StructFieldR\x06\x66ields\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\xa2\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12#\n\rcontains_null\x18\x02 \x01(\x08R\x0c\x63ontainsNull\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x1a\xdb\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12.\n\x13value_contains_null\x18\x03 \x01(\x08R\x11valueContainsNull\x12\x38\n\x18type_variation_reference\x18\x04 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Variant\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x8f\x02\n\x03UDT\x12\x12\n\x04type\x18\x01 \x01(\tR\x04type\x12 \n\tjvm_class\x18\x02 \x01(\tH\x00R\x08jvmClass\x88\x01\x01\x12&\n\x0cpython_class\x18\x03 \x01(\tH\x01R\x0bpythonClass\x88\x01\x01\x12;\n\x17serialized_python_class\x18\x04 \x01(\tH\x02R\x15serializedPythonClass\x88\x01\x01\x12\x32\n\x08sql_type\x18\x05 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07sqlTypeB\x0c\n\n_jvm_classB\x0f\n\r_python_classB\x1a\n\x18_serialized_python_class\x1a\x34\n\x08Unparsed\x12(\n\x10\x64\x61ta_type_string\x18\x01 \x01(\tR\x0e\x64\x61taTypeStringB\x06\n\x04kindB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3" + b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xe7!\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0b\x32\x1d.spark.connect.DataType.ShortH\x00R\x05short\x12;\n\x07integer\x18\x06 \x01(\x0b\x32\x1f.spark.connect.DataType.IntegerH\x00R\x07integer\x12\x32\n\x04long\x18\x07 \x01(\x0b\x32\x1c.spark.connect.DataType.LongH\x00R\x04long\x12\x35\n\x05\x66loat\x18\x08 \x01(\x0b\x32\x1d.spark.connect.DataType.FloatH\x00R\x05\x66loat\x12\x38\n\x06\x64ouble\x18\t \x01(\x0b\x32\x1e.spark.connect.DataType.DoubleH\x00R\x06\x64ouble\x12;\n\x07\x64\x65\x63imal\x18\n \x01(\x0b\x32\x1f.spark.connect.DataType.DecimalH\x00R\x07\x64\x65\x63imal\x12\x38\n\x06string\x18\x0b \x01(\x0b\x32\x1e.spark.connect.DataType.StringH\x00R\x06string\x12\x32\n\x04\x63har\x18\x0c \x01(\x0b\x32\x1c.spark.connect.DataType.CharH\x00R\x04\x63har\x12<\n\x08var_char\x18\r \x01(\x0b\x32\x1f.spark.connect.DataType.VarCharH\x00R\x07varChar\x12\x32\n\x04\x64\x61te\x18\x0e \x01(\x0b\x32\x1c.spark.connect.DataType.DateH\x00R\x04\x64\x61te\x12\x41\n\ttimestamp\x18\x0f \x01(\x0b\x32!.spark.connect.DataType.TimestampH\x00R\ttimestamp\x12K\n\rtimestamp_ntz\x18\x10 \x01(\x0b\x32$.spark.connect.DataType.TimestampNTZH\x00R\x0ctimestampNtz\x12W\n\x11\x63\x61lendar_interval\x18\x11 \x01(\x0b\x32(.spark.connect.DataType.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12[\n\x13year_month_interval\x18\x12 \x01(\x0b\x32).spark.connect.DataType.YearMonthIntervalH\x00R\x11yearMonthInterval\x12U\n\x11\x64\x61y_time_interval\x18\x13 \x01(\x0b\x32'.spark.connect.DataType.DayTimeIntervalH\x00R\x0f\x64\x61yTimeInterval\x12\x35\n\x05\x61rray\x18\x14 \x01(\x0b\x32\x1d.spark.connect.DataType.ArrayH\x00R\x05\x61rray\x12\x38\n\x06struct\x18\x15 \x01(\x0b\x32\x1e.spark.connect.DataType.StructH\x00R\x06struct\x12/\n\x03map\x18\x16 \x01(\x0b\x32\x1b.spark.connect.DataType.MapH\x00R\x03map\x12;\n\x07variant\x18\x19 \x01(\x0b\x32\x1f.spark.connect.DataType.VariantH\x00R\x07variant\x12/\n\x03udt\x18\x17 \x01(\x0b\x32\x1b.spark.connect.DataType.UDTH\x00R\x03udt\x12>\n\x08unparsed\x18\x18 \x01(\x0b\x32 .spark.connect.DataType.UnparsedH\x00R\x08unparsed\x1a\x43\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x42yte\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05Short\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Integer\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04Long\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05\x46loat\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x42\n\x06\x44ouble\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a`\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x1c\n\tcollation\x18\x02 \x01(\tR\tcollation\x1a\x42\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04NULL\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x45\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aH\n\x0cTimestampNTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aL\n\x10\x43\x61lendarInterval\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xb3\x01\n\x11YearMonthInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1a\xb1\x01\n\x0f\x44\x61yTimeInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1aX\n\x04\x43har\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a[\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x99\x01\n\x07\x44\x65\x63imal\x12\x19\n\x05scale\x18\x01 \x01(\x05H\x00R\x05scale\x88\x01\x01\x12!\n\tprecision\x18\x02 \x01(\x05H\x01R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x08\n\x06_scaleB\x0c\n\n_precision\x1a\xa1\x01\n\x0bStructField\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x34\n\tdata_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x08\x64\x61taType\x12\x1a\n\x08nullable\x18\x03 \x01(\x08R\x08nullable\x12\x1f\n\x08metadata\x18\x04 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x7f\n\x06Struct\x12;\n\x06\x66ields\x18\x01 \x03(\x0b\x32#.spark.connect.DataType.StructFieldR\x06\x66ields\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\xa2\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12#\n\rcontains_null\x18\x02 \x01(\x08R\x0c\x63ontainsNull\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x1a\xdb\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12.\n\x13value_contains_null\x18\x03 \x01(\x08R\x11valueContainsNull\x12\x38\n\x18type_variation_reference\x18\x04 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Variant\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x8f\x02\n\x03UDT\x12\x12\n\x04type\x18\x01 \x01(\tR\x04type\x12 \n\tjvm_class\x18\x02 \x01(\tH\x00R\x08jvmClass\x88\x01\x01\x12&\n\x0cpython_class\x18\x03 \x01(\tH\x01R\x0bpythonClass\x88\x01\x01\x12;\n\x17serialized_python_class\x18\x04 \x01(\tH\x02R\x15serializedPythonClass\x88\x01\x01\x12\x32\n\x08sql_type\x18\x05 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07sqlTypeB\x0c\n\n_jvm_classB\x0f\n\r_python_classB\x1a\n\x18_serialized_python_class\x1a\x34\n\x08Unparsed\x12(\n\x10\x64\x61ta_type_string\x18\x01 \x01(\tR\x0e\x64\x61taTypeStringB\x06\n\x04kindB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3" ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -42,7 +42,7 @@ b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _DATATYPE._serialized_start = 45 - _DATATYPE._serialized_end = 4377 + _DATATYPE._serialized_end = 4372 _DATATYPE_BOOLEAN._serialized_start = 1595 _DATATYPE_BOOLEAN._serialized_end = 1662 _DATATYPE_BYTE._serialized_start = 1664 @@ -58,41 +58,41 @@ _DATATYPE_DOUBLE._serialized_start = 1999 _DATATYPE_DOUBLE._serialized_end = 2065 _DATATYPE_STRING._serialized_start = 2067 - _DATATYPE_STRING._serialized_end = 2168 - _DATATYPE_BINARY._serialized_start = 2170 - _DATATYPE_BINARY._serialized_end = 2236 - _DATATYPE_NULL._serialized_start = 2238 - _DATATYPE_NULL._serialized_end = 2302 - _DATATYPE_TIMESTAMP._serialized_start = 2304 - _DATATYPE_TIMESTAMP._serialized_end = 2373 - _DATATYPE_DATE._serialized_start = 2375 - _DATATYPE_DATE._serialized_end = 2439 - _DATATYPE_TIMESTAMPNTZ._serialized_start = 2441 - _DATATYPE_TIMESTAMPNTZ._serialized_end = 2513 - _DATATYPE_CALENDARINTERVAL._serialized_start = 2515 - _DATATYPE_CALENDARINTERVAL._serialized_end = 2591 - _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2594 - _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2773 - _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2776 - _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2953 - _DATATYPE_CHAR._serialized_start = 2955 - _DATATYPE_CHAR._serialized_end = 3043 - _DATATYPE_VARCHAR._serialized_start = 3045 - _DATATYPE_VARCHAR._serialized_end = 3136 - _DATATYPE_DECIMAL._serialized_start = 3139 - _DATATYPE_DECIMAL._serialized_end = 3292 - _DATATYPE_STRUCTFIELD._serialized_start = 3295 - _DATATYPE_STRUCTFIELD._serialized_end = 3456 - _DATATYPE_STRUCT._serialized_start = 3458 - _DATATYPE_STRUCT._serialized_end = 3585 - _DATATYPE_ARRAY._serialized_start = 3588 - _DATATYPE_ARRAY._serialized_end = 3750 - _DATATYPE_MAP._serialized_start = 3753 - _DATATYPE_MAP._serialized_end = 3972 - _DATATYPE_VARIANT._serialized_start = 3974 - _DATATYPE_VARIANT._serialized_end = 4041 - _DATATYPE_UDT._serialized_start = 4044 - _DATATYPE_UDT._serialized_end = 4315 - _DATATYPE_UNPARSED._serialized_start = 4317 - _DATATYPE_UNPARSED._serialized_end = 4369 + _DATATYPE_STRING._serialized_end = 2163 + _DATATYPE_BINARY._serialized_start = 2165 + _DATATYPE_BINARY._serialized_end = 2231 + _DATATYPE_NULL._serialized_start = 2233 + _DATATYPE_NULL._serialized_end = 2297 + _DATATYPE_TIMESTAMP._serialized_start = 2299 + _DATATYPE_TIMESTAMP._serialized_end = 2368 + _DATATYPE_DATE._serialized_start = 2370 + _DATATYPE_DATE._serialized_end = 2434 + _DATATYPE_TIMESTAMPNTZ._serialized_start = 2436 + _DATATYPE_TIMESTAMPNTZ._serialized_end = 2508 + _DATATYPE_CALENDARINTERVAL._serialized_start = 2510 + _DATATYPE_CALENDARINTERVAL._serialized_end = 2586 + _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2589 + _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2768 + _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2771 + _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2948 + _DATATYPE_CHAR._serialized_start = 2950 + _DATATYPE_CHAR._serialized_end = 3038 + _DATATYPE_VARCHAR._serialized_start = 3040 + _DATATYPE_VARCHAR._serialized_end = 3131 + _DATATYPE_DECIMAL._serialized_start = 3134 + _DATATYPE_DECIMAL._serialized_end = 3287 + _DATATYPE_STRUCTFIELD._serialized_start = 3290 + _DATATYPE_STRUCTFIELD._serialized_end = 3451 + _DATATYPE_STRUCT._serialized_start = 3453 + _DATATYPE_STRUCT._serialized_end = 3580 + _DATATYPE_ARRAY._serialized_start = 3583 + _DATATYPE_ARRAY._serialized_end = 3745 + _DATATYPE_MAP._serialized_start = 3748 + _DATATYPE_MAP._serialized_end = 3967 + _DATATYPE_VARIANT._serialized_start = 3969 + _DATATYPE_VARIANT._serialized_end = 4036 + _DATATYPE_UDT._serialized_start = 4039 + _DATATYPE_UDT._serialized_end = 4310 + _DATATYPE_UNPARSED._serialized_start = 4312 + _DATATYPE_UNPARSED._serialized_end = 4364 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi index e6b34d3485c2f..b376211045377 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.pyi +++ b/python/pyspark/sql/connect/proto/types_pb2.pyi @@ -178,22 +178,19 @@ class DataType(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int - COLLATION_ID_FIELD_NUMBER: builtins.int + COLLATION_FIELD_NUMBER: builtins.int type_variation_reference: builtins.int - collation_id: builtins.int + collation: builtins.str def __init__( self, *, type_variation_reference: builtins.int = ..., - collation_id: builtins.int = ..., + collation: builtins.str = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ - "collation_id", - b"collation_id", - "type_variation_reference", - b"type_variation_reference", + "collation", b"collation", "type_variation_reference", b"type_variation_reference" ], ) -> None: ... diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 351fa01659657..885ce62e7db6f 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -129,7 +129,7 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: if isinstance(data_type, NullType): ret.null.CopyFrom(pb2.DataType.NULL()) elif isinstance(data_type, StringType): - ret.string.collation_id = data_type.collationId + ret.string.collation = data_type.collation elif isinstance(data_type, BooleanType): ret.boolean.CopyFrom(pb2.DataType.Boolean()) elif isinstance(data_type, BinaryType): @@ -229,7 +229,8 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: s = schema.decimal.scale if schema.decimal.HasField("scale") else 0 return DecimalType(precision=p, scale=s) elif schema.HasField("string"): - return StringType.fromCollationId(schema.string.collation_id) + collation = schema.string.collation if schema.string.collation != "" else "UTF8_BINARY" + return StringType(collation) elif schema.HasField("char"): return CharType(schema.char.length) elif schema.HasField("var_char"): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 563c63f5dfb1a..c72ff72ce426b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -280,26 +280,13 @@ class StringType(AtomicType): name of the collation, default is UTF8_BINARY. """ - collationNames = ["UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI"] providerSpark = "spark" providerICU = "icu" providers = [providerSpark, providerICU] - def __init__(self, collation: Optional[str] = None): + def __init__(self, collation: str = "UTF8_BINARY"): self.typeName = self._type_name # type: ignore[method-assign] - self.collationId = 0 if collation is None else self.collationNameToId(collation) - - @classmethod - def fromCollationId(self, collationId: int) -> "StringType": - return StringType(StringType.collationNames[collationId]) - - @classmethod - def collationIdToName(cls, collationId: int) -> str: - return StringType.collationNames[collationId] - - @classmethod - def collationNameToId(cls, collationName: str) -> int: - return StringType.collationNames.index(collationName) + self.collation = collation @classmethod def collationProvider(cls, collationName: str) -> str: @@ -312,7 +299,7 @@ def _type_name(self) -> str: if self.isUTF8BinaryCollation(): return "string" - return f"string collate ${self.collationIdToName(self.collationId)}" + return f"string collate ${self.collation}" # For backwards compatibility and compatibility with other readers all string types # are serialized in json as regular strings and the collation info is written to @@ -322,13 +309,11 @@ def jsonValue(self) -> str: def __repr__(self) -> str: return ( - "StringType('%s')" % StringType.collationNames[self.collationId] - if self.collationId != 0 - else "StringType()" + "StringType()" if self.isUTF8BinaryCollation() else "StringType('%s')" % self.collation ) def isUTF8BinaryCollation(self) -> bool: - return self.collationId == 0 + return self.collation == "UTF8_BINARY" class CharType(AtomicType): @@ -1046,7 +1031,7 @@ def _isCollatedString(self, dt: DataType) -> bool: def schemaCollationValue(self, dt: DataType) -> str: assert isinstance(dt, StringType) - collationName = StringType.collationIdToName(dt.collationId) + collationName = dt.collation provider = StringType.collationProvider(collationName) return f"{provider}.{collationName}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 06e0c6eda5896..f6f5b23b7f10a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -772,12 +772,17 @@ object SQLConf { " produced by a builtin function such as to_char or CAST") .version("4.0.0") .stringConf - .checkValue(CollationFactory.isValidCollation, + .checkValue( + collationName => { + try { + CollationFactory.fetchCollation(collationName) + true + } catch { + case e: SparkException if e.getErrorClass == "COLLATION_INVALID_NAME" => false + } + }, "DEFAULT_COLLATION", - name => - Map( - "proposal" -> CollationFactory.getClosestCollation(name) - )) + _ => Map()) .createWithDefault("UTF8_BINARY") val FETCH_SHUFFLE_BLOCKS_IN_BATCH = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index 537bac9aae9b4..c3495a0c112c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -62,7 +62,7 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[SparkException] { Collate(Literal("abc"), "UTF8_BS") }, errorClass = "COLLATION_INVALID_NAME", sqlState = "42704", - parameters = Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS")) + parameters = Map("collationName" -> "UTF8_BS")) } test("collation on non-explicit default collation") { @@ -71,7 +71,8 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("collation on explicitly collated string") { checkEvaluation( - Collation(Literal.create("abc", StringType(1))).replacement, + Collation(Literal.create("abc", + StringType(CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID))).replacement, "UTF8_BINARY_LCASE") checkEvaluation( Collation(Collate(Literal("abc"), "UTF8_BINARY_LCASE")).replacement, @@ -161,4 +162,32 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayExcept(left, right), out) } } + + test("collation name normalization in collation expression") { + Seq( + ("en_USA", "en_USA"), + ("en_CS", "en"), + ("en_AS", "en"), + ("en_CS_AS", "en"), + ("en_AS_CS", "en"), + ("en_CI", "en_CI"), + ("en_AI", "en_AI"), + ("en_AI_CI", "en_CI_AI"), + ("en_CI_AI", "en_CI_AI"), + ("en_CS_AI", "en_AI"), + ("en_AI_CS", "en_AI"), + ("en_CI_AS", "en_CI"), + ("en_AS_CI", "en_CI"), + ("en_USA_AI_CI", "en_USA_CI_AI"), + // randomized case + ("EN_USA", "en_USA"), + ("SR_CYRL", "sr_Cyrl"), + ("sr_cyrl_srb", "sr_Cyrl_SRB"), + ("sR_cYRl_sRb", "sr_Cyrl_SRB") + ).foreach { + case (collation, normalized) => + checkEvaluation(Collation(Literal.create("abc", StringType(collation))).replacement, + normalized) + } + } } diff --git a/sql/core/src/test/resources/collations/ICU-collations-map.md b/sql/core/src/test/resources/collations/ICU-collations-map.md new file mode 100644 index 0000000000000..598c3c4b40240 --- /dev/null +++ b/sql/core/src/test/resources/collations/ICU-collations-map.md @@ -0,0 +1,143 @@ + +## ICU locale ids to name map +| Locale id | Locale name | +| --------- | ----------- | +| 0 | UNICODE | +| 1 | af | +| 2 | am | +| 3 | ar | +| 4 | ar_SAU | +| 5 | as | +| 6 | az | +| 7 | be | +| 8 | bg | +| 9 | bn | +| 10 | bo | +| 11 | br | +| 12 | bs | +| 13 | bs_Cyrl | +| 14 | ca | +| 15 | ceb | +| 16 | chr | +| 17 | cs | +| 18 | cy | +| 19 | da | +| 20 | de | +| 21 | de_AUT | +| 22 | dsb | +| 23 | dz | +| 24 | ee | +| 25 | el | +| 26 | en | +| 27 | en_USA | +| 28 | eo | +| 29 | es | +| 30 | et | +| 31 | fa | +| 32 | fa_AFG | +| 33 | ff | +| 34 | ff_Adlm | +| 35 | fi | +| 36 | fil | +| 37 | fo | +| 38 | fr | +| 39 | fr_CAN | +| 40 | fy | +| 41 | ga | +| 42 | gl | +| 43 | gu | +| 44 | ha | +| 45 | haw | +| 46 | he | +| 47 | he_ISR | +| 48 | hi | +| 49 | hr | +| 50 | hsb | +| 51 | hu | +| 52 | hy | +| 53 | id | +| 54 | id_IDN | +| 55 | ig | +| 56 | is | +| 57 | it | +| 58 | ja | +| 59 | ka | +| 60 | kk | +| 61 | kl | +| 62 | km | +| 63 | kn | +| 64 | ko | +| 65 | kok | +| 66 | ku | +| 67 | ky | +| 68 | lb | +| 69 | lkt | +| 70 | ln | +| 71 | lo | +| 72 | lt | +| 73 | lv | +| 74 | mk | +| 75 | ml | +| 76 | mn | +| 77 | mr | +| 78 | ms | +| 79 | mt | +| 80 | my | +| 81 | nb | +| 82 | nb_NOR | +| 83 | ne | +| 84 | nl | +| 85 | nn | +| 86 | no | +| 87 | om | +| 88 | or | +| 89 | pa | +| 90 | pa_Guru | +| 91 | pa_Guru_IND | +| 92 | pl | +| 93 | ps | +| 94 | pt | +| 95 | ro | +| 96 | ru | +| 97 | sa | +| 98 | se | +| 99 | si | +| 100 | sk | +| 101 | sl | +| 102 | smn | +| 103 | sq | +| 104 | sr | +| 105 | sr_Cyrl | +| 106 | sr_Cyrl_BIH | +| 107 | sr_Cyrl_MNE | +| 108 | sr_Cyrl_SRB | +| 109 | sr_Latn | +| 110 | sr_Latn_BIH | +| 111 | sr_Latn_SRB | +| 112 | sv | +| 113 | sw | +| 114 | ta | +| 115 | te | +| 116 | th | +| 117 | tk | +| 118 | to | +| 119 | tr | +| 120 | ug | +| 121 | uk | +| 122 | ur | +| 123 | uz | +| 124 | vi | +| 125 | wae | +| 126 | wo | +| 127 | xh | +| 128 | yi | +| 129 | yo | +| 130 | zh | +| 131 | zh_Hans | +| 132 | zh_Hans_CHN | +| 133 | zh_Hans_SGP | +| 134 | zh_Hant | +| 135 | zh_Hant_HKG | +| 136 | zh_Hant_MAC | +| 137 | zh_Hant_TWN | +| 138 | zu | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index d242a60a17c18..9a1f4ed1f8e57 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -312,3 +312,80 @@ select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate -- !query analysis Project [array_except(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation + + +-- !query +select 'a' collate unicode < 'A' +-- !query analysis +Project [(collate(a, unicode) < cast(A as string collate UNICODE)) AS (collate(a) < A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate unicode_ci = 'A' +-- !query analysis +Project [(collate(a, unicode_ci) = cast(A as string collate UNICODE_CI)) AS (collate(a) = A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate unicode_ai = 'å' +-- !query analysis +Project [(collate(a, unicode_ai) = cast(å as string collate UNICODE_AI)) AS (collate(a) = å)#x] ++- OneRowRelation + + +-- !query +select 'a' collate unicode_ci_ai = 'Å' +-- !query analysis +Project [(collate(a, unicode_ci_ai) = cast(Å as string collate UNICODE_CI_AI)) AS (collate(a) = Å)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en < 'A' +-- !query analysis +Project [(collate(a, en) < cast(A as string collate en)) AS (collate(a) < A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en_ci = 'A' +-- !query analysis +Project [(collate(a, en_ci) = cast(A as string collate en_CI)) AS (collate(a) = A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en_ai = 'å' +-- !query analysis +Project [(collate(a, en_ai) = cast(å as string collate en_AI)) AS (collate(a) = å)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en_ci_ai = 'Å' +-- !query analysis +Project [(collate(a, en_ci_ai) = cast(Å as string collate en_CI_AI)) AS (collate(a) = Å)#x] ++- OneRowRelation + + +-- !query +select 'Kypper' collate sv < 'Köpfe' +-- !query analysis +Project [(collate(Kypper, sv) < cast(Köpfe as string collate sv)) AS (collate(Kypper) < Köpfe)#x] ++- OneRowRelation + + +-- !query +select 'Kypper' collate de > 'Köpfe' +-- !query analysis +Project [(collate(Kypper, de) > cast(Köpfe as string collate de)) AS (collate(Kypper) > Köpfe)#x] ++- OneRowRelation + + +-- !query +select 'I' collate tr_ci = 'ı' +-- !query analysis +Project [(collate(I, tr_ci) = cast(ı as string collate tr_CI)) AS (collate(I) = ı)#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index 619eb4470e9ad..6bb0a0163443a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -77,3 +77,16 @@ select array_distinct(array('aaa' collate utf8_binary_lcase, 'AAA' collate utf8_ select array_union(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)); select array_intersect(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)); select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)); + +-- ICU collations (all statements return true) +select 'a' collate unicode < 'A'; +select 'a' collate unicode_ci = 'A'; +select 'a' collate unicode_ai = 'å'; +select 'a' collate unicode_ci_ai = 'Å'; +select 'a' collate en < 'A'; +select 'a' collate en_ci = 'A'; +select 'a' collate en_ai = 'å'; +select 'a' collate en_ci_ai = 'Å'; +select 'Kypper' collate sv < 'Köpfe'; +select 'Kypper' collate de > 'Köpfe'; +select 'I' collate tr_ci = 'ı'; diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index 4485191ba1f3b..96c875306d358 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -339,3 +339,91 @@ select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate struct> -- !query output [] + + +-- !query +select 'a' collate unicode < 'A' +-- !query schema +struct<(collate(a) < A):boolean> +-- !query output +true + + +-- !query +select 'a' collate unicode_ci = 'A' +-- !query schema +struct<(collate(a) = A):boolean> +-- !query output +true + + +-- !query +select 'a' collate unicode_ai = 'å' +-- !query schema +struct<(collate(a) = å):boolean> +-- !query output +true + + +-- !query +select 'a' collate unicode_ci_ai = 'Å' +-- !query schema +struct<(collate(a) = Å):boolean> +-- !query output +true + + +-- !query +select 'a' collate en < 'A' +-- !query schema +struct<(collate(a) < A):boolean> +-- !query output +true + + +-- !query +select 'a' collate en_ci = 'A' +-- !query schema +struct<(collate(a) = A):boolean> +-- !query output +true + + +-- !query +select 'a' collate en_ai = 'å' +-- !query schema +struct<(collate(a) = å):boolean> +-- !query output +true + + +-- !query +select 'a' collate en_ci_ai = 'Å' +-- !query schema +struct<(collate(a) = Å):boolean> +-- !query output +true + + +-- !query +select 'Kypper' collate sv < 'Köpfe' +-- !query schema +struct<(collate(Kypper) < Köpfe):boolean> +-- !query output +true + + +-- !query +select 'Kypper' collate de > 'Köpfe' +-- !query schema +struct<(collate(Kypper) > Köpfe):boolean> +-- !query output +true + + +-- !query +select 'I' collate tr_ci = 'ı' +-- !query schema +struct<(collate(I) = ı):boolean> +-- !query output +true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 657fd4504cac1..4f8587395b3e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -152,7 +152,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[SparkException] { sql("select 'aaa' collate UTF8_BS") }, errorClass = "COLLATION_INVALID_NAME", sqlState = "42704", - parameters = Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS")) + parameters = Map("collationName" -> "UTF8_BS")) } test("disable bucketing on collated string column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala new file mode 100644 index 0000000000000..42d486bd75454 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile, CollationFactory} + +// scalastyle:off line.size.limit +/** + * Guard against breaking changes in ICU locale names and codes supported by Collator class and provider by CollationFactory. + * Map is in form of rows of pairs (locale name, locale id); locale name consists of three parts: + * - 2-letter lowercase language code + * - 4-letter script code (optional) + * - 3-letter uppercase country code + * + * To re-generate collations map golden file, run: + * {{{ + * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly org.apache.spark.sql.ICUCollationsMapSuite" + * }}} + */ +// scalastyle:on line.size.limit +class ICUCollationsMapSuite extends SparkFunSuite { + + private val collationsMapFile = { + getWorkspaceFilePath("sql", "core", "src", "test", "resources", + "collations", "ICU-collations-map.md").toFile + } + + if (regenerateGoldenFiles) { + val map = CollationFactory.getICULocaleNames + val mapOutput = map.zipWithIndex.map { + case (localeName, idx) => s"| $idx | $localeName |" }.mkString("\n") + val goldenOutput = { + s"\n" + + "## ICU locale ids to name map\n" + + "| Locale id | Locale name |\n" + + "| --------- | ----------- |\n" + + mapOutput + "\n" + } + val parent = collationsMapFile.getParentFile + if (!parent.exists()) { + assert(parent.mkdirs(), "Could not create directory: " + parent) + } + stringToFile(collationsMapFile, goldenOutput) + } + + test("ICU locales map breaking change") { + val goldenLines = fileToString(collationsMapFile).split('\n') + val goldenRelevantLines = goldenLines.slice(4, goldenLines.length) // skip header + val input = goldenRelevantLines.map( + s => (s.split('|')(2).strip(), s.split('|')(1).strip().toInt)) + assert(input sameElements CollationFactory.getICULocaleNames.zipWithIndex) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 213dfd32c8698..8d291591c5f41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -518,8 +518,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { errorClass = "INVALID_CONF_VALUE.DEFAULT_COLLATION", parameters = Map( "confValue" -> "UNICODE_C", - "confName" -> "spark.sql.session.collation.default", - "proposal" -> "UNICODE_CI" + "confName" -> "spark.sql.session.collation.default" )) withSQLConf(SQLConf.COLLATION_ENABLED.key -> "false") { From 731a2cfcffaeeeb1f1c107080ca77000330d79b5 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Tue, 28 May 2024 09:59:53 -0700 Subject: [PATCH 43/45] [SPARK-48273][SQL] Fix late rewrite of PlanWithUnresolvedIdentifier ### What changes were proposed in this pull request? `PlanWithUnresolvedIdentifier` is rewritten later in analysis which causes rules like `SubstituteUnresolvedOrdinals` to miss the new plan. This causes following queries to fail: ``` create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); -- cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); -- create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1; ``` Fix this by explicitly applying rules after plan rewrite. ### Why are the changes needed? To fix the described bug. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the mentioned problematic queries. ### How was this patch tested? Updated existing `identifier-clause.sql` golden file. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46580 from nikolamand-db/SPARK-48273. Authored-by: Nikola Mandic Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 9 ++- .../analysis/ResolveIdentifierClause.scala | 11 +++- .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../identifier-clause.sql.out | 59 +++++++++++++++++++ .../sql-tests/inputs/identifier-clause.sql | 9 +++ .../results/identifier-clause.sql.out | 56 ++++++++++++++++++ 6 files changed, 139 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 55b6f1af7fd8b..a233161713c3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -254,7 +254,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor TypeCoercion.typeCoercionRules } - override def batches: Seq[Batch] = Seq( + private def earlyBatches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, new SubstituteExecuteImmediate(catalogManager), // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -274,7 +274,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs), + KeepLegacyOutputs) + ) + + override def batches: Seq[Batch] = earlyBatches ++ Seq( Batch("Resolution", fixedPoint, new ResolveCatalogs(catalogManager) :: ResolveInsertInto :: @@ -319,7 +322,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveTimeZone :: ResolveRandomSeed :: ResolveBinaryArithmetic :: - ResolveIdentifierClause :: + new ResolveIdentifierClause(earlyBatches) :: ResolveUnion :: ResolveRowLevelCommandAssignments :: RewriteDeleteFromTable :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index ced7123dfcc14..f04b7799e35ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -20,19 +20,24 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER import org.apache.spark.sql.types.StringType /** * Resolves the identifier expressions and builds the original plans/expressions. */ -object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper { +class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch]) + extends Rule[LogicalPlan] with AliasHelper with EvalHelper { + + private val executor = new RuleExecutor[LogicalPlan] { + override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]] + } override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved => - p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr)) + executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr))) case other => other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 0aa01e4f5c517..c8b3f224a3129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -147,7 +147,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { override val maxIterationsSetting: String = null) extends Strategy /** A batch of rules. */ - protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) + protected[catalyst] case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected def batches: Seq[Batch] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out index 7389c7be87af7..f799c19a3bb8e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out @@ -926,6 +926,65 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CreateViewCommand `v1`, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, false, LocalTempView, UNSUPPORTED, true + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CacheTableAsSelect t1, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, true + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t2`, ErrorIfExists, [my_col] + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1 +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t2, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t2], Append, `spark_catalog`.`default`.`t2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t2), [my_col] ++- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +drop view v1 +-- !query analysis +DropTempViewCommand v1 + + +-- !query +drop table t1 +-- !query analysis +DropTempViewCommand t1 + + +-- !query +drop table t2 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2 + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql index fd53f44d3c33c..978b82c331feb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql @@ -132,6 +132,15 @@ CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.a DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg'); CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1); +-- SPARK-48273: Aggregation operation in statements using identifier clause for table name +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1; +drop view v1; +drop table t1; +drop table t2; + -- Not supported SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1); SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1')); diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out index 9dfc6a66b0782..68aa5956a91c1 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out @@ -1059,6 +1059,62 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop view v1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t2 +-- !query schema +struct<> +-- !query output + + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query schema From e9a3ed857954c21ca639f07f2621ac8cebc30ad3 Mon Sep 17 00:00:00 2001 From: Nebojsa Savic Date: Tue, 28 May 2024 10:01:42 -0700 Subject: [PATCH 44/45] [SPARK-48159][SQL] Extending support for collated strings on datetime expressions ### What changes were proposed in this pull request? This PR introduces changes that will allow for collated strings to be passed to various datetime expressions or return value as collated string from those expressions. Impacted datetime expressions: - current_timezone - to_unix_timestamp - from_unixtime - next_day - from_utc_timestamp - to_utc_timestamp - to_date - to_timestamp - trunc - date_trunc - make_timestamp - date_part - convert_timezone ### Why are the changes needed? This PR is part of ongoing effort to support collated strings on SparkSQL. ### Does this PR introduce _any_ user-facing change? Yes, users will be able to use collated strings for datetime expressions. ### How was this patch tested? Added corresponding tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46618 from nebojsa-db/SPARK-48159. Authored-by: Nebojsa Savic Signed-off-by: Wenchen Fan --- .../expressions/datetimeExpressions.scala | 38 +-- .../sql/CollationSQLExpressionsSuite.scala | 234 ++++++++++++++++++ 2 files changed, 254 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 8caf8c5d48c2b..808ad54f8ecad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -105,7 +105,7 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { since = "3.1.0") case class CurrentTimeZone() extends LeafExpression with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "current_timezone" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } @@ -924,7 +924,7 @@ case class DayName(child: Expression) extends GetDateField { override val funcName = "getDayName" override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): DayName = copy(child = newChild) } @@ -1262,7 +1262,8 @@ abstract class ToTimestamp override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, DateType, TimestampType, TimestampNTZType), StringType) + Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType), + StringTypeAnyCollation) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1284,7 +1285,7 @@ abstract class ToTimestamp daysToMicros(t.asInstanceOf[Int], zoneId) / downScaleFactor case TimestampType | TimestampNTZType => t.asInstanceOf[Long] / downScaleFactor - case StringType => + case _: StringType => val fmt = right.eval(input) if (fmt == null) { null @@ -1327,7 +1328,7 @@ abstract class ToTimestamp } left.dataType match { - case StringType => formatterOption.map { fmt => + case _: StringType => formatterOption.map { fmt => val df = classOf[TimestampFormatter].getName val formatterName = ctx.addReferenceObj("formatter", fmt, df) nullSafeCodeGen(ctx, ev, (datetimeStr, _) => @@ -1430,10 +1431,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ this(unix, Literal(TimestampFormatter.defaultPattern())) } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1541,7 +1542,7 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1752,7 +1753,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w val func: (Long, String) => Long val funcName: String - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2092,8 +2093,8 @@ case class ParseToDate( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringType).toSeq + TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringTypeAnyCollation).toSeq } override protected def withNewChildrenInternal( @@ -2164,10 +2165,10 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringType, DateType, TimestampType, TimestampNTZType) + val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringType).toSeq + ) +: format.map(_ => StringTypeAnyCollation).toSeq } override protected def withNewChildrenInternal( @@ -2297,7 +2298,7 @@ case class TruncDate(date: Expression, format: Expression) override def left: Expression = date override def right: Expression = format - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2366,7 +2367,7 @@ case class TruncTimestamp( override def left: Expression = format override def right: Expression = timestamp - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2667,7 +2668,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringType) + timezone.map(_ => StringTypeAnyCollation) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -2939,7 +2940,7 @@ case class Extract(field: Expression, source: Expression, replacement: Expressio object Extract { def createExpr(funcName: String, field: Expression, source: Expression): Expression = { // both string and null literals are allowed. - if ((field.dataType == StringType || field.dataType == NullType) && field.foldable) { + if ((field.dataType.isInstanceOf[StringType] || field.dataType == NullType) && field.foldable) { val fieldStr = field.eval().asInstanceOf[UTF8String] if (fieldStr == null) { Literal(null, DoubleType) @@ -3114,7 +3115,8 @@ case class ConvertTimezone( override def second: Expression = targetTz override def third: Expression = sourceTs - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, TimestampNTZType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, + StringTypeAnyCollation, TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f3d07ba47b715..525ef02f949a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import scala.collection.immutable.Seq @@ -31,6 +32,8 @@ class CollationSQLExpressionsSuite extends QueryTest with SharedSparkSession { + private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI") + test("Support Md5 hash expression with collation") { case class Md5TestCase( input: String, @@ -1632,6 +1635,237 @@ class CollationSQLExpressionsSuite } } + test("CurrentTimeZone expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = "select current_timezone()" + // Data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("DayName expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = "select dayname(current_date())" + // Data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("ToUnixTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_unix_timestamp(collate('2021-01-01 00:00:00', '${collationName}'), + |collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = LongType + val expectedResult = 1609488000L + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + + test("FromUnixTime expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select from_unixtime(1609488000, collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = StringType(collationName) + val expectedResult = "2021-01-01 00:00:00" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + } + }) + } + + test("NextDay expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select next_day('2015-01-14', collate('TU', '${collationName}')) + |""".stripMargin + // Result & data type check + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2015-01-20" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + } + }) + } + + test("FromUTCTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select from_utc_timestamp(collate('2016-08-31', '${collationName}'), + |collate('Asia/Seoul', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-08-31 09:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("ToUTCTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_utc_timestamp(collate('2016-08-31 09:00:00', '${collationName}'), + |collate('Asia/Seoul', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-08-31 00:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("ParseToDate expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_date(collate('2016-12-31', '${collationName}'), + |collate('yyyy-MM-dd', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-12-31" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("ParseToTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select to_timestamp(collate('2016-12-31 23:59:59', '${collationName}'), + |collate('yyyy-MM-dd HH:mm:ss', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2016-12-31 23:59:59.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("TruncDate expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select trunc(collate('2016-12-31 23:59:59', '${collationName}'), 'MM') + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-12-01" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("TruncTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_trunc(collate('HOUR', '${collationName}'), + |collate('2015-03-05T09:32:05.359', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2015-03-05 09:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("MakeTimestamp expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select make_timestamp(2014, 12, 28, 6, 30, 45.887, collate('CET', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResult = "2014-12-27 21:30:45.887" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Timestamp.valueOf(expectedResult))) + }) + } + + test("DatePart expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_part(collate('Week', '${collationName}'), + |collate('2019-08-12 01:00:00.123456', '${collationName}')) + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = IntegerType + val expectedResult = 33 + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + + test("ConvertTimezone expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = + s""" + |select date_format(convert_timezone(collate('America/Los_Angeles', '${collationName}'), + |collate('UTC', '${collationName}'), collate('2021-12-06 00:00:00', '${collationName}')), + |'yyyy-MM-dd HH:mm:ss.S') + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = StringType + val expectedResult = "2021-12-06 08:00:00.0" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(expectedResult)) + }) + } + // TODO: Add more tests for other SQL expressions } From 249390017ef4a045037213dec386e16cca125080 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 28 May 2024 10:05:12 -0700 Subject: [PATCH 45/45] [SPARK-48221][SQL] Alter string search logic for UTF8_BINARY_LCASE collation (Contains, StartsWith, EndsWith, StringLocate) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? String searching in UTF8_BINARY_LCASE now works on character-level, rather than on byte-level. For example: `contains("İ", "i");` now returns **false**, because there exists no `start, len` such that `lowercase(substring("İ", start, len)) == "i"`. ### Why are the changes needed? Fix functions that give unusable results due to one-to-many case mapping when performing string search under UTF8_BINARY_LCASE (see example above). ### Does this PR introduce _any_ user-facing change? Yes, behaviour of `contains`, `startswith`, `endswith`, and `locate`/`position` expressions is changed for edge cases with one-to-many case mapping. ### How was this patch tested? New unit tests in `CollationSupportSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46511 from uros-db/alter-lcase-impl. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../util/CollationAwareUTF8String.java | 169 ++++++++++++++++++ .../sql/catalyst/util/CollationSupport.java | 8 +- .../apache/spark/unsafe/types/UTF8String.java | 118 ------------ .../unsafe/types/CollationSupportSuite.java | 129 ++++++++++--- .../spark/unsafe/types/UTF8StringSuite.java | 105 ----------- 5 files changed, 278 insertions(+), 251 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index ee0d611d7e652..0d0094d8d0a03 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -34,6 +34,155 @@ * Utility class for collation-aware UTF8String operations. */ public class CollationAwareUTF8String { + + /** + * The constant value to indicate that the match is not found when searching for a pattern + * string in a target string. + */ + private static final int MATCH_NOT_FOUND = -1; + + /** + * Returns whether the target string starts with the specified prefix, starting from the + * specified position (0-based index referring to character position in UTF8String), with respect + * to the UTF8_BINARY_LCASE collation. The method assumes that the prefix is already lowercased + * prior to method call to avoid the overhead of calling .toLowerCase() multiple times on the + * same prefix string. + * + * @param target the string to be searched in + * @param lowercasePattern the string to be searched for + * @param startPos the start position for searching (in the target string) + * @return whether the target string starts with the specified prefix in UTF8_BINARY_LCASE + */ + public static boolean lowercaseMatchFrom( + final UTF8String target, + final UTF8String lowercasePattern, + int startPos) { + return lowercaseMatchLengthFrom(target, lowercasePattern, startPos) != MATCH_NOT_FOUND; + } + + /** + * Returns the length of the substring of the target string that starts with the specified + * prefix, starting from the specified position (0-based index referring to character position + * in UTF8String), with respect to the UTF8_BINARY_LCASE collation. The method assumes that the + * prefix is already lowercased. The method only considers the part of target string that + * starts from the specified (inclusive) position (that is, the method does not look at UTF8 + * characters of the target string at or after position `endPos`). If the prefix is not found, + * MATCH_NOT_FOUND is returned. + * + * @param target the string to be searched in + * @param lowercasePattern the string to be searched for + * @param startPos the start position for searching (in the target string) + * @return length of the target substring that starts with the specified prefix in lowercase + */ + private static int lowercaseMatchLengthFrom( + final UTF8String target, + final UTF8String lowercasePattern, + int startPos) { + assert startPos >= 0; + for (int len = 0; len <= target.numChars() - startPos; ++len) { + if (target.substring(startPos, startPos + len).toLowerCase().equals(lowercasePattern)) { + return len; + } + } + return MATCH_NOT_FOUND; + } + + /** + * Returns the position of the first occurrence of the pattern string in the target string, + * starting from the specified position (0-based index referring to character position in + * UTF8String), with respect to the UTF8_BINARY_LCASE collation. The method assumes that the + * pattern string is already lowercased prior to call. If the pattern is not found, + * MATCH_NOT_FOUND is returned. + * + * @param target the string to be searched in + * @param lowercasePattern the string to be searched for + * @param startPos the start position for searching (in the target string) + * @return the position of the first occurrence of pattern in target + */ + private static int lowercaseFind( + final UTF8String target, + final UTF8String lowercasePattern, + int startPos) { + assert startPos >= 0; + for (int i = startPos; i <= target.numChars(); ++i) { + if (lowercaseMatchFrom(target, lowercasePattern, i)) { + return i; + } + } + return MATCH_NOT_FOUND; + } + + /** + * Returns whether the target string ends with the specified suffix, ending at the specified + * position (0-based index referring to character position in UTF8String), with respect to the + * UTF8_BINARY_LCASE collation. The method assumes that the suffix is already lowercased prior + * to method call to avoid the overhead of calling .toLowerCase() multiple times on the same + * suffix string. + * + * @param target the string to be searched in + * @param lowercasePattern the string to be searched for + * @param endPos the end position for searching (in the target string) + * @return whether the target string ends with the specified suffix in lowercase + */ + public static boolean lowercaseMatchUntil( + final UTF8String target, + final UTF8String lowercasePattern, + int endPos) { + return lowercaseMatchLengthUntil(target, lowercasePattern, endPos) != MATCH_NOT_FOUND; + } + + /** + * Returns the length of the substring of the target string that ends with the specified + * suffix, ending at the specified position (0-based index referring to character position in + * UTF8String), with respect to the UTF8_BINARY_LCASE collation. The method assumes that the + * suffix is already lowercased. The method only considers the part of target string that ends + * at the specified (non-inclusive) position (that is, the method does not look at UTF8 + * characters of the target string at or after position `endPos`). If the suffix is not found, + * MATCH_NOT_FOUND is returned. + * + * @param target the string to be searched in + * @param lowercasePattern the string to be searched for + * @param endPos the end position for searching (in the target string) + * @return length of the target substring that ends with the specified suffix in lowercase + */ + private static int lowercaseMatchLengthUntil( + final UTF8String target, + final UTF8String lowercasePattern, + int endPos) { + assert endPos <= target.numChars(); + for (int len = 0; len <= endPos; ++len) { + if (target.substring(endPos - len, endPos).toLowerCase().equals(lowercasePattern)) { + return len; + } + } + return MATCH_NOT_FOUND; + } + + /** + * Returns the position of the last occurrence of the pattern string in the target string, + * ending at the specified position (0-based index referring to character position in + * UTF8String), with respect to the UTF8_BINARY_LCASE collation. The method assumes that the + * pattern string is already lowercased prior to call. If the pattern is not found, + * MATCH_NOT_FOUND is returned. + * + * @param target the string to be searched in + * @param lowercasePattern the string to be searched for + * @param endPos the end position for searching (in the target string) + * @return the position of the last occurrence of pattern in target + */ + private static int lowercaseRFind( + final UTF8String target, + final UTF8String lowercasePattern, + int endPos) { + assert endPos <= target.numChars(); + for (int i = endPos; i >= 0; --i) { + if (lowercaseMatchUntil(target, lowercasePattern, i)) { + return i; + } + } + return MATCH_NOT_FOUND; + } + public static UTF8String replace(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { // This collation aware implementation is based on existing implementation on UTF8String @@ -183,6 +332,23 @@ public static int findInSet(final UTF8String match, final UTF8String set, int co return 0; } + /** + * Returns the position of the first occurrence of the pattern string in the target string, + * starting from the specified position (0-based index referring to character position in + * UTF8String), with respect to the UTF8_BINARY_LCASE collation. If the pattern is not found, + * MATCH_NOT_FOUND is returned. + * + * @param target the string to be searched in + * @param pattern the string to be searched for + * @param start the start position for searching (in the target string) + * @return the position of the first occurrence of pattern in target + */ + public static int lowercaseIndexOf(final UTF8String target, final UTF8String pattern, + final int start) { + if (pattern.numChars() == 0) return 0; + return lowercaseFind(target, pattern.toLowerCase(), start); + } + public static int indexOf(final UTF8String target, final UTF8String pattern, final int start, final int collationId) { if (pattern.numBytes() == 0) { @@ -467,4 +633,7 @@ public static UTF8String lowercaseTrimRight( } return srcString.copyUTF8String(0, trimByteIdx); } + + // TODO: Add more collation-aware UTF8String operations here. + } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index bea3dc08b4489..8f7aed30464cc 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -118,7 +118,7 @@ public static boolean execBinary(final UTF8String l, final UTF8String r) { return l.contains(r); } public static boolean execLowercase(final UTF8String l, final UTF8String r) { - return l.containsInLowerCase(r); + return CollationAwareUTF8String.lowercaseIndexOf(l, r, 0) >= 0; } public static boolean execICU(final UTF8String l, final UTF8String r, final int collationId) { @@ -156,7 +156,7 @@ public static boolean execBinary(final UTF8String l, final UTF8String r) { return l.startsWith(r); } public static boolean execLowercase(final UTF8String l, final UTF8String r) { - return l.startsWithInLowerCase(r); + return CollationAwareUTF8String.lowercaseMatchFrom(l, r.toLowerCase(), 0); } public static boolean execICU(final UTF8String l, final UTF8String r, final int collationId) { @@ -193,7 +193,7 @@ public static boolean execBinary(final UTF8String l, final UTF8String r) { return l.endsWith(r); } public static boolean execLowercase(final UTF8String l, final UTF8String r) { - return l.endsWithInLowerCase(r); + return CollationAwareUTF8String.lowercaseMatchUntil(l, r.toLowerCase(), l.numChars()); } public static boolean execICU(final UTF8String l, final UTF8String r, final int collationId) { @@ -430,7 +430,7 @@ public static int execBinary(final UTF8String string, final UTF8String substring } public static int execLowercase(final UTF8String string, final UTF8String substring, final int start) { - return string.toLowerCase().indexOf(substring.toLowerCase(), start); + return CollationAwareUTF8String.lowercaseIndexOf(string, substring, start); } public static int execICU(final UTF8String string, final UTF8String substring, final int start, final int collationId) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 20b26b6ebc5a5..03286e0635287 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -341,44 +341,6 @@ public boolean contains(final UTF8String substring) { return false; } - /** - * Returns whether `this` contains `substring` in a lowercase unicode-aware manner - * - * This function is written in a way which avoids excessive allocations in case if we work with - * bare ASCII-character strings. - */ - public boolean containsInLowerCase(final UTF8String substring) { - if (substring.numBytes == 0) { - return true; - } - - // Both `this` and the `substring` are checked for non-ASCII characters, otherwise we would - // have to use `startsWithLowerCase(...)` in a loop, and it would frequently allocate - // (e.g. in case of `containsInLowerCase("1大1大1大...", "11")`) - if (!substring.isFullAscii()) { - return toLowerCase().contains(substring.toLowerCaseSlow()); - } - if (!isFullAscii()) { - return toLowerCaseSlow().contains(substring.toLowerCaseAscii()); - } - - if (numBytes < substring.numBytes) { - return false; - } - - final var firstLower = Character.toLowerCase(substring.getByte(0)); - for (var i = 0; i <= (numBytes - substring.numBytes); i++) { - if (Character.toLowerCase(getByte(i)) == firstLower) { - final var rest = UTF8String.fromAddress(base, offset + i, numBytes - i); - if (rest.matchAtInLowerCaseAscii(substring, 0)) { - return true; - } - } - } - - return false; - } - /** * Returns the byte at position `i`. */ @@ -393,94 +355,14 @@ public boolean matchAt(final UTF8String s, int pos) { return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); } - private boolean matchAtInLowerCaseAscii(final UTF8String s, int pos) { - if (s.numBytes + pos > numBytes || pos < 0) { - return false; - } - - for (var i = 0; i < s.numBytes; i++) { - if (Character.toLowerCase(getByte(pos + i)) != Character.toLowerCase(s.getByte(i))) { - return false; - } - } - - return true; - } - public boolean startsWith(final UTF8String prefix) { return matchAt(prefix, 0); } - /** - * Checks whether `prefix` is a prefix of `this` in a lowercase unicode-aware manner - * - * This function is written in a way which avoids excessive allocations in case if we work with - * bare ASCII-character strings. - */ - public boolean startsWithInLowerCase(final UTF8String prefix) { - // No way to match sizes of strings for early return, since single grapheme can be expanded - // into several independent ones in lowercase - if (prefix.numBytes == 0) { - return true; - } - if (numBytes == 0) { - return false; - } - - if (!prefix.isFullAscii()) { - return toLowerCase().startsWith(prefix.toLowerCaseSlow()); - } - - final var part = prefix.numBytes >= numBytes ? this : UTF8String.fromAddress( - base, offset, prefix.numBytes); - if (!part.isFullAscii()) { - return toLowerCaseSlow().startsWith(prefix.toLowerCaseAscii()); - } - - if (numBytes < prefix.numBytes) { - return false; - } - - return matchAtInLowerCaseAscii(prefix, 0); - } - public boolean endsWith(final UTF8String suffix) { return matchAt(suffix, numBytes - suffix.numBytes); } - /** - * Checks whether `suffix` is a suffix of `this` in a lowercase unicode-aware manner - * - * This function is written in a way which avoids excessive allocations in case if we work with - * bare ASCII-character strings. - */ - public boolean endsWithInLowerCase(final UTF8String suffix) { - // No way to match sizes of strings for early return, since single grapheme can be expanded - // into several independent ones in lowercase - if (suffix.numBytes == 0) { - return true; - } - if (numBytes == 0) { - return false; - } - - if (!suffix.isFullAscii()) { - return toLowerCase().endsWith(suffix.toLowerCaseSlow()); - } - - final var part = suffix.numBytes >= numBytes ? this : UTF8String.fromAddress( - base, offset + numBytes - suffix.numBytes, suffix.numBytes); - if (!part.isFullAscii()) { - return toLowerCaseSlow().endsWith(suffix.toLowerCaseAscii()); - } - - if (numBytes < suffix.numBytes) { - return false; - } - - return matchAtInLowerCaseAscii(suffix, numBytes - suffix.numBytes); - } - /** * Returns the upper case of this string */ diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 7fc3c4e349c3b..eb18d7665b092 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -23,7 +23,7 @@ import static org.junit.jupiter.api.Assertions.*; - +// checkstyle.off: AvoidEscapedUnicodeCharacters public class CollationSupportSuite { /** @@ -101,14 +101,6 @@ public void testContains() throws SparkException { assertContains("ab世De", "AB世dE", "UNICODE_CI", true); assertContains("äbćδe", "ÄbćδE", "UNICODE_CI", true); assertContains("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); - // Case-variable character length - assertContains("abİo12", "i̇o", "UNICODE_CI", true); - assertContains("abi̇o12", "İo", "UNICODE_CI", true); - assertContains("the İodine", "the i̇odine", "UTF8_BINARY_LCASE", true); - assertContains("the i̇odine", "the İodine", "UTF8_BINARY_LCASE", true); - assertContains("The İodiNe", " i̇oDin", "UTF8_BINARY_LCASE", true); - assertContains("İodiNe", "i̇oDine", "UTF8_BINARY_LCASE", true); - assertContains("İodiNe", " i̇oDin", "UTF8_BINARY_LCASE", false); // Characters with the same binary lowercase representation assertContains("The Kelvin.", "Kelvin", "UTF8_BINARY_LCASE", true); assertContains("The Kelvin.", "Kelvin", "UTF8_BINARY_LCASE", true); @@ -116,6 +108,33 @@ public void testContains() throws SparkException { assertContains("2 Kelvin.", "2 Kelvin", "UTF8_BINARY_LCASE", true); assertContains("2 Kelvin.", "2 Kelvin", "UTF8_BINARY_LCASE", true); assertContains("The KKelvin.", "KKelvin,", "UTF8_BINARY_LCASE", false); + // Case-variable character length + assertContains("i̇", "i", "UNICODE_CI", false); + assertContains("i̇", "\u0307", "UNICODE_CI", false); + assertContains("i̇", "İ", "UNICODE_CI", true); + assertContains("İ", "i", "UNICODE_CI", false); + assertContains("adi̇os", "io", "UNICODE_CI", false); + assertContains("adi̇os", "Io", "UNICODE_CI", false); + assertContains("adi̇os", "i̇o", "UNICODE_CI", true); + assertContains("adi̇os", "İo", "UNICODE_CI", true); + assertContains("adİos", "io", "UNICODE_CI", false); + assertContains("adİos", "Io", "UNICODE_CI", false); + assertContains("adİos", "i̇o", "UNICODE_CI", true); + assertContains("adİos", "İo", "UNICODE_CI", true); + assertContains("i̇", "i", "UTF8_BINARY_LCASE", true); // != UNICODE_CI + assertContains("İ", "\u0307", "UTF8_BINARY_LCASE", false); + assertContains("İ", "i", "UTF8_BINARY_LCASE", false); + assertContains("i̇", "\u0307", "UTF8_BINARY_LCASE", true); // != UNICODE_CI + assertContains("i̇", "İ", "UTF8_BINARY_LCASE", true); + assertContains("İ", "i", "UTF8_BINARY_LCASE", false); + assertContains("adi̇os", "io", "UTF8_BINARY_LCASE", false); + assertContains("adi̇os", "Io", "UTF8_BINARY_LCASE", false); + assertContains("adi̇os", "i̇o", "UTF8_BINARY_LCASE", true); + assertContains("adi̇os", "İo", "UTF8_BINARY_LCASE", true); + assertContains("adİos", "io", "UTF8_BINARY_LCASE", false); + assertContains("adİos", "Io", "UTF8_BINARY_LCASE", false); + assertContains("adİos", "i̇o", "UTF8_BINARY_LCASE", true); + assertContains("adİos", "İo", "UTF8_BINARY_LCASE", true); } private void assertStartsWith( @@ -191,13 +210,6 @@ public void testStartsWith() throws SparkException { assertStartsWith("ab世De", "AB世dE", "UNICODE_CI", true); assertStartsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true); assertStartsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); - // Case-variable character length - assertStartsWith("İonic", "i̇o", "UNICODE_CI", true); - assertStartsWith("i̇onic", "İo", "UNICODE_CI", true); - assertStartsWith("the İodine", "the i̇odine", "UTF8_BINARY_LCASE", true); - assertStartsWith("the i̇odine", "the İodine", "UTF8_BINARY_LCASE", true); - assertStartsWith("İodiNe", "i̇oDin", "UTF8_BINARY_LCASE", true); - assertStartsWith("The İodiNe", "i̇oDin", "UTF8_BINARY_LCASE", false); // Characters with the same binary lowercase representation assertStartsWith("Kelvin.", "Kelvin", "UTF8_BINARY_LCASE", true); assertStartsWith("Kelvin.", "Kelvin", "UTF8_BINARY_LCASE", true); @@ -205,6 +217,37 @@ public void testStartsWith() throws SparkException { assertStartsWith("2 Kelvin.", "2 Kelvin", "UTF8_BINARY_LCASE", true); assertStartsWith("2 Kelvin.", "2 Kelvin", "UTF8_BINARY_LCASE", true); assertStartsWith("KKelvin.", "KKelvin,", "UTF8_BINARY_LCASE", false); + // Case-variable character length + assertStartsWith("i̇", "i", "UNICODE_CI", false); + assertStartsWith("i̇", "İ", "UNICODE_CI", true); + assertStartsWith("İ", "i", "UNICODE_CI", false); + assertStartsWith("İİİ", "i̇i̇", "UNICODE_CI", true); + assertStartsWith("İİİ", "i̇i", "UNICODE_CI", false); + assertStartsWith("İi̇İ", "i̇İ", "UNICODE_CI", true); + assertStartsWith("i̇İi̇i̇", "İi̇İi", "UNICODE_CI", false); + assertStartsWith("i̇onic", "io", "UNICODE_CI", false); + assertStartsWith("i̇onic", "Io", "UNICODE_CI", false); + assertStartsWith("i̇onic", "i̇o", "UNICODE_CI", true); + assertStartsWith("i̇onic", "İo", "UNICODE_CI", true); + assertStartsWith("İonic", "io", "UNICODE_CI", false); + assertStartsWith("İonic", "Io", "UNICODE_CI", false); + assertStartsWith("İonic", "i̇o", "UNICODE_CI", true); + assertStartsWith("İonic", "İo", "UNICODE_CI", true); + assertStartsWith("i̇", "i", "UTF8_BINARY_LCASE", true); // != UNICODE_CI + assertStartsWith("i̇", "İ", "UTF8_BINARY_LCASE", true); + assertStartsWith("İ", "i", "UTF8_BINARY_LCASE", false); + assertStartsWith("İİİ", "i̇i̇", "UTF8_BINARY_LCASE", true); + assertStartsWith("İİİ", "i̇i", "UTF8_BINARY_LCASE", false); + assertStartsWith("İi̇İ", "i̇İ", "UTF8_BINARY_LCASE", true); + assertStartsWith("i̇İi̇i̇", "İi̇İi", "UTF8_BINARY_LCASE", true); // != UNICODE_CI + assertStartsWith("i̇onic", "io", "UTF8_BINARY_LCASE", false); + assertStartsWith("i̇onic", "Io", "UTF8_BINARY_LCASE", false); + assertStartsWith("i̇onic", "i̇o", "UTF8_BINARY_LCASE", true); + assertStartsWith("i̇onic", "İo", "UTF8_BINARY_LCASE", true); + assertStartsWith("İonic", "io", "UTF8_BINARY_LCASE", false); + assertStartsWith("İonic", "Io", "UTF8_BINARY_LCASE", false); + assertStartsWith("İonic", "i̇o", "UTF8_BINARY_LCASE", true); + assertStartsWith("İonic", "İo", "UTF8_BINARY_LCASE", true); } private void assertEndsWith(String pattern, String suffix, String collationName, boolean expected) @@ -279,13 +322,6 @@ public void testEndsWith() throws SparkException { assertEndsWith("ab世De", "AB世dE", "UNICODE_CI", true); assertEndsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true); assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); - // Case-variable character length - assertEndsWith("The İo", "i̇o", "UNICODE_CI", true); - assertEndsWith("The i̇o", "İo", "UNICODE_CI", true); - assertEndsWith("the İodine", "the i̇odine", "UTF8_BINARY_LCASE", true); - assertEndsWith("the i̇odine", "the İodine", "UTF8_BINARY_LCASE", true); - assertEndsWith("The İodiNe", "i̇oDine", "UTF8_BINARY_LCASE", true); - assertEndsWith("The İodiNe", "i̇oDin", "UTF8_BINARY_LCASE", false); // Characters with the same binary lowercase representation assertEndsWith("The Kelvin", "Kelvin", "UTF8_BINARY_LCASE", true); assertEndsWith("The Kelvin", "Kelvin", "UTF8_BINARY_LCASE", true); @@ -293,6 +329,38 @@ public void testEndsWith() throws SparkException { assertEndsWith("The 2 Kelvin", "2 Kelvin", "UTF8_BINARY_LCASE", true); assertEndsWith("The 2 Kelvin", "2 Kelvin", "UTF8_BINARY_LCASE", true); assertEndsWith("The KKelvin", "KKelvin,", "UTF8_BINARY_LCASE", false); + // Case-variable character length + assertEndsWith("i̇", "\u0307", "UNICODE_CI", false); + assertEndsWith("i̇", "İ", "UNICODE_CI", true); + assertEndsWith("İ", "i", "UNICODE_CI", false); + assertEndsWith("İİİ", "i̇i̇", "UNICODE_CI", true); + assertEndsWith("İİİ", "ii̇", "UNICODE_CI", false); + assertEndsWith("İi̇İ", "İi̇", "UNICODE_CI", true); + assertEndsWith("i̇İi̇i̇", "\u0307İi̇İ", "UNICODE_CI", false); + assertEndsWith("the i̇o", "io", "UNICODE_CI", false); + assertEndsWith("the i̇o", "Io", "UNICODE_CI", false); + assertEndsWith("the i̇o", "i̇o", "UNICODE_CI", true); + assertEndsWith("the i̇o", "İo", "UNICODE_CI", true); + assertEndsWith("the İo", "io", "UNICODE_CI", false); + assertEndsWith("the İo", "Io", "UNICODE_CI", false); + assertEndsWith("the İo", "i̇o", "UNICODE_CI", true); + assertEndsWith("the İo", "İo", "UNICODE_CI", true); + assertEndsWith("i̇", "\u0307", "UTF8_BINARY_LCASE", true); // != UNICODE_CI + assertEndsWith("i̇", "İ", "UTF8_BINARY_LCASE", true); + assertEndsWith("İ", "\u0307", "UTF8_BINARY_LCASE", false); + assertEndsWith("İİİ", "i̇i̇", "UTF8_BINARY_LCASE", true); + assertEndsWith("İİİ", "ii̇", "UTF8_BINARY_LCASE", false); + assertEndsWith("İi̇İ", "İi̇", "UTF8_BINARY_LCASE", true); + assertEndsWith("i̇İi̇i̇", "\u0307İi̇İ", "UTF8_BINARY_LCASE", true); // != UNICODE_CI + assertEndsWith("i̇İi̇i̇", "\u0307İİ", "UTF8_BINARY_LCASE", false); + assertEndsWith("the i̇o", "io", "UTF8_BINARY_LCASE", false); + assertEndsWith("the i̇o", "Io", "UTF8_BINARY_LCASE", false); + assertEndsWith("the i̇o", "i̇o", "UTF8_BINARY_LCASE", true); + assertEndsWith("the i̇o", "İo", "UTF8_BINARY_LCASE", true); + assertEndsWith("the İo", "io", "UTF8_BINARY_LCASE", false); + assertEndsWith("the İo", "Io", "UTF8_BINARY_LCASE", false); + assertEndsWith("the İo", "i̇o", "UTF8_BINARY_LCASE", true); + assertEndsWith("the İo", "İo", "UTF8_BINARY_LCASE", true); } private void assertStringSplitSQL(String str, String delimiter, String collationName, @@ -709,12 +777,24 @@ public void testLocate() throws SparkException { assertLocate("大千", "test大千世界大千世界", 9, "UNICODE_CI", 9); assertLocate("大千", "大千世界大千世界", 1, "UNICODE_CI", 1); // Case-variable character length + assertLocate("\u0307", "i̇", 1, "UTF8_BINARY", 2); + assertLocate("\u0307", "İ", 1, "UTF8_BINARY_LCASE", 0); // != UTF8_BINARY + assertLocate("i", "i̇", 1, "UNICODE_CI", 0); + assertLocate("\u0307", "i̇", 1, "UNICODE_CI", 0); + assertLocate("i̇", "i", 1, "UNICODE_CI", 0); + assertLocate("İ", "i̇", 1, "UNICODE_CI", 1); + assertLocate("İ", "i", 1, "UNICODE_CI", 0); + assertLocate("i", "i̇", 1, "UTF8_BINARY_LCASE", 1); // != UNICODE_CI + assertLocate("\u0307", "i̇", 1, "UTF8_BINARY_LCASE", 2); // != UNICODE_CI + assertLocate("i̇", "i", 1, "UTF8_BINARY_LCASE", 0); + assertLocate("İ", "i̇", 1, "UTF8_BINARY_LCASE", 1); + assertLocate("İ", "i", 1, "UTF8_BINARY_LCASE", 0); assertLocate("i̇o", "İo世界大千世界", 1, "UNICODE_CI", 1); assertLocate("i̇o", "大千İo世界大千世界", 1, "UNICODE_CI", 3); assertLocate("i̇o", "世界İo大千世界大千İo", 4, "UNICODE_CI", 11); assertLocate("İo", "i̇o世界大千世界", 1, "UNICODE_CI", 1); assertLocate("İo", "大千i̇o世界大千世界", 1, "UNICODE_CI", 3); - assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); // 12 instead of 11 + assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); } private void assertSubstringIndex(String string, String delimiter, Integer count, @@ -1008,3 +1088,4 @@ public void testStringTrim() throws SparkException { // TODO: Test other collation-aware expressions. } +// checkstyle.on: AvoidEscapedUnicodeCharacters diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index a1aba86cfbc56..0188297fd05a2 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -215,43 +215,6 @@ public void contains() { assertFalse(fromString("大千世界").contains(fromString("大千世界好"))); } - @Test - public void containsInLowerCase() { - // Corner cases - assertTrue(EMPTY_UTF8.containsInLowerCase(EMPTY_UTF8)); - assertTrue(fromString("a").containsInLowerCase(EMPTY_UTF8)); - assertTrue(fromString("A").containsInLowerCase(fromString("a"))); - assertTrue(fromString("a").containsInLowerCase(fromString("A"))); - assertFalse(EMPTY_UTF8.containsInLowerCase(fromString("a"))); - // ASCII - assertTrue(fromString("hello").containsInLowerCase(fromString("ello"))); - assertFalse(fromString("hello").containsInLowerCase(fromString("vello"))); - assertFalse(fromString("hello").containsInLowerCase(fromString("hellooo"))); - // Unicode - assertTrue(fromString("大千世界").containsInLowerCase(fromString("千世界"))); - assertFalse(fromString("大千世界").containsInLowerCase(fromString("世千"))); - assertFalse(fromString("大千世界").containsInLowerCase(fromString("大千世界好"))); - // ASCII lowercase - assertTrue(fromString("HeLlO").containsInLowerCase(fromString("ElL"))); - assertFalse(fromString("HeLlO").containsInLowerCase(fromString("ElLoO"))); - // Unicode lowercase - assertTrue(fromString("ЯбЛоКо").containsInLowerCase(fromString("БлОк"))); - assertFalse(fromString("ЯбЛоКо").containsInLowerCase(fromString("лОкБ"))); - // Characters with the same binary lowercase representation - assertTrue(fromString("The Kelvin.").containsInLowerCase(fromString("Kelvin"))); - assertTrue(fromString("The Kelvin.").containsInLowerCase(fromString("Kelvin"))); - assertTrue(fromString("The KKelvin.").containsInLowerCase(fromString("KKelvin"))); - assertTrue(fromString("2 Kelvin.").containsInLowerCase(fromString("2 Kelvin"))); - assertTrue(fromString("2 Kelvin.").containsInLowerCase(fromString("2 Kelvin"))); - assertFalse(fromString("The KKelvin.").containsInLowerCase(fromString("KKelvin,"))); - // Characters with longer binary lowercase representation - assertTrue(fromString("the İodine").containsInLowerCase(fromString("the i̇odine"))); - assertTrue(fromString("the i̇odine").containsInLowerCase(fromString("the İodine"))); - assertTrue(fromString("The İodiNe").containsInLowerCase(fromString(" i̇oDin"))); - assertTrue(fromString("İodiNe").containsInLowerCase(fromString("i̇oDin"))); - assertFalse(fromString("İodiNe").containsInLowerCase(fromString(" i̇oDin"))); - } - @Test public void startsWith() { assertTrue(EMPTY_UTF8.startsWith(EMPTY_UTF8)); @@ -263,40 +226,6 @@ public void startsWith() { assertFalse(fromString("大千世界").startsWith(fromString("大千世界好"))); } - @Test - public void startsWithInLowerCase() { - // Corner cases - assertTrue(EMPTY_UTF8.startsWithInLowerCase(EMPTY_UTF8)); - assertTrue(fromString("a").startsWithInLowerCase(EMPTY_UTF8)); - assertTrue(fromString("A").startsWithInLowerCase(fromString("a"))); - assertTrue(fromString("a").startsWithInLowerCase(fromString("A"))); - assertFalse(EMPTY_UTF8.startsWithInLowerCase(fromString("a"))); - // ASCII - assertTrue(fromString("hello").startsWithInLowerCase(fromString("hell"))); - assertFalse(fromString("hello").startsWithInLowerCase(fromString("ell"))); - // Unicode - assertTrue(fromString("大千世界").startsWithInLowerCase(fromString("大千"))); - assertFalse(fromString("大千世界").startsWithInLowerCase(fromString("世千"))); - // ASCII lowercase - assertTrue(fromString("HeLlO").startsWithInLowerCase(fromString("hElL"))); - assertFalse(fromString("HeLlO").startsWithInLowerCase(fromString("ElL"))); - // Unicode lowercase - assertTrue(fromString("ЯбЛоКо").startsWithInLowerCase(fromString("яБлОк"))); - assertFalse(fromString("ЯбЛоКо").startsWithInLowerCase(fromString("БлОк"))); - // Characters with the same binary lowercase representation - assertTrue(fromString("Kelvin.").startsWithInLowerCase(fromString("Kelvin"))); - assertTrue(fromString("Kelvin.").startsWithInLowerCase(fromString("Kelvin"))); - assertTrue(fromString("KKelvin.").startsWithInLowerCase(fromString("KKelvin"))); - assertTrue(fromString("2 Kelvin.").startsWithInLowerCase(fromString("2 Kelvin"))); - assertTrue(fromString("2 Kelvin.").startsWithInLowerCase(fromString("2 Kelvin"))); - assertFalse(fromString("KKelvin.").startsWithInLowerCase(fromString("KKelvin,"))); - // Characters with longer binary lowercase representation - assertTrue(fromString("the İodine").startsWithInLowerCase(fromString("the i̇odine"))); - assertTrue(fromString("the i̇odine").startsWithInLowerCase(fromString("the İodine"))); - assertTrue(fromString("İodiNe").startsWithInLowerCase(fromString("i̇oDin"))); - assertFalse(fromString("The İodiNe").startsWithInLowerCase(fromString("i̇oDin"))); - } - @Test public void endsWith() { assertTrue(EMPTY_UTF8.endsWith(EMPTY_UTF8)); @@ -308,40 +237,6 @@ public void endsWith() { assertFalse(fromString("数据砖头").endsWith(fromString("我的数据砖头"))); } - @Test - public void endsWithInLowerCase() { - // Corner cases - assertTrue(EMPTY_UTF8.endsWithInLowerCase(EMPTY_UTF8)); - assertTrue(fromString("a").endsWithInLowerCase(EMPTY_UTF8)); - assertTrue(fromString("A").endsWithInLowerCase(fromString("a"))); - assertTrue(fromString("a").endsWithInLowerCase(fromString("A"))); - assertFalse(EMPTY_UTF8.endsWithInLowerCase(fromString("a"))); - // ASCII - assertTrue(fromString("hello").endsWithInLowerCase(fromString("ello"))); - assertFalse(fromString("hello").endsWithInLowerCase(fromString("hell"))); - // Unicode - assertTrue(fromString("大千世界").endsWithInLowerCase(fromString("世界"))); - assertFalse(fromString("大千世界").endsWithInLowerCase(fromString("大千"))); - // ASCII lowercase - assertTrue(fromString("HeLlO").endsWithInLowerCase(fromString("ElLo"))); - assertFalse(fromString("HeLlO").endsWithInLowerCase(fromString("hElL"))); - // Unicode lowercase - assertTrue(fromString("ЯбЛоКо").endsWithInLowerCase(fromString("БлОкО"))); - assertFalse(fromString("ЯбЛоКо").endsWithInLowerCase(fromString("яБлОк"))); - // Characters with the same binary lowercase representation - assertTrue(fromString("The Kelvin").endsWithInLowerCase(fromString("Kelvin"))); - assertTrue(fromString("The Kelvin").endsWithInLowerCase(fromString("Kelvin"))); - assertTrue(fromString("The KKelvin").endsWithInLowerCase(fromString("KKelvin"))); - assertTrue(fromString("The 2 Kelvin").endsWithInLowerCase(fromString("2 Kelvin"))); - assertTrue(fromString("The 2 Kelvin").endsWithInLowerCase(fromString("2 Kelvin"))); - assertFalse(fromString("The KKelvin").endsWithInLowerCase(fromString("KKelvin,"))); - // Characters with longer binary lowercase representation - assertTrue(fromString("the İodine").endsWithInLowerCase(fromString("the i̇odine"))); - assertTrue(fromString("the i̇odine").endsWithInLowerCase(fromString("the İodine"))); - assertTrue(fromString("The İodiNe").endsWithInLowerCase(fromString("i̇oDine"))); - assertFalse(fromString("The İodiNe").endsWithInLowerCase(fromString("i̇oDin"))); - } - @Test public void substring() { assertEquals(EMPTY_UTF8, fromString("hello").substring(0, 0));