Skip to content

Commit

Permalink
Merge branch 'main' into 2987-fix_truncated_parquet_files
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger authored Oct 10, 2024
2 parents f69e058 + 066134f commit 51def0d
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 23 deletions.
23 changes: 10 additions & 13 deletions awswrangler/athena/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import threading
from heapq import heappop, heappush
from typing import TYPE_CHECKING, Any, Match, NamedTuple
from typing import TYPE_CHECKING, Match, NamedTuple

import boto3

Expand All @@ -23,23 +23,23 @@ class _CacheInfo(NamedTuple):
has_valid_cache: bool
file_format: str | None = None
query_execution_id: str | None = None
query_execution_payload: dict[str, Any] | None = None
query_execution_payload: "QueryExecutionTypeDef" | None = None


class _LocalMetadataCacheManager:
def __init__(self) -> None:
self._lock: threading.Lock = threading.Lock()
self._cache: dict[str, Any] = {}
self._cache: dict[str, "QueryExecutionTypeDef"] = {}
self._pqueue: list[tuple[datetime.datetime, str]] = []
self._max_cache_size = 100

def update_cache(self, items: list[dict[str, Any]]) -> None:
def update_cache(self, items: list["QueryExecutionTypeDef"]) -> None:
"""
Update the local metadata cache with new query metadata.
Parameters
----------
items : List[Dict[str, Any]]
items
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
"""
with self._lock:
Expand All @@ -62,18 +62,17 @@ def update_cache(self, items: list[dict[str, Any]]) -> None:
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
self._cache[item["QueryExecutionId"]] = item

def sorted_successful_generator(self) -> list[dict[str, Any]]:
def sorted_successful_generator(self) -> list["QueryExecutionTypeDef"]:
"""
Sorts the entries in the local cache based on query Completion DateTime.
This is useful to guarantee LRU caching rules.
Returns
-------
List[Dict[str, Any]]
Returns successful DDL and DML queries sorted by query completion time.
"""
filtered: list[dict[str, Any]] = []
filtered: list["QueryExecutionTypeDef"] = []
for query in self._cache.values():
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
filtered.append(query)
Expand Down Expand Up @@ -115,9 +114,7 @@ def _compare_query_string(sql: str, other: str) -> bool:
comparison_query = _prepare_query_string_for_comparison(query_string=other)
_logger.debug("sql: %s", sql)
_logger.debug("comparison_query: %s", comparison_query)
if sql == comparison_query:
return True
return False
return sql == comparison_query


def _prepare_query_string_for_comparison(query_string: str) -> str:
Expand All @@ -135,7 +132,7 @@ def _get_last_query_infos(
max_remote_cache_entries: int,
boto3_session: boto3.Session | None = None,
workgroup: str | None = None,
) -> list[dict[str, Any]]:
) -> list["QueryExecutionTypeDef"]:
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
client_athena = _utils.client(service_name="athena", session=boto3_session)
page_size = 50
Expand All @@ -160,7 +157,7 @@ def _get_last_query_infos(
QueryExecutionIds=uncached_ids[i : i + page_size],
).get("QueryExecutions")
)
_cache_manager.update_cache(new_execution_data) # type: ignore[arg-type]
_cache_manager.update_cache(new_execution_data)
return _cache_manager.sorted_successful_generator()


