Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Update Base64 as non-throwing API #11149

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 113 additions & 57 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <folly/Portability.h>
#include <folly/container/Foreach.h>
#include <folly/io/Cursor.h>
#include <stdint.h>
#include <cstdint>

#include "velox/common/base/Exceptions.h"

Expand Down Expand Up @@ -163,7 +163,7 @@ std::string Base64::encodeImpl(
const T& input,
const Charset& charset,
bool includePadding) {
size_t encodedSize = calculateEncodedSize(input.size(), includePadding);
static size_t encodedSize{calculateEncodedSize(input.size(), includePadding)};
std::string encodedResult;
encodedResult.resize(encodedSize);
encodeImpl(input, charset, includePadding, encodedResult.data());
Expand Down Expand Up @@ -310,8 +310,7 @@ std::string Base64::encode(const folly::IOBuf* inputBuffer) {
// static
std::string Base64::decode(folly::StringPiece encodedText) {
std::string decodedResult;
Base64::decode(
std::make_pair(encodedText.data(), encodedText.size()), decodedResult);
decode(std::make_pair(encodedText.data(), encodedText.size()), decodedResult);
return decodedResult;
}

Expand All @@ -320,29 +319,45 @@ void Base64::decode(
const std::pair<const char*, int32_t>& payload,
std::string& decodedOutput) {
size_t inputSize = payload.second;
decodedOutput.resize(calculateDecodedSize(payload.first, inputSize));
decode(payload.first, inputSize, decodedOutput.data(), decodedOutput.size());
size_t decodedSize{0};
auto status = calculateDecodedSize(payload.first, inputSize, decodedSize);
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}
decodedOutput.resize(decodedSize);
status = decode(
payload.first, inputSize, decodedOutput.data(), decodedOutput.size());
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}
}

