Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relational.write - Improve ConnectionError #370

1 change: 1 addition & 0 deletions core/src/datayoga_core/blocks/relational/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def get_engine(connection_name: str, context: Context, autocommit: bool = True)
query=query_args),
echo=connection_details.get("debug", False),
connect_args=connect_args,
pool_pre_ping=True,
**extra)

return engine, db_type
Expand Down
152 changes: 63 additions & 89 deletions core/src/datayoga_core/blocks/relational/write/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from datayoga_core.opcode import OpCode
from datayoga_core.result import BlockResult, Result, Status
from sqlalchemy import text
from sqlalchemy.exc import (DatabaseError, OperationalError,
PendingRollbackError)
from sqlalchemy.sql.expression import ColumnCollection

logger = logging.getLogger("dy")
Expand All @@ -30,69 +28,49 @@ def init(self, context: Optional[Context] = None):
self.setup_engine()

def setup_engine(self):
"""Sets up the SQLAlchemy engine and configure it."""
if self.engine:
return

try:
self.engine, self.db_type = relational_utils.get_engine(self.properties["connection"], self.context)

logger.debug(f"Connecting to {self.db_type}")
self.connection = self.engine.connect()

# Disable the new MySQL 8.0.17+ default behavior of requiring an alias for ON DUPLICATE KEY UPDATE
# This behavior is not supported by pymysql driver
if self.engine.driver == "pymysql":
self.engine.dialect._requires_alias_for_on_duplicate_key = False

self.schema = self.properties.get("schema")
self.table = self.properties.get("table")
self.opcode_field = self.properties.get("opcode_field")
self.load_strategy = self.properties.get("load_strategy")
self.keys = self.properties.get("keys")
self.mapping = self.properties.get("mapping")
self.foreach = self.properties.get("foreach")
self.tbl = sa.Table(self.table, sa.MetaData(schema=self.schema), autoload_with=self.engine)

if self.opcode_field:
self.business_key_columns = [column["column"] for column in write_utils.get_column_mapping(self.keys)]
self.mapping_columns = [column["column"] for column in write_utils.get_column_mapping(self.mapping)]

self.columns = self.business_key_columns + [x for x in self.mapping_columns
if x not in self.business_key_columns]

for column in self.columns:
if not any(col.name.lower() == column.lower() for col in self.tbl.columns):
raise ValueError(f"{column} column does not exist in {self.tbl.fullname} table")

conditions = []
for business_key_column in self.business_key_columns:
for tbl_column in self.tbl.columns:
if tbl_column.name.lower() == business_key_column.lower():
conditions.append(tbl_column == sa.bindparam(business_key_column))
break

self.delete_stmt = self.tbl.delete().where(sa.and_(*conditions))
self.upsert_stmt = self.generate_upsert_stmt()

except OperationalError as e:
self.dispose_engine()
raise ConnectionError(e)
except DatabaseError as e:
# Handling specific OracleDB errors: Network failure and Database restart
if self.db_type == relational_utils.DbType.ORACLE:
self.handle_oracle_database_error(e)
raise

def dispose_engine(self):
with suppress(Exception):
self.connection.close()
with suppress(Exception):
self.engine.dispose()
self.engine, self.db_type = relational_utils.get_engine(self.properties["connection"], self.context)

for attr in self._engine_fields:
setattr(self, attr, None)
# Disable the new MySQL 8.0.17+ default behavior of requiring an alias for ON DUPLICATE KEY UPDATE
# This behavior is not supported by pymysql driver
if self.engine.driver == "pymysql":
self.engine.dialect._requires_alias_for_on_duplicate_key = False

self.schema = self.properties.get("schema")
self.table = self.properties.get("table")
self.opcode_field = self.properties.get("opcode_field")
self.load_strategy = self.properties.get("load_strategy")
self.keys = self.properties.get("keys")
self.mapping = self.properties.get("mapping")
self.foreach = self.properties.get("foreach")
self.tbl = sa.Table(self.table, sa.MetaData(schema=self.schema), autoload_with=self.engine)

if self.opcode_field:
self.business_key_columns = [column["column"] for column in write_utils.get_column_mapping(self.keys)]
self.mapping_columns = [column["column"] for column in write_utils.get_column_mapping(self.mapping)]

self.columns = self.business_key_columns + [x for x in self.mapping_columns
if x not in self.business_key_columns]

for column in self.columns:
if not any(col.name.lower() == column.lower() for col in self.tbl.columns):
raise ValueError(f"{column} column does not exist in {self.tbl.fullname} table")

conditions = []
for business_key_column in self.business_key_columns:
for tbl_column in self.tbl.columns:
if tbl_column.name.lower() == business_key_column.lower():
conditions.append(tbl_column == sa.bindparam(business_key_column))
break

