From 3b4a6b122a29c1d40623ba22e0c231c5540d84b7 Mon Sep 17 00:00:00 2001 From: Francis <455954986@qq.com> Date: Fri, 8 Sep 2023 00:10:24 +0800 Subject: [PATCH] GH-37170: [C++] Support schema rewriting of RecordBatch. (#37171) ### Rationale for this change We have a scene. There is a plan in pg that looks like the following. For the Append node, there are two scans in parallel, and then there is a column of data, but the column names are different. If it is mapped to the arrow schema It is a different field. For the append node, we will get two batches. The first batch comes from the first scan, and the second batch comes from the second scan, but because the two columns are constructed based on the scan The schema is different, so the final schema of the two batches is different. When we construct the slot returned by the Append node, we use the schema of the first batch. When we put the data of the second batch into it, the verification fails due to inconsistent shcema. Therefore, the problem is simplified to: For a node, If there are n child nodes, the schema of the following child nodes must be consistent. If not, the schema of n-1 child nodes must be the same as the first schema, so there is logic to rewrite the schema of the batch data. ``` -> Vec Append -> Vec Seq Scan on public. tenk1 Output: tenk1.unique1 -> Vec Seq Scan on public.tenk1 tenk1_1 Output: tenk1_1.fivethous ``` However, when reading the batch code, there is only the read-only interface schema(), so here we submit a pr to add and rewrite the schema interface, and only modify the columns with the same type. If they are not the same, an invalid modification will be returned. backgroud: https://github.com/apache/arrow/issues/37170 ### What changes are included in this PR? - record_batch.h - record_batch.cc - record_batch_test.cc ### Are these changes tested? yes, see record_batch_test.cc. gtest filter is: ``` TestRecordBatch.RewriteSchema ``` ### Are there any user-facing changes? yes: see background in issue. * Closes: #37170 Authored-by: light-city <455954986@qq.com> Signed-off-by: David Li --- cpp/src/arrow/record_batch.cc | 19 +++++++++++++++++ cpp/src/arrow/record_batch.h | 5 +++++ cpp/src/arrow/record_batch_test.cc | 34 ++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 1c5c8912e5a0b..f0ee295c6347d 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -283,6 +283,25 @@ bool RecordBatch::ApproxEquals(const RecordBatch& other, const EqualOptions& opt return true; } +Result> RecordBatch::ReplaceSchema( + std::shared_ptr schema) const { + if (schema_->num_fields() != schema->num_fields()) + return Status::Invalid("RecordBatch schema fields", schema_->num_fields(), + ", did not match new schema fields: ", schema->num_fields()); + auto fields = schema_->fields(); + int n_fields = static_cast(fields.size()); + for (int i = 0; i < n_fields; i++) { + auto old_type = fields[i]->type(); + auto replace_type = schema->field(i)->type(); + if (!old_type->Equals(replace_type)) { + return Status::Invalid( + "RecordBatch schema field index ", i, " type is ", old_type->ToString(), + ", did not match new schema field type: ", replace_type->ToString()); + } + } + return RecordBatch::Make(std::move(schema), num_rows(), columns()); +} + Result> RecordBatch::SelectColumns( const std::vector& indices) const { int n = static_cast(indices.size()); diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index d728d5eb0da2f..cb1f6d54f7cff 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -114,6 +114,11 @@ class ARROW_EXPORT RecordBatch { /// \return the record batch's schema const std::shared_ptr& schema() const { return schema_; } + /// \brief Replace the schema with another schema with the same types, but potentially + /// different field names and/or metadata. + Result> ReplaceSchema( + std::shared_ptr schema) const; + /// \brief Retrieve all columns at once virtual const std::vector>& columns() const = 0; diff --git a/cpp/src/arrow/record_batch_test.cc b/cpp/src/arrow/record_batch_test.cc index e8180c6740879..bc923a1444160 100644 --- a/cpp/src/arrow/record_batch_test.cc +++ b/cpp/src/arrow/record_batch_test.cc @@ -521,4 +521,38 @@ TEST_F(TestRecordBatchReader, ToTable) { ASSERT_EQ(table->column(0)->chunks().size(), 0); } +TEST_F(TestRecordBatch, ReplaceSchema) { + const int length = 10; + + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8()); + auto f2 = field("f2", int16()); + auto f3 = field("f3", int8()); + + auto schema = ::arrow::schema({f0, f1, f2}); + + random::RandomArrayGenerator gen(42); + + auto a0 = gen.ArrayOf(int32(), length); + auto a1 = gen.ArrayOf(uint8(), length); + auto a2 = gen.ArrayOf(int16(), length); + + auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2}); + + f0 = field("fd0", int32()); + f1 = field("fd1", uint8()); + f2 = field("fd2", int16()); + + schema = ::arrow::schema({f0, f1, f2}); + ASSERT_OK_AND_ASSIGN(auto mutated, b1->ReplaceSchema(schema)); + auto expected = RecordBatch::Make(schema, length, b1->columns()); + ASSERT_TRUE(mutated->Equals(*expected)); + + schema = ::arrow::schema({f0, f1, f3}); + ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema)); + + schema = ::arrow::schema({f0, f1}); + ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema)); +} + } // namespace arrow