diff --git a/poetry.lock b/poetry.lock index 0f0f883..edc0f48 100644 --- a/poetry.lock +++ b/poetry.lock @@ -935,6 +935,20 @@ files = [ {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, ] +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.12.2" @@ -1092,4 +1106,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0a1acf9eed47a80c983407b1a26e987abc803e9d5f5a86ab9ea79bab7a548e6a" +content-hash = "49d40f2cdd93ba901a15c9e0a73b5dcf2a767ef90eab25758cb461c0702234cc" diff --git a/pyproject.toml b/pyproject.toml index a8678b6..52b7084 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wherobots-python-dbapi" -version = "0.9.0" +version = "0.9.1" description = "Python DB-API driver for Wherobots DB" authors = ["Maxime Petazzoni "] license = "Apache 2.0" @@ -22,6 +22,7 @@ pandas = "^2.1.0" StrEnum = "^0.4.15" # pyarrow 14.0.2 doesn't limit numpy < 2, but it should, we do it here numpy = "<2" +types-requests = "^2.32.0.20241016" [tool.poetry.group.dev.dependencies] mypy = "^1.8.0" diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index 5bb1c3d..a0aff53 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -22,7 +22,11 @@ GeometryRepresentation, ) from wherobots.db.cursor import Cursor -from wherobots.db.errors import NotSupportedError, OperationalError +from wherobots.db.errors import ( + NotSupportedError, + OperationalError, + QueryCancelledError, +) @dataclass @@ -74,19 +78,19 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def close(self): + def close(self) -> None: self.__ws.close() - def commit(self): + def commit(self) -> None: raise NotSupportedError - def rollback(self): + def rollback(self) -> None: raise NotSupportedError def cursor(self) -> Cursor: return Cursor(self.__execute_sql, self.__cancel_query) - def __main_loop(self): + def __main_loop(self) -> None: """Main background loop listening for messages from the SQL session.""" logging.info("Starting background connection handling loop...") while self.__ws.protocol.state < websockets.protocol.State.CLOSING: @@ -101,7 +105,7 @@ def __main_loop(self): except Exception as e: logging.exception("Error handling message from SQL session", exc_info=e) - def __listen(self): + def __listen(self) -> None: """Waits for the next message from the SQL session and processes it. The code in this method is purposefully defensive to avoid unexpected situations killing the thread. @@ -120,7 +124,8 @@ def __listen(self): ) return - if kind == EventKind.STATE_UPDATED: + # Incoming state transitions are handled here. + if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT: try: query.state = ExecutionState[message["state"].upper()] logging.info("Query %s is now %s.", execution_id, query.state) @@ -128,46 +133,29 @@ def __listen(self): logging.warning("Invalid state update message for %s", execution_id) return - # Incoming state transitions are handled here. if query.state == ExecutionState.SUCCEEDED: - self.__request_results(execution_id) + # On a state_updated event telling us the query succeeded, + # ask for results. + if kind == EventKind.STATE_UPDATED: + self.__request_results(execution_id) + return + + # Otherwise, process the results from the execution_result event. + results = message.get("results") + if not results or not isinstance(results, dict): + logging.warning("Got no results back from %s.", execution_id) + return + + query.state = ExecutionState.COMPLETED + query.handler(self._handle_results(execution_id, results)) elif query.state == ExecutionState.CANCELLED: logging.info("Query %s has been cancelled.", execution_id) + query.handler(QueryCancelledError()) self.__queries.pop(execution_id) elif query.state == ExecutionState.FAILED: # Don't do anything here; the ERROR event is coming with more # details. pass - - elif kind == EventKind.EXECUTION_RESULT: - results = message.get("results") - if not results or not isinstance(results, dict): - logging.warning("Got no results back from %s.", execution_id) - return - - result_bytes = results.get("result_bytes") - result_format = results.get("format") - result_compression = results.get("compression") - logging.info( - "Received %d bytes of %s-compressed %s results from %s.", - len(result_bytes), - result_compression, - result_format, - execution_id, - ) - - query.state = ExecutionState.COMPLETED - if result_format == ResultsFormat.JSON: - query.handler(json.loads(result_bytes.decode("utf-8"))) - elif result_format == ResultsFormat.ARROW: - buffer = pyarrow.py_buffer(result_bytes) - stream = pyarrow.input_stream(buffer, result_compression) - with pyarrow.ipc.open_stream(stream) as reader: - query.handler(reader.read_pandas()) - else: - query.handler( - OperationalError(f"Unsupported results format {result_format}") - ) elif kind == EventKind.ERROR: query.state = ExecutionState.FAILED error = message.get("message") @@ -175,6 +163,28 @@ def __listen(self): else: logging.warning("Received unknown %s event!", kind) + def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any: + result_bytes = results.get("result_bytes") + result_format = results.get("format") + result_compression = results.get("compression") + logging.info( + "Received %d bytes of %s-compressed %s results from %s.", + len(result_bytes), + result_compression, + result_format, + execution_id, + ) + + if result_format == ResultsFormat.JSON: + return json.loads(result_bytes.decode("utf-8")) + elif result_format == ResultsFormat.ARROW: + buffer = pyarrow.py_buffer(result_bytes) + stream = pyarrow.input_stream(buffer, result_compression) + with pyarrow.ipc.open_stream(stream) as reader: + return reader.read_pandas() + else: + return OperationalError(f"Unsupported results format {result_format}") + def __send(self, message: dict[str, Any]) -> None: request = json.dumps(message) logging.debug("Request: %s", request) diff --git a/wherobots/db/constants.py b/wherobots/db/constants.py index 6228757..808b27a 100644 --- a/wherobots/db/constants.py +++ b/wherobots/db/constants.py @@ -44,7 +44,7 @@ class ExecutionState(LowercaseStrEnum): COMPLETED = auto() "The driver has completed processing the query results." - def is_terminal_state(self): + def is_terminal_state(self) -> bool: return self in ( ExecutionState.COMPLETED, ExecutionState.CANCELLED, @@ -97,7 +97,7 @@ class AppStatus(StrEnum): DESTROY_FAILED = auto() DESTROYED = auto() - def is_starting(self): + def is_starting(self) -> bool: return self in ( AppStatus.PENDING, AppStatus.PREPARING, @@ -107,7 +107,7 @@ def is_starting(self): AppStatus.INITIALIZING, ) - def is_terminal_state(self): + def is_terminal_state(self) -> bool: return self in ( AppStatus.PREPARE_FAILED, AppStatus.DEPLOY_FAILED, diff --git a/wherobots/db/cursor.py b/wherobots/db/cursor.py index ef47759..1d8b35e 100644 --- a/wherobots/db/cursor.py +++ b/wherobots/db/cursor.py @@ -1,7 +1,7 @@ import queue from typing import Any, Optional, List, Tuple -from .errors import ProgrammingError, DatabaseError +from .errors import DatabaseError, ProgrammingError, QueryCancelledError _TYPE_MAP = { "object": "STRING", @@ -16,7 +16,7 @@ class Cursor: - def __init__(self, exec_fn, cancel_fn): + def __init__(self, exec_fn, cancel_fn) -> None: self.__exec_fn = exec_fn self.__cancel_fn = cancel_fn @@ -51,7 +51,9 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]: return self.__results result = self.__queue.get() - if isinstance(result, DatabaseError): + if isinstance(result, QueryCancelledError): + return None + elif isinstance(result, DatabaseError): raise result self.__rowcount = len(result) @@ -72,7 +74,7 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]: return self.__results - def execute(self, operation: str, parameters: dict[str, Any] = None): + def execute(self, operation: str, parameters: dict[str, Any] = None) -> None: if self.__current_execution_id: self.__cancel_fn(self.__current_execution_id) @@ -84,38 +86,40 @@ def execute(self, operation: str, parameters: dict[str, Any] = None): sql = operation.format(**(parameters or {})) self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result) - def executemany(self, operation: str, seq_of_parameters: list[dict[str, Any]]): + def executemany( + self, operation: str, seq_of_parameters: list[dict[str, Any]] + ) -> None: raise NotImplementedError - def fetchone(self): + def fetchone(self) -> Any: results = self.__get_results()[self.__current_row :] if len(results) == 0: return None self.__current_row += 1 return results[0] - def fetchmany(self, size: int = None): + def fetchmany(self, size: int = None) -> list[Any]: size = size or self.arraysize results = self.__get_results()[self.__current_row : self.__current_row + size] self.__current_row += size return results - def fetchall(self): + def fetchall(self) -> list[Any]: return self.__get_results()[self.__current_row :] - def close(self): + def close(self) -> None: """Close the cursor.""" if self.__results is None and self.__current_execution_id: self.__cancel_fn(self.__current_execution_id) - def __iter__(self): + def __iter__(self) -> Cursor: return self - def __next__(self): + def __next__(self) -> None: raise StopIteration - def __enter__(self): + def __enter__(self) -> Cursor: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() diff --git a/wherobots/db/errors.py b/wherobots/db/errors.py index 4e18110..0b577fc 100644 --- a/wherobots/db/errors.py +++ b/wherobots/db/errors.py @@ -6,6 +6,10 @@ class InterfaceError(Error): pass +class QueryCancelledError(Error): + pass + + class DatabaseError(Error): pass