Skip to content

Commit

Permalink
refactor: Base64 APIs as non-throwing APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Jan 30, 2025
1 parent 5f8e7f1 commit 3412318
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 100 deletions.
170 changes: 113 additions & 57 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <folly/Portability.h>
#include <folly/container/Foreach.h>
#include <folly/io/Cursor.h>
#include <stdint.h>
#include <cstdint>

#include "velox/common/base/Exceptions.h"

Expand Down Expand Up @@ -163,7 +163,7 @@ std::string Base64::encodeImpl(
const T& input,
const Charset& charset,
bool includePadding) {
size_t encodedSize = calculateEncodedSize(input.size(), includePadding);
static size_t encodedSize{calculateEncodedSize(input.size(), includePadding)};
std::string encodedResult;
encodedResult.resize(encodedSize);
encodeImpl(input, charset, includePadding, encodedResult.data());
Expand Down Expand Up @@ -310,8 +310,7 @@ std::string Base64::encode(const folly::IOBuf* inputBuffer) {
// static
std::string Base64::decode(folly::StringPiece encodedText) {
std::string decodedResult;
Base64::decode(
std::make_pair(encodedText.data(), encodedText.size()), decodedResult);
decode(std::make_pair(encodedText.data(), encodedText.size()), decodedResult);
return decodedResult;
}

Expand All @@ -320,29 +319,45 @@ void Base64::decode(
const std::pair<const char*, int32_t>& payload,
std::string& decodedOutput) {
size_t inputSize = payload.second;
decodedOutput.resize(calculateDecodedSize(payload.first, inputSize));
decode(payload.first, inputSize, decodedOutput.data(), decodedOutput.size());
size_t decodedSize{0};
auto status = calculateDecodedSize(payload.first, inputSize, decodedSize);
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}
decodedOutput.resize(decodedSize);
status = decode(
payload.first, inputSize, decodedOutput.data(), decodedOutput.size());
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}
}

