Skip to content

Commit

Permalink
speed up filters construction (ydb-platform#7934)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmorozov333 authored Aug 17, 2024
1 parent 15fd269 commit c03e8b2
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 53 deletions.
94 changes: 47 additions & 47 deletions ydb/core/formats/arrow/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TConstFunction : public IStepFunction<TAssign> {
using TBase = IStepFunction<TAssign>;
public:
using TBase::TBase;
arrow::Result<arrow::Datum> Call(const TAssign& assign, const TDatumBatch& batch) const override {
arrow::Result<arrow::Datum> Call(const TAssign& assign, const TDatumBatch& batch) const override {
Y_UNUSED(batch);
return assign.GetConstant();
}
Expand Down Expand Up @@ -531,7 +531,7 @@ class TFilterVisitor : public arrow::ArrayVisitor {


arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& column) {
if (Schema->GetFieldIndex(name) != -1) {
if (HasColumn(name)) {
return arrow::Status::Invalid("Trying to add duplicate column '" + name + "'");
}

Expand All @@ -543,20 +543,27 @@ arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& col
return arrow::Status::Invalid("Wrong column length.");
}

Schema = *Schema->AddField(Schema->num_fields(), field);
NewColumnIds.emplace(name, NewColumnsPtr.size());
NewColumnsPtr.emplace_back(field);

Datums.emplace_back(column);
return arrow::Status::OK();
}

arrow::Result<arrow::Datum> TDatumBatch::GetColumnByName(const std::string& name) const {
auto i = Schema->GetFieldIndex(name);
auto it = NewColumnIds.find(name);
if (it != NewColumnIds.end()) {
AFL_VERIFY(SchemaBase->num_fields() + it->second < Datums.size());
return Datums[SchemaBase->num_fields() + it->second];
}
auto i = SchemaBase->GetFieldIndex(name);
if (i < 0) {
return arrow::Status::Invalid("Not found column '" + name + "' or duplicate");
}
return Datums[i];
}

std::shared_ptr<arrow::Table> TDatumBatch::ToTable() const {
std::shared_ptr<arrow::Table> TDatumBatch::ToTable() {
std::vector<std::shared_ptr<arrow::ChunkedArray>> columns;
columns.reserve(Datums.size());
for (auto col : Datums) {
Expand All @@ -576,10 +583,10 @@ std::shared_ptr<arrow::Table> TDatumBatch::ToTable() const {
AFL_VERIFY(false);
}
}
return arrow::Table::Make(Schema, columns, Rows);
return arrow::Table::Make(GetSchema(), columns, Rows);
}

std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() const {
std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() {
std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(Datums.size());
for (auto col : Datums) {
Expand All @@ -594,7 +601,7 @@ std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() const {
AFL_VERIFY(false);
}
}
return arrow::RecordBatch::Make(Schema, Rows, columns);
return arrow::RecordBatch::Make(GetSchema(), Rows, columns);
}

std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch) {
Expand All @@ -603,12 +610,7 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch(const std::shared_ptr<
for (int64_t i = 0; i < batch->num_columns(); ++i) {
datums.push_back(arrow::Datum(batch->column(i)));
}
return std::make_shared<TProgramStep::TDatumBatch>(
TProgramStep::TDatumBatch{
.Schema = std::make_shared<arrow::Schema>(*batch->schema()),
.Datums = std::move(datums),
.Rows = batch->num_rows()
});
return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema()), std::move(datums), batch->num_rows());
}

std::shared_ptr<TDatumBatch> TDatumBatch::FromTable(const std::shared_ptr<arrow::Table>& batch) {
Expand All @@ -617,12 +619,15 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromTable(const std::shared_ptr<arrow:
for (int64_t i = 0; i < batch->num_columns(); ++i) {
datums.push_back(arrow::Datum(batch->column(i)));
}
return std::make_shared<TProgramStep::TDatumBatch>(
TProgramStep::TDatumBatch{
.Schema = std::make_shared<arrow::Schema>(*batch->schema()),
.Datums = std::move(datums),
.Rows = batch->num_rows()
});
return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema()), std::move(datums), batch->num_rows());
}

TDatumBatch::TDatumBatch(const std::shared_ptr<arrow::Schema>& schema, std::vector<arrow::Datum>&& datums, const i64 rows)
: SchemaBase(schema)
, Rows(rows)
, Datums(std::move(datums)) {
AFL_VERIFY(SchemaBase);
AFL_VERIFY(Datums.size() == (ui32)SchemaBase->num_fields());
}

TAssign TAssign::MakeTimestamp(const TColumnInfo& column, ui64 value) {
Expand Down Expand Up @@ -680,7 +685,7 @@ arrow::Status TProgramStep::ApplyAssignes(TDatumBatch& batch, arrow::compute::Ex
}
batch.Datums.reserve(batch.Datums.size() + Assignes.size());
for (auto& assign : Assignes) {
if (batch.GetColumnByName(assign.GetName()).ok()) {
if (batch.HasColumn(assign.GetName())) {
return arrow::Status::Invalid("Assign to existing column '" + assign.GetName() + "'.");
}

Expand All @@ -703,8 +708,9 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
}

ui32 numResultColumns = GroupBy.size() + GroupByKeys.size();
TDatumBatch res;
res.Datums.reserve(numResultColumns);
std::vector<arrow::Datum> datums;
datums.reserve(numResultColumns);
std::optional<ui32> resultRecordsCount;

arrow::FieldVector fields;
fields.reserve(numResultColumns);
Expand All @@ -715,13 +721,13 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
if (!funcResult.ok()) {
return funcResult.status();
}
res.Datums.push_back(*funcResult);
fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), res.Datums.back().type()));
datums.push_back(*funcResult);
fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), datums.back().type()));
}
res.Rows = 1;
resultRecordsCount = 1;
} else {
CH::GroupByOptions funcOpts;
funcOpts.schema = batch.Schema;
funcOpts.schema = batch.GetSchema();
funcOpts.assigns.reserve(numResultColumns);
funcOpts.has_nullable_key = false;

Expand Down Expand Up @@ -759,19 +765,18 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
return arrow::Status::Invalid("No expected column in GROUP BY result.");
}
fields.emplace_back(std::make_shared<arrow::Field>(assign.result_column, column->type()));
res.Datums.push_back(column);
datums.push_back(column);
}

