From 7639307e23b89f522bc578c9c5e3a9768df2bcce Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Thu, 30 Jan 2025 15:16:58 -0800 Subject: [PATCH] fix: Remove StringWriter that modifies input in place (#12219) Summary: StringWriter supports two flavors, one that writes out new strings to a result Vector, and one that modifies existing strings in place. The latter is only used in two places, providing fast paths for upper/lower and replace/replace_first. I found a related bug with fuzzer that occurs in replace/replace_first that occurs when the StringViews in the first argument's Vector point to overlapping ranges of the same string buffers. In this case the changes to one row are accidentally applied to multiple rows resulting in incorrect results. Since we explicitly allow operations that produce a substring of the original string to do so with a no-copy implementation, StringViews with overlapping ranges can occur. E.g. SimpleFunctions that apply this no-copy optimization on a string argument which could be constant, with other arguments that determine the range of the original string to take that are non-constant (like trim). I discussed this offline with a few folks and since the above is allowed and this in-place optimization is so rarely used, the consensus was to treat the string buffers in a FlatVector as immutable. Therefore in this change, I remove the flavor of StringWriter that modifies strings in place. This fixes the bug in replace/replace_first. upper/lower was using it correctly because the optimization was only applied to ASCII strings, the function takes a single argument, and the modification is idempotent (the value of a byte in the string don't depend on any other bytes in the string or any other arguments, and can be reapplied without consequences). Given how precarious this optimization is (if any of those conditions changed it would result in difficult to detect bugs), and allowing upper/lower to mutate the string in place would invite others to do so in the future (potentially leading to more bugs like in replace/replace_first) I think losing this fast path is worth the added safety. I also updated the documentation to clarify that the string buffers in a FlatVector should be treated as immutable. Differential Revision: D68924324 --- velox/docs/develop/vectors.rst | 6 + velox/docs/develop/view-and-writer-types.rst | 2 +- velox/docs/functions/presto/json.rst | 2 +- velox/expression/CastExpr-inl.h | 2 +- velox/expression/CastExpr.cpp | 4 +- velox/expression/StringWriter.h | 54 +----- velox/expression/UdfTypeResolver.h | 5 +- velox/expression/VectorWriters.h | 2 +- velox/expression/tests/StringWriterTest.cpp | 18 +- velox/expression/tests/VariadicViewTest.cpp | 2 +- velox/functions/lib/Re2Functions.cpp | 4 +- velox/functions/lib/ToHex.h | 4 +- velox/functions/lib/string/StringImpl.h | 16 -- velox/functions/prestosql/FromUtf8.cpp | 6 +- velox/functions/prestosql/Reverse.cpp | 4 +- velox/functions/prestosql/StringFunctions.cpp | 77 +------- .../prestosql/tests/StringFunctionsTest.cpp | 178 ++++++++---------- .../prestosql/types/IPAddressType.cpp | 4 +- .../prestosql/types/IPPrefixType.cpp | 2 +- velox/functions/prestosql/types/JsonType.cpp | 8 +- .../types/TimestampWithTimeZoneType.cpp | 2 +- velox/functions/prestosql/types/UuidType.cpp | 2 +- velox/functions/sparksql/String.h | 6 +- 23 files changed, 122 insertions(+), 288 deletions(-) diff --git a/velox/docs/develop/vectors.rst b/velox/docs/develop/vectors.rst index 963642151748..292afe17c599 100644 --- a/velox/docs/develop/vectors.rst +++ b/velox/docs/develop/vectors.rst @@ -256,6 +256,12 @@ After applying substr(s, 2) function string in position 1 became short enough to fit inside the StringView, hence, it no longer contains a pointer to a position in the string buffer. +Allowing these zero-copy implementations of functions that simply change the +starting position/length of a string, means that we may end up with StringViews +pointing to overlapping ranges within `stringBuffers_`. For this reason the +Buffers in `stringBuffers_` should be treated as immutable to prevent +modifications from unintentionally cascading. + Flat vectors of type TIMESTAMP are represented by FlatVector. Timestamp struct consists of two 64-bit integers: seconds and nanoseconds. Each entry uses 16 bytes. diff --git a/velox/docs/develop/view-and-writer-types.rst b/velox/docs/develop/view-and-writer-types.rst index 5c8bf89ee116..8a478d99ad72 100644 --- a/velox/docs/develop/view-and-writer-types.rst +++ b/velox/docs/develop/view-and-writer-types.rst @@ -310,7 +310,7 @@ When a given Ti is primitive, the following is valid. Assignable to std::optional allows writing null or value to the primitive. Returned by complex writers when writing nullable primitives. -**StringWriter<>** +**StringWriter** - void **reserve** (size_t newCapacity) : Reserve a space for the output string with size of at least newCapacity. - void **resize** (size_t newCapacity) : Set the size of the string. diff --git a/velox/docs/functions/presto/json.rst b/velox/docs/functions/presto/json.rst index 36cc09ce8255..a7d50f9e170d 100644 --- a/velox/docs/functions/presto/json.rst +++ b/velox/docs/functions/presto/json.rst @@ -211,4 +211,4 @@ are similar. To create a JSON-typed vector, one can use ``BaseVector::create(JSON(), size, pool)`` that creates a flat vector of StringViews, i.e. FlatVector. Reading and writing to a JSON-typed vector are the same as those for VARCHAR vectors, e.g., via -VectorReader and StringWriter<>. +VectorReader and StringWriter. diff --git a/velox/expression/CastExpr-inl.h b/velox/expression/CastExpr-inl.h index f660b992fca7..38a08868124e 100644 --- a/velox/expression/CastExpr-inl.h +++ b/velox/expression/CastExpr-inl.h @@ -341,7 +341,7 @@ void CastExpr::applyCastKernel( if constexpr ( ToKind == TypeKind::VARCHAR || ToKind == TypeKind::VARBINARY) { // Write the result output to the output vector - auto writer = exec::StringWriter<>(result, row); + auto writer = exec::StringWriter(result, row); writer.copy_from(output); writer.finalize(); } else { diff --git a/velox/expression/CastExpr.cpp b/velox/expression/CastExpr.cpp index 4e3ba162c5d3..c5c8ecb0bb79 100644 --- a/velox/expression/CastExpr.cpp +++ b/velox/expression/CastExpr.cpp @@ -175,7 +175,7 @@ VectorPtr CastExpr::castFromDate( try { // TODO Optimize to avoid creating an intermediate string. auto output = DATE()->toString(inputFlatVector->valueAt(row)); - auto writer = exec::StringWriter<>(resultFlatVector, row); + auto writer = exec::StringWriter(resultFlatVector, row); writer.resize(output.size()); ::memcpy(writer.data(), output.data(), output.size()); writer.finalize(); @@ -298,7 +298,7 @@ VectorPtr CastExpr::castFromIntervalDayTime( // TODO Optimize to avoid creating an intermediate string. auto output = INTERVAL_DAY_TIME()->valueToString(inputFlatVector->valueAt(row)); - auto writer = exec::StringWriter<>(resultFlatVector, row); + auto writer = exec::StringWriter(resultFlatVector, row); writer.resize(output.size()); ::memcpy(writer.data(), output.data(), output.size()); writer.finalize(); diff --git a/velox/expression/StringWriter.h b/velox/expression/StringWriter.h index 22502bc27432..006fa1f8e1fe 100644 --- a/velox/expression/StringWriter.h +++ b/velox/expression/StringWriter.h @@ -23,11 +23,7 @@ #include "velox/vector/FlatVector.h" namespace facebook::velox::exec { -template -class StringWriter; - -template <> -class StringWriter : public UDFOutputString { +class StringWriter : public UDFOutputString { public: // Used to initialize top-level strings and allow zero-copy writes. StringWriter(FlatVector* vector, int32_t offset) @@ -156,52 +152,4 @@ class StringWriter : public UDFOutputString { template friend struct VectorWriter; }; - -// A string writer with UDFOutputString semantics that utilizes a pre-allocated -// input string for the output allocation, if inPlace is true in the constructor -// the string will be initialized with the input string value. -template <> -class StringWriter : public UDFOutputString { - public: - StringWriter() : vector_(nullptr), offset_(-1) {} - - StringWriter( - FlatVector* vector, - int32_t offset, - const StringView& stringToReuse, - bool inPlace = false) - : vector_(vector), offset_(offset), stringToReuse_(stringToReuse) { - setData(const_cast(stringToReuse_.data())); - setCapacity(stringToReuse_.size()); - - if (inPlace) { - // The string should be intialized with the input value - setSize(stringToReuse_.size()); - } - } - - void reserve(size_t newCapacity) override { - VELOX_CHECK( - newCapacity <= capacity() && "String writer max capacity extended"); - } - - /// Not called by the UDF Implementation. Should be called at the end to - /// finalize the allocation and the string writing - void finalize() { - VELOX_DCHECK(size() == 0 || data()); - vector_->setNoCopy(offset_, StringView(data(), size())); - } - - private: - /// The output vector that this string is being written to - FlatVector* vector_; - - /// The offset the string writes to within vector_ - int32_t offset_; - - /// The input string that is reused, held locally to assert the validity of - /// the data pointer throughout the proxy lifetime. More specifically when - /// the string is inlined. - StringView stringToReuse_; -}; } // namespace facebook::velox::exec diff --git a/velox/expression/UdfTypeResolver.h b/velox/expression/UdfTypeResolver.h index 72d14b1a481b..b27c197dbdb4 100644 --- a/velox/expression/UdfTypeResolver.h +++ b/velox/expression/UdfTypeResolver.h @@ -21,7 +21,6 @@ namespace facebook::velox::exec { -template class StringWriter; template @@ -120,14 +119,14 @@ template <> struct resolver { using in_type = StringView; using null_free_in_type = in_type; - using out_type = StringWriter; + using out_type = StringWriter; }; template <> struct resolver { using in_type = StringView; using null_free_in_type = in_type; - using out_type = StringWriter; + using out_type = StringWriter; }; template <> diff --git a/velox/expression/VectorWriters.h b/velox/expression/VectorWriters.h index 8da7312ce700..c2d845144c61 100644 --- a/velox/expression/VectorWriters.h +++ b/velox/expression/VectorWriters.h @@ -353,7 +353,7 @@ struct VectorWriter< std::enable_if_t | std::is_same_v>> : public VectorWriterBase { using vector_t = typename TypeToFlatVector::type; - using exec_out_t = StringWriter<>; + using exec_out_t = StringWriter; void init(vector_t& vector, bool uniqueAndMutable = false) { proxy_.vector_ = &vector; diff --git a/velox/expression/tests/StringWriterTest.cpp b/velox/expression/tests/StringWriterTest.cpp index aec9d9cacb88..d0c8327e6d31 100644 --- a/velox/expression/tests/StringWriterTest.cpp +++ b/velox/expression/tests/StringWriterTest.cpp @@ -28,7 +28,7 @@ class StringWriterTest : public functions::test::FunctionBaseTest {}; TEST_F(StringWriterTest, append) { auto vector = makeFlatVector(2); - auto writer = exec::StringWriter<>(vector.get(), 0); + auto writer = exec::StringWriter(vector.get(), 0); writer.append("1 "_sv); writer.append(std::string_view("2 ")); writer.append("3 "_sv); @@ -42,7 +42,7 @@ TEST_F(StringWriterTest, append) { TEST_F(StringWriterTest, plusOperator) { auto vector = makeFlatVector(1); - auto writer = exec::StringWriter<>(vector.get(), 0); + auto writer = exec::StringWriter(vector.get(), 0); writer += "1 "_sv; writer += "2 "; writer += std::string_view("3 "); @@ -57,19 +57,19 @@ TEST_F(StringWriterTest, plusOperator) { TEST_F(StringWriterTest, assignment) { auto vector = makeFlatVector(4); - auto writer0 = exec::StringWriter<>(vector.get(), 0); + auto writer0 = exec::StringWriter(vector.get(), 0); writer0 = "string0"_sv; writer0.finalize(); - auto writer1 = exec::StringWriter<>(vector.get(), 1); + auto writer1 = exec::StringWriter(vector.get(), 1); writer1 = std::string("string1"); writer1.finalize(); - auto writer2 = exec::StringWriter<>(vector.get(), 2); + auto writer2 = exec::StringWriter(vector.get(), 2); writer2 = std::string_view("string2"); writer2.finalize(); - auto writer3 = exec::StringWriter<>(vector.get(), 3); + auto writer3 = exec::StringWriter(vector.get(), 3); writer3 = folly::StringPiece("string3"); writer3.finalize(); @@ -81,7 +81,7 @@ TEST_F(StringWriterTest, assignment) { TEST_F(StringWriterTest, copyFromStringView) { auto vector = makeFlatVector(1); - auto writer = exec::StringWriter<>(vector.get(), 0); + auto writer = exec::StringWriter(vector.get(), 0); writer.copy_from("1 2 3 4 5 "_sv); writer.finalize(); @@ -90,7 +90,7 @@ TEST_F(StringWriterTest, copyFromStringView) { TEST_F(StringWriterTest, copyFromStdString) { auto vector = makeFlatVector(1); - auto writer = exec::StringWriter<>(vector.get(), 0); + auto writer = exec::StringWriter(vector.get(), 0); writer.copy_from(std::string("1 2 3 4 5 ")); writer.finalize(); @@ -99,7 +99,7 @@ TEST_F(StringWriterTest, copyFromStdString) { TEST_F(StringWriterTest, copyFromCString) { auto vector = makeFlatVector(4); - auto writer = exec::StringWriter<>(vector.get(), 0); + auto writer = exec::StringWriter(vector.get(), 0); writer.copy_from("1 2 3 4 5 "); writer.finalize(); diff --git a/velox/expression/tests/VariadicViewTest.cpp b/velox/expression/tests/VariadicViewTest.cpp index 271a2459984c..346da8a43fde 100644 --- a/velox/expression/tests/VariadicViewTest.cpp +++ b/velox/expression/tests/VariadicViewTest.cpp @@ -377,7 +377,7 @@ const auto callNullablePrefix = "callNullable "_sv; const auto callAsciiPrefix = "callAscii "_sv; void writeInputToOutput( - StringWriter<>& out, + StringWriter& out, const VariadicView* inputs) { for (const auto& input : *inputs) { out += input.has_value() ? input.value() : null; diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index ace6d0a6fa3f..d0438d671bad 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -1397,11 +1397,11 @@ class RegexpReplaceWithLambdaFunction : public exec::VectorFunction { // Sections being replaced should not overlap. struct Replacer { const StringView& original; - exec::StringWriter& writer; + exec::StringWriter& writer; char* result; size_t start = 0; - Replacer(const StringView& _original, exec::StringWriter& _writer) + Replacer(const StringView& _original, exec::StringWriter& _writer) : original{_original}, writer{_writer}, result{writer.data()} {} void replace(size_t offset, size_t size, const StringView& replacement) { diff --git a/velox/functions/lib/ToHex.h b/velox/functions/lib/ToHex.h index 14b714f81157..04602bca7bc8 100644 --- a/velox/functions/lib/ToHex.h +++ b/velox/functions/lib/ToHex.h @@ -23,7 +23,7 @@ namespace facebook::velox::functions { struct ToHexUtil { FOLLY_ALWAYS_INLINE static void toHex( StringView input, - exec::StringWriter& result) { + exec::StringWriter& result) { // Lookup table to translate unsigned char to its hexadecimal format. static const char* const kHexTable = "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F" @@ -49,7 +49,7 @@ struct ToHexUtil { FOLLY_ALWAYS_INLINE static void toHex( uint64_t input, - exec::StringWriter& result) { + exec::StringWriter& result) { static const char* const kHexTable = "0123456789ABCDEF"; if (input == 0) { result = "0"; diff --git a/velox/functions/lib/string/StringImpl.h b/velox/functions/lib/string/StringImpl.h index e0a13ced7e5d..f343fda590b2 100644 --- a/velox/functions/lib/string/StringImpl.h +++ b/velox/functions/lib/string/StringImpl.h @@ -21,8 +21,6 @@ #include #include #include -#include -#include #include #include #include "folly/CPortability.h" @@ -65,20 +63,6 @@ FOLLY_ALWAYS_INLINE bool lower(TOutString& output, const TInString& input) { return true; } -/// Inplace ascii lower -template -FOLLY_ALWAYS_INLINE bool lowerAsciiInPlace(T& str) { - lowerAscii(str.data(), str.data(), str.size()); - return true; -} - -/// Inplace ascii upper -template -FOLLY_ALWAYS_INLINE bool upperAsciiInPlace(T& str) { - upperAscii(str.data(), str.data(), str.size()); - return true; -} - /// Apply a set of appenders on an output string, an appender is a lambda /// that takes an output string and append a string to it. This can be used by /// code-gen to reduce copying in concat by evaluating nested expressions diff --git a/velox/functions/prestosql/FromUtf8.cpp b/velox/functions/prestosql/FromUtf8.cpp index b93df3744d5b..70d09643bfa6 100644 --- a/velox/functions/prestosql/FromUtf8.cpp +++ b/velox/functions/prestosql/FromUtf8.cpp @@ -95,7 +95,7 @@ class FromUtf8Function : public exec::VectorFunction { if (constantReplacement) { rows.applyToSelected([&](auto row) { - exec::StringWriter writer(flatResult, row); + exec::StringWriter writer(flatResult, row); auto value = decodedInput.valueAt(row); if (row < firstInvalidRow) { writer.append(value); @@ -109,7 +109,7 @@ class FromUtf8Function : public exec::VectorFunction { context.applyToSelectedNoThrow(rows, [&](auto row) { auto replacement = getReplacementCharacter(args[1]->type(), decodedReplacement, row); - exec::StringWriter writer(flatResult, row); + exec::StringWriter writer(flatResult, row); auto value = decodedInput.valueAt(row); if (row < firstInvalidRow) { writer.append(value); @@ -261,7 +261,7 @@ class FromUtf8Function : public exec::VectorFunction { void fixInvalidUtf8( StringView input, const std::string& replacement, - exec::StringWriter& fixedWriter) const { + exec::StringWriter& fixedWriter) const { if (input.empty()) { fixedWriter.setEmpty(); return; diff --git a/velox/functions/prestosql/Reverse.cpp b/velox/functions/prestosql/Reverse.cpp index 0d97ba0114cd..cc7b7e7d8b13 100644 --- a/velox/functions/prestosql/Reverse.cpp +++ b/velox/functions/prestosql/Reverse.cpp @@ -42,7 +42,7 @@ class ReverseFunction : public exec::VectorFunction { const FlatVector* input, FlatVector* result) { rows.applyToSelected([&](int row) { - auto proxy = exec::StringWriter<>(result, row); + auto proxy = exec::StringWriter(result, row); stringImpl::reverse(proxy, input->valueAt(row).getString()); proxy.finalize(); }); @@ -100,7 +100,7 @@ class ReverseFunction : public exec::VectorFunction { if (originalArg->isConstantEncoding()) { auto value = originalArg->as>()->valueAt(0); - auto proxy = exec::StringWriter<>(flatResult, rows.begin()); + auto proxy = exec::StringWriter(flatResult, rows.begin()); if (isAscii) { stringImpl::reverse(proxy, value.str()); } else { diff --git a/velox/functions/prestosql/StringFunctions.cpp b/velox/functions/prestosql/StringFunctions.cpp index d2f23b2e1bc4..22207ad5274e 100644 --- a/velox/functions/prestosql/StringFunctions.cpp +++ b/velox/functions/prestosql/StringFunctions.cpp @@ -19,7 +19,6 @@ #include "velox/expression/StringWriter.h" #include "velox/expression/VectorFunction.h" #include "velox/functions/lib/StringEncodingUtils.h" -#include "velox/functions/lib/string/StringCore.h" #include "velox/functions/lib/string/StringImpl.h" #include "velox/vector/FlatVector.h" @@ -43,7 +42,7 @@ class UpperLowerTemplateFunction : public exec::VectorFunction { const DecodedVector* decodedInput, FlatVector* results) { rows.applyToSelected([&](int row) { - auto proxy = exec::StringWriter<>(results, row); + auto proxy = exec::StringWriter(results, row); if constexpr (isLower) { stringImpl::lower( proxy, decodedInput->valueAt(row)); @@ -56,25 +55,6 @@ class UpperLowerTemplateFunction : public exec::VectorFunction { } }; - void applyInternalInPlace( - const SelectivityVector& rows, - DecodedVector* decodedInput, - FlatVector* results) const { - rows.applyToSelected([&](int row) { - auto proxy = exec::StringWriter( - results, - row, - decodedInput->valueAt(row) /*reusedInput*/, - true /*inPlace*/); - if constexpr (isLower) { - stringImpl::lowerAsciiInPlace(proxy); - } else { - stringImpl::upperAsciiInPlace(proxy); - } - proxy.finalize(); - }); - } - public: void apply( const SelectivityVector& rows, @@ -92,19 +72,6 @@ class UpperLowerTemplateFunction : public exec::VectorFunction { auto ascii = isAscii(inputStringsVector, rows); - bool tryInplace = ascii && - (inputStringsVector->encoding() == VectorEncoding::Simple::FLAT); - - // If tryInplace, then call prepareFlatResultsVector(). If the latter - // returns true, note that the input arg was moved to result, so that the - // buffer can be reused as output. - if (tryInplace && - prepareFlatResultsVector(result, rows, context, args.at(0))) { - auto* resultFlatVector = result->as>(); - applyInternalInPlace(rows, decodedInput, resultFlatVector); - return; - } - // Not in place path. VectorPtr emptyVectorPtr; prepareFlatResultsVector(result, rows, context, emptyVectorPtr); @@ -308,7 +275,7 @@ class Replace : public exec::VectorFunction { const SelectivityVector& rows, FlatVector* results) const { rows.applyToSelected([&](int row) { - auto proxy = exec::StringWriter<>(results, row); + auto proxy = exec::StringWriter(results, row); stringImpl::replace( proxy, stringReader(row), @@ -319,25 +286,6 @@ class Replace : public exec::VectorFunction { }); } - template < - typename StringReader, - typename SearchReader, - typename ReplaceReader> - void applyInPlace( - StringReader stringReader, - SearchReader searchReader, - ReplaceReader replaceReader, - const SelectivityVector& rows, - FlatVector* results) const { - rows.applyToSelected([&](int row) { - auto proxy = exec::StringWriter( - results, row, stringReader(row) /*reusedInput*/, true /*inPlace*/); - stringImpl::replaceInPlace( - proxy, searchReader(row), replaceReader(row), replaceFirst_); - proxy.finalize(); - }); - } - const bool replaceFirst_; public: @@ -392,27 +340,6 @@ class Replace : public exec::VectorFunction { } }; - // Right now we enable the inplace if 'search' and 'replace' are constants - // and 'search' size is larger than or equal to 'replace' and if the input - // vector is reused. - - // TODO: analyze other options for enabling inplace i.e.: - // 1. Decide per row. - // 2. Scan inputs for max lengths and decide based on that. ..etc - bool tryInplace = replaceArgValue.has_value() && - searchArgValue.has_value() && - (searchArgValue.value().size() >= replaceArgValue.value().size()) && - (args.at(0)->encoding() == VectorEncoding::Simple::FLAT); - - if (tryInplace) { - if (prepareFlatResultsVector(result, rows, context, args.at(0))) { - auto* resultFlatVector = result->as>(); - applyInPlace( - stringReader, searchReader, replaceReader, rows, resultFlatVector); - return; - } - } - // Not in place path VectorPtr emptyVectorPtr; prepareFlatResultsVector(result, rows, context, emptyVectorPtr); diff --git a/velox/functions/prestosql/tests/StringFunctionsTest.cpp b/velox/functions/prestosql/tests/StringFunctionsTest.cpp index acdd6135ac53..c46b09069887 100644 --- a/velox/functions/prestosql/tests/StringFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/StringFunctionsTest.cpp @@ -314,13 +314,6 @@ class StringFunctionsTest : public FunctionBaseTest { bool withReplaceArgument, bool replaceFirst = false); - void testReplaceInPlace( - const std::vector>& tests, - const std::string& search, - const std::string& replace, - bool multiReferenced, - bool replaceFirst = false); - using replace_first_input_test_t = std::vector, std::string>>; @@ -1370,75 +1363,6 @@ TEST_F(StringFunctionsTest, invalidLevenshteinDistance) { "The combined inputs size exceeded max Levenshtein distance combined input size"); } -void StringFunctionsTest::testReplaceInPlace( - const std::vector>& tests, - const std::string& search, - const std::string& replace, - bool multiReferenced, - bool replaceFirst) { - auto makeInput = [&]() { - auto stringVector = makeFlatVector(tests.size()); - - for (int i = 0; i < tests.size(); i++) { - stringVector->set(i, StringView(tests[i].first)); - } - auto crossRefVector = makeFlatVector(1); - - if (multiReferenced) { - crossRefVector->acquireSharedStringBuffers(stringVector.get()); - } - return stringVector; - }; - - auto testResults = [&](const FlatVector* results) { - for (int32_t i = 0; i < tests.size(); ++i) { - ASSERT_EQ(results->valueAt(i), StringView(tests[i].second)); - } - }; - - auto result = evaluate>( - fmt::format( - "{}(c0, '{}', '{}')", - replaceFirst ? "replace_first" : "replace", - search, - replace), - makeRowVector({makeInput()})); - testResults(result.get()); - - // Test in place optimization. If in-place is expected, make sure it happened. - // If its not expected make sure it did not happen. - auto applyReplaceFunction = [&](std::vector& functionInputs, - VectorPtr& resultPtr) { - core::QueryConfig config({}); - auto replaceFunction = replaceFirst - ? exec::getVectorFunction( - "replace_first", {VARCHAR(), VARCHAR(), VARCHAR()}, {}, config) - : exec::getVectorFunction( - "replace", {VARCHAR(), VARCHAR()}, {}, config); - SelectivityVector rows(tests.size()); - ExprSet exprSet({}, &execCtx_); - RowVectorPtr inputRows = makeRowVector({}); - exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputRows.get()); - replaceFunction->apply(rows, functionInputs, VARCHAR(), evalCtx, resultPtr); - }; - - std::vector functionInputs = { - makeInput(), - makeConstant(search.c_str(), tests.size()), - makeConstant(replace.c_str(), tests.size())}; - - VectorPtr resultPtr; - applyReplaceFunction(functionInputs, resultPtr); - testResults(resultPtr->asFlatVector()); - - if (!multiReferenced && search >= replace) { - // Expected in-place. - ASSERT_TRUE(resultPtr == functionInputs[0]); - } else { - ASSERT_FALSE(resultPtr == functionInputs[0]); - } -} - void StringFunctionsTest::testReplaceFlatVector( const replace_input_test_t& tests, bool withReplaceArgument, @@ -1521,19 +1445,14 @@ TEST_F(StringFunctionsTest, replaceFirst) { true, /*replaceFirst*/ true); - // Test in place path - std::vector> testsInplace = { - {"foobar", "fttbar"}, {"oooooo", "ttoooo"}}; - testReplaceInPlace(testsInplace, "oo", "tt", false, /*replaceFirst*/ true); - testReplaceInPlace(testsInplace, "oo", "tt", true, /*replaceFirst*/ true); - // Test in place path with unicode - std::vector> testsInplaceUnicode = { - {"αβγδεζηθικλμνξοπρςστυφχψ", "αβγδεζηψκλμνξοπρςστυφχψ"}, - {"θιбвгдежз", "ψбвгдежз"}}; - testReplaceInPlace( - testsInplaceUnicode, "θι", "ψ", false, /*replaceFirst*/ true); - testReplaceInPlace( - testsInplaceUnicode, "θι", "ψ", true, /*replaceFirst*/ true); + testReplaceFlatVector({{{"foobar", "oo", "tt"}, {"fttbar"}}}, true, true); + testReplaceFlatVector({{{"oooooo", "oo", "tt"}, {"ttoooo"}}}, true, true); + + testReplaceFlatVector( + {{{"αβγδεζηθικλμνξοπρςστυφχψ", "θι", "ψ"}, {"αβγδεζηψκλμνξοπρςστυφχψ"}}}, + true, + true); + testReplaceFlatVector({{{"θιбвгдежз", "θι", "ψ"}, {"ψбвгдежз"}}}, true, true); // Test constant vectors auto rows = makeRowVector(makeRowType({BIGINT()}), 10); @@ -1565,19 +1484,21 @@ TEST_F(StringFunctionsTest, replace) { testReplaceFlatVector(testsTwoArgs, false); - // Test in place path - std::vector> testsInplace = { - {"aaa", "bbb"}, - {"aba", "bbb"}, - {"qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertyuio", - "qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertyuio"}, - {"qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertaaaa", - "qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertbbbb"}, - }; - - testReplaceInPlace(testsInplace, "a", "b", true); - testReplaceInPlace(testsInplace, "a", "b", false); - testReplaceInPlace({{"a", "bb"}, {"aa", "bbbb"}}, "a", "bb", false); + replace_input_test_t moreTests = { + {{"aaa", "a", "b"}, {"bbb"}}, + {{"aba", "a", "b"}, {"bbb"}}, + {{"qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertyuio", + "a", + "b"}, + {"qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertyuio"}}, + {{"qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertaaaa", + "a", + "b"}, + {"qwertyuiowertyuioqwertyuiopwertyuiopwertyuiopwertyuiopertyuioqwertyuiopwertyuiowertbbbb"}}, + {{"a", "a", "bb"}, {"bb"}}, + {{"aa", "a", "bb"}, {"bbbb"}}}; + + testReplaceFlatVector(moreTests, true); // Test constant vectors auto rows = makeRowVector(makeRowType({BIGINT()}), 10); @@ -1588,7 +1509,7 @@ TEST_F(StringFunctionsTest, replace) { } } -TEST_F(StringFunctionsTest, replaceWithReusableInputButNoInplace) { +TEST_F(StringFunctionsTest, replaceWithReusableInput) { auto c0 = ({ auto values = makeFlatVector({"foo"}); auto indices = allocateIndices(100, execCtx_.pool()); @@ -1612,6 +1533,57 @@ TEST_F(StringFunctionsTest, replaceWithReusableInputButNoInplace) { } } +TEST_F(StringFunctionsTest, replaceOverlappingStringViews) { + auto test = [&](const std::string& function) { + BufferPtr stringData = AlignedBuffer::allocate(15, pool_.get()); + memcpy(stringData->asMutable(), "abcdefghijklmno", 15); + const char* str = stringData->as(); + + BufferPtr values = AlignedBuffer::allocate(3, pool_.get()); + auto* valuesMutable = values->asMutable(); + // Make the strings large enough that they are not inlined. + // Note that only the first string contains the substring "abc", though the + // other two contain portions of it. This test verifies that "abc" is not + // replaced with "def" in the original string buffer (which would cause + // visible changes in the other two strings). + valuesMutable[0] = StringView(str, 13); // abcdefghijklm + valuesMutable[1] = StringView(str + 1, 13); // bcdefghijklmn + valuesMutable[2] = StringView(str + 2, 13); // cdefghijklmno + + auto inputVector = std::make_shared>( + pool_.get(), + VARCHAR(), + nullptr, + 3, + std::move(values), + std::vector{std::move(stringData)}); + const auto numRows = inputVector->size(); + + core::QueryConfig config({}); + auto replaceFunction = exec::getVectorFunction( + function, {VARCHAR(), VARCHAR(), VARCHAR()}, {}, config); + SelectivityVector rows(numRows); + ExprSet exprSet({}, &execCtx_); + RowVectorPtr inputRows = makeRowVector({}); + exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputRows.get()); + + std::vector functionInputs{ + std::move(inputVector), + makeConstant("abc", numRows), + makeConstant("def", numRows)}; + VectorPtr resultPtr; + replaceFunction->apply(rows, functionInputs, VARCHAR(), evalCtx, resultPtr); + + auto* results = resultPtr->as>(); + EXPECT_EQ(results->valueAt(0), "defdefghijklm"); + EXPECT_EQ(results->valueAt(1), "bcdefghijklmn"); + EXPECT_EQ(results->valueAt(2), "cdefghijklmno"); + }; + + test("replace"); + test("replace_first"); +} + TEST_F(StringFunctionsTest, controlExprEncodingPropagation) { std::vector dataASCII({"ali", "ali", "ali"}); std::vector dataUTF8({"àáâãäåæçè", "àáâãäåæçè", "àáâãäå"}); diff --git a/velox/functions/prestosql/types/IPAddressType.cpp b/velox/functions/prestosql/types/IPAddressType.cpp index 35d4270fe439..7f3884d33a37 100644 --- a/velox/functions/prestosql/types/IPAddressType.cpp +++ b/velox/functions/prestosql/types/IPAddressType.cpp @@ -114,7 +114,7 @@ class IPAddressCastOperator : public exec::CastOperator { std::reverse(addrBytes.begin(), addrBytes.end()); folly::IPAddressV6 v6Addr(addrBytes); - exec::StringWriter result(flatResult, row); + exec::StringWriter result(flatResult, row); if (v6Addr.isIPv4Mapped()) { result.append(v6Addr.createIPv4().str()); } else { @@ -165,7 +165,7 @@ class IPAddressCastOperator : public exec::CastOperator { memcpy(&addrBytes, &intAddr, ipaddress::kIPAddressBytes); std::reverse(addrBytes.begin(), addrBytes.end()); - exec::StringWriter result(flatResult, row); + exec::StringWriter result(flatResult, row); result.resize(ipaddress::kIPAddressBytes); memcpy(result.data(), &addrBytes, ipaddress::kIPAddressBytes); result.finalize(); diff --git a/velox/functions/prestosql/types/IPPrefixType.cpp b/velox/functions/prestosql/types/IPPrefixType.cpp index 554742bf34ab..5f1842d24737 100644 --- a/velox/functions/prestosql/types/IPPrefixType.cpp +++ b/velox/functions/prestosql/types/IPPrefixType.cpp @@ -122,7 +122,7 @@ class IPPrefixCastOperator : public exec::CastOperator { auto stringRet = fmt::format("{}/{}", ipString, prefixVal); // Write the string to the result vector - exec::StringWriter result(flatResult, row); + exec::StringWriter result(flatResult, row); result.append(stringRet); result.finalize(); }); diff --git a/velox/functions/prestosql/types/JsonType.cpp b/velox/functions/prestosql/types/JsonType.cpp index 83a6ba140ad3..ebbb722e2e55 100644 --- a/velox/functions/prestosql/types/JsonType.cpp +++ b/velox/functions/prestosql/types/JsonType.cpp @@ -314,7 +314,7 @@ struct AsJson { } // Appends the json string of the value at i to a string writer. - void append(vector_size_t i, exec::StringWriter<>& proxy) const { + void append(vector_size_t i, exec::StringWriter& proxy) const { if (decoded_->isNullAt(i)) { proxy.append("null"); } else { @@ -435,7 +435,7 @@ void castToJsonFromArray( auto offset = inputArray->offsetAt(row); auto size = inputArray->sizeAt(row); - auto proxy = exec::StringWriter<>(&flatResult, row); + auto proxy = exec::StringWriter(&flatResult, row); proxy.append("["_sv); for (int i = offset, end = offset + size; i < end; ++i) { @@ -529,7 +529,7 @@ void castToJsonFromMap( } std::sort(sortedKeys.begin(), sortedKeys.end()); - auto proxy = exec::StringWriter<>(&flatResult, row); + auto proxy = exec::StringWriter(&flatResult, row); proxy.append("{"_sv); for (auto it = sortedKeys.begin(); it != sortedKeys.end(); ++it) { @@ -587,7 +587,7 @@ void castToJsonFromRow( return; } - auto proxy = exec::StringWriter<>(&flatResult, row); + auto proxy = exec::StringWriter(&flatResult, row); proxy.append("["_sv); for (int i = 0; i < childrenSize; ++i) { diff --git a/velox/functions/prestosql/types/TimestampWithTimeZoneType.cpp b/velox/functions/prestosql/types/TimestampWithTimeZoneType.cpp index d6fbf76f6413..bf1effadb45b 100644 --- a/velox/functions/prestosql/types/TimestampWithTimeZoneType.cpp +++ b/velox/functions/prestosql/types/TimestampWithTimeZoneType.cpp @@ -131,7 +131,7 @@ void castToString( const auto timeZoneId = unpackZoneKeyId(timestampWithTimezone); const auto* timezonePtr = tz::locateZone(tz::getTimeZoneName(timeZoneId)); - exec::StringWriter result(flatResult, row); + exec::StringWriter result(flatResult, row); const auto maxResultSize = formatter->maxResultSize(timezonePtr); result.reserve(maxResultSize); diff --git a/velox/functions/prestosql/types/UuidType.cpp b/velox/functions/prestosql/types/UuidType.cpp index 8475b831d496..d25488438553 100644 --- a/velox/functions/prestosql/types/UuidType.cpp +++ b/velox/functions/prestosql/types/UuidType.cpp @@ -93,7 +93,7 @@ class UuidCastOperator : public exec::CastOperator { "c0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedf" "e0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"; - exec::StringWriter result(flatResult, row); + exec::StringWriter result(flatResult, row); result.resize(36); size_t offset = 0; diff --git a/velox/functions/sparksql/String.h b/velox/functions/sparksql/String.h index 2f188edc807c..b90b9b851d33 100644 --- a/velox/functions/sparksql/String.h +++ b/velox/functions/sparksql/String.h @@ -17,11 +17,9 @@ #include "folly/ssl/OpenSSLHash.h" -#include #include #include "velox/expression/VectorFunction.h" #include "velox/functions/Macros.h" -#include "velox/functions/UDFOutputString.h" #include "velox/functions/lib/string/StringCore.h" #include "velox/functions/lib/string/StringImpl.h" @@ -683,7 +681,7 @@ struct SubstrFunction { struct OverlayFunctionBase { template FOLLY_ALWAYS_INLINE void doCall( - exec::StringWriter& result, + exec::StringWriter& result, StringView input, StringView replace, int32_t pos, @@ -713,7 +711,7 @@ struct OverlayFunctionBase { template FOLLY_ALWAYS_INLINE void append( - exec::StringWriter& result, + exec::StringWriter& result, StringView input, std::pair pair) { if constexpr (isVarchar && !isAscii) {