Skip to content

Commit

Permalink
[DOP-18743] Set default jobDescription
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Aug 9, 2024
1 parent 3c25405 commit e34d9e5
Show file tree
Hide file tree
Showing 51 changed files with 545 additions and 362 deletions.
3 changes: 3 additions & 0 deletions docs/changelog/next_release/304.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Change connection URL used for generating HWM names of S3 and Samba sources:
* ``smb://host:port`` -> ``smb://host:port/share``
* ``s3://host:port`` -> ``s3://host:port/bucket``
6 changes: 6 additions & 0 deletions docs/changelog/next_release/304.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Generate default ``jobDescription`` based on currently executed method. Examples:
* ``DBWriter() -> Postgres[host:5432/database]``
* ``MongoDB[localhost:27017/admin] -> DBReader.run()``
* ``Hive[cluster].execute()``

If user already set custom ``jobDescription``, it will left intact.
4 changes: 2 additions & 2 deletions onetl/_util/hadoop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_hadoop_version(spark_session: SparkSession) -> Version:
"""
Get version of Hadoop libraries embedded to Spark
"""
jvm = spark_session._jvm # noqa: WPS437
jvm = spark_session._jvm # noqa: WPS437 # type: ignore[attr-defined]
version_info = jvm.org.apache.hadoop.util.VersionInfo # type: ignore[union-attr]
hadoop_version: str = version_info.getVersion()
return Version(hadoop_version)
Expand All @@ -24,4 +24,4 @@ def get_hadoop_config(spark_session: SparkSession):
"""
Get ``org.apache.hadoop.conf.Configuration`` object
"""
return spark_session.sparkContext._jsc.hadoopConfiguration()
return spark_session.sparkContext._jsc.hadoopConfiguration() # type: ignore[attr-defined]
2 changes: 1 addition & 1 deletion onetl/_util/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_java_gateway(spark_session: SparkSession) -> JavaGateway:
"""
Get py4j Java gateway object
"""
return spark_session._sc._gateway # noqa: WPS437 # type: ignore
return spark_session._sc._gateway # noqa: WPS437 # type: ignore[attr-defined]


def try_import_java_class(spark_session: SparkSession, name: str):
Expand Down
23 changes: 23 additions & 0 deletions onetl/_util/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.conf import RuntimeConfig

SPARK_JOB_DESCRIPTION_PROPERTY = "spark.job.description"
SPARK_JOB_GROUP_PROPERTY = "spark.jobGroup.id"


