diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index f8dd4ba549ff3..25da5e8007660 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -722,7 +722,9 @@ Status FlightClient::Close() { return Status::OK(); } -bool FlightClient::supports_async() const { return transport_->supports_async(); } +bool FlightClient::supports_async() const { return transport_->CheckAsyncSupport().ok(); } + +Status FlightClient::CheckAsyncSupport() const { return transport_->CheckAsyncSupport(); } Status FlightClient::CheckOpen() const { if (closed_) { diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 705b36c23cebe..e26a821359781 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -396,6 +396,13 @@ class ARROW_FLIGHT_EXPORT FlightClient { /// \brief Whether this client supports asynchronous methods. bool supports_async() const; + /// \brief Check whether this client supports asynchronous methods. + /// + /// This is like supports_async(), except that a detailed error message + /// is returned if async support is not available. If async support is + /// available, this function returns successfully. + Status CheckAsyncSupport() const; + private: FlightClient(); Status CheckOpen() const; diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 55be3244fbde4..c84c5a18ff468 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -1832,7 +1832,9 @@ void AsyncClientTest::SetUpTest() { ASSERT_OK_AND_ASSIGN(client_, FlightClient::Connect(real_location, client_options)); ASSERT_TRUE(client_->supports_async()); + ASSERT_OK(client_->CheckAsyncSupport()); } + void AsyncClientTest::TearDownTest() { if (supports_async()) { ASSERT_OK(client_->Close()); diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h index 8cb2479d113fc..ee7bd01720730 100644 --- a/cpp/src/arrow/flight/transport.h +++ b/cpp/src/arrow/flight/transport.h @@ -201,7 +201,11 @@ class ARROW_FLIGHT_EXPORT ClientTransport { virtual Status DoExchange(const FlightCallOptions& options, std::unique_ptr* stream); - virtual bool supports_async() const { return false; } + bool supports_async() const { return CheckAsyncSupport().ok(); } + virtual Status CheckAsyncSupport() const { + return Status::NotImplemented( + "this Flight transport does not support async operations"); + } static void SetAsyncRpc(AsyncListenerBase* listener, std::unique_ptr&& rpc); static AsyncRpc* GetAsyncRpc(AsyncListenerBase* listener); diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc index 5baf687cd83dc..f7612759e8de6 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -1034,16 +1034,17 @@ class GrpcClientImpl : public internal::ClientTransport { ->StartCall(); } - bool supports_async() const override { return true; } + Status CheckAsyncSupport() const override { return Status::OK(); } #else void GetFlightInfoAsync(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::shared_ptr> listener) override { - listener->OnFinish( - Status::NotImplemented("gRPC 1.40 or newer is required to use async")); + listener->OnFinish(CheckAsyncSupport()); } - bool supports_async() const override { return false; } + Status CheckAsyncSupport() const override { + return Status::NotImplemented("gRPC 1.40 or newer is required to use async"); + } #endif private: diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 0572ed77b40ef..42b221ed72a1b 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -837,6 +837,12 @@ cdef class FlightInfo(_Weakrefable): cdef: unique_ptr[CFlightInfo] info + @staticmethod + cdef wrap(CFlightInfo c_info): + cdef FlightInfo obj = FlightInfo.__new__(FlightInfo) + obj.info.reset(new CFlightInfo(move(c_info))) + return obj + def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints, total_records, total_bytes): """Create a FlightInfo object from a schema, descriptor, and endpoints. @@ -1219,6 +1225,66 @@ cdef class FlightMetadataWriter(_Weakrefable): check_flight_status(self.writer.get().WriteMetadata(deref(buf))) +class AsyncioCall: + """State for an async RPC using asyncio.""" + + def __init__(self) -> None: + import asyncio + self._future = asyncio.get_running_loop().create_future() + + def as_awaitable(self) -> object: + return self._future + + def wakeup(self, result_or_exception) -> None: + # Mark the Future done from within its loop (asyncio + # objects are generally not thread-safe) + loop = self._future.get_loop() + if isinstance(result_or_exception, BaseException): + loop.call_soon_threadsafe( + self._future.set_exception, result_or_exception) + else: + loop.call_soon_threadsafe( + self._future.set_result, result_or_exception) + + +cdef class AsyncioFlightClient: + """ + A FlightClient with an asyncio-based async interface. + + This interface is EXPERIMENTAL. + """ + + cdef: + FlightClient _client + + def __init__(self, FlightClient client) -> None: + self._client = client + + async def get_flight_info( + self, + descriptor: FlightDescriptor, + *, + options: FlightCallOptions = None, + ): + call = AsyncioCall() + self._get_flight_info(call, descriptor, options) + return await call.as_awaitable() + + cdef _get_flight_info(self, call, descriptor, options): + cdef: + CFlightCallOptions* c_options = \ + FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + CFuture[CFlightInfo] c_future + + with nogil: + c_future = self._client.client.get().GetFlightInfoAsync( + deref(c_options), c_descriptor) + + BindFuture(move(c_future), call.wakeup, FlightInfo.wrap) + + cdef class FlightClient(_Weakrefable): """A client to a Flight service. @@ -1320,6 +1386,14 @@ cdef class FlightClient(_Weakrefable): check_flight_status(CFlightClient.Connect(c_location, c_options ).Value(&self.client)) + @property + def supports_async(self): + return self.client.get().supports_async() + + def as_async(self) -> None: + check_status(self.client.get().CheckAsyncSupport()) + return AsyncioFlightClient(self) + def wait_for_available(self, timeout=5): """Block until the server can be contacted. diff --git a/python/pyarrow/error.pxi b/python/pyarrow/error.pxi index 46ea021ebf634..4357cde32c31d 100644 --- a/python/pyarrow/error.pxi +++ b/python/pyarrow/error.pxi @@ -77,8 +77,8 @@ class ArrowCancelled(ArrowException): ArrowIOError = IOError -# This function could be written directly in C++ if we didn't -# define Arrow-specific subclasses (ArrowInvalid etc.) +# check_status() and convert_status() could be written directly in C++ +# if we didn't define Arrow-specific subclasses (ArrowInvalid etc.) cdef int check_status(const CStatus& status) except -1 nogil: if status.ok(): return 0 @@ -88,61 +88,74 @@ cdef int check_status(const CStatus& status) except -1 nogil: RestorePyError(status) return -1 - # We don't use Status::ToString() as it would redundantly include - # the C++ class name. - message = frombytes(status.message(), safe=True) - detail = status.detail() - if detail != nullptr: - message += ". Detail: " + frombytes(detail.get().ToString(), - safe=True) - - if status.IsInvalid(): - raise ArrowInvalid(message) - elif status.IsIOError(): - # Note: OSError constructor is - # OSError(message) - # or - # OSError(errno, message, filename=None) - # or (on Windows) - # OSError(errno, message, filename, winerror) - errno = ErrnoFromStatus(status) - winerror = WinErrorFromStatus(status) - if winerror != 0: - raise IOError(errno, message, None, winerror) - elif errno != 0: - raise IOError(errno, message) - else: - raise IOError(message) - elif status.IsOutOfMemory(): - raise ArrowMemoryError(message) - elif status.IsKeyError(): - raise ArrowKeyError(message) - elif status.IsNotImplemented(): - raise ArrowNotImplementedError(message) - elif status.IsTypeError(): - raise ArrowTypeError(message) - elif status.IsCapacityError(): - raise ArrowCapacityError(message) - elif status.IsIndexError(): - raise ArrowIndexError(message) - elif status.IsSerializationError(): - raise ArrowSerializationError(message) - elif status.IsCancelled(): - signum = SignalFromStatus(status) - if signum > 0: - raise ArrowCancelled(message, signum) - else: - raise ArrowCancelled(message) + raise convert_status(status) + + +cdef object convert_status(const CStatus& status): + if IsPyError(status): + try: + RestorePyError(status) + except BaseException as e: + return e + + # We don't use Status::ToString() as it would redundantly include + # the C++ class name. + message = frombytes(status.message(), safe=True) + detail = status.detail() + if detail != nullptr: + message += ". Detail: " + frombytes(detail.get().ToString(), + safe=True) + + if status.IsInvalid(): + return ArrowInvalid(message) + elif status.IsIOError(): + # Note: OSError constructor is + # OSError(message) + # or + # OSError(errno, message, filename=None) + # or (on Windows) + # OSError(errno, message, filename, winerror) + errno = ErrnoFromStatus(status) + winerror = WinErrorFromStatus(status) + if winerror != 0: + return IOError(errno, message, None, winerror) + elif errno != 0: + return IOError(errno, message) + else: + return IOError(message) + elif status.IsOutOfMemory(): + return ArrowMemoryError(message) + elif status.IsKeyError(): + return ArrowKeyError(message) + elif status.IsNotImplemented(): + return ArrowNotImplementedError(message) + elif status.IsTypeError(): + return ArrowTypeError(message) + elif status.IsCapacityError(): + return ArrowCapacityError(message) + elif status.IsIndexError(): + return ArrowIndexError(message) + elif status.IsSerializationError(): + return ArrowSerializationError(message) + elif status.IsCancelled(): + signum = SignalFromStatus(status) + if signum > 0: + return ArrowCancelled(message, signum) else: - message = frombytes(status.ToString(), safe=True) - raise ArrowException(message) + return ArrowCancelled(message) + else: + message = frombytes(status.ToString(), safe=True) + return ArrowException(message) -# This is an API function for C++ PyArrow +# These are API functions for C++ PyArrow cdef api int pyarrow_internal_check_status(const CStatus& status) \ except -1 nogil: return check_status(status) +cdef api object pyarrow_internal_convert_status(const CStatus& status): + return convert_status(status) + cdef class StopToken: cdef void init(self, CStopToken stop_token): diff --git a/python/pyarrow/includes/common.pxd b/python/pyarrow/includes/common.pxd index 9e139be3e5918..044dd0333f323 100644 --- a/python/pyarrow/includes/common.pxd +++ b/python/pyarrow/includes/common.pxd @@ -149,6 +149,20 @@ cdef extern from "arrow/result.h" namespace "arrow" nogil: T operator*() +cdef extern from "arrow/util/future.h" namespace "arrow" nogil: + cdef cppclass CFuture "arrow::Future"[T]: + CFuture() + + +cdef extern from "arrow/python/async.h" namespace "arrow::py" nogil: + # BindFuture's third argument is really a C++ callable with + # the signature `object(T*)`, but Cython does not allow declaring that. + # We use an ellipsis as a workaround. + # Another possibility is to type-erase the argument by making it + # `object(void*)`, but it would lose compile-time C++ type safety. + void BindFuture[T](CFuture[T], object cb, ...) + + cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil: T GetResultValue[T](CResult[T]) except * cdef function[F] BindFunction[F](void* unbound, object bound, ...) diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 624904ed77a69..4bddd2d080f5f 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -382,6 +382,9 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CResult[unique_ptr[CFlightClient]] Connect(const CLocation& location, const CFlightClientOptions& options) + c_bool supports_async() + CStatus CheckAsyncSupport() + CStatus Authenticate(CFlightCallOptions& options, unique_ptr[CClientAuthHandler] auth_handler) @@ -396,6 +399,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CResult[unique_ptr[CFlightListing]] ListFlights(CFlightCallOptions& options, CCriteria criteria) CResult[unique_ptr[CFlightInfo]] GetFlightInfo(CFlightCallOptions& options, CFlightDescriptor& descriptor) + CFuture[CFlightInfo] GetFlightInfoAsync(CFlightCallOptions& options, + CFlightDescriptor& descriptor) CResult[unique_ptr[CSchemaResult]] GetSchema(CFlightCallOptions& options, CFlightDescriptor& descriptor) CResult[unique_ptr[CFlightStreamReader]] DoGet(CFlightCallOptions& options, CTicket& ticket) diff --git a/python/pyarrow/includes/libarrow_python.pxd b/python/pyarrow/includes/libarrow_python.pxd index f08fcaa40d104..4d109fc660e08 100644 --- a/python/pyarrow/includes/libarrow_python.pxd +++ b/python/pyarrow/includes/libarrow_python.pxd @@ -258,7 +258,7 @@ cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": cdef extern from "arrow/python/common.h" namespace "arrow::py": c_bool IsPyError(const CStatus& status) - void RestorePyError(const CStatus& status) + void RestorePyError(const CStatus& status) except * cdef extern from "arrow/python/inference.h" namespace "arrow::py": diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index e8c89cf0d56dc..63ebe6aea8233 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -65,6 +65,7 @@ cdef extern from "Python.h": cdef int check_status(const CStatus& status) except -1 nogil +cdef object convert_status(const CStatus& status) cdef class _Weakrefable: diff --git a/python/pyarrow/src/arrow/python/async.h b/python/pyarrow/src/arrow/python/async.h new file mode 100644 index 0000000000000..1568d21938e6e --- /dev/null +++ b/python/pyarrow/src/arrow/python/async.h @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/python/common.h" +#include "arrow/status.h" +#include "arrow/util/future.h" + +namespace arrow::py { + +/// \brief Bind a Python callback to an arrow::Future. +/// +/// If the Future finishes successfully, py_wrapper is called with its +/// result value and should return a PyObject*. If py_wrapper is successful, +/// py_cb is called with its return value. +/// +/// If either the Future or py_wrapper fails, py_cb is called with the +/// associated Python exception. +/// +/// \param future The future to bind to. +/// \param py_cb The Python callback function. Will be passed the result of +/// py_wrapper, or a Python exception if the future failed or one was +/// raised by py_wrapper. +/// \param py_wrapper A function (likely defined in Cython) to convert the C++ +/// result of the future to a Python object. +template +void BindFuture(Future future, PyObject* py_cb, PyWrapper py_wrapper) { + Py_INCREF(py_cb); + OwnedRefNoGIL cb_ref(py_cb); + + auto future_cb = [cb_ref = std::move(cb_ref), + py_wrapper = std::move(py_wrapper)](Result result) { + SafeCallIntoPythonVoid([&]() { + OwnedRef py_value_or_exc{WrapResult(std::move(result), std::move(py_wrapper))}; + Py_XDECREF( + PyObject_CallFunctionObjArgs(cb_ref.obj(), py_value_or_exc.obj(), NULLPTR)); + ARROW_WARN_NOT_OK(CheckPyError(), "Internal error in async call"); + }); + }; + future.AddCallback(std::move(future_cb)); +} + +} // namespace arrow::py diff --git a/python/pyarrow/src/arrow/python/common.h b/python/pyarrow/src/arrow/python/common.h index bfd11ba702a84..e36c0834fd424 100644 --- a/python/pyarrow/src/arrow/python/common.h +++ b/python/pyarrow/src/arrow/python/common.h @@ -72,6 +72,37 @@ T GetResultValue(Result result) { } } +/// \brief Wrap a Result and return the corresponding Python object. +/// +/// If the Result is successful, py_wrapper is called with its result value +/// and should return a PyObject*. If py_wrapper is successful (returns +/// a non-NULL value), its return value is returned. +/// +/// If either the Result or py_wrapper fails, the associated Python exception +/// is raised and NULL is returned. +// +/// \param result The Result whose value to wrap in a Python object. +/// \param py_wrapper A function (likely defined in Cython) to convert the C++ +/// value of the Result to a Python object. +/// \return A new Python reference, or NULL if an exception occurred +template +PyObject* WrapResult(Result result, PyWrapper&& py_wrapper) { + static_assert(std::is_same_v()))>, + "PyWrapper argument to WrapResult should return a PyObject* " + "when called with a T*"); + Status st = result.status(); + if (st.ok()) { + PyObject* py_value = py_wrapper(result.MoveValueUnsafe()); + st = CheckPyError(); + if (st.ok()) { + return py_value; + } + Py_XDECREF(py_value); // should be null, but who knows + } + // Status is an error, convert it to an exception. + return internal::convert_status(st); +} + // A RAII-style helper that ensures the GIL is acquired inside a lexical block. class ARROW_PYTHON_EXPORT PyAcquireGIL { public: @@ -131,6 +162,19 @@ auto SafeCallIntoPython(Function&& func) -> decltype(func()) { return maybe_status; } +template +auto SafeCallIntoPythonVoid(Function&& func) -> decltype(func()) { + PyAcquireGIL lock; + PyObject* exc_type; + PyObject* exc_value; + PyObject* exc_traceback; + PyErr_Fetch(&exc_type, &exc_value, &exc_traceback); + func(); + if (exc_type != NULLPTR) { + PyErr_Restore(exc_type, exc_value, exc_traceback); + } +} + // A RAII primitive that DECREFs the underlying PyObject* when it // goes out of scope. class ARROW_PYTHON_EXPORT OwnedRef { diff --git a/python/pyarrow/src/arrow/python/pyarrow.cc b/python/pyarrow/src/arrow/python/pyarrow.cc index 30d1f04f12317..af0fbbad1f74b 100644 --- a/python/pyarrow/src/arrow/python/pyarrow.cc +++ b/python/pyarrow/src/arrow/python/pyarrow.cc @@ -24,6 +24,7 @@ #include "arrow/table.h" #include "arrow/tensor.h" #include "arrow/type.h" +#include "arrow/util/logging.h" #include "arrow/python/common.h" #include "arrow/python/datetime.h" @@ -89,6 +90,11 @@ namespace internal { int check_status(const Status& status) { return ::pyarrow_internal_check_status(status); } +PyObject* convert_status(const Status& status) { + DCHECK(!status.ok()); + return ::pyarrow_internal_convert_status(status); +} + } // namespace internal } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/pyarrow.h b/python/pyarrow/src/arrow/python/pyarrow.h index 4c365081d70ca..113035500c005 100644 --- a/python/pyarrow/src/arrow/python/pyarrow.h +++ b/python/pyarrow/src/arrow/python/pyarrow.h @@ -75,8 +75,13 @@ DECLARE_WRAP_FUNCTIONS(table, Table) namespace internal { +// If status is ok, return 0. +// If status is not ok, set Python error indicator and return -1. ARROW_PYTHON_EXPORT int check_status(const Status& status); +// Convert status to a Python exception object. Status must not be ok. +ARROW_PYTHON_EXPORT PyObject* convert_status(const Status& status); + } // namespace internal } // namespace py } // namespace arrow diff --git a/python/pyarrow/tests/test_flight_async.py b/python/pyarrow/tests/test_flight_async.py new file mode 100644 index 0000000000000..f3cd1bbb58e2f --- /dev/null +++ b/python/pyarrow/tests/test_flight_async.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio + +import pytest + +import pyarrow + +flight = pytest.importorskip("pyarrow.flight") +pytestmark = pytest.mark.flight + + +class ExampleServer(flight.FlightServerBase): + simple_info = flight.FlightInfo( + pyarrow.schema([("a", "int32")]), + flight.FlightDescriptor.for_command(b"simple"), + [], + -1, + -1, + ) + + def get_flight_info(self, context, descriptor): + if descriptor.command == b"simple": + return self.simple_info + elif descriptor.command == b"unknown": + raise NotImplementedError("Unknown command") + + raise NotImplementedError("Unknown descriptor") + + +def async_or_skip(client): + if not client.supports_async: + # Use async error message as skip message + with pytest.raises(NotImplementedError) as e: + client.as_async() + pytest.skip(str(e.value)) + + +@pytest.fixture(scope="module") +def flight_client(): + with ExampleServer() as server: + with flight.connect(f"grpc://localhost:{server.port}") as client: + yield client + + +@pytest.fixture(scope="module") +def async_client(flight_client): + async_or_skip(flight_client) + yield flight_client.as_async() + + +def test_async_support_property(flight_client): + assert isinstance(flight_client.supports_async, bool) + if flight_client.supports_async: + flight_client.as_async() + else: + with pytest.raises(NotImplementedError): + flight_client.as_async() + + +def test_get_flight_info(async_client): + async def _test(): + descriptor = flight.FlightDescriptor.for_command(b"simple") + info = await async_client.get_flight_info(descriptor) + assert info == ExampleServer.simple_info + + asyncio.run(_test()) + + +def test_get_flight_info_error(async_client): + async def _test(): + descriptor = flight.FlightDescriptor.for_command(b"unknown") + with pytest.raises(NotImplementedError) as excinfo: + await async_client.get_flight_info(descriptor) + + assert "Unknown command" in repr(excinfo.value) + + asyncio.run(_test())