From f828c843b7887c0b36e070a9187404e16b1c2343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D1=80=D1=82=D1=8B=D0=BD=D0=BE=D0=B2=20=D0=9C?= =?UTF-8?q?=D0=B0=D0=BA=D1=81=D0=B8=D0=BC=20=D0=A1=D0=B5=D1=80=D0=B3=D0=B5?= =?UTF-8?q?=D0=B5=D0=B2=D0=B8=D1=87?= Date: Thu, 8 Aug 2024 16:21:10 +0000 Subject: [PATCH] DOP-18743] Set default jobDescription --- docs/changelog/next_release/304.breaking.rst | 3 + docs/changelog/next_release/304.feature.rst | 6 ++ onetl/_util/hadoop.py | 4 +- onetl/_util/java.py | 2 +- onetl/_util/spark.py | 23 ++++ onetl/base/base_db_connection.py | 3 +- onetl/base/base_file_df_connection.py | 4 +- .../db_connection/clickhouse/connection.py | 3 + .../db_connection/greenplum/connection.py | 3 + .../db_connection/hive/connection.py | 18 +++- .../jdbc_connection/connection.py | 7 +- .../db_connection/jdbc_mixin/connection.py | 54 ++++++---- .../db_connection/kafka/connection.py | 7 +- .../db_connection/mongodb/connection.py | 20 ++-- .../db_connection/mssql/connection.py | 9 ++ .../db_connection/mysql/connection.py | 3 + .../db_connection/oracle/connection.py | 43 +++----- .../db_connection/postgres/connection.py | 3 + .../db_connection/teradata/connection.py | 3 + onetl/connection/file_connection/ftp.py | 5 +- onetl/connection/file_connection/ftps.py | 4 - .../file_connection/hdfs/connection.py | 5 + onetl/connection/file_connection/s3.py | 5 +- onetl/connection/file_connection/samba.py | 5 +- onetl/connection/file_connection/sftp.py | 5 +- onetl/connection/file_connection/webdav.py | 5 +- .../spark_hdfs/connection.py | 3 + .../file_df_connection/spark_local_fs.py | 4 + .../file_df_connection/spark_s3/connection.py | 5 +- onetl/db/db_reader/db_reader.py | 92 ++++++++-------- onetl/db/db_writer/db_writer.py | 27 +++-- onetl/file/file_df_reader/file_df_reader.py | 27 +++-- onetl/file/file_df_writer/file_df_writer.py | 27 +++-- tests/fixtures/spark.py | 3 +- .../test_clickhouse_unit.py | 4 +- .../test_greenplum_unit.py | 5 +- .../test_kafka_unit.py | 4 + .../test_mongodb_unit.py | 22 ++-- .../test_mssql_unit.py | 6 +- .../test_mysql_unit.py | 6 +- .../test_oracle_unit.py | 6 +- .../test_postgres_unit.py | 5 +- .../test_teradata_unit.py | 5 +- .../test_ftp_unit.py | 33 +++--- .../test_ftps_unit.py | 33 +++--- .../test_hdfs_unit.py | 101 +++++++++--------- .../test_s3_unit.py | 56 +++++----- .../test_samba_unit.py | 39 +++---- .../test_sftp_unit.py | 46 ++++---- .../test_webdav_unit.py | 37 ++++--- .../test_spark_hdfs_unit.py | 39 +++---- .../test_spark_local_fs_unit.py | 1 + .../test_spark_s3_unit.py | 46 ++++---- 53 files changed, 548 insertions(+), 386 deletions(-) create mode 100644 docs/changelog/next_release/304.breaking.rst create mode 100644 docs/changelog/next_release/304.feature.rst diff --git a/docs/changelog/next_release/304.breaking.rst b/docs/changelog/next_release/304.breaking.rst new file mode 100644 index 000000000..605983210 --- /dev/null +++ b/docs/changelog/next_release/304.breaking.rst @@ -0,0 +1,3 @@ +Change connection URL used for generating HWM names of S3 and Samba sources: +* ``smb://host:port`` -> ``smb://host:port/share`` +* ``s3://host:port`` -> ``s3://host:port/bucket`` diff --git a/docs/changelog/next_release/304.feature.rst b/docs/changelog/next_release/304.feature.rst new file mode 100644 index 000000000..975603547 --- /dev/null +++ b/docs/changelog/next_release/304.feature.rst @@ -0,0 +1,6 @@ +Generate default ``jobDescription`` based on currently executed method. Examples: +* ``DBWriter() -> Postgres[host:5432/database]`` +* ``MongoDB[localhost:27017/admin] -> DBReader.run()`` +* ``Hive[cluster].execute()`` + +If user already set custom ``jobDescription``, it will left intact. diff --git a/onetl/_util/hadoop.py b/onetl/_util/hadoop.py index fdf275de7..aed572e06 100644 --- a/onetl/_util/hadoop.py +++ b/onetl/_util/hadoop.py @@ -14,7 +14,7 @@ def get_hadoop_version(spark_session: SparkSession) -> Version: """ Get version of Hadoop libraries embedded to Spark """ - jvm = spark_session._jvm # noqa: WPS437 + jvm = spark_session._jvm # noqa: WPS437 # type: ignore[attr-defined] version_info = jvm.org.apache.hadoop.util.VersionInfo # type: ignore[union-attr] hadoop_version: str = version_info.getVersion() return Version(hadoop_version) @@ -24,4 +24,4 @@ def get_hadoop_config(spark_session: SparkSession): """ Get ``org.apache.hadoop.conf.Configuration`` object """ - return spark_session.sparkContext._jsc.hadoopConfiguration() + return spark_session.sparkContext._jsc.hadoopConfiguration() # type: ignore[attr-defined] diff --git a/onetl/_util/java.py b/onetl/_util/java.py index df88b1a59..4d413d66e 100644 --- a/onetl/_util/java.py +++ b/onetl/_util/java.py @@ -13,7 +13,7 @@ def get_java_gateway(spark_session: SparkSession) -> JavaGateway: """ Get py4j Java gateway object """ - return spark_session._sc._gateway # noqa: WPS437 # type: ignore + return spark_session._sc._gateway # noqa: WPS437 # type: ignore[attr-defined] def try_import_java_class(spark_session: SparkSession, name: str): diff --git a/onetl/_util/spark.py b/onetl/_util/spark.py index f172b1c98..108f33148 100644 --- a/onetl/_util/spark.py +++ b/onetl/_util/spark.py @@ -19,6 +19,9 @@ from pyspark.sql import SparkSession from pyspark.sql.conf import RuntimeConfig +SPARK_JOB_DESCRIPTION_PROPERTY = "spark.job.description" +SPARK_JOB_GROUP_PROPERTY = "spark.jobGroup.id" + def stringify(value: Any, quote: bool = False) -> Any: # noqa: WPS212 """ @@ -185,3 +188,23 @@ def get_executor_total_cores(spark_session: SparkSession, include_driver: bool = expected_cores += 1 return expected_cores, config + + +@contextmanager +def override_job_description(spark_session: SparkSession, job_description: str): + """ + Override Spark job description. + + Unlike ``spark_session.sparkContext.setJobDescription``, this method resets job description + before exiting the context manager, instead of keeping it. + + If user set custom description, it will be left intact. + """ + spark_context = spark_session.sparkContext + original_description = spark_context.getLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY) + + try: + spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description or job_description) + yield + finally: + spark_context.setLocalProperty(SPARK_JOB_DESCRIPTION_PROPERTY, original_description) # type: ignore[arg-type] diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index f9c7bcac0..2c427debd 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from etl_entities.hwm import HWM - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StructField, StructType @@ -106,6 +106,7 @@ class BaseDBConnection(BaseConnection): Implements generic methods for reading and writing dataframe from/to database-like source """ + spark: SparkSession Dialect = BaseDBDialect @property diff --git a/onetl/base/base_file_df_connection.py b/onetl/base/base_file_df_connection.py index c54390ce8..28c57f3c7 100644 --- a/onetl/base/base_file_df_connection.py +++ b/onetl/base/base_file_df_connection.py @@ -11,7 +11,7 @@ from onetl.base.pure_path_protocol import PurePathProtocol if TYPE_CHECKING: - from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter + from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter, SparkSession from pyspark.sql.types import StructType @@ -72,6 +72,8 @@ class BaseFileDFConnection(BaseConnection): .. versionadded:: 0.9.0 """ + spark: SparkSession + @abstractmethod def check_if_format_supported( self, diff --git a/onetl/connection/db_connection/clickhouse/connection.py b/onetl/connection/db_connection/clickhouse/connection.py index 0ca6d0ce7..482cc941b 100644 --- a/onetl/connection/db_connection/clickhouse/connection.py +++ b/onetl/connection/db_connection/clickhouse/connection.py @@ -196,6 +196,9 @@ def jdbc_params(self) -> dict: def instance_url(self) -> str: return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}]" + @staticmethod def _build_statement( statement: str, diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index 7ed60539b..0f40436fc 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -267,6 +267,9 @@ def package_spark_3_2(cls) -> str: def instance_url(self) -> str: return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}" + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]" + @property def jdbc_url(self) -> str: return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}" diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 81c50e87e..d0c218b36 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -13,7 +13,7 @@ except (ImportError, AttributeError): from pydantic import validator # type: ignore[no-redef, assignment] -from onetl._util.spark import inject_spark_param +from onetl._util.spark import inject_spark_param, override_job_description from onetl._util.sql import clear_statement from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.hive.dialect import HiveDialect @@ -158,6 +158,9 @@ def get_current(cls, spark: SparkSession): def instance_url(self) -> str: return self.cluster + def __str__(self): + return f"{self.__class__.__name__}[{self.cluster}]" + @slot def check(self): log.debug("|%s| Detecting current cluster...", self.__class__.__name__) @@ -210,7 +213,11 @@ def sql( log.info("|%s| Executing SQL query:", self.__class__.__name__) log_lines(log, query) - df = self._execute_sql(query) + with override_job_description( + self.spark, + f"{self}.sql()", + ): + df = self._execute_sql(query) log.info("|Spark| DataFrame successfully created from SQL statement") return df @@ -236,7 +243,12 @@ def execute( log.info("|%s| Executing statement:", self.__class__.__name__) log_lines(log, statement) - self._execute_sql(statement).collect() + with override_job_description( + self.spark, + f"{self}.execute()", + ): + self._execute_sql(statement).collect() + log.info("|%s| Call succeeded", self.__class__.__name__) @slot diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index 5b0aebeb8..7d9af47a5 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -7,6 +7,7 @@ import warnings from typing import TYPE_CHECKING, Any +from onetl._util.spark import override_job_description from onetl._util.sql import clear_statement from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect @@ -92,7 +93,11 @@ def sql( log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) log_lines(log, query) - df = self._query_on_executor(query, self.SQLOptions.parse(options)) + with override_job_description( + self.spark, + f"{self}.sql()", + ): + df = self._query_on_executor(query, self.SQLOptions.parse(options)) log.info("|Spark| DataFrame successfully created from SQL statement ") return df diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index e8c19e38b..fd0f7c8a8 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -17,7 +17,7 @@ from pydantic import Field, PrivateAttr, SecretStr, validator # type: ignore[no-redef, assignment] from onetl._util.java import get_java_gateway, try_import_java_class -from onetl._util.spark import get_spark_version, stringify +from onetl._util.spark import get_spark_version, override_job_description, stringify from onetl._util.sql import clear_statement from onetl._util.version import Version from onetl.connection.db_connection.jdbc_mixin.options import ( @@ -204,20 +204,23 @@ def fetch( log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) log_lines(log, query) - df = self._query_on_driver( - query, - ( - self.FetchOptions.parse(options.dict()) # type: ignore - if isinstance(options, JDBCMixinOptions) - else self.FetchOptions.parse(options) - ), + call_options = ( + self.FetchOptions.parse(options.dict()) # type: ignore + if isinstance(options, JDBCMixinOptions) + else self.FetchOptions.parse(options) ) - log.info( - "|%s| Query succeeded, resulting in-memory dataframe contains %d rows", - self.__class__.__name__, - df.count(), - ) + with override_job_description( + self.spark, + f"{self}.fetch()", + ): + df = self._query_on_driver(query, call_options) + + log.info( + "|%s| Query succeeded, resulting in-memory dataframe contains %d rows", + self.__class__.__name__, + df.count(), + ) return df @slot @@ -273,17 +276,22 @@ def execute( if isinstance(options, JDBCMixinOptions) else self.ExecuteOptions.parse(options) ) - df = self._call_on_driver(statement, call_options) - if df is not None: - rows_count = df.count() - log.info( - "|%s| Execution succeeded, resulting in-memory dataframe contains %d rows", - self.__class__.__name__, - rows_count, - ) - else: - log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__) + with override_job_description( + self.spark, + f"{self}.execute()", + ): + df = self._call_on_driver(statement, call_options) + + if df is not None: + rows_count = df.count() + log.info( + "|%s| Execution succeeded, resulting in-memory dataframe contains %d rows", + self.__class__.__name__, + rows_count, + ) + else: + log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__) return df @validator("spark") diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py index ce3829e49..b404eafba 100644 --- a/onetl/connection/db_connection/kafka/connection.py +++ b/onetl/connection/db_connection/kafka/connection.py @@ -497,7 +497,7 @@ def get_min_max_values( # https://kafka.apache.org/22/javadoc/org/apache/kafka/clients/consumer/KafkaConsumer.html#partitionsFor-java.lang.String- partition_infos = consumer.partitionsFor(source) - jvm = self.spark._jvm + jvm = self.spark._jvm # type: ignore[attr-defined] topic_partitions = [ jvm.org.apache.kafka.common.TopicPartition(source, p.partition()) # type: ignore[union-attr] for p in partition_infos @@ -542,6 +542,9 @@ def get_min_max_values( def instance_url(self): return "kafka://" + self.cluster + def __str__(self): + return f"{self.__class__.__name__}[{self.cluster}]" + @root_validator(pre=True) def _get_addresses_by_cluster(cls, values): cluster = values.get("cluster") @@ -639,7 +642,7 @@ def _get_java_consumer(self): return consumer_class(connection_properties) def _get_topics(self, timeout: int = 10) -> set[str]: - jvm = self.spark._jvm + jvm = self.spark._jvm # type: ignore[attr-defined] # Maybe we should not pass explicit timeout at all, # and instead use default.api.timeout.ms which is configurable via self.extra. # Think about this next time if someone see issues in real use diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index 568cd9537..f81a3bf8b 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -18,7 +18,7 @@ from onetl._util.classproperty import classproperty from onetl._util.java import try_import_java_class from onetl._util.scala import get_default_scala_version -from onetl._util.spark import get_spark_version +from onetl._util.spark import get_spark_version, override_job_description from onetl._util.version import Version from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.mongodb.dialect import MongoDBDialect @@ -347,17 +347,25 @@ def pipeline( if pipeline: read_options["aggregation.pipeline"] = json.dumps(pipeline) read_options["connection.uri"] = self.connection_url - spark_reader = self.spark.read.format("mongodb").options(**read_options) - if df_schema: - spark_reader = spark_reader.schema(df_schema) + with override_job_description( + self.spark, + f"{self}.pipeline()", + ): + spark_reader = self.spark.read.format("mongodb").options(**read_options) - return spark_reader.load() + if df_schema: + spark_reader = spark_reader.schema(df_schema) + + return spark_reader.load() @property def instance_url(self) -> str: return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}" + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]" + @slot def check(self): log.info("|%s| Checking connection availability...", self.__class__.__name__) @@ -532,7 +540,7 @@ def _check_java_class_imported(cls, spark): return spark def _collection_exists(self, source: str) -> bool: - jvm = self.spark._jvm + jvm = self.spark._jvm # type: ignore[attr-defined] client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore collections = set(client.getDatabase(self.database).listCollectionNames().iterator()) if source in collections: diff --git a/onetl/connection/db_connection/mssql/connection.py b/onetl/connection/db_connection/mssql/connection.py index 556cb4cb3..f2a29b448 100644 --- a/onetl/connection/db_connection/mssql/connection.py +++ b/onetl/connection/db_connection/mssql/connection.py @@ -268,3 +268,12 @@ def instance_url(self) -> str: # for backward compatibility keep port number in legacy HWM instance url port = self.port or 1433 return f"{self.__class__.__name__.lower()}://{self.host}:{port}/{self.database}" + + def __str__(self): + extra_dict = self.extra.dict(by_alias=True) + instance_name = extra_dict.get("instanceName") + if instance_name: + return rf"{self.__class__.__name__}[{self.host}\{instance_name}/{self.database}]" + + port = self.port or 1433 + return f"{self.__class__.__name__}[{self.host}:{port}/{self.database}]" diff --git a/onetl/connection/db_connection/mysql/connection.py b/onetl/connection/db_connection/mysql/connection.py index 72090d585..e3c911964 100644 --- a/onetl/connection/db_connection/mysql/connection.py +++ b/onetl/connection/db_connection/mysql/connection.py @@ -175,3 +175,6 @@ def jdbc_params(self) -> dict: @property def instance_url(self) -> str: return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" + + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}]" diff --git a/onetl/connection/db_connection/oracle/connection.py b/onetl/connection/db_connection/oracle/connection.py index 043989500..75bc71d8c 100644 --- a/onetl/connection/db_connection/oracle/connection.py +++ b/onetl/connection/db_connection/oracle/connection.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from decimal import Decimal from textwrap import indent -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import Any, ClassVar, Optional try: from pydantic.v1 import root_validator @@ -20,14 +20,12 @@ from etl_entities.instance import Host from onetl._util.classproperty import classproperty -from onetl._util.sql import clear_statement from onetl._util.version import Version from onetl.connection.db_connection.jdbc_connection import JDBCConnection from onetl.connection.db_connection.jdbc_connection.options import JDBCReadOptions from onetl.connection.db_connection.jdbc_mixin.options import ( JDBCExecuteOptions, JDBCFetchOptions, - JDBCOptions, ) from onetl.connection.db_connection.oracle.dialect import OracleDialect from onetl.connection.db_connection.oracle.options import ( @@ -42,12 +40,6 @@ from onetl.impl import GenericOptions from onetl.log import BASE_LOG_INDENT, log_lines -# do not import PySpark here, as we allow user to use `Oracle.get_packages()` for creating Spark session - - -if TYPE_CHECKING: - from pyspark.sql import DataFrame - log = logging.getLogger(__name__) # CREATE ... PROCEDURE name ... @@ -266,6 +258,12 @@ def instance_url(self) -> str: return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.service_name}" + def __str__(self): + if self.sid: + return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.sid}]" + + return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.service_name}]" + @slot def get_min_max_values( self, @@ -290,31 +288,14 @@ def get_min_max_values( max_value = int(max_value) return min_value, max_value - @slot - def execute( + def _call_on_driver( self, statement: str, - options: JDBCOptions | JDBCExecuteOptions | dict | None = None, # noqa: WPS437 - ) -> DataFrame | None: - statement = clear_statement(statement) - - log.info("|%s| Executing statement (on driver):", self.__class__.__name__) - log_lines(log, statement) - - call_options = self.ExecuteOptions.parse(options) - df = self._call_on_driver(statement, call_options) + call_options: JDBCExecuteOptions, + ): + result = super()._call_on_driver(statement, call_options) self._handle_compile_errors(statement.strip(), call_options) - - if df is not None: - rows_count = df.count() - log.info( - "|%s| Execution succeeded, resulting in-memory dataframe contains %d rows", - self.__class__.__name__, - rows_count, - ) - else: - log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__) - return df + return result @root_validator def _only_one_of_sid_or_service_name(cls, values): diff --git a/onetl/connection/db_connection/postgres/connection.py b/onetl/connection/db_connection/postgres/connection.py index 132d9727f..1c11d9e3e 100644 --- a/onetl/connection/db_connection/postgres/connection.py +++ b/onetl/connection/db_connection/postgres/connection.py @@ -182,6 +182,9 @@ def jdbc_params(self) -> dict[str, str]: def instance_url(self) -> str: return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.database}" + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]" + def _options_to_connection_properties( self, options: JDBCFetchOptions | JDBCExecuteOptions, diff --git a/onetl/connection/db_connection/teradata/connection.py b/onetl/connection/db_connection/teradata/connection.py index 6ef2637b4..9c8f073c5 100644 --- a/onetl/connection/db_connection/teradata/connection.py +++ b/onetl/connection/db_connection/teradata/connection.py @@ -208,3 +208,6 @@ def jdbc_url(self) -> str: @property def instance_url(self) -> str: return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" + + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}]" diff --git a/onetl/connection/file_connection/ftp.py b/onetl/connection/file_connection/ftp.py index b457b966f..d5ff5216c 100644 --- a/onetl/connection/file_connection/ftp.py +++ b/onetl/connection/file_connection/ftp.py @@ -105,7 +105,10 @@ class FTP(FileConnection, RenameDirMixin): @property def instance_url(self) -> str: - return f"ftp://{self.host}:{self.port}" + return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" + + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}]" @slot def path_exists(self, path: os.PathLike | str) -> bool: diff --git a/onetl/connection/file_connection/ftps.py b/onetl/connection/file_connection/ftps.py index 8cf9aa8fc..0180edf40 100644 --- a/onetl/connection/file_connection/ftps.py +++ b/onetl/connection/file_connection/ftps.py @@ -95,10 +95,6 @@ class FTPS(FTP): ) """ - @property - def instance_url(self) -> str: - return f"ftps://{self.host}:{self.port}" - def _get_client(self) -> FTPHost: """ Returns a FTPS connection object diff --git a/onetl/connection/file_connection/hdfs/connection.py b/onetl/connection/file_connection/hdfs/connection.py index 056622fbe..89c0ec961 100644 --- a/onetl/connection/file_connection/hdfs/connection.py +++ b/onetl/connection/file_connection/hdfs/connection.py @@ -264,6 +264,11 @@ def instance_url(self) -> str: return self.cluster return f"hdfs://{self.host}:{self.webhdfs_port}" + def __str__(self): + if self.cluster: + return f"{self.__class__.__name__}[{self.cluster}]" + return f"{self.__class__.__name__}[{self.host}:{self.webhdfs_port}]" + @slot def path_exists(self, path: os.PathLike | str) -> bool: return self.client.status(os.fspath(path), strict=False) diff --git a/onetl/connection/file_connection/s3.py b/onetl/connection/file_connection/s3.py index f8f584dcc..0f411c85d 100644 --- a/onetl/connection/file_connection/s3.py +++ b/onetl/connection/file_connection/s3.py @@ -131,7 +131,10 @@ def validate_port(cls, values): @property def instance_url(self) -> str: - return f"s3://{self.host}:{self.port}" + return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}/{self.bucket}" + + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.bucket}]" @slot def create_dir(self, path: os.PathLike | str) -> RemoteDirectory: diff --git a/onetl/connection/file_connection/samba.py b/onetl/connection/file_connection/samba.py index 9fc0857fa..430e15a71 100644 --- a/onetl/connection/file_connection/samba.py +++ b/onetl/connection/file_connection/samba.py @@ -125,7 +125,10 @@ class Samba(FileConnection): @property def instance_url(self) -> str: - return f"smb://{self.host}:{self.port}" + return f"smb://{self.host}:{self.port}/{self.share}" + + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.share}]" @slot def check(self): diff --git a/onetl/connection/file_connection/sftp.py b/onetl/connection/file_connection/sftp.py index 8cd2ac1ed..92db2adce 100644 --- a/onetl/connection/file_connection/sftp.py +++ b/onetl/connection/file_connection/sftp.py @@ -120,7 +120,10 @@ class SFTP(FileConnection, RenameDirMixin): @property def instance_url(self) -> str: - return f"sftp://{self.host}:{self.port}" + return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" + + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}]" @slot def path_exists(self, path: os.PathLike | str) -> bool: diff --git a/onetl/connection/file_connection/webdav.py b/onetl/connection/file_connection/webdav.py index aa540567e..44ac766a5 100644 --- a/onetl/connection/file_connection/webdav.py +++ b/onetl/connection/file_connection/webdav.py @@ -130,7 +130,10 @@ def check_port(cls, values): @property def instance_url(self) -> str: - return f"webdav://{self.host}:{self.port}" + return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" + + def __str__(self): + return f"{self.__class__.__name__}[{self.host}:{self.port}]" @slot def path_exists(self, path: os.PathLike | str) -> bool: diff --git a/onetl/connection/file_df_connection/spark_hdfs/connection.py b/onetl/connection/file_df_connection/spark_hdfs/connection.py index 26c1416eb..10ff10058 100644 --- a/onetl/connection/file_df_connection/spark_hdfs/connection.py +++ b/onetl/connection/file_df_connection/spark_hdfs/connection.py @@ -164,6 +164,9 @@ def path_from_string(self, path: os.PathLike | str) -> Path: def instance_url(self): return self.cluster + def __str__(self): + return f"HDFS[{self.cluster}]" + def __enter__(self): return self diff --git a/onetl/connection/file_df_connection/spark_local_fs.py b/onetl/connection/file_df_connection/spark_local_fs.py index 839cbdaec..71c704145 100644 --- a/onetl/connection/file_df_connection/spark_local_fs.py +++ b/onetl/connection/file_df_connection/spark_local_fs.py @@ -74,6 +74,10 @@ def instance_url(self): fqdn = socket.getfqdn() return f"file://{fqdn}" + def __str__(self): + # str should not make network requests + return "LocalFS" + @validator("spark") def _validate_spark(cls, spark): master = spark.conf.get("spark.master") diff --git a/onetl/connection/file_df_connection/spark_s3/connection.py b/onetl/connection/file_df_connection/spark_s3/connection.py index 1efe39d4e..eb74d6981 100644 --- a/onetl/connection/file_df_connection/spark_s3/connection.py +++ b/onetl/connection/file_df_connection/spark_s3/connection.py @@ -256,7 +256,10 @@ def path_from_string(self, path: os.PathLike | str) -> RemotePath: @property def instance_url(self): - return f"s3://{self.host}:{self.port}" + return f"s3://{self.host}:{self.port}/{self.bucket}" + + def __str__(self): + return f"S3[{self.host}:{self.port}/{self.bucket}]" def __enter__(self): return self diff --git a/onetl/db/db_reader/db_reader.py b/onetl/db/db_reader/db_reader.py index 91b3f21b1..f560104d4 100644 --- a/onetl/db/db_reader/db_reader.py +++ b/onetl/db/db_reader/db_reader.py @@ -17,7 +17,7 @@ except (ImportError, AttributeError): from pydantic import Field, PrivateAttr, root_validator, validator # type: ignore[no-redef, assignment] -from onetl._util.spark import try_import_pyspark +from onetl._util.spark import override_job_description, try_import_pyspark from onetl.base import ( BaseDBConnection, ContainsGetDFSchemaMethod, @@ -542,26 +542,30 @@ def has_data(self) -> bool: """ self._check_strategy() - if not self._connection_checked: - self._log_parameters() - self.connection.check() - - window, limit = self._calculate_window_and_limit() - if limit == 0: - return False - - df = self.connection.read_source_as_df( - source=str(self.source), - columns=self.columns, - hint=self.hint, - where=self.where, - df_schema=self.df_schema, - window=window, - limit=1, - **self._get_read_kwargs(), - ) + with override_job_description( + self.connection.spark, + f"{self.connection} -> {self.__class__.__name__}.has_data()", + ): + if not self._connection_checked: + self._log_parameters() + self.connection.check() + + window, limit = self._calculate_window_and_limit() + if limit == 0: + return False + + df = self.connection.read_source_as_df( + source=str(self.source), + columns=self.columns, + hint=self.hint, + where=self.where, + df_schema=self.df_schema, + window=window, + limit=1, + **self._get_read_kwargs(), + ) - return bool(df.take(1)) + return bool(df.take(1)) @slot def raise_if_no_data(self) -> None: @@ -633,28 +637,32 @@ def run(self) -> DataFrame: self._check_strategy() - if not self._connection_checked: - self._log_parameters() - self.connection.check() - self._connection_checked = True - - window, limit = self._calculate_window_and_limit() - - # update the HWM with the stop value - if self.hwm and window: - strategy: HWMStrategy = StrategyManager.get_current() # type: ignore[assignment] - strategy.update_hwm(window.stop_at.value) - - df = self.connection.read_source_as_df( - source=str(self.source), - columns=self.columns, - hint=self.hint, - where=self.where, - df_schema=self.df_schema, - window=window, - limit=limit, - **self._get_read_kwargs(), - ) + with override_job_description( + self.connection.spark, + f"{self.connection} -> {self.__class__.__name__}.run()", + ): + if not self._connection_checked: + self._log_parameters() + self.connection.check() + self._connection_checked = True + + window, limit = self._calculate_window_and_limit() + + # update the HWM with the stop value + if self.hwm and window: + strategy: HWMStrategy = StrategyManager.get_current() # type: ignore[assignment] + strategy.update_hwm(window.stop_at.value) + + df = self.connection.read_source_as_df( + source=str(self.source), + columns=self.columns, + hint=self.hint, + where=self.where, + df_schema=self.df_schema, + window=window, + limit=limit, + **self._get_read_kwargs(), + ) entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") return df diff --git a/onetl/db/db_writer/db_writer.py b/onetl/db/db_writer/db_writer.py index 666fce87e..3ef510674 100644 --- a/onetl/db/db_writer/db_writer.py +++ b/onetl/db/db_writer/db_writer.py @@ -10,6 +10,7 @@ except (ImportError, AttributeError): from pydantic import Field, PrivateAttr, validator # type: ignore[no-redef, assignment] +from onetl._util.spark import override_job_description from onetl.base import BaseDBConnection from onetl.hooks import slot, support_hooks from onetl.impl import FrozenModel, GenericOptions @@ -199,17 +200,21 @@ def run(self, df: DataFrame): entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") - if not self._connection_checked: - self._log_parameters() - log_dataframe_schema(log, df) - self.connection.check() - self._connection_checked = True - - self.connection.write_df_to_target( - df=df, - target=str(self.target), - **self._get_write_kwargs(), - ) + with override_job_description( + self.connection.spark, + f"{self.__class__.__name__}.run() -> {self.connection}", + ): + if not self._connection_checked: + self._log_parameters() + log_dataframe_schema(log, df) + self.connection.check() + self._connection_checked = True + + self.connection.write_df_to_target( + df=df, + target=str(self.target), + **self._get_write_kwargs(), + ) entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") diff --git a/onetl/file/file_df_reader/file_df_reader.py b/onetl/file/file_df_reader/file_df_reader.py index b18fc1792..f1e2f01eb 100644 --- a/onetl/file/file_df_reader/file_df_reader.py +++ b/onetl/file/file_df_reader/file_df_reader.py @@ -13,7 +13,7 @@ except (ImportError, AttributeError): from pydantic import PrivateAttr, validator # type: ignore[no-redef, assignment] -from onetl._util.spark import try_import_pyspark +from onetl._util.spark import override_job_description, try_import_pyspark from onetl.base import BaseFileDFConnection, BaseReadableFileFormat, PurePathProtocol from onetl.file.file_df_reader.options import FileDFReaderOptions from onetl.file.file_set import FileSet @@ -211,18 +211,23 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DataFrame: if not self._connection_checked: self._log_parameters(files) - paths: FileSet[PurePathProtocol] = FileSet() - if files is not None: - paths = FileSet(self._validate_files(files)) - elif self.source_path: - paths = FileSet([self.source_path]) + with override_job_description( + self.connection.spark, + f"{self.connection} -> {self.__class__.__name__}.run()", + ): + paths: FileSet[PurePathProtocol] = FileSet() + if files is not None: + paths = FileSet(self._validate_files(files)) + elif self.source_path: + paths = FileSet([self.source_path]) - if not self._connection_checked: - self.connection.check() - log_with_indent(log, "") - self._connection_checked = True + if not self._connection_checked: + self.connection.check() + log_with_indent(log, "") + self._connection_checked = True + + df = self._read_files(paths) - df = self._read_files(paths) entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") return df diff --git a/onetl/file/file_df_writer/file_df_writer.py b/onetl/file/file_df_writer/file_df_writer.py index a80f54801..87ea13179 100644 --- a/onetl/file/file_df_writer/file_df_writer.py +++ b/onetl/file/file_df_writer/file_df_writer.py @@ -10,6 +10,7 @@ except (ImportError, AttributeError): from pydantic import PrivateAttr, validator # type: ignore[no-redef, assignment] +from onetl._util.spark import override_job_description from onetl.base import BaseFileDFConnection, BaseWritableFileFormat, PurePathProtocol from onetl.file.file_df_writer.options import FileDFWriterOptions from onetl.hooks import slot, support_hooks @@ -120,17 +121,21 @@ def run(self, df: DataFrame) -> None: if df.isStreaming: raise ValueError(f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames.") - if not self._connection_checked: - self._log_parameters(df) - self.connection.check() - self._connection_checked = True - - self.connection.write_df_as_files( - df=df, - path=self.target_path, - format=self.format, - options=self.options, - ) + with override_job_description( + self.connection.spark, + f"{self.__class__.__name__}.run() -> {self.connection}", + ): + if not self._connection_checked: + self._log_parameters(df) + self.connection.check() + self._connection_checked = True + + self.connection.write_df_as_files( + df=df, + path=self.target_path, + format=self.format, + options=self.options, + ) entity_boundary_log(log, f"{self.__class__.__name__}.run() ends", char="-") diff --git a/tests/fixtures/spark.py b/tests/fixtures/spark.py index e7248e84f..7a9b812a4 100644 --- a/tests/fixtures/spark.py +++ b/tests/fixtures/spark.py @@ -123,12 +123,11 @@ def excluded_packages(): @pytest.fixture( scope="session", - name="spark", params=[ pytest.param("real-spark", marks=[pytest.mark.db_connection, pytest.mark.connection]), ], ) -def get_spark_session(warehouse_dir, spark_metastore_dir, ivysettings_path, maven_packages, excluded_packages): +def spark(warehouse_dir, spark_metastore_dir, ivysettings_path, maven_packages, excluded_packages): from pyspark.sql import SparkSession spark = ( diff --git a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py index ff36e0a66..287061d25 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py @@ -128,10 +128,10 @@ def test_clickhouse(spark_mock): "url": "jdbc:clickhouse://some_host:8123/database", } - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) assert conn.instance_url == "clickhouse://some_host:8123" + assert str(conn) == "Clickhouse[some_host:8123]" def test_clickhouse_with_port(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py index 0d382d44d..47821642b 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py @@ -129,10 +129,10 @@ def test_greenplum(spark_mock): "tcpKeepAlive": "true", } - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) assert conn.instance_url == "greenplum://some_host:5432/database" + assert str(conn) == "Greenplum[some_host:5432/database]" def test_greenplum_with_port(spark_mock): @@ -156,6 +156,7 @@ def test_greenplum_with_port(spark_mock): } assert conn.instance_url == "greenplum://some_host:5000/database" + assert str(conn) == "Greenplum[some_host:5000/database]" def test_greenplum_without_database_error(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py index 2e0ccd1a0..741013885 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py @@ -181,6 +181,7 @@ def test_kafka_basic_auth_get_jaas_conf(spark_mock): assert conn.addresses == ["192.168.1.1"] assert conn.instance_url == "kafka://some_cluster" + assert str(conn) == "Kafka[some_cluster]" def test_kafka_anon_auth(spark_mock): @@ -194,6 +195,7 @@ def test_kafka_anon_auth(spark_mock): assert conn.addresses == ["192.168.1.1"] assert conn.instance_url == "kafka://some_cluster" + assert str(conn) == "Kafka[some_cluster]" @pytest.mark.parametrize("digest", ["SHA-256", "SHA-512"]) @@ -217,6 +219,7 @@ def test_kafka_scram_auth(spark_mock, digest): assert conn.addresses == ["192.168.1.1"] assert conn.instance_url == "kafka://some_cluster" + assert str(conn) == "Kafka[some_cluster]" def test_kafka_auth_keytab(spark_mock, create_keytab): @@ -235,6 +238,7 @@ def test_kafka_auth_keytab(spark_mock, create_keytab): assert conn.addresses == ["192.168.1.1"] assert conn.instance_url == "kafka://some_cluster" + assert str(conn) == "Kafka[some_cluster]" def test_kafka_empty_addresses(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py index f494e3deb..9142848e1 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py @@ -126,9 +126,10 @@ def test_mongodb(spark_mock): assert conn.database == "database" assert conn.connection_url == "mongodb://user:password@host:27017/database" + assert conn.instance_url == "mongodb://host:27017/database" + assert str(conn) == "MongoDB[host:27017/database]" - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) @pytest.mark.parametrize( @@ -150,7 +151,7 @@ def test_mongodb_options_hint(): def test_mongodb_with_port(spark_mock): - mongo = MongoDB( + conn = MongoDB( host="host", user="user", password="password", @@ -159,14 +160,15 @@ def test_mongodb_with_port(spark_mock): spark=spark_mock, ) - assert mongo.host == "host" - assert mongo.port == 12345 - assert mongo.user == "user" - assert mongo.password != "password" - assert mongo.password.get_secret_value() == "password" - assert mongo.database == "database" + assert conn.host == "host" + assert conn.port == 12345 + assert conn.user == "user" + assert conn.password != "password" + assert conn.password.get_secret_value() == "password" + assert conn.database == "database" - assert mongo.connection_url == "mongodb://user:password@host:12345/database" + assert conn.connection_url == "mongodb://user:password@host:12345/database" + assert conn.instance_url == "mongodb://host:12345/database" def test_mongodb_without_mandatory_args(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py index e1a18aa9d..d9f3cfdab 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py @@ -101,10 +101,10 @@ def test_mssql(spark_mock): "databaseName": "database", } - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) assert conn.instance_url == "mssql://some_host:1433/database" + assert str(conn) == "MSSQL[some_host:1433/database]" def test_mssql_with_custom_port(spark_mock): @@ -127,6 +127,7 @@ def test_mssql_with_custom_port(spark_mock): } assert conn.instance_url == "mssql://some_host:5000/database" + assert str(conn) == "MSSQL[some_host:5000/database]" def test_mssql_with_instance_name(spark_mock): @@ -157,6 +158,7 @@ def test_mssql_with_instance_name(spark_mock): } assert conn.instance_url == "mssql://some_host\\myinstance/database" + assert str(conn) == "MSSQL[some_host\\myinstance/database]" def test_mssql_without_database_error(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py index f2c68d939..0d57da488 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py @@ -89,10 +89,10 @@ def test_mysql(spark_mock): "useUnicode": "yes", } - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) assert conn.instance_url == "mysql://some_host:3306" + assert str(conn) == "MySQL[some_host:3306]" def test_mysql_with_port(spark_mock): @@ -116,6 +116,7 @@ def test_mysql_with_port(spark_mock): } assert conn.instance_url == "mysql://some_host:5000" + assert str(conn) == "MySQL[some_host:5000]" def test_mysql_without_database(spark_mock): @@ -139,6 +140,7 @@ def test_mysql_without_database(spark_mock): } assert conn.instance_url == "mysql://some_host:3306" + assert str(conn) == "MySQL[some_host:3306]" def test_mysql_with_extra(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py index ae7bf87cd..dd02b5c95 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py @@ -110,10 +110,10 @@ def test_oracle(spark_mock): "url": "jdbc:oracle:thin:@some_host:1521:sid", } - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) assert conn.instance_url == "oracle://some_host:1521/sid" + assert str(conn) == "Oracle[some_host:1521/sid]" def test_oracle_with_port(spark_mock): @@ -135,6 +135,7 @@ def test_oracle_with_port(spark_mock): } assert conn.instance_url == "oracle://some_host:5000/sid" + assert str(conn) == "Oracle[some_host:5000/sid]" def test_oracle_uri_with_service_name(spark_mock): @@ -149,6 +150,7 @@ def test_oracle_uri_with_service_name(spark_mock): } assert conn.instance_url == "oracle://some_host:1521/service" + assert str(conn) == "Oracle[some_host:1521/service]" def test_oracle_without_sid_and_service_name(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py index 6e37417aa..2b0080bf8 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py @@ -90,10 +90,10 @@ def test_postgres(spark_mock): "stringtype": "unspecified", } - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) assert conn.instance_url == "postgres://some_host:5432/database" + assert str(conn) == "Postgres[some_host:5432/database]" def test_postgres_with_port(spark_mock): @@ -118,6 +118,7 @@ def test_postgres_with_port(spark_mock): } assert conn.instance_url == "postgres://some_host:5000/database" + assert str(conn) == "Postgres[some_host:5000/database]" def test_postgres_without_database_error(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py index bef65a554..557d4f29b 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py @@ -89,10 +89,10 @@ def test_teradata(spark_mock): "url": conn.jdbc_url, } - assert "password='passwd'" not in str(conn) - assert "password='passwd'" not in repr(conn) + assert "passwd" not in repr(conn) assert conn.instance_url == "teradata://some_host:1025" + assert str(conn) == "Teradata[some_host:1025]" def test_teradata_with_port(spark_mock): @@ -117,6 +117,7 @@ def test_teradata_with_port(spark_mock): } assert conn.instance_url == "teradata://some_host:5000" + assert str(conn) == "Teradata[some_host:1025]" def test_teradata_without_database(spark_mock): diff --git a/tests/tests_unit/tests_file_connection_unit/test_ftp_unit.py b/tests/tests_unit/tests_file_connection_unit/test_ftp_unit.py index ab47c2482..33f6b29f7 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_ftp_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_ftp_unit.py @@ -8,35 +8,34 @@ def test_ftp_connection(): from onetl.connection import FTP - ftp = FTP(host="some_host", user="some_user", password="pwd") - assert isinstance(ftp, FileConnection) - assert ftp.host == "some_host" - assert ftp.user == "some_user" - assert ftp.password != "pwd" - assert ftp.password.get_secret_value() == "pwd" - assert ftp.port == 21 + conn = FTP(host="some_host", user="some_user", password="pwd") + assert isinstance(conn, FileConnection) + assert conn.host == "some_host" + assert conn.user == "some_user" + assert conn.password != "pwd" + assert conn.password.get_secret_value() == "pwd" + assert conn.port == 21 - assert "password='pwd'" not in str(ftp) - assert "password='pwd'" not in repr(ftp) + assert str(conn) == "FTP[some_host:21]" + assert "pwd" not in repr(conn) def test_ftp_connection_anonymous(): from onetl.connection import FTP - ftp = FTP(host="some_host") - - assert isinstance(ftp, FileConnection) - assert ftp.host == "some_host" - assert ftp.user is None - assert ftp.password is None + conn = FTP(host="some_host") + assert conn.host == "some_host" + assert conn.user is None + assert conn.password is None def test_ftp_connection_with_port(): from onetl.connection import FTP - ftp = FTP(host="some_host", user="some_user", password="pwd", port=500) + conn = FTP(host="some_host", user="some_user", password="pwd", port=500) - assert ftp.port == 500 + assert conn.port == 500 + assert str(conn) == "FTP[some_host:500]" def test_ftp_connection_without_mandatory_args(): diff --git a/tests/tests_unit/tests_file_connection_unit/test_ftps_unit.py b/tests/tests_unit/tests_file_connection_unit/test_ftps_unit.py index aa63de1ef..c0b201e6c 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_ftps_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_ftps_unit.py @@ -8,35 +8,36 @@ def test_ftps_connection(): from onetl.connection import FTPS - ftps = FTPS(host="some_host", user="some_user", password="pwd") - assert isinstance(ftps, FileConnection) - assert ftps.host == "some_host" - assert ftps.user == "some_user" - assert ftps.password != "pwd" - assert ftps.password.get_secret_value() == "pwd" - assert ftps.port == 21 + conn = FTPS(host="some_host", user="some_user", password="pwd") + assert isinstance(conn, FileConnection) + assert conn.host == "some_host" + assert conn.user == "some_user" + assert conn.password != "pwd" + assert conn.password.get_secret_value() == "pwd" + assert conn.port == 21 - assert "password='pwd'" not in str(ftps) - assert "password='pwd'" not in repr(ftps) + assert str(conn) == "FTPS[some_host:21]" + assert "pwd" not in repr(conn) def test_ftps_connection_anonymous(): from onetl.connection import FTPS - ftps = FTPS(host="some_host") + conn = FTPS(host="some_host") - assert isinstance(ftps, FileConnection) - assert ftps.host == "some_host" - assert ftps.user is None - assert ftps.password is None + assert isinstance(conn, FileConnection) + assert conn.host == "some_host" + assert conn.user is None + assert conn.password is None def test_ftps_connection_with_port(): from onetl.connection import FTPS - ftps = FTPS(host="some_host", user="some_user", password="pwd", port=500) + conn = FTPS(host="some_host", user="some_user", password="pwd", port=500) - assert ftps.port == 500 + assert conn.port == 500 + assert str(conn) == "FTPS[some_host:500]" def test_ftps_connection_without_mandatory_args(): diff --git a/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py b/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py index 2249e2373..3a450208c 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py @@ -15,73 +15,74 @@ def test_hdfs_connection_with_host(): from onetl.connection import HDFS - hdfs = HDFS(host="some-host.domain.com") - assert isinstance(hdfs, FileConnection) - assert hdfs.host == "some-host.domain.com" - assert hdfs.webhdfs_port == 50070 - assert not hdfs.user - assert not hdfs.password - assert not hdfs.keytab - assert hdfs.instance_url == "hdfs://some-host.domain.com:50070" + conn = HDFS(host="some-host.domain.com") + assert isinstance(conn, FileConnection) + assert conn.host == "some-host.domain.com" + assert conn.webhdfs_port == 50070 + assert not conn.user + assert not conn.password + assert not conn.keytab + assert conn.instance_url == "hdfs://some-host.domain.com:50070" + assert str(conn) == "HDFS[some-host.domain.com:50070]" def test_hdfs_connection_with_cluster(): from onetl.connection import HDFS - hdfs = HDFS(cluster="rnd-dwh") - assert isinstance(hdfs, FileConnection) - assert hdfs.cluster == "rnd-dwh" - assert hdfs.webhdfs_port == 50070 - assert not hdfs.user - assert not hdfs.password - assert not hdfs.keytab - assert hdfs.instance_url == "rnd-dwh" + conn = HDFS(cluster="rnd-dwh") + assert conn.cluster == "rnd-dwh" + assert conn.webhdfs_port == 50070 + assert not conn.user + assert not conn.password + assert not conn.keytab + assert conn.instance_url == "rnd-dwh" + assert str(conn) == "HDFS[rnd-dwh]" def test_hdfs_connection_with_cluster_and_host(): from onetl.connection import HDFS - hdfs = HDFS(cluster="rnd-dwh", host="some-host.domain.com") - assert isinstance(hdfs, FileConnection) - assert hdfs.cluster == "rnd-dwh" - assert hdfs.host == "some-host.domain.com" - assert hdfs.instance_url == "rnd-dwh" + conn = HDFS(cluster="rnd-dwh", host="some-host.domain.com") + assert conn.cluster == "rnd-dwh" + assert conn.host == "some-host.domain.com" + assert conn.instance_url == "rnd-dwh" + assert str(conn) == "HDFS[rnd-dwh]" def test_hdfs_connection_with_port(): from onetl.connection import HDFS - hdfs = HDFS(host="some-host.domain.com", port=9080) - assert isinstance(hdfs, FileConnection) - assert hdfs.host == "some-host.domain.com" - assert hdfs.webhdfs_port == 9080 - assert hdfs.instance_url == "hdfs://some-host.domain.com:9080" + conn = HDFS(host="some-host.domain.com", port=9080) + assert conn.host == "some-host.domain.com" + assert conn.webhdfs_port == 9080 + assert conn.instance_url == "hdfs://some-host.domain.com:9080" + assert str(conn) == "HDFS[some-host.domain.com:9080]" def test_hdfs_connection_with_user(): from onetl.connection import HDFS - hdfs = HDFS(host="some-host.domain.com", user="some_user") - assert hdfs.host == "some-host.domain.com" - assert hdfs.webhdfs_port == 50070 - assert hdfs.user == "some_user" - assert not hdfs.password - assert not hdfs.keytab + conn = HDFS(host="some-host.domain.com", user="some_user") + assert conn.host == "some-host.domain.com" + assert conn.webhdfs_port == 50070 + assert conn.user == "some_user" + assert not conn.password + assert not conn.keytab def test_hdfs_connection_with_password(): from onetl.connection import HDFS - hdfs = HDFS(host="some-host.domain.com", user="some_user", password="pwd") - assert hdfs.host == "some-host.domain.com" - assert hdfs.webhdfs_port == 50070 - assert hdfs.user == "some_user" - assert hdfs.password != "pwd" - assert hdfs.password.get_secret_value() == "pwd" - assert not hdfs.keytab + conn = HDFS(host="some-host.domain.com", user="some_user", password="pwd") + assert conn.host == "some-host.domain.com" + assert conn.webhdfs_port == 50070 + assert conn.user == "some_user" + assert conn.password != "pwd" + assert conn.password.get_secret_value() == "pwd" + assert not conn.keytab + assert str(conn) == "HDFS[some-host.domain.com:50070]" - assert "password='pwd'" not in str(hdfs) - assert "password='pwd'" not in repr(hdfs) + assert "pwd" not in repr(conn) def test_hdfs_connection_with_keytab(request, tmp_path_factory): @@ -91,15 +92,15 @@ def test_hdfs_connection_with_keytab(request, tmp_path_factory): folder.mkdir(exist_ok=True, parents=True) keytab = folder / "user.keytab" keytab.touch() - hdfs = HDFS(host="some-host.domain.com", user="some_user", keytab=keytab) + conn = HDFS(host="some-host.domain.com", user="some_user", keytab=keytab) def finalizer(): shutil.rmtree(folder) request.addfinalizer(finalizer) - assert hdfs.user == "some_user" - assert not hdfs.password + assert conn.user == "some_user" + assert not conn.password def test_hdfs_connection_keytab_does_not_exist(): @@ -242,12 +243,12 @@ def get_webhdfs_port(cluster: str) -> int | None: assert HDFS(host="some-node.domain.com", cluster="rnd-dwh").webhdfs_port == 9080 -def test_hdfs_known_get_current(request, mocker): +def test_hdfs_known_get_current(request): from onetl.connection import HDFS - - # no hooks bound to HDFS.Slots.get_current_cluster + + # no hooks bound to conn.Slots.get_current_cluster error_msg = re.escape( - "HDFS.get_current() can be used only if there are some hooks bound to HDFS.Slots.get_current_cluster", + "conn.get_current() can be used only if there are some hooks bound to conn.Slots.get_current_cluster", ) with pytest.raises(RuntimeError, match=error_msg): HDFS.get_current() @@ -259,5 +260,5 @@ def get_current_cluster() -> str: request.addfinalizer(get_current_cluster.disable) - hdfs = HDFS.get_current() - assert hdfs.cluster == "rnd-dwh" + conn = HDFS.get_current() + assert conn.cluster == "rnd-dwh" diff --git a/tests/tests_unit/tests_file_connection_unit/test_s3_unit.py b/tests/tests_unit/tests_file_connection_unit/test_s3_unit.py index e652c24ed..f63b58e6a 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_s3_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_s3_unit.py @@ -6,29 +6,29 @@ def test_s3_connection(): from onetl.connection import S3 - s3 = S3( + conn = S3( host="some_host", access_key="access key", secret_key="some key", bucket="bucket", ) - assert s3.host == "some_host" - assert s3.access_key == "access key" - assert s3.secret_key != "some key" - assert s3.secret_key.get_secret_value() == "some key" - assert s3.protocol == "https" - assert s3.port == 443 - assert s3.instance_url == "s3://some_host:443" + assert conn.host == "some_host" + assert conn.access_key == "access key" + assert conn.secret_key != "some key" + assert conn.secret_key.get_secret_value() == "some key" + assert conn.protocol == "https" + assert conn.port == 443 + assert conn.instance_url == "s3://some_host:443/bucket" + assert str(conn) == "S3[some_host:443]" - assert "some key" not in str(s3) - assert "some key" not in repr(s3) + assert "some key" not in repr(conn) def test_s3_connection_with_session_token(): from onetl.connection import S3 - s3 = S3( + conn = S3( host="some_host", access_key="access_key", secret_key="some key", @@ -36,17 +36,16 @@ def test_s3_connection_with_session_token(): bucket="bucket", ) - assert s3.session_token != "some token" - assert s3.session_token.get_secret_value() == "some token" + assert conn.session_token != "some token" + assert conn.session_token.get_secret_value() == "some token" - assert "some token" not in str(s3) - assert "some token" not in repr(s3) + assert "some token" not in repr(conn) def test_s3_connection_https(): from onetl.connection import S3 - s3 = S3( + conn = S3( host="some_host", access_key="access_key", secret_key="secret_key", @@ -54,15 +53,16 @@ def test_s3_connection_https(): protocol="https", ) - assert s3.protocol == "https" - assert s3.port == 443 - assert s3.instance_url == "s3://some_host:443" + assert conn.protocol == "https" + assert conn.port == 443 + assert conn.instance_url == "s3://some_host:443/bucket" + assert str(conn) == "S3[some_host:443/bucket]" def test_s3_connection_http(): from onetl.connection import S3 - s3 = S3( + conn = S3( host="some_host", access_key="access_key", secret_key="secret_key", @@ -70,16 +70,17 @@ def test_s3_connection_http(): protocol="http", ) - assert s3.protocol == "http" - assert s3.port == 80 - assert s3.instance_url == "s3://some_host:80" + assert conn.protocol == "http" + assert conn.port == 80 + assert conn.instance_url == "s3://some_host:80/bucket" + assert str(conn) == "S3[some_host:80/bucket]" @pytest.mark.parametrize("protocol", ["http", "https"]) def test_s3_connection_with_port(protocol): from onetl.connection import S3 - s3 = S3( + conn = S3( host="some_host", port=9000, access_key="access_key", @@ -88,6 +89,7 @@ def test_s3_connection_with_port(protocol): protocol=protocol, ) - assert s3.protocol == protocol - assert s3.port == 9000 - assert s3.instance_url == "s3://some_host:9000" + assert conn.protocol == protocol + assert conn.port == 9000 + assert conn.instance_url == "s3://some_host:9000/bucket" + assert str(conn) == "S3[some_host:9000/bucket]" diff --git a/tests/tests_unit/tests_file_connection_unit/test_samba_unit.py b/tests/tests_unit/tests_file_connection_unit/test_samba_unit.py index 42f95b368..2dfd06e6b 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_samba_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_samba_unit.py @@ -8,36 +8,39 @@ def test_samba_connection(): from onetl.connection import Samba - samba = Samba(host="some_host", share="share_name", user="some_user", password="pwd") - assert isinstance(samba, FileConnection) - assert samba.host == "some_host" - assert samba.protocol == "SMB" - assert samba.domain == "" - assert samba.auth_type == "NTLMv2" - assert samba.port == 445 - assert samba.user == "some_user" - assert samba.password != "pwd" - assert samba.password.get_secret_value() == "pwd" + conn = Samba(host="some_host", share="share_name", user="some_user", password="pwd") + assert isinstance(conn, FileConnection) + assert conn.host == "some_host" + assert conn.port == 445 + assert conn.share == "share_name" + assert conn.protocol == "SMB" + assert conn.domain == "" + assert conn.auth_type == "NTLMv2" + assert conn.user == "some_user" + assert conn.password != "pwd" + assert conn.password.get_secret_value() == "pwd" - assert "password='pwd'" not in str(samba) - assert "password='pwd'" not in repr(samba) + assert conn.instance_url == "smb://some_host:445/share_name" + assert str(conn) == "Samba[some_host:445/share_name]" + + assert "pwd" not in repr(conn) def test_samba_connection_with_net_bios(): from onetl.connection import Samba - samba = Samba(host="some_host", share="share_name", user="some_user", password="pwd", protocol="NetBIOS") - assert samba.protocol == "NetBIOS" - assert samba.port == 139 + conn = Samba(host="some_host", share="share_name", user="some_user", password="pwd", protocol="NetBIOS") + assert conn.protocol == "NetBIOS" + assert conn.port == 139 @pytest.mark.parametrize("protocol", ["SMB", "NetBIOS"]) def test_samba_connection_with_custom_port(protocol): from onetl.connection import Samba - samba = Samba(host="some_host", share="share_name", user="some_user", password="pwd", protocol=protocol, port=444) - assert samba.protocol == protocol - assert samba.port == 444 + conn = Samba(host="some_host", share="share_name", user="some_user", password="pwd", protocol=protocol, port=444) + assert conn.protocol == protocol + assert conn.port == 444 def test_samba_connection_without_mandatory_args(): diff --git a/tests/tests_unit/tests_file_connection_unit/test_sftp_unit.py b/tests/tests_unit/tests_file_connection_unit/test_sftp_unit.py index 11f6cfbd2..d2e02b755 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_sftp_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_sftp_unit.py @@ -7,35 +7,41 @@ def test_sftp_connection_anonymous(): - from onetl.connection import SFTP + from onetl.connection import SFTP, FileConnection - sftp = SFTP(host="some_host") - assert sftp.host == "some_host" - assert sftp.port == 22 - assert not sftp.user - assert not sftp.password - assert not sftp.key_file + conn = SFTP(host="some_host") + assert isinstance(conn, FileConnection) + assert conn.host == "some_host" + assert conn.port == 22 + assert not conn.user + assert not conn.password + assert not conn.key_file + assert conn.instance_url == "sftp://some_host:22" + assert str(conn) == "SFTP[some_host:22]" def test_sftp_connection_with_port(): from onetl.connection import SFTP - sftp = SFTP(host="some_host", port=500) + conn = SFTP(host="some_host", port=500) - assert sftp.port == 500 + assert conn.port == 500 + assert conn.instance_url == "sftp://some_host:500" + assert str(conn) == "SFTP[some_host:500]" def test_sftp_connection_with_password(): from onetl.connection import SFTP - sftp = SFTP(host="some_host", user="some_user", password="pwd") - assert sftp.user == "some_user" - assert sftp.password != "pwd" - assert sftp.password.get_secret_value() == "pwd" - assert not sftp.key_file + conn = SFTP(host="some_host", user="some_user", password="pwd") + assert conn.user == "some_user" + assert conn.password != "pwd" + assert conn.password.get_secret_value() == "pwd" + assert not conn.key_file + assert conn.instance_url == "sftp://some_host:22" + assert str(conn) == "SFTP[some_host:22]" - assert "password='pwd'" not in str(sftp) - assert "password='pwd'" not in repr(sftp) + assert "pwd" not in repr(conn) def test_sftp_connection_with_key_file(request, tmp_path_factory): @@ -51,10 +57,10 @@ def finalizer(): request.addfinalizer(finalizer) - sftp = SFTP(host="some_host", user="some_user", key_file=key_file) - assert sftp.user == "some_user" - assert not sftp.password - assert sftp.key_file == key_file + conn = SFTP(host="some_host", user="some_user", key_file=key_file) + assert conn.user == "some_user" + assert not conn.password + assert conn.key_file == key_file def test_sftp_connection_key_file_does_not_exist(): diff --git a/tests/tests_unit/tests_file_connection_unit/test_webdav_unit.py b/tests/tests_unit/tests_file_connection_unit/test_webdav_unit.py index 7d92d494c..7f4586780 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_webdav_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_webdav_unit.py @@ -8,34 +8,39 @@ def test_webdav_connection(): from onetl.connection import WebDAV - webdav = WebDAV(host="some_host", user="some_user", password="pwd") - assert isinstance(webdav, FileConnection) - assert webdav.host == "some_host" - assert webdav.protocol == "https" - assert webdav.port == 443 - assert webdav.user == "some_user" - assert webdav.password != "pwd" - assert webdav.password.get_secret_value() == "pwd" + conn = WebDAV(host="some_host", user="some_user", password="pwd") + assert isinstance(conn, FileConnection) + assert conn.host == "some_host" + assert conn.protocol == "https" + assert conn.port == 443 + assert conn.user == "some_user" + assert conn.password != "pwd" + assert conn.password.get_secret_value() == "pwd" + assert conn.instance_url == "webdav://some_host:443" + assert str(conn) == "WebDAV[some_host:443]" - assert "password='pwd'" not in str(webdav) - assert "password='pwd'" not in repr(webdav) + assert "pwd" not in repr(conn) def test_webdav_connection_with_http(): from onetl.connection import WebDAV - webdav = WebDAV(host="some_host", user="some_user", password="pwd", protocol="http") - assert webdav.protocol == "http" - assert webdav.port == 80 + conn = WebDAV(host="some_host", user="some_user", password="pwd", protocol="http") + assert conn.protocol == "http" + assert conn.port == 80 + assert conn.instance_url == "webdav://some_host:80" + assert str(conn) == "WebDAV[some_host:80]" @pytest.mark.parametrize("protocol", ["http", "https"]) def test_webdav_connection_with_custom_port(protocol): from onetl.connection import WebDAV - webdav = WebDAV(host="some_host", user="some_user", password="pwd", port=500, protocol=protocol) - assert webdav.protocol == protocol - assert webdav.port == 500 + conn = WebDAV(host="some_host", user="some_user", password="pwd", port=500, protocol=protocol) + assert conn.protocol == protocol + assert conn.port == 500 + assert conn.instance_url == "webdav://some_host:500" + assert str(conn) == "WebDAV[some_host:500]" def test_webdav_connection_without_mandatory_args(): diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py index 08ca6c1f4..0d392c8d9 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py @@ -12,28 +12,31 @@ def test_spark_hdfs_with_cluster(spark_mock): - hdfs = SparkHDFS(cluster="rnd-dwh", spark=spark_mock) - assert isinstance(hdfs, BaseFileDFConnection) - assert hdfs.cluster == "rnd-dwh" - assert hdfs.host is None - assert hdfs.ipc_port == 8020 - assert hdfs.instance_url == "rnd-dwh" + conn = SparkHDFS(cluster="rnd-dwh", spark=spark_mock) + assert isinstance(conn, BaseFileDFConnection) + assert conn.cluster == "rnd-dwh" + assert conn.host is None + assert conn.ipc_port == 8020 + assert conn.instance_url == "rnd-dwh" + assert str(conn) == "HDFS[rnd-dwh]" def test_spark_hdfs_with_cluster_and_host(spark_mock): - hdfs = SparkHDFS(cluster="rnd-dwh", host="some-host.domain.com", spark=spark_mock) - assert isinstance(hdfs, BaseFileDFConnection) - assert hdfs.cluster == "rnd-dwh" - assert hdfs.host == "some-host.domain.com" - assert hdfs.instance_url == "rnd-dwh" + conn = SparkHDFS(cluster="rnd-dwh", host="some-host.domain.com", spark=spark_mock) + assert isinstance(conn, BaseFileDFConnection) + assert conn.cluster == "rnd-dwh" + assert conn.host == "some-host.domain.com" + assert conn.instance_url == "rnd-dwh" + assert str(conn) == "HDFS[rnd-dwh]" def test_spark_hdfs_with_port(spark_mock): - hdfs = SparkHDFS(cluster="rnd-dwh", port=9020, spark=spark_mock) - assert isinstance(hdfs, BaseFileDFConnection) - assert hdfs.cluster == "rnd-dwh" - assert hdfs.ipc_port == 9020 - assert hdfs.instance_url == "rnd-dwh" + conn = SparkHDFS(cluster="rnd-dwh", port=9020, spark=spark_mock) + assert isinstance(conn, BaseFileDFConnection) + assert conn.cluster == "rnd-dwh" + assert conn.ipc_port == 9020 + assert conn.instance_url == "rnd-dwh" + assert str(conn) == "HDFS[rnd-dwh]" def test_spark_hdfs_without_cluster(spark_mock): @@ -143,5 +146,5 @@ def get_current_cluster() -> str: request.addfinalizer(get_current_cluster.disable) - hdfs = SparkHDFS.get_current(spark=spark_mock) - assert hdfs.cluster == "rnd-dwh" + conn = SparkHDFS.get_current(spark=spark_mock) + assert conn.cluster == "rnd-dwh" diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py index e98c986cf..ac41f7f8d 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py @@ -13,6 +13,7 @@ def test_spark_local_fs_spark_local(spark_mock): conn = SparkLocalFS(spark=spark_mock) assert conn.spark == spark_mock assert conn.instance_url == f"file://{socket.getfqdn()}" + assert str(conn) == "LocalFS" @pytest.mark.parametrize("master", ["k8s", "yarn"]) diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py index 99a20633c..0c369e6c9 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py @@ -84,7 +84,7 @@ def spark_mock_hadoop_3(spark_mock): def test_spark_s3(spark_mock_hadoop_3): - s3 = SparkS3( + conn = SparkS3( host="some_host", access_key="access key", secret_key="some key", @@ -92,20 +92,20 @@ def test_spark_s3(spark_mock_hadoop_3): spark=spark_mock_hadoop_3, ) - assert s3.host == "some_host" - assert s3.access_key == "access key" - assert s3.secret_key != "some key" - assert s3.secret_key.get_secret_value() == "some key" - assert s3.protocol == "https" - assert s3.port == 443 - assert s3.instance_url == "s3://some_host:443" + assert conn.host == "some_host" + assert conn.access_key == "access key" + assert conn.secret_key != "some key" + assert conn.secret_key.get_secret_value() == "some key" + assert conn.protocol == "https" + assert conn.port == 443 + assert conn.instance_url == "s3://some_host:443/bucket" + assert str(conn) == "S3[some_host:443]" - assert "some key" not in str(s3) - assert "some key" not in repr(s3) + assert "some key" not in repr(conn) def test_spark_s3_with_protocol_https(spark_mock_hadoop_3): - s3 = SparkS3( + conn = SparkS3( host="some_host", access_key="access_key", secret_key="secret_key", @@ -114,13 +114,14 @@ def test_spark_s3_with_protocol_https(spark_mock_hadoop_3): spark=spark_mock_hadoop_3, ) - assert s3.protocol == "https" - assert s3.port == 443 - assert s3.instance_url == "s3://some_host:443" + assert conn.protocol == "https" + assert conn.port == 443 + assert conn.instance_url == "s3://some_host:443/bucket" + assert str(conn) == "S3[some_host:443]" def test_spark_s3_with_protocol_http(spark_mock_hadoop_3): - s3 = SparkS3( + conn = SparkS3( host="some_host", access_key="access_key", secret_key="secret_key", @@ -129,14 +130,15 @@ def test_spark_s3_with_protocol_http(spark_mock_hadoop_3): spark=spark_mock_hadoop_3, ) - assert s3.protocol == "http" - assert s3.port == 80 - assert s3.instance_url == "s3://some_host:80" + assert conn.protocol == "http" + assert conn.port == 80 + assert conn.instance_url == "s3://some_host:80/bucket" + assert str(conn) == "S3[some_host:80/bucket]" @pytest.mark.parametrize("protocol", ["http", "https"]) def test_spark_s3_with_port(spark_mock_hadoop_3, protocol): - s3 = SparkS3( + conn = SparkS3( host="some_host", port=9000, access_key="access_key", @@ -146,9 +148,9 @@ def test_spark_s3_with_port(spark_mock_hadoop_3, protocol): spark=spark_mock_hadoop_3, ) - assert s3.protocol == protocol - assert s3.port == 9000 - assert s3.instance_url == "s3://some_host:9000" + assert conn.protocol == protocol + assert conn.port == 9000 + assert conn.instance_url == "s3://some_host:9000" @pytest.mark.parametrize(