Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SQLAlchemy 2.0 support #115

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
---
# docker compose -f docker-compose.yml up -d
version: "2.1"
services:
mysqldb:
image: mysql
image: mysql:8.0
restart: always
command: --default-authentication-plugin=mysql_native_password --bind-address=0.0.0.0
environment:
Expand All @@ -13,7 +12,7 @@ services:
ports:
- 3306:3306
mysqldb_ssh:
image: mysql
image: mysql:8.0
restart: always
command: --default-authentication-plugin=mysql_native_password --bind-address=0.0.0.0
environment:
Expand All @@ -34,7 +33,7 @@ services:
- PASSWORD_ACCESS=false
- USER_NAME=melty
volumes:
- ./ssh_tunnel/ssh-server-config:/config/ssh_host_keys:ro
- ./ssh_tunnel/ssh-server-config:/config/ssh_host_keys:rw
ports:
- "127.0.0.1:2223:2222"
networks:
Expand Down
382 changes: 220 additions & 162 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,27 @@ packages = [

[tool.poetry.dependencies]
python = ">=3.9"
fs-s3fs = { version = "==1.1.1", optional = true }
singer-sdk = { version="~=0.43.1" }
singer-sdk = { version="~=0.44.3" }
pymysql = "==1.1.1"
sqlalchemy = "<2"
sqlalchemy = "<3"
sshtunnel = "0.4.0"

# Binary client for MySQL
mysqlclient = { version = "==2.2.4", optional = true }

# S3 client
urllib3 = "<2"
fs-s3fs = { version = "==1.1.1", optional = true }

[tool.poetry.group.dev.dependencies]
faker = ">=20"
pytest = ">=7.3.2"
singer-sdk = { version="~=0.43.1", extras = ["testing"] }
singer-sdk = { version="~=0.44.3", extras = ["testing"] }
remote-pdb=">=2.1.0"

[tool.poetry.group.typing.dependencies]
mypy = ">=1.8.0"
sqlalchemy = { version = "<2", extras = ["mypy"] }
sqlalchemy = { version = "<3", extras = ["mypy"] }
types-paramiko = ">=3.4.0.20240120"

[tool.poetry.extras]
Expand Down
Empty file modified ssh_tunnel/ssh-server-config/ssh_host_rsa_key
100755 → 100644
Empty file.
61 changes: 41 additions & 20 deletions tap_mysql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema
from singer_sdk.helpers._typing import TypeConformanceLevel
from sqlalchemy import text

if TYPE_CHECKING:
from collections.abc import Iterable

from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.reflection import Inspector, ReflectedPrimaryKeyConstraint


unpatched_conform = (
singer_sdk.helpers._typing._conform_primitive_property # noqa: SLF001
Expand Down Expand Up @@ -58,19 +60,21 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
super().__init__(*args, **kwargs)
self.is_vitess = self.config.get("is_vitess")
self._table_cols_cache: dict[str, dict[str, sqlalchemy.Column]] = {}

if self.is_vitess is None:
self.logger.info(
"No is_vitess configuration provided, dynamically checking if "
"we are using a Vitess instance."
)
with self._connect() as conn:
output = conn.execute(
query = text(
"select variable_value from "
"performance_schema.global_variables where "
"variable_name='version_comment' and variable_value like "
"'PlanetScale%%'"
)
output = conn.execute(query)
rows = output.fetchall()
if len(rows) > 0:
self.logger.info(
Expand Down Expand Up @@ -214,45 +218,62 @@ def get_schema_names(self, engine: Engine, inspected: Inspector) -> list[str]:
return self.config["filter_schemas"]
return super().get_schema_names(engine, inspected)

def discover_catalog_entry(
def discover_catalog_entry( # noqa: PLR0913
self,
engine: Engine,
inspected: Inspector,
schema_name: str,
table_name: str,
is_view: bool, # noqa: FBT001
*,
reflected_columns: list[Any] | None = None,
reflected_pk: ReflectedPrimaryKeyConstraint | None = None,
reflected_indices: list[Any] | None = None,
) -> CatalogEntry:
"""Overrode to support Vitess as DESCRIBE is not supported for views.

Create `CatalogEntry` object for the given table or a view.
"""Create `CatalogEntry` object for the given table or a view.

Args:
engine: SQLAlchemy engine
inspected: SQLAlchemy inspector instance for engine
schema_name: Schema name to inspect
table_name: Name of the table or a view
is_view: Flag whether this object is a view, returned by `get_object_names`
is_view: Flag whether this object is a view
reflected_columns: Pre-reflected columns (optional)
reflected_pk: Pre-reflected primary key info (optional)
reflected_indices: Pre-reflected indices info (optional)

Returns:
`CatalogEntry` object for the given table or a view
"""
if self.is_vitess is False or is_view is False:
return super().discover_catalog_entry(
engine, inspected, schema_name, table_name, is_view
engine,
inspected,
schema_name,
table_name,
is_view,
reflected_columns=reflected_columns,
reflected_pk=reflected_pk,
reflected_indices=reflected_indices,
)
# For vitess views, we can't use DESCRIBE as it's not supported for
# views so we do the below.
unique_stream_id = self.get_fully_qualified_name(
db_name=None,
schema_name=schema_name,
table_name=table_name,
delimiter="-",
unique_stream_id = str(
self.get_fully_qualified_name(
db_name=None,
schema_name=schema_name,
table_name=table_name,
delimiter="-",
)
)

# Initialize columns list
table_schema = th.PropertiesList()
with self._connect() as conn:
columns = conn.execute(f"SHOW columns from `{schema_name}`.`{table_name}`")
result = conn.execute(
text(f"SHOW columns from `{schema_name}`.`{table_name}`")
)
columns = result.mappings()
for column in columns:
column_name = column["Field"]
is_nullable = column["Null"] == "YES"
Expand Down Expand Up @@ -358,9 +379,10 @@ def get_table_columns(
if full_table_name not in self._table_cols_cache:
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
with self._connect() as conn:
columns = conn.execute(
f"SHOW columns from `{schema_name}`.`{table_name}`"
result = conn.execute(
text(f"SHOW columns from `{schema_name}`.`{table_name}`")
)
columns = result.mappings()
self._table_cols_cache[full_table_name] = {
col_meta["Field"]: sqlalchemy.Column(
col_meta["Field"],
Expand Down Expand Up @@ -405,9 +427,7 @@ def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]:
"""
if context:
msg = f"Stream '{self.name}' does not support partitioning."
raise NotImplementedError(
msg,
)
raise NotImplementedError(msg)

# pulling rows with only selected columns from stream
selected_column_names = list(self.get_selected_schema()["properties"])
Expand All @@ -429,5 +449,6 @@ def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]:
conn.exec_driver_sql(
"set workload=olap"
) # See https://github.com/planetscale/discussion/discussions/190
for row in conn.execute(query):
result = conn.execute(query)
for row in result.mappings():
yield dict(row)
5 changes: 3 additions & 2 deletions tap_mysql/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def ssh_tunnel_connect(self, *, ssh_config: dict[str, Any], url: URL) -> URL:
self.ssh_tunnel: SSHTunnelForwarder = SSHTunnelForwarder(
ssh_address_or_host=(ssh_config["host"], ssh_config["port"]),
ssh_username=ssh_config["username"],
ssh_private_key=self.guess_key_type(ssh_config["private_key"]),
ssh_pkey=self.guess_key_type(ssh_config["private_key"]),
ssh_private_key_password=ssh_config.get("private_key_password"),
remote_bind_address=(url.host, url.port),
)
Expand All @@ -297,7 +297,8 @@ def ssh_tunnel_connect(self, *, ssh_config: dict[str, Any], url: URL) -> URL:

def clean_up(self) -> None:
"""Stop the SSH Tunnel."""
self.logger.info("Shutting down SSH Tunnel")
if self.logger and self.logger.handlers:
self.logger.info("Shutting down SSH Tunnel")
self.ssh_tunnel.stop()

def catch_signal(self, signum, frame) -> None: # noqa: ANN001 ARG002
Expand Down
78 changes: 43 additions & 35 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from faker import Faker
from singer_sdk.testing import get_tap_test_class, suites
from singer_sdk.testing.runners import TapTestRunner
from sqlalchemy import Column, DateTime, Integer, MetaData, Numeric, String, Table
from sqlalchemy import Column, DateTime, Integer, MetaData, Numeric, String, Table, text
from sqlalchemy.dialects.mysql import DATE, DATETIME, JSON, TIME

from tap_mysql.tap import TapMySQL
Expand Down Expand Up @@ -52,39 +52,42 @@ def setup_test_table(table_name, sqlalchemy_url):
Column("name", String(length=100)),
)
with engine.connect() as conn:
metadata_obj.create_all(conn)
conn.execute(f"TRUNCATE TABLE {table_name}")
for _ in range(1000):
insert = test_replication_key_table.insert().values(
updated_at=fake.date_between(date1, date2),
name=fake.name(),
)
conn.execute(insert)
with conn.begin(): # Ensure transaction is committed
metadata_obj.create_all(conn)
conn.execute(text(f"TRUNCATE TABLE {table_name}"))
for _ in range(5):
insert = test_replication_key_table.insert().values(
updated_at=fake.date_time_between(date1, date2),
name=fake.name(),
)
conn.execute(insert)


def teardown_test_table(table_name, sqlalchemy_url):
engine = sqlalchemy.create_engine(sqlalchemy_url)
with engine.connect() as conn:
conn.execute(f"DROP TABLE {table_name}")
conn.execute(text(f"DROP TABLE {table_name}"))


custom_test_replication_key = suites.TestSuite(
kind="tap",
tests=[TapTestReplicationKey],
)

with open("tests/resources/data.json") as f:
catalog_dict = json.load(f)

TapMySQLTest = get_tap_test_class(
tap_class=TapMySQL,
config=SAMPLE_CONFIG,
catalog="tests/resources/data.json",
catalog=catalog_dict,
custom_suites=[custom_test_replication_key],
)

TapMySQLTestNOSQLALCHEMY = get_tap_test_class(
tap_class=TapMySQL,
config=NO_SQLALCHEMY_CONFIG,
catalog="tests/resources/data.json",
catalog=catalog_dict,
custom_suites=[custom_test_replication_key],
)

Expand Down Expand Up @@ -129,15 +132,18 @@ def test_temporal_datatypes():
Column("column_timestamp", DATETIME),
)
with engine.connect() as conn:
if table.exists(conn):
table.drop(conn)
metadata_obj.create_all(conn)
insert = table.insert().values(
column_date="2022-03-19",
column_time="06:04:19.222",
column_timestamp="1918-02-03 13:00:01",
)
conn.execute(insert)
# Start a transaction
with conn.begin():
table.drop(engine, checkfirst=True)
metadata_obj.create_all(engine)
insert = table.insert().values(
column_date="2022-03-19",
column_time="06:04:19.222",
column_timestamp="1918-02-03 13:00:01",
)
conn.execute(insert)
# Transaction will be automatically committed here

tap = TapMySQL(config=SAMPLE_CONFIG)
tap_catalog = json.loads(tap.catalog_json_text)
altered_table_name = f"melty-{table_name}"
Expand Down Expand Up @@ -178,6 +184,7 @@ def test_temporal_datatypes():


def test_jsonb_json():
"""Test JSON type handling."""
table_name = "test_jsonb_json"
engine = sqlalchemy.create_engine(SAMPLE_CONFIG["sqlalchemy_url"])

Expand All @@ -188,13 +195,16 @@ def test_jsonb_json():
Column("column_json", JSON),
)
with engine.connect() as conn:
if table.exists(conn):
table.drop(conn)
metadata_obj.create_all(conn)
insert = table.insert().values(
column_json={"baz": "foo"},
)
conn.execute(insert)
# Start a transaction
with conn.begin():
table.drop(engine, checkfirst=True)
metadata_obj.create_all(engine)
insert = table.insert().values(
column_json={"baz": "foo"},
)
conn.execute(insert)
# Transaction will be automatically committed here

tap = TapMySQL(config=SAMPLE_CONFIG)
tap_catalog = json.loads(tap.catalog_json_text)
altered_table_name = f"melty-{table_name}"
Expand Down Expand Up @@ -238,9 +248,8 @@ def test_decimal():
Column("column", Numeric()),
)
with engine.connect() as conn:
if table.exists(conn):
table.drop(conn)
metadata_obj.create_all(conn)
table.drop(engine, checkfirst=True)
metadata_obj.create_all(engine)
insert = table.insert().values(column=decimal.Decimal("3.14"))
conn.execute(insert)
insert = table.insert().values(column=decimal.Decimal("12"))
Expand Down Expand Up @@ -283,10 +292,9 @@ def test_filter_schemas():
table = Table(table_name, metadata_obj, Column("id", Integer), schema="new_schema")

with engine.connect() as conn:
conn.execute("CREATE SCHEMA IF NOT EXISTS new_schema")
if table.exists(conn):
table.drop(conn)
metadata_obj.create_all(conn)
conn.execute(text("CREATE SCHEMA IF NOT EXISTS new_schema"))
table.drop(engine, checkfirst=True)
metadata_obj.create_all(engine)
filter_schemas_config = copy.deepcopy(SAMPLE_CONFIG)
filter_schemas_config.update({"filter_schemas": ["new_schema"]})
tap = TapMySQL(config=filter_schemas_config)
Expand Down
Loading
Loading