Skip to content

Commit

Permalink
fix: RDS Data API - allow ANSI-compatible identifiers. (#2391)
Browse files Browse the repository at this point in the history
  • Loading branch information
kukushking authored Jul 11, 2023
1 parent f8590a1 commit 6c0f65b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 7 deletions.
33 changes: 26 additions & 7 deletions awswrangler/data_api/rds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""RDS Data API Connector."""
import datetime as dt
import logging
import re
import time
import uuid
from decimal import Decimal
Expand Down Expand Up @@ -227,6 +228,19 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame:
return dataframe


def escape_identifier(identifier: str, sql_mode: str = "mysql") -> str:
"""Escape identifiers. Uses MySQL-compatible backticks by default."""
if not isinstance(identifier, str):
raise TypeError("SQL identifier must be a string")
if re.search(r"\W", identifier):
raise TypeError(f"SQL identifier contains invalid characters: {identifier}")
if sql_mode == "mysql":
return f"`{identifier}`"
elif sql_mode == "ansi":
return f'"{identifier}"'
raise ValueError(f"Unknown SQL MODE: {sql_mode}")


def connect(
resource_arn: str, database: str, secret_arn: str = "", boto3_session: Optional[boto3.Session] = None, **kwargs: Any
) -> RdsDataApi:
Expand Down Expand Up @@ -271,8 +285,8 @@ def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) ->
return con.execute(sql, database=database)


def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str) -> None:
sql = f"DROP TABLE IF EXISTS `{table}`"
def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None:
sql = f"DROP TABLE IF EXISTS {escape_identifier(table, sql_mode=sql_mode)}"
_logger.debug("Drop table query:\n%s", sql)
con.execute(sql, database=database, transaction_id=transaction_id)

Expand All @@ -292,9 +306,10 @@ def _create_table(
index: bool,
dtype: Optional[Dict[str, str]],
varchar_lengths: Optional[Dict[str, int]],
sql_mode: str,
) -> None:
if mode == "overwrite":
_drop_table(con=con, table=table, database=database, transaction_id=transaction_id)
_drop_table(con=con, table=table, database=database, transaction_id=transaction_id, sql_mode=sql_mode)
elif _does_table_exist(con=con, table=table, database=database, transaction_id=transaction_id):
return

Expand All @@ -306,8 +321,8 @@ def _create_table(
varchar_lengths=varchar_lengths,
converter_func=_data_types.pyarrow2mysql,
)
cols_str: str = "".join([f"`{k}` {v},\n" for k, v in mysql_types.items()])[:-2]
sql = f"CREATE TABLE IF NOT EXISTS `{table}` (\n{cols_str})"
cols_str: str = "".join([f"{escape_identifier(k, sql_mode=sql_mode)} {v},\n" for k, v in mysql_types.items()])[:-2]
sql = f"CREATE TABLE IF NOT EXISTS {escape_identifier(table, sql_mode=sql_mode)} (\n{cols_str})"

_logger.debug("Create table query:\n%s", sql)
con.execute(sql, database=database, transaction_id=transaction_id)
Expand Down Expand Up @@ -388,6 +403,7 @@ def to_sql(
varchar_lengths: Optional[Dict[str, int]] = None,
use_column_names: bool = False,
chunksize: int = 200,
sql_mode: str = "mysql",
) -> None:
"""
Insert data using an SQL query on a Data API connection.
Expand Down Expand Up @@ -439,19 +455,22 @@ def to_sql(
index=index,
dtype=dtype,
varchar_lengths=varchar_lengths,
sql_mode=sql_mode,
)

if index:
df = df.reset_index(level=df.index.names)

if use_column_names:
insertion_columns = "(" + ", ".join([f"`{col}`" for col in df.columns]) + ")"
insertion_columns = (
"(" + ", ".join([f"{escape_identifier(col, sql_mode=sql_mode)}" for col in df.columns]) + ")"
)
else:
insertion_columns = ""

placeholders = ", ".join([f":{col}" for col in df.columns])

sql = f"""INSERT INTO `{table}` {insertion_columns} VALUES ({placeholders})"""
sql = f"INSERT INTO {escape_identifier(table, sql_mode=sql_mode)} {insertion_columns} VALUES ({placeholders})"
parameter_sets = _generate_parameter_sets(df)

for parameter_sets_chunk in _utils.chunkify(parameter_sets, max_length=chunksize):
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/test_data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def mysql_serverless_connector(databases_parameters: Dict[str, Any]) -> "RdsData
yield con


@pytest.fixture
def postgresql_serverless_connector(databases_parameters: Dict[str, Any]) -> "RdsDataApi":
con = create_rds_connector("postgresql_serverless", databases_parameters)
with con:
yield con


def test_connect_redshift_serverless_iam_role(databases_parameters: Dict[str, Any]) -> None:
workgroup_name = databases_parameters["redshift_serverless"]["workgroup"]
database = databases_parameters["redshift_serverless"]["database"]
Expand Down Expand Up @@ -68,6 +75,16 @@ def mysql_serverless_table(mysql_serverless_connector: "RdsDataApi") -> Iterator
mysql_serverless_connector.execute(f"DROP TABLE IF EXISTS test.{name}")


@pytest.fixture(scope="function")
def postgresql_serverless_table(postgresql_serverless_connector: "RdsDataApi") -> Iterator[str]:
name = f"tbl_{get_time_str_with_random_suffix()}"
print(f"Table name: {name}")
try:
yield name
finally:
postgresql_serverless_connector.execute(f"DROP TABLE IF EXISTS test.{name}")


def test_data_api_redshift_columnless_query(redshift_connector: "RedshiftDataApi") -> None:
dataframe = wr.data_api.redshift.read_sql_query("SELECT 1", con=redshift_connector)
unknown_column_indicator = "?column?"
Expand Down Expand Up @@ -223,3 +240,43 @@ def test_data_api_mysql_to_sql_mode(
def test_data_api_exception(mysql_serverless_connector: "RdsDataApi", mysql_serverless_table: str) -> None:
with pytest.raises(boto3.client("rds-data").exceptions.BadRequestException):
wr.data_api.rds.read_sql_query("CUPCAKE", con=mysql_serverless_connector)


def test_data_api_mysql_ansi(mysql_serverless_connector: "RdsDataApi", mysql_serverless_table: str) -> None:
database = "test"
frame = pd.DataFrame([[42, "test"]], columns=["id", "name"])

mysql_serverless_connector.execute("SET SESSION sql_mode='ANSI_QUOTES';")

wr.data_api.rds.to_sql(
df=frame,
con=mysql_serverless_connector,
table=mysql_serverless_table,
database=database,
sql_mode="ansi",
)

out_frame = wr.data_api.rds.read_sql_query(
f"SELECT name FROM {mysql_serverless_table} WHERE id = 42", con=mysql_serverless_connector
)
expected_dataframe = pd.DataFrame([["test"]], columns=["name"])
assert_pandas_equals(out_frame, expected_dataframe)


def test_data_api_postgresql(postgresql_serverless_connector: "RdsDataApi", postgresql_serverless_table: str) -> None:
database = "test"
frame = pd.DataFrame([[42, "test"]], columns=["id", "name"])

wr.data_api.rds.to_sql(
df=frame,
con=postgresql_serverless_connector,
table=postgresql_serverless_table,
database=database,
sql_mode="ansi",
)

out_frame = wr.data_api.rds.read_sql_query(
f"SELECT name FROM {postgresql_serverless_table} WHERE id = 42", con=postgresql_serverless_connector
)
expected_dataframe = pd.DataFrame([["test"]], columns=["name"])
assert_pandas_equals(out_frame, expected_dataframe)

0 comments on commit 6c0f65b

Please sign in to comment.