diff --git a/cpp/core/shuffle/Spill.cc b/cpp/core/shuffle/Spill.cc index d8b9bc7ebf99..8efa323945d7 100644 --- a/cpp/core/shuffle/Spill.cc +++ b/cpp/core/shuffle/Spill.cc @@ -73,7 +73,7 @@ void Spill::insertPayload( void Spill::openSpillFile() { if (!is_) { - GLUTEN_ASSIGN_OR_THROW(is_, arrow::io::MemoryMappedFile::Open(spillFile_, arrow::io::FileMode::READ)); + GLUTEN_ASSIGN_OR_THROW(is_, MmapFileStream::open(spillFile_)); rawIs_ = is_.get(); } } diff --git a/cpp/core/shuffle/Spill.h b/cpp/core/shuffle/Spill.h index c82a60f562b4..2a88177d9756 100644 --- a/cpp/core/shuffle/Spill.h +++ b/cpp/core/shuffle/Spill.h @@ -69,9 +69,8 @@ class Spill final { }; SpillType type_; - std::shared_ptr is_; + std::shared_ptr is_; std::list partitionPayloads_{}; - std::shared_ptr inputStream_{}; std::string spillFile_; int64_t spillTime_; int64_t compressTime_; diff --git a/cpp/core/shuffle/Utils.cc b/cpp/core/shuffle/Utils.cc index 6854c1978370..8965288beeed 100644 --- a/cpp/core/shuffle/Utils.cc +++ b/cpp/core/shuffle/Utils.cc @@ -16,10 +16,13 @@ */ #include "shuffle/Utils.h" +#include #include #include #include #include +#include +#include #include #include #include @@ -212,6 +215,96 @@ arrow::Result> makeUncompressedRecordBatch( } return arrow::RecordBatch::Make(writeSchema, 1, {arrays}); } + +arrow::Result> MmapFileStream::open(const std::string& path) { + ARROW_ASSIGN_OR_RAISE(auto fileName, arrow::internal::PlatformFilename::FromString(path)); + + ARROW_ASSIGN_OR_RAISE(auto fd, arrow::internal::FileOpenReadable(fileName)); + ARROW_ASSIGN_OR_RAISE(auto size, arrow::internal::FileGetSize(fd.fd())); + + void* result = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd.fd(), 0); + if (result == MAP_FAILED) { + return arrow::Status::IOError("Memory mapping file failed: ", ::arrow::internal::ErrnoMessage(errno)); + } + + auto fstream = std::shared_ptr(new MmapFileStream()); + fstream->fd_ = std::move(fd); + fstream->data_ = static_cast(result); + fstream->size_ = size; + return fstream; +} + +void MmapFileStream::advance(int64_t length) { + static auto pageSize = static_cast(arrow::internal::GetPageSize()); + static auto pageMask = ~(pageSize - 1); + DCHECK_GT(pageSize, 0); + DCHECK_EQ(pageMask & pageSize, pageSize); + + auto purgeLength = (pos_ - posRetain_) & pageMask; + if (purgeLength > 0) { + int ret = madvise(data_ + posRetain_, purgeLength, MADV_DONTNEED); + if (ret != 0) { + LOG(WARNING) << "fadvise failed " << ::arrow::internal::ErrnoMessage(errno); + } + posRetain_ += purgeLength; + } + + pos_ += length; +} + +void MmapFileStream::willNeed(int64_t length) { + static auto pageSize = static_cast(arrow::internal::GetPageSize()); + static auto pageMask = ~(pageSize - 1); + DCHECK_GT(pageSize, 0); + DCHECK_EQ(pageMask & pageSize, pageSize); + + auto willNeedPos = pos_ & pageMask; + auto willNeedLen = pos_ + length - willNeedPos; + int ret = madvise(data_ + willNeedPos, willNeedLen, MADV_WILLNEED); + if (ret != 0) { + LOG(WARNING) << "madvise willneed failed: " << ::arrow::internal::ErrnoMessage(errno); + } +} + +arrow::Status MmapFileStream::Close() { + if (data_ != nullptr) { + int result = munmap(data_, size_); + if (result != 0) { + LOG(WARNING) << "munmap failed"; + } + data_ = nullptr; + } + + return fd_.Close(); +} + +arrow::Result MmapFileStream::Tell() const { + return pos_; +} + +arrow::Result MmapFileStream::Read(int64_t nbytes, void* out) { + ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes)); + + if (nbytes > 0) { + memcpy(out, data_ + pos_, nbytes); + advance(nbytes); + } + + return nbytes; +} + +arrow::Result> MmapFileStream::Read(int64_t nbytes) { + ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes)); + + if (nbytes > 0) { + auto buffer = std::make_shared(data_ + pos_, nbytes); + willNeed(nbytes); + advance(nbytes); + return buffer; + } else { + return std::make_shared(nullptr, 0); + } +} } // namespace gluten std::string gluten::generateUuid() { diff --git a/cpp/core/shuffle/Utils.h b/cpp/core/shuffle/Utils.h index c4e2409d2da0..5d6f07707a73 100644 --- a/cpp/core/shuffle/Utils.h +++ b/cpp/core/shuffle/Utils.h @@ -72,4 +72,34 @@ arrow::Result> makeUncompressedRecordBatch( std::shared_ptr zeroLengthNullBuffer(); +class MmapFileStream : public arrow::io::InputStream { + public: + static arrow::Result> open(const std::string& path); + arrow::Result Tell() const override; + arrow::Status Close() override; + arrow::Result Read(int64_t nbytes, void* out) override; + arrow::Result> Read(int64_t nbytes) override; + bool closed() const override { + return data_ == nullptr; + }; + + private: + arrow::Result actualReadSize(int64_t nbytes) { + if (nbytes < 0 || pos_ > size_) { + return arrow::Status::IOError("Read out of range. Offset: ", pos_, " Size: ", nbytes, " File Size: ", size_); + } + return std::min(size_ - pos_, nbytes); + } + + void advance(int64_t length); + void willNeed(int64_t length); + + arrow::internal::FileDescriptor fd_; + uint8_t* data_ = nullptr; + int64_t size_; + int64_t pos_ = 0; + int64_t posRetain_ = 0; + MmapFileStream() = default; +}; + } // namespace gluten