Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-18743] Set default jobDescription #304

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changelog/next_release/304.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Change connection URL used for generating HWM names of S3 and Samba sources:
* ``smb://host:port`` -> ``smb://host:port/share``
* ``s3://host:port`` -> ``s3://host:port/bucket``
6 changes: 6 additions & 0 deletions docs/changelog/next_release/304.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Generate default ``jobDescription`` based on currently executed method. Examples:
* ``DBWriter() -> Postgres[host:5432/database]``
* ``MongoDB[localhost:27017/admin] -> DBReader.run()``
* ``Hive[cluster].execute()``

If user already set custom ``jobDescription``, it will left intact.
4 changes: 2 additions & 2 deletions onetl/_util/hadoop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_hadoop_version(spark_session: SparkSession) -> Version:
"""
Get version of Hadoop libraries embedded to Spark
"""
jvm = spark_session._jvm # noqa: WPS437
jvm = spark_session._jvm # noqa: WPS437 # type: ignore[attr-defined]
version_info = jvm.org.apache.hadoop.util.VersionInfo # type: ignore[union-attr]
hadoop_version: str = version_info.getVersion()
return Version(hadoop_version)
Expand All @@ -24,4 +24,4 @@ def get_hadoop_config(spark_session: SparkSession):
"""
Get ``org.apache.hadoop.conf.Configuration`` object
"""
return spark_session.sparkContext._jsc.hadoopConfiguration()
return spark_session.sparkContext._jsc.hadoopConfiguration() # type: ignore[attr-defined]
2 changes: 1 addition & 1 deletion onetl/_util/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_java_gateway(spark_session: SparkSession) -> JavaGateway:
"""
Get py4j Java gateway object
"""
return spark_session._sc._gateway # noqa: WPS437 # type: ignore
return spark_session._sc._gateway # noqa: WPS437 # type: ignore[attr-defined]


def try_import_java_class(spark_session: SparkSession, name: str):
Expand Down
23 changes: 23 additions & 0 deletions onetl/_util/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.conf import RuntimeConfig

SPARK_JOB_DESCRIPTION_PROPERTY = "spark.job.description"
SPARK_JOB_GROUP_PROPERTY = "spark.jobGroup.id"


def stringify(value: Any, quote: bool = False) -> Any: # noqa: WPS212
"""
Expand Down Expand Up @@ -200,3 +203,23 @@ def get_executor_total_cores(spark_session: SparkSession, include_driver: bool =
expected_cores += 1

return expected_cores, config


@contextmanager
def override_job_description(spark_session: SparkSession, job_description: str):
"""
Override Spark job description.

Unlike ``spark_session.sparkContext.setJobDescription``, this method resets job description
before exiting the context manager, instead of keeping it.

