diff --git a/.github/workflows/data/file-df/tracked.txt b/.github/workflows/data/file-df/tracked.txt index 880912b14..c12307374 100644 --- a/.github/workflows/data/file-df/tracked.txt +++ b/.github/workflows/data/file-df/tracked.txt @@ -1,6 +1,4 @@ .github/workflows/data/file-df/** -onetl/file_df_connection/spark_file_df_connection.py -onetl/file/file_df_reader/** -onetl/file/file_df_writer/** onetl/file/__init__.py -tests/resources/file_df_connection/** +**/*file_df* +**/*file_df*/** diff --git a/onetl/_metrics/__init__.py b/onetl/_metrics/__init__.py new file mode 100644 index 000000000..5d7482b69 --- /dev/null +++ b/onetl/_metrics/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from onetl._metrics.command import SparkCommandMetrics +from onetl._metrics.driver import SparkDriverMetrics +from onetl._metrics.executor import SparkExecutorMetrics +from onetl._metrics.input import SparkInputMetrics +from onetl._metrics.output import SparkOutputMetrics +from onetl._metrics.recorder import SparkMetricsRecorder + +__all__ = [ + "SparkCommandMetrics", + "SparkDriverMetrics", + "SparkMetricsRecorder", + "SparkExecutorMetrics", + "SparkInputMetrics", + "SparkOutputMetrics", +] diff --git a/onetl/_metrics/command.py b/onetl/_metrics/command.py new file mode 100644 index 000000000..2a8a53c64 --- /dev/null +++ b/onetl/_metrics/command.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +import textwrap + +try: + from pydantic.v1 import Field +except (ImportError, AttributeError): + from pydantic import Field # type: ignore[no-redef, assignment] + +from onetl._metrics.driver import SparkDriverMetrics +from onetl._metrics.executor import SparkExecutorMetrics +from onetl._metrics.input import SparkInputMetrics +from onetl._metrics.output import SparkOutputMetrics +from onetl.impl import BaseModel + +INDENT = " " * 4 + + +class SparkCommandMetrics(BaseModel): + input: SparkInputMetrics = Field(default_factory=SparkInputMetrics) + output: SparkOutputMetrics = Field(default_factory=SparkOutputMetrics) + driver: SparkDriverMetrics = Field(default_factory=SparkDriverMetrics) + executor: SparkExecutorMetrics = Field(default_factory=SparkExecutorMetrics) + + @property + def is_empty(self) -> bool: + return all([self.input.is_empty, self.output.is_empty]) + + def update(self, other: SparkCommandMetrics) -> SparkCommandMetrics: + self.input.update(other.input) + self.output.update(other.output) + self.driver.update(other.driver) + self.executor.update(other.executor) + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + if not self.input.is_empty: + result.append(f"Input:{os.linesep}{textwrap.indent(self.input.details, INDENT)}") + if not self.output.is_empty: + result.append(f"Output:{os.linesep}{textwrap.indent(self.output.details, INDENT)}") + if not self.driver.is_empty: + result.append(f"Driver:{os.linesep}{textwrap.indent(self.driver.details, INDENT)}") + if not self.executor.is_empty: + result.append(f"Executor:{os.linesep}{textwrap.indent(self.executor.details, INDENT)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/_metrics/driver.py b/onetl/_metrics/driver.py new file mode 100644 index 000000000..4e6857192 --- /dev/null +++ b/onetl/_metrics/driver.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os + +from humanize import naturalsize + +from onetl.impl import BaseModel + +# Metrics themselves are considered a part of driver result, +# ignore if result is smaller than 1MB +MIN_DRIVER_BYTES = 1_000_000 + + +class SparkDriverMetrics(BaseModel): + in_memory_bytes: int = 0 + + @property + def is_empty(self) -> bool: + return self.in_memory_bytes < MIN_DRIVER_BYTES + + def update(self, other: SparkDriverMetrics) -> SparkDriverMetrics: + self.in_memory_bytes += other.in_memory_bytes + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + if self.in_memory_bytes >= MIN_DRIVER_BYTES: + result.append(f"In-memory data (approximate): {naturalsize(self.in_memory_bytes)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/_metrics/executor.py b/onetl/_metrics/executor.py new file mode 100644 index 000000000..3fd6f3fc6 --- /dev/null +++ b/onetl/_metrics/executor.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +from datetime import timedelta + +from humanize import naturalsize, precisedelta + +from onetl.impl import BaseModel + + +class SparkExecutorMetrics(BaseModel): + total_run_time: timedelta = timedelta() + total_cpu_time: timedelta = timedelta() + peak_memory_bytes: int = 0 + memory_spilled_bytes: int = 0 + disk_spilled_bytes: int = 0 + + @property + def is_empty(self) -> bool: + return not self.total_run_time + + def update(self, other: SparkExecutorMetrics) -> SparkExecutorMetrics: + self.total_run_time += other.total_run_time + self.total_cpu_time += other.total_cpu_time + self.peak_memory_bytes += other.peak_memory_bytes + self.memory_spilled_bytes += other.memory_spilled_bytes + self.disk_spilled_bytes += other.disk_spilled_bytes + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [ + f"Total run time: {precisedelta(self.total_run_time)}", + f"Total CPU time: {precisedelta(self.total_cpu_time)}", + ] + + if self.peak_memory_bytes: + result.append(f"Peak memory: {naturalsize(self.peak_memory_bytes)}") + + if self.memory_spilled_bytes: + result.append(f"Memory spilled: {naturalsize(self.memory_spilled_bytes)}") + + if self.disk_spilled_bytes: + result.append(f"Disk spilled: {naturalsize(self.disk_spilled_bytes)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/_metrics/extract.py b/onetl/_metrics/extract.py new file mode 100644 index 000000000..42fd56232 --- /dev/null +++ b/onetl/_metrics/extract.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import re +from datetime import timedelta +from typing import Any + +try: + from pydantic.v1 import ByteSize +except (ImportError, AttributeError): + from pydantic import ByteSize # type: ignore[no-redef, assignment] + +from onetl._metrics.command import SparkCommandMetrics +from onetl._metrics.driver import SparkDriverMetrics +from onetl._metrics.executor import SparkExecutorMetrics +from onetl._metrics.input import SparkInputMetrics +from onetl._metrics.listener.execution import ( + SparkListenerExecution, + SparkSQLMetricNames, +) +from onetl._metrics.output import SparkOutputMetrics + +NON_DIGIT = re.compile(r"[^\d.]") + + +def _get_int(data: dict[SparkSQLMetricNames, list[str]], key: Any) -> int | None: + if key not in data: + return None + + items = data[key] + if not items: + return None + + return int(items[0]) + + +def _get_bytes(data: dict[SparkSQLMetricNames, list[str]], key: Any) -> int | None: + if key not in data: + return None + + items = data[key] + if not items: + return None + + return int(ByteSize.validate(items[0])) + + +def extract_metrics_from_execution(execution: SparkListenerExecution) -> SparkCommandMetrics: + input_read_bytes: int = 0 + input_read_rows: int = 0 + output_bytes: int = 0 + output_rows: int = 0 + + run_time_milliseconds: int = 0 + cpu_time_nanoseconds: int = 0 + peak_memory_bytes: int = 0 + memory_spilled_bytes: int = 0 + disk_spilled_bytes: int = 0 + result_size_bytes: int = 0 + + # some metrics are per-stage, and have to be summed, others are per-execution + for job in execution.jobs: + for stage in job.stages: + input_read_bytes += stage.metrics.input_metrics.bytes_read + input_read_rows += stage.metrics.input_metrics.records_read + output_bytes += stage.metrics.output_metrics.bytes_written + output_rows += stage.metrics.output_metrics.records_written + + run_time_milliseconds += stage.metrics.executor_run_time_milliseconds + cpu_time_nanoseconds += stage.metrics.executor_cpu_time_nanoseconds + peak_memory_bytes += stage.metrics.peak_execution_memory_bytes + memory_spilled_bytes += stage.metrics.memory_spilled_bytes + disk_spilled_bytes += stage.metrics.disk_spilled_bytes + result_size_bytes += stage.metrics.result_size_bytes + + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L467-L473 + input_file_count = ( + _get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_FILES_READ) + or _get_int(execution.metrics, SparkSQLMetricNames.STATIC_NUMBER_OF_FILES_READ) + or 0 + ) + input_raw_file_bytes = ( + _get_bytes(execution.metrics, SparkSQLMetricNames.SIZE_OF_FILES_READ) + or _get_bytes(execution.metrics, SparkSQLMetricNames.STATIC_SIZE_OF_FILES_READ) + or 0 + ) + input_read_partitions = _get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_PARTITIONS_READ) or 0 + + output_files = _get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_WRITTEN_FILES) or 0 + output_dynamic_partitions = _get_int(execution.metrics, SparkSQLMetricNames.NUMBER_OF_DYNAMIC_PART) or 0 + + return SparkCommandMetrics( + input=SparkInputMetrics( + read_rows=input_read_rows, + read_files=input_file_count, + read_bytes=input_read_bytes, + raw_file_bytes=input_raw_file_bytes, + read_partitions=input_read_partitions, + ), + output=SparkOutputMetrics( + written_rows=output_rows, + written_bytes=output_bytes, + created_files=output_files, + created_partitions=output_dynamic_partitions, + ), + driver=SparkDriverMetrics( + in_memory_bytes=result_size_bytes, + ), + executor=SparkExecutorMetrics( + total_run_time=timedelta(milliseconds=run_time_milliseconds), + total_cpu_time=timedelta(microseconds=cpu_time_nanoseconds / 1000), + peak_memory_bytes=peak_memory_bytes, + memory_spilled_bytes=memory_spilled_bytes, + disk_spilled_bytes=disk_spilled_bytes, + ), + ) diff --git a/onetl/_metrics/input.py b/onetl/_metrics/input.py new file mode 100644 index 000000000..390613114 --- /dev/null +++ b/onetl/_metrics/input.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +from pprint import pformat + +from humanize import naturalsize + +from onetl.impl import BaseModel + + +class SparkInputMetrics(BaseModel): + read_rows: int = 0 + read_files: int = 0 + read_partitions: int = 0 + read_bytes: int = 0 + raw_file_bytes: int = 0 + + @property + def is_empty(self) -> bool: + return not any([self.read_bytes, self.read_files, self.read_rows]) + + def update(self, other: SparkInputMetrics) -> SparkInputMetrics: + self.read_rows += other.read_rows + self.read_files += other.read_files + self.read_partitions += other.read_partitions + self.read_bytes += other.read_bytes + self.raw_file_bytes += other.raw_file_bytes + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + result.append(f"Read rows: {pformat(self.read_rows)}") + + if self.read_partitions: + result.append(f"Read partitions: {pformat(self.read_partitions)}") + + if self.read_files: + result.append(f"Read files: {pformat(self.read_files)}") + + if self.read_bytes: + result.append(f"Read size: {naturalsize(self.read_bytes)}") + + if self.raw_file_bytes and self.read_bytes != self.raw_file_bytes: + result.append(f"Raw files size: {naturalsize(self.raw_file_bytes)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/_metrics/listener/__init__.py b/onetl/_metrics/listener/__init__.py new file mode 100644 index 000000000..112e4fba7 --- /dev/null +++ b/onetl/_metrics/listener/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from onetl._metrics.listener.execution import ( + SparkListenerExecution, + SparkListenerExecutionStatus, + SparkSQLMetricNames, +) +from onetl._metrics.listener.job import SparkListenerJob, SparkListenerJobStatus +from onetl._metrics.listener.listener import SparkMetricsListener +from onetl._metrics.listener.stage import SparkListenerStage, SparkListenerStageStatus +from onetl._metrics.listener.task import ( + SparkListenerTask, + SparkListenerTaskMetrics, + SparkListenerTaskStatus, +) + +__all__ = [ + "SparkListenerTask", + "SparkListenerTaskStatus", + "SparkListenerTaskMetrics", + "SparkListenerStage", + "SparkListenerStageStatus", + "SparkListenerJob", + "SparkListenerJobStatus", + "SparkListenerExecution", + "SparkListenerExecutionStatus", + "SparkSQLMetricNames", + "SparkMetricsListener", +] diff --git a/onetl/_metrics/listener/base.py b/onetl/_metrics/listener/base.py new file mode 100644 index 000000000..90432c7c3 --- /dev/null +++ b/onetl/_metrics/listener/base.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from contextlib import suppress +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from onetl._util.java import get_java_gateway, start_callback_server + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + +@dataclass +class BaseSparkListener: + """Base no-op SparkListener implementation. + + See `SparkListener `_ interface. + """ + + spark: SparkSession + + def activate(self): + start_callback_server(self.spark) + + # passing python listener object directly to addSparkListener or removeSparkListener leads to creating new java object each time. + # But removeSparkListener call has effect only on the same Java object passed to removeSparkListener. + # So we need to explicitly create Java object, and then pass it both calls. + gateway = get_java_gateway(self.spark) + java_list = gateway.jvm.java.util.ArrayList() + java_list.append(self) + self._java_listener = java_list[0] + + spark_context = self.spark.sparkContext._jsc.sc() # noqa: WPS437 + spark_context.addSparkListener(self._java_listener) + + def deactivate(self): + with suppress(Exception): + spark_context = self.spark.sparkContext._jsc.sc() # noqa: WPS437 + spark_context.removeSparkListener(self._java_listener) + + with suppress(Exception): + del self._java_listener + + def __enter__(self): + self.activate() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.deactivate() + + def __del__(self): # noqa: WPS603 + # If current object is collected by GC, deactivate listener + # and free bind Java object + self.deactivate() + + def equals(self, other): + # Java does not provide proper way to get object id for comparison, + # so we compare string representation which should contain some form of id + return other.toString() == self._java_listener.toString() + + def toString(self): + return type(self).__qualname__ + "@" + hex(id(self)) + + def hashCode(self): + return hash(self) + + # no cover: start + # method names are important for Java interface compatibility! + def onApplicationEnd(self, application): + pass + + def onApplicationStart(self, application): + pass + + def onBlockManagerAdded(self, block_manager): + pass + + def onBlockManagerRemoved(self, block_manager): + pass + + def onBlockUpdated(self, block): + pass + + def onEnvironmentUpdate(self, environment): + pass + + def onExecutorAdded(self, executor): + pass + + def onExecutorMetricsUpdate(self, executor): + pass + + def onExecutorRemoved(self, executor): + pass + + def onExecutorBlacklisted(self, event): + pass + + def onExecutorBlacklistedForStage(self, event): + pass + + def onExecutorExcluded(self, event): + pass + + def onExecutorExcludedForStage(self, event): + pass + + def onExecutorUnblacklisted(self, event): + pass + + def onExecutorUnexcluded(self, event): + pass + + def onJobStart(self, event): + pass + + def onJobEnd(self, event): + pass + + def onNodeBlacklisted(self, node): + pass + + def onNodeBlacklistedForStage(self, stage): + pass + + def onNodeExcluded(self, node): + pass + + def onNodeExcludedForStage(self, node): + pass + + def onNodeUnblacklisted(self, node): + pass + + def onNodeUnexcluded(self, node): + pass + + def onOtherEvent(self, event): + pass + + def onResourceProfileAdded(self, resource_profile): + pass + + def onSpeculativeTaskSubmitted(self, task): + pass + + def onStageCompleted(self, event): + pass + + def onStageExecutorMetrics(self, metrics): + pass + + def onStageSubmitted(self, event): + pass + + def onTaskEnd(self, event): + pass + + def onTaskGettingResult(self, task): + pass + + def onTaskStart(self, event): + pass + + def onUnpersistRDD(self, rdd): + pass + + def onUnschedulableTaskSetAdded(self, task_set): + pass + + def onUnschedulableTaskSetRemoved(self, task_set): + pass + + # no cover: stop + class Java: + implements = ["org.apache.spark.scheduler.SparkListenerInterface"] diff --git a/onetl/_metrics/listener/execution.py b/onetl/_metrics/listener/execution.py new file mode 100644 index 000000000..728c4c2ca --- /dev/null +++ b/onetl/_metrics/listener/execution.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum + +from onetl._metrics.listener.job import SparkListenerJob, SparkListenerJobStatus + + +class SparkListenerExecutionStatus(str, Enum): + STARTED = "STARTED" + COMPLETE = "COMPLETE" + FAILED = "FAILED" + + def __str__(self): + return self.value + + +class SparkSQLMetricNames(str, Enum): # noqa: WPS338 + # Metric names passed to SQLMetrics.createMetric(...) + # But only those we're interested in. + + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L233C55-L233C87 + NUMBER_OF_PARTITIONS_READ = "number of partitions read" + + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L225-L227 + NUMBER_OF_FILES_READ = "number of files read" + SIZE_OF_FILES_READ = "size of files read" + + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L455-L456 + STATIC_NUMBER_OF_FILES_READ = "static number of files read" + STATIC_SIZE_OF_FILES_READ = "static size of files read" + + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala#L241-L246 + NUMBER_OF_DYNAMIC_PART = "number of dynamic part" + NUMBER_OF_WRITTEN_FILES = "number of written files" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerExecution: + id: int + description: str | None = None + external_id: str | None = None + status: SparkListenerExecutionStatus = SparkListenerExecutionStatus.STARTED + + # These metrics are emitted by any command performed within this execution, so we can have multiple values. + # Some metrics can be summarized, but some not, so we store a list. + metrics: dict[SparkSQLMetricNames, list[str]] = field(default_factory=lambda: defaultdict(list), repr=False) + + _jobs: dict[int, SparkListenerJob] = field(default_factory=dict, repr=False, init=False) + + @property + def jobs(self) -> list[SparkListenerJob]: + result = [] + for job_id in sorted(self._jobs.keys()): + result.append(self._jobs[job_id]) + return result + + def on_execution_start(self, event): + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L44-L58 + self.status = SparkListenerExecutionStatus.STARTED + + def on_execution_end(self, event): + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L61-L83 + for job in self._jobs.values(): + if job.status == SparkListenerJobStatus.FAILED: + self.status = SparkListenerExecutionStatus.FAILED + break + else: + self.status = SparkListenerExecutionStatus.COMPLETE + + def on_job_start(self, event): + job_id = event.jobId() + job = SparkListenerJob.create(event) + self._jobs[job_id] = job + job.on_job_start(event) + + def on_job_end(self, event): + job_id = event.jobId() + job = self._jobs.get(job_id) + + if job: + job.on_job_end(event) + + # in some cases Execution consists of just one job with same id + if job_id == self.id: + self.on_execution_end(event) + + # push down events + def on_stage_start(self, event): + for job in self._jobs.values(): + job.on_stage_start(event) + + def on_stage_end(self, event): + for job in self._jobs.values(): + job.on_stage_end(event) + + def on_task_start(self, event): + for job in self._jobs.values(): + job.on_task_start(event) + + def on_task_end(self, event): + for job in self._jobs.values(): + job.on_task_end(event) diff --git a/onetl/_metrics/listener/job.py b/onetl/_metrics/listener/job.py new file mode 100644 index 000000000..b3abbd061 --- /dev/null +++ b/onetl/_metrics/listener/job.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from onetl._metrics.listener.stage import SparkListenerStage, SparkListenerStageStatus +from onetl._util.scala import scala_seq_to_python_list + + +class SparkListenerJobStatus(str, Enum): + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + UNKNOWN = "UNKNOWN" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerJob: + id: int + description: str | None = None + group_id: str | None = None + call_site: str | None = None + status: SparkListenerJobStatus = SparkListenerJobStatus.UNKNOWN + + _stages: dict[int, SparkListenerStage] = field(default_factory=dict, repr=False, init=False) + + @property + def stages(self) -> list[SparkListenerStage]: + result = [] + for stage_id in sorted(self._stages.keys()): + result.append(self._stages[stage_id]) + return result + + @classmethod + def create(cls, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerJobSubmitted.html + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerJobCompleted.html + result = cls( + id=event.jobId(), + description=event.properties().get("spark.job.description"), + group_id=event.properties().get("spark.jobGroup.id"), + call_site=event.properties().get("callSite.short"), + ) + + stage_ids = scala_seq_to_python_list(event.stageIds()) + stage_infos = scala_seq_to_python_list(event.stageInfos()) + for stage_id, stage_info in zip(stage_ids, stage_infos): + result._stages[stage_id] = SparkListenerStage.create(stage_info) # noqa: WPS437 + + return result + + def on_job_start(self, event): + self.status = SparkListenerJobStatus.RUNNING + + def on_job_end(self, event): + for stage in self._stages.values(): + if stage.status == SparkListenerStageStatus.FAILED: + self.status = SparkListenerJobStatus.FAILED + break + else: + self.status = SparkListenerJobStatus.SUCCEEDED + + def on_stage_start(self, event): + stage_id = event.stageInfo().stageId() + stage = self._stages.get(stage_id) + if stage: + stage.on_stage_start(event) + + def on_stage_end(self, event): + stage_id = event.stageInfo().stageId() + stage = self._stages.get(stage_id) + if stage: + stage.on_stage_end(event) + + # push down events + def on_task_start(self, event): + for stage in self._stages.values(): + stage.on_task_start(event) + + def on_task_end(self, event): + for stage in self._stages.values(): + stage.on_task_end(event) diff --git a/onetl/_metrics/listener/listener.py b/onetl/_metrics/listener/listener.py new file mode 100644 index 000000000..3421e5ae0 --- /dev/null +++ b/onetl/_metrics/listener/listener.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from threading import current_thread +from typing import ClassVar + +from onetl._metrics.listener.base import BaseSparkListener +from onetl._metrics.listener.execution import ( + SparkListenerExecution, + SparkSQLMetricNames, +) + + +@dataclass +class SparkMetricsListener(BaseSparkListener): + THREAD_ID_KEY = "python.thread.id" + SQL_START_CLASS_NAME: ClassVar[str] = "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart" + SQL_STOP_CLASS_NAME: ClassVar[str] = "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd" + + _thread_id: str = field(default_factory=lambda: str(current_thread().ident), repr=False, init=False) + _recorded_executions: dict[int, SparkListenerExecution] = field(default_factory=dict, repr=False, init=False) + + def activate(self): + # we cannot override execution_id property as it set by Spark + # we also cannot use job tags, as they were implemented only in Spark 3.5+ + self.spark.sparkContext.setLocalProperty(self.THREAD_ID_KEY, self._thread_id) + return super().activate() + + def reset(self): + self._recorded_executions.clear() + return self + + @property + def executions(self): + return [ + execution for execution in self._recorded_executions.values() if execution.external_id == self._thread_id + ] + + def __enter__(self): + """Record only executions performed by current Spark thread. + + It is important to use this method only in combination with + :obj:`pyspark.util.InheritableThread` to preserve thread-local variables + between Python thread and Java thread. + """ + self.reset() + return super().__enter__() + + def onOtherEvent(self, event): + class_name = event.getClass().getName() + if class_name == self.SQL_START_CLASS_NAME: + self.onExecutionStart(event) + elif class_name == self.SQL_STOP_CLASS_NAME: + self.onExecutionEnd(event) + + def onExecutionStart(self, event): + execution_id = event.executionId() + description = event.description() + execution = SparkListenerExecution( + id=execution_id, + description=description, + ) + self._recorded_executions[execution_id] = execution + execution.on_execution_start(event) + + def onExecutionEnd(self, event): + execution_id = event.executionId() + execution = self._recorded_executions.get(execution_id) + if execution: + execution.on_execution_end(event) + + # Get execution metrics from SQLAppStatusStore, + # as SparkListenerSQLExecutionEnd event does not provide them: + # https://github.com/apache/spark/blob/v3.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala + session_status_store = self.spark._jsparkSession.sharedState().statusStore() # noqa: WPS437 + raw_execution = session_status_store.execution(execution.id).get() + metrics = raw_execution.metrics() + metric_values = session_status_store.executionMetrics(execution.id) + for i in range(metrics.size()): + metric = metrics.apply(i) + metric_name = metric.name() + if metric_name not in SparkSQLMetricNames: + continue + metric_value = metric_values.get(metric.accumulatorId()) + if not metric_value.isDefined(): + continue + execution.metrics[SparkSQLMetricNames(metric_name)].append(metric_value.get()) + + def onJobStart(self, event): + execution_id = event.properties().get("spark.sql.execution.id") + execution_thread_id = event.properties().get(self.THREAD_ID_KEY) + if execution_id is None: + # single job execution + job_id = event.jobId() + execution = SparkListenerExecution( + id=job_id, + description=event.properties().get("spark.job.description"), + external_id=execution_thread_id, + ) + self._recorded_executions[job_id] = execution + else: + execution = self._recorded_executions.get(int(execution_id)) + if execution is None: + return + + if execution_thread_id: + # SparkListenerSQLExecutionStart does not have properties, but SparkListenerJobStart does, + # use it as a source of external_id + execution.external_id = execution_thread_id + + execution.on_job_start(event) + + def onJobEnd(self, event): + for execution in self._recorded_executions.values(): + execution.on_job_end(event) + + def onStageSubmitted(self, event): + for execution in self._recorded_executions.values(): + execution.on_stage_start(event) + + def onStageCompleted(self, event): + for execution in self._recorded_executions.values(): + execution.on_stage_end(event) + + def onTaskStart(self, event): + for execution in self._recorded_executions.values(): + execution.on_task_start(event) + + def onTaskEnd(self, event): + for execution in self._recorded_executions.values(): + execution.on_task_end(event) diff --git a/onetl/_metrics/listener/stage.py b/onetl/_metrics/listener/stage.py new file mode 100644 index 000000000..4bf4dffb0 --- /dev/null +++ b/onetl/_metrics/listener/stage.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from onetl._metrics.listener.task import SparkListenerTask, SparkListenerTaskMetrics + + +class SparkListenerStageStatus(str, Enum): + ACTIVE = "ACTIVE" + COMPLETE = "COMPLETE" + FAILED = "FAILED" + PENDING = "PENDING" + SKIPPED = "SKIPPED" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerStage: + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/StageInfo.html + id: int + status: SparkListenerStageStatus = SparkListenerStageStatus.PENDING + metrics: SparkListenerTaskMetrics = field(default_factory=SparkListenerTaskMetrics, repr=False, init=False) + _tasks: dict[int, SparkListenerTask] = field(default_factory=dict, repr=False, init=False) + + @property + def tasks(self) -> list[SparkListenerTask]: + result = [] + for task_id in sorted(self._tasks.keys()): + result.append(self._tasks[task_id]) + return result + + @classmethod + def create(cls, stage_info): + return cls(id=stage_info.stageId()) + + def on_stage_start(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerStageSubmitted.html + self.status = SparkListenerStageStatus.ACTIVE + + def on_stage_end(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerStageCompleted.html + stage_info = event.stageInfo() + if stage_info.failureReason().isDefined(): + self.status = SparkListenerStageStatus.FAILED + elif not self.tasks: + self.status = SparkListenerStageStatus.SKIPPED + else: + self.status = SparkListenerStageStatus.COMPLETE + + self.metrics = SparkListenerTaskMetrics.create(stage_info.taskMetrics()) + + def on_task_start(self, event): + task_info = event.taskInfo() + task_id = task_info.taskId() + self._tasks[task_id] = SparkListenerTask.create(task_info) + + def on_task_end(self, event): + task_id = event.taskInfo().taskId() + task = self._tasks.get(task_id) + if task: + task.on_task_end(event) diff --git a/onetl/_metrics/listener/task.py b/onetl/_metrics/listener/task.py new file mode 100644 index 000000000..4b27ffcfa --- /dev/null +++ b/onetl/_metrics/listener/task.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + + +class SparkListenerTaskStatus(str, Enum): + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCESS = "SUCCESS" + FAILED = "FAILED" + KILLED = "KILLED" + + def __str__(self): + return self.value + + +@dataclass +class SparkListenerTaskInputMetrics: + bytes_read: int = 0 + records_read: int = 0 + + @classmethod + def create(cls, task_input_metrics): + return cls( + bytes_read=task_input_metrics.bytesRead(), + records_read=task_input_metrics.recordsRead(), + ) + + +@dataclass +class SparkListenerTaskOutputMetrics: + bytes_written: int = 0 + records_written: int = 0 + + @classmethod + def create(cls, task_output_metrics): + return cls( + bytes_written=task_output_metrics.bytesWritten(), + records_written=task_output_metrics.recordsWritten(), + ) + + +@dataclass +class SparkListenerTaskMetrics: + """Python representation of Spark TaskMetrics object. + + See `documentation `_. + """ + + executor_run_time_milliseconds: int = 0 + executor_cpu_time_nanoseconds: int = 0 + peak_execution_memory_bytes: int = 0 + memory_spilled_bytes: int = 0 + disk_spilled_bytes: int = 0 + result_size_bytes: int = 0 + input_metrics: SparkListenerTaskInputMetrics = field(default_factory=SparkListenerTaskInputMetrics) + output_metrics: SparkListenerTaskOutputMetrics = field(default_factory=SparkListenerTaskOutputMetrics) + + @classmethod + def create(cls, task_metrics): + return cls( + executor_run_time_milliseconds=task_metrics.executorRunTime(), + executor_cpu_time_nanoseconds=task_metrics.executorCpuTime(), + peak_execution_memory_bytes=task_metrics.peakExecutionMemory(), + memory_spilled_bytes=task_metrics.memoryBytesSpilled(), + disk_spilled_bytes=task_metrics.diskBytesSpilled(), + result_size_bytes=task_metrics.resultSize(), + input_metrics=SparkListenerTaskInputMetrics.create(task_metrics.inputMetrics()), + output_metrics=SparkListenerTaskOutputMetrics.create(task_metrics.outputMetrics()), + ) + + +@dataclass +class SparkListenerTask: + id: int + status: SparkListenerTaskStatus = SparkListenerTaskStatus.PENDING + metrics: SparkListenerTaskMetrics | None = field(default=None, repr=False, init=False) + + @classmethod + def create(cls, task_info): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/TaskInfo.html + return cls(id=task_info.taskId()) + + def on_task_start(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerTaskStart.html + self.status = SparkListenerTaskStatus(event.taskInfo().status()) + + def on_task_end(self, event): + # https://spark.apache.org/docs/3.5.1/api/java/org/apache/spark/scheduler/SparkListenerTaskEnd.html + self.status = SparkListenerTaskStatus(event.taskInfo().status()) + self.metrics = SparkListenerTaskMetrics.create(event.taskMetrics()) diff --git a/onetl/_metrics/output.py b/onetl/_metrics/output.py new file mode 100644 index 000000000..8600bb68d --- /dev/null +++ b/onetl/_metrics/output.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +from pprint import pformat + +from humanize import naturalsize + +from onetl.impl import BaseModel + + +class SparkOutputMetrics(BaseModel): + written_bytes: int = 0 + written_rows: int = 0 + created_files: int = 0 + created_partitions: int = 0 + + @property + def is_empty(self) -> bool: + return not any([self.written_bytes, self.written_rows, self.created_files]) + + def update(self, other: SparkOutputMetrics) -> SparkOutputMetrics: + self.written_bytes += other.written_bytes + self.written_rows += other.written_rows + self.created_files += other.created_files + self.created_partitions = max([self.created_partitions, other.created_partitions]) + return self + + @property + def details(self) -> str: + if self.is_empty: + return "No data" + + result = [] + result.append(f"Written rows: {pformat(self.written_rows)}") + + if self.written_bytes: + result.append(f"Written size: {naturalsize(self.written_bytes)}") + + if self.created_files: + result.append(f"Created files: {pformat(self.created_files)}") + + if self.created_partitions: + result.append(f"Created partitions: {pformat(self.created_partitions)}") + + return os.linesep.join(result) + + def __str__(self): + return self.details diff --git a/onetl/_metrics/recorder.py b/onetl/_metrics/recorder.py new file mode 100644 index 000000000..4cc5745bb --- /dev/null +++ b/onetl/_metrics/recorder.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: 2021-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING + +from onetl._metrics.command import SparkCommandMetrics +from onetl._metrics.extract import extract_metrics_from_execution +from onetl._metrics.listener import SparkMetricsListener + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + +class SparkMetricsRecorder: + def __init__(self, spark: SparkSession): + self._listener = SparkMetricsListener(spark=spark) + + def __enter__(self): + self._listener.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._listener.__exit__(exc_type, exc_val, exc_tb) + + def metrics(self) -> SparkCommandMetrics: + result = SparkCommandMetrics() + for execution in self._listener.executions: + result = result.update(extract_metrics_from_execution(execution)) + return result diff --git a/onetl/_util/java.py b/onetl/_util/java.py index df88b1a59..451114324 100644 --- a/onetl/_util/java.py +++ b/onetl/_util/java.py @@ -4,6 +4,9 @@ from typing import TYPE_CHECKING +from onetl._util.spark import get_spark_version +from onetl._util.version import Version + if TYPE_CHECKING: from py4j.java_gateway import JavaGateway from pyspark.sql import SparkSession @@ -24,3 +27,34 @@ def try_import_java_class(spark_session: SparkSession, name: str): klass = getattr(gateway.jvm, name) gateway.help(klass, display=False) return klass + + +def start_callback_server(spark_session: SparkSession): + """ + Start Py4J callback server. Important to receive Java events on Python side, + e.g. in Spark Listener implementations. + """ + gateway = get_java_gateway(spark_session) + if get_spark_version(spark_session) >= Version("2.4"): + from pyspark.java_gateway import ensure_callback_server_started + + ensure_callback_server_started(gateway) + return + + # PySpark 2.3 + if "_callback_server" not in gateway.__dict__ or gateway._callback_server is None: + from py4j.java_gateway import JavaObject + + gateway.callback_server_parameters.eager_load = True + gateway.callback_server_parameters.daemonize = True + gateway.callback_server_parameters.daemonize_connections = True + gateway.callback_server_parameters.port = 0 + gateway.start_callback_server(gateway.callback_server_parameters) + cbport = gateway._callback_server.server_socket.getsockname()[1] + gateway._callback_server.port = cbport + # gateway with real port + gateway._python_proxy_port = gateway._callback_server.port + # get the GatewayServer object in JVM by ID + java_gateway = JavaObject("GATEWAY_SERVER", gateway._gateway_client) + # update the port of CallbackClient with real port + java_gateway.resetCallbackClient(java_gateway.getCallbackClient().getAddress(), gateway._python_proxy_port) diff --git a/onetl/_util/scala.py b/onetl/_util/scala.py index 397a91576..5e6c21bc1 100644 --- a/onetl/_util/scala.py +++ b/onetl/_util/scala.py @@ -12,3 +12,10 @@ def get_default_scala_version(spark_version: Version) -> Version: if spark_version.major < 3: return Version("2.11") return Version("2.12") + + +def scala_seq_to_python_list(seq) -> list: + result = [] + for i in range(seq.size()): + result.append(seq.apply(i)) + return result diff --git a/onetl/strategy/hwm_store/__init__.py b/onetl/strategy/hwm_store/__init__.py index 0b931301e..7a0338d31 100644 --- a/onetl/strategy/hwm_store/__init__.py +++ b/onetl/strategy/hwm_store/__init__.py @@ -23,7 +23,7 @@ register_spark_type_to_hwm_type_mapping, ) -__all__ = [ # noqa: WPS410 +__all__ = [ "BaseHWMStore", "SparkTypeToHWM", "register_spark_type_to_hwm_type_mapping", diff --git a/onetl/version.py b/onetl/version.py index dada22dd7..1a3c6cecc 100644 --- a/onetl/version.py +++ b/onetl/version.py @@ -8,4 +8,4 @@ VERSION_FILE = Path(__file__).parent / "VERSION" -__version__ = VERSION_FILE.read_text().strip() # noqa: WPS410 +__version__ = VERSION_FILE.read_text().strip() diff --git a/setup.cfg b/setup.cfg index d12261ed9..7ddb67b0d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -275,7 +275,9 @@ ignore = # WPS474 Found import object collision WPS474, # WPS318 Found extra indentation - WPS318 + WPS318, +# WPS410 Found wrong metadata variable: __all__ + WPS410 # http://flake8.pycqa.org/en/latest/user/options.html?highlight=per-file-ignores#cmdoption-flake8-per-file-ignores per-file-ignores = @@ -350,6 +352,9 @@ per-file-ignores = onetl/hooks/slot.py: # WPS210 Found too many local variables WPS210, + onetl/_metrics/listener/*: +# N802 function name 'onJobStart' should be lowercase + N802, tests/*: # Found too many empty lines in `def` WPS473, diff --git a/tests/.coveragerc b/tests/.coveragerc index 08633e6cc..55af8c092 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -7,6 +7,7 @@ data_file = reports/.coverage [report] exclude_lines = pragma: no cover + no cover: start(?s:.)*?no cover: stop def __repr__ if self.debug: if settings.DEBUG diff --git a/tests/fixtures/global_hwm_store.py b/tests/fixtures/global_hwm_store.py index f10a0089d..2e006b923 100644 --- a/tests/fixtures/global_hwm_store.py +++ b/tests/fixtures/global_hwm_store.py @@ -5,7 +5,7 @@ @pytest.fixture(scope="function", autouse=True) def global_hwm_store(request): # noqa: WPS325 test_function = request.function - entities = test_function.__name__.split("_") if test_function else [] + entities = set(test_function.__name__.split("_")) if test_function else set() if "strategy" in entities: with MemoryHWMStore(): diff --git a/tests/fixtures/processing/fixtures.py b/tests/fixtures/processing/fixtures.py index 3f541f692..9bb62689e 100644 --- a/tests/fixtures/processing/fixtures.py +++ b/tests/fixtures/processing/fixtures.py @@ -21,10 +21,14 @@ def processing(request, spark): "kafka": ("tests.fixtures.processing.kafka", "KafkaProcessing"), } - db_storage_name = request.function.__name__.split("_")[1] - if db_storage_name not in processing_classes: - raise ValueError(f"Wrong name. Please use one of: {list(processing_classes.keys())}") + test_name_parts = set(request.function.__name__.split("_")) + matches = set(processing_classes.keys()) & test_name_parts + if not matches or len(matches) > 1: + raise ValueError( + f"Test name {request.function.__name__} should have one of these components: {list(processing_classes.keys())}", + ) + db_storage_name = matches.pop() module_name, class_name = processing_classes[db_storage_name] module = import_module(module_name) db_processing = getattr(module, class_name) diff --git a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_files.py b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_files.py new file mode 100644 index 000000000..e9e221e71 --- /dev/null +++ b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_files.py @@ -0,0 +1,164 @@ +import time +from contextlib import suppress +from pathlib import Path + +import pytest + +from onetl._metrics.recorder import SparkMetricsRecorder +from onetl.file import FileDFReader, FileDFWriter +from onetl.file.format import CSV, JSON + +pytestmark = [ + pytest.mark.local_fs, + pytest.mark.file_df_connection, + pytest.mark.connection, + pytest.mark.csv, + # SparkListener does not give guarantees of delivering execution metrics in time + pytest.mark.flaky(reruns=5), +] + + +def test_spark_metrics_recorder_files_read( + spark, + local_fs_file_df_connection_with_path_and_files, +): + local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files + files_path: Path = source_path / "csv/with_header" + + reader = FileDFReader( + connection=local_fs, + format=CSV(header=True), + source_path=files_path, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.input.read_rows + assert metrics.input.read_bytes + # file related metrics are too flaky to assert + + +def test_spark_metrics_recorder_files_read_no_files( + spark, + local_fs_file_df_connection_with_path, + file_df_schema, +): + local_fs, source_path = local_fs_file_df_connection_with_path + + reader = FileDFReader( + connection=local_fs, + format=CSV(), + source_path=source_path, + df_schema=file_df_schema, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.read_rows + assert not metrics.input.read_files + + +def test_spark_metrics_recorder_files_read_no_data_after_filter( + spark, + local_fs_file_df_connection_with_path_and_files, + file_df_schema, +): + local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files + files_path = source_path / "csv/with_header" + + reader = FileDFReader( + connection=local_fs, + format=CSV(header=True), + source_path=files_path, + df_schema=file_df_schema, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run().where("str_value = 'unknown'") + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.raw_file_bytes + # some files _may_ be scanned, but such assertions are too flaky to use + + +def test_spark_metrics_recorder_files_read_error( + spark, + local_fs_file_df_connection_with_path_and_files, +): + local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files + files_path: Path = source_path / "csv/with_header" + + reader = FileDFReader( + connection=local_fs, + format=JSON(), + source_path=files_path, + ) + + with SparkMetricsRecorder(spark) as recorder: + with suppress(Exception): + df = reader.run() + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + # some files metadata may be scanned, but file content was not read + assert not metrics.input.raw_file_bytes + + +def test_spark_metrics_recorder_files_write( + spark, + local_fs_file_df_connection_with_path, + file_df_dataframe, +): + local_fs, target_path = local_fs_file_df_connection_with_path + + writer = FileDFWriter( + connection=local_fs, + format=CSV(), + target_path=target_path, + options=FileDFWriter.Options(if_exists="append"), + ) + + with SparkMetricsRecorder(spark) as recorder: + writer.run(file_df_dataframe) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.output.written_rows == file_df_dataframe.count() + assert metrics.output.written_bytes + # file related metrics are too flaky to assert + + +def test_spark_metrics_recorder_files_write_empty_input( + spark, + local_fs_file_df_connection_with_path, + file_df_dataframe, +): + local_fs, target_path = local_fs_file_df_connection_with_path + + df = file_df_dataframe.limit(0) + + writer = FileDFWriter( + connection=local_fs, + format=CSV(), + target_path=target_path, + options=FileDFWriter.Options(if_exists="append"), + ) + + with SparkMetricsRecorder(spark) as recorder: + writer.run(df) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.output.written_rows + assert not metrics.output.written_bytes diff --git a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_hive.py b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_hive.py new file mode 100644 index 000000000..7e8dc2187 --- /dev/null +++ b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_hive.py @@ -0,0 +1,159 @@ +import time + +import pytest + +from onetl._metrics.recorder import SparkMetricsRecorder +from onetl.connection import Hive +from onetl.db import DBReader, DBWriter +from tests.util.rand import rand_str + +pytestmark = [ + pytest.mark.hive, + pytest.mark.db_connection, + pytest.mark.connection, + # SparkListener does not give guarantees of delivering execution metrics in time + pytest.mark.flaky(reruns=5), +] + + +def test_spark_metrics_recorder_hive_read_count(spark, load_table_data): + hive = Hive(cluster="rnd-dwh", spark=spark) + reader = DBReader( + connection=hive, + source=load_table_data.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + rows = df.count() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.input.read_rows == rows + assert metrics.input.read_bytes + # in some cases files are read, in some cases only metastore statistics is used + + +def test_spark_metrics_recorder_hive_read_collect(spark, load_table_data): + hive = Hive(cluster="rnd-dwh", spark=spark) + reader = DBReader( + connection=hive, + source=load_table_data.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + rows = len(df.collect()) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.input.read_rows == rows + assert metrics.input.read_bytes + # file related metrics are too flaky to assert + + +def test_spark_metrics_recorder_hive_read_empty_source(spark, prepare_schema_table): + hive = Hive(cluster="rnd-dwh", spark=spark) + reader = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.read_rows + assert not metrics.input.read_bytes + + +def test_spark_metrics_recorder_hive_read_no_data_after_filter(spark, load_table_data): + hive = Hive(cluster="rnd-dwh", spark=spark) + reader = DBReader( + connection=hive, + source=load_table_data.full_name, + where="1=0", + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.read_rows + assert not metrics.input.read_bytes + + +def test_spark_metrics_recorder_hive_sql(spark, load_table_data): + hive = Hive(cluster="rnd-dwh", spark=spark) + + with SparkMetricsRecorder(spark) as recorder: + df = hive.sql(f"SELECT * FROM {load_table_data.full_name}") + rows = len(df.collect()) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.input.read_rows == rows + assert metrics.input.read_bytes + # file related metrics are too flaky to assert + + +def test_spark_metrics_recorder_hive_write(spark, processing, get_schema_table): + df = processing.create_spark_df(spark) + + hive = Hive(cluster="rnd-dwh", spark=spark) + writer = DBWriter( + connection=hive, + target=get_schema_table.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + writer.run(df) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.output.written_rows == df.count() + assert metrics.output.written_bytes + # file related metrics are too flaky to assert + + +def test_spark_metrics_recorder_hive_write_empty(spark, processing, get_schema_table): + df = processing.create_spark_df(spark).limit(0) + + hive = Hive(cluster="rnd-dwh", spark=spark) + writer = DBWriter( + connection=hive, + target=get_schema_table.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + writer.run(df) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.output.written_rows + + +def test_spark_metrics_recorder_hive_execute(request, spark, processing, get_schema_table): + df = processing.create_spark_df(spark) + view_name = rand_str() + df.createOrReplaceTempView(view_name) + + def finalizer(): + spark.sql(f"DROP VIEW IF EXISTS {view_name}") + + request.addfinalizer(finalizer) + + hive = Hive(cluster="rnd-dwh", spark=spark) + + with SparkMetricsRecorder(spark) as recorder: + hive.execute(f"CREATE TABLE {get_schema_table.full_name} AS SELECT * FROM {view_name}") + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.output.written_rows == df.count() + assert metrics.output.written_bytes + # file related metrics are too flaky to assert diff --git a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py new file mode 100644 index 000000000..a979f1028 --- /dev/null +++ b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py @@ -0,0 +1,198 @@ +import time + +import pytest + +from onetl._metrics.recorder import SparkMetricsRecorder +from onetl.connection import Postgres +from onetl.db import DBReader, DBWriter + +pytestmark = [ + pytest.mark.postgres, + pytest.mark.db_connection, + pytest.mark.connection, + # SparkListener does not give guarantees of delivering execution metrics in time + pytest.mark.flaky(reruns=5), +] + + +def test_spark_metrics_recorder_postgres_read(spark, processing, load_table_data): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + rows = len(df.collect()) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.input.read_rows == rows + # JDBC does not provide information about data size + assert not metrics.input.read_bytes + + +def test_spark_metrics_recorder_postgres_read_empty_source(spark, processing, prepare_schema_table): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.read_rows + + +def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, processing, load_table_data): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + where="1=0", + ) + + with SparkMetricsRecorder(spark) as recorder: + df = reader.run() + df.collect() + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.read_rows + + +def test_spark_metrics_recorder_postgres_sql(spark, processing, load_table_data): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + with SparkMetricsRecorder(spark) as recorder: + df = postgres.sql(f"SELECT * FROM {load_table_data.full_name}") + rows = len(df.collect()) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.input.read_rows == rows + + +def test_spark_metrics_recorder_postgres_write(spark, processing, get_schema_table): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + df = processing.create_spark_df(spark) + + writer = DBWriter( + connection=postgres, + target=get_schema_table.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + writer.run(df) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert metrics.output.written_rows == df.count() + # JDBC does not provide information about data size + assert not metrics.output.written_bytes + + +def test_spark_metrics_recorder_postgres_write_empty(spark, processing, get_schema_table): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + df = processing.create_spark_df(spark).limit(0) + + writer = DBWriter( + connection=postgres, + target=get_schema_table.full_name, + ) + + with SparkMetricsRecorder(spark) as recorder: + writer.run(df) + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.output.written_rows + + +def test_spark_metrics_recorder_postgres_fetch(spark, processing, load_table_data): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + with SparkMetricsRecorder(spark) as recorder: + postgres.fetch(f"SELECT * FROM {load_table_data.full_name}") + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.read_rows + + +def test_spark_metrics_recorder_postgres_execute(spark, processing, load_table_data): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + new_table = load_table_data.full_name + "_new" + + with SparkMetricsRecorder(spark) as recorder: + postgres.execute(f"CREATE TABLE {new_table} AS SELECT * FROM {load_table_data.full_name}") + + time.sleep(0.1) # sleep to fetch late metrics from SparkListener + metrics = recorder.metrics() + assert not metrics.input.read_rows diff --git a/tests/tests_unit/test_metrics/test_spark_command_metrics.py b/tests/tests_unit/test_metrics/test_spark_command_metrics.py new file mode 100644 index 000000000..f4da30703 --- /dev/null +++ b/tests/tests_unit/test_metrics/test_spark_command_metrics.py @@ -0,0 +1,70 @@ +import textwrap +from datetime import timedelta + +from onetl._metrics.command import SparkCommandMetrics +from onetl._metrics.driver import SparkDriverMetrics +from onetl._metrics.executor import SparkExecutorMetrics +from onetl._metrics.input import SparkInputMetrics +from onetl._metrics.output import SparkOutputMetrics + + +def test_spark_metrics_command_is_empty(): + empty_metrics = SparkCommandMetrics() + assert empty_metrics.is_empty + + no_input_output = SparkCommandMetrics( + driver=SparkDriverMetrics(in_memory_bytes=1_000_000), + executor=SparkExecutorMetrics(total_run_time=timedelta(microseconds=1)), + ) + assert no_input_output.is_empty + + with_input = SparkCommandMetrics( + input=SparkInputMetrics(read_rows=1), + ) + assert not with_input.is_empty + + with_output = SparkCommandMetrics( + output=SparkOutputMetrics(written_rows=1), + ) + assert not with_output.is_empty + + +def test_spark_metrics_command_details(): + empty_metrics = SparkCommandMetrics() + assert empty_metrics.details == "No data" + assert str(empty_metrics) == empty_metrics.details + + jdbc_fetch_metrics = SparkCommandMetrics( + input=SparkInputMetrics(read_rows=1_000), + driver=SparkDriverMetrics(in_memory_bytes=1_000_000), + ) + + expected = textwrap.dedent( + """ + Input: + Read rows: 1000 + Driver: + In-memory data (approximate): 1.0 MB + """, + ) + assert jdbc_fetch_metrics.details == expected.strip() + assert str(jdbc_fetch_metrics) == jdbc_fetch_metrics.details + + jdbc_write_metrics = SparkCommandMetrics( + output=SparkOutputMetrics(written_rows=1_000), + executor=SparkExecutorMetrics( + total_run_time=timedelta(seconds=2), + total_cpu_time=timedelta(seconds=1), + ), + ) + expected = textwrap.dedent( + """ + Output: + Written rows: 1000 + Executor: + Total run time: 2 seconds + Total CPU time: 1 second + """, + ) + assert jdbc_write_metrics.details == expected.strip() + assert str(jdbc_write_metrics) == jdbc_write_metrics.details diff --git a/tests/tests_unit/test_metrics/test_spark_driver_metrics.py b/tests/tests_unit/test_metrics/test_spark_driver_metrics.py new file mode 100644 index 000000000..cd4c5dc93 --- /dev/null +++ b/tests/tests_unit/test_metrics/test_spark_driver_metrics.py @@ -0,0 +1,22 @@ +from onetl._metrics.driver import SparkDriverMetrics + + +def test_spark_metrics_driver_is_empty(): + empty_metrics = SparkDriverMetrics() + assert empty_metrics.is_empty + + metrics1 = SparkDriverMetrics(in_memory_bytes=1_000) + assert metrics1.is_empty + + metrics2 = SparkDriverMetrics(in_memory_bytes=1_000_000) + assert not metrics2.is_empty + + +def test_spark_metrics_driver_details(): + empty_metrics = SparkDriverMetrics() + assert empty_metrics.details == "No data" + assert str(empty_metrics) == empty_metrics.details + + jdbc_metrics = SparkDriverMetrics(in_memory_bytes=1_000_000) + assert jdbc_metrics.details == "In-memory data (approximate): 1.0 MB" + assert str(jdbc_metrics) == jdbc_metrics.details diff --git a/tests/tests_unit/test_metrics/test_spark_executor_metrics.py b/tests/tests_unit/test_metrics/test_spark_executor_metrics.py new file mode 100644 index 000000000..3acd71904 --- /dev/null +++ b/tests/tests_unit/test_metrics/test_spark_executor_metrics.py @@ -0,0 +1,58 @@ +import textwrap +from datetime import timedelta + +from onetl._metrics.executor import SparkExecutorMetrics + + +def test_spark_metrics_executor_is_empty(): + empty_metrics = SparkExecutorMetrics() + assert empty_metrics.is_empty + + run_metrics = SparkExecutorMetrics( + total_run_time=timedelta(microseconds=1), + ) + assert not run_metrics.is_empty + + +def test_spark_metrics_executor_details(): + empty_metrics = SparkExecutorMetrics() + assert empty_metrics.details == "No data" + assert str(empty_metrics) == empty_metrics.details + + full_metrics = SparkExecutorMetrics( + total_run_time=timedelta(hours=2), + total_cpu_time=timedelta(hours=1), + peak_memory_bytes=1_000_000_000, + memory_spilled_bytes=2_000_000_000, + disk_spilled_bytes=3_000_000_000, + ) + + assert ( + full_metrics.details + == textwrap.dedent( + """ + Total run time: 2 hours + Total CPU time: 1 hour + Peak memory: 1.0 GB + Memory spilled: 2.0 GB + Disk spilled: 3.0 GB + """, + ).strip() + ) + assert str(full_metrics) == full_metrics.details + + minimal_metrics = SparkExecutorMetrics( + total_run_time=timedelta(seconds=1), + total_cpu_time=timedelta(seconds=1), + ) + + assert ( + minimal_metrics.details + == textwrap.dedent( + """ + Total run time: 1 second + Total CPU time: 1 second + """, + ).strip() + ) + assert str(minimal_metrics) == minimal_metrics.details diff --git a/tests/tests_unit/test_metrics/test_spark_input_metrics.py b/tests/tests_unit/test_metrics/test_spark_input_metrics.py new file mode 100644 index 000000000..0de1a57a4 --- /dev/null +++ b/tests/tests_unit/test_metrics/test_spark_input_metrics.py @@ -0,0 +1,50 @@ +import textwrap + +from onetl._metrics.input import SparkInputMetrics + + +def test_spark_metrics_input_is_empty(): + empty_metrics = SparkInputMetrics() + assert empty_metrics.is_empty + + metrics1 = SparkInputMetrics(read_rows=1) + assert not metrics1.is_empty + + metrics2 = SparkInputMetrics(read_files=1) + assert not metrics2.is_empty + + metrics3 = SparkInputMetrics(read_bytes=1) + assert not metrics3.is_empty + + +def test_spark_metrics_input_details(): + empty_metrics = SparkInputMetrics() + assert empty_metrics.details == "No data" + assert str(empty_metrics) == empty_metrics.details + + file_df_metrics = SparkInputMetrics( + read_rows=1_000, + read_partitions=4, + read_files=4, + read_bytes=2_000_000, + raw_file_bytes=5_000_000, + ) + + expected = textwrap.dedent( + """ + Read rows: 1000 + Read partitions: 4 + Read files: 4 + Read size: 2.0 MB + Raw files size: 5.0 MB + """, + ) + assert file_df_metrics.details == expected.strip() + assert str(file_df_metrics) == file_df_metrics.details + + jdbc_metrics = SparkInputMetrics( + read_rows=1_000, + ) + + assert jdbc_metrics.details == "Read rows: 1000" + assert str(jdbc_metrics) == jdbc_metrics.details diff --git a/tests/tests_unit/test_metrics/test_spark_output_metrics.py b/tests/tests_unit/test_metrics/test_spark_output_metrics.py new file mode 100644 index 000000000..e8cb9ae71 --- /dev/null +++ b/tests/tests_unit/test_metrics/test_spark_output_metrics.py @@ -0,0 +1,46 @@ +import textwrap + +from onetl._metrics.output import SparkOutputMetrics + + +def test_spark_metrics_output_is_empty(): + empty_metrics = SparkOutputMetrics() + assert empty_metrics.is_empty + + metric1 = SparkOutputMetrics(written_rows=1) + assert not metric1.is_empty + + metric2 = SparkOutputMetrics(written_bytes=1) + assert not metric2.is_empty + + metric3 = SparkOutputMetrics(created_files=1) + assert not metric3.is_empty + + +def test_spark_metrics_output_details(): + empty_metrics = SparkOutputMetrics() + assert empty_metrics.details == "No data" + assert str(empty_metrics) == empty_metrics.details + + hive_metrics = SparkOutputMetrics( + written_rows=1_000, + written_bytes=2_000_000, + created_files=4, + created_partitions=4, + ) + + expected = textwrap.dedent( + """ + Written rows: 1000 + Written size: 2.0 MB + Created files: 4 + Created partitions: 4 + """, + ) + assert hive_metrics.details == expected.strip() + assert str(hive_metrics) == hive_metrics.details + + jdbc_metrics = SparkOutputMetrics(written_rows=1_000) + + assert jdbc_metrics.details == "Written rows: 1000" + assert str(jdbc_metrics) == jdbc_metrics.details