Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for query cancellation #23

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "wherobots-python-dbapi"
version = "0.9.0"
version = "0.9.1"
description = "Python DB-API driver for Wherobots DB"
authors = ["Maxime Petazzoni <[email protected]>"]
license = "Apache 2.0"
Expand All @@ -22,6 +22,7 @@ pandas = "^2.1.0"
StrEnum = "^0.4.15"
# pyarrow 14.0.2 doesn't limit numpy < 2, but it should, we do it here
numpy = "<2"
types-requests = "^2.32.0.20241016"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?


[tool.poetry.group.dev.dependencies]
mypy = "^1.8.0"
Expand Down
88 changes: 49 additions & 39 deletions wherobots/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
GeometryRepresentation,
)
from wherobots.db.cursor import Cursor
from wherobots.db.errors import NotSupportedError, OperationalError
from wherobots.db.errors import (
NotSupportedError,
OperationalError,
QueryCancelledError,
)


@dataclass
Expand Down Expand Up @@ -74,19 +78,19 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def close(self):
def close(self) -> None:
self.__ws.close()

def commit(self):
def commit(self) -> None:
raise NotSupportedError

def rollback(self):
def rollback(self) -> None:
raise NotSupportedError

def cursor(self) -> Cursor:
return Cursor(self.__execute_sql, self.__cancel_query)

def __main_loop(self):
def __main_loop(self) -> None:
"""Main background loop listening for messages from the SQL session."""
logging.info("Starting background connection handling loop...")
while self.__ws.protocol.state < websockets.protocol.State.CLOSING:
Expand All @@ -101,7 +105,7 @@ def __main_loop(self):
except Exception as e:
logging.exception("Error handling message from SQL session", exc_info=e)

