Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameters in clickhouse-cpp #394

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions clickhouse/base/wire_format.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <assert.h>
#include "wire_format.h"

#include "input.h"
Expand Down Expand Up @@ -99,4 +100,81 @@ bool WireFormat::SkipString(InputStream& input) {
return false;
}

inline const char* find_quoted_chars(const char* start, const char* end) {
static const char quoted_chars[] = {'\0', '\b', '\t', '\n', '\'', '\\'};
while (start < end) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not std::find_first_of ?

Copy link
Author

@OlegGalizin OlegGalizin Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it more comfortable for me.
I think std::find_first_of needs
one extra compare because it return 'end' on failure
Or after call we needs compare with 'end' that worse then comparing with 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use functions from standard library whenever possible, instead of inventing new ones with 99% overlapping functionality. "comfortable for me" is not a valid excuse...

Copy link
Author

@OlegGalizin OlegGalizin Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about extra compare instruction and slow down the solution?
I must be sure that you are understand that you suggest

se for example find_first_symbols_sse2 in main clickhouse repo that also uses double loop and return nullptr

Copy link
Collaborator

@Enmk Enmk Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, find_first_symbols_sse2 was introduced to be used in a very hot pieces of code, which I believe is not the case here.
Second, there is a very high chance that compiler can optimize call std::find_first_of (e.g. using SSE) way beyond anything you are trying to gain here.
And last, find_first_symbols_sse2 supports two modes: "return nullptr" and "return end" 😉.

This looks like you are trying to prematurely optimize O(N + 1) into O(N), by introducing something that would unnecessary complicate the support.

char c = *start;
for (unsigned i = 0; i < sizeof(quoted_chars); i++) {
if (quoted_chars[i] == c) return start;
}
start++;
}
return nullptr;
}

void WireFormat::WriteQuotedString(OutputStream& output, std::string_view value) {
auto size = value.size();
const char* start = value.data();
const char* end = start + size;
const char* quoted_char = find_quoted_chars(start, end);
if (quoted_char == nullptr) {
WriteVarint64(output, size + 2);
WriteAll(output, "'", 1);
WriteAll(output, start, size);
WriteAll(output, "'", 1);
return;
}

// calculate quoted chars count
int quoted_count = 1;
const char* next_quoted_char = quoted_char + 1;
while ((next_quoted_char = find_quoted_chars(next_quoted_char, end))) {
quoted_count++;
next_quoted_char++;
}
WriteVarint64(output, size + 2 + 3 * quoted_count); // length

WriteAll(output, "'", 1);

do {
auto write_size = quoted_char - start;
WriteAll(output, start, write_size);
WriteAll(output, "\\", 1);
char c = quoted_char[0];
switch (c) {
case '\0':
WriteAll(output, "x00", 3);
break;
case '\b':
WriteAll(output, "x08", 3);
break;
case '\t':
WriteAll(output, R"(\\t)", 3);
break;
case '\n':
WriteAll(output, R"(\\n)", 3);
break;
case '\'':
WriteAll(output, "x27", 3);
break;
case '\\':
WriteAll(output, R"(\\\)", 3);
break;
default:
break;
}
start = quoted_char + 1;
quoted_char = find_quoted_chars(start, end);
} while (quoted_char);

WriteAll(output, start, end - start);
WriteAll(output, "'", 1);
}

