Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-7860][CORE] In shuffle writer, replace MemoryMappedFile to avoid OOM #7861

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/core/shuffle/Spill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void Spill::insertPayload(

void Spill::openSpillFile() {
if (!is_) {
GLUTEN_ASSIGN_OR_THROW(is_, arrow::io::MemoryMappedFile::Open(spillFile_, arrow::io::FileMode::READ));
GLUTEN_ASSIGN_OR_THROW(is_, MmapFileStream::open(spillFile_));
rawIs_ = is_.get();
}
}
Expand Down
3 changes: 1 addition & 2 deletions cpp/core/shuffle/Spill.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ class Spill final {
};

SpillType type_;
std::shared_ptr<arrow::io::MemoryMappedFile> is_;
std::shared_ptr<gluten::MmapFileStream> is_;
std::list<PartitionPayload> partitionPayloads_{};
std::shared_ptr<arrow::io::MemoryMappedFile> inputStream_{};
std::string spillFile_;
int64_t spillTime_;
int64_t compressTime_;
Expand Down
103 changes: 103 additions & 0 deletions cpp/core/shuffle/Utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
*/

#include "shuffle/Utils.h"
#include <arrow/buffer.h>
#include <arrow/record_batch.h>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <fcntl.h>
#include <glog/logging.h>
#include <sys/mman.h>
#include <iomanip>
#include <iostream>
#include <numeric>
Expand Down Expand Up @@ -212,6 +215,106 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> makeUncompressedRecordBatch(
}
return arrow::RecordBatch::Make(writeSchema, 1, {arrays});
}

MmapFileStream::MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t size)
: fd_(std::move(fd)), data_(data), size_(size){};

arrow::Result<std::shared_ptr<MmapFileStream>> MmapFileStream::open(const std::string& path) {
ARROW_ASSIGN_OR_RAISE(auto fileName, arrow::internal::PlatformFilename::FromString(path));

ARROW_ASSIGN_OR_RAISE(auto fd, arrow::internal::FileOpenReadable(fileName));
ARROW_ASSIGN_OR_RAISE(auto size, arrow::internal::FileGetSize(fd.fd()));

void* result = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd.fd(), 0);
if (result == MAP_FAILED) {
return arrow::Status::IOError("Memory mapping file failed: ", ::arrow::internal::ErrnoMessage(errno));
}

return std::make_shared<MmapFileStream>(std::move(fd), static_cast<uint8_t*>(result), size);
}

arrow::Result<int64_t> MmapFileStream::actualReadSize(int64_t nbytes) {
if (nbytes < 0 || pos_ > size_) {
return arrow::Status::IOError("Read out of range. Offset: ", pos_, " Size: ", nbytes, " File Size: ", size_);
}
return std::min(size_ - pos_, nbytes);
}

bool MmapFileStream::closed() const {
return data_ == nullptr;
};

void MmapFileStream::advance(int64_t length) {
static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
static auto pageMask = ~(pageSize - 1);
DCHECK_GT(pageSize, 0);
DCHECK_EQ(pageMask & pageSize, pageSize);

auto purgeLength = (pos_ - posRetain_) & pageMask;
if (purgeLength > 0) {
int ret = madvise(data_ + posRetain_, purgeLength, MADV_DONTNEED);
if (ret != 0) {
LOG(WARNING) << "fadvise failed " << ::arrow::internal::ErrnoMessage(errno);
}
posRetain_ += purgeLength;
}

pos_ += length;
}

void MmapFileStream::willNeed(int64_t length) {
static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
static auto pageMask = ~(pageSize - 1);
DCHECK_GT(pageSize, 0);
DCHECK_EQ(pageMask & pageSize, pageSize);

auto willNeedPos = pos_ & pageMask;
auto willNeedLen = pos_ + length - willNeedPos;
int ret = madvise(data_ + willNeedPos, willNeedLen, MADV_WILLNEED);
if (ret != 0) {
LOG(WARNING) << "madvise willneed failed: " << ::arrow::internal::ErrnoMessage(errno);
}
}

arrow::Status MmapFileStream::Close() {
if (data_ != nullptr) {
int result = munmap(data_, size_);
if (result != 0) {
LOG(WARNING) << "munmap failed";
}
data_ = nullptr;
}

return fd_.Close();
}

arrow::Result<int64_t> MmapFileStream::Tell() const {
return pos_;
}

arrow::Result<int64_t> MmapFileStream::Read(int64_t nbytes, void* out) {
ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));

if (nbytes > 0) {
memcpy(out, data_ + pos_, nbytes);
advance(nbytes);
}

return nbytes;
}

arrow::Result<std::shared_ptr<arrow::Buffer>> MmapFileStream::Read(int64_t nbytes) {
ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));

if (nbytes > 0) {
auto buffer = std::make_shared<arrow::Buffer>(data_ + pos_, nbytes);
willNeed(nbytes);
advance(nbytes);
return buffer;
} else {
return std::make_shared<arrow::Buffer>(nullptr, 0);
}
}
} // namespace gluten

std::string gluten::generateUuid() {
Expand Down
32 changes: 32 additions & 0 deletions cpp/core/shuffle/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,36 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> makeUncompressedRecordBatch(

std::shared_ptr<arrow::Buffer> zeroLengthNullBuffer();

// MmapFileStream is used to optimize sequential file reading. It uses madvise
// to prefetch and release memory timely.
class MmapFileStream : public arrow::io::InputStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some comments to explain the usage/functionality for this class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may contribute MmapFileStream to Apache Arrow in future.

public:
MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t size);

static arrow::Result<std::shared_ptr<MmapFileStream>> open(const std::string& path);

arrow::Result<int64_t> Tell() const override;

arrow::Status Close() override;

arrow::Result<int64_t> Read(int64_t nbytes, void* out) override;

arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nbytes) override;

bool closed() const override;

private:
arrow::Result<int64_t> actualReadSize(int64_t nbytes);

void advance(int64_t length);

void willNeed(int64_t length);

arrow::internal::FileDescriptor fd_;
uint8_t* data_ = nullptr;
int64_t size_;
int64_t pos_ = 0;
int64_t posRetain_ = 0;
};

} // namespace gluten
Loading