diff --git a/src/printf.cpp b/src/printf.cpp index 104c620d..4eb4ff9b 100644 --- a/src/printf.cpp +++ b/src/printf.cpp @@ -43,7 +43,7 @@ std::string get_vector_fmt(std::string fmt, int& vector_size, int& element_size, // Consume precision and field width pos = fmt.find_first_not_of("123456789.", pos); - if (fmt.at(pos) != 'v') { + if (pos == std::string::npos || fmt.at(pos) != 'v') { vector_size = 1; return std::string{fmt}; } @@ -165,14 +165,20 @@ void process_printf(char*& data, const printf_descriptor_map_t& descs, std::stringstream printf_out{}; - // Firstly print the part of the format string up to the first '%' + // Firstly print the part of the format string up to the first '%' if any + // otherwise print the whole string as is and move the data pointer to the + // end. size_t next_part = format_string.find_first_of('%'); + if (next_part == std::string::npos) { + next_part = format_string.size(); + data = data_end; + } printf_out << format_string.substr(0, next_part); // Decompose the remaining format string into individual strings with // one format specifier each, handle each one individually size_t arg_idx = 0; - while (next_part < format_string.size() - 1) { + while (next_part < format_string.size()) { // Get the part of the format string before the next format specifier size_t part_start = next_part; size_t part_end = format_string.find_first_of('%', part_start + 1); @@ -181,7 +187,15 @@ void process_printf(char*& data, const printf_descriptor_map_t& descs, // Handle special cases if (part_end == part_start + 1) { printf_out << "%"; - next_part = part_end + 1; + // We also need to print the literals between '%%' and the next '%' + next_part = part_start = part_end + 1; + part_end = format_string.find_first_of('%', part_start); + if (part_end != std::string::npos && part_end > part_start) { + part_fmt = + format_string.substr(part_start, part_end - part_start); + printf_out << part_fmt; + next_part = part_end; + } continue; } else if (part_end == std::string::npos && arg_idx >= descs.at(printf_id).arg_sizes.size()) { diff --git a/tests/api/printf.cpp b/tests/api/printf.cpp index 570336fd..e9db7093 100644 --- a/tests/api/printf.cpp +++ b/tests/api/printf.cpp @@ -23,6 +23,23 @@ TEST_F(WithCommandQueueAndPrintf, SimplePrintf) { char source[512]; sprintf(source, "kernel void test_printf() { printf(\"%s\");}", message); auto kernel = CreateKernel(source, "test_printf"); + size_t gws = 1; + size_t lws = 1; + EnqueueNDRangeKernel(kernel, 1, nullptr, &gws, &lws, 0, nullptr, nullptr); + Finish(); + ASSERT_STREQ(m_printf_output.c_str(), message); +} + +TEST_F(WithCommandQueueAndPrintf, SimplePrintfPercent) { + // Tests that the while loop inside the process_printf func + // goes all the way to the end of the string to be printed. + const char message[] = "The value is: 123%"; + char source[512]; + sprintf(source, + "kernel void test_printf() { printf(\"The value is: %d%%\"); }", + 123); + + auto kernel = CreateKernel(source, "test_printf"); size_t gws = 1; size_t lws = 1; @@ -32,6 +49,50 @@ TEST_F(WithCommandQueueAndPrintf, SimplePrintf) { ASSERT_STREQ(m_printf_output.c_str(), message); } +TEST_F(WithCommandQueueAndPrintf, SimplePrintfBetweenPercents) { + const char message[] = "%Hello%World%!"; + const char message_test[] = "%%Hello%%World%%!"; + + char source[512]; + sprintf(source, "kernel void test_printf() { printf(\"%s\");}", + message_test); + + auto kernel = CreateKernel(source, "test_printf"); + + size_t gws = 1; + size_t lws = 1; + EnqueueNDRangeKernel(kernel, 1, nullptr, &gws, &lws, 0, nullptr, nullptr); + Finish(); + + ASSERT_STREQ(m_printf_output.c_str(), message); +} + +TEST_F(WithCommandQueueAndPrintf, SimpleFormatedPrintf) { + const char* source = "kernel void test_printf() { printf(\"%s\", \"\"); }"; + auto kernel = CreateKernel(source, "test_printf"); + + size_t gws = 1; + size_t lws = 1; + EnqueueNDRangeKernel(kernel, 1, nullptr, &gws, &lws, 0, nullptr, nullptr); + Finish(); + + ASSERT_STREQ(m_printf_output.c_str(), ""); +} +TEST_F(WithCommandQueueAndPrintf, PrintfWithNoFormatSpecifier) { + const char* source = + "kernel void test_printf() { printf(\"\\n\", \"foo\");}"; + auto kernel = CreateKernel(source, "test_printf"); + + size_t gws = 1; + size_t lws = 1; + EnqueueNDRangeKernel(kernel, 1, nullptr, &gws, &lws, 0, nullptr, nullptr); + Finish(); + + // The expected output is just a newline character since there's no + // format specifier to consume the "foo" argument. + ASSERT_STREQ(m_printf_output.c_str(), "\n"); +} + TEST_F(WithCommandQueueAndPrintf, TooLongPrintf) { // each print takes 12 bytes (4 for the printf_id, and 2*4 for the 2 integer // to print) + 4 for the byte written counter