diff --git a/dbms/src/Debug/MockComputeServerManager.cpp b/dbms/src/Debug/MockComputeServerManager.cpp new file mode 100644 index 00000000000..6b67eb76b9b --- /dev/null +++ b/dbms/src/Debug/MockComputeServerManager.cpp @@ -0,0 +1,101 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int IP_ADDRESS_NOT_ALLOWED; +} // namespace ErrorCodes +namespace tests +{ +void MockComputeServerManager::addServer(String addr) +{ + MockServerConfig config; + for (const auto & server : server_config_map) + { + RUNTIME_CHECK_MSG( + server.second.addr != addr, + "Already register mock compute server with addr = {}", + addr); + } + config.partition_id = server_config_map.size(); + config.addr = addr; + server_config_map[config.partition_id] = config; +} + +void MockComputeServerManager::startServers(const LoggerPtr & log_ptr, Context & global_context) +{ + global_context.setMPPTest(); + for (const auto & server_config : server_config_map) + { + TiFlashSecurityConfig security_config; + TiFlashRaftConfig raft_config; + raft_config.flash_server_addr = server_config.second.addr; + Poco::AutoPtr config = new Poco::Util::LayeredConfiguration; + addServer(server_config.first, std::make_unique(global_context, *config, security_config, raft_config, log_ptr)); + } + + prepareMockMPPServerInfo(); +} + +void MockComputeServerManager::setMockStorage(MockStorage & mock_storage) +{ + for (const auto & server : server_map) + { + server.second->setMockStorage(mock_storage); + } +} + +void MockComputeServerManager::reset() +{ + server_map.clear(); + server_config_map.clear(); +} + +MockMPPServerInfo MockComputeServerManager::getMockMPPServerInfo(size_t partition_id) +{ + return {server_config_map[partition_id].partition_id, server_config_map.size()}; +} + +std::unordered_map & MockComputeServerManager::getServerConfigMap() +{ + return server_config_map; +} + +void MockComputeServerManager::prepareMockMPPServerInfo() +{ + for (const auto & server : server_map) + { + server.second->setMockMPPServerInfo(getMockMPPServerInfo(server.first)); + } +} + +void MockComputeServerManager::resetMockMPPServerInfo(size_t partition_num) +{ + size_t i = 0; + for (const auto & server : server_map) + { + server.second->setMockMPPServerInfo({i++, partition_num}); + } +} + +void MockComputeServerManager::addServer(size_t partition_id, std::unique_ptr server) +{ + server_map[partition_id] = std::move(server); +} +} // namespace tests +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Debug/MockComputeServerManager.h b/dbms/src/Debug/MockComputeServerManager.h new file mode 100644 index 00000000000..ea88934fa31 --- /dev/null +++ b/dbms/src/Debug/MockComputeServerManager.h @@ -0,0 +1,57 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +namespace DB::tests +{ + +/** Hold Mock Compute Server to manage the lifetime of them. + * Maintains Mock Compute Server info. + */ +class MockComputeServerManager : public ext::Singleton +{ +public: + /// register an server to run. + void addServer(String addr); + + /// call startServers to run all servers in current test. + void startServers(const LoggerPtr & log_ptr, Context & global_context); + + /// set MockStorage for Compute Server in order to mock input columns. + void setMockStorage(MockStorage & mock_storage); + + /// stop all servers. + void reset(); + + MockMPPServerInfo getMockMPPServerInfo(size_t partition_id); + + std::unordered_map & getServerConfigMap(); + + void resetMockMPPServerInfo(size_t partition_num); + +private: + void addServer(size_t partition_id, std::unique_ptr server); + void prepareMockMPPServerInfo(); + +private: + std::unordered_map> server_map; + std::unordered_map server_config_map; +}; +} // namespace DB::tests \ No newline at end of file diff --git a/dbms/src/Debug/MockServerInfo.h b/dbms/src/Debug/MockServerInfo.h new file mode 100644 index 00000000000..945ef6837c7 --- /dev/null +++ b/dbms/src/Debug/MockServerInfo.h @@ -0,0 +1,31 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace DB::tests +{ +struct MockServerConfig +{ + String addr; + size_t partition_id; +}; + +struct MockMPPServerInfo +{ + size_t partition_id; + size_t partition_num; +}; +} // namespace DB::tests \ No newline at end of file diff --git a/dbms/src/Debug/MockStorage.cpp b/dbms/src/Debug/MockStorage.cpp new file mode 100644 index 00000000000..834a494311c --- /dev/null +++ b/dbms/src/Debug/MockStorage.cpp @@ -0,0 +1,153 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +namespace DB::tests +{ +void MockStorage::addTableSchema(const String & name, const MockColumnInfoVec & columnInfos) +{ + name_to_id_map[name] = MockTableIdGenerator::instance().nextTableId(); + table_schema[getTableId(name)] = columnInfos; +} + +void MockStorage::addTableData(const String & name, const ColumnsWithTypeAndName & columns) +{ + table_columns[getTableId(name)] = columns; +} + +Int64 MockStorage::getTableId(const String & name) +{ + if (name_to_id_map.find(name) != name_to_id_map.end()) + { + return name_to_id_map[name]; + } + throw Exception(fmt::format("Failed to get table id by table name '{}'", name)); +} + +bool MockStorage::tableExists(Int64 table_id) +{ + return table_schema.find(table_id) != table_schema.end(); +} + +ColumnsWithTypeAndName MockStorage::getColumns(Int64 table_id) +{ + if (tableExists(table_id)) + { + return table_columns[table_id]; + } + throw Exception(fmt::format("Failed to get columns by table_id '{}'", table_id)); +} + +MockColumnInfoVec MockStorage::getTableSchema(const String & name) +{ + if (tableExists(getTableId(name))) + { + return table_schema[getTableId(name)]; + } + throw Exception(fmt::format("Failed to get table schema by table name '{}'", name)); +} + +/// for exchange receiver +void MockStorage::addExchangeSchema(const String & exchange_name, const MockColumnInfoVec & columnInfos) +{ + exchange_schemas[exchange_name] = columnInfos; +} + +void MockStorage::addExchangeData(const String & exchange_name, const ColumnsWithTypeAndName & columns) +{ + exchange_columns[exchange_name] = columns; +} + +bool MockStorage::exchangeExists(const String & executor_id) +{ + return exchange_schemas.find(executor_id_to_name_map[executor_id]) != exchange_schemas.end(); +} + +bool MockStorage::exchangeExistsWithName(const String & name) +{ + return exchange_schemas.find(name) != exchange_schemas.end(); +} + +ColumnsWithTypeAndName MockStorage::getExchangeColumns(const String & executor_id) +{ + if (exchangeExists(executor_id)) + { + return exchange_columns[executor_id_to_name_map[executor_id]]; + } + throw Exception(fmt::format("Failed to get exchange columns by executor_id '{}'", executor_id)); +} + +void MockStorage::addExchangeRelation(const String & executor_id, const String & exchange_name) +{ + executor_id_to_name_map[executor_id] = exchange_name; +} + +MockColumnInfoVec MockStorage::getExchangeSchema(const String & exchange_name) +{ + if (exchangeExistsWithName(exchange_name)) + { + return exchange_schemas[exchange_name]; + } + throw Exception(fmt::format("Failed to get exchange schema by exchange name '{}'", exchange_name)); +} + +// use this function to determine where to cut the columns, +// and how many rows are needed for each partition of MPP task. +CutColumnInfo getCutColumnInfo(size_t rows, Int64 partition_id, Int64 partition_num) +{ + int start, per_rows, rows_left, cur_rows; + per_rows = rows / partition_num; + rows_left = rows - per_rows * partition_num; + if (partition_id >= rows_left) + { + start = (per_rows + 1) * rows_left + (partition_id - rows_left) * per_rows; + cur_rows = per_rows; + } + else + { + start = (per_rows + 1) * partition_id; + cur_rows = per_rows + 1; + } + return {start, cur_rows}; +} + +ColumnsWithTypeAndName MockStorage::getColumnsForMPPTableScan(Int64 table_id, Int64 partition_id, Int64 partition_num) +{ + if (tableExists(table_id)) + { + auto columns_with_type_and_name = table_columns[table_id]; + size_t rows = 0; + for (const auto & col : columns_with_type_and_name) + { + if (rows == 0) + rows = col.column->size(); + assert(rows == col.column->size()); + } + + CutColumnInfo cut_info = getCutColumnInfo(rows, partition_id, partition_num); + + ColumnsWithTypeAndName res; + for (const auto & column_with_type_and_name : columns_with_type_and_name) + { + res.push_back( + ColumnWithTypeAndName( + column_with_type_and_name.column->cut(cut_info.first, cut_info.second), + column_with_type_and_name.type, + column_with_type_and_name.name)); + } + return res; + } + throw Exception(fmt::format("Failed to get table columns by table_id '{}'", table_id)); +} +} // namespace DB::tests diff --git a/dbms/src/Debug/MockStorage.h b/dbms/src/Debug/MockStorage.h new file mode 100644 index 00000000000..5d92c9bf3ec --- /dev/null +++ b/dbms/src/Debug/MockStorage.h @@ -0,0 +1,87 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include + +#include +#include +namespace DB::tests +{ +using MockColumnInfo = std::pair; +using MockColumnInfoVec = std::vector; +using CutColumnInfo = std::pair; // + +class MockTableIdGenerator : public ext::Singleton +{ +public: + Int64 nextTableId() + { + return ++current_id; + } + +private: + std::atomic current_id = 0; +}; + +/** Responsible for mock data for executor tests and mpp tests. + * 1. Use this class to add mock table schema and table column data. + * 2. Use this class to add mock exchange schema and exchange column data. + */ +class MockStorage +{ +public: + /// for table scan + void addTableSchema(const String & name, const MockColumnInfoVec & columnInfos); + + void addTableData(const String & name, const ColumnsWithTypeAndName & columns); + + Int64 getTableId(const String & name); + + bool tableExists(Int64 table_id); + + ColumnsWithTypeAndName getColumns(Int64 table_id); + + MockColumnInfoVec getTableSchema(const String & name); + + /// for exchange receiver + void addExchangeSchema(const String & exchange_name, const MockColumnInfoVec & columnInfos); + + void addExchangeData(const String & exchange_name, const ColumnsWithTypeAndName & columns); + + bool exchangeExists(const String & executor_id); + bool exchangeExistsWithName(const String & name); + + ColumnsWithTypeAndName getExchangeColumns(const String & executor_id); + + void addExchangeRelation(const String & executor_id, const String & exchange_name); + + MockColumnInfoVec getExchangeSchema(const String & exchange_name); + + /// for MPP Tasks, it will split data by partition num, then each MPP service will have a subset of mock data. + ColumnsWithTypeAndName getColumnsForMPPTableScan(Int64 table_id, Int64 partition_id, Int64 partition_num); + +private: + /// for mock table scan + std::unordered_map name_to_id_map; /// + std::unordered_map table_schema; /// + std::unordered_map table_columns; /// + + /// for mock exchange receiver + std::unordered_map executor_id_to_name_map; /// + std::unordered_map exchange_schemas; /// + std::unordered_map exchange_columns; /// +}; +} // namespace DB::tests diff --git a/dbms/src/Debug/astToExecutor.cpp b/dbms/src/Debug/astToExecutor.cpp index e1d22b328ac..f3c8348c376 100644 --- a/dbms/src/Debug/astToExecutor.cpp +++ b/dbms/src/Debug/astToExecutor.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -31,6 +32,7 @@ namespace DB { using ASTPartitionByElement = ASTOrderByElement; +using MockComputeServerManager = tests::MockComputeServerManager; void literalFieldToTiPBExpr(const ColumnInfo & ci, const Field & val_field, tipb::Expr * expr, Int32 collator_id) { *(expr->mutable_field_type()) = columnInfoToFieldType(ci); @@ -769,6 +771,7 @@ void compileFilter(const DAGSchema & input, ASTPtr ast, std::vector & co namespace Debug { String LOCAL_HOST = "127.0.0.1:3930"; + void setServiceAddr(const std::string & addr) { LOCAL_HOST = addr; @@ -820,13 +823,16 @@ bool ExchangeSender::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t coll tipb_type.set_collate(collator_id); *exchange_sender->add_types() = tipb_type; } + for (auto task_id : mpp_info.sender_target_task_ids) { mpp::TaskMeta meta; meta.set_start_ts(mpp_info.start_ts); meta.set_task_id(task_id); meta.set_partition_id(mpp_info.partition_id); - meta.set_address(Debug::LOCAL_HOST); + auto addr = context.isMPPTest() ? MockComputeServerManager::instance().getServerConfigMap()[mpp_info.partition_id].addr : Debug::LOCAL_HOST; + meta.set_address(addr); + auto * meta_string = exchange_sender->add_encoded_task_meta(); meta.AppendToString(meta_string); } @@ -843,7 +849,7 @@ bool ExchangeSender::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t coll return children[0]->toTiPBExecutor(child_executor, collator_id, mpp_info, context); } -bool ExchangeReceiver::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, const MPPInfo & mpp_info, const Context &) +bool ExchangeReceiver::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, const MPPInfo & mpp_info, const Context & context) { tipb_executor->set_tp(tipb::ExecType::TypeExchangeReceiver); tipb_executor->set_executor_id(name); @@ -860,13 +866,16 @@ bool ExchangeReceiver::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t co auto it = mpp_info.receiver_source_task_ids_map.find(name); if (it == mpp_info.receiver_source_task_ids_map.end()) throw Exception("Can not found mpp receiver info"); - for (size_t i = 0; i < it->second.size(); i++) + + auto size = it->second.size(); + for (size_t i = 0; i < size; ++i) { mpp::TaskMeta meta; meta.set_start_ts(mpp_info.start_ts); meta.set_task_id(it->second[i]); meta.set_partition_id(i); - meta.set_address(Debug::LOCAL_HOST); + auto addr = context.isMPPTest() ? MockComputeServerManager::instance().getServerConfigMap()[mpp_info.partition_id].addr : Debug::LOCAL_HOST; + meta.set_address(addr); auto * meta_string = exchange_receiver->add_encoded_task_meta(); meta.AppendToString(meta_string); } @@ -875,8 +884,19 @@ bool ExchangeReceiver::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t co void TableScan::columnPrune(std::unordered_set & used_columns) { - output_schema.erase(std::remove_if(output_schema.begin(), output_schema.end(), [&](const auto & field) { return used_columns.count(field.first) == 0; }), - output_schema.end()); + DAGSchema new_schema; + for (const auto & col : output_schema) + { + for (const auto & used_col : used_columns) + { + if (splitQualifiedName(used_col).column_name == splitQualifiedName(col.first).column_name && splitQualifiedName(used_col).table_name == splitQualifiedName(col.first).table_name) + { + new_schema.push_back({used_col, col.second}); + } + } + } + + output_schema = new_schema; } bool TableScan::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t, const MPPInfo &, const Context &) @@ -1181,19 +1201,28 @@ void Join::columnPrune(std::unordered_set & used_columns) std::unordered_set right_columns; for (auto & field : children[0]->output_schema) - left_columns.emplace(field.first); - for (auto & field : children[1]->output_schema) - right_columns.emplace(field.first); + { + auto [db_name, table_name, column_name] = splitQualifiedName(field.first); + left_columns.emplace(table_name + "." + column_name); + } + for (auto & field : children[1]->output_schema) + { + auto [db_name, table_name, column_name] = splitQualifiedName(field.first); + right_columns.emplace(table_name + "." + column_name); + } std::unordered_set left_used_columns; std::unordered_set right_used_columns; for (const auto & s : used_columns) { - if (left_columns.find(s) != left_columns.end()) - left_used_columns.emplace(s); - else - right_used_columns.emplace(s); + auto [db_name, table_name, col_name] = splitQualifiedName(s); + auto t = table_name + "." + col_name; + if (left_columns.find(t) != left_columns.end()) + left_used_columns.emplace(t); + + if (right_columns.find(t) != right_columns.end()) + right_used_columns.emplace(t); } for (const auto & child : join_cols) @@ -1203,17 +1232,19 @@ void Join::columnPrune(std::unordered_set & used_columns) auto col_name = identifier->getColumnName(); for (auto & field : children[0]->output_schema) { - if (col_name == splitQualifiedName(field.first).column_name) + auto [db_name, table_name, column_name] = splitQualifiedName(field.first); + if (col_name == column_name) { - left_used_columns.emplace(field.first); + left_used_columns.emplace(table_name + "." + column_name); break; } } for (auto & field : children[1]->output_schema) { - if (col_name == splitQualifiedName(field.first).column_name) + auto [db_name, table_name, column_name] = splitQualifiedName(field.first); + if (col_name == column_name) { - right_used_columns.emplace(field.first); + right_used_columns.emplace(table_name + "." + column_name); break; } } @@ -1229,6 +1260,7 @@ void Join::columnPrune(std::unordered_set & used_columns) /// update output schema output_schema.clear(); + for (auto & field : children[0]->output_schema) { if (tp == tipb::TypeRightOuterJoin && field.second.hasNotNullFlag()) @@ -1758,6 +1790,7 @@ ExecutorPtr compileJoin(size_t & executor_index, ExecutorPtr left, ExecutorPtr r return compileJoin(executor_index, left, right, tp, join_cols); } + ExecutorPtr compileExchangeSender(ExecutorPtr input, size_t & executor_index, tipb::ExchangeType exchange_type) { ExecutorPtr exchange_sender = std::make_shared(executor_index, input->output_schema, exchange_type); diff --git a/dbms/src/Debug/astToExecutor.h b/dbms/src/Debug/astToExecutor.h index b97577f1e55..caa9116f2e1 100644 --- a/dbms/src/Debug/astToExecutor.h +++ b/dbms/src/Debug/astToExecutor.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include diff --git a/dbms/src/Debug/dbgFuncCoprocessor.cpp b/dbms/src/Debug/dbgFuncCoprocessor.cpp index fd2f9606015..7be2b808748 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.cpp +++ b/dbms/src/Debug/dbgFuncCoprocessor.cpp @@ -184,27 +184,76 @@ void setTipbRegionInfo(coprocessor::RegionInfo * tipb_region_info, const std::pa range->set_end(RecordKVFormat::genRawKey(table_id, handle_range.second.handle_id)); } +BlockInputStreamPtr prepareRootExchangeReceiver(Context & context, const DAGProperties & properties, std::vector & root_task_ids, DAGSchema & root_task_schema, bool enable_local_tunnel) +{ + tipb::ExchangeReceiver tipb_exchange_receiver; + for (const auto root_task_id : root_task_ids) + { + mpp::TaskMeta tm; + tm.set_start_ts(properties.start_ts); + tm.set_address(Debug::LOCAL_HOST); + tm.set_task_id(root_task_id); + tm.set_partition_id(-1); + auto * tm_string = tipb_exchange_receiver.add_encoded_task_meta(); + tm.AppendToString(tm_string); + } + for (auto & field : root_task_schema) + { + auto tipb_type = TiDB::columnInfoToFieldType(field.second); + tipb_type.set_collate(properties.collator); + auto * field_type = tipb_exchange_receiver.add_field_types(); + *field_type = tipb_type; + } + mpp::TaskMeta root_tm; + root_tm.set_start_ts(properties.start_ts); + root_tm.set_address(Debug::LOCAL_HOST); + root_tm.set_task_id(-1); + root_tm.set_partition_id(-1); + std::shared_ptr exchange_receiver + = std::make_shared( + std::make_shared( + tipb_exchange_receiver, + root_tm, + context.getTMTContext().getKVCluster(), + context.getTMTContext().getMPPTaskManager(), + enable_local_tunnel, + context.getSettingsRef().enable_async_grpc_client), + tipb_exchange_receiver.encoded_task_meta_size(), + 10, + /*req_id=*/"", + /*executor_id=*/"", + /*fine_grained_shuffle_stream_count=*/0); + BlockInputStreamPtr ret = std::make_shared(exchange_receiver, /*req_id=*/"", /*executor_id=*/"", /*stream_id*/ 0); + return ret; +} + +void prepareDispatchTaskRequest(QueryTask & task, std::shared_ptr req, const DAGProperties & properties, std::vector & root_task_ids, DAGSchema & root_task_schema, String & addr) +{ + if (task.is_root_task) + { + root_task_ids.push_back(task.task_id); + root_task_schema = task.result_schema; + } + auto * tm = req->mutable_meta(); + tm->set_start_ts(properties.start_ts); + tm->set_partition_id(task.partition_id); + tm->set_address(addr); + tm->set_task_id(task.task_id); + auto * encoded_plan = req->mutable_encoded_plan(); + task.dag_request->AppendToString(encoded_plan); + req->set_timeout(properties.mpp_timeout); + req->set_schema_ver(DEFAULT_UNSPECIFIED_SCHEMA_VERSION); +} + +// execute MPP Query in one service BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks) { DAGSchema root_task_schema; std::vector root_task_ids; for (auto & task : query_tasks) { - if (task.is_root_task) - { - root_task_ids.push_back(task.task_id); - root_task_schema = task.result_schema; - } auto req = std::make_shared(); - auto * tm = req->mutable_meta(); - tm->set_start_ts(properties.start_ts); - tm->set_partition_id(task.partition_id); - tm->set_address(Debug::LOCAL_HOST); - tm->set_task_id(task.task_id); - auto * encoded_plan = req->mutable_encoded_plan(); - task.dag_request->AppendToString(encoded_plan); - req->set_timeout(properties.mpp_timeout); - req->set_schema_ver(DEFAULT_UNSPECIFIED_SCHEMA_VERSION); + prepareDispatchTaskRequest(task, req, properties, root_task_ids, root_task_schema, Debug::LOCAL_HOST); auto table_id = task.table_id; if (table_id != -1) { @@ -250,59 +299,29 @@ BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & pro } } - if (context.isMPPTest()) - { - MockComputeClient client( - grpc::CreateChannel(Debug::LOCAL_HOST, grpc::InsecureChannelCredentials())); - client.runDispatchMPPTask(req); - } - else - { - pingcap::kv::RpcCall call(req); - context.getTMTContext().getCluster()->rpc_client->sendRequest(Debug::LOCAL_HOST, call, 1000); - if (call.getResp()->has_error()) - throw Exception("Meet error while dispatch mpp task: " + call.getResp()->error().msg()); - } - } - tipb::ExchangeReceiver tipb_exchange_receiver; - for (const auto root_task_id : root_task_ids) - { - mpp::TaskMeta tm; - tm.set_start_ts(properties.start_ts); - tm.set_address(Debug::LOCAL_HOST); - tm.set_task_id(root_task_id); - tm.set_partition_id(-1); - auto * tm_string = tipb_exchange_receiver.add_encoded_task_meta(); - tm.AppendToString(tm_string); + pingcap::kv::RpcCall call(req); + context.getTMTContext().getCluster()->rpc_client->sendRequest(Debug::LOCAL_HOST, call, 1000); + if (call.getResp()->has_error()) + throw Exception("Meet error while dispatch mpp task: " + call.getResp()->error().msg()); } - for (auto & field : root_task_schema) + return prepareRootExchangeReceiver(context, properties, root_task_ids, root_task_schema, context.getSettingsRef().enable_local_tunnel); +} + +// execute MPP Query across multiple service +BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks, std::unordered_map & server_config_map) +{ + DAGSchema root_task_schema; + std::vector root_task_ids; + for (auto & task : query_tasks) { - auto tipb_type = TiDB::columnInfoToFieldType(field.second); - tipb_type.set_collate(properties.collator); - auto * field_type = tipb_exchange_receiver.add_field_types(); - *field_type = tipb_type; + auto req = std::make_shared(); + auto addr = server_config_map[task.partition_id].addr; + prepareDispatchTaskRequest(task, req, properties, root_task_ids, root_task_schema, addr); + MockComputeClient client( + grpc::CreateChannel(addr, grpc::InsecureChannelCredentials())); + client.runDispatchMPPTask(req); } - mpp::TaskMeta root_tm; - root_tm.set_start_ts(properties.start_ts); - root_tm.set_address(Debug::LOCAL_HOST); - root_tm.set_task_id(-1); - root_tm.set_partition_id(-1); - std::shared_ptr exchange_receiver - = std::make_shared( - std::make_shared( - tipb_exchange_receiver, - root_tm, - context.getTMTContext().getKVCluster(), - context.getTMTContext().getMPPTaskManager(), - context.getSettingsRef().enable_local_tunnel, - context.getSettingsRef().enable_async_grpc_client), - tipb_exchange_receiver.encoded_task_meta_size(), - 10, - /*req_id=*/"", - /*executor_id=*/"", - /*fine_grained_shuffle_stream_count=*/0); - BlockInputStreamPtr ret = std::make_shared(exchange_receiver, /*req_id=*/"", /*executor_id=*/"", /*stream_id*/ 0); - return ret; + return prepareRootExchangeReceiver(context, properties, root_task_ids, root_task_schema, true); } BlockInputStreamPtr executeNonMPPQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream) @@ -645,7 +664,7 @@ QueryFragments mppQueryToQueryFragments( { mpp_ctx->sender_target_task_ids = current_task_ids; auto sub_fragments = mppQueryToQueryFragments(exchange.second.second, executor_index, properties, false, mpp_ctx); - receiver_source_task_ids_map[exchange.first] = sub_fragments.cbegin()->task_ids; + receiver_source_task_ids_map[exchange.first] = sub_fragments[sub_fragments.size() - 1].task_ids; fragments.insert(fragments.end(), sub_fragments.begin(), sub_fragments.end()); } fragments.emplace_back(root_executor, table_id, for_root_fragment, std::move(sender_target_task_ids), std::move(receiver_source_task_ids_map), std::move(current_task_ids)); diff --git a/dbms/src/Debug/dbgFuncCoprocessor.h b/dbms/src/Debug/dbgFuncCoprocessor.h index 41456e54ac4..5a1ccc669d2 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.h +++ b/dbms/src/Debug/dbgFuncCoprocessor.h @@ -24,6 +24,7 @@ namespace DB { class Context; +using MockServerConfig = tests::MockServerConfig; // Coprocessor debug tools @@ -84,8 +85,9 @@ QueryTasks queryPlanToQueryTasks( const Context & context); BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream); - BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks); +BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks, std::unordered_map & server_config_map); + namespace Debug { void setServiceAddr(const std::string & addr); diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 8ae4310eb50..b4a7511197e 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -160,7 +160,7 @@ AnalysisResult analyzeExpressions( // for tests, we need to mock tableScan blockInputStream as the source stream. void DAGQueryBlockInterpreter::handleMockTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline) { - if (context.columnsForTestEmpty() || context.columnsForTest(table_scan.getTableScanExecutorID()).empty()) + if (!context.mockStorage().tableExists(table_scan.getLogicalTableID())) { auto names_and_types = genNamesAndTypes(table_scan, "mock_table_scan"); auto columns_with_type_and_name = getColumnWithTypeAndName(names_and_types); @@ -173,12 +173,13 @@ void DAGQueryBlockInterpreter::handleMockTableScan(const TiDBTableScan & table_s } else { - auto [names_and_types, mock_table_scan_streams] = mockSourceStream(context, max_streams, log, table_scan.getTableScanExecutorID()); + auto [names_and_types, mock_table_scan_streams] = mockSourceStream(context, max_streams, log, table_scan.getTableScanExecutorID(), table_scan.getLogicalTableID()); analyzer = std::make_unique(std::move(names_and_types), context); pipeline.streams.insert(pipeline.streams.end(), mock_table_scan_streams.begin(), mock_table_scan_streams.end()); } } + void DAGQueryBlockInterpreter::handleTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline) { const auto push_down_filter = PushDownFilter::pushDownFilterFrom(query_block.selection_name, query_block.selection); @@ -493,7 +494,7 @@ void DAGQueryBlockInterpreter::handleExchangeReceiver(DAGPipeline & pipeline) // for tests, we need to mock ExchangeReceiver blockInputStream as the source stream. void DAGQueryBlockInterpreter::handleMockExchangeReceiver(DAGPipeline & pipeline) { - if (context.columnsForTestEmpty() || context.columnsForTest(query_block.source_name).empty()) + if (!context.mockStorage().exchangeExists(query_block.source_name)) { for (size_t i = 0; i < max_streams; ++i) { diff --git a/dbms/src/Flash/Coprocessor/MockSourceStream.h b/dbms/src/Flash/Coprocessor/MockSourceStream.h index 5b69630b40f..3cecfc0996f 100644 --- a/dbms/src/Flash/Coprocessor/MockSourceStream.h +++ b/dbms/src/Flash/Coprocessor/MockSourceStream.h @@ -16,19 +16,27 @@ #include #include #include +#include #include #include namespace DB { + template -std::pair>> mockSourceStream(Context & context, size_t max_streams, DB::LoggerPtr log, String executor_id) +std::pair>> mockSourceStream(Context & context, size_t max_streams, DB::LoggerPtr log, String executor_id, Int64 table_id = 0) { ColumnsWithTypeAndName columns_with_type_and_name; NamesAndTypes names_and_types; size_t rows = 0; std::vector> mock_source_streams; - columns_with_type_and_name = context.columnsForTest(executor_id); + if constexpr (std::is_same_v) + columns_with_type_and_name = context.mockStorage().getExchangeColumns(executor_id); + else if (context.isMPPTest()) + columns_with_type_and_name = context.mockStorage().getColumnsForMPPTableScan(table_id, context.mockMPPServerInfo().partition_id, context.mockMPPServerInfo().partition_num); + else + columns_with_type_and_name = context.mockStorage().getColumns(table_id); + for (const auto & col : columns_with_type_and_name) { if (rows == 0) diff --git a/dbms/src/Flash/FlashService.cpp b/dbms/src/Flash/FlashService.cpp index ce9934d7d7f..f4768a347ae 100644 --- a/dbms/src/Flash/FlashService.cpp +++ b/dbms/src/Flash/FlashService.cpp @@ -186,6 +186,8 @@ ::grpc::Status FlashService::DispatchMPPTask( { return status; } + context->setMockStorage(mock_storage); + context->setMockMPPServerInfo(mpp_test_info); MPPHandler mpp_handler(*request); return mpp_handler.execute(context, response); @@ -231,7 +233,8 @@ ::grpc::Status FlashService::establishMPPConnectionSyncOrAsync(::grpc::ServerCon // We need to find it out and bind the grpc stream with it. LOG_FMT_DEBUG(log, "Handling establish mpp connection request: {}", request->DebugString()); - if (!security_config.checkGrpcContext(grpc_context)) + // For MPP test, we don't care about security config. + if (!context.isMPPTest() && !security_config.checkGrpcContext(grpc_context)) { return returnStatus(calldata, grpc::Status(grpc::PERMISSION_DENIED, tls_err_msg)); } @@ -431,4 +434,13 @@ ::grpc::Status FlashService::Compact(::grpc::ServerContext * grpc_context, const return manual_compact_manager->handleRequest(request, response); } +void FlashService::setMockStorage(MockStorage & mock_storage_) +{ + mock_storage = mock_storage_; +} + +void FlashService::setMockMPPServerInfo(MockMPPServerInfo & mpp_test_info_) +{ + mpp_test_info = mpp_test_info_; +} } // namespace DB \ No newline at end of file diff --git a/dbms/src/Flash/FlashService.h b/dbms/src/Flash/FlashService.h index 7a25aae4fa2..5ac694288e8 100644 --- a/dbms/src/Flash/FlashService.h +++ b/dbms/src/Flash/FlashService.h @@ -36,6 +36,10 @@ namespace DB class IServer; class CallExecPool; class EstablishCallData; + +using MockStorage = tests::MockStorage; +using MockMPPServerInfo = tests::MockMPPServerInfo; + namespace Management { class ManualCompactManager; @@ -80,6 +84,9 @@ class FlashService : public tikvpb::Tikv::Service ::grpc::Status Compact(::grpc::ServerContext * context, const ::kvrpcpb::CompactRequest * request, ::kvrpcpb::CompactResponse * response) override; + void setMockStorage(MockStorage & mock_storage_); + void setMockMPPServerInfo(MockMPPServerInfo & mpp_test_info_); + protected: std::tuple createDBContext(const grpc::ServerContext * grpc_context) const; @@ -92,6 +99,11 @@ class FlashService : public tikvpb::Tikv::Service std::unique_ptr manual_compact_manager; + + /// for mpp unit test. + MockStorage mock_storage; + MockMPPServerInfo mpp_test_info{}; + // Put thread pool member(s) at the end so that ensure it will be destroyed firstly. std::unique_ptr cop_pool, batch_cop_pool; }; diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 5ea7b527e0f..6e81dec1aff 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -147,7 +147,7 @@ void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request) bool is_local = context->getSettingsRef().enable_local_tunnel && meta.address() == task_meta.address(); bool is_async = !is_local && context->getSettingsRef().enable_async_server; MPPTunnelPtr tunnel = std::make_shared(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier()); - LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id()); + LOG_FMT_DEBUG(log, "begin to register the tunnel {}, is_local: {}, is_async: {}", tunnel->id(), is_local, is_async); if (status != INITIALIZING) throw Exception(fmt::format("The tunnel {} can not be registered, because the task is not in initializing state", tunnel->id())); tunnel_set_local->registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel); diff --git a/dbms/src/Flash/Planner/plans/PhysicalMockExchangeReceiver.cpp b/dbms/src/Flash/Planner/plans/PhysicalMockExchangeReceiver.cpp index 7a4566eefd0..4cc3262ba18 100644 --- a/dbms/src/Flash/Planner/plans/PhysicalMockExchangeReceiver.cpp +++ b/dbms/src/Flash/Planner/plans/PhysicalMockExchangeReceiver.cpp @@ -37,7 +37,7 @@ std::pair mockSchemaAndStreams( size_t max_streams = dag_context.initialize_concurrency; assert(max_streams > 0); - if (context.columnsForTestEmpty() || context.columnsForTest(executor_id).empty()) + if (!context.mockStorage().exchangeExists(executor_id)) { /// build with default blocks. for (size_t i = 0; i < max_streams; ++i) diff --git a/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp b/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp index d390d6c5e05..26f31bf4400 100644 --- a/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp +++ b/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp @@ -38,7 +38,7 @@ std::pair mockSchemaAndStreams( size_t max_streams = dag_context.initialize_concurrency; assert(max_streams > 0); - if (context.columnsForTestEmpty() || context.columnsForTest(executor_id).empty()) + if (!context.mockStorage().tableExists(table_scan.getLogicalTableID())) { /// build with default blocks. schema = genNamesAndTypes(table_scan, "mock_table_scan"); @@ -49,7 +49,7 @@ std::pair mockSchemaAndStreams( else { /// build from user input blocks. - auto [names_and_types, mock_table_scan_streams] = mockSourceStream(context, max_streams, log, executor_id); + auto [names_and_types, mock_table_scan_streams] = mockSourceStream(context, max_streams, log, executor_id, table_scan.getLogicalTableID()); schema = std::move(names_and_types); mock_streams.insert(mock_streams.end(), mock_table_scan_streams.begin(), mock_table_scan_streams.end()); } diff --git a/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp b/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp index 1d0704f692b..c26e4c9f298 100644 --- a/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp +++ b/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp @@ -89,10 +89,9 @@ class PhysicalPlanTestRunner : public DB::tests::ExecutorTest // TODO support multi-streams. size_t max_streams = 1; - context.context.setColumnsForTest(context.executorIdColumnsMap()); - DAGContext dag_context(*request, "executor_test", max_streams); context.context.setDAGContext(&dag_context); + context.context.setMockStorage(context.mockStorage()); PhysicalPlan physical_plan{context.context, log->identifier()}; assert(request); @@ -112,7 +111,7 @@ class PhysicalPlanTestRunner : public DB::tests::ExecutorTest ASSERT_EQ(Poco::trim(expected_streams), Poco::trim(fb.toString())); } - ASSERT_COLUMNS_EQ_R(expect_columns, readBlock(final_stream)); + ASSERT_COLUMNS_EQ_UR(expect_columns, readBlock(final_stream)); } std::tuple multiTestScan() diff --git a/dbms/src/Flash/tests/bench_window.cpp b/dbms/src/Flash/tests/bench_window.cpp index 75dc53b065b..9f68ba9beb9 100644 --- a/dbms/src/Flash/tests/bench_window.cpp +++ b/dbms/src/Flash/tests/bench_window.cpp @@ -39,7 +39,7 @@ class WindowFunctionBench : public ExchangeBench size_t executor_index = 0; DAGRequestBuilder builder(executor_index); builder - .mockTable("test", "t1", columns) + .mockTable("test", "t1", 0 /*table_id=*/, columns) .sort({{"c1", false}, {"c2", false}, {"c3", false}}, true, fine_grained_shuffle_stream_count) .window(RowNumber(), {{"c1", false}, {"c2", false}, {"c3", false}}, diff --git a/dbms/src/Flash/tests/gtest_compute_server.cpp b/dbms/src/Flash/tests/gtest_compute_server.cpp index 5dac37d207f..a5c5bc18bf0 100644 --- a/dbms/src/Flash/tests/gtest_compute_server.cpp +++ b/dbms/src/Flash/tests/gtest_compute_server.cpp @@ -28,7 +28,7 @@ class ComputeServerRunner : public DB::tests::MPPTaskTestUtils context.addMockTable( {"test_db", "test_table_1"}, {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, - {toNullableVec("s1", {1, {}, 10000000}), toNullableVec("s2", {"apple", {}, "banana"}), toNullableVec("s3", {"apple", {}, "banana"})}); + {toNullableVec("s1", {1, {}, 10000000, 10000000}), toNullableVec("s2", {"apple", {}, "banana", "test"}), toNullableVec("s3", {"apple", {}, "banana", "test"})}); /// for join context.addMockTable( @@ -45,29 +45,79 @@ class ComputeServerRunner : public DB::tests::MPPTaskTestUtils TEST_F(ComputeServerRunner, runAggTasks) try { + startServers(4); { - auto tasks = context.scan("test_db", "test_table_1") - .aggregation({Max(col("s1"))}, {col("s2"), col("s3")}) - .project({"max(s1)"}) - .buildMPPTasks(context); + std::vector expected_strings = { + R"(exchange_sender_5 | type:Hash, {<0, Long>, <1, String>, <2, String>} + aggregation_4 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + table_scan_0 | {<0, Long>, <1, String>, <2, String>} +)", + R"(exchange_sender_5 | type:Hash, {<0, Long>, <1, String>, <2, String>} + aggregation_4 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + table_scan_0 | {<0, Long>, <1, String>, <2, String>} +)", + R"(exchange_sender_5 | type:Hash, {<0, Long>, <1, String>, <2, String>} + aggregation_4 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + table_scan_0 | {<0, Long>, <1, String>, <2, String>} +)", + R"(exchange_sender_5 | type:Hash, {<0, Long>, <1, String>, <2, String>} + aggregation_4 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + table_scan_0 | {<0, Long>, <1, String>, <2, String>} +)", + R"(exchange_sender_3 | type:PassThrough, {<0, Long>} + project_2 | {<0, Long>} + aggregation_1 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + exchange_receiver_6 | type:PassThrough, {<0, Long>, <1, String>, <2, String>} +)", + R"(exchange_sender_3 | type:PassThrough, {<0, Long>} + project_2 | {<0, Long>} + aggregation_1 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + exchange_receiver_6 | type:PassThrough, {<0, Long>, <1, String>, <2, String>} +)", + R"( +exchange_sender_3 | type:PassThrough, {<0, Long>} + project_2 | {<0, Long>} + aggregation_1 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + exchange_receiver_6 | type:PassThrough, {<0, Long>, <1, String>, <2, String>} +)", + R"(exchange_sender_3 | type:PassThrough, {<0, Long>} + project_2 | {<0, Long>} + aggregation_1 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)} + exchange_receiver_6 | type:PassThrough, {<0, Long>, <1, String>, <2, String>} +)"}; + auto expected_cols = {toNullableVec({1, {}, 10000000, 10000000})}; - size_t task_size = tasks.size(); + ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT( + context + .scan("test_db", "test_table_1") + .aggregation({Max(col("s1"))}, {col("s2"), col("s3")}) + .project({"max(s1)"}), + expected_strings, + expected_cols); + } + { + auto properties = getDAGPropertiesForTest(1); + auto tasks = context + .scan("test_db", "test_table_1") + .aggregation({Count(col("s1"))}, {}) + .project({"count(s1)"}) + .buildMPPTasks(context, properties); std::vector expected_strings = { - "exchange_sender_5 | type:Hash, {<0, Long>, <1, String>, <2, String>}\n" - " aggregation_4 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)}\n" - " table_scan_0 | {<0, Long>, <1, String>, <2, String>}\n", - "exchange_sender_3 | type:PassThrough, {<0, Long>}\n" - " project_2 | {<0, Long>}\n" - " aggregation_1 | group_by: {<1, String>, <2, String>}, agg_func: {max(<0, Long>)}\n" - " exchange_receiver_6 | type:PassThrough, {<0, Long>, <1, String>, <2, String>}\n"}; + R"(exchange_sender_5 | type:PassThrough, {<0, Longlong>} + aggregation_4 | group_by: {}, agg_func: {count(<0, Long>)} + table_scan_0 | {<0, Long>} + )", + R"(exchange_sender_3 | type:PassThrough, {<0, Longlong>} + project_2 | {<0, Longlong>} + aggregation_1 | group_by: {}, agg_func: {sum(<0, Longlong>)} + exchange_receiver_6 | type:PassThrough, {<0, Longlong>})"}; + + size_t task_size = tasks.size(); for (size_t i = 0; i < task_size; ++i) { ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); } - - auto expected_cols = {toNullableVec({1, {}, 10000000})}; - ASSERT_MPPTASK_EQUAL(tasks, expected_cols); } } CATCH @@ -75,34 +125,164 @@ CATCH TEST_F(ComputeServerRunner, runJoinTasks) try { - auto tasks = context - .scan("test_db", "l_table") - .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}) - .topN("join_c", false, 2) - .buildMPPTasks(context); - - size_t task_size = tasks.size(); - std::vector expected_strings = { - "exchange_sender_6 | type:Hash, {<0, String>, <1, String>}\n" - " table_scan_1 | {<0, String>, <1, String>}", - "exchange_sender_5 | type:Hash, {<0, String>, <1, String>}\n" - " table_scan_0 | {<0, String>, <1, String>}", - "exchange_sender_4 | type:PassThrough, {<0, String>, <1, String>, <2, String>, <3, String>}\n" - " topn_3 | order_by: {(<1, String>, desc: false)}, limit: 2\n" - " Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>}\n" - " exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>}\n" - " exchange_receiver_8 | type:PassThrough, {<0, String>, <1, String>}"}; - for (size_t i = 0; i < task_size; ++i) + startServers(3); { - ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); + auto expected_cols = { + toNullableVec({{}, "banana", "banana"}), + toNullableVec({{}, "apple", "banana"}), + toNullableVec({{}, "banana", "banana"}), + toNullableVec({{}, "apple", "banana"})}; + + std::vector expected_strings = { + R"(exchange_sender_5 | type:Hash, {<0, String>, <1, String>} + table_scan_1 | {<0, String>, <1, String>})", + R"(exchange_sender_5 | type:Hash, {<0, String>, <1, String>} + table_scan_1 | {<0, String>, <1, String>})", + R"(exchange_sender_5 | type:Hash, {<0, String>, <1, String>} + table_scan_1 | {<0, String>, <1, String>})", + R"(exchange_sender_4 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_4 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_4 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_3 | type:PassThrough, {<0, String>, <1, String>, <2, String>, <3, String>} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_6 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>})", + R"(exchange_sender_3 | type:PassThrough, {<0, String>, <1, String>, <2, String>, <3, String>} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_6 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>})", + R"(exchange_sender_3 | type:PassThrough, {<0, String>, <1, String>, <2, String>, <3, String>} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_6 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>})"}; + + ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT(context + .scan("test_db", "l_table") + .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}), + expected_strings, + expect_cols); } - auto expected_cols = { - toNullableVec({{}, "banana"}), - toNullableVec({{}, "apple"}), - toNullableVec({{}, "banana"}), - toNullableVec({{}, "apple"})}; - ASSERT_MPPTASK_EQUAL(tasks, expected_cols); + { + auto properties = getDAGPropertiesForTest(1); + auto tasks = context + .scan("test_db", "l_table") + .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}) + .buildMPPTasks(context, properties); + + std::vector expected_strings = { + R"(exchange_sender_5 | type:Hash, {<0, String>, <1, String>} + table_scan_1 | {<0, String>, <1, String>})", + R"(exchange_sender_4 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_3 | type:PassThrough, {<0, String>, <1, String>, <2, String>, <3, String>} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_6 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>})"}; + + size_t task_size = tasks.size(); + for (size_t i = 0; i < task_size; ++i) + { + ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); + } + } +} +CATCH + +TEST_F(ComputeServerRunner, runJoinThenAggTasks) +try +{ + startServers(3); + { + std::vector expected_strings = { + R"(exchange_sender_10 | type:Hash, {<0, String>} + table_scan_1 | {<0, String>})", + R"(exchange_sender_10 | type:Hash, {<0, String>} + table_scan_1 | {<0, String>})", + R"(exchange_sender_10 | type:Hash, {<0, String>} + table_scan_1 | {<0, String>})", + R"(exchange_sender_9 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_9 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_9 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_7 | type:Hash, {<0, String>, <1, String>} + aggregation_6 | group_by: {<0, String>}, agg_func: {max(<0, String>)} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_11 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_12 | type:PassThrough, {<0, String>})", + R"(exchange_sender_7 | type:Hash, {<0, String>, <1, String>} + aggregation_6 | group_by: {<0, String>}, agg_func: {max(<0, String>)} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_11 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_12 | type:PassThrough, {<0, String>})", + R"(exchange_sender_7 | type:Hash, {<0, String>, <1, String>} + aggregation_6 | group_by: {<0, String>}, agg_func: {max(<0, String>)} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_11 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_12 | type:PassThrough, {<0, String>})", + R"(exchange_sender_5 | type:PassThrough, {<0, String>, <1, String>} + project_4 | {<0, String>, <1, String>} + aggregation_3 | group_by: {<1, String>}, agg_func: {max(<0, String>)} + exchange_receiver_8 | type:PassThrough, {<0, String>, <1, String>})", + R"(exchange_sender_5 | type:PassThrough, {<0, String>, <1, String>} + project_4 | {<0, String>, <1, String>} + aggregation_3 | group_by: {<1, String>}, agg_func: {max(<0, String>)} + exchange_receiver_8 | type:PassThrough, {<0, String>, <1, String>})", + R"(exchange_sender_5 | type:PassThrough, {<0, String>, <1, String>} + project_4 | {<0, String>, <1, String>} + aggregation_3 | group_by: {<1, String>}, agg_func: {max(<0, String>)} + exchange_receiver_8 | type:PassThrough, {<0, String>, <1, String>})"}; + + auto expected_cols = { + toNullableVec({{}, "banana"}), + toNullableVec({{}, "banana"})}; + + ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT( + context + .scan("test_db", "l_table") + .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}) + .aggregation({Max(col("l_table.s"))}, {col("l_table.s")}) + .project({col("max(l_table.s)"), col("l_table.s")}), + expected_strings, + expect_cols); + } + + { + auto properties = getDAGPropertiesForTest(1); + auto tasks = context + .scan("test_db", "l_table") + .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}) + .aggregation({Max(col("l_table.s"))}, {col("l_table.s")}) + .project({col("max(l_table.s)"), col("l_table.s")}) + .buildMPPTasks(context, properties); + + std::vector expected_strings = { + R"(exchange_sender_10 | type:Hash, {<0, String>} + table_scan_1 | {<0, String>})", + R"(exchange_sender_9 | type:Hash, {<0, String>, <1, String>} + table_scan_0 | {<0, String>, <1, String>})", + R"(exchange_sender_7 | type:Hash, {<0, String>, <1, String>} + aggregation_6 | group_by: {<0, String>}, agg_func: {max(<0, String>)} + Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>} + exchange_receiver_11 | type:PassThrough, {<0, String>, <1, String>} + exchange_receiver_12 | type:PassThrough, {<0, String>})", + R"(exchange_sender_5 | type:PassThrough, {<0, String>, <1, String>} + project_4 | {<0, String>, <1, String>} + aggregation_3 | group_by: {<1, String>}, agg_func: {max(<0, String>)} + exchange_receiver_8 | type:PassThrough, {<0, String>, <1, String>})", + }; + + size_t task_size = tasks.size(); + for (size_t i = 0; i < task_size; ++i) + { + ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); + } + } } CATCH diff --git a/dbms/src/Flash/tests/gtest_window_executor.cpp b/dbms/src/Flash/tests/gtest_window_executor.cpp index 2a2260ada03..a3224d089ca 100644 --- a/dbms/src/Flash/tests/gtest_window_executor.cpp +++ b/dbms/src/Flash/tests/gtest_window_executor.cpp @@ -71,13 +71,14 @@ class WindowExecutorTestRunner : public DB::tests::ExecutorTest WRAP_FOR_DIS_ENABLE_PLANNER_END } - void executeWithTableScanAndConcurrency(const std::shared_ptr & request, const ColumnsWithTypeAndName & source_columns, const ColumnsWithTypeAndName & expect_columns) + void executeWithTableScanAndConcurrency(const std::shared_ptr & request, const String & db, const String & table_name, const ColumnsWithTypeAndName & source_columns, const ColumnsWithTypeAndName & expect_columns) { + context.addMockTableColumnData(db, table_name, source_columns); + ASSERT_COLUMNS_EQ_R(expect_columns, executeStreams(request)); WRAP_FOR_DIS_ENABLE_PLANNER_BEGIN - ASSERT_COLUMNS_EQ_R(expect_columns, executeStreamsWithSingleSource(request, source_columns, SourceType::TableScan)); for (size_t i = 2; i <= max_concurrency_level; ++i) { - ASSERT_COLUMNS_EQ_UR(expect_columns, executeStreamsWithSingleSource(request, source_columns, SourceType::TableScan)); + ASSERT_COLUMNS_EQ_UR(expect_columns, executeStreams(request, i)); } WRAP_FOR_DIS_ENABLE_PLANNER_END } @@ -101,12 +102,16 @@ try // null input executeWithTableScanAndConcurrency(request, + "test_db", + "test_table", {toNullableVec("partition", {}), toNullableVec("order", {})}, createColumns({})); // nullable executeWithTableScanAndConcurrency( request, + "test_db", + "test_table", {toNullableVec("partition", {{}, 1, 1, 1, 1, 2, 2, 2, 2}), {toNullableVec("order", {{}, 1, 1, 2, 2, 1, 1, 2, 2})}}, createColumns({toNullableVec("partition", {{}, 1, 1, 1, 1, 2, 2, 2, 2}), @@ -127,6 +132,8 @@ try // nullable executeWithTableScanAndConcurrency(request, + "test_db", + "test_table_string", {toNullableVec("partition", {"banana", "banana", "banana", "banana", {}, "apple", "apple", "apple", "apple"}), toNullableVec("order", {"apple", "apple", "banana", "banana", {}, "apple", "apple", "banana", "banana"})}, createColumns({toNullableVec("partition", {{}, "apple", "apple", "apple", "apple", "banana", "banana", "banana", "banana"}), @@ -147,6 +154,8 @@ try // nullable executeWithTableScanAndConcurrency(request, + "test_db", + "test_table_float64", {toNullableVec("partition", {{}, 1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00}), toNullableVec("order", {{}, 1.00, 1.00, 2.00, 2.00, 1.00, 1.00, 2.00, 2.00})}, createColumns({toNullableVec("partition", {{}, 1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00}), @@ -160,6 +169,8 @@ try .window(RowNumber(), {"order", false}, {"partition", false}, buildDefaultRowsFrame()) .build(context); executeWithTableScanAndConcurrency(request, + "test_db", + "test_table_datetime", {toNullableDatetimeVec("partition", {"20220101010102", "20220101010102", "20220101010102", "20220101010102", "20220101010101", "20220101010101", "20220101010101", "20220101010101"}, 0), toDatetimeVec("order", {"20220101010101", "20220101010101", "20220101010102", "20220101010102", "20220101010101", "20220101010101", "20220101010102", "20220101010102"}, 0)}, createColumns({toNullableDatetimeVec("partition", {"20220101010101", "20220101010101", "20220101010101", "20220101010101", "20220101010102", "20220101010102", "20220101010102", "20220101010102"}, 0), @@ -168,6 +179,8 @@ try // nullable executeWithTableScanAndConcurrency(request, + "test_db", + "test_table_datetime", {toNullableDatetimeVec("partition", {"20220101010102", {}, "20220101010102", "20220101010102", "20220101010102", "20220101010101", "20220101010101", "20220101010101", "20220101010101"}, 0), toNullableDatetimeVec("order", {"20220101010101", {}, "20220101010101", "20220101010102", "20220101010102", "20220101010101", "20220101010101", "20220101010102", "20220101010102"}, 0)}, createColumns({toNullableDatetimeVec("partition", {{}, "20220101010101", "20220101010101", "20220101010101", "20220101010101", "20220101010102", "20220101010102", "20220101010102", "20220101010102"}, 0), @@ -199,6 +212,8 @@ try // nullable executeWithTableScanAndConcurrency(request, + "test_db", + "test_table_for_rank", {toNullableVec("partition", {{}, 1, 1, 1, 1, 2, 2, 2, 2}), toNullableVec("order", {{}, 1, 1, 2, 2, 1, 1, 2, 2})}, createColumns({toNullableVec("partition", {{}, 1, 1, 1, 1, 2, 2, 2, 2}), @@ -208,6 +223,8 @@ try executeWithTableScanAndConcurrency( request, + "test_db", + "test_table_for_rank", {toNullableVec("partition", {{}, {}, 1, 1, 1, 1, 2, 2, 2, 2}), toNullableVec("order", {{}, 1, 1, 1, 2, 2, 1, 1, 2, 2})}, createColumns({toNullableVec("partition", {{}, {}, 1, 1, 1, 1, 2, 2, 2, 2}), diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index beffc7d0928..8dfa37c4ce0 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -1859,29 +1859,24 @@ bool Context::isTest() const return test_mode != non_test; } -void Context::setColumnsForTest(std::unordered_map & columns_for_test_map_) +void Context::setMockStorage(MockStorage & mock_storage_) { - columns_for_test_map = columns_for_test_map_; + mock_storage = mock_storage_; } -std::unordered_map & Context::getColumnsForTestMap() +MockStorage Context::mockStorage() const { - return columns_for_test_map; + return mock_storage; } -ColumnsWithTypeAndName Context::columnsForTest(String executor_id) +MockMPPServerInfo Context::mockMPPServerInfo() const { - auto it = columns_for_test_map.find(executor_id); - if (unlikely(it == columns_for_test_map.end())) - { - throw DB::Exception("Don't have columns for mock source executors"); - } - return it->second; + return mpp_server_info; } -bool Context::columnsForTestEmpty() +void Context::setMockMPPServerInfo(MockMPPServerInfo & info) { - return columns_for_test_map.empty(); + mpp_server_info = info; } SessionCleaner::~SessionCleaner() diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 66942ea709b..6236c214607 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include #include @@ -96,6 +98,8 @@ class WriteLimiter; using WriteLimiterPtr = std::shared_ptr; class ReadLimiter; using ReadLimiterPtr = std::shared_ptr; +using MockMPPServerInfo = DB::tests::MockMPPServerInfo; +using MockStorage = DB::tests::MockStorage; enum class PageStorageRunMode : UInt8; namespace DM @@ -161,11 +165,12 @@ class Context }; TestMode test_mode = non_test; + MockStorage mock_storage; + MockMPPServerInfo mpp_server_info{}; + TimezoneInfo timezone_info; DAGContext * dag_context = nullptr; - // TODO: add MockStorage. - std::unordered_map columns_for_test_map; /// , for multiple sources using DatabasePtr = std::shared_ptr; using Databases = std::map>; @@ -471,10 +476,11 @@ class Context bool isExecutorTest() const; void setExecutorTest(); bool isTest() const; - void setColumnsForTest(std::unordered_map & columns_for_test_map_); - std::unordered_map & getColumnsForTestMap(); - ColumnsWithTypeAndName columnsForTest(String executor_id); - bool columnsForTestEmpty(); + + void setMockStorage(MockStorage & mock_storage_); + MockStorage mockStorage() const; + MockMPPServerInfo mockMPPServerInfo() const; + void setMockMPPServerInfo(MockMPPServerInfo & info); private: /** Check if the current client has access to the specified database. diff --git a/dbms/src/Server/FlashGrpcServerHolder.cpp b/dbms/src/Server/FlashGrpcServerHolder.cpp index c82f79976e8..1190985004d 100644 --- a/dbms/src/Server/FlashGrpcServerHolder.cpp +++ b/dbms/src/Server/FlashGrpcServerHolder.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include + namespace DB { namespace ErrorCodes @@ -181,6 +182,8 @@ FlashGrpcServerHolder::~FlashGrpcServerHolder() flash_grpc_server.reset(); if (GRPCCompletionQueuePool::global_instance) GRPCCompletionQueuePool::global_instance->markShutdown(); + + GRPCCompletionQueuePool::global_instance = nullptr; LOG_FMT_INFO(log, "Shut down flash grpc server"); /// Close flash service. @@ -195,4 +198,14 @@ FlashGrpcServerHolder::~FlashGrpcServerHolder() std::terminate(); } } + +void FlashGrpcServerHolder::setMockStorage(MockStorage & mock_storage) +{ + flash_service->setMockStorage(mock_storage); +} + +void FlashGrpcServerHolder::setMockMPPServerInfo(MockMPPServerInfo info) +{ + flash_service->setMockMPPServerInfo(info); +} } // namespace DB \ No newline at end of file diff --git a/dbms/src/Server/FlashGrpcServerHolder.h b/dbms/src/Server/FlashGrpcServerHolder.h index 81c50dc609b..57146f40aae 100644 --- a/dbms/src/Server/FlashGrpcServerHolder.h +++ b/dbms/src/Server/FlashGrpcServerHolder.h @@ -22,6 +22,9 @@ namespace DB { +using MockStorage = tests::MockStorage; +using MockMPPServerInfo = tests::MockMPPServerInfo; + class FlashGrpcServerHolder { public: @@ -33,6 +36,9 @@ class FlashGrpcServerHolder const LoggerPtr & log_); ~FlashGrpcServerHolder(); + void setMockStorage(MockStorage & mock_storage); + void setMockMPPServerInfo(MockMPPServerInfo info); + private: const LoggerPtr & log; std::shared_ptr> is_shutdown; diff --git a/dbms/src/TestUtils/ExecutorTestUtils.cpp b/dbms/src/TestUtils/ExecutorTestUtils.cpp index 682b783e2af..85d81742a0d 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.cpp +++ b/dbms/src/TestUtils/ExecutorTestUtils.cpp @@ -14,6 +14,8 @@ #include #include +#include +#include #include #include #include @@ -122,35 +124,19 @@ void ExecutorTest::enablePlanner(bool is_enable) context.context.setSetting("enable_planner", is_enable ? "true" : "false"); } -DB::ColumnsWithTypeAndName ExecutorTest::executeStreams(const std::shared_ptr & request, std::unordered_map & source_columns_map, size_t concurrency) +DB::ColumnsWithTypeAndName ExecutorTest::executeStreams(const std::shared_ptr & request, size_t concurrency) { DAGContext dag_context(*request, "executor_test", concurrency); context.context.setExecutorTest(); - context.context.setColumnsForTest(source_columns_map); + context.context.setMockStorage(context.mockStorage()); context.context.setDAGContext(&dag_context); // Currently, don't care about regions information in tests. return readBlock(executeQuery(context.context).in); } -DB::ColumnsWithTypeAndName ExecutorTest::executeStreams(const std::shared_ptr & request, size_t concurrency) -{ - return executeStreams(request, context.executorIdColumnsMap(), concurrency); -} - -DB::ColumnsWithTypeAndName ExecutorTest::executeStreamsWithSingleSource(const std::shared_ptr & request, const ColumnsWithTypeAndName & source_columns, SourceType type, size_t concurrency) -{ - std::unordered_map source_columns_map; - source_columns_map[getSourceName(type)] = source_columns; - return executeStreams(request, source_columns_map, concurrency); -} - -DB::ColumnsWithTypeAndName ExecutorTest::executeMPPTasks(QueryTasks & tasks) +DB::ColumnsWithTypeAndName ExecutorTest::executeMPPTasks(QueryTasks & tasks, const DAGProperties & properties, std::unordered_map & server_config_map) { - DAGProperties properties; - // enable mpp - properties.is_mpp_query = true; - context.context.setMPPTest(); - auto res = executeMPPQuery(context.context, properties, tasks); + auto res = executeMPPQuery(context.context, properties, tasks, server_config_map); return readBlock(res); } diff --git a/dbms/src/TestUtils/ExecutorTestUtils.h b/dbms/src/TestUtils/ExecutorTestUtils.h index 48f2cc68513..9ca73216464 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.h +++ b/dbms/src/TestUtils/ExecutorTestUtils.h @@ -52,6 +52,7 @@ class ExecutorTest : public ::testing::Test ExecutorTest() : context(TiFlashTestEnv::getContext()) {} + static void SetUpTestCase(); virtual void initializeContext(); @@ -90,20 +91,9 @@ class ExecutorTest : public ::testing::Test ColumnsWithTypeAndName executeStreams( const std::shared_ptr & request, - std::unordered_map & source_columns_map, - size_t concurrency = 1); - - ColumnsWithTypeAndName executeStreams( - const std::shared_ptr & request, - size_t concurrency = 1); - - ColumnsWithTypeAndName executeStreamsWithSingleSource( - const std::shared_ptr & request, - const ColumnsWithTypeAndName & source_columns, - SourceType type = TableScan, size_t concurrency = 1); - ColumnsWithTypeAndName executeMPPTasks(QueryTasks & tasks); + ColumnsWithTypeAndName executeMPPTasks(QueryTasks & tasks, const DAGProperties & properties, std::unordered_map & server_config_map); protected: MockDAGRequestContext context; @@ -112,9 +102,5 @@ class ExecutorTest : public ::testing::Test #define ASSERT_DAGREQUEST_EQAUL(str, request) dagRequestEqual((str), (request)); #define ASSERT_BLOCKINPUTSTREAM_EQAUL(str, request, concurrency) executeInterpreter((str), (request), (concurrency)) -#define ASSERT_MPPTASK_EQUAL(tasks, expect_cols) \ - TiFlashTestEnv::getGlobalContext().setColumnsForTest(context.executorIdColumnsMap()); \ - TiFlashTestEnv::getGlobalContext().setMPPTest(); \ - ASSERT_COLUMNS_EQ_UR(executeMPPTasks(tasks), expected_cols); } // namespace DB::tests diff --git a/dbms/src/TestUtils/MPPTaskTestUtils.h b/dbms/src/TestUtils/MPPTaskTestUtils.h index 9e710c6d00f..ec81a1c8b52 100644 --- a/dbms/src/TestUtils/MPPTaskTestUtils.h +++ b/dbms/src/TestUtils/MPPTaskTestUtils.h @@ -14,11 +14,56 @@ #pragma once +#include +#include #include #include namespace DB::tests { +class MockTimeStampGenerator : public ext::Singleton +{ +public: + Int64 nextTs() + { + return ++current_ts; + } + +private: + std::atomic current_ts = 0; +}; + +class MockServerAddrGenerator : public ext::Singleton +{ +public: + String nextAddr() + { + if (port >= port_upper_bound) + { + throw Exception("Failed to get next server addr"); + } + return "0.0.0.0:" + std::to_string(port++); + } + + void reset() + { + port = 3931; + } + +private: + const Int64 port_upper_bound = 65536; + std::atomic port = 3931; +}; + +DAGProperties getDAGPropertiesForTest(int server_num) +{ + DAGProperties properties; + // enable mpp + properties.is_mpp_query = true; + properties.mpp_partition_num = server_num; + properties.start_ts = MockTimeStampGenerator::instance().nextTs(); + return properties; +} class MPPTaskTestUtils : public ExecutorTest { @@ -26,38 +71,80 @@ class MPPTaskTestUtils : public ExecutorTest static void SetUpTestCase() { ExecutorTest::SetUpTestCase(); - TiFlashSecurityConfig security_config; - TiFlashRaftConfig raft_config; - raft_config.flash_server_addr = "0.0.0.0:3930"; // TODO:: each FlashGrpcServer should have unique addr. - Poco::AutoPtr config = new Poco::Util::LayeredConfiguration; log_ptr = Logger::get("compute_test"); - compute_server_ptr = std::make_unique(TiFlashTestEnv::getGlobalContext(), *config, security_config, raft_config, log_ptr); + server_num = 1; + } + + static void TearDownTestCase() // NOLINT(readability-identifier-naming)) + { + MockComputeServerManager::instance().reset(); + } + + void startServers() + { + startServers(server_num); + } + + void startServers(size_t server_num_) + { + server_num = server_num_; + MockComputeServerManager::instance().reset(); + auto size = std::thread::hardware_concurrency(); + GRPCCompletionQueuePool::global_instance = std::make_unique(size); + for (size_t i = 0; i < server_num; ++i) + { + MockComputeServerManager::instance().addServer(MockServerAddrGenerator::instance().nextAddr()); + } + MockComputeServerManager::instance().startServers(log_ptr, TiFlashTestEnv::getGlobalContext()); + MockServerAddrGenerator::instance().reset(); } - static void TearDownTestCase() + size_t serverNum() const { - compute_server_ptr.reset(); + return server_num; } protected: - // TODO: Mock a simple storage layer to store test input. - // Currently the lifetime of a server is held in this scope. - // TODO: Add ComputeServerManager to maintain the lifetime of a bunch of servers. - // Note: May go through GRPC fail number 14 --> socket closed, - // if you start a server, send a request to the server using pingcap::kv::RpcClient, - // then close the server and start the server using the same addr, - // then send a request to the new server using pingcap::kv::RpcClient. - static std::unique_ptr compute_server_ptr; static LoggerPtr log_ptr; + static size_t server_num; }; -std::unique_ptr MPPTaskTestUtils::compute_server_ptr = nullptr; LoggerPtr MPPTaskTestUtils::log_ptr = nullptr; +size_t MPPTaskTestUtils::server_num = 0; +#define ASSERT_MPPTASK_EQUAL(tasks, properties, expect_cols) \ + do \ + { \ + TiFlashTestEnv::getGlobalContext().setMPPTest(); \ + MockComputeServerManager::instance().setMockStorage(context.mockStorage()); \ + ASSERT_COLUMNS_EQ_UR(executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()), expected_cols); \ + } while (0) -#define ASSERT_MPPTASK_EQUAL(tasks, expect_cols) \ - TiFlashTestEnv::getGlobalContext().setColumnsForTest(context.executorIdColumnsMap()); \ - TiFlashTestEnv::getGlobalContext().setMPPTest(); \ - ASSERT_COLUMNS_EQ_UR(executeMPPTasks(tasks), expected_cols); +#define ASSERT_MPPTASK_EQUAL_WITH_SERVER_NUM(builder, properties, expect_cols) \ + do \ + { \ + for (size_t i = 1; i <= serverNum(); ++i) \ + { \ + (properties).mpp_partition_num = i; \ + MockComputeServerManager::instance().resetMockMPPServerInfo(i); \ + auto tasks = (builder).buildMPPTasks(context, properties); \ + ASSERT_MPPTASK_EQUAL(tasks, properties, expect_cols); \ + } \ + } while (0) +#define ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT(builder, expected_strings, expected_cols) \ + do \ + { \ + auto properties = getDAGPropertiesForTest(serverNum()); \ + auto tasks = (builder).buildMPPTasks(context, properties); \ + size_t task_size = tasks.size(); \ + for (size_t i = 0; i < task_size; ++i) \ + { \ + ASSERT_DAGREQUEST_EQAUL((expected_strings)[i], tasks[i].dag_request); \ + } \ + ASSERT_MPPTASK_EQUAL_WITH_SERVER_NUM( \ + builder, \ + properties, \ + expect_cols); \ + } while (0) } // namespace DB::tests diff --git a/dbms/src/TestUtils/mockExecutor.cpp b/dbms/src/TestUtils/mockExecutor.cpp index 2b12eeb9c18..960c686ae8b 100644 --- a/dbms/src/TestUtils/mockExecutor.cpp +++ b/dbms/src/TestUtils/mockExecutor.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -108,24 +109,35 @@ void columnPrune(ExecutorPtr executor) // Split a DAGRequest into multiple QueryTasks which can be dispatched to multiple Compute nodes. // Currently we don't support window functions -// and MPPTask with multiple partitions. +QueryTasks DAGRequestBuilder::buildMPPTasks(MockDAGRequestContext & mock_context, const DAGProperties & properties) +{ + columnPrune(root); + mock_context.context.setMPPTest(); + auto query_tasks = queryPlanToQueryTasks(properties, root, executor_index, mock_context.context); + root.reset(); + executor_index = 0; + return query_tasks; +} + QueryTasks DAGRequestBuilder::buildMPPTasks(MockDAGRequestContext & mock_context) { columnPrune(root); - // enable mpp + DAGProperties properties; properties.is_mpp_query = true; - // TODO find a way to record service info. + properties.mpp_partition_num = 1; + mock_context.context.setMPPTest(); auto query_tasks = queryPlanToQueryTasks(properties, root, executor_index, mock_context.context); root.reset(); executor_index = 0; return query_tasks; } -DAGRequestBuilder & DAGRequestBuilder::mockTable(const String & db, const String & table, const MockColumnInfoVec & columns) +DAGRequestBuilder & DAGRequestBuilder::mockTable(const String & db, const String & table, Int64 table_id, const MockColumnInfoVec & columns) { assert(!columns.empty()); TableInfo table_info; table_info.name = db + "." + table; + table_info.id = table_id; int i = 0; for (const auto & column : columns) { @@ -142,9 +154,9 @@ DAGRequestBuilder & DAGRequestBuilder::mockTable(const String & db, const String return *this; } -DAGRequestBuilder & DAGRequestBuilder::mockTable(const MockTableName & name, const MockColumnInfoVec & columns) +DAGRequestBuilder & DAGRequestBuilder::mockTable(const MockTableName & name, Int64 table_id, const MockColumnInfoVec & columns) { - return mockTable(name.first, name.second, columns); + return mockTable(name.first, name.second, table_id, columns); } DAGRequestBuilder & DAGRequestBuilder::exchangeReceiver(const MockColumnInfoVec & columns, uint64_t fine_grained_shuffle_stream_count) @@ -333,32 +345,32 @@ DAGRequestBuilder & DAGRequestBuilder::sort(MockOrderByItemVec order_by_vec, boo void MockDAGRequestContext::addMockTable(const String & db, const String & table, const MockColumnInfoVec & columnInfos) { - mock_tables[db + "." + table] = columnInfos; + mock_storage.addTableSchema(db + "." + table, columnInfos); } void MockDAGRequestContext::addMockTable(const MockTableName & name, const MockColumnInfoVec & columnInfos) { - mock_tables[name.first + "." + name.second] = columnInfos; + mock_storage.addTableSchema(name.first + "." + name.second, columnInfos); } void MockDAGRequestContext::addExchangeRelationSchema(String name, const MockColumnInfoVec & columnInfos) { - exchange_schemas[name] = columnInfos; + mock_storage.addExchangeSchema(name, columnInfos); } void MockDAGRequestContext::addMockTableColumnData(const String & db, const String & table, ColumnsWithTypeAndName columns) { - mock_table_columns[db + "." + table] = columns; + mock_storage.addTableData(db + "." + table, columns); } void MockDAGRequestContext::addMockTableColumnData(const MockTableName & name, ColumnsWithTypeAndName columns) { - mock_table_columns[name.first + "." + name.second] = columns; + mock_storage.addTableData(name.first + "." + name.second, columns); } void MockDAGRequestContext::addExchangeReceiverColumnData(const String & name, ColumnsWithTypeAndName columns) { - mock_exchange_columns[name] = columns; + mock_storage.addExchangeData(name, columns); } void MockDAGRequestContext::addMockTable(const String & db, const String & table, const MockColumnInfoVec & columnInfos, ColumnsWithTypeAndName columns) @@ -379,28 +391,17 @@ void MockDAGRequestContext::addExchangeReceiver(const String & name, MockColumnI addExchangeReceiverColumnData(name, columns); } -DAGRequestBuilder MockDAGRequestContext::scan(String db_name, String table_name) +DAGRequestBuilder MockDAGRequestContext::scan(const String & db_name, const String & table_name) { - auto builder = DAGRequestBuilder(index, collation).mockTable({db_name, table_name}, mock_tables[db_name + "." + table_name]); - // If don't have related columns, user must pass input columns as argument of executeStreams in order to run Executors Tests. - // If user don't want to test executors, it will be safe to run Interpreter Tests. - if (mock_table_columns.find(db_name + "." + table_name) != mock_table_columns.end()) - { - executor_id_columns_map[builder.getRoot()->name] = mock_table_columns[db_name + "." + table_name]; - } - return builder; + auto table_id = mock_storage.getTableId(db_name + "." + table_name); + return DAGRequestBuilder(index, collation).mockTable({db_name, table_name}, table_id, mock_storage.getTableSchema(db_name + "." + table_name)); } -DAGRequestBuilder MockDAGRequestContext::receive(String exchange_name, uint64_t fine_grained_shuffle_stream_count) +DAGRequestBuilder MockDAGRequestContext::receive(const String & exchange_name, uint64_t fine_grained_shuffle_stream_count) { - auto builder = DAGRequestBuilder(index, collation).exchangeReceiver(exchange_schemas[exchange_name], fine_grained_shuffle_stream_count); + auto builder = DAGRequestBuilder(index, collation).exchangeReceiver(mock_storage.getExchangeSchema(exchange_name), fine_grained_shuffle_stream_count); receiver_source_task_ids_map[builder.getRoot()->name] = {}; - // If don't have related columns, user must pass input columns as argument of executeStreams in order to run Executors Tests. - // If user don't want to test executors, it will be safe to run Interpreter Tests. - if (mock_exchange_columns.find(exchange_name) != mock_exchange_columns.end()) - { - executor_id_columns_map[builder.getRoot()->name] = mock_exchange_columns[exchange_name]; - } + mock_storage.addExchangeRelation(builder.getRoot()->name, exchange_name); return builder; } } // namespace DB::tests diff --git a/dbms/src/TestUtils/mockExecutor.h b/dbms/src/TestUtils/mockExecutor.h index 84ddcc29802..58d9d1806ab 100644 --- a/dbms/src/TestUtils/mockExecutor.h +++ b/dbms/src/TestUtils/mockExecutor.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -71,9 +72,10 @@ class DAGRequestBuilder std::shared_ptr build(MockDAGRequestContext & mock_context); QueryTasks buildMPPTasks(MockDAGRequestContext & mock_context); + QueryTasks buildMPPTasks(MockDAGRequestContext & mock_context, const DAGProperties & properties); - DAGRequestBuilder & mockTable(const String & db, const String & table, const MockColumnInfoVec & columns); - DAGRequestBuilder & mockTable(const MockTableName & name, const MockColumnInfoVec & columns); + DAGRequestBuilder & mockTable(const String & db, const String & table, Int64 table_id, const MockColumnInfoVec & columns); + DAGRequestBuilder & mockTable(const MockTableName & name, Int64 table_id, const MockColumnInfoVec & columns); DAGRequestBuilder & exchangeReceiver(const MockColumnInfoVec & columns, uint64_t fine_grained_shuffle_stream_count = 0); @@ -162,21 +164,17 @@ class MockDAGRequestContext void addExchangeReceiverColumnData(const String & name, ColumnsWithTypeAndName columns); void addExchangeReceiver(const String & name, MockColumnInfoVec columnInfos, ColumnsWithTypeAndName columns); - std::unordered_map & executorIdColumnsMap() { return executor_id_columns_map; } - - DAGRequestBuilder scan(String db_name, String table_name); - DAGRequestBuilder receive(String exchange_name, uint64_t fine_grained_shuffle_stream_count = 0); + DAGRequestBuilder scan(const String & db_name, const String & table_name); + DAGRequestBuilder receive(const String & exchange_name, uint64_t fine_grained_shuffle_stream_count = 0); void setCollation(Int32 collation_) { collation = convertToTiDBCollation(collation_); } Int32 getCollation() const { return abs(collation); } + MockStorage & mockStorage() { return mock_storage; } + private: size_t index; - std::unordered_map mock_tables; - std::unordered_map exchange_schemas; - std::unordered_map mock_table_columns; - std::unordered_map mock_exchange_columns; - std::unordered_map executor_id_columns_map; /// + MockStorage mock_storage; public: // Currently don't support task_id, so the following to structure is useless, diff --git a/libs/libdaemon/src/tests/CMakeLists.txt b/libs/libdaemon/src/tests/CMakeLists.txt index 3fbbcc282e2..2cf5eb74dd4 100644 --- a/libs/libdaemon/src/tests/CMakeLists.txt +++ b/libs/libdaemon/src/tests/CMakeLists.txt @@ -15,7 +15,7 @@ include (${TiFlash_SOURCE_DIR}/cmake/add_check.cmake) add_executable (gtests_libdaemon EXCLUDE_FROM_ALL gtest_daemon_config.cpp) -target_link_libraries (gtests_libdaemon gtest_main daemon) +target_link_libraries (gtests_libdaemon gtest_main daemon tipb) #add for libcctz used by BaseDaemon if (APPLE) set_target_properties(gtests_libdaemon PROPERTIES LINK_FLAGS "-framework CoreFoundation")