diff --git a/onetl/_util/java.py b/onetl/_util/java.py index df88b1a59..dfcfccfa1 100644 --- a/onetl/_util/java.py +++ b/onetl/_util/java.py @@ -4,6 +4,9 @@ from typing import TYPE_CHECKING +from onetl._util.spark import get_spark_version +from onetl._util.version import Version + if TYPE_CHECKING: from py4j.java_gateway import JavaGateway from pyspark.sql import SparkSession @@ -24,3 +27,34 @@ def try_import_java_class(spark_session: SparkSession, name: str): klass = getattr(gateway.jvm, name) gateway.help(klass, display=False) return klass + + +def start_callback_server(spark_session: SparkSession): + """ + Start Py4J callback server. Important to receive Java events on Python side, + e.g. in Spark Listener implementations. + """ + gateway = get_java_gateway(spark_session) + if get_spark_version(spark_session) >= Version("2.4"): + from pyspark.java_gateway import ensure_callback_server_started + + ensure_callback_server_started(gateway) + return + + # python 2.3 + if "_callback_server" not in gateway.__dict__ or gateway._callback_server is None: + from py4j.java_gateway import JavaObject + + gateway.callback_server_parameters.eager_load = True + gateway.callback_server_parameters.daemonize = True + gateway.callback_server_parameters.daemonize_connections = True + gateway.callback_server_parameters.port = 0 + gateway.start_callback_server(gateway.callback_server_parameters) + cbport = gateway._callback_server.server_socket.getsockname()[1] + gateway._callback_server.port = cbport + # gateway with real port + gateway._python_proxy_port = gateway._callback_server.port + # get the GatewayServer object in JVM by ID + java_gateway = JavaObject("GATEWAY_SERVER", gateway._gateway_client) + # update the port of CallbackClient with real port + java_gateway.resetCallbackClient(java_gateway.getCallbackClient().getAddress(), gateway._python_proxy_port) diff --git a/onetl/_util/scala.py b/onetl/_util/scala.py index 397a91576..5e6c21bc1 100644 --- a/onetl/_util/scala.py +++ b/onetl/_util/scala.py @@ -12,3 +12,10 @@ def get_default_scala_version(spark_version: Version) -> Version: if spark_version.major < 3: return Version("2.11") return Version("2.12") + + +def scala_seq_to_python_list(seq) -> list: + result = [] + for i in range(seq.size()): + result.append(seq.apply(i)) + return result diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index f9c7bcac0..2c427debd 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from etl_entities.hwm import HWM - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StructField, StructType @@ -106,6 +106,7 @@ class BaseDBConnection(BaseConnection): Implements generic methods for reading and writing dataframe from/to database-like source """ + spark: SparkSession Dialect = BaseDBDialect @property diff --git a/onetl/base/base_file_df_connection.py b/onetl/base/base_file_df_connection.py index c54390ce8..28c57f3c7 100644 --- a/onetl/base/base_file_df_connection.py +++ b/onetl/base/base_file_df_connection.py @@ -11,7 +11,7 @@ from onetl.base.pure_path_protocol import PurePathProtocol if TYPE_CHECKING: - from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter + from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter, SparkSession from pyspark.sql.types import StructType @@ -72,6 +72,8 @@ class BaseFileDFConnection(BaseConnection): .. versionadded:: 0.9.0 """ + spark: SparkSession + @abstractmethod def check_if_format_supported( self, diff --git a/onetl/db/db_writer/__init__.py b/onetl/db/db_writer/__init__.py index b181c7f04..c82d2e200 100644 --- a/onetl/db/db_writer/__init__.py +++ b/onetl/db/db_writer/__init__.py @@ -1,3 +1,4 @@ # SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 from onetl.db.db_writer.db_writer import DBWriter +from onetl.db.db_writer.result import DBWriterResult diff --git a/onetl/db/db_writer/db_writer.py b/onetl/db/db_writer/db_writer.py index 666fce87e..da3519128 100644 --- a/onetl/db/db_writer/db_writer.py +++ b/onetl/db/db_writer/db_writer.py @@ -5,6 +5,9 @@ from logging import getLogger from typing import TYPE_CHECKING, Optional +from onetl.db.db_writer.result import DBWriterResult +from onetl.metrics.collector import SparkMetricsCollector + try: from pydantic.v1 import Field, PrivateAttr, validator except (ImportError, AttributeError): @@ -16,6 +19,7 @@ from onetl.log import ( entity_boundary_log, log_dataframe_schema, + log_lines, log_options, log_with_indent, ) @@ -172,7 +176,7 @@ def validate_options(cls, options, values): return None @slot - def run(self, df: DataFrame): + def run(self, df: DataFrame) -> DBWriterResult: """ Method for writing your df to specified target. |support_hooks| @@ -185,33 +189,50 @@ def run(self, df: DataFrame): df : pyspark.sql.dataframe.DataFrame Spark dataframe + Returns + ------- + :obj:`DBWriterResult ` + + DBWriter result object + Examples -------- - Write df to target: + Write dataframe to target: .. code:: python - writer.run(df) + result = writer.run(df) """ if df.isStreaming: raise ValueError(f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames.") entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") - if not self._connection_checked: self._log_parameters() log_dataframe_schema(log, df) self.connection.check() self._connection_checked = True - self.connection.write_df_to_target( - df=df, - target=str(self.target), - **self._get_write_kwargs(), - ) + with SparkMetricsCollector(self.connection.spark) as collector: + try: + self.connection.write_df_to_target( + df=df, + target=str(self.target), + **self._get_write_kwargs(), + ) + except Exception: + log.exception( + "|%s| Error while writing dataframe. Target may contain partially written data!", + self.__class__.__name__, + ) + raise + finally: + result = DBWriterResult(metrics=collector.recorded_metrics) + self._log_result(result) entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") + return result def _log_parameters(self) -> None: log.info("|Spark| -> |%s| Writing DataFrame to target using parameters:", self.connection.__class__.__name__) @@ -225,3 +246,8 @@ def _get_write_kwargs(self) -> dict: return {"options": self.options} return {} + + def _log_result(self, result: DBWriterResult) -> None: + log_with_indent(log, "") + log.info("|%s| Write result:", self.__class__.__name__) + log_lines(log, str(result)) diff --git a/onetl/db/db_writer/result.py b/onetl/db/db_writer/result.py new file mode 100644 index 000000000..1b6a49c24 --- /dev/null +++ b/onetl/db/db_writer/result.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +import textwrap + +from onetl.impl import BaseModel +from onetl.metrics.metrics import SparkMetrics + +INDENT = " " * 4 + + +class DBWriterResult(BaseModel): + """ + Representation of DBWriter result. + + .. versionadded:: 0.12.0 + + Examples + -------- + + >>> from onetl.db import DBWriter + >>> writer = DBWriter(...) + >>> result = writer.run(df) + >>> result + DBWriterResult( + metrics=SparkMetrics( + input=SparkInputMetrics( + read_rows=1_000, + read_files=10, + read_bytes=1_000_000, + scan_bytes=2_000_000, + read_partitions=3, + dynamic_partition_pruning=True, + ), + output=SparkOutputMetrics( + written_rows=1_000, + created_files=10, + written_bytes=1_000_000, + created_dynamic_partitions=1, + ), + executor=SparkExecutorMetrics( + run_time_milliseconds=1_000, + cpu_time_nanoseconds=2_000_000_000, + peak_memory_bytes=1_000_000_000, + ), + ) + ) + """ + + metrics: SparkMetrics + + @property + def details(self) -> str: + """ + Return summarized information about the result object. + + Examples + -------- + + >>> from onetl.db.writer import DBWriterResult + >>> from onetl.metrics import SparkMetrics, SparkOutputMetrics, SparkInputMetrics, SparkExecutorMetrics + >>> result1 = DBWriterResult( + ... metrics=SparkMetrics( + ... input=SparkInputMetrics( + ... read_rows=1_000, + ... read_files=10, + ... read_bytes=1_000_000, + ... scan_bytes=2_000_000, + ... read_partitions=3, + ... dynamic_partition_pruning=True, + ... ), + ... output=SparkOutputMetrics( + ... written_bytes=1_000_000, + ... written_rows=1_000, + ... created_files=10, + ... created_dynamic_partitions=1, + ... ), + ... executor=SparkExecutorMetrics( + ... run_time_milliseconds=1_000, + ... cpu_time_nanoseconds=2_000_000_000, + ... peak_memory_bytes=1_000_000_000, + ... ), + ... ) + ... ) + >>> print(result1.details) + Metrics: + Input: + Read rows: 1000 + Read files: 10 + Read size: 1.0 MB + Scan size: 2.0 MB + Dynamic partition pruning: True + Read partitions: 3 + Output: + Written rows: 1000 + Created files: 10 + Written size: 1.0 MB + Created dynamic partitions: 1 + Executor: + Run time: 1.0 ms + CPU time: 2.0 ms + Peak memory: 1.0 MB + >>> result2 = DBWriterResult() + >>> print(result2.details) + Metrics: No data + """ + if self.metrics.is_empty: + return "Metrics: No data" + + return "Metrics:" + os.linesep + textwrap.indent(self.metrics.details, INDENT) + + def __str__(self): + """Same as :obj:`onetl.db.db_writer.result.DBWriterResult.details`""" + return self.details diff --git a/onetl/file/file_df_writer/file_df_writer.py b/onetl/file/file_df_writer/file_df_writer.py index a80f54801..f1e5c4ade 100644 --- a/onetl/file/file_df_writer/file_df_writer.py +++ b/onetl/file/file_df_writer/file_df_writer.py @@ -5,6 +5,9 @@ import logging from typing import TYPE_CHECKING +from onetl.db.db_writer.result import DBWriterResult +from onetl.metrics.collector import SparkMetricsCollector + try: from pydantic.v1 import PrivateAttr, validator except (ImportError, AttributeError): @@ -17,6 +20,7 @@ from onetl.log import ( entity_boundary_log, log_dataframe_schema, + log_lines, log_options, log_with_indent, ) @@ -93,7 +97,7 @@ class FileDFWriter(FrozenModel): _connection_checked: bool = PrivateAttr(default=False) @slot - def run(self, df: DataFrame) -> None: + def run(self, df: DataFrame) -> DBWriterResult: """ Method for writing DataFrame as files. |support_hooks| @@ -125,14 +129,25 @@ def run(self, df: DataFrame) -> None: self.connection.check() self._connection_checked = True - self.connection.write_df_as_files( - df=df, - path=self.target_path, - format=self.format, - options=self.options, - ) + with SparkMetricsCollector(self.connection.spark) as collector: + try: + self.connection.write_df_as_files( + df=df, + path=self.target_path, + format=self.format, + options=self.options, + ) + except Exception: + log.exception( + "|%s| Error while writing dataframe. Target may contain partially written data!", + self.__class__.__name__, + ) + finally: + result = DBWriterResult(metrics=collector.recorded_metrics) + self._log_result(result) entity_boundary_log(log, f"{self.__class__.__name__}.run() ends", char="-") + return result def _log_parameters(self, df: DataFrame) -> None: log.info("|Spark| -> |%s| Writing dataframe using parameters:", self.connection.__class__.__name__) @@ -143,6 +158,11 @@ def _log_parameters(self, df: DataFrame) -> None: log_options(log, options_dict) log_dataframe_schema(log, df) + def _log_result(self, result: DBWriterResult) -> None: + log_with_indent(log, "") + log.info("|%s| Write result:", self.__class__.__name__) + log_lines(log, str(result)) + @validator("target_path", pre=True) def _validate_target_path(cls, target_path, values): connection: BaseFileDFConnection = values["connection"] diff --git a/onetl/metrics/__init__.py b/onetl/metrics/__init__.py new file mode 100644 index 000000000..39eac747c --- /dev/null +++ b/onetl/metrics/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from onetl.metrics.collector import SparkMetricsCollector +from onetl.metrics.executor import SparkExecutorMetrics +from onetl.metrics.input import SparkInputMetrics +from onetl.metrics.metrics import SparkMetrics +from onetl.metrics.output import SparkOutputMetrics + +__all__ = [ + "SparkMetrics", + "SparkMetricsCollector", + "SparkExecutorMetrics", + "SparkInputMetrics", + "SparkOutputMetrics", +] diff --git a/onetl/metrics/_listener/__init__.py b/onetl/metrics/_listener/__init__.py new file mode 100644 index 000000000..47c7b892c --- /dev/null +++ b/onetl/metrics/_listener/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from onetl.metrics._listener.execution import ( + SparkListenerExecution, + SparkListenerExecutionStatus, + SparkSQLMetricNames, +) +from onetl.metrics._listener.job import SparkListenerJob, SparkListenerJobStatus +from onetl.metrics._listener.listener import SparkMetricsListener +from onetl.metrics._listener.stage import SparkListenerStage, SparkListenerStageStatus +from onetl.metrics._listener.task import ( + SparkListenerTask, + SparkListenerTaskMetrics, + SparkListenerTaskStatus, +) + +__all__ = [ + "SparkListenerTask", + "SparkListenerTaskStatus", + "SparkListenerTaskMetrics", + "SparkListenerStage", + "SparkListenerStageStatus", + "SparkListenerJob", + "SparkListenerJobStatus", + "SparkListenerExecution", + "SparkListenerExecutionStatus", + "SparkSQLMetricNames", + "SparkMetricsListener", +] diff --git a/onetl/metrics/_listener/base.py b/onetl/metrics/_listener/base.py new file mode 100644 index 000000000..26d605801 --- /dev/null +++ b/onetl/metrics/_listener/base.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + +from onetl._util.java import get_java_gateway, start_callback_server + + +@dataclass +class BaseSparkListener: + """Base no-op SparkListener implementation. + + See `SparkListener `_ interface. + """ + + spark: SparkSession + + def __post_init__(self): + # passing python listener object directly to addSparkListener or removeSparkListener leads to creating new java object each time. + # But removeSparkListener call has effect only on the same Java object passed to removeSparkListener. + # So we need to explicitly create Java object, and then pass it both calls. + gateway = get_java_gateway(self.spark) + java_list = gateway.jvm.java.util.ArrayList() + java_list.append(self) + self._java_listener = java_list[0] + + def activate(self): + start_callback_server(self.spark) + + spark_context = self.spark.sparkContext._jsc.sc() # noqa: WPS437 + spark_context.addSparkListener(self._java_listener) + + def deactivate(self): + spark_context = self.spark.sparkContext._jsc.sc() # noqa: WPS437 + spark_context.removeSparkListener(self._java_listener) + + def __enter__(self): + self.activate() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.deactivate() + + # method names are important for Java interface compatibility! + def onApplicationEnd(self, application): + pass + + def onApplicationStart(self, application): + pass + + def onBlockManagerAdded(self, block_manager): + pass + + def onBlockManagerRemoved(self, block_manager): + pass + + def onBlockUpdated(self, block): + pass + + def onEnvironmentUpdate(self, environment): + pass + + def onExecutorAdded(self, executor): + pass + + def onExecutorMetricsUpdate(self, executor): + pass + + def onExecutorRemoved(self, executor): + pass + + def onExecutorBlacklisted(self, event): + pass + + def onExecutorBlacklistedForStage(self, event): + pass + + def onExecutorExcluded(self, event): + pass + + def onExecutorExcludedForStage(self, event): + pass + + def onExecutorUnblacklisted(self, event): + pass + + def onExecutorUnexcluded(self, event): + pass + + def onJobStart(self, event): + pass + + def onJobEnd(self, event): + pass + + def onNodeBlacklisted(self, node): + pass + + def onNodeBlacklistedForStage(self, stage): + pass + + def onNodeExcluded(self, node): + pass + + def onNodeExcludedForStage(self, node): + pass + + def onNodeUnblacklisted(self, node): + pass + + def onNodeUnexcluded(self, node): + pass + + def onOtherEvent(self, event): + pass + + def onResourceProfileAdded(self, resource_profile): + pass + + def onSpeculativeTaskSubmitted(self, task): + pass + + def onStageCompleted(self, event): + pass + + def onStageExecutorMetrics(self, metrics): + pass + + def onStageSubmitted(self, event): + pass + + def onTaskEnd(self, event): + pass + + def onTaskGettingResult(self, task): + pass + + def onTaskStart(self, event): + pass + + def onUnpersistRDD(self, rdd): + pass + + def onUnschedulableTaskSetAdded(self, task_set): + pass + + def onUnschedulableTaskSetRemoved(self, task_set): + pass + + def equals(self, other): + # Java does not provide proper way to get object id for comparison, + # so we compare string representation + return other.toString() == self._java_listener.toString() + + def toString(self): + return type(self).__qualname__ + "@" + hex(id(self)) + + def hashCode(self): + return hash(self) + + class Java: + implements = ["org.apache.spark.scheduler.SparkListenerInterface"] diff --git a/onetl/metrics/_listener/execution.py b/onetl/metrics/_listener/execution.py new file mode 100644 index 000000000..04d43d7d3 --- /dev/null +++ b/onetl/metrics/_listener/execution.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum + +from onetl.metrics._listener.job import SparkListenerJob, SparkListenerJobStatus + + +class SparkListenerExecutionStatus(str, Enum): + STARTED = "STARTED" + COMPLETE = "COMPLETE" + FAILED = "FAILED" + + def __str__(self): + return self.value + + +class SparkSQLMetricNames(str, Enum): # noqa: WPS338 + # Metric names passed to SQLMetrics.createMetric(...) + # But only those we're interested in. + NUMBER_OF_DYNAMIC_PART = "number of dynamic part" + NUMBER_OF_FILES_READ = "number of files read" + NUMBER_OF_OUTPUT_ROWS = "number of output rows" + NUMBER_OF_PARTITIONS_READ = "number of partitions read" + NUMBER_OF_WRITTEN_FILES = "number of written files" + SIZE_OF_FILES_READ = "size of files read" + STATIC_NUMBER_OF_FILES_READ = "static number of files read" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerExecution: + id: int + description: str | None = None + external_id: str | None = None + status: SparkListenerExecutionStatus = SparkListenerExecutionStatus.STARTED + + # These metrics are emitted by any command performed within this execution, so we can have multiple values. + # Some metrics can be summarized, but some not, so we store a list. + metrics: dict[SparkSQLMetricNames, list[str]] = field(default_factory=lambda: defaultdict(list), repr=False) + + _jobs: dict[int, SparkListenerJob] = field(default_factory=dict, repr=False, init=False) + + @property + def jobs(self) -> list[SparkListenerJob]: + result = [] + for job_id in sorted(self._jobs.keys()): + result.append(self._jobs[job_id]) + return result + + def on_execution_start(self, event): + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L44-L58 + self.status = SparkListenerExecutionStatus.STARTED + + def on_execution_end(self, event): + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L61-L83 + for job in self._jobs.values(): + if job.status == SparkListenerJobStatus.FAILED: + self.status = SparkListenerExecutionStatus.FAILED + break + else: + self.status = SparkListenerExecutionStatus.COMPLETE + + def on_job_start(self, event): + job_id = event.jobId() + self._jobs[job_id] = SparkListenerJob.create(event) + + def on_job_end(self, event): + job_id = event.jobId() + job = self._jobs.get(job_id) + + if job: + job.on_job_end(event) + + # in some cases Execution consists of just one job with same id + if job_id == self.id: + self.on_execution_end(event) + + # push down events + def on_stage_start(self, event): + for job in self._jobs.values(): + job.on_stage_start(event) + + def on_stage_end(self, event): + for job in self._jobs.values(): + job.on_stage_end(event) + + def on_task_start(self, event): + for job in self._jobs.values(): + job.on_task_start(event) + + def on_task_end(self, event): + for job in self._jobs.values(): + job.on_task_end(event) diff --git a/onetl/metrics/_listener/job.py b/onetl/metrics/_listener/job.py new file mode 100644 index 000000000..fcfb6a317 --- /dev/null +++ b/onetl/metrics/_listener/job.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from onetl._util.scala import scala_seq_to_python_list +from onetl.metrics._listener.stage import SparkListenerStage, SparkListenerStageStatus + + +class SparkListenerJobStatus(str, Enum): + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + UNKNOWN = "UNKNOWN" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerJob: + id: int + description: str | None = None + group_id: str | None = None + call_site: str | None = None + status: SparkListenerJobStatus = SparkListenerJobStatus.UNKNOWN + + _stages: dict[int, SparkListenerStage] = field(default_factory=dict, repr=False, init=False) + + @property + def stages(self) -> list[SparkListenerStage]: + result = [] + for stage_id in sorted(self._stages.keys()): + result.append(self._stages[stage_id]) + return result + + @classmethod + def create(cls, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerJobSubmitted.html + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerJobCompleted.html + result = cls( + id=event.jobId(), + description=event.properties().get("spark.job.description"), + group_id=event.properties().get("spark.jobGroup.id"), + call_site=event.properties().get("callSite.short"), + ) + + stage_ids = scala_seq_to_python_list(event.stageIds()) + stage_infos = scala_seq_to_python_list(event.stageInfos()) + for stage_id, stage_info in zip(stage_ids, stage_infos): + result._stages[stage_id] = SparkListenerStage.create(stage_info) # noqa: WPS437 + + return result + + def on_job_start(self, event): + self.status = SparkListenerJobStatus.RUNNING + + def on_job_end(self, event): + for stage in self._stages.values(): + if stage.status == SparkListenerStageStatus.FAILED: + self.status = SparkListenerJobStatus.FAILED + break + else: + self.status = SparkListenerJobStatus.SUCCEEDED + + def on_stage_start(self, event): + stage_id = event.stageInfo().stageId() + stage = self._stages.get(stage_id) + if stage: + stage.on_stage_start(event) + + def on_stage_end(self, event): + stage_id = event.stageInfo().stageId() + stage = self._stages.get(stage_id) + if stage: + stage.on_stage_end(event) + + # push down events + def on_task_start(self, event): + for stage in self._stages.values(): + stage.on_task_start(event) + + def on_task_end(self, event): + for stage in self._stages.values(): + stage.on_task_end(event) diff --git a/onetl/metrics/_listener/listener.py b/onetl/metrics/_listener/listener.py new file mode 100644 index 000000000..3d1d63d22 --- /dev/null +++ b/onetl/metrics/_listener/listener.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from threading import current_thread +from typing import ClassVar + +from onetl.metrics._listener.base import BaseSparkListener +from onetl.metrics._listener.execution import ( + SparkListenerExecution, + SparkSQLMetricNames, +) + + +@dataclass +class SparkMetricsListener(BaseSparkListener): + THREAD_ID_KEY = "python.thread.id" + SQL_START_CLASS_NAME: ClassVar[str] = "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart" + SQL_STOP_CLASS_NAME: ClassVar[str] = "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd" + + _thread_id: str = field(default_factory=lambda: str(current_thread().ident), repr=False, init=False) + _recorded_executions: dict[int, SparkListenerExecution] = field(default_factory=dict, repr=False, init=False) + + def activate(self): + # we cannot override execution_id property as it set by Spark + # we also cannot use job tags, as they were implemented only in Spark 3.5+ + self.spark.sparkContext.setLocalProperty(self.THREAD_ID_KEY, self._thread_id) + return super().activate() + + def reset(self): + self._recorded_executions.clear() + return self + + @property + def executions(self): + return [ + execution for execution in self._recorded_executions.values() if execution.external_id == self._thread_id + ] + + def __enter__(self): + """Record only executions performed by current Spark thread. + + It is important to use this method only in combination with + :obj:`pyspark.util.InheritableThread` to preserve thread-local variables + between Python thread and Java thread. + """ + self.reset() + return super().__enter__() + + def onOtherEvent(self, event): + class_name = event.getClass().getName() + if class_name == self.SQL_START_CLASS_NAME: + self.onExecutionStart(event) + elif class_name == self.SQL_STOP_CLASS_NAME: + self.onExecutionEnd(event) + + def onExecutionStart(self, event): + execution_id = event.executionId() + description = event.description() + execution = SparkListenerExecution( + id=execution_id, + description=description, + ) + self._recorded_executions[execution_id] = execution + execution.on_execution_start(event) + + def onExecutionEnd(self, event): + execution_id = event.executionId() + execution = self._recorded_executions.get(execution_id) + if execution: + execution.on_execution_end(event) + + # Get execution metrics from SQLAppStatusStore, + # as SparkListenerSQLExecutionEnd event does not provide them: + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala + session_status_store = self.spark._jsparkSession.sharedState().statusStore() # noqa: WPS437 + raw_execution = session_status_store.execution(execution.id).get() + metrics = raw_execution.metrics() + metric_values = session_status_store.executionMetrics(execution.id) + for i in range(metrics.size()): + metric = metrics.apply(i) + metric_name = metric.name() + if metric_name not in SparkSQLMetricNames: + continue + metric_value = metric_values.get(metric.accumulatorId()) + if not metric_value.isDefined(): + continue + execution.metrics[SparkSQLMetricNames(metric_name)].append(metric_value.get()) + + def onJobStart(self, event): + execution_id = event.properties().get("spark.sql.execution.id") + execution_thread_id = event.properties().get(self.THREAD_ID_KEY) + if execution_id is None: + # single job execution + job_id = event.jobId() + execution = SparkListenerExecution( + id=job_id, + description=event.properties().get("spark.job.description"), + external_id=execution_thread_id, + ) + self._recorded_executions[job_id] = execution + else: + execution = self._recorded_executions.get(int(execution_id)) + if execution is None: + return + + if execution_thread_id: + # SparkListenerSQLExecutionStart does not have properties, but SparkListenerJobStart does, + # use it as a source of external_id + execution.external_id = execution_thread_id + + execution.on_job_start(event) + + def onJobEnd(self, event): + for execution in self._recorded_executions.values(): + execution.on_job_end(event) + + def onStageSubmitted(self, event): + for execution in self._recorded_executions.values(): + execution.on_stage_start(event) + + def onStageCompleted(self, event): + for execution in self._recorded_executions.values(): + execution.on_stage_end(event) + + def onTaskStart(self, event): + for execution in self._recorded_executions.values(): + execution.on_task_start(event) + + def onTaskEnd(self, event): + for execution in self._recorded_executions.values(): + execution.on_task_end(event) diff --git a/onetl/metrics/_listener/stage.py b/onetl/metrics/_listener/stage.py new file mode 100644 index 000000000..426328a4b --- /dev/null +++ b/onetl/metrics/_listener/stage.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from onetl.metrics._listener.task import SparkListenerTask, SparkListenerTaskMetrics + + +class SparkListenerStageStatus(str, Enum): + ACTIVE = "ACTIVE" + COMPLETE = "COMPLETE" + FAILED = "FAILED" + PENDING = "PENDING" + SKIPPED = "SKIPPED" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerStage: + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/StageInfo.html + id: int + status: SparkListenerStageStatus = SparkListenerStageStatus.PENDING + metrics: SparkListenerTaskMetrics = field(default_factory=SparkListenerTaskMetrics, repr=False, init=False) + _tasks: dict[int, SparkListenerTask] = field(default_factory=dict, repr=False, init=False) + + @property + def tasks(self) -> list[SparkListenerTask]: + result = [] + for task_id in sorted(self._tasks.keys()): + result.append(self._tasks[task_id]) + return result + + @classmethod + def create(cls, stage_info): + return cls(id=stage_info.stageId()) + + def on_stage_start(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerStageSubmitted.html + self.status = SparkListenerStageStatus.ACTIVE + + def on_stage_end(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerStageCompleted.html + stage_info = event.stageInfo() + if stage_info.failureReason().isDefined(): + self.status = SparkListenerStageStatus.FAILED + elif not self.tasks: + self.status = SparkListenerStageStatus.SKIPPED + else: + self.status = SparkListenerStageStatus.COMPLETE + + metrics = stage_info.taskMetrics() + if metrics: + self.metrics = SparkListenerTaskMetrics.create(metrics) + + def on_task_start(self, event): + task_info = event.taskInfo() + task_id = task_info.taskId() + self._tasks[task_id] = SparkListenerTask.create(task_info) + + def on_task_end(self, event): + task_id = event.taskInfo().taskId() + task = self._tasks.get(task_id) + if task: + task.on_task_end(event) diff --git a/onetl/metrics/_listener/task.py b/onetl/metrics/_listener/task.py new file mode 100644 index 000000000..cce25330c --- /dev/null +++ b/onetl/metrics/_listener/task.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + + +class SparkListenerTaskStatus(str, Enum): + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCESS = "SUCCESS" + FAILED = "FAILED" + KILLED = "KILLED" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerTaskInputMetrics: + bytes_read: int = 0 + records_read: int = 0 + + @classmethod + def create(cls, task_input_metrics): + return cls( + bytes_read=task_input_metrics.bytesRead(), + records_read=task_input_metrics.recordsRead(), + ) + + +@dataclass +class SparkListenerTaskOutputMetrics: + bytes_written: int = 0 + records_written: int = 0 + + @classmethod + def create(cls, task_output_metrics): + return cls( + bytes_written=task_output_metrics.bytesWritten(), + records_written=task_output_metrics.recordsWritten(), + ) + + +@dataclass +class SparkListenerTaskMetrics: + """Python representation of Spark TaskMetrics object. + + See `documentation `_. + """ + + executor_run_time_milliseconds: int = 0 + executor_cpu_time_nanoseconds: int = 0 + peak_execution_memory_bytes: int = 0 + input_metrics: SparkListenerTaskInputMetrics = field(default_factory=SparkListenerTaskInputMetrics) + output_metrics: SparkListenerTaskOutputMetrics = field(default_factory=SparkListenerTaskOutputMetrics) + + @classmethod + def create(cls, task_metrics): + return cls( + executor_run_time_milliseconds=task_metrics.executorRunTime(), + executor_cpu_time_nanoseconds=task_metrics.executorCpuTime(), + peak_execution_memory_bytes=task_metrics.peakExecutionMemory(), + input_metrics=SparkListenerTaskInputMetrics.create(task_metrics.inputMetrics()), + output_metrics=SparkListenerTaskOutputMetrics.create(task_metrics.outputMetrics()), + ) + + +@dataclass +class SparkListenerTask: + id: int + status: SparkListenerTaskStatus = SparkListenerTaskStatus.PENDING + metrics: SparkListenerTaskMetrics | None = field(default=None, repr=False, init=False) + + @classmethod + def create(cls, task_info): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/TaskInfo.html + return cls(id=task_info.taskId()) + + def on_task_start(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerTaskStart.html + self.status = SparkListenerTaskStatus(event.taskInfo().status()) + + def on_task_end(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerTaskEnd.html + self.status = SparkListenerTaskStatus(event.taskInfo().status()) + self.metrics = SparkListenerTaskMetrics.create(event.taskMetrics()) diff --git a/onetl/metrics/collector.py b/onetl/metrics/collector.py new file mode 100644 index 000000000..27dcbc629 --- /dev/null +++ b/onetl/metrics/collector.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from onetl.metrics._listener.execution import ( + SparkListenerExecution, + SparkSQLMetricNames, +) +from onetl.metrics._listener.listener import SparkMetricsListener +from onetl.metrics._listener.stage import SparkListenerStageStatus +from onetl.metrics.executor import SparkExecutorMetrics +from onetl.metrics.input import SparkInputMetrics +from onetl.metrics.metrics import SparkMetrics +from onetl.metrics.output import SparkOutputMetrics + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + +def _get_first_int(data: dict, key: Any) -> int | None: + if key not in data: + return None + + return int(data[key][0]) + + +class SparkMetricsCollector: + def __init__(self, spark: SparkSession): + self._listener = SparkMetricsListener(spark=spark) + + def __enter__(self) -> SparkMetricsCollector: + self._listener.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._listener.__exit__(exc_type, exc_val, exc_tb) + + @property + def recorded_metrics(self) -> SparkMetrics: + result = SparkMetrics() + for execution in self._listener.executions: + result = result.merge(self._execution_to_metrics(execution)) + return result + + def _execution_to_metrics(self, execution: SparkListenerExecution) -> SparkMetrics: + run_time_milliseconds: int = 0 + cpu_time_nanoseconds: int = 0 + peak_memory_bytes: int = 0 + + input_read_bytes: int = 0 + input_read_rows: int = 0 + output_bytes: int = 0 + output_rows: int = 0 + + for job in execution.jobs: + for stage in job.stages: + run_time_milliseconds += stage.metrics.executor_run_time_milliseconds + cpu_time_nanoseconds += stage.metrics.executor_cpu_time_nanoseconds + peak_memory_bytes = max(peak_memory_bytes, stage.metrics.peak_execution_memory_bytes) + + if stage.status == SparkListenerStageStatus.COMPLETE: + input_read_bytes += stage.metrics.input_metrics.bytes_read + input_read_rows += stage.metrics.input_metrics.records_read + output_bytes += stage.metrics.output_metrics.bytes_written + output_rows += stage.metrics.output_metrics.records_written + + input_read_files_dynamic = _get_first_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_FILES_READ) or 0 + input_read_files_static = ( + _get_first_int(execution.metrics, SparkSQLMetricNames.STATIC_NUMBER_OF_FILES_READ) or 0 + ) + input_read_files = input_read_files_dynamic + input_read_files_static + + input_dynamic_partition_pruning = not bool(input_read_files_static) + input_read_partitions = _get_first_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_PARTITIONS_READ) or 0 + + input_scan_files_dynamic = _get_first_int(execution.metrics, SparkSQLMetricNames.SIZE_OF_FILES_READ) or 0 + input_scan_files_static = ( + _get_first_int(execution.metrics, SparkSQLMetricNames.STATIC_NUMBER_OF_FILES_READ) or 0 + ) + input_scan_bytes = input_scan_files_dynamic + input_scan_files_static + + output_files = _get_first_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_WRITTEN_FILES) or 0 + output_dynamic_partitions = _get_first_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_DYNAMIC_PART) or 0 + + return SparkMetrics( + input=SparkInputMetrics( + read_rows=input_read_rows, + read_files=input_read_files, + read_bytes=input_read_bytes, + scan_bytes=input_scan_bytes, + read_partitions=input_read_partitions, + dynamic_partition_pruning=input_dynamic_partition_pruning, + ), + output=SparkOutputMetrics( + written_bytes=output_bytes, + written_rows=output_rows, + created_files=output_files, + created_dynamic_partitions=output_dynamic_partitions, + ), + executor=SparkExecutorMetrics( + run_time_milliseconds=run_time_milliseconds, + cpu_time_nanoseconds=cpu_time_nanoseconds, + peak_memory_bytes=peak_memory_bytes, + ), + ) diff --git a/onetl/metrics/executor.py b/onetl/metrics/executor.py new file mode 100644 index 000000000..3b6d83c73 --- /dev/null +++ b/onetl/metrics/executor.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +from datetime import timedelta + +from humanize import naturalsize, precisedelta + +from onetl.impl import BaseModel + + +class SparkExecutorMetrics(BaseModel): + run_time_milliseconds: int = 0 + cpu_time_nanoseconds: int = 0 + peak_memory_bytes: int = 0 + + @property + def is_empty(self) -> bool: + return not any([self.run_time_milliseconds, self.cpu_time_nanoseconds, self.peak_memory_bytes]) + + def merge(self, other: SparkExecutorMetrics) -> SparkExecutorMetrics: + self.run_time_milliseconds += other.run_time_milliseconds + self.cpu_time_nanoseconds += other.cpu_time_nanoseconds + self.peak_memory_bytes = max(self.peak_memory_bytes, other.peak_memory_bytes) + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + if self.run_time_milliseconds: + result.append(f"Run time: {precisedelta(timedelta(milliseconds=self.run_time_milliseconds))}") + + if self.cpu_time_nanoseconds: + result.append(f"CPU time: {precisedelta(timedelta(microseconds=self.cpu_time_nanoseconds / 1000))}") + + if self.peak_memory_bytes: + result.append(f"Peak memory: {naturalsize(self.peak_memory_bytes)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/metrics/input.py b/onetl/metrics/input.py new file mode 100644 index 000000000..ae9d3c5d8 --- /dev/null +++ b/onetl/metrics/input.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os + +from humanize import naturalsize + +from onetl.impl import BaseModel + + +class SparkInputMetrics(BaseModel): + read_rows: int = 0 + read_files: int = 0 + read_bytes: int = 0 + scan_bytes: int = 0 + read_partitions: int = 0 + dynamic_partition_pruning: bool = False + + @property + def is_empty(self) -> bool: + return not any([self.read_bytes, self.read_files, self.read_rows]) + + def merge(self, other: SparkInputMetrics) -> SparkInputMetrics: + self.read_rows += other.read_rows + self.read_files += other.read_files + self.read_bytes += other.read_bytes + self.scan_bytes += other.scan_bytes + self.read_partitions = max([self.read_partitions, other.read_partitions]) + if not self.dynamic_partition_pruning: + self.dynamic_partition_pruning = other.dynamic_partition_pruning + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + if self.read_rows: + result.append(f"Read rows: {self.read_rows}") + + if self.read_files: + result.append(f"Read files: {self.read_files}") + + if self.read_bytes and self.scan_bytes and self.read_bytes != self.scan_bytes: + result.append(f"Read size: {naturalsize(self.read_bytes)}") + result.append(f"Scan size: {naturalsize(self.scan_bytes)}") + result.append(f"Dynamic partition pruning: {self.dynamic_partition_pruning}") + if self.read_partitions: + result.append(f"Read partitions: {self.read_partitions}") + elif self.read_bytes: + result.append(f"Read size: {naturalsize(self.read_bytes)}") + if self.scan_bytes: + result.append(f"Dynamic partition pruning: {self.dynamic_partition_pruning}") + if self.read_partitions: + result.append(f"Read partitions: {self.read_partitions}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/metrics/metrics.py b/onetl/metrics/metrics.py new file mode 100644 index 000000000..33e9e4c2f --- /dev/null +++ b/onetl/metrics/metrics.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +import textwrap + +from pydantic import Field + +from onetl.impl import BaseModel +from onetl.metrics.executor import SparkExecutorMetrics +from onetl.metrics.input import SparkInputMetrics +from onetl.metrics.output import SparkOutputMetrics + +INDENT = " " * 4 + + +class SparkMetrics(BaseModel): + input: SparkInputMetrics = Field(default_factory=SparkInputMetrics) + output: SparkOutputMetrics = Field(default_factory=SparkOutputMetrics) + executor: SparkExecutorMetrics = Field(default_factory=SparkExecutorMetrics) + + @property + def is_empty(self) -> bool: + return all([self.input.is_empty, self.output.is_empty, self.executor.is_empty]) + + def merge(self, other: SparkMetrics) -> SparkMetrics: + self.input.merge(other.input) + self.output.merge(other.output) + self.executor.merge(other.executor) + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + if not self.input.is_empty: + result.append(f"Input:{os.linesep}{textwrap.indent(self.input.details, INDENT)}") + if not self.output.is_empty: + result.append(f"Output:{os.linesep}{textwrap.indent(self.output.details, INDENT)}") + if not self.executor.is_empty: + result.append(f"Executor:{os.linesep}{textwrap.indent(self.executor.details, INDENT)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/metrics/output.py b/onetl/metrics/output.py new file mode 100644 index 000000000..be3bb6801 --- /dev/null +++ b/onetl/metrics/output.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os + +from humanize import naturalsize + +from onetl.impl import BaseModel + + +class SparkOutputMetrics(BaseModel): + written_bytes: int = 0 + written_rows: int = 0 + created_files: int = 0 + created_dynamic_partitions: int = 0 + + @property + def is_empty(self) -> bool: + return not any([self.written_bytes, self.written_rows, self.created_files]) + + def merge(self, other: SparkOutputMetrics) -> SparkOutputMetrics: + self.written_bytes += other.written_bytes + self.written_rows += other.written_rows + self.created_files += other.created_files + self.created_dynamic_partitions = max([self.created_dynamic_partitions, other.created_dynamic_partitions]) + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + if self.written_rows: + result.append(f"Written rows: {self.written_rows}") + + if self.created_files: + result.append(f"Written files: {self.created_files}") + + if self.created_dynamic_partitions: + result.append(f"Created dynamic partitions: {self.created_dynamic_partitions}") + + if self.written_bytes: + result.append(f"Written size: {naturalsize(self.written_bytes)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/strategy/hwm_store/__init__.py b/onetl/strategy/hwm_store/__init__.py index 0b931301e..7a0338d31 100644 --- a/onetl/strategy/hwm_store/__init__.py +++ b/onetl/strategy/hwm_store/__init__.py @@ -23,7 +23,7 @@ register_spark_type_to_hwm_type_mapping, ) -__all__ = [ # noqa: WPS410 +__all__ = [ "BaseHWMStore", "SparkTypeToHWM", "register_spark_type_to_hwm_type_mapping", diff --git a/onetl/version.py b/onetl/version.py index dada22dd7..1a3c6cecc 100644 --- a/onetl/version.py +++ b/onetl/version.py @@ -8,4 +8,4 @@ VERSION_FILE = Path(__file__).parent / "VERSION" -__version__ = VERSION_FILE.read_text().strip() # noqa: WPS410 +__version__ = VERSION_FILE.read_text().strip() diff --git a/setup.cfg b/setup.cfg index d12261ed9..b03476fb3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -275,7 +275,9 @@ ignore = # WPS474 Found import object collision WPS474, # WPS318 Found extra indentation - WPS318 + WPS318, +# WPS410 Found wrong metadata variable: __all__ + WPS410 # http://flake8.pycqa.org/en/latest/user/options.html?highlight=per-file-ignores#cmdoption-flake8-per-file-ignores per-file-ignores = @@ -350,6 +352,9 @@ per-file-ignores = onetl/hooks/slot.py: # WPS210 Found too many local variables WPS210, + onetl/metrics/_listener/*: +# N802 function name 'onJobStart' should be lowercase + N802, tests/*: # Found too many empty lines in `def` WPS473,