From 99cf0ae539172e21003b5763e1905a0af1255774 Mon Sep 17 00:00:00 2001 From: Georges Berenger Date: Thu, 9 Jan 2025 13:08:30 -0800 Subject: [PATCH] Move all AsyncThreadHandler functionality into AsyncReadHandler Summary: Make AsyncThreadHandler more coherent, by making it fully own the background thread and its queue, renaming it AsyncReadHandler for good measure. This will allow us to make the reader smarter in a follow-up diff. There are no real functional changes, this is only a refactor. Reviewed By: hanghu Differential Revision: D67933571 fbshipit-source-id: f4a97a8afa5aaf08241a8f4ade40c1aac9dd61d5 --- csrc/reader/AsyncVRSReader.cpp | 24 ++++++++++++------------ csrc/reader/AsyncVRSReader.h | 32 ++++++++++++++++++-------------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/csrc/reader/AsyncVRSReader.cpp b/csrc/reader/AsyncVRSReader.cpp index b043000..ef9509a 100644 --- a/csrc/reader/AsyncVRSReader.cpp +++ b/csrc/reader/AsyncVRSReader.cpp @@ -31,28 +31,28 @@ namespace pyvrs { -AwaitableRecord::AwaitableRecord(uint32_t index, AsyncJobQueue& queue) - : index_{index}, queue_{queue} {} +AwaitableRecord::AwaitableRecord(uint32_t index, AsyncReadHandler& readHandler) + : index_{index}, readHandler_{readHandler} {} AwaitableRecord::AwaitableRecord(const AwaitableRecord& other) - : index_{other.index_}, queue_{other.queue_} {} + : index_{other.index_}, readHandler_{other.readHandler_} {} py::object AwaitableRecord::await() const { py::gil_scoped_acquire acquire; unique_ptr job = make_unique(index_); py::object res = job->await(); - queue_.sendJob(std::move(job)); + readHandler_.getQueue().sendJob(std::move(job)); return res; } -void AsyncThreadHandler::cleanup() { +void AsyncReadHandler::cleanup() { shouldEndAsyncThread_ = true; if (asyncThread_.joinable()) { asyncThread_.join(); } } -void AsyncThreadHandler::asyncThreadActivity() { +void AsyncReadHandler::asyncThreadActivity() { std::unique_ptr job; while (!shouldEndAsyncThread_) { if (workerQueue_.waitForJob(job, 1) && !shouldEndAsyncThread_) { @@ -90,18 +90,18 @@ OssAsyncVRSReader::asyncReadRecord(const string& streamId, const string& recordT nextRecordIndex_ = static_cast(reader_.getIndex().size()); throw py::index_error("Invalid record index"); } - return AwaitableRecord(static_cast(record - reader_.getIndex().data()), workerQueue_); + return AwaitableRecord(static_cast(record - reader_.getIndex().data()), readHandler_); } AwaitableRecord OssAsyncVRSReader::asyncReadRecord(int index) { if (static_cast(index) >= reader_.getIndex().size()) { throw py::index_error("No record for this index"); } - return AwaitableRecord(static_cast(index), workerQueue_); + return AwaitableRecord(static_cast(index), readHandler_); } OssAsyncVRSReader::~OssAsyncVRSReader() { - asyncThreadHandler_.cleanup(); + readHandler_.cleanup(); reader_.closeFile(); } @@ -128,18 +128,18 @@ AwaitableRecord OssAsyncMultiVRSReader::asyncReadRecord( nextRecordIndex_ = reader_.getRecordCount(); throw py::index_error("Invalid record index: " + to_string(index)); } - return AwaitableRecord(reader_.getRecordIndex(record), workerQueue_); + return AwaitableRecord(reader_.getRecordIndex(record), readHandler_); } AwaitableRecord OssAsyncMultiVRSReader::asyncReadRecord(int index) { if (static_cast(index) >= reader_.getRecordCount()) { throw py::index_error("No record for this index"); } - return AwaitableRecord(static_cast(index), workerQueue_); + return AwaitableRecord(static_cast(index), readHandler_); } OssAsyncMultiVRSReader::~OssAsyncMultiVRSReader() { - asyncThreadHandler_.cleanup(); + readHandler_.cleanup(); reader_.close(); } diff --git a/csrc/reader/AsyncVRSReader.h b/csrc/reader/AsyncVRSReader.h index 65c3509..ff0af54 100644 --- a/csrc/reader/AsyncVRSReader.h +++ b/csrc/reader/AsyncVRSReader.h @@ -72,6 +72,7 @@ class AsyncReadJob : public AsyncJob { }; using AsyncJobQueue = JobQueue>; +class AsyncReadHandler; /// \brief Python awaitable record /// This class only exposes __await__ method to Python which does the following: @@ -83,30 +84,35 @@ using AsyncJobQueue = JobQueue>; /// - Call future.__await__() and Python side waits until set_result will be called by AsyncReader class AwaitableRecord { public: - AwaitableRecord(uint32_t index, AsyncJobQueue& queue); + AwaitableRecord(uint32_t index, AsyncReadHandler& readHandler); AwaitableRecord(const AwaitableRecord& other); py::object await() const; private: uint32_t index_; - AsyncJobQueue& queue_; + AsyncReadHandler& readHandler_; }; /// \brief Helper class to manage the background async thread -class AsyncThreadHandler { +class AsyncReadHandler { public: - AsyncThreadHandler(VRSReaderBase& reader, AsyncJobQueue& queue) - : reader_{reader}, - workerQueue_(queue), - asyncThread_(&AsyncThreadHandler::asyncThreadActivity, this) {} + AsyncReadHandler(VRSReaderBase& reader) + : reader_{reader}, asyncThread_(&AsyncReadHandler::asyncThreadActivity, this) {} + + VRSReaderBase& getReader() const { + return reader_; + } + AsyncJobQueue& getQueue() { + return workerQueue_; + } void asyncThreadActivity(); void cleanup(); private: VRSReaderBase& reader_; - AsyncJobQueue& workerQueue_; + AsyncJobQueue workerQueue_; atomic shouldEndAsyncThread_ = false; thread asyncThread_; }; @@ -117,7 +123,7 @@ class AsyncThreadHandler { class OssAsyncVRSReader : public OssVRSReader { public: explicit OssAsyncVRSReader(bool autoReadConfigurationRecord) - : OssVRSReader(autoReadConfigurationRecord), asyncThreadHandler_{*this, workerQueue_} {} + : OssVRSReader(autoReadConfigurationRecord), readHandler_{*this} {} ~OssAsyncVRSReader() override; @@ -136,8 +142,7 @@ class OssAsyncVRSReader : public OssVRSReader { AwaitableRecord asyncReadRecord(int index); private: - AsyncJobQueue workerQueue_; - AsyncThreadHandler asyncThreadHandler_; + AsyncReadHandler readHandler_; }; /// \brief The async MultiVRSReader class @@ -146,7 +151,7 @@ class OssAsyncVRSReader : public OssVRSReader { class OssAsyncMultiVRSReader : public OssMultiVRSReader { public: explicit OssAsyncMultiVRSReader(bool autoReadConfigurationRecord) - : OssMultiVRSReader(autoReadConfigurationRecord), asyncThreadHandler_{*this, workerQueue_} {} + : OssMultiVRSReader(autoReadConfigurationRecord), readHandler_{*this} {} ~OssAsyncMultiVRSReader() override; @@ -165,8 +170,7 @@ class OssAsyncMultiVRSReader : public OssMultiVRSReader { AwaitableRecord asyncReadRecord(int index); private: - AsyncJobQueue workerQueue_; - AsyncThreadHandler asyncThreadHandler_; + AsyncReadHandler readHandler_; }; /// Binds methods and classes for AsyncVRSReader.