Skip to content

Commit

Permalink
feat: Add DecodedVector::sharedBase() (facebookincubator#12249)
Browse files Browse the repository at this point in the history
Summary:

Sometimes we need to take shared ownership of the base value vector of a dictionary.  The current `DecodedVector` only keeps reference to a raw pointer so there is no way to get hold of the `shared_ptr`.  We add `DecodedVector::sharedBase()` and overload of `DecodedVector::decode` to take `shared_ptr` so that we can get the shared ownership.

Differential Revision: D69081492
  • Loading branch information
Yuhta authored and facebook-github-bot committed Feb 4, 2025
1 parent 5d13c13 commit 22ad3c0
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 104 deletions.
2 changes: 1 addition & 1 deletion velox/expression/PeeledEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void PeeledEncoding::setDictionaryWrapping(
wrapNulls_ = firstWrapper.nulls();
return;
}
auto wrapping = decoded.dictionaryWrapping(firstWrapper, rows.end());
auto wrapping = decoded.dictionaryWrapping(*firstWrapper.pool(), rows.end());
wrap_ = std::move(wrapping.indices);
wrapNulls_ = std::move(wrapping.nulls);
}
Expand Down
192 changes: 132 additions & 60 deletions velox/vector/DecodedVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,33 @@ namespace facebook::velox {
uint64_t DecodedVector::constantNullMask_{0};

namespace {

std::vector<vector_size_t> makeConsecutiveIndices(size_t size) {
std::vector<vector_size_t> consecutiveIndices(size);
for (vector_size_t i = 0; i < consecutiveIndices.size(); ++i) {
consecutiveIndices[i] = i;
}
return consecutiveIndices;
}

template <typename T>
auto getLoadedVector(const T& vector) {
if constexpr (std::is_same_v<T, VectorPtr>) {
return BaseVector::loadedVectorShared(vector);
} else {
return vector->loadedVector();
}
}

template <typename T>
auto getValueVector(const T& vector) {
if constexpr (std::is_same_v<T, VectorPtr>) {
return vector->valueVector();
} else {
return vector->valueVector().get();
}
}

} // namespace

const std::vector<vector_size_t>& DecodedVector::consecutiveIndices() {
Expand All @@ -44,21 +64,22 @@ const std::vector<vector_size_t>& DecodedVector::zeroIndices() {
return indices;
}

void DecodedVector::decode(
const BaseVector& vector,
template <typename T>
VectorPtr DecodedVector::decodeImpl(
const T& vector,
const SelectivityVector* rows,
bool loadLazy) {
reset(end(vector.size(), rows));
reset(end(vector->size(), rows));
partialRowsDecoded_ = rows != nullptr;
loadLazy_ = loadLazy;
const bool isTopLevelLazyAndLoaded =
vector.isLazy() && vector.asUnchecked<LazyVector>()->isLoaded();
if (isTopLevelLazyAndLoaded || (loadLazy_ && isLazyNotLoaded(vector))) {
decode(*vector.loadedVector(), rows, loadLazy);
return;
const bool isTopLevelLazyAndLoaded = vector->isLazy() &&
vector->template asUnchecked<LazyVector>()->isLoaded();
if (isTopLevelLazyAndLoaded || (loadLazy_ && isLazyNotLoaded(*vector))) {
return decodeImpl(getLoadedVector(vector), rows, loadLazy);
}

const auto encoding = vector.encoding();
VectorPtr sharedBase;
const auto encoding = vector->encoding();
switch (encoding) {
case VectorEncoding::Simple::FLAT:
case VectorEncoding::Simple::BIASED:
Expand All @@ -67,29 +88,63 @@ void DecodedVector::decode(
case VectorEncoding::Simple::MAP:
case VectorEncoding::Simple::LAZY:
isIdentityMapping_ = true;
setBaseData(vector, rows);
return;
setBaseData(vector, rows, sharedBase);
break;
case VectorEncoding::Simple::CONSTANT: {
isConstantMapping_ = true;
if (isLazyNotLoaded(vector)) {
baseVector_ = vector.valueVector().get();
constantIndex_ = vector.wrapInfo()->as<vector_size_t>()[0];
if (isLazyNotLoaded(*vector)) {
if constexpr (std::is_same_v<T, VectorPtr>) {
sharedBase = vector->valueVector();
}
baseVector_ = vector->valueVector().get();
constantIndex_ = vector->wrapInfo()->template as<vector_size_t>()[0];
mayHaveNulls_ = true;
} else {
setBaseData(vector, rows);
setBaseData(vector, rows, sharedBase);
}
break;
}
case VectorEncoding::Simple::DICTIONARY:
case VectorEncoding::Simple::SEQUENCE: {
combineWrappers(&vector, rows);
combineWrappers(vector, rows, sharedBase);
break;
}
default:
VELOX_FAIL(
"Unsupported vector encoding: {}",
VectorEncoding::mapSimpleToName(encoding));
}
return sharedBase;
}

DecodedVector::DecodedVector(
const BaseVector& vector,
const SelectivityVector& rows,
bool loadLazy) {
decodeImpl(&vector, &rows, loadLazy);
}

DecodedVector::DecodedVector(const BaseVector& vector, bool loadLazy) {
decodeImpl(&vector, nullptr, loadLazy);
}

void DecodedVector::decode(
const BaseVector& vector,
const SelectivityVector& rows,
bool loadLazy) {
decodeImpl(&vector, &rows, loadLazy);
}

void DecodedVector::decode(const BaseVector& vector, bool loadLazy) {
decodeImpl(&vector, nullptr, loadLazy);
}

VectorPtr DecodedVector::decodeAndGetBase(
const VectorPtr& vector,
bool loadLazy) {
auto sharedBase = decodeImpl(vector, nullptr, loadLazy);
VELOX_CHECK(sharedBase.get() == baseVector_);
return sharedBase;
}

void DecodedVector::makeIndices(
Expand All @@ -101,7 +156,8 @@ void DecodedVector::makeIndices(
}

reset(end(vector.size(), rows));
combineWrappers(&vector, rows, numLevels);
VectorPtr sharedPtr;
combineWrappers(&vector, rows, sharedPtr, numLevels);
}

void DecodedVector::reset(vector_size_t size) {
Expand Down Expand Up @@ -133,15 +189,17 @@ void DecodedVector::copyNulls(vector_size_t size) {
nulls_ = copiedNulls_.data();
}

template <typename T>
void DecodedVector::combineWrappers(
const BaseVector* vector,
const T& vector,
const SelectivityVector* rows,
VectorPtr& sharedBase,
int numLevels) {
auto topEncoding = vector->encoding();
BaseVector* values = nullptr;
T values;
if (topEncoding == VectorEncoding::Simple::DICTIONARY) {
indices_ = vector->wrapInfo()->as<vector_size_t>();
values = vector->valueVector().get();
indices_ = vector->wrapInfo()->template as<vector_size_t>();
values = getValueVector(vector);
nulls_ = vector->rawNulls();
if (nulls_) {
hasExtraNulls_ = true;
Expand All @@ -155,14 +213,19 @@ void DecodedVector::combineWrappers(
int32_t levelCounter = 0;
for (;;) {
if (numLevels != -1 && ++levelCounter == numLevels) {
baseVector_ = values;
if constexpr (std::is_same_v<T, VectorPtr>) {
// We get the shared base vector only in case numLevels == -1.
VELOX_UNREACHABLE();
} else {
baseVector_ = values;
}
return;
}

auto encoding = values->encoding();
if (isLazy(encoding) &&
(loadLazy_ || values->asUnchecked<LazyVector>()->isLoaded())) {
values = values->loadedVector();
(loadLazy_ || values->template asUnchecked<LazyVector>()->isLoaded())) {
values = getLoadedVector(values);
encoding = values->encoding();
}

Expand All @@ -174,13 +237,12 @@ void DecodedVector::combineWrappers(
case VectorEncoding::Simple::ROW:
case VectorEncoding::Simple::ARRAY:
case VectorEncoding::Simple::MAP:
setBaseData(*values, rows);
setBaseData(values, rows, sharedBase);
return;
case VectorEncoding::Simple::DICTIONARY: {
case VectorEncoding::Simple::DICTIONARY:
applyDictionaryWrapper(*values, rows);
values = values->valueVector().get();
values = getValueVector(values);
break;
}
default:
VELOX_CHECK(false, "Unsupported vector encoding");
}
Expand Down Expand Up @@ -226,7 +288,7 @@ void DecodedVector::applyDictionaryWrapper(
});
}

void DecodedVector::fillInIndices() {
void DecodedVector::fillInIndices() const {
if (isConstantMapping_) {
if (size_ > zeroIndices().size() || constantIndex_ != 0) {
copiedIndices_.resize(size_);
Expand Down Expand Up @@ -284,60 +346,72 @@ void DecodedVector::setFlatNulls(
}
}

template <typename T>
void DecodedVector::setBaseData(
const BaseVector& vector,
const SelectivityVector* rows) {
auto encoding = vector.encoding();
baseVector_ = &vector;
const T& vector,
const SelectivityVector* rows,
VectorPtr& sharedBase) {
auto encoding = vector->encoding();
if constexpr (std::is_same_v<T, VectorPtr>) {
sharedBase = vector;
baseVector_ = vector.get();
} else {
baseVector_ = vector;
}
switch (encoding) {
case VectorEncoding::Simple::LAZY:
break;
case VectorEncoding::Simple::FLAT: {
case VectorEncoding::Simple::FLAT:
// values() may be nullptr if 'vector' is all nulls.
data_ = vector.values() ? vector.values()->as<void>() : nullptr;
setFlatNulls(vector, rows);
data_ =
vector->values() ? vector->values()->template as<void>() : nullptr;
setFlatNulls(*vector, rows);
break;
}
case VectorEncoding::Simple::ROW:
case VectorEncoding::Simple::ARRAY:
case VectorEncoding::Simple::MAP: {
setFlatNulls(vector, rows);
case VectorEncoding::Simple::MAP:
setFlatNulls(*vector, rows);
break;
}
case VectorEncoding::Simple::CONSTANT: {
setBaseDataForConstant(vector, rows);
case VectorEncoding::Simple::CONSTANT:
setBaseDataForConstant(vector, rows, sharedBase);
break;
}
default:
VELOX_UNREACHABLE();
}
}

template <typename T>
void DecodedVector::setBaseDataForConstant(
const BaseVector& vector,
const SelectivityVector* rows) {
if (!vector.isScalar()) {
baseVector_ = vector.wrappedVector();
constantIndex_ = vector.wrappedIndex(0);
const T& vector,
const SelectivityVector* rows,
VectorPtr& sharedBase) {
if (!vector->isScalar()) {
if constexpr (std::is_same_v<T, VectorPtr>) {
sharedBase = BaseVector::wrappedVectorShared(vector);
baseVector_ = sharedBase.get();
} else {
baseVector_ = vector->wrappedVector();
}
constantIndex_ = vector->wrappedIndex(0);
}
if (!hasExtraNulls_ || vector.isNullAt(0)) {
if (!hasExtraNulls_ || vector->isNullAt(0)) {
// A mapping over a constant is constant except if the
// mapping adds nulls and the constant is not null.
isConstantMapping_ = true;
hasExtraNulls_ = false;
indices_ = nullptr;
nulls_ = vector.isNullAt(0) ? &constantNullMask_ : nullptr;
nulls_ = vector->isNullAt(0) ? &constantNullMask_ : nullptr;
} else {
makeIndicesMutable();

applyToRows(rows, [this](vector_size_t row) {
copiedIndices_[row] = constantIndex_;
});
setFlatNulls(vector, rows);
setFlatNulls(*vector, rows);
}
data_ = vector.valuesAsVoid();
data_ = vector->valuesAsVoid();
if (!nulls_) {
nulls_ = vector.isNullAt(0) ? &constantNullMask_ : nullptr;
nulls_ = vector->isNullAt(0) ? &constantNullMask_ : nullptr;
}
mayHaveNulls_ = hasExtraNulls_ || nulls_;
}
Expand Down Expand Up @@ -374,25 +448,23 @@ BufferPtr copyNullsBuffer(
} // namespace

DecodedVector::DictionaryWrapping DecodedVector::dictionaryWrapping(
const BaseVector& wrapper,
memory::MemoryPool& pool,
vector_size_t size) const {
VELOX_CHECK(!isIdentityMapping_);
VELOX_CHECK(!isConstantMapping_);
VELOX_CHECK_LE(size, size_);

// Make a copy of the indices and nulls buffers.
BufferPtr indices = copyIndicesBuffer(indices_, size, wrapper.pool());
BufferPtr indices = copyIndicesBuffer(this->indices(), size, &pool);
// Only copy nulls if we have nulls coming from one of the wrappers, don't
// do it if nulls are missing or from the base vector.
// TODO: remove the check for hasExtraNulls_ after #3553 is merged.
BufferPtr nulls =
hasExtraNulls_ ? copyNullsBuffer(nulls_, size, wrapper.pool()) : nullptr;
hasExtraNulls_ ? copyNullsBuffer(nulls_, size, &pool) : nullptr;
return {std::move(indices), std::move(nulls)};
}

VectorPtr DecodedVector::wrap(
VectorPtr data,
const BaseVector& wrapper,
memory::MemoryPool& pool,
vector_size_t size) {
if (isConstantMapping_) {
if (isNullAt(0)) {
Expand All @@ -406,7 +478,7 @@ VectorPtr DecodedVector::wrap(
return BaseVector::wrapInConstant(size, constantIndex_, data);
}

auto wrapping = dictionaryWrapping(wrapper, size);
auto wrapping = dictionaryWrapping(pool, size);
return BaseVector::wrapInDictionary(
std::move(wrapping.nulls),
std::move(wrapping.indices),
Expand Down
Loading

0 comments on commit 22ad3c0

Please sign in to comment.