// static
void Base64::decode(const char* input, size_t size, char* output) {
size_t expectedOutputSize = size / 4 * 3;
Base64::decode(input, size, output, expectedOutputSize);
void Base64::decode(const char* input, size_t inputSize, char* outputBuffer) {
size_t outputSize{0};
if (auto status = calculateDecodedSize(input, inputSize, outputSize);
!status.ok() ||
!(status = decode(input, inputSize, outputBuffer, outputSize)).ok()) {
VELOX_USER_FAIL(status.message());
}
}

// static
uint8_t Base64::base64ReverseLookup(
Status Base64::base64ReverseLookup(
char encodedChar,
const Base64::ReverseIndex& reverseIndex) {
auto reverseLookupValue = reverseIndex[static_cast<uint8_t>(encodedChar)];
const ReverseIndex& reverseIndex,
uint8_t& reverseLookupValue) {
reverseLookupValue = reverseIndex[static_cast<uint8_t>(encodedChar)];
if (reverseLookupValue >= 0x40) {
VELOX_USER_FAIL("decode() - invalid input string: invalid characters");
return Status::UserError(fmt::format(
"decode() - invalid input string: invalid character '{}'",
encodedChar));
}
return reverseLookupValue;
return Status::OK();
}

// static
size_t Base64::decode(
Status Base64::decode(
const char* input,
size_t inputSize,
char* output,
Expand All @@ -352,63 +367,71 @@ size_t Base64::decode(
}

// static
size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) {
Status Base64::calculateDecodedSize(
const char* input,
size_t& inputSize,
size_t& decodedSize) {
if (inputSize == 0) {
return 0;
decodedSize = 0;
return Status::OK();
}

// Check if the input string is padded
if (isPadded(input, inputSize)) {
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (inputSize % kEncodedBlockByteSize != 0) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length is not a multiple of 4.");
}

auto decodedSize =
(inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize;
decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize;
auto paddingCount = numPadding(input, inputSize);
inputSize -= paddingCount;

// Adjust the needed size by deducting the bytes corresponding to the
// padding from the calculated size.
return decodedSize -
decodedSize -=
((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) /
kEncodedBlockByteSize;
return Status::OK();
}
// If not padded, Calculate extra bytes, if any
auto extraBytes = inputSize % kEncodedBlockByteSize;
auto decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize;
decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize;

// Adjust the needed size for extra bytes, if present
if (extraBytes) {
if (extraBytes == 1) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize;
}

return decodedSize;
return Status::OK();
}

// static
size_t Base64::decodeImpl(
Status Base64::decodeImpl(
const char* input,
size_t inputSize,
char* outputBuffer,
size_t outputSize,
const ReverseIndex& reverseIndex) {
if (!inputSize) {
return 0;
if (inputSize == 0) {
return Status::OK();
}

auto decodedSize = calculateDecodedSize(input, inputSize);
size_t decodedSize;
auto status = calculateDecodedSize(input, inputSize, decodedSize);
if (!status.ok()) {
return status;
}
if (outputSize < decodedSize) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid output string: "
"output string is too small.");
}
Expand All @@ -418,32 +441,57 @@ size_t Base64::decodeImpl(
// Each character of the 4 encodes 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8-bit bytes.
uint32_t decodedBlock =
(base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12) |
(base64ReverseLookup(input[2], reverseIndex) << 6) |
base64ReverseLookup(input[3], reverseIndex);
outputBuffer[0] = (decodedBlock >> 16) & 0xff;
outputBuffer[1] = (decodedBlock >> 8) & 0xff;
outputBuffer[2] = decodedBlock & 0xff;
uint32_t decodedBlock = 0;
uint8_t reverseLookupValue;
for (int i = 0; i < 4; ++i) {
status = base64ReverseLookup(input[i], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue << (18 - 6 * i);
}
outputBuffer[0] = static_cast<char>((decodedBlock >> 16) & 0xff);
outputBuffer[1] = static_cast<char>((decodedBlock >> 8) & 0xff);
outputBuffer[2] = static_cast<char>(decodedBlock & 0xff);
}

// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(inputSize >= 2);
uint32_t decodedBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12);
outputBuffer[0] = (decodedBlock >> 16) & 0xff;
if (inputSize > 2) {
decodedBlock |= base64ReverseLookup(input[2], reverseIndex) << 6;
outputBuffer[1] = (decodedBlock >> 8) & 0xff;
if (inputSize > 3) {
decodedBlock |= base64ReverseLookup(input[3], reverseIndex);
outputBuffer[2] = decodedBlock & 0xff;
if (inputSize >= 2) {
uint32_t decodedBlock = 0;
uint8_t reverseLookupValue;

// Process the first two characters
for (int i = 0; i < 2; ++i) {
status = base64ReverseLookup(input[i], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue << (18 - 6 * i);
}
outputBuffer[0] = static_cast<char>((decodedBlock >> 16) & 0xff);

if (inputSize > 2) {
status = base64ReverseLookup(input[2], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue << 6;
outputBuffer[1] = static_cast<char>((decodedBlock >> 8) & 0xff);

if (inputSize > 3) {
status =
base64ReverseLookup(input[3], reverseIndex, reverseLookupValue);
if (!status.ok()) {
return status;
}
decodedBlock |= reverseLookupValue;
outputBuffer[2] = static_cast<char>(decodedBlock & 0xff);
}
}
}

return decodedSize;
return Status::OK();
}

// static
Expand All @@ -462,19 +510,19 @@ std::string Base64::encodeUrl(const folly::IOBuf* inputBuffer) {
}

// static
void Base64::decodeUrl(
Status Base64::decodeUrl(
const char* input,
size_t inputSize,
char* outputBuffer,
size_t outputSize) {
decodeImpl(
return decodeImpl(
input, inputSize, outputBuffer, outputSize, kBase64UrlReverseIndexTable);
}

// static
std::string Base64::decodeUrl(folly::StringPiece encodedText) {
std::string decodedOutput;
Base64::decodeUrl(
decodeUrl(
std::make_pair(encodedText.data(), encodedText.size()), decodedOutput);
return decodedOutput;
}
Expand All @@ -483,15 +531,23 @@ std::string Base64::decodeUrl(folly::StringPiece encodedText) {
void Base64::decodeUrl(
const std::pair<const char*, int32_t>& payload,
std::string& decodedOutput) {
size_t decodedSize = (payload.second + 3) / 4 * 3;
decodedOutput.resize(decodedSize, '\0');
decodedSize = Base64::decodeImpl(
size_t inputSize = payload.second;
size_t decodedSize{0};
auto status = calculateDecodedSize(payload.first, inputSize, decodedSize);
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}

decodedOutput.resize(decodedSize);
status = decodeImpl(
payload.first,
payload.second,
&decodedOutput[0],
decodedSize,
decodedOutput.data(),
decodedOutput.size(),
kBase64UrlReverseIndexTable);
decodedOutput.resize(decodedSize);
if (!status.ok()) {
VELOX_USER_FAIL(status.message());
}
}

} // namespace facebook::velox::encoding
21 changes: 13 additions & 8 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <string>

#include "velox/common/base/GTestMacros.h"
#include "velox/common/base/Status.h"

namespace facebook::velox::encoding {

Expand Down Expand Up @@ -86,7 +87,7 @@ class Base64 {

/// Decodes the specified number of characters from the 'input' and writes the
/// result to the 'outputBuffer'.
static size_t decode(
static Status decode(
const char* input,
size_t inputSize,
char* outputBuffer,
Expand All @@ -103,7 +104,7 @@ class Base64 {

/// Decodes the specified number of characters from the 'input' using URL
/// encoding and writes the result to the 'outputBuffer'
static void decodeUrl(
static Status decodeUrl(
const char* input,
size_t inputSize,
char* outputBuffer,
Expand All @@ -112,9 +113,12 @@ class Base64 {
/// Calculates the encoded size based on input 'inputSize'.
static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true);

/// Returns the actual size of the decoded data. Removes the padding
/// length from the input data 'inputSize'.
static size_t calculateDecodedSize(const char* input, size_t& inputSize);
/// Calculates the decoded size based on encoded input and adjusts the input
/// size for padding.
static Status calculateDecodedSize(
const char* input,
size_t& inputSize,
size_t& decodedSize);

private:
// Padding character used in encoding.
Expand All @@ -137,9 +141,10 @@ class Base64 {

// Reverse lookup helper function to get the original index of a Base64
// character.
static uint8_t base64ReverseLookup(
static Status base64ReverseLookup(
char encodedChar,
const ReverseIndex& reverseIndex);
const ReverseIndex& reverseIndex,
uint8_t& reverseLookupValue);

// Encodes the specified data using the provided charset.
template <class T>
Expand All @@ -155,7 +160,7 @@ class Base64 {
char* outputBuffer);

// Decodes the specified data using the provided reverse lookup table.
static size_t decodeImpl(
static Status decodeImpl(
const char* input,
size_t inputSize,
char* outputBuffer,
Expand Down
2 changes: 1 addition & 1 deletion velox/common/encode/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ if(${VELOX_BUILD_TESTING})
endif()

velox_add_library(velox_encode Base64.cpp)
velox_link_libraries(velox_encode PUBLIC Folly::folly)
velox_link_libraries(velox_encode PUBLIC velox_status Folly::folly)
Loading

0 comments on commit 3412318

Please sign in to comment.