// static
void Base64::decode(const char* input, size_t size, char* output) {
size_t expectedOutputSize = size / 4 * 3;
Base64::decode(input, size, output, expectedOutputSize);
void Base64::decode(const char* input, size_t inputSize, char* outputBuffer) {
size_t outputSize{0};
if (auto status = calculateDecodedSize(input, inputSize, outputSize);
!status.ok() ||
!(status = decode(input, inputSize, outputBuffer, outputSize)).ok()) {
VELOX_USER_FAIL(status.message());
}
}

// static
uint8_t Base64::base64ReverseLookup(
Status Base64::base64ReverseLookup(
char encodedChar,
const Base64::ReverseIndex& reverseIndex) {
auto reverseLookupValue = reverseIndex[static_cast<uint8_t>(encodedChar)];
const ReverseIndex& reverseIndex,
uint8_t& reverseLookupValue) {
reverseLookupValue = reverseIndex[static_cast<uint8_t>(encodedChar)];
if (reverseLookupValue >= 0x40) {
VELOX_USER_FAIL("decode() - invalid input string: invalid characters");
return Status::UserError(fmt::format(
"decode() - invalid input string: invalid character '{}'",
encodedChar));
}
return reverseLookupValue;
return Status::OK();
}

// static
size_t Base64::decode(
Status Base64::decode(
const char* input,
size_t inputSize,
char* output,
Expand All @@ -352,63 +367,71 @@ size_t Base64::decode(
}

// static
size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) {
Status Base64::calculateDecodedSize(
const char* input,
size_t& inputSize,
size_t& decodedSize) {
if (inputSize == 0) {
return 0;
decodedSize = 0;
return Status::OK();
}

// Check if the input string is padded
if (isPadded(input, inputSize)) {
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (inputSize % kEncodedBlockByteSize != 0) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length is not a multiple of 4.");
}

auto decodedSize =
(inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize;
decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize;
auto paddingCount = numPadding(input, inputSize);
inputSize -= paddingCount;

// Adjust the needed size by deducting the bytes corresponding to the
// padding from the calculated size.
return decodedSize -
decodedSize -=
((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) /
kEncodedBlockByteSize;
return Status::OK();
}
// If not padded, Calculate extra bytes, if any
auto extraBytes = inputSize % kEncodedBlockByteSize;
auto decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize;
decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize;

// Adjust the needed size for extra bytes, if present
if (extraBytes) {
if (extraBytes == 1) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize;
}

return decodedSize;
return Status::OK();
}

// static
size_t Base64::decodeImpl(
Status Base64::decodeImpl(
const char* input,
size_t inputSize,
char* outputBuffer,
size_t outputSize,
const ReverseIndex& reverseIndex) {
if (!inputSize) {
return 0;
if (inputSize == 0) {
return Status::OK();
}

auto decodedSize = calculateDecodedSize(input, inputSize);
size_t decodedSize;
auto status = calculateDecodedSize(input, inputSize, decodedSize);
if (!status.ok()) {
return status;
}
if (outputSize < decodedSize) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid output string: "
"output string is too small.");
}
Expand All @@ -418,32 +441,57 @@ size_t Base64::decodeImpl(
// Each character of the 4 encodes 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8-bit bytes.
uint32_t decodedBlock =
(base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12) |
(base64ReverseLookup(input[2], reverseIndex) << 6) |
base64ReverseLookup(input[3], reverseIndex);
outputBuffer[0] = (decodedBlock >> 16) & 0xff;
outputBuffer[1] = (decodedBlock >> 8) & 0xff;
outputBuffer[2] = decodedBlock & 0xff;
uint32_t decodedBlock = 0;
uint8_t reverseLookupValue;
for (int i = 0; i < 4; ++i) {
status = base64ReverseLookup(input[i], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue << (18 - 6 * i);
}
outputBuffer[0] = static_cast<char>((decodedBlock >> 16) & 0xff);
outputBuffer[1] = static_cast<char>((decodedBlock >> 8) & 0xff);
outputBuffer[2] = static_cast<char>(decodedBlock & 0xff);
}

// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(inputSize >= 2);
uint32_t decodedBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12);
outputBuffer[0] = (decodedBlock >> 16) & 0xff;
if (inputSize > 2) {
decodedBlock |= base64ReverseLookup(input[2], reverseIndex) << 6;
outputBuffer[1] = (decodedBlock >> 8) & 0xff;
if (inputSize > 3) {
decodedBlock |= base64ReverseLookup(input[3], reverseIndex);
outputBuffer[2] = decodedBlock & 0xff;
if (inputSize >= 2) {
uint32_t decodedBlock = 0;
uint8_t reverseLookupValue;

// Process the first two characters
for (int i = 0; i < 2; ++i) {
status = base64ReverseLookup(input[i], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue << (18 - 6 * i);
}
outputBuffer[0] = static_cast<char>((decodedBlock >> 16) & 0xff);

if (inputSize > 2) {
status = base64ReverseLookup(input[2], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue << 6;
outputBuffer[1] = static_cast<char>((decodedBlock >> 8) & 0xff);

if (inputSize > 3) {
status =
base64ReverseLookup(input[3], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue;
outputBuffer[2] = static_cast<char>(decodedBlock & 0xff);
}
}
}

return decodedSize;
return Status::OK();
}

// static
Expand All @@ -462,19 +510,19 @@ std::string Base64::encodeUrl(const folly::IOBuf* inputBuffer) {
}

// static
void Base64::decodeUrl(
Status Base64::decodeUrl(
const char* input,
size_t inputSize,
char* outputBuffer,
size_t outputSize) {
decodeImpl(
return decodeImpl(
input, inputSize, outputBuffer, outputSize, kBase64UrlReverseIndexTable);
}

// static
std::string Base64::decodeUrl(folly::StringPiece encodedText) {
std::string decodedOutput;
Base64::decodeUrl(
decodeUrl(
std::make_pair(encodedText.data(), encodedText.size()), decodedOutput);
return decodedOutput;
}
Expand All @@ -483,15 +531,23 @@ std::string Base64::decodeUrl(folly::StringPiece encodedText) {
void Base64::decodeUrl(
const std::pair<const char*, int32_t>& payload,
std::string& decodedOutput) {
size_t decodedSize = (payload.second + 3) / 4 * 3;
decodedOutput.resize(decodedSize, '\0');
decodedSize = Base64::decodeImpl(
size_t inputSize = payload.second;
size_t decodedSize{0};
auto status = calculateDecodedSize(payload.first, inputSize, decodedSize);
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}

decodedOutput.resize(decodedSize);
status = decodeImpl(
payload.first,
payload.second,
&decodedOutput[0],
decodedSize,
decodedOutput.data(),
decodedOutput.size(),
kBase64UrlReverseIndexTable);
decodedOutput.resize(decodedSize);
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}
}

} // namespace facebook::velox::encoding
21 changes: 13 additions & 8 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <string>

#include "velox/common/base/GTestMacros.h"
#include "velox/common/base/Status.h"

namespace facebook::velox::encoding {

Expand Down Expand Up @@ -86,7 +87,7 @@ class Base64 {

/// Decodes the specified number of characters from the 'input' and writes the
/// result to the 'outputBuffer'.
static size_t decode(
static Status decode(
const char* input,
size_t inputSize,
char* outputBuffer,
Expand All @@ -103,7 +104,7 @@ class Base64 {

/// Decodes the specified number of characters from the 'input' using URL
/// encoding and writes the result to the 'outputBuffer'
static void decodeUrl(
static Status decodeUrl(
const char* input,
size_t inputSize,
char* outputBuffer,
Expand All @@ -112,9 +113,12 @@ class Base64 {
/// Calculates the encoded size based on input 'inputSize'.
static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true);

/// Returns the actual size of the decoded data. Removes the padding
/// length from the input data 'inputSize'.
static size_t calculateDecodedSize(const char* input, size_t& inputSize);
/// Calculates the decoded size based on encoded input and adjusts the input
/// size for padding.
static Status calculateDecodedSize(
const char* input,
size_t& inputSize,
size_t& decodedSize);

private:
// Padding character used in encoding.
Expand All @@ -137,9 +141,10 @@ class Base64 {

// Reverse lookup helper function to get the original index of a Base64
// character.
static uint8_t base64ReverseLookup(
static Status base64ReverseLookup(
char encodedChar,
const ReverseIndex& reverseIndex);
const ReverseIndex& reverseIndex,
uint8_t& reverseLookupValue);

// Encodes the specified data using the provided charset.
template <class T>
Expand All @@ -155,7 +160,7 @@ class Base64 {
char* outputBuffer);

// Decodes the specified data using the provided reverse lookup table.
static size_t decodeImpl(
static Status decodeImpl(
const char* input,
size_t inputSize,
char* outputBuffer,
Expand Down
2 changes: 1 addition & 1 deletion velox/common/encode/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ if(${VELOX_BUILD_TESTING})
endif()

velox_add_library(velox_encode Base64.cpp)
velox_link_libraries(velox_encode PUBLIC Folly::folly)
velox_link_libraries(velox_encode PUBLIC velox_status Folly::folly)
Loading
Loading