Skip to content

Commit

Permalink
Move all AsyncThreadHandler functionality into AsyncReadHandler
Browse files Browse the repository at this point in the history
Summary:
Make AsyncThreadHandler more coherent, by making it fully own the background thread and its queue, renaming it AsyncReadHandler for good measure. This will allow us to make the reader smarter in a follow-up diff.
There are no real functional changes, this is only a refactor.

Reviewed By: hanghu

Differential Revision: D67933571

fbshipit-source-id: f4a97a8afa5aaf08241a8f4ade40c1aac9dd61d5
  • Loading branch information
Georges Berenger authored and facebook-github-bot committed Jan 9, 2025
1 parent 636dff3 commit 99cf0ae
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 26 deletions.
24 changes: 12 additions & 12 deletions csrc/reader/AsyncVRSReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,28 @@

namespace pyvrs {

AwaitableRecord::AwaitableRecord(uint32_t index, AsyncJobQueue& queue)
: index_{index}, queue_{queue} {}
AwaitableRecord::AwaitableRecord(uint32_t index, AsyncReadHandler& readHandler)
: index_{index}, readHandler_{readHandler} {}

AwaitableRecord::AwaitableRecord(const AwaitableRecord& other)
: index_{other.index_}, queue_{other.queue_} {}
: index_{other.index_}, readHandler_{other.readHandler_} {}

py::object AwaitableRecord::await() const {
py::gil_scoped_acquire acquire;
unique_ptr<AsyncJob> job = make_unique<AsyncReadJob>(index_);
py::object res = job->await();
queue_.sendJob(std::move(job));
readHandler_.getQueue().sendJob(std::move(job));
return res;
}

void AsyncThreadHandler::cleanup() {
void AsyncReadHandler::cleanup() {
shouldEndAsyncThread_ = true;
if (asyncThread_.joinable()) {
asyncThread_.join();
}
}

void AsyncThreadHandler::asyncThreadActivity() {
void AsyncReadHandler::asyncThreadActivity() {
std::unique_ptr<AsyncJob> job;
while (!shouldEndAsyncThread_) {
if (workerQueue_.waitForJob(job, 1) && !shouldEndAsyncThread_) {
Expand Down Expand Up @@ -90,18 +90,18 @@ OssAsyncVRSReader::asyncReadRecord(const string& streamId, const string& recordT
nextRecordIndex_ = static_cast<uint32_t>(reader_.getIndex().size());
throw py::index_error("Invalid record index");
}
return AwaitableRecord(static_cast<uint32_t>(record - reader_.getIndex().data()), workerQueue_);
return AwaitableRecord(static_cast<uint32_t>(record - reader_.getIndex().data()), readHandler_);
}

AwaitableRecord OssAsyncVRSReader::asyncReadRecord(int index) {
if (static_cast<size_t>(index) >= reader_.getIndex().size()) {
throw py::index_error("No record for this index");
}
return AwaitableRecord(static_cast<uint32_t>(index), workerQueue_);
return AwaitableRecord(static_cast<uint32_t>(index), readHandler_);
}

OssAsyncVRSReader::~OssAsyncVRSReader() {
asyncThreadHandler_.cleanup();
readHandler_.cleanup();
reader_.closeFile();
}

Expand All @@ -128,18 +128,18 @@ AwaitableRecord OssAsyncMultiVRSReader::asyncReadRecord(
nextRecordIndex_ = reader_.getRecordCount();
throw py::index_error("Invalid record index: " + to_string(index));
}
return AwaitableRecord(reader_.getRecordIndex(record), workerQueue_);
return AwaitableRecord(reader_.getRecordIndex(record), readHandler_);
}

AwaitableRecord OssAsyncMultiVRSReader::asyncReadRecord(int index) {
if (static_cast<uint32_t>(index) >= reader_.getRecordCount()) {
throw py::index_error("No record for this index");
}
return AwaitableRecord(static_cast<uint32_t>(index), workerQueue_);
return AwaitableRecord(static_cast<uint32_t>(index), readHandler_);
}

OssAsyncMultiVRSReader::~OssAsyncMultiVRSReader() {
asyncThreadHandler_.cleanup();
readHandler_.cleanup();
reader_.close();
}

Expand Down
32 changes: 18 additions & 14 deletions csrc/reader/AsyncVRSReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AsyncReadJob : public AsyncJob {
};

using AsyncJobQueue = JobQueue<std::unique_ptr<AsyncJob>>;
class AsyncReadHandler;

/// \brief Python awaitable record
/// This class only exposes __await__ method to Python which does the following:
Expand All @@ -83,30 +84,35 @@ using AsyncJobQueue = JobQueue<std::unique_ptr<AsyncJob>>;
/// - Call future.__await__() and Python side waits until set_result will be called by AsyncReader
class AwaitableRecord {
public:
AwaitableRecord(uint32_t index, AsyncJobQueue& queue);
AwaitableRecord(uint32_t index, AsyncReadHandler& readHandler);
AwaitableRecord(const AwaitableRecord& other);

py::object await() const;

private:
uint32_t index_;
AsyncJobQueue& queue_;
AsyncReadHandler& readHandler_;
};

/// \brief Helper class to manage the background async thread
class AsyncThreadHandler {
class AsyncReadHandler {
public:
AsyncThreadHandler(VRSReaderBase& reader, AsyncJobQueue& queue)
: reader_{reader},
workerQueue_(queue),
asyncThread_(&AsyncThreadHandler::asyncThreadActivity, this) {}
AsyncReadHandler(VRSReaderBase& reader)
: reader_{reader}, asyncThread_(&AsyncReadHandler::asyncThreadActivity, this) {}

VRSReaderBase& getReader() const {
return reader_;
}
AsyncJobQueue& getQueue() {
return workerQueue_;
}

void asyncThreadActivity();
void cleanup();

private:
VRSReaderBase& reader_;
AsyncJobQueue& workerQueue_;
AsyncJobQueue workerQueue_;
atomic<bool> shouldEndAsyncThread_ = false;
thread asyncThread_;
};
Expand All @@ -117,7 +123,7 @@ class AsyncThreadHandler {
class OssAsyncVRSReader : public OssVRSReader {
public:
explicit OssAsyncVRSReader(bool autoReadConfigurationRecord)
: OssVRSReader(autoReadConfigurationRecord), asyncThreadHandler_{*this, workerQueue_} {}
: OssVRSReader(autoReadConfigurationRecord), readHandler_{*this} {}

~OssAsyncVRSReader() override;

Expand All @@ -136,8 +142,7 @@ class OssAsyncVRSReader : public OssVRSReader {
AwaitableRecord asyncReadRecord(int index);

private:
AsyncJobQueue workerQueue_;
AsyncThreadHandler asyncThreadHandler_;
AsyncReadHandler readHandler_;
};

/// \brief The async MultiVRSReader class
Expand All @@ -146,7 +151,7 @@ class OssAsyncVRSReader : public OssVRSReader {
class OssAsyncMultiVRSReader : public OssMultiVRSReader {
public:
explicit OssAsyncMultiVRSReader(bool autoReadConfigurationRecord)
: OssMultiVRSReader(autoReadConfigurationRecord), asyncThreadHandler_{*this, workerQueue_} {}
: OssMultiVRSReader(autoReadConfigurationRecord), readHandler_{*this} {}

~OssAsyncMultiVRSReader() override;

Expand All @@ -165,8 +170,7 @@ class OssAsyncMultiVRSReader : public OssMultiVRSReader {
AwaitableRecord asyncReadRecord(int index);

private:
AsyncJobQueue workerQueue_;
AsyncThreadHandler asyncThreadHandler_;
AsyncReadHandler readHandler_;
};

/// Binds methods and classes for AsyncVRSReader.
Expand Down

0 comments on commit 99cf0ae

Please sign in to comment.