diff --git a/ydb/core/formats/arrow/program.cpp b/ydb/core/formats/arrow/program.cpp index e07f76ed3b49..60e59749bb7a 100644 --- a/ydb/core/formats/arrow/program.cpp +++ b/ydb/core/formats/arrow/program.cpp @@ -88,7 +88,7 @@ class TConstFunction : public IStepFunction { using TBase = IStepFunction; public: using TBase::TBase; - arrow::Result Call(const TAssign& assign, const TDatumBatch& batch) const override { + arrow::Result Call(const TAssign& assign, const TDatumBatch& batch) const override { Y_UNUSED(batch); return assign.GetConstant(); } @@ -531,7 +531,7 @@ class TFilterVisitor : public arrow::ArrayVisitor { arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& column) { - if (Schema->GetFieldIndex(name) != -1) { + if (HasColumn(name)) { return arrow::Status::Invalid("Trying to add duplicate column '" + name + "'"); } @@ -543,20 +543,27 @@ arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& col return arrow::Status::Invalid("Wrong column length."); } - Schema = *Schema->AddField(Schema->num_fields(), field); + NewColumnIds.emplace(name, NewColumnsPtr.size()); + NewColumnsPtr.emplace_back(field); + Datums.emplace_back(column); return arrow::Status::OK(); } arrow::Result TDatumBatch::GetColumnByName(const std::string& name) const { - auto i = Schema->GetFieldIndex(name); + auto it = NewColumnIds.find(name); + if (it != NewColumnIds.end()) { + AFL_VERIFY(SchemaBase->num_fields() + it->second < Datums.size()); + return Datums[SchemaBase->num_fields() + it->second]; + } + auto i = SchemaBase->GetFieldIndex(name); if (i < 0) { return arrow::Status::Invalid("Not found column '" + name + "' or duplicate"); } return Datums[i]; } -std::shared_ptr TDatumBatch::ToTable() const { +std::shared_ptr TDatumBatch::ToTable() { std::vector> columns; columns.reserve(Datums.size()); for (auto col : Datums) { @@ -576,10 +583,10 @@ std::shared_ptr TDatumBatch::ToTable() const { AFL_VERIFY(false); } } - return arrow::Table::Make(Schema, columns, Rows); + return arrow::Table::Make(GetSchema(), columns, Rows); } -std::shared_ptr TDatumBatch::ToRecordBatch() const { +std::shared_ptr TDatumBatch::ToRecordBatch() { std::vector> columns; columns.reserve(Datums.size()); for (auto col : Datums) { @@ -594,7 +601,7 @@ std::shared_ptr TDatumBatch::ToRecordBatch() const { AFL_VERIFY(false); } } - return arrow::RecordBatch::Make(Schema, Rows, columns); + return arrow::RecordBatch::Make(GetSchema(), Rows, columns); } std::shared_ptr TDatumBatch::FromRecordBatch(const std::shared_ptr& batch) { @@ -603,12 +610,7 @@ std::shared_ptr TDatumBatch::FromRecordBatch(const std::shared_ptr< for (int64_t i = 0; i < batch->num_columns(); ++i) { datums.push_back(arrow::Datum(batch->column(i))); } - return std::make_shared( - TProgramStep::TDatumBatch{ - .Schema = std::make_shared(*batch->schema()), - .Datums = std::move(datums), - .Rows = batch->num_rows() - }); + return std::make_shared(std::make_shared(*batch->schema()), std::move(datums), batch->num_rows()); } std::shared_ptr TDatumBatch::FromTable(const std::shared_ptr& batch) { @@ -617,12 +619,15 @@ std::shared_ptr TDatumBatch::FromTable(const std::shared_ptrnum_columns(); ++i) { datums.push_back(arrow::Datum(batch->column(i))); } - return std::make_shared( - TProgramStep::TDatumBatch{ - .Schema = std::make_shared(*batch->schema()), - .Datums = std::move(datums), - .Rows = batch->num_rows() - }); + return std::make_shared(std::make_shared(*batch->schema()), std::move(datums), batch->num_rows()); +} + +TDatumBatch::TDatumBatch(const std::shared_ptr& schema, std::vector&& datums, const i64 rows) + : SchemaBase(schema) + , Rows(rows) + , Datums(std::move(datums)) { + AFL_VERIFY(SchemaBase); + AFL_VERIFY(Datums.size() == (ui32)SchemaBase->num_fields()); } TAssign TAssign::MakeTimestamp(const TColumnInfo& column, ui64 value) { @@ -680,7 +685,7 @@ arrow::Status TProgramStep::ApplyAssignes(TDatumBatch& batch, arrow::compute::Ex } batch.Datums.reserve(batch.Datums.size() + Assignes.size()); for (auto& assign : Assignes) { - if (batch.GetColumnByName(assign.GetName()).ok()) { + if (batch.HasColumn(assign.GetName())) { return arrow::Status::Invalid("Assign to existing column '" + assign.GetName() + "'."); } @@ -703,8 +708,9 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute:: } ui32 numResultColumns = GroupBy.size() + GroupByKeys.size(); - TDatumBatch res; - res.Datums.reserve(numResultColumns); + std::vector datums; + datums.reserve(numResultColumns); + std::optional resultRecordsCount; arrow::FieldVector fields; fields.reserve(numResultColumns); @@ -715,13 +721,13 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute:: if (!funcResult.ok()) { return funcResult.status(); } - res.Datums.push_back(*funcResult); - fields.emplace_back(std::make_shared(assign.GetName(), res.Datums.back().type())); + datums.push_back(*funcResult); + fields.emplace_back(std::make_shared(assign.GetName(), datums.back().type())); } - res.Rows = 1; + resultRecordsCount = 1; } else { CH::GroupByOptions funcOpts; - funcOpts.schema = batch.Schema; + funcOpts.schema = batch.GetSchema(); funcOpts.assigns.reserve(numResultColumns); funcOpts.has_nullable_key = false; @@ -759,19 +765,18 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute:: return arrow::Status::Invalid("No expected column in GROUP BY result."); } fields.emplace_back(std::make_shared(assign.result_column, column->type())); - res.Datums.push_back(column); + datums.push_back(column); } - res.Rows = gbBatch->num_rows(); + resultRecordsCount = gbBatch->num_rows(); } - - res.Schema = std::make_shared(std::move(fields)); - batch = std::move(res); + AFL_VERIFY(resultRecordsCount); + batch = TDatumBatch(std::make_shared(std::move(fields)), std::move(datums), *resultRecordsCount); return arrow::Status::OK(); } arrow::Status TProgramStep::MakeCombinedFilter(TDatumBatch& batch, NArrow::TColumnFilter& result) const { - TFilterVisitor filterVisitor(batch.Rows); + TFilterVisitor filterVisitor(batch.GetRecordsCount()); for (auto& colName : Filters) { auto column = batch.GetColumnByName(colName.GetColumnName()); if (!column.ok()) { @@ -821,13 +826,13 @@ arrow::Status TProgramStep::ApplyFilters(TDatumBatch& batch) const { } } std::vector filterDatums; - for (int64_t i = 0; i < batch.Schema->num_fields(); ++i) { - if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.Schema->field(i)->name()))) { + for (int64_t i = 0; i < batch.GetSchema()->num_fields(); ++i) { + if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.GetSchema()->field(i)->name()))) { filterDatums.emplace_back(&batch.Datums[i]); } } - bits.Apply(batch.Rows, filterDatums); - batch.Rows = bits.GetFilteredCount().value_or(batch.Rows); + bits.Apply(batch.GetRecordsCount(), filterDatums); + batch.SetRecordsCount(bits.GetFilteredCount().value_or(batch.GetRecordsCount())); return arrow::Status::OK(); } @@ -838,15 +843,14 @@ arrow::Status TProgramStep::ApplyProjection(TDatumBatch& batch) const { std::vector> newFields; std::vector newDatums; for (size_t i = 0; i < Projection.size(); ++i) { - int schemaFieldIndex = batch.Schema->GetFieldIndex(Projection[i].GetColumnName()); + int schemaFieldIndex = batch.GetSchema()->GetFieldIndex(Projection[i].GetColumnName()); if (schemaFieldIndex == -1) { return arrow::Status::Invalid("Could not find column " + Projection[i].GetColumnName() + " in record batch schema."); } - newFields.push_back(batch.Schema->field(schemaFieldIndex)); + newFields.push_back(batch.GetSchema()->field(schemaFieldIndex)); newDatums.push_back(batch.Datums[schemaFieldIndex]); } - batch.Schema = std::make_shared(std::move(newFields)); - batch.Datums = std::move(newDatums); + batch = TDatumBatch(std::make_shared(std::move(newFields)), std::move(newDatums), batch.GetRecordsCount()); return arrow::Status::OK(); } @@ -919,14 +923,10 @@ std::set TProgramStep::GetColumnsInUsage(const bool originalOnly/* } arrow::Result> TProgramStep::BuildFilter(const std::shared_ptr& t) const { - return BuildFilter(t->BuildTableVerified(GetColumnsInUsage(true))); -} - -arrow::Result> TProgramStep::BuildFilter(const std::shared_ptr& t) const { if (Filters.empty()) { return nullptr; } - std::vector> batches = NArrow::SliceToRecordBatches(t); + std::vector> batches = NArrow::SliceToRecordBatches(t->BuildTableVerified(GetColumnsInUsage(true))); NArrow::TColumnFilter fullLocal = NArrow::TColumnFilter::BuildAllowFilter(); for (auto&& rb : batches) { auto datumBatch = TDatumBatch::FromRecordBatch(rb); @@ -938,7 +938,7 @@ arrow::Result> TProgramStep::BuildFilter( } NArrow::TColumnFilter local = NArrow::TColumnFilter::BuildAllowFilter(); NArrow::TStatusValidator::Validate(MakeCombinedFilter(*datumBatch, local)); - AFL_VERIFY(local.Size() == datumBatch->Rows)("local", local.Size())("datum", datumBatch->Rows); + AFL_VERIFY(local.Size() == datumBatch->GetRecordsCount())("local", local.Size())("datum", datumBatch->GetRecordsCount()); fullLocal.Append(local); } AFL_VERIFY(fullLocal.Size() == t->num_rows())("filter", fullLocal.Size())("t", t->num_rows()); diff --git a/ydb/core/formats/arrow/program.h b/ydb/core/formats/arrow/program.h index dfb22116158b..e3f9943e6c13 100644 --- a/ydb/core/formats/arrow/program.h +++ b/ydb/core/formats/arrow/program.h @@ -37,15 +37,47 @@ const char * GetHouseFunctionName(EAggregate op); inline const char * GetHouseGroupByName() { return "ch.group_by"; } EOperation ValidateOperation(EOperation op, ui32 argsSize); -struct TDatumBatch { - std::shared_ptr Schema; - std::vector Datums; +class TDatumBatch { +private: + std::shared_ptr SchemaBase; + THashMap NewColumnIds; + std::vector> NewColumnsPtr; int64_t Rows = 0; +public: + std::vector Datums; + + ui64 GetRecordsCount() const { + return Rows; + } + + void SetRecordsCount(const ui64 value) { + Rows = value; + } + + TDatumBatch(const std::shared_ptr& schema, std::vector&& datums, const i64 rows); + + const std::shared_ptr& GetSchema() { + if (NewColumnIds.size()) { + std::vector> fields = SchemaBase->fields(); + fields.insert(fields.end(), NewColumnsPtr.begin(), NewColumnsPtr.end()); + SchemaBase = std::make_shared(fields); + NewColumnIds.clear(); + NewColumnsPtr.clear(); + } + return SchemaBase; + } + arrow::Status AddColumn(const std::string& name, arrow::Datum&& column); arrow::Result GetColumnByName(const std::string& name) const; - std::shared_ptr ToTable() const; - std::shared_ptr ToRecordBatch() const; + bool HasColumn(const std::string& name) const { + if (NewColumnIds.contains(name)) { + return true; + } + return SchemaBase->GetFieldIndex(name) > -1; + } + std::shared_ptr ToTable(); + std::shared_ptr ToRecordBatch(); static std::shared_ptr FromRecordBatch(const std::shared_ptr& batch); static std::shared_ptr FromTable(const std::shared_ptr& batch); }; @@ -405,7 +437,6 @@ class TProgramStep { return Filters.size() && (!GroupBy.size() && !GroupByKeys.size()); } - [[nodiscard]] arrow::Result> BuildFilter(const std::shared_ptr& t) const; [[nodiscard]] arrow::Result> BuildFilter(const std::shared_ptr& t) const; };