res.Rows = gbBatch->num_rows();
resultRecordsCount = gbBatch->num_rows();
}

res.Schema = std::make_shared<arrow::Schema>(std::move(fields));
batch = std::move(res);
AFL_VERIFY(resultRecordsCount);
batch = TDatumBatch(std::make_shared<arrow::Schema>(std::move(fields)), std::move(datums), *resultRecordsCount);
return arrow::Status::OK();
}

arrow::Status TProgramStep::MakeCombinedFilter(TDatumBatch& batch, NArrow::TColumnFilter& result) const {
TFilterVisitor filterVisitor(batch.Rows);
TFilterVisitor filterVisitor(batch.GetRecordsCount());
for (auto& colName : Filters) {
auto column = batch.GetColumnByName(colName.GetColumnName());
if (!column.ok()) {
Expand Down Expand Up @@ -821,13 +826,13 @@ arrow::Status TProgramStep::ApplyFilters(TDatumBatch& batch) const {
}
}
std::vector<arrow::Datum*> filterDatums;
for (int64_t i = 0; i < batch.Schema->num_fields(); ++i) {
if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.Schema->field(i)->name()))) {
for (int64_t i = 0; i < batch.GetSchema()->num_fields(); ++i) {
if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.GetSchema()->field(i)->name()))) {
filterDatums.emplace_back(&batch.Datums[i]);
}
}
bits.Apply(batch.Rows, filterDatums);
batch.Rows = bits.GetFilteredCount().value_or(batch.Rows);
bits.Apply(batch.GetRecordsCount(), filterDatums);
batch.SetRecordsCount(bits.GetFilteredCount().value_or(batch.GetRecordsCount()));
return arrow::Status::OK();
}

Expand All @@ -838,15 +843,14 @@ arrow::Status TProgramStep::ApplyProjection(TDatumBatch& batch) const {
std::vector<std::shared_ptr<arrow::Field>> newFields;
std::vector<arrow::Datum> newDatums;
for (size_t i = 0; i < Projection.size(); ++i) {
int schemaFieldIndex = batch.Schema->GetFieldIndex(Projection[i].GetColumnName());
int schemaFieldIndex = batch.GetSchema()->GetFieldIndex(Projection[i].GetColumnName());
if (schemaFieldIndex == -1) {
return arrow::Status::Invalid("Could not find column " + Projection[i].GetColumnName() + " in record batch schema.");
}
newFields.push_back(batch.Schema->field(schemaFieldIndex));
newFields.push_back(batch.GetSchema()->field(schemaFieldIndex));
newDatums.push_back(batch.Datums[schemaFieldIndex]);
}
batch.Schema = std::make_shared<arrow::Schema>(std::move(newFields));
batch.Datums = std::move(newDatums);
batch = TDatumBatch(std::make_shared<arrow::Schema>(std::move(newFields)), std::move(newDatums), batch.GetRecordsCount());
return arrow::Status::OK();
}