self.delete_stmt = self.tbl.delete().where(sa.and_(*conditions))
self.upsert_stmt = self.generate_upsert_stmt()

async def run(self, data: List[Dict[str, Any]]) -> BlockResult:
"""Runs the block with provided data and return the result."""
logger.debug(f"Running {self.get_block_name()}")
rejected_records: List[Result] = []

Expand Down Expand Up @@ -190,38 +168,27 @@ def generate_upsert_stmt(self) -> Any:
))

def execute(self, statement: Any, records: List[Dict[str, Any]]):
"""Executes a SQL statement with given records."""
if isinstance(statement, str):
statement = text(statement)

logger.debug(f"Executing {statement} on {records}")
connected = False
try:
if isinstance(statement, str):
statement = text(statement)
logger.debug(f"Executing {statement} on {records}")
self.connection.execute(statement, records)
if not self.connection._is_autocommit_isolation():
self.connection.commit()

except (OperationalError, PendingRollbackError) as e:
if self.db_type == relational_utils.DbType.SQLSERVER:
self.handle_mssql_operational_error(e)

self.dispose_engine()
raise ConnectionError(e)
except DatabaseError as e:
if self.db_type == relational_utils.DbType.ORACLE:
self.handle_oracle_database_error(e)

raise

def handle_mssql_operational_error(self, e):
"""Handling specific MSSQL cases: Conversion failed (245) and Truncated data (2628)"""
if e.orig.args[0] in (245, 2628):
raise

def handle_oracle_database_error(self, e):
"""Handling specific OracleDB cases: Network failure (DPY-4011) and Database restart (ORA-01089)"""
if "DPY-4011" in f"{e}" or "ORA-01089" in f"{e}":
self.dispose_engine()
raise ConnectionError(e)
with self.engine.connect() as connection:
connected = True
try:
connection.execute(statement, records)
if not connection._is_autocommit_isolation():
connection.commit()
except Exception:
raise
spicy-sauce marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
if not connected:
raise ConnectionError(e) from e

def execute_upsert(self, records: List[Dict[str, Any]]):
"""Upserts records into the table."""
if records:
logger.debug(f"Upserting {len(records)} record(s) to {self.table} table")
records_to_upsert = []
Expand All @@ -232,6 +199,7 @@ def execute_upsert(self, records: List[Dict[str, Any]]):
self.execute(self.upsert_stmt, records_to_upsert)

def execute_delete(self, records: List[Dict[str, Any]]):
"""Deletes records from the table."""
if records:
logger.debug(f"Deleting {len(records)} record(s) from {self.table} table")
records_to_delete = []
Expand All @@ -242,4 +210,10 @@ def execute_delete(self, records: List[Dict[str, Any]]):
self.execute(self.delete_stmt, records_to_delete)

def stop(self):
self.dispose_engine()
"""Disposes of the engine and cleans up resources."""
with suppress(Exception):
if self.engine:
self.engine.dispose()

for attr in self._engine_fields:
setattr(self, attr, None)
12 changes: 12 additions & 0 deletions integration-tests/common/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def add_to_emp_stream(redis_client: Redis):
],
"__$$opcode": "u"
},
# gender length is too long, should fail (except from Cassandra)
{
"_id": 11,
"fname": "jane",
"lname": "doe",
"country_code": 972,
"country_name": "israel",
"credit_card": "1000-2000-3000-4000",
"gender": "FF",
"addresses": [],
"__$$opcode": "u"
},
{
"_id": 12,
"fname": "john",
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/test_redis_to_cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def prepare_db():
def test_total_records(prepare_db):
session = cassandra_utils.get_cassandra_session(["localhost"])
total_employees = session.execute(f"select count(*) as total from {TABLE}").one()
assert total_employees.total == 3
assert total_employees.total == 4


def test_filtered_record(prepare_db):
Expand Down
4 changes: 2 additions & 2 deletions integration-tests/test_redis_to_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
("mysql", "hr"),
("pg", "hr"),
("oracle", "hr"),
pytest.param("sqlserver", "dbo", marks=pytest.mark.xfail)
pytest.param("sqlserver", "dbo", marks=pytest.mark.skip(reason="SQLServer test fails"))
])
def test_redis_to_relational_db(db_type: str, schema_name: Optional[str]):
"""Reads data from a Redis stream and writes it to a relational database."""
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_redis_to_relational_db(db_type: str, schema_name: Optional[str]):
with suppress(Exception):
redis_container.stop()
with suppress(Exception):
database_container.stop()
db_container.stop()


def check_results(engine: Engine, schema_name: Optional[str]):
Expand Down
Loading
Loading