Expand Down
5 changes: 4 additions & 1 deletion awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,10 @@ def read_sql_query(
# Substitute query parameters if applicable
sql, execution_params = _apply_formatter(sql, params, paramstyle)

if not client_request_token:
if not client_request_token and paramstyle != "qmark":
# For paramstyle=="qmark", we will need to use Athena's caching option.
# The issue is that when describing an Athena execution, the API does not return
# the parameters that were used.
cache_info: _CacheInfo = _check_for_cached_results(
sql=sql,
boto3_session=boto3_session,
Expand Down
22 changes: 13 additions & 9 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._cache import _cache_manager, _LocalMetadataCacheManager

if TYPE_CHECKING:
from mypy_boto3_athena.type_defs import QueryExecutionTypeDef
from mypy_boto3_glue.type_defs import ColumnOutputTypeDef

_QUERY_FINAL_STATES: list[str] = ["FAILED", "SUCCEEDED", "CANCELLED"]
Expand All @@ -53,7 +54,7 @@ class _QueryMetadata(NamedTuple):
binaries: list[str]
output_location: str | None
manifest_location: str | None
raw_payload: dict[str, Any]
raw_payload: "QueryExecutionTypeDef"


class _WorkGroupConfig(NamedTuple):
Expand Down Expand Up @@ -214,7 +215,7 @@ def _get_query_metadata(
query_execution_id: str,
boto3_session: boto3.Session | None = None,
categories: list[str] | None = None,
query_execution_payload: dict[str, Any] | None = None,
query_execution_payload: "QueryExecutionTypeDef" | None = None,
metadata_cache_manager: _LocalMetadataCacheManager | None = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
execution_params: list[str] | None = None,
Expand All @@ -225,12 +226,15 @@ def _get_query_metadata(
if query_execution_payload["Status"]["State"] != "SUCCEEDED":
reason: str = query_execution_payload["Status"]["StateChangeReason"]
raise exceptions.QueryFailed(f"Query error: {reason}")
_query_execution_payload: dict[str, Any] = query_execution_payload
_query_execution_payload = query_execution_payload
else:
_query_execution_payload = _executions.wait_query(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
_query_execution_payload = cast(
"QueryExecutionTypeDef",
_executions.wait_query(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
),
)
cols_types: dict[str, str] = get_query_columns_types(
query_execution_id=query_execution_id, boto3_session=boto3_session
Expand Down Expand Up @@ -266,8 +270,8 @@ def _get_query_metadata(
if "ResultConfiguration" in _query_execution_payload:
output_location = _query_execution_payload["ResultConfiguration"].get("OutputLocation")

athena_statistics: dict[str, int | str] = _query_execution_payload.get("Statistics", {})
manifest_location: str | None = str(athena_statistics.get("DataManifestLocation"))
athena_statistics = _query_execution_payload.get("Statistics", {})
manifest_location: str | None = athena_statistics.get("DataManifestLocation")

if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager:
metadata_cache_manager.update_cache(items=[_query_execution_payload])
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,65 @@ def test_athena_paramstyle_qmark_parameters(
assert len(df_out) == 1


@pytest.mark.parametrize(
"ctas_approach,unload_approach",
[
pytest.param(False, False, id="regular"),
pytest.param(True, False, id="ctas"),
pytest.param(False, True, id="unload"),
],
)
def test_athena_paramstyle_qmark_skip_caching(
path: str,
path2: str,
glue_database: str,
glue_table: str,
workgroup0: str,
ctas_approach: bool,
unload_approach: bool,
) -> None:
wr.s3.to_parquet(
df=get_df(),
path=path,
index=False,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
partition_cols=["par0", "par1"],
)

df_out = wr.athena.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
database=glue_database,
ctas_approach=ctas_approach,
unload_approach=unload_approach,
workgroup=workgroup0,
params=["Washington"],
paramstyle="qmark",
keep_files=False,
s3_output=path2,
athena_cache_settings={"max_cache_seconds": 300},
)

assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Washington"

df_out = wr.athena.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
database=glue_database,
ctas_approach=ctas_approach,
unload_approach=unload_approach,
workgroup=workgroup0,
params=["Seattle"],
paramstyle="qmark",
keep_files=False,
s3_output=path2,
athena_cache_settings={"max_cache_seconds": 300},
)

assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Seattle"


def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0):
wr.s3.to_parquet(
df=get_df(),
Expand Down

0 comments on commit 51def0d

Please sign in to comment.