def stringify(value: Any, quote: bool = False) -> Any: # noqa: WPS212
"""
Expand Down Expand Up @@ -200,3 +203,23 @@ def get_executor_total_cores(spark_session: SparkSession, include_driver: bool =
expected_cores += 1

return expected_cores, config


@contextmanager
def override_job_description(spark_session: SparkSession, job_description: str):
"""
Override Spark job description.
Unlike ``spark_session.sparkContext.setJobDescription``, this method resets job description
before exiting the context manager, instead of keeping it.
If user set custom description, it will be left intact.
"""
spark_context = spark_session.sparkContext
original_description = spark_context.getLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY)

try:
spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description or job_description)
yield
finally:
spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description) # type: ignore[arg-type]
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/clickhouse/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def jdbc_params(self) -> dict:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"

@staticmethod
def _build_statement(
statement: str,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def package_spark_3_2(cls) -> str:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

@property
def jdbc_url(self) -> str:
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}"
Expand Down
14 changes: 10 additions & 4 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pydantic import validator # type: ignore[no-redef, assignment]

from onetl._metrics.recorder import SparkMetricsRecorder
from onetl._util.spark import inject_spark_param
from onetl._util.spark import inject_spark_param, override_job_description
from onetl._util.sql import clear_statement
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.hive.dialect import HiveDialect
Expand Down Expand Up @@ -159,6 +159,9 @@ def get_current(cls, spark: SparkSession):
def instance_url(self) -> str:
return self.cluster

def __str__(self):
return f"{self.__class__.__name__}[{self.cluster}]"

@slot
def check(self):
log.debug("|%s| Detecting current cluster...", self.__class__.__name__)
Expand All @@ -173,7 +176,8 @@ def check(self):
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
with override_job_description(self.spark, f"{self}.check()"):
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down Expand Up @@ -213,7 +217,8 @@ def sql(

with SparkMetricsRecorder(self.spark) as recorder:
try:
df = self._execute_sql(query)
with override_job_description(self.spark, f"{self}.sql()"):
df = self._execute_sql(query)
except Exception:
log.error("|%s| Query failed", self.__class__.__name__)

Expand Down Expand Up @@ -260,7 +265,8 @@ def execute(

with SparkMetricsRecorder(self.spark) as recorder:
try:
self._execute_sql(statement).collect()
with override_job_description(self.spark, f"{self}.execute()"):
self._execute_sql(statement).collect()
except Exception:
log.error("|%s| Execution failed", self.__class__.__name__)
metrics = recorder.metrics()
Expand Down
4 changes: 3 additions & 1 deletion onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from typing import TYPE_CHECKING, Any

from onetl._util.spark import override_job_description
from onetl._util.sql import clear_statement
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect
Expand Down Expand Up @@ -93,7 +94,8 @@ def sql(
log_lines(log, query)

try:
df = self._query_on_executor(query, self.SQLOptions.parse(options))
with override_job_description(self.spark, f"{self}.sql()"):
df = self._query_on_executor(query, self.SQLOptions.parse(options))
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise
Expand Down
77 changes: 42 additions & 35 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from onetl._metrics.command import SparkCommandMetrics
from onetl._util.java import get_java_gateway, try_import_java_class
from onetl._util.spark import estimate_dataframe_size, get_spark_version, stringify
from onetl._util.spark import (
estimate_dataframe_size,
get_spark_version,
override_job_description,
stringify,
)
from onetl._util.sql import clear_statement
from onetl._util.version import Version
from onetl.connection.db_connection.jdbc_mixin.options import (
Expand Down Expand Up @@ -209,21 +214,22 @@ def fetch(
else self.FetchOptions.parse(options)
)

try:
df = self._query_on_driver(query, call_options)
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise

log.info("|%s| Query succeeded, created in-memory dataframe.", self.__class__.__name__)

# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)
log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
with override_job_description(self.spark, f"{self}.fetch()"):
try:
df = self._query_on_driver(query, call_options)
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise

log.info("|%s| Query succeeded, created in-memory dataframe.", self.__class__.__name__)

# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)
log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
return df

@slot
Expand Down Expand Up @@ -280,25 +286,26 @@ def execute(
else self.ExecuteOptions.parse(options)
)

try:
df = self._call_on_driver(statement, call_options)
except Exception:
log.error("|%s| Execution failed!", self.__class__.__name__)
raise

if not df:
log.info("|%s| Execution succeeded, nothing returned.", self.__class__.__name__)
return None

log.info("|%s| Execution succeeded, created in-memory dataframe.", self.__class__.__name__)
# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)

log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
with override_job_description(self.spark, f"{self}.execute()"):
try:
df = self._call_on_driver(statement, call_options)
except Exception:
log.error("|%s| Execution failed!", self.__class__.__name__)
raise

if not df:
log.info("|%s| Execution succeeded, nothing returned.", self.__class__.__name__)
return None

log.info("|%s| Execution succeeded, created in-memory dataframe.", self.__class__.__name__)
# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)

log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
return df

@validator("spark")
Expand Down
7 changes: 5 additions & 2 deletions onetl/connection/db_connection/kafka/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def get_min_max_values(
# https://kafka.apache.org/22/javadoc/org/apache/kafka/clients/consumer/KafkaConsumer.html#partitionsFor-java.lang.String-
partition_infos = consumer.partitionsFor(source)

jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
topic_partitions = [
jvm.org.apache.kafka.common.TopicPartition(source, p.partition()) # type: ignore[union-attr]
for p in partition_infos
Expand Down Expand Up @@ -542,6 +542,9 @@ def get_min_max_values(
def instance_url(self):
return "kafka://" + self.cluster

def __str__(self):
return f"{self.__class__.__name__}[{self.cluster}]"

@root_validator(pre=True)
def _get_addresses_by_cluster(cls, values):
cluster = values.get("cluster")
Expand Down Expand Up @@ -639,7 +642,7 @@ def _get_java_consumer(self):
return consumer_class(connection_properties)

def _get_topics(self, timeout: int = 10) -> set[str]:
jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
# Maybe we should not pass explicit timeout at all,
# and instead use default.api.timeout.ms which is configurable via self.extra.
# Think about this next time if someone see issues in real use
Expand Down
20 changes: 14 additions & 6 deletions onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from onetl._util.classproperty import classproperty
from onetl._util.java import try_import_java_class
from onetl._util.scala import get_default_scala_version
from onetl._util.spark import get_spark_version
from onetl._util.spark import get_spark_version, override_job_description
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.mongodb.dialect import MongoDBDialect
Expand Down Expand Up @@ -347,17 +347,25 @@ def pipeline(
if pipeline:
read_options["aggregation.pipeline"] = json.dumps(pipeline)
read_options["connection.uri"] = self.connection_url
spark_reader = self.spark.read.format("mongodb").options(**read_options)

if df_schema:
spark_reader = spark_reader.schema(df_schema)
with override_job_description(
self.spark,
f"{self}.pipeline()",
):
spark_reader = self.spark.read.format("mongodb").options(**read_options)

return spark_reader.load()
if df_schema:
spark_reader = spark_reader.schema(df_schema)

Check warning on line 358 in onetl/connection/db_connection/mongodb/connection.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/db_connection/mongodb/connection.py#L358

Added line #L358 was not covered by tests

return spark_reader.load()

@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
Expand Down Expand Up @@ -532,7 +540,7 @@ def _check_java_class_imported(cls, spark):
return spark

def _collection_exists(self, source: str) -> bool:
jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore
collections = set(client.getDatabase(self.database).listCollectionNames().iterator())
if source in collections:
Expand Down
9 changes: 9 additions & 0 deletions onetl/connection/db_connection/mssql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,12 @@ def instance_url(self) -> str:
# for backward compatibility keep port number in legacy HWM instance url
port = self.port or 1433
return f"{self.__class__.__name__.lower()}://{self.host}:{port}/{self.database}"

def __str__(self):
extra_dict = self.extra.dict(by_alias=True)
instance_name = extra_dict.get("instanceName")
if instance_name:
return rf"{self.__class__.__name__}[{self.host}\{instance_name}/{self.database}]"

port = self.port or 1433
return f"{self.__class__.__name__}[{self.host}:{port}/{self.database}]"
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/mysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,6 @@ def jdbc_params(self) -> dict:
@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"
6 changes: 6 additions & 0 deletions onetl/connection/db_connection/oracle/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ def instance_url(self) -> str:

return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.service_name}"

def __str__(self):
if self.sid:
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.sid}]"

return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.service_name}]"

@slot
def get_min_max_values(
self,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/postgres/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def jdbc_params(self) -> dict[str, str]:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

def _options_to_connection_properties(
self,
options: JDBCFetchOptions | JDBCExecuteOptions,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/teradata/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,6 @@ def jdbc_url(self) -> str:
@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"
5 changes: 4 additions & 1 deletion onetl/connection/file_connection/ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ class FTP(FileConnection, RenameDirMixin):

@property
def instance_url(self) -> str:
return f"ftp://{self.host}:{self.port}"
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

Check warning on line 108 in onetl/connection/file_connection/ftp.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/file_connection/ftp.py#L108

Added line #L108 was not covered by tests

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"

@slot
def path_exists(self, path: os.PathLike | str) -> bool:
Expand Down
Loading

0 comments on commit e34d9e5

Please sign in to comment.