diff --git a/clickhouse/base/wire_format.cpp b/clickhouse/base/wire_format.cpp index 62a21833..b3ca7c98 100644 --- a/clickhouse/base/wire_format.cpp +++ b/clickhouse/base/wire_format.cpp @@ -1,3 +1,4 @@ +#include #include "wire_format.h" #include "input.h" @@ -99,4 +100,81 @@ bool WireFormat::SkipString(InputStream& input) { return false; } +inline const char* find_quoted_chars(const char* start, const char* end) { + static const char quoted_chars[] = {'\0', '\b', '\t', '\n', '\'', '\\'}; + while (start < end) { + char c = *start; + for (unsigned i = 0; i < sizeof(quoted_chars); i++) { + if (quoted_chars[i] == c) return start; + } + start++; + } + return nullptr; +} + +void WireFormat::WriteQuotedString(OutputStream& output, std::string_view value) { + auto size = value.size(); + const char* start = value.data(); + const char* end = start + size; + const char* quoted_char = find_quoted_chars(start, end); + if (quoted_char == nullptr) { + WriteVarint64(output, size + 2); + WriteAll(output, "'", 1); + WriteAll(output, start, size); + WriteAll(output, "'", 1); + return; + } + + // calculate quoted chars count + int quoted_count = 1; + const char* next_quoted_char = quoted_char + 1; + while ((next_quoted_char = find_quoted_chars(next_quoted_char, end))) { + quoted_count++; + next_quoted_char++; + } + WriteVarint64(output, size + 2 + 3 * quoted_count); // length + + WriteAll(output, "'", 1); + + do { + auto write_size = quoted_char - start; + WriteAll(output, start, write_size); + WriteAll(output, "\\", 1); + char c = quoted_char[0]; + switch (c) { + case '\0': + WriteAll(output, "x00", 3); + break; + case '\b': + WriteAll(output, "x08", 3); + break; + case '\t': + WriteAll(output, R"(\\t)", 3); + break; + case '\n': + WriteAll(output, R"(\\n)", 3); + break; + case '\'': + WriteAll(output, "x27", 3); + break; + case '\\': + WriteAll(output, R"(\\\)", 3); + break; + default: + break; + } + start = quoted_char + 1; + quoted_char = find_quoted_chars(start, end); + } while (quoted_char); + + WriteAll(output, start, end - start); + WriteAll(output, "'", 1); } + +void WireFormat::WriteParamNullRepresentation(OutputStream& output) { + const std::string NULL_REPRESENTATION(R"('\\N')"); + WriteVarint64(output, NULL_REPRESENTATION.size()); + WriteAll(output, NULL_REPRESENTATION.data(), NULL_REPRESENTATION.size()); +} + +} // namespace clickhouse diff --git a/clickhouse/base/wire_format.h b/clickhouse/base/wire_format.h index 6ff53528..be58d7a9 100644 --- a/clickhouse/base/wire_format.h +++ b/clickhouse/base/wire_format.h @@ -22,6 +22,8 @@ class WireFormat { static void WriteFixed(OutputStream& output, const T& value); static void WriteBytes(OutputStream& output, const void* buf, size_t len); static void WriteString(OutputStream& output, std::string_view value); + static void WriteQuotedString(OutputStream& output, std::string_view value); + static void WriteParamNullRepresentation(OutputStream& output); static void WriteUInt64(OutputStream& output, const uint64_t value); static void WriteVarint64(OutputStream& output, uint64_t value); diff --git a/clickhouse/client.cpp b/clickhouse/client.cpp index 01ee70b4..a8ad3400 100644 --- a/clickhouse/client.cpp +++ b/clickhouse/client.cpp @@ -38,8 +38,13 @@ #define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448 #define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449 #define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451 +#define DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS 54453 +#define DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION 54454 // Client can get some fields in JSon format +#define DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM 54458 // send quota key after handshake +#define DBMS_MIN_PROTOCOL_REVISION_WITH_QUOTA_KEY 54458 // the same +#define DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS 54459 -#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS +#define DMBS_PROTOCOL_REVISION DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS namespace clickhouse { @@ -433,6 +438,11 @@ bool Client::Impl::Handshake() { if (!ReceiveHello()) { return false; } + + if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) { + WireFormat::WriteString(*output_, std::string()); + } + return true; } @@ -502,7 +512,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) { return false; } } - if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO) + if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO) { if (!WireFormat::ReadUInt64(*input_, &info.written_rows)) { return false; @@ -589,7 +599,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) { bool Client::Impl::ReadBlock(InputStream& input, Block* block) { // Additional information about block. - if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) { + if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) { uint64_t num; BlockInfo info; @@ -635,6 +645,16 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) { if (!WireFormat::ReadString(input, &type)) { return false; } + + if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) { + uint8_t custom_format_len; + if (!WireFormat::ReadFixed(input, &custom_format_len)) { + return false; + } + if (custom_format_len > 0) { + throw UnimplementedError(std::string("unsupported custom serialization")); + } + } if (ColumnRef col = CreateColumnByType(type, create_column_settings)) { if (num_rows && !col->Load(&input, num_rows)) { @@ -653,7 +673,7 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) { bool Client::Impl::ReceiveData() { Block block; - if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { + if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { if (!WireFormat::SkipString(*input_)) { return false; } @@ -793,6 +813,12 @@ void Client::Impl::SendQuery(const Query& query) { throw UnimplementedError(std::string("Can't send open telemetry tracing context to a server, server version is too old")); } } + if (server_info_.revision >= DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS) { + // replica dont supported by client + WireFormat::WriteUInt64(*output_, 0); + WireFormat::WriteUInt64(*output_, 0); + WireFormat::WriteUInt64(*output_, 0); + } } /// Per query settings @@ -817,6 +843,22 @@ void Client::Impl::SendQuery(const Query& query) { WireFormat::WriteUInt64(*output_, Stages::Complete); WireFormat::WriteUInt64(*output_, compression_); WireFormat::WriteString(*output_, query.GetText()); + + //Send params after query text + if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) { + for(const auto& [name, value] : query.GetParams()) { + // params is like query settings + WireFormat::WriteString(*output_, name); + const uint64_t Custom = 2; + WireFormat::WriteVarint64(*output_, Custom); + if (value) + WireFormat::WriteQuotedString(*output_, *value); + else + WireFormat::WriteParamNullRepresentation(*output_); + } + WireFormat::WriteString(*output_, std::string()); // empty string after last param + } + // Send empty block as marker of // end of data SendData(Block()); @@ -842,6 +884,11 @@ void Client::Impl::WriteBlock(const Block& block, OutputStream& output) { WireFormat::WriteString(output, bi.Name()); WireFormat::WriteString(output, bi.Type()->GetName()); + if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) { + // TODO: custom serialization + WireFormat::WriteFixed(output, 0); + } + // Empty columns are not serialized and occupy exactly 0 bytes. // ref https://github.com/ClickHouse/ClickHouse/blob/39b37a3240f74f4871c8c1679910e065af6bea19/src/Formats/NativeWriter.cpp#L163 const bool containsData = block.GetRowCount() > 0; diff --git a/clickhouse/query.h b/clickhouse/query.h index b6551803..0dc82dd1 100644 --- a/clickhouse/query.h +++ b/clickhouse/query.h @@ -26,6 +26,8 @@ struct QuerySettingsField { }; using QuerySettings = std::unordered_map; +using QueryParamValue = std::optional; +using QueryParams = std::unordered_map; struct Profile { uint64_t rows = 0; @@ -115,6 +117,18 @@ class Query : public QueryEvents { return *this; } + inline const QueryParams& GetParams() const { return query_params_; } + + inline Query& SetParams(QueryParams query_params) { + query_params_ = std::move(query_params); + return *this; + } + + inline Query& SetParam(const std::string& name, const QueryParamValue& value) { + query_params_[name] = value; + return *this; + } + inline const std::optional& GetTracingContext() const { return tracing_context_; } @@ -219,6 +233,7 @@ class Query : public QueryEvents { const std::string query_id_; std::optional tracing_context_; QuerySettings query_settings_; + QueryParams query_params_; ExceptionCallback exception_cb_; ProgressCallback progress_cb_; SelectCallback select_cb_; diff --git a/tests/simple/main.cpp b/tests/simple/main.cpp index 2911e559..aa45bb11 100644 --- a/tests/simple/main.cpp +++ b/tests/simple/main.cpp @@ -234,6 +234,81 @@ inline void GenericExample(Client& client) { client.Execute("DROP TEMPORARY TABLE test_client"); } +inline void ParamExample(Client& client) { + /// Create a table. + client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name String)"); + + { + Query query("insert into test_client values ({id: UInt64}, {name: String})"); + + query.SetParam("id", "1").SetParam("name", "NAME"); + client.Execute(query); + + query.SetParam("id", "123").SetParam("name", "FromParam"); + client.Execute(query); + + const char FirstPrintable = ' '; + char test_str1[FirstPrintable * 2 + 1]; + for (unsigned int i = 0; i < FirstPrintable; i++) { + test_str1[i * 2] = 'A'; + test_str1[i * 2 + 1] = i; + } + test_str1[int(FirstPrintable * 2)] = 'A'; + + query.SetParam("id", "333").SetParam("name", std::string(test_str1, FirstPrintable * 2 + 1)); + client.Execute(query); + + const char LastPrintable = 127; + unsigned char big_string[LastPrintable - FirstPrintable]; + for (unsigned int i = 0; i < sizeof(big_string); i++) big_string[i] = i + FirstPrintable; + query.SetParam("id", "444").SetParam("name", std::string((char*)big_string, sizeof(big_string))); + client.Execute(query); + + query.SetParam("id", "555") + .SetParam("name", "utf8Русский"); + client.Execute(query); + } + + /// Select values inserted in the previous step. + Query query ("SELECT id, name, length(name) FROM test_client where id > {a: Int32}"); + query.SetParam("a", "4"); + SelectCallback cb([](const Block& block) + { + std::cout << PrettyPrintBlock{block} << std::endl; + }); + query.OnData(cb); + client.Select(query); + /// Delete table. + client.Execute("DROP TEMPORARY TABLE test_client"); +} + +inline void ParamNullExample(Client& client) { + client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name Nullable(String))"); + + Query query("insert into test_client values ({id: UInt64}, {name: Nullable(String)})"); + + query.SetParam("id", "123").SetParam("name", QueryParamValue()); + client.Execute(query); + + query.SetParam("id", "456").SetParam("name", "String Value"); + client.Execute(query); + + client.Select("SELECT id, name FROM test_client", [](const Block& block) { + for (size_t c = 0; c < block.GetRowCount(); ++c) { + std::cerr << block[0]->As()->At(c) << " "; + + auto col_string = block[1]->As(); + if (col_string->IsNull(c)) { + std::cerr << "\\N\n"; + } else { + std::cerr << col_string->Nested()->As()->At(c) << "\n"; + } + } + }); + + client.Execute("DROP TEMPORARY TABLE test_client"); +} + inline void NullableExample(Client& client) { /// Create a table. client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id Nullable(UInt64), date Nullable(Date))"); @@ -478,6 +553,8 @@ inline void IPExample(Client &client) { } static void RunTests(Client& client) { + ParamExample(client); + ParamNullExample(client); ArrayExample(client); CancelableExample(client); DateExample(client); diff --git a/ut/client_ut.cpp b/ut/client_ut.cpp index 3e08b17d..43b8a104 100644 --- a/ut/client_ut.cpp +++ b/ut/client_ut.cpp @@ -1488,3 +1488,54 @@ TEST(SimpleClientTest, issue_335_reconnects_count) { << "\tThere was no attempt to connect to endpoint " << endpoint; } } + +TEST_P(ClientCase, QueryParameters) { + const auto & server_info = client_->GetServerInfo(); + if (versionNumber(server_info) < versionNumber(24, 7)) { + GTEST_SKIP() << "Test is skipped since server '" << server_info << "' does not support query parameters" << std::endl; + } + const std::string table_name = "test_clickhouse_cpp_query_parameter"; + client_->Execute("CREATE TEMPORARY TABLE IF NOT EXISTS " + table_name + " (id UInt64, name String)"); + { + Query query("insert into " + table_name + " values ({id: UInt64}, {name: String})"); + + query.SetParam("id", "1").SetParam("name", "NAME"); + client_->Execute(query); + + query.SetParam("id", "123").SetParam("name", "FromParam"); + client_->Execute(query); + + const char FirstPrintable = ' '; + char test_str1[FirstPrintable * 2 + 1]; + for (unsigned int i = 0; i < FirstPrintable; i++) { + test_str1[i * 2] = 'A'; + test_str1[i * 2 + 1] = i; + } + test_str1[int(FirstPrintable * 2)] = 'A'; + + query.SetParam("id", "333").SetParam("name", std::string(test_str1, FirstPrintable * 2 + 1)); + client_->Execute(query); + + const char LastPrintable = 127; + unsigned char big_string[LastPrintable - FirstPrintable]; + for (unsigned int i = 0; i < sizeof(big_string); i++) big_string[i] = i + FirstPrintable; + query.SetParam("id", "444").SetParam("name", std::string((char*)big_string, sizeof(big_string))); + client_->Execute(query); + + query.SetParam("id", "555").SetParam("name", "utf8Русский"); + client_->Execute(query); + } + + Query query("SELECT id, name, length(name) FROM " + table_name + " where id > {a: Int32}"); + query.SetParam("a", "4"); + size_t total_count = 0; + SelectCallback cb([&total_count](const Block& block) { + total_count += block.GetRowCount(); + //std::cout << PrettyPrintBlock{block} << std::endl; + }); + query.OnData(cb); + client_->Select(query); + EXPECT_EQ(4u, total_count); + + client_->Execute("DROP TEMPORARY TABLE " + table_name); +}