Expand Down Expand Up @@ -919,14 +923,10 @@ std::set<std::string> TProgramStep::GetColumnsInUsage(const bool originalOnly/*
}

arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(const std::shared_ptr<NArrow::TGeneralContainer>& t) const {
return BuildFilter(t->BuildTableVerified(GetColumnsInUsage(true)));
}

arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(const std::shared_ptr<arrow::Table>& t) const {
if (Filters.empty()) {
return nullptr;
}
std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches(t);
std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches(t->BuildTableVerified(GetColumnsInUsage(true)));
NArrow::TColumnFilter fullLocal = NArrow::TColumnFilter::BuildAllowFilter();
for (auto&& rb : batches) {
auto datumBatch = TDatumBatch::FromRecordBatch(rb);
Expand All @@ -938,7 +938,7 @@ arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(
}
NArrow::TColumnFilter local = NArrow::TColumnFilter::BuildAllowFilter();
NArrow::TStatusValidator::Validate(MakeCombinedFilter(*datumBatch, local));
AFL_VERIFY(local.Size() == datumBatch->Rows)("local", local.Size())("datum", datumBatch->Rows);
AFL_VERIFY(local.Size() == datumBatch->GetRecordsCount())("local", local.Size())("datum", datumBatch->GetRecordsCount());
fullLocal.Append(local);
}
AFL_VERIFY(fullLocal.Size() == t->num_rows())("filter", fullLocal.Size())("t", t->num_rows());
Expand Down
43 changes: 37 additions & 6 deletions ydb/core/formats/arrow/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,47 @@ const char * GetHouseFunctionName(EAggregate op);
inline const char * GetHouseGroupByName() { return "ch.group_by"; }
EOperation ValidateOperation(EOperation op, ui32 argsSize);

struct TDatumBatch {
std::shared_ptr<arrow::Schema> Schema;
std::vector<arrow::Datum> Datums;
class TDatumBatch {
private:
std::shared_ptr<arrow::Schema> SchemaBase;
THashMap<std::string, ui32> NewColumnIds;
std::vector<std::shared_ptr<arrow::Field>> NewColumnsPtr;
int64_t Rows = 0;

public:
std::vector<arrow::Datum> Datums;

ui64 GetRecordsCount() const {
return Rows;
}

void SetRecordsCount(const ui64 value) {
Rows = value;
}

TDatumBatch(const std::shared_ptr<arrow::Schema>& schema, std::vector<arrow::Datum>&& datums, const i64 rows);

const std::shared_ptr<arrow::Schema>& GetSchema() {
if (NewColumnIds.size()) {
std::vector<std::shared_ptr<arrow::Field>> fields = SchemaBase->fields();
fields.insert(fields.end(), NewColumnsPtr.begin(), NewColumnsPtr.end());
SchemaBase = std::make_shared<arrow::Schema>(fields);
NewColumnIds.clear();
NewColumnsPtr.clear();
}
return SchemaBase;
}

arrow::Status AddColumn(const std::string& name, arrow::Datum&& column);
arrow::Result<arrow::Datum> GetColumnByName(const std::string& name) const;
std::shared_ptr<arrow::Table> ToTable() const;
std::shared_ptr<arrow::RecordBatch> ToRecordBatch() const;
bool HasColumn(const std::string& name) const {
if (NewColumnIds.contains(name)) {
return true;
}
return SchemaBase->GetFieldIndex(name) > -1;
}
std::shared_ptr<arrow::Table> ToTable();
std::shared_ptr<arrow::RecordBatch> ToRecordBatch();
static std::shared_ptr<TDatumBatch> FromRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch);
static std::shared_ptr<TDatumBatch> FromTable(const std::shared_ptr<arrow::Table>& batch);
};
Expand Down Expand Up @@ -405,7 +437,6 @@ class TProgramStep {
return Filters.size() && (!GroupBy.size() && !GroupByKeys.size());
}

[[nodiscard]] arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> BuildFilter(const std::shared_ptr<arrow::Table>& t) const;
[[nodiscard]] arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> BuildFilter(const std::shared_ptr<NArrow::TGeneralContainer>& t) const;
};

Expand Down

0 comments on commit c03e8b2

Please sign in to comment.