From 6d9a0b5c359deddf8075f1896a0ce00c155dc4e7 Mon Sep 17 00:00:00 2001
From: Travis Geiselbrecht <geist@foobox.com>
Date: Thu, 9 Jan 2025 23:11:21 -0800
Subject: [PATCH] [libc][printf] Handle case where snprintf underflows for len
 = 0

The code would write a null pointer always, even if len = 0, or if the
buffer pointer is null.

Test for this condition and add a unit test.

fixes #407
---
 app/tests/printf_tests.c       | 19 ++++++++++++++--
 lib/libc/printf.c              |  9 ++++++--
 lib/libc/test/printf_tests.cpp | 40 ++++++++++++++++++++++++++++++++++
 3 files changed, 64 insertions(+), 4 deletions(-)

diff --git a/app/tests/printf_tests.c b/app/tests/printf_tests.c
index 4e42b39b60..1dda970837 100644
--- a/app/tests/printf_tests.c
+++ b/app/tests/printf_tests.c
@@ -96,16 +96,31 @@ int printf_tests(int argc, const console_cmd_args *argv) {
     /* make sure snprintf terminates at the right spot */
     char buf[32];
 
-    memset(buf, 0, sizeof(buf));
+    memset(buf, 0x99, sizeof(buf));
     err = sprintf(buf, "0123456789abcdef012345678");
     printf("sprintf returns %d\n", err);
     hexdump8(buf, sizeof(buf));
 
-    memset(buf, 0, sizeof(buf));
+    memset(buf, 0x99, sizeof(buf));
     err = snprintf(buf, 15, "0123456789abcdef012345678");
     printf("snprintf returns %d\n", err);
     hexdump8(buf, sizeof(buf));
 
+    memset(buf, 0x99, sizeof(buf));
+    err = snprintf(buf, 1, "0123456789abcdef012345678");
+    printf("snprintf returns %d\n", err);
+    hexdump8(buf, sizeof(buf));
+
+    /* zero length is special case, should not write anything */
+    memset(buf, 0x99, sizeof(buf));
+    err = snprintf(buf, 0, "0123456789abcdef012345678");
+    printf("snprintf returns %d\n", err);
+    hexdump8(buf, sizeof(buf));
+
+    /* shold be able to pass null to the output buffer if zero length */
+    err = snprintf(NULL, 0, "0123456789abcdef012345678");
+    printf("snprintf returns %d\n", err);
+
     return NO_ERROR;
 }
 
diff --git a/lib/libc/printf.c b/lib/libc/printf.c
index 1baf8304d6..18ed97f513 100644
--- a/lib/libc/printf.c
+++ b/lib/libc/printf.c
@@ -66,6 +66,8 @@ static int _vsnprintf_output(const char *str, size_t len, void *state) {
         count++;
     }
 
+    // Return the count of the number of bytes that would be written even if the buffer
+    // wasn't large enough.
     return count;
 }
 
@@ -78,10 +80,13 @@ int vsnprintf(char *str, size_t len, const char *fmt, va_list ap) {
     args.pos = 0;
 
     wlen = _printf_engine(&_vsnprintf_output, (void *)&args, fmt, ap);
-    if (args.pos >= len)
+    if (len == 0) {
+        // do nothing, we can't null terminate the output
+    } else if (args.pos >= len) {
         str[len-1] = '\0';
-    else
+    } else {
         str[wlen] = '\0';
+    }
     return wlen;
 }
 
diff --git a/lib/libc/test/printf_tests.cpp b/lib/libc/test/printf_tests.cpp
index b252c4d425..094c59cf8e 100644
--- a/lib/libc/test/printf_tests.cpp
+++ b/lib/libc/test/printf_tests.cpp
@@ -458,6 +458,44 @@ bool snprintf_truncation_test() {
   END_TEST;
 }
 
+// Test snprintf() with zero length.
+bool snprintf_truncation_test_zero_length() {
+  BEGIN_TEST;
+
+  char buf[32];
+
+  memset(buf, 'x', sizeof(buf));
+  static const char str[26] = "0123456789abcdef012345678";
+
+  // Write with len = 0 a little ways into the buffer (to make sure it doesn't
+  // write to len -1).
+  int result = snprintf(buf + 4, 0, "%s", str);
+
+  // Check that snprintf() returns the length of the string that it would
+  // have written if the buffer was big enough.
+  EXPECT_EQ(result, (int)strlen(str));
+
+  // Check that snprintf() did not write anything.
+  for (auto c : buf)
+    EXPECT_EQ(c, 'x');
+
+  END_TEST;
+}
+
+// Test snprintf() with null pointer and zero length.
+bool snprintf_truncation_test_null_buffer() {
+  BEGIN_TEST;
+
+  static const char str[26] = "0123456789abcdef012345678";
+  int result = snprintf(nullptr, 0, "%s", str);
+
+  // Check that snprintf() returns the length of the string that it would
+  // have written if the buffer was big enough.
+  EXPECT_EQ(result, (int)strlen(str));
+
+  END_TEST;
+}
+
 }  // namespace
 
 BEGIN_TEST_CASE(printf_tests)
@@ -467,4 +505,6 @@ RUN_TEST(alt_and_sign)
 RUN_TEST(formatting)
 //RUN_TEST(printf_field_width_and_precision_test)
 RUN_TEST(snprintf_truncation_test)
+RUN_TEST(snprintf_truncation_test_zero_length)
+RUN_TEST(snprintf_truncation_test_null_buffer)
 END_TEST_CASE(printf_tests)