def __listen(self):
def __listen(self) -> None:
"""Waits for the next message from the SQL session and processes it.

The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
Expand All @@ -120,61 +124,67 @@ def __listen(self):
)
return

if kind == EventKind.STATE_UPDATED:
# Incoming state transitions are handled here.
if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT:
try:
query.state = ExecutionState[message["state"].upper()]
logging.info("Query %s is now %s.", execution_id, query.state)
except KeyError:
logging.warning("Invalid state update message for %s", execution_id)
return

# Incoming state transitions are handled here.
if query.state == ExecutionState.SUCCEEDED:
self.__request_results(execution_id)
# On a state_updated event telling us the query succeeded,
# ask for results.
if kind == EventKind.STATE_UPDATED:
self.__request_results(execution_id)
return

# Otherwise, process the results from the execution_result event.
results = message.get("results")
if not results or not isinstance(results, dict):
logging.warning("Got no results back from %s.", execution_id)
return

query.state = ExecutionState.COMPLETED
query.handler(self._handle_results(execution_id, results))
elif query.state == ExecutionState.CANCELLED:
logging.info("Query %s has been cancelled.", execution_id)
query.handler(QueryCancelledError())
self.__queries.pop(execution_id)
elif query.state == ExecutionState.FAILED:
# Don't do anything here; the ERROR event is coming with more
# details.
pass

elif kind == EventKind.EXECUTION_RESULT:
results = message.get("results")
if not results or not isinstance(results, dict):
logging.warning("Got no results back from %s.", execution_id)
return

result_bytes = results.get("result_bytes")
result_format = results.get("format")
result_compression = results.get("compression")
logging.info(
"Received %d bytes of %s-compressed %s results from %s.",
len(result_bytes),
result_compression,
result_format,
execution_id,
)

query.state = ExecutionState.COMPLETED
if result_format == ResultsFormat.JSON:
query.handler(json.loads(result_bytes.decode("utf-8")))
elif result_format == ResultsFormat.ARROW:
buffer = pyarrow.py_buffer(result_bytes)
stream = pyarrow.input_stream(buffer, result_compression)
with pyarrow.ipc.open_stream(stream) as reader:
query.handler(reader.read_pandas())
else:
query.handler(
OperationalError(f"Unsupported results format {result_format}")
)
elif kind == EventKind.ERROR:
query.state = ExecutionState.FAILED
error = message.get("message")
query.handler(OperationalError(error))
else:
logging.warning("Received unknown %s event!", kind)

def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard +1 on this refactor

result_bytes = results.get("result_bytes")
result_format = results.get("format")
result_compression = results.get("compression")
logging.info(
"Received %d bytes of %s-compressed %s results from %s.",
len(result_bytes),
result_compression,
result_format,
execution_id,
)

if result_format == ResultsFormat.JSON:
return json.loads(result_bytes.decode("utf-8"))
elif result_format == ResultsFormat.ARROW:
buffer = pyarrow.py_buffer(result_bytes)
stream = pyarrow.input_stream(buffer, result_compression)
with pyarrow.ipc.open_stream(stream) as reader:
return reader.read_pandas()
else:
return OperationalError(f"Unsupported results format {result_format}")

def __send(self, message: dict[str, Any]) -> None:
request = json.dumps(message)
logging.debug("Request: %s", request)
Expand Down
6 changes: 3 additions & 3 deletions wherobots/db/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ExecutionState(LowercaseStrEnum):
COMPLETED = auto()
"The driver has completed processing the query results."

def is_terminal_state(self):
def is_terminal_state(self) -> bool:
return self in (
ExecutionState.COMPLETED,
ExecutionState.CANCELLED,
Expand Down Expand Up @@ -97,7 +97,7 @@ class AppStatus(StrEnum):
DESTROY_FAILED = auto()
DESTROYED = auto()

def is_starting(self):
def is_starting(self) -> bool:
return self in (
AppStatus.PENDING,
AppStatus.PREPARING,
Expand All @@ -107,7 +107,7 @@ def is_starting(self):
AppStatus.INITIALIZING,
)

def is_terminal_state(self):
def is_terminal_state(self) -> bool:
return self in (
AppStatus.PREPARE_FAILED,
AppStatus.DEPLOY_FAILED,
Expand Down
30 changes: 17 additions & 13 deletions wherobots/db/cursor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import queue
from typing import Any, Optional, List, Tuple

from .errors import ProgrammingError, DatabaseError
from .errors import DatabaseError, ProgrammingError, QueryCancelledError

_TYPE_MAP = {
"object": "STRING",
Expand All @@ -16,7 +16,7 @@

class Cursor:

def __init__(self, exec_fn, cancel_fn):
def __init__(self, exec_fn, cancel_fn) -> None:
self.__exec_fn = exec_fn
self.__cancel_fn = cancel_fn

Expand Down Expand Up @@ -51,7 +51,9 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:
return self.__results

result = self.__queue.get()
if isinstance(result, DatabaseError):
if isinstance(result, QueryCancelledError):
return None
elif isinstance(result, DatabaseError):
raise result

self.__rowcount = len(result)
Expand All @@ -72,7 +74,7 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:

return self.__results

def execute(self, operation: str, parameters: dict[str, Any] = None):
def execute(self, operation: str, parameters: dict[str, Any] = None) -> None:
if self.__current_execution_id:
self.__cancel_fn(self.__current_execution_id)

Expand All @@ -84,38 +86,40 @@ def execute(self, operation: str, parameters: dict[str, Any] = None):
sql = operation.format(**(parameters or {}))
self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result)

def executemany(self, operation: str, seq_of_parameters: list[dict[str, Any]]):
def executemany(
self, operation: str, seq_of_parameters: list[dict[str, Any]]
) -> None:
raise NotImplementedError

def fetchone(self):
def fetchone(self) -> Any:
results = self.__get_results()[self.__current_row :]
if len(results) == 0:
return None
self.__current_row += 1
return results[0]

def fetchmany(self, size: int = None):
def fetchmany(self, size: int = None) -> list[Any]:
size = size or self.arraysize
results = self.__get_results()[self.__current_row : self.__current_row + size]
self.__current_row += size
return results

def fetchall(self):
def fetchall(self) -> list[Any]:
return self.__get_results()[self.__current_row :]

def close(self):
def close(self) -> None:
"""Close the cursor."""
if self.__results is None and self.__current_execution_id:
self.__cancel_fn(self.__current_execution_id)

def __iter__(self):
def __iter__(self) -> Cursor:
return self

def __next__(self):
def __next__(self) -> None:
raise StopIteration

def __enter__(self):
def __enter__(self) -> Cursor:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
4 changes: 4 additions & 0 deletions wherobots/db/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ class InterfaceError(Error):
pass


class QueryCancelledError(Error):
pass


class DatabaseError(Error):
pass

Expand Down