Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-36512: [C++][FlightRPC] Add async GetFlightInfo client call #36517

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ else()
add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc_impl::experimental)
endif()

# Was in a different namespace, or simply not supported, prior to this
if(ARROW_GRPC_VERSION VERSION_GREATER_EQUAL "1.40")
add_definitions(-DGRPC_ENABLE_ASYNC)
endif()

# </KLUDGE> Restore the CXXFLAGS that were modified above
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}")

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
#include "arrow/flight/server_middleware.h"
#include "arrow/flight/server_tracing_middleware.h"
#include "arrow/flight/types.h"
#include "arrow/flight/types_async.h"
58 changes: 58 additions & 0 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,56 @@
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/util/future.h"
#include "arrow/util/logging.h"

#include "arrow/flight/client_auth.h"
#include "arrow/flight/serialization_internal.h"
#include "arrow/flight/transport.h"
#include "arrow/flight/transport/grpc/grpc_client.h"
#include "arrow/flight/types.h"
#include "arrow/flight/types_async.h"

namespace arrow {

namespace flight {

namespace {
template <typename T>
class UnaryUnaryAsyncListener : public AsyncListener<T> {
public:
UnaryUnaryAsyncListener() : future_(arrow::Future<T>::Make()) {}

void OnNext(T result) override {
DCHECK(!result_.ok());
result_ = std::move(result);
}

void OnFinish(Status status) override {
if (status.ok()) {
DCHECK(result_.ok());
} else {
// Default-initialized result is not ok
DCHECK(!result_.ok());
result_ = std::move(status);
}
future_.MarkFinished(std::move(result_));
}

static std::pair<std::shared_ptr<AsyncListener<T>>, arrow::Future<T>> Make() {
auto self = std::make_shared<UnaryUnaryAsyncListener<T>>();
// Keep the listener alive by stashing it in the future
self->future_.AddCallback([self](const arrow::Result<T>&) {});
auto future = self->future_;
return std::make_pair(std::move(self), std::move(future));
}

private:
arrow::Result<T> result_;
arrow::Future<T> future_;
};
} // namespace

const char* kWriteSizeDetailTypeId = "flight::FlightWriteSizeStatusDetail";

FlightCallOptions::FlightCallOptions()
Expand Down Expand Up @@ -584,6 +622,24 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightClient::GetFlightInfo(
return info;
}

void FlightClient::GetFlightInfoAsync(
const FlightCallOptions& options, const FlightDescriptor& descriptor,
std::shared_ptr<AsyncListener<FlightInfo>> listener) {
if (auto status = CheckOpen(); !status.ok()) {
listener->OnFinish(std::move(status));
return;
}
transport_->GetFlightInfoAsync(options, descriptor, std::move(listener));
}

arrow::Future<FlightInfo> FlightClient::GetFlightInfoAsync(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
auto [listener, future] = UnaryUnaryAsyncListener<FlightInfo>::Make();
transport_->GetFlightInfoAsync(options, descriptor, std::move(listener));
return future;
}

arrow::Result<std::unique_ptr<SchemaResult>> FlightClient::GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
Expand Down Expand Up @@ -658,6 +714,8 @@ Status FlightClient::Close() {
return Status::OK();
}

bool FlightClient::supports_async() const { return transport_->supports_async(); }

Status FlightClient::CheckOpen() const {
if (closed_) {
return Status::Invalid("FlightClient is closed");
Expand Down
28 changes: 28 additions & 0 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,31 @@ class ARROW_FLIGHT_EXPORT FlightClient {
return GetFlightInfo({}, descriptor);
}

/// \brief Asynchronous GetFlightInfo.
/// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request
/// \param[in] listener Callbacks for response and RPC completion
///
/// This API is EXPERIMENTAL.
void GetFlightInfoAsync(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::shared_ptr<AsyncListener<FlightInfo>> listener);
void GetFlightInfoAsync(const FlightDescriptor& descriptor,
std::shared_ptr<AsyncListener<FlightInfo>> listener) {
return GetFlightInfoAsync({}, descriptor, std::move(listener));
}

/// \brief Asynchronous GetFlightInfo returning a Future.
/// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request
///
/// This API is EXPERIMENTAL.
arrow::Future<FlightInfo> GetFlightInfoAsync(const FlightCallOptions& options,
const FlightDescriptor& descriptor);
arrow::Future<FlightInfo> GetFlightInfoAsync(const FlightDescriptor& descriptor) {
return GetFlightInfoAsync({}, descriptor);
}

/// \brief Request schema for a single flight, which may be an existing
/// dataset or a command to be executed
/// \param[in] options Per-RPC options
Expand Down Expand Up @@ -355,6 +380,9 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \since 8.0.0
Status Close();

/// \brief Whether this client supports asynchronous methods.
bool supports_async() const;

private:
FlightClient();
Status CheckOpen() const;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/flight/flight_internals_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ void TestRoundtrip(const std::vector<FlightType>& values,
ASSERT_OK(internal::ToProto(values[i], &pb_value));

if constexpr (std::is_same_v<FlightType, FlightInfo>) {
FlightInfo::Data data;
ASSERT_OK(internal::FromProto(pb_value, &data));
FlightInfo value(std::move(data));
ASSERT_OK_AND_ASSIGN(FlightInfo value, internal::FromProto(pb_value));
EXPECT_EQ(values[i], value);
} else if constexpr (std::is_same_v<FlightType, SchemaResult>) {
std::string data;
Expand Down Expand Up @@ -742,5 +740,7 @@ TEST(TransportErrorHandling, ReconstructStatus) {
ASSERT_EQ(detail->extra_info(), "Binary error details");
}

// TODO: test TransportStatusDetail
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we resolve this TODO in this PR or can we defer to a follow-up PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather follow up in a separate PR since it appears the PR is already too large to get reviews


} // namespace flight
} // namespace arrow
25 changes: 24 additions & 1 deletion cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
#include "arrow/util/base64.h"
#include "arrow/util/future.h"
#include "arrow/util/logging.h"

#ifdef GRPCPP_GRPCPP_H
Expand Down Expand Up @@ -91,9 +92,16 @@ const char kAuthHeader[] = "authorization";
//------------------------------------------------------------
// Common transport tests

#ifdef GRPC_ENABLE_ASYNC
constexpr bool kGrpcSupportsAsync = true;
#else
constexpr bool kGrpcSupportsAsync = false;
#endif

class GrpcConnectivityTest : public ConnectivityTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -102,6 +110,7 @@ ARROW_FLIGHT_TEST_CONNECTIVITY(GrpcConnectivityTest);
class GrpcDataTest : public DataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -110,6 +119,7 @@ ARROW_FLIGHT_TEST_DATA(GrpcDataTest);
class GrpcDoPutTest : public DoPutTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -118,6 +128,7 @@ ARROW_FLIGHT_TEST_DO_PUT(GrpcDoPutTest);
class GrpcAppMetadataTest : public AppMetadataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -126,6 +137,7 @@ ARROW_FLIGHT_TEST_APP_METADATA(GrpcAppMetadataTest);
class GrpcIpcOptionsTest : public IpcOptionsTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -134,6 +146,7 @@ ARROW_FLIGHT_TEST_IPC_OPTIONS(GrpcIpcOptionsTest);
class GrpcCudaDataTest : public CudaDataTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
Expand All @@ -142,11 +155,21 @@ ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest);
class GrpcErrorHandlingTest : public ErrorHandlingTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
ARROW_FLIGHT_TEST_ERROR_HANDLING(GrpcErrorHandlingTest);

class GrpcAsyncClientTest : public AsyncClientTest, public ::testing::Test {
protected:
std::string transport() const override { return "grpc"; }
bool supports_async() const override { return kGrpcSupportsAsync; }
void SetUp() override { SetUpTest(); }
void TearDown() override { TearDownTest(); }
};
ARROW_FLIGHT_TEST_ASYNC_CLIENT(GrpcAsyncClientTest);

//------------------------------------------------------------
// Ad-hoc gRPC-specific tests

Expand Down Expand Up @@ -443,7 +466,7 @@ class TestTls : public ::testing::Test {
Location location_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<FlightServerBase> server_;
bool server_is_initialized_;
bool server_is_initialized_ = false;
};

// A server middleware that rejects all calls.
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/arrow/flight/serialization_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,21 @@ Status ToProto(const FlightDescriptor& descriptor, pb::FlightDescriptor* pb_desc

// FlightInfo

Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) {
RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info->descriptor));
arrow::Result<FlightInfo> FromProto(const pb::FlightInfo& pb_info) {
FlightInfo::Data info;
RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info.descriptor));

info->schema = pb_info.schema();
info.schema = pb_info.schema();

info->endpoints.resize(pb_info.endpoint_size());
info.endpoints.resize(pb_info.endpoint_size());
for (int i = 0; i < pb_info.endpoint_size(); ++i) {
RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info->endpoints[i]));
RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info.endpoints[i]));
}

info->total_records = pb_info.total_records();
info->total_bytes = pb_info.total_bytes();
info->ordered = pb_info.ordered();
return Status::OK();
info.total_records = pb_info.total_records();
info.total_bytes = pb_info.total_bytes();
info.ordered = pb_info.ordered();
return FlightInfo(std::move(info));
}

Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) {
Expand Down Expand Up @@ -291,9 +292,8 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) {

Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request) {
FlightInfo::Data data;
RETURN_NOT_OK(FromProto(pb_request.info(), &data));
request->info = std::make_unique<FlightInfo>(std::move(data));
ARROW_ASSIGN_OR_RAISE(FlightInfo info, FromProto(pb_request.info()));
request->info = std::make_unique<FlightInfo>(std::move(info));
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/serialization_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr);
Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint);
Status FromProto(const pb::RenewFlightEndpointRequest& pb_request,
RenewFlightEndpointRequest* request);
Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info);
arrow::Result<FlightInfo> FromProto(const pb::FlightInfo& pb_info);
lidavidm marked this conversation as resolved.
Show resolved Hide resolved
Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request);
Status FromProto(const pb::SchemaResult& pb_result, std::string* result);
Expand Down
Loading
Loading