diff --git a/cpp/src/arrow/buffer.h b/cpp/src/arrow/buffer.h index 08a3bd749e25d..7cc2d2c9cc8c4 100644 --- a/cpp/src/arrow/buffer.h +++ b/cpp/src/arrow/buffer.h @@ -346,6 +346,8 @@ class ARROW_EXPORT Buffer { static Result> ViewOrCopy( std::shared_ptr source, const std::shared_ptr& to); + virtual std::shared_ptr device_sync_event() { return NULLPTR; } + protected: bool is_mutable_; bool is_cpu_; diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 13355dd6d05ae..b967af28e4aec 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -521,8 +521,7 @@ struct ExportedArrayPrivateData : PoolAllocationMixin SmallVector child_pointers_; std::shared_ptr data_; - - RawSyncEvent sync_event_; + std::shared_ptr sync_; ExportedArrayPrivateData() = default; ARROW_DEFAULT_MOVE_AND_ASSIGN(ExportedArrayPrivateData); @@ -547,10 +546,6 @@ void ReleaseExportedArray(struct ArrowArray* array) { } DCHECK_NE(array->private_data, nullptr); auto* pdata = reinterpret_cast(array->private_data); - if (pdata->sync_event_.sync_event != nullptr && - pdata->sync_event_.release_func != nullptr) { - pdata->sync_event_.release_func(pdata->sync_event_.sync_event); - } delete pdata; ArrowArrayMarkReleased(array); @@ -591,7 +586,7 @@ struct ArrayExporter { // Store owning pointer to ArrayData export_.data_ = data; - export_.sync_event_ = RawSyncEvent(); + export_.sync_ = nullptr; return Status::OK(); } @@ -714,12 +709,9 @@ Result, int64_t>> ValidateDeviceIn return std::make_pair(device_type, device_id); } -Status ExportDeviceArray(const Array& array, RawSyncEvent sync_event, +Status ExportDeviceArray(const Array& array, std::shared_ptr sync, struct ArrowDeviceArray* out, struct ArrowSchema* out_schema) { - if (sync_event.sync_event != nullptr && sync_event.release_func) { - return Status::Invalid( - "Must provide a release event function if providing a non-null event"); - } + void* sync_event = sync ? sync->get_raw() : nullptr; SchemaExportGuard guard(out_schema); if (out_schema != nullptr) { @@ -739,19 +731,20 @@ Status ExportDeviceArray(const Array& array, RawSyncEvent sync_event, exporter.Finish(&out->array); auto* pdata = reinterpret_cast(out->array.private_data); - pdata->sync_event_ = sync_event; - out->sync_event = sync_event.sync_event; + pdata->sync_ = std::move(sync); + out->sync_event = sync_event; guard.Detach(); return Status::OK(); } -Status ExportDeviceRecordBatch(const RecordBatch& batch, RawSyncEvent sync_event, +Status ExportDeviceRecordBatch(const RecordBatch& batch, + std::shared_ptr sync, struct ArrowDeviceArray* out, struct ArrowSchema* out_schema) { - if (sync_event.sync_event != nullptr && sync_event.release_func == nullptr) { - return Status::Invalid( - "Must provide a release event function if providing a non-null event"); + void* sync_event{nullptr}; + if (sync) { + sync_event = sync->get_raw(); } // XXX perhaps bypass ToStructArray for speed? @@ -776,8 +769,8 @@ Status ExportDeviceRecordBatch(const RecordBatch& batch, RawSyncEvent sync_event exporter.Finish(&out->array); auto* pdata = reinterpret_cast(out->array.private_data); - pdata->sync_event_ = sync_event; - out->sync_event = sync_event.sync_event; + pdata->sync_ = std::move(sync); + out->sync_event = sync_event; guard.Detach(); return Status::OK(); @@ -1362,7 +1355,7 @@ namespace { // The ArrowArray is released on destruction. struct ImportedArrayData { struct ArrowArray array_; - void* sync_event_; + std::shared_ptr device_sync_; ImportedArrayData() { ArrowArrayMarkReleased(&array_); // Initially released @@ -1395,6 +1388,10 @@ class ImportedBuffer : public Buffer { ~ImportedBuffer() override {} + std::shared_ptr device_sync_event() override { + return import_->device_sync_; + } + protected: std::shared_ptr import_; }; @@ -1409,7 +1406,10 @@ struct ArrayImporter { ARROW_ASSIGN_OR_RAISE(memory_mgr_, mapper(src->device_type, src->device_id)); device_type_ = static_cast(src->device_type); RETURN_NOT_OK(Import(&src->array)); - import_->sync_event_ = src->sync_event; + if (src->sync_event != nullptr) { + ARROW_ASSIGN_OR_RAISE(import_->device_sync_, memory_mgr_->WrapDeviceSyncEvent( + src->sync_event, [](void*) {})); + } // reset internal state before next import memory_mgr_.reset(); device_type_ = DeviceAllocationType::kCPU; diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h index 92707a59729fc..45583109a761f 100644 --- a/cpp/src/arrow/c/bridge.h +++ b/cpp/src/arrow/c/bridge.h @@ -22,6 +22,7 @@ #include #include "arrow/c/abi.h" +#include "arrow/device.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_fwd.h" @@ -172,17 +173,6 @@ Result> ImportRecordBatch(struct ArrowArray* array, /// /// @{ -/// \brief EXPERIMENTAL: Type for freeing a sync event -/// -/// If synchronization is necessary for accessing the data on a device, -/// a pointer to an event needs to be passed when exporting the device -/// array. It's the responsibility of the release function for the array -/// to release the event. Both can be null if no sync'ing is necessary. -struct RawSyncEvent { - void* sync_event = NULL; - std::function release_func; -}; - /// \brief EXPERIMENTAL: Export C++ Array as an ArrowDeviceArray. /// /// The resulting ArrowDeviceArray struct keeps the array data and buffers alive @@ -190,15 +180,15 @@ struct RawSyncEvent { /// the provided array MUST have the same device_type, otherwise an error /// will be returned. /// -/// If a non-null sync_event is provided, then the sync_release func must also be -/// non-null. If the sync_event is null, then the sync_release parameter is not called. +/// If sync is non-null, get_event will be called on it in order to +/// potentially provide an event for consumers to synchronize on. /// /// \param[in] array Array object to export -/// \param[in] sync_event A struct containing what is needed for syncing if necessary +/// \param[in] sync shared_ptr to object derived from Device::SyncEvent or null /// \param[out] out C struct to export the array to /// \param[out] out_schema optional C struct to export the array type to ARROW_EXPORT -Status ExportDeviceArray(const Array& array, RawSyncEvent sync_event, +Status ExportDeviceArray(const Array& array, std::shared_ptr sync, struct ArrowDeviceArray* out, struct ArrowSchema* out_schema = NULLPTR); @@ -212,15 +202,16 @@ Status ExportDeviceArray(const Array& array, RawSyncEvent sync_event, /// otherwise an error will be returned. If columns are on different devices, /// they should be exported using different ArrowDeviceArray instances. /// -/// If a non-null sync_event is provided, then the sync_release func must also be -/// non-null. If the sync_event is null, then the sync_release parameter is ignored. +/// If sync is non-null, get_event will be called on it in order to +/// potentially provide an event for consumers to synchronize on. /// /// \param[in] batch Record batch to export -/// \param[in] sync_event A struct containing what is needed for syncing if necessary +/// \param[in] sync shared_ptr to object derived from Device::SyncEvent or null /// \param[out] out C struct where to export the record batch /// \param[out] out_schema optional C struct where to export the record batch schema ARROW_EXPORT -Status ExportDeviceRecordBatch(const RecordBatch& batch, RawSyncEvent sync_event, +Status ExportDeviceRecordBatch(const RecordBatch& batch, + std::shared_ptr sync, struct ArrowDeviceArray* out, struct ArrowSchema* out_schema = NULLPTR); diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 5c7de8e4a0783..9727403163e58 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -1135,12 +1135,49 @@ TEST_F(TestArrayExport, ExportRecordBatch) { static const char kMyDeviceTypeName[] = "arrowtest::MyDevice"; static const ArrowDeviceType kMyDeviceType = ARROW_DEVICE_EXT_DEV; +static const void* kMyEventPtr = reinterpret_cast(uintptr_t(0xBAADF00D)); class MyBuffer final : public MutableBuffer { public: using MutableBuffer::MutableBuffer; ~MyBuffer() { default_memory_pool()->Free(const_cast(data_), size_); } + + std::shared_ptr device_sync_event() override { return device_sync_; } + + protected: + std::shared_ptr device_sync_; +}; + +class MyDevice : public Device { + public: + explicit MyDevice(int64_t value) : Device(true), value_(value) {} + const char* type_name() const override { return kMyDeviceTypeName; } + std::string ToString() const override { return kMyDeviceTypeName; } + bool Equals(const Device& other) const override { + if (other.type_name() != kMyDeviceTypeName || other.device_type() != device_type()) { + return false; + } + return checked_cast(other).value_ == value_; + } + DeviceAllocationType device_type() const override { + return static_cast(kMyDeviceType); + } + int64_t device_id() const override { return value_; } + std::shared_ptr default_memory_manager() override; + + class MySyncEvent final : public Device::SyncEvent { + public: + explicit MySyncEvent(void* sync_event, release_fn_t release_sync_event) + : Device::SyncEvent(sync_event, release_sync_event) {} + + virtual ~MySyncEvent() = default; + Status Wait() override { return Status::OK(); } + Status Record(const Device::Stream&) override { return Status::OK(); } + }; + + protected: + int64_t value_; }; class MyMemoryManager : public CPUMemoryManager { @@ -1154,6 +1191,16 @@ class MyMemoryManager : public CPUMemoryManager { return std::make_unique(data, size, shared_from_this()); } + Result> MakeDeviceSyncEvent() override { + return std::make_shared(const_cast(kMyEventPtr), + [](void*) {}); + } + + Result> WrapDeviceSyncEvent( + void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) override { + return std::make_shared(sync_event, release_sync_event); + } + protected: Result> CopyBufferFrom( const std::shared_ptr& buf, @@ -1174,28 +1221,9 @@ class MyMemoryManager : public CPUMemoryManager { } }; -class MyDevice : public Device { - public: - explicit MyDevice(int value) : Device(true), value_(value) {} - const char* type_name() const override { return kMyDeviceTypeName; } - std::string ToString() const override { return kMyDeviceTypeName; } - bool Equals(const Device& other) const override { - if (other.type_name() != kMyDeviceTypeName || other.device_type() != device_type()) { - return false; - } - return checked_cast(other).value_ == value_; - } - DeviceAllocationType device_type() const override { - return static_cast(kMyDeviceType); - } - int64_t device_id() const override { return value_; } - std::shared_ptr default_memory_manager() override { - return std::make_shared(shared_from_this()); - } - - protected: - int value_; -}; +std::shared_ptr MyDevice::default_memory_manager() { + return std::make_shared(shared_from_this()); +} class TestDeviceArrayExport : public ::testing::Test { public: @@ -1251,7 +1279,8 @@ class TestDeviceArrayExport : public ::testing::Test { ", array data = ", arr->ToString()); const ArrayData& data = *arr->data(); // non-owning reference struct ArrowDeviceArray c_export; - ASSERT_OK(ExportDeviceArray(*arr, {nullptr, nullptr}, &c_export)); + std::shared_ptr sync{nullptr}; + ASSERT_OK(ExportDeviceArray(*arr, sync, &c_export)); ArrayExportGuard guard(&c_export.array); auto new_bytes = pool_->bytes_allocated(); @@ -1455,7 +1484,8 @@ TEST_F(TestDeviceArrayExport, ExportArrayAndType) { ArrayExportGuard array_guard(&c_array.array); auto array = ToDevice(mm, *ArrayFromJSON(int8(), "[1, 2, 3]")->data()).ValueOrDie(); - ASSERT_OK(ExportDeviceArray(*array, {nullptr, nullptr}, &c_array, &c_schema)); + auto sync = mm->MakeDeviceSyncEvent().ValueOrDie(); + ASSERT_OK(ExportDeviceArray(*array, sync, &c_array, &c_schema)); const ArrayData& data = *array->data(); array.reset(); ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema)); @@ -1463,7 +1493,7 @@ TEST_F(TestDeviceArrayExport, ExportArrayAndType) { ASSERT_EQ(c_schema.format, std::string("c")); ASSERT_EQ(c_schema.n_children, 0); ArrayExportChecker checker{}; - checker(&c_array, data, kMyDeviceType, 1, nullptr); + checker(&c_array, data, kMyDeviceType, 1, kMyEventPtr); } TEST_F(TestDeviceArrayExport, ExportRecordBatch) { @@ -1481,25 +1511,25 @@ TEST_F(TestDeviceArrayExport, ExportRecordBatch) { .ValueOrDie(); auto batch_factory = [&]() { return RecordBatch::Make(schema, 3, {arr0, arr1}); }; - + auto sync = mm->MakeDeviceSyncEvent().ValueOrDie(); { auto batch = batch_factory(); - ASSERT_OK(ExportDeviceRecordBatch(*batch, {nullptr, nullptr}, &c_array, &c_schema)); + ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array, &c_schema)); SchemaExportGuard schema_guard(&c_schema); ArrayExportGuard array_guard(&c_array.array); RecordBatchExportChecker checker{}; - checker(&c_array, *batch, kMyDeviceType, 1, nullptr); + checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr); // create batch anew, with the same buffer pointers batch = batch_factory(); - checker(&c_array, *batch, kMyDeviceType, 1, nullptr); + checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr); } { // Check one can export both schema and record batch at once auto batch = batch_factory(); - ASSERT_OK(ExportDeviceRecordBatch(*batch, {nullptr, nullptr}, &c_array, &c_schema)); + ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array, &c_schema)); SchemaExportGuard schema_guard(&c_schema); ArrayExportGuard array_guard(&c_array.array); ASSERT_EQ(c_schema.format, std::string("+s")); @@ -1508,11 +1538,11 @@ TEST_F(TestDeviceArrayExport, ExportRecordBatch) { ASSERT_EQ(kEncodedMetadata2, std::string(c_schema.metadata, kEncodedMetadata2.size())); RecordBatchExportChecker checker{}; - checker(&c_array, *batch, kMyDeviceType, 1, nullptr); + checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr); // Create batch anew, with the same buffer pointers batch = batch_factory(); - checker(&c_array, *batch, kMyDeviceType, 1, nullptr); + checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr); } } @@ -3552,6 +3582,190 @@ TEST_F(TestArrayRoundtrip, RecordBatch) { } } +class TestDeviceArrayRoundtrip : public ::testing::Test { + public: + using ArrayFactory = std::function>()>; + + void SetUp() override { pool_ = default_memory_pool(); } + + static Result> DeviceMapper(ArrowDeviceType type, + int64_t id) { + if (type != kMyDeviceType) { + return Status::NotImplemented("should only be MyDevice"); + } + + std::shared_ptr device = std::make_shared(id); + return device->default_memory_manager(); + } + + static Result> ToDeviceData( + const std::shared_ptr& mm, const ArrayData& data) { + arrow::BufferVector buffers; + for (const auto& buf : data.buffers) { + if (buf) { + ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm)); + buffers.push_back(dest); + } else { + buffers.push_back(nullptr); + } + } + + arrow::ArrayDataVector children; + for (const auto& child : data.child_data) { + ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child)); + children.push_back(dest); + } + + return ArrayData::Make(data.type, data.length, buffers, children, data.null_count, + data.offset); + } + + static Result> ToDevice(const std::shared_ptr& mm, + const ArrayData& data) { + ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data)); + return MakeArray(result); + } + + static ArrayFactory ToDeviceFactory(const std::shared_ptr& mm, + ArrayFactory&& factory) { + return [&]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, factory()); + return ToDevice(mm, *arr->data()); + }; + } + + static ArrayFactory JSONArrayFactory(const std::shared_ptr& mm, + std::shared_ptr type, const char* json) { + return [=]() { return ToDevice(mm, *ArrayFromJSON(type, json)->data()); }; + } + + static ArrayFactory SlicedArrayFactory(ArrayFactory factory) { + return [=]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto arr, factory()); + DCHECK_GE(arr->length(), 2); + return arr->Slice(1, arr->length() - 2); + }; + } + + template + void TestWithArrayFactory(ArrayFactory&& factory) { + TestWithArrayFactory(factory, factory); + } + + template + void TestWithArrayFactory(ArrayFactory&& factory, + ExpectedArrayFactory&& factory_expected) { + std::shared_ptr array; + struct ArrowDeviceArray c_array {}; + struct ArrowSchema c_schema {}; + ArrayExportGuard array_guard(&c_array.array); + SchemaExportGuard schema_guard(&c_schema); + + auto orig_bytes = pool_->bytes_allocated(); + + ASSERT_OK_AND_ASSIGN(array, ToResult(factory())); + ASSERT_OK(ExportType(*array->type(), &c_schema)); + std::shared_ptr sync{nullptr}; + ASSERT_OK(ExportDeviceArray(*array, sync, &c_array)); + + auto new_bytes = pool_->bytes_allocated(); + if (array->type_id() != Type::NA) { + ASSERT_GT(new_bytes, orig_bytes); + } + + array.reset(); + ASSERT_EQ(pool_->bytes_allocated(), new_bytes); + ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array, &c_schema, DeviceMapper)); + ASSERT_OK(array->ValidateFull()); + ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema)); + ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array)); + + // Re-export and re-import, now both at once + ASSERT_OK(ExportDeviceArray(*array, sync, &c_array, &c_schema)); + array.reset(); + ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array, &c_schema, DeviceMapper)); + ASSERT_OK(array->ValidateFull()); + ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema)); + ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array)); + + // Check value of imported array + { + std::shared_ptr expected; + ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected())); + AssertTypeEqual(*expected->type(), *array->type()); + AssertArraysEqual(*expected, *array, true); + } + array.reset(); + ASSERT_EQ(pool_->bytes_allocated(), orig_bytes); + } + + template + void TestWithBatchFactory(BatchFactory&& factory) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + std::shared_ptr batch; + struct ArrowDeviceArray c_array {}; + struct ArrowSchema c_schema {}; + ArrayExportGuard array_guard(&c_array.array); + SchemaExportGuard schema_guard(&c_schema); + + auto orig_bytes = pool_->bytes_allocated(); + ASSERT_OK_AND_ASSIGN(batch, ToResult(factory())); + ASSERT_OK(ExportSchema(*batch->schema(), &c_schema)); + ASSERT_OK_AND_ASSIGN(auto sync, mm->MakeDeviceSyncEvent()); + ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array)); + + auto new_bytes = pool_->bytes_allocated(); + batch.reset(); + ASSERT_EQ(pool_->bytes_allocated(), new_bytes); + ASSERT_OK_AND_ASSIGN(batch, + ImportDeviceRecordBatch(&c_array, &c_schema, DeviceMapper)); + ASSERT_OK(batch->ValidateFull()); + ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema)); + ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array)); + + // Re-export and re-import, now both at once + ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array, &c_schema)); + batch.reset(); + ASSERT_OK_AND_ASSIGN(batch, + ImportDeviceRecordBatch(&c_array, &c_schema, DeviceMapper)); + ASSERT_OK(batch->ValidateFull()); + ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema)); + ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array)); + + // Check value of imported record batch + { + std::shared_ptr expected; + ASSERT_OK_AND_ASSIGN(expected, ToResult(factory())); + AssertSchemaEqual(*expected->schema(), *batch->schema()); + AssertBatchesEqual(*expected, *batch); + } + batch.reset(); + ASSERT_EQ(pool_->bytes_allocated(), orig_bytes); + } + + void TestWithJSON(const std::shared_ptr& mm, + std::shared_ptr type, const char* json) { + TestWithArrayFactory(JSONArrayFactory(mm, type, json)); + } + + void TestWithJSONSliced(const std::shared_ptr& mm, + std::shared_ptr type, const char* json) { + TestWithArrayFactory(SlicedArrayFactory(JSONArrayFactory(mm, type, json))); + } + + protected: + MemoryPool* pool_; +}; + +TEST_F(TestDeviceArrayRoundtrip, Primitive) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + TestWithJSON(mm, int32(), "[4, 5, null]"); +} + // TODO C -> C++ -> C roundtripping tests? //////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/arrow/device.cc b/cpp/src/arrow/device.cc index fbb0c3e1a4a39..14d3bac0af1b7 100644 --- a/cpp/src/arrow/device.cc +++ b/cpp/src/arrow/device.cc @@ -29,6 +29,15 @@ namespace arrow { MemoryManager::~MemoryManager() {} +Result> MemoryManager::MakeDeviceSyncEvent() { + return nullptr; +} + +Result> MemoryManager::WrapDeviceSyncEvent( + void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) { + return nullptr; +} + Device::~Device() {} #define COPY_BUFFER_SUCCESS(maybe_buffer) \ diff --git a/cpp/src/arrow/device.h b/cpp/src/arrow/device.h index 9cc68fe8c82ce..55037ac418808 100644 --- a/cpp/src/arrow/device.h +++ b/cpp/src/arrow/device.h @@ -22,6 +22,8 @@ #include #include "arrow/io/type_fwd.h" +#include "arrow/result.h" +#include "arrow/status.h" #include "arrow/type_fwd.h" #include "arrow/util/compare.h" #include "arrow/util/macros.h" @@ -98,6 +100,54 @@ class ARROW_EXPORT Device : public std::enable_shared_from_this, /// \brief Return the DeviceAllocationType of this device virtual DeviceAllocationType device_type() const = 0; + class SyncEvent; + + /// \brief EXPERIMENTAL: An opaque wrapper for Device-specific streams + /// + /// In essence this is just a wrapper around a void* to represent the + /// standard concept of a stream/queue on a device. Derived classes + /// should be trivially constructible from it's device-specific counterparts. + class ARROW_EXPORT Stream { + public: + virtual const void* get_raw() const { return NULLPTR; } + + /// \brief Make the stream wait on the provided event. + /// + /// Tells the stream that it should wait until the synchronization + /// event is completed without blocking the CPU. + virtual Status WaitEvent(const SyncEvent&) = 0; + + protected: + Stream() = default; + virtual ~Stream() = default; + }; + + /// \brief EXPERIMENTAL: An object that provides event/stream sync primitives + class ARROW_EXPORT SyncEvent { + public: + using release_fn_t = void (*)(void*); + + virtual ~SyncEvent() = default; + + void* get_raw() { return sync_event_.get(); } + + /// @brief Block until sync event is completed. + virtual Status Wait() = 0; + + /// @brief Record the wrapped event on the stream so it triggers + /// the event when the stream gets to that point in its queue. + virtual Status Record(const Stream&) = 0; + + protected: + /// If creating this with a passed in event, the caller must ensure + /// that the event lives until clear_event is called on this as it + /// won't own it. + explicit SyncEvent(void* sync_event, release_fn_t release_sync_event) + : sync_event_{sync_event, release_sync_event} {} + + std::unique_ptr sync_event_; + }; + protected: ARROW_DISALLOW_COPY_AND_ASSIGN(Device); explicit Device(bool is_cpu = false) : is_cpu_(is_cpu) {} @@ -165,6 +215,22 @@ class ARROW_EXPORT MemoryManager : public std::enable_shared_from_this> ViewBuffer( const std::shared_ptr& source, const std::shared_ptr& to); + /// \brief Create a new SyncEvent. + /// + /// This version should construct the appropriate event for the device and + /// provide the unique_ptr with the correct deleter for the event type. + /// If the device does not require or work with any synchronization, it is + /// allowed for it to return a nullptr. + virtual Result> MakeDeviceSyncEvent(); + + /// \brief Wrap an event into a SyncEvent. + /// + /// @param sync_event passed in sync_event from the imported device array. + /// @param release_sync_event destructor to free sync_event. `nullptr` may be + /// passed to indicate that no destruction/freeing is necessary + virtual Result> WrapDeviceSyncEvent( + void* sync_event, Device::SyncEvent::release_fn_t release_sync_event); + protected: ARROW_DISALLOW_COPY_AND_ASSIGN(MemoryManager); diff --git a/cpp/src/arrow/gpu/cuda_context.cc b/cpp/src/arrow/gpu/cuda_context.cc index 869ea6453ccda..3e1af26cac39b 100644 --- a/cpp/src/arrow/gpu/cuda_context.cc +++ b/cpp/src/arrow/gpu/cuda_context.cc @@ -293,6 +293,13 @@ std::shared_ptr CudaMemoryManager::cuda_device() const { return checked_pointer_cast(device_); } +Result> CudaMemoryManager::WrapDeviceSyncEvent( + void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) { + return nullptr; + // auto ev = reinterpret_cast(sync_event); + // return std::make_shared(ev); +} + Result> CudaMemoryManager::GetBufferReader( std::shared_ptr buf) { if (*buf->device() != *device_) { diff --git a/cpp/src/arrow/gpu/cuda_context.h b/cpp/src/arrow/gpu/cuda_context.h index a1b95c7b4181d..79a2ec9f97581 100644 --- a/cpp/src/arrow/gpu/cuda_context.h +++ b/cpp/src/arrow/gpu/cuda_context.h @@ -179,6 +179,9 @@ class ARROW_EXPORT CudaMemoryManager : public MemoryManager { /// having to cast the `device()` result. std::shared_ptr cuda_device() const; + Result> WrapDeviceSyncEvent( + void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) override; + protected: using MemoryManager::MemoryManager; static std::shared_ptr Make(const std::shared_ptr& device);