If user set custom description, it will be left intact.
"""
spark_context = spark_session.sparkContext
original_description = spark_context.getLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY)

try:
spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description or job_description)
yield
finally:
spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description) # type: ignore[arg-type]
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/clickhouse/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def jdbc_params(self) -> dict:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"

@staticmethod
def _build_statement(
statement: str,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def package_spark_3_2(cls) -> str:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

@property
def jdbc_url(self) -> str:
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}"
Expand Down
14 changes: 10 additions & 4 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pydantic import validator # type: ignore[no-redef, assignment]

from onetl._metrics.recorder import SparkMetricsRecorder
from onetl._util.spark import inject_spark_param
from onetl._util.spark import inject_spark_param, override_job_description
from onetl._util.sql import clear_statement
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.hive.dialect import HiveDialect
Expand Down Expand Up @@ -159,6 +159,9 @@ def get_current(cls, spark: SparkSession):
def instance_url(self) -> str:
return self.cluster

def __str__(self):
return f"{self.__class__.__name__}[{self.cluster}]"

@slot
def check(self):
log.debug("|%s| Detecting current cluster...", self.__class__.__name__)
Expand All @@ -173,7 +176,8 @@ def check(self):
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
with override_job_description(self.spark, f"{self}.check()"):
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down Expand Up @@ -213,7 +217,8 @@ def sql(

with SparkMetricsRecorder(self.spark) as recorder:
try:
df = self._execute_sql(query)
with override_job_description(self.spark, f"{self}.sql()"):
df = self._execute_sql(query)
except Exception:
log.error("|%s| Query failed", self.__class__.__name__)

Expand Down Expand Up @@ -260,7 +265,8 @@ def execute(

with SparkMetricsRecorder(self.spark) as recorder:
try:
self._execute_sql(statement).collect()
with override_job_description(self.spark, f"{self}.execute()"):
self._execute_sql(statement).collect()
except Exception:
log.error("|%s| Execution failed", self.__class__.__name__)
metrics = recorder.metrics()
Expand Down
4 changes: 3 additions & 1 deletion onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from typing import TYPE_CHECKING, Any

from onetl._util.spark import override_job_description
from onetl._util.sql import clear_statement
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect
Expand Down Expand Up @@ -93,7 +94,8 @@ def sql(
log_lines(log, query)

try:
df = self._query_on_executor(query, self.SQLOptions.parse(options))
with override_job_description(self.spark, f"{self}.sql()"):
df = self._query_on_executor(query, self.SQLOptions.parse(options))
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise
Expand Down
77 changes: 42 additions & 35 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from onetl._metrics.command import SparkCommandMetrics
from onetl._util.java import get_java_gateway, try_import_java_class
from onetl._util.spark import estimate_dataframe_size, get_spark_version, stringify
from onetl._util.spark import (
estimate_dataframe_size,
get_spark_version,
override_job_description,
stringify,
)
from onetl._util.sql import clear_statement
from onetl._util.version import Version
from onetl.connection.db_connection.jdbc_mixin.options import (
Expand Down Expand Up @@ -209,21 +214,22 @@ def fetch(
else self.FetchOptions.parse(options)
)

try:
df = self._query_on_driver(query, call_options)
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise

log.info("|%s| Query succeeded, created in-memory dataframe.", self.__class__.__name__)

# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)
log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
with override_job_description(self.spark, f"{self}.fetch()"):
try:
df = self._query_on_driver(query, call_options)
except Exception:
log.error("|%s| Query failed!", self.__class__.__name__)
raise

log.info("|%s| Query succeeded, created in-memory dataframe.", self.__class__.__name__)

# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)
log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
return df

@slot
Expand Down Expand Up @@ -280,25 +286,26 @@ def execute(
else self.ExecuteOptions.parse(options)
)

try:
df = self._call_on_driver(statement, call_options)
except Exception:
log.error("|%s| Execution failed!", self.__class__.__name__)
raise

if not df:
log.info("|%s| Execution succeeded, nothing returned.", self.__class__.__name__)
return None

log.info("|%s| Execution succeeded, created in-memory dataframe.", self.__class__.__name__)
# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)

log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
with override_job_description(self.spark, f"{self}.execute()"):
try:
df = self._call_on_driver(statement, call_options)
except Exception:
log.error("|%s| Execution failed!", self.__class__.__name__)
raise

if not df:
log.info("|%s| Execution succeeded, nothing returned.", self.__class__.__name__)
return None

log.info("|%s| Execution succeeded, created in-memory dataframe.", self.__class__.__name__)
# as we don't actually use Spark for this method, SparkMetricsRecorder is useless.
# Just create metrics by hand, and fill them up using information based on dataframe content.
metrics = SparkCommandMetrics()
metrics.input.read_rows = df.count()
metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df)

log.info("|%s| Recorded metrics:", self.__class__.__name__)
log_lines(log, str(metrics))
return df

@validator("spark")
Expand Down
7 changes: 5 additions & 2 deletions onetl/connection/db_connection/kafka/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def get_min_max_values(
# https://kafka.apache.org/22/javadoc/org/apache/kafka/clients/consumer/KafkaConsumer.html#partitionsFor-java.lang.String-
partition_infos = consumer.partitionsFor(source)

jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
topic_partitions = [
jvm.org.apache.kafka.common.TopicPartition(source, p.partition()) # type: ignore[union-attr]
for p in partition_infos
Expand Down Expand Up @@ -542,6 +542,9 @@ def get_min_max_values(
def instance_url(self):
return "kafka://" + self.cluster

def __str__(self):
return f"{self.__class__.__name__}[{self.cluster}]"

@root_validator(pre=True)
def _get_addresses_by_cluster(cls, values):
cluster = values.get("cluster")
Expand Down Expand Up @@ -639,7 +642,7 @@ def _get_java_consumer(self):
return consumer_class(connection_properties)

def _get_topics(self, timeout: int = 10) -> set[str]:
jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
# Maybe we should not pass explicit timeout at all,
# and instead use default.api.timeout.ms which is configurable via self.extra.
# Think about this next time if someone see issues in real use
Expand Down
20 changes: 14 additions & 6 deletions onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from onetl._util.classproperty import classproperty
from onetl._util.java import try_import_java_class
from onetl._util.scala import get_default_scala_version
from onetl._util.spark import get_spark_version
from onetl._util.spark import get_spark_version, override_job_description
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.mongodb.dialect import MongoDBDialect
Expand Down Expand Up @@ -347,17 +347,25 @@
if pipeline:
read_options["aggregation.pipeline"] = json.dumps(pipeline)
read_options["connection.uri"] = self.connection_url
spark_reader = self.spark.read.format("mongodb").options(**read_options)

if df_schema:
spark_reader = spark_reader.schema(df_schema)
with override_job_description(
self.spark,
f"{self}.pipeline()",
):
spark_reader = self.spark.read.format("mongodb").options(**read_options)

return spark_reader.load()
if df_schema:
spark_reader = spark_reader.schema(df_schema)

Check warning on line 358 in onetl/connection/db_connection/mongodb/connection.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/db_connection/mongodb/connection.py#L358

Added line #L358 was not covered by tests

return spark_reader.load()

@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
Expand Down Expand Up @@ -532,7 +540,7 @@
return spark

def _collection_exists(self, source: str) -> bool:
jvm = self.spark._jvm
jvm = self.spark._jvm # type: ignore[attr-defined]
client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore
collections = set(client.getDatabase(self.database).listCollectionNames().iterator())
if source in collections:
Expand Down
9 changes: 9 additions & 0 deletions onetl/connection/db_connection/mssql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,12 @@ def instance_url(self) -> str:
# for backward compatibility keep port number in legacy HWM instance url
port = self.port or 1433
return f"{self.__class__.__name__.lower()}://{self.host}:{port}/{self.database}"

def __str__(self):
extra_dict = self.extra.dict(by_alias=True)
instance_name = extra_dict.get("instanceName")
if instance_name:
return rf"{self.__class__.__name__}[{self.host}\{instance_name}/{self.database}]"

port = self.port or 1433
return f"{self.__class__.__name__}[{self.host}:{port}/{self.database}]"
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/mysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,6 @@ def jdbc_params(self) -> dict:
@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"
6 changes: 6 additions & 0 deletions onetl/connection/db_connection/oracle/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ def instance_url(self) -> str:

return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.service_name}"

def __str__(self):
if self.sid:
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.sid}]"

return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.service_name}]"

@slot
def get_min_max_values(
self,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/postgres/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def jdbc_params(self) -> dict[str, str]:
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]"

def _options_to_connection_properties(
self,
options: JDBCFetchOptions | JDBCExecuteOptions,
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/teradata/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,6 @@ def jdbc_url(self) -> str:
@property
def instance_url(self) -> str:
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"
5 changes: 4 additions & 1 deletion onetl/connection/file_connection/ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@

@property
def instance_url(self) -> str:
return f"ftp://{self.host}:{self.port}"
return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}"

Check warning on line 108 in onetl/connection/file_connection/ftp.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/file_connection/ftp.py#L108

Added line #L108 was not covered by tests

def __str__(self):
return f"{self.__class__.__name__}[{self.host}:{self.port}]"

@slot
def path_exists(self, path: os.PathLike | str) -> bool:
Expand Down
Loading