void WireFormat::WriteParamNullRepresentation(OutputStream& output) {
const std::string NULL_REPRESENTATION(R"('\\N')");
WriteVarint64(output, NULL_REPRESENTATION.size());
WriteAll(output, NULL_REPRESENTATION.data(), NULL_REPRESENTATION.size());
}

} // namespace clickhouse
2 changes: 2 additions & 0 deletions clickhouse/base/wire_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class WireFormat {
static void WriteFixed(OutputStream& output, const T& value);
static void WriteBytes(OutputStream& output, const void* buf, size_t len);
static void WriteString(OutputStream& output, std::string_view value);
static void WriteQuotedString(OutputStream& output, std::string_view value);
static void WriteParamNullRepresentation(OutputStream& output);
static void WriteUInt64(OutputStream& output, const uint64_t value);
static void WriteVarint64(OutputStream& output, uint64_t value);

Expand Down
55 changes: 51 additions & 4 deletions clickhouse/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@
#define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448
#define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449
#define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451
#define DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS 54453
#define DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION 54454 // Client can get some fields in JSon format
#define DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM 54458 // send quota key after handshake
#define DBMS_MIN_PROTOCOL_REVISION_WITH_QUOTA_KEY 54458 // the same
#define DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS 54459

#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS
#define DMBS_PROTOCOL_REVISION DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS

namespace clickhouse {

Expand Down Expand Up @@ -433,6 +438,11 @@ bool Client::Impl::Handshake() {
if (!ReceiveHello()) {
return false;
}

if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) {
WireFormat::WriteString(*output_, std::string());
}

return true;
}

Expand Down Expand Up @@ -502,7 +512,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) {
return false;
}
}
if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
{
if (!WireFormat::ReadUInt64(*input_, &info.written_rows)) {
return false;
Expand Down Expand Up @@ -589,7 +599,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) {

bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
// Additional information about block.
if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
uint64_t num;
BlockInfo info;

Expand Down Expand Up @@ -635,6 +645,16 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
if (!WireFormat::ReadString(input, &type)) {
return false;
}

if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) {
uint8_t custom_format_len;
if (!WireFormat::ReadFixed(input, &custom_format_len)) {
return false;
}
if (custom_format_len > 0) {
throw UnimplementedError(std::string("unsupported custom serialization"));
}
}

if (ColumnRef col = CreateColumnByType(type, create_column_settings)) {
if (num_rows && !col->Load(&input, num_rows)) {
Expand All @@ -653,7 +673,7 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
bool Client::Impl::ReceiveData() {
Block block;

if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
if (!WireFormat::SkipString(*input_)) {
return false;
}
Expand Down Expand Up @@ -793,6 +813,12 @@ void Client::Impl::SendQuery(const Query& query) {
throw UnimplementedError(std::string("Can't send open telemetry tracing context to a server, server version is too old"));
}
}
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS) {
// replica dont supported by client
WireFormat::WriteUInt64(*output_, 0);
OlegGalizin marked this conversation as resolved.
Show resolved Hide resolved
WireFormat::WriteUInt64(*output_, 0);
WireFormat::WriteUInt64(*output_, 0);
}
}

/// Per query settings
Expand All @@ -817,6 +843,22 @@ void Client::Impl::SendQuery(const Query& query) {
WireFormat::WriteUInt64(*output_, Stages::Complete);
WireFormat::WriteUInt64(*output_, compression_);
WireFormat::WriteString(*output_, query.GetText());

//Send params after query text
if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) {
for(const auto& [name, value] : query.GetParams()) {
// params is like query settings
WireFormat::WriteString(*output_, name);
const uint64_t Custom = 2;
WireFormat::WriteVarint64(*output_, Custom);
if (value)
WireFormat::WriteQuotedString(*output_, *value);
else
WireFormat::WriteParamNullRepresentation(*output_);
}
WireFormat::WriteString(*output_, std::string()); // empty string after last param
}

// Send empty block as marker of
// end of data
SendData(Block());
Expand All @@ -842,6 +884,11 @@ void Client::Impl::WriteBlock(const Block& block, OutputStream& output) {
WireFormat::WriteString(output, bi.Name());
WireFormat::WriteString(output, bi.Type()->GetName());

if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) {
// TODO: custom serialization
WireFormat::WriteFixed<uint8_t>(output, 0);
}

// Empty columns are not serialized and occupy exactly 0 bytes.
// ref https://github.com/ClickHouse/ClickHouse/blob/39b37a3240f74f4871c8c1679910e065af6bea19/src/Formats/NativeWriter.cpp#L163
const bool containsData = block.GetRowCount() > 0;
Expand Down
15 changes: 15 additions & 0 deletions clickhouse/query.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ struct QuerySettingsField {
};

using QuerySettings = std::unordered_map<std::string, QuerySettingsField>;
using QueryParamValue = std::optional<std::string>;
using QueryParams = std::unordered_map<std::string, QueryParamValue>;

struct Profile {
uint64_t rows = 0;
Expand Down Expand Up @@ -115,6 +117,18 @@ class Query : public QueryEvents {
return *this;
}

inline const QueryParams& GetParams() const { return query_params_; }

inline Query& SetParams(QueryParams query_params) {
query_params_ = std::move(query_params);
return *this;
}

inline Query& SetParam(const std::string& name, const QueryParamValue& value) {
query_params_[name] = value;
return *this;
}

inline const std::optional<open_telemetry::TracingContext>& GetTracingContext() const {
return tracing_context_;
}
Expand Down Expand Up @@ -219,6 +233,7 @@ class Query : public QueryEvents {
const std::string query_id_;
std::optional<open_telemetry::TracingContext> tracing_context_;
QuerySettings query_settings_;
QueryParams query_params_;
ExceptionCallback exception_cb_;
ProgressCallback progress_cb_;
SelectCallback select_cb_;
Expand Down
77 changes: 77 additions & 0 deletions tests/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,81 @@ inline void GenericExample(Client& client) {
client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void ParamExample(Client& client) {
/// Create a table.
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name String)");

{
Query query("insert into test_client values ({id: UInt64}, {name: String})");

query.SetParam("id", "1").SetParam("name", "NAME");
client.Execute(query);

query.SetParam("id", "123").SetParam("name", "FromParam");
client.Execute(query);

const char FirstPrintable = ' ';
char test_str1[FirstPrintable * 2 + 1];
for (unsigned int i = 0; i < FirstPrintable; i++) {
test_str1[i * 2] = 'A';
test_str1[i * 2 + 1] = i;
}
test_str1[int(FirstPrintable * 2)] = 'A';

query.SetParam("id", "333").SetParam("name", std::string(test_str1, FirstPrintable * 2 + 1));
client.Execute(query);

const char LastPrintable = 127;
unsigned char big_string[LastPrintable - FirstPrintable];
for (unsigned int i = 0; i < sizeof(big_string); i++) big_string[i] = i + FirstPrintable;
query.SetParam("id", "444").SetParam("name", std::string((char*)big_string, sizeof(big_string)));
client.Execute(query);

query.SetParam("id", "555")
.SetParam("name", "utf8Русский");
client.Execute(query);
}

/// Select values inserted in the previous step.
Query query ("SELECT id, name, length(name) FROM test_client where id > {a: Int32}");
query.SetParam("a", "4");
SelectCallback cb([](const Block& block)
{
std::cout << PrettyPrintBlock{block} << std::endl;
});
query.OnData(cb);
client.Select(query);
/// Delete table.
client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void ParamNullExample(Client& client) {
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name Nullable(String))");

Query query("insert into test_client values ({id: UInt64}, {name: Nullable(String)})");

query.SetParam("id", "123").SetParam("name", QueryParamValue());
client.Execute(query);

query.SetParam("id", "456").SetParam("name", "String Value");
client.Execute(query);

client.Select("SELECT id, name FROM test_client", [](const Block& block) {
for (size_t c = 0; c < block.GetRowCount(); ++c) {
std::cerr << block[0]->As<ColumnUInt64>()->At(c) << " ";

auto col_string = block[1]->As<ColumnNullable>();
if (col_string->IsNull(c)) {
std::cerr << "\\N\n";
} else {
std::cerr << col_string->Nested()->As<ColumnString>()->At(c) << "\n";
}
}
});

client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void NullableExample(Client& client) {
/// Create a table.
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id Nullable(UInt64), date Nullable(Date))");
Expand Down Expand Up @@ -478,6 +553,8 @@ inline void IPExample(Client &client) {
}

static void RunTests(Client& client) {
ParamExample(client);
ParamNullExample(client);
ArrayExample(client);
CancelableExample(client);
DateExample(client);
Expand Down
51 changes: 51 additions & 0 deletions ut/client_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1488,3 +1488,54 @@ TEST(SimpleClientTest, issue_335_reconnects_count) {
<< "\tThere was no attempt to connect to endpoint " << endpoint;
}
}

TEST_P(ClientCase, QueryParameters) {
const auto & server_info = client_->GetServerInfo();
if (versionNumber(server_info) < versionNumber(24, 7)) {
GTEST_SKIP() << "Test is skipped since server '" << server_info << "' does not support query parameters" << std::endl;
}
const std::string table_name = "test_clickhouse_cpp_query_parameter";
client_->Execute("CREATE TEMPORARY TABLE IF NOT EXISTS " + table_name + " (id UInt64, name String)");
{
Query query("insert into " + table_name + " values ({id: UInt64}, {name: String})");

query.SetParam("id", "1").SetParam("name", "NAME");
client_->Execute(query);

query.SetParam("id", "123").SetParam("name", "FromParam");
client_->Execute(query);

const char FirstPrintable = ' ';
char test_str1[FirstPrintable * 2 + 1];
for (unsigned int i = 0; i < FirstPrintable; i++) {
test_str1[i * 2] = 'A';
test_str1[i * 2 + 1] = i;
}
test_str1[int(FirstPrintable * 2)] = 'A';

query.SetParam("id", "333").SetParam("name", std::string(test_str1, FirstPrintable * 2 + 1));
client_->Execute(query);

const char LastPrintable = 127;
unsigned char big_string[LastPrintable - FirstPrintable];
for (unsigned int i = 0; i < sizeof(big_string); i++) big_string[i] = i + FirstPrintable;
query.SetParam("id", "444").SetParam("name", std::string((char*)big_string, sizeof(big_string)));
client_->Execute(query);

query.SetParam("id", "555").SetParam("name", "utf8Русский");
client_->Execute(query);
}

Query query("SELECT id, name, length(name) FROM " + table_name + " where id > {a: Int32}");
query.SetParam("a", "4");
size_t total_count = 0;
SelectCallback cb([&total_count](const Block& block) {
total_count += block.GetRowCount();
//std::cout << PrettyPrintBlock{block} << std::endl;
});
query.OnData(cb);
client_->Select(query);
EXPECT_EQ(4u, total_count);

client_->Execute("DROP TEMPORARY TABLE " + table_name);
}
Loading