Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement from_json_to_structs #2510

Open
wants to merge 67 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
1376061
Implement `castStringsToBooleans`
ttnghia Oct 16, 2024
ff2f340
Merge branch 'branch-24.12' into convert_table
ttnghia Oct 16, 2024
c3fa10d
Implement `removeQuotes`
ttnghia Oct 16, 2024
ae2b41f
Rewrite using offsets and chars
ttnghia Oct 16, 2024
8d7ad2e
Fix empty input
ttnghia Oct 17, 2024
9e759c4
Misc
ttnghia Oct 17, 2024
2fff949
Add `nullifyIfNotQuoted` option for `removeQuotes`
ttnghia Oct 17, 2024
d09de41
Implement `castStringsToDecimals`
ttnghia Oct 18, 2024
576b65c
Implement `removeQuotesForFloats`
ttnghia Oct 18, 2024
2bd5335
Fix `removeQuotesForFloats`
ttnghia Oct 18, 2024
21c80a5
Implement `castStringsToIntegers`
ttnghia Oct 18, 2024
1a7d192
Implement non-legacy `castStringsToDates`
ttnghia Oct 18, 2024
dcb463e
WIP for `cast_strings_to_dates_legacy`
ttnghia Oct 21, 2024
f059c21
Revert "WIP for `cast_strings_to_dates_legacy`"
ttnghia Oct 21, 2024
207d6a3
Merge branch 'branch-24.12' into convert_table
ttnghia Oct 23, 2024
07b23ea
Fix compile issues
ttnghia Oct 23, 2024
de83a25
WIP: Implement `from_json_to_structs`
ttnghia Oct 24, 2024
443ca38
Merge branch 'branch-24.12' into convert_table
ttnghia Oct 24, 2024
6c2bd5e
Fix cmake
ttnghia Oct 24, 2024
904d857
Fix compile issues
ttnghia Oct 24, 2024
d84f1fe
Implement `castStringsToFloats`
ttnghia Oct 24, 2024
3024583
WIP
ttnghia Oct 24, 2024
d33d8e2
WIP: Implementing `fromJSONToStructs`
ttnghia Oct 25, 2024
295c36c
Merge branch 'branch-24.12' into convert_table
ttnghia Oct 28, 2024
1ea9cc8
Fix compile errors
ttnghia Oct 29, 2024
c1bb2d4
Cleanup
ttnghia Oct 29, 2024
f6634b4
Revert code as we still need them
ttnghia Oct 29, 2024
06b2c19
Add error check
ttnghia Oct 29, 2024
2dcdd11
Add more comments
ttnghia Oct 29, 2024
f3c391b
Cleanup
ttnghia Oct 29, 2024
52c42a6
Return as-is if the column is date/time
ttnghia Oct 29, 2024
19c64be
Update test
ttnghia Oct 30, 2024
cb9d252
Merge branch 'branch-24.12' into convert_table
ttnghia Oct 30, 2024
5d07db1
Update cudf
ttnghia Oct 30, 2024
39e3a9b
Revert "Update cudf"
ttnghia Oct 30, 2024
8628136
Merge branch 'branch-24.12' into convert_table
ttnghia Oct 30, 2024
df1428d
Update cudf
ttnghia Oct 30, 2024
0fd8d0e
Merge branch 'branch-24.12' into convert_table
ttnghia Nov 8, 2024
1d48906
Update cudf
ttnghia Nov 8, 2024
d9e1db5
Change header
ttnghia Nov 9, 2024
0f053a6
Rewrite JSONUtils.cpp
ttnghia Nov 9, 2024
8912e00
Implement a common function for converting column
ttnghia Nov 12, 2024
3614718
Rewrite `convert_data_type`
ttnghia Nov 12, 2024
6d9bbdc
Remove `cast_strings_to_dates`
ttnghia Nov 12, 2024
a832938
Implement `convert_data_type`
ttnghia Nov 13, 2024
44b885b
Fix compile errors
ttnghia Nov 13, 2024
ab45de8
Add `CUDF_FUNC_RANGE();`
ttnghia Nov 13, 2024
89e74a0
Fix schema
ttnghia Nov 13, 2024
27ef532
Complete `from_json_to_structs`
ttnghia Nov 13, 2024
5b65712
Fix null mask
ttnghia Nov 13, 2024
6788471
Write Javadoc
ttnghia Nov 13, 2024
49c78ce
Rewrite JNI
ttnghia Nov 13, 2024
9d16d43
Merge branch 'branch-24.12' into convert_table
ttnghia Nov 13, 2024
bb9029b
Remove deprecated function
ttnghia Nov 14, 2024
1243599
Revert test
ttnghia Nov 14, 2024
6f89fcd
Remove header
ttnghia Nov 14, 2024
deb3ebf
Rewrite Javadoc
ttnghia Nov 14, 2024
9dc641f
Rename variable
ttnghia Nov 14, 2024
53b121d
Rewrite docs
ttnghia Nov 14, 2024
69265b4
Revert test
ttnghia Nov 14, 2024
da4d1f6
Cleanup headers
ttnghia Nov 14, 2024
1d91e64
Cleanup
ttnghia Nov 14, 2024
d0fa2ae
Rewrite the conversion functions
ttnghia Nov 14, 2024
f375a4d
Move code
ttnghia Nov 14, 2024
034a5ec
Remove call to `make_structs_column`
ttnghia Nov 14, 2024
74d858c
Cleanup
ttnghia Nov 14, 2024
7a32b6f
Merge branch 'branch-24.12' into convert_table
ttnghia Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/main/cpp/src/JSONUtilsJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,31 @@ Java_com_nvidia_spark_rapids_jni_JSONUtils_castStringsToBooleans(JNIEnv* env, jc
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_removeQuotes(JNIEnv* env,
jclass,
jlong j_input)
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_castStringsToDecimals(
JNIEnv* env, jclass, jlong j_input, jint precision, jint scale, jboolean is_us_locale)
{
JNI_NULL_CHECK(env, j_input, "j_input is null", 0);

try {
cudf::jni::auto_set_device(env);
auto const input = *reinterpret_cast<cudf::column_view const*>(j_input);
return cudf::jni::ptr_as_jlong(spark_rapids_jni::remove_quotes(input).release());

return cudf::jni::ptr_as_jlong(
spark_rapids_jni::cast_strings_to_decimals(input, precision, scale, is_us_locale).release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_removeQuotes(
JNIEnv* env, jclass, jlong j_input, jboolean nullify_if_not_quoted)
{
JNI_NULL_CHECK(env, j_input, "j_input is null", 0);

try {
cudf::jni::auto_set_device(env);
auto const input = *reinterpret_cast<cudf::column_view const*>(j_input);
return cudf::jni::ptr_as_jlong(
spark_rapids_jni::remove_quotes(input, nullify_if_not_quoted).release());
}
CATCH_STD(env, 0);
}
Expand Down
229 changes: 212 additions & 17 deletions src/main/cpp/src/json_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "cast_string.hpp"
#include "json_utils.hpp"

#include <cudf/column/column_device_view.cuh>
Expand All @@ -33,6 +34,7 @@

#include <cub/device/device_histogram.cuh>
#include <cub/device/device_memcpy.cuh>
#include <cub/device/device_segmented_reduce.cuh>
#include <cuda/functional>
#include <thrust/find.h>
#include <thrust/functional.h>
Expand All @@ -43,6 +45,8 @@
#include <thrust/tuple.h>
#include <thrust/uninitialized_fill.h>

#include <limits>

namespace spark_rapids_jni {

namespace detail {
Expand Down Expand Up @@ -276,7 +280,7 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> cast_strings

auto output = cudf::make_fixed_width_column(
cudf::data_type{cudf::type_id::BOOL8}, string_count, cudf::mask_state::UNALLOCATED, stream, mr);
auto validity = rmm::device_uvector<bool>(string_count, stream); // intentionally not use `mr`
auto validity = rmm::device_uvector<bool>(string_count, stream);
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

auto const input_sv = cudf::strings_column_view{input};
auto const offsets_it =
Expand Down Expand Up @@ -310,6 +314,9 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> cast_strings
return {false, false};
});

// Reset null count, as it is invalidated after calling to `mutable_view()`.
output->set_null_mask(rmm::device_buffer{0, stream, mr}, 0);

return {std::move(output), std::move(validity)};
}

Expand Down Expand Up @@ -356,8 +363,161 @@ rmm::device_uvector<char> make_chars_buffer(cudf::column_view const& offsets,
return chars_data;
}

// TODO there is a bug here around 0 https://github.com/NVIDIA/spark-rapids/issues/10898
std::unique_ptr<cudf::column> cast_strings_to_decimals(cudf::column_view const& input,
int precision,
int scale,
bool is_us_locale,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const string_count = input.size();
if (string_count == 0) {
auto const dtype = [precision, scale]() {
if (precision <= std::numeric_limits<int32_t>::digits10) {
return cudf::data_type(cudf::type_id::DECIMAL32, scale);
} else if (precision <= std::numeric_limits<int64_t>::digits10) {
return cudf::data_type(cudf::type_id::DECIMAL64, scale);
} else if (precision <= std::numeric_limits<__int128_t>::digits10) {
return cudf::data_type(cudf::type_id::DECIMAL128, scale);
} else {
CUDF_FAIL("Unable to support decimal with precision " + std::to_string(precision));
}
}();
return cudf::make_empty_column(dtype);
}

CUDF_EXPECTS(is_us_locale, "String to decimal conversion is only supported in US locale.");

auto const input_sv = cudf::strings_column_view{input};
auto const in_offsets =
cudf::detail::offsetalator_factory::make_input_iterator(input_sv.offsets());

// Count the number of characters `"`.
rmm::device_uvector<int8_t> quote_counts(string_count, stream);
// Count the number of characters `"` and `,` in each string.
rmm::device_uvector<int8_t> remove_counts(string_count, stream);

{
using count_type = thrust::tuple<int8_t, int8_t>;
auto const check_it = cudf::detail::make_counting_transform_iterator(
0,
cuda::proclaim_return_type<count_type>(
[chars = input_sv.chars_begin(stream)] __device__(auto idx) {
auto const c = chars[idx];
auto const is_quote = c == '"';
auto const should_remove = is_quote || c == ',';
return count_type{static_cast<int8_t>(is_quote), static_cast<int8_t>(should_remove)};
}));
auto const plus_op =
cuda::proclaim_return_type<count_type>([] __device__(count_type lhs, count_type rhs) {
return count_type{thrust::get<0>(lhs) + thrust::get<0>(rhs),
thrust::get<1>(lhs) + thrust::get<1>(rhs)};
});

auto const out_count_it =
thrust::make_zip_iterator(quote_counts.begin(), remove_counts.begin());

std::size_t temp_storage_bytes = 0;
cub::DeviceSegmentedReduce::Reduce(nullptr,
temp_storage_bytes,
check_it,
out_count_it,
string_count,
in_offsets,
in_offsets + 1,
plus_op,
count_type{0, 0},
stream.value());
auto d_temp_storage = rmm::device_buffer{temp_storage_bytes, stream};
cub::DeviceSegmentedReduce::Reduce(d_temp_storage.data(),
temp_storage_bytes,
check_it,
out_count_it,
string_count,
in_offsets,
in_offsets + 1,
plus_op,
count_type{0, 0},
stream.value());
}

auto const out_size_it = cudf::detail::make_counting_transform_iterator(
0,
cuda::proclaim_return_type<cudf::size_type>(
[offsets = in_offsets,
quote_counts = quote_counts.begin(),
remove_counts = remove_counts.begin()] __device__(auto idx) {
auto const input_size = offsets[idx + 1] - offsets[idx];
// If the current row is a non-quoted string, just return the original string.
if (quote_counts[idx] == 0) { return static_cast<cudf::size_type>(input_size); }
// Otherwise, we will modify the string, removing characters '"' and ','.
return static_cast<cudf::size_type>(input_size - remove_counts[idx]);
}));
auto [offsets_column, bytes] = cudf::strings::detail::make_offsets_child_column(
out_size_it, out_size_it + string_count, stream, mr);

// If the output strings column does not change in its total bytes, we know that it does not have
// any '"' or ',' characters.
if (bytes == input_sv.chars_size(stream)) {
return string_to_decimal(precision, scale, input_sv, false, false, stream, mr);
}

auto const out_offsets =
cudf::detail::offsetalator_factory::make_input_iterator(offsets_column->view());
auto chars_data = rmm::device_uvector<char>(bytes, stream, mr);

// Since the strings store decimal numbers, they should be very short.
// As such, using one thread per string should be good.
thrust::for_each(rmm::exec_policy_nosync(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(string_count),
[in_offsets,
out_offsets,
input = input_sv.chars_begin(stream),
output = chars_data.begin()] __device__(auto idx) {
auto const in_size = in_offsets[idx + 1] - in_offsets[idx];
auto const out_size = out_offsets[idx + 1] - out_offsets[idx];
if (in_size == 0) { return; }

// If the output size is not changed, we are returning the original unquoted
// string. Such string may still contain other alphabet characters, but that
// should be handled in the conversion function later on.
if (in_size == out_size) {
memcpy(output + out_offsets[idx], input + in_offsets[idx], in_size);
} else { // copy byte by byte, ignoring '"' and ',' characters.
auto in_ptr = input + in_offsets[idx];
auto in_end = input + in_offsets[idx + 1];
auto out_ptr = output + out_offsets[idx];
while (in_ptr != in_end) {
if (*in_ptr != '"' && *in_ptr != ',') {
*out_ptr = *in_ptr;
++out_ptr;
}
++in_ptr;
}
}
});

auto const unquoted_strings = cudf::make_strings_column(string_count,
std::move(offsets_column),
chars_data.release(),
0,
rmm::device_buffer{0, stream, mr});
return string_to_decimal(precision,
scale,
cudf::strings_column_view{unquoted_strings->view()},
false,
false,
stream,
mr);
}

std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> remove_quotes(
cudf::column_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr)
cudf::column_view const& input,
bool nullify_if_not_quoted,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const string_count = input.size();
if (string_count == 0) {
Expand All @@ -375,7 +535,8 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> remove_quote
thrust::tabulate(rmm::exec_policy_nosync(stream),
string_pairs.begin(),
string_pairs.end(),
[chars = input_sv.chars_begin(stream),
[nullify_if_not_quoted,
chars = input_sv.chars_begin(stream),
offsets = input_offsets_it,
is_valid = is_valid_it] __device__(cudf::size_type idx) -> string_index_pair {
if (!is_valid[idx]) { return {nullptr, 0}; }
Expand All @@ -387,7 +548,9 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> remove_quote

// Need to check for size, since the input string may contain just a single
// character `"`. Such input should not be considered as quoted.
auto const is_quoted = size > 1 && str[0] == '"' && str[size - 1] == '"';
auto const is_quoted = size > 1 && str[0] == '"' && str[size - 1] == '"';
if (nullify_if_not_quoted && !is_quoted) { return {nullptr, 0}; }

auto const output_size = is_quoted ? size - 2 : size;
return {chars + start_offset + (is_quoted ? 1 : 0), output_size};
});
Expand All @@ -403,14 +566,32 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> remove_quote
auto chars_data = /*cudf::strings::detail::*/ make_chars_buffer(
offsets_column->view(), bytes, string_pairs.begin(), string_count, stream, mr);

auto output = cudf::make_strings_column(string_count,
std::move(offsets_column),
chars_data.release(),
input.null_count(),
cudf::detail::copy_bitmask(input, stream, mr));
if (nullify_if_not_quoted) {
auto validity = rmm::device_uvector<bool>(string_count, stream);
thrust::transform(
rmm::exec_policy_nosync(stream),
string_pairs.begin(),
string_pairs.end(),
validity.begin(),
[] __device__(string_index_pair const& pair) { return pair.first != nullptr; });

// Null mask and null count will be updated later from the validity vector.
auto output = cudf::make_strings_column(string_count,
std::move(offsets_column),
chars_data.release(),
0,
rmm::device_buffer{0, stream, mr});

return {std::move(output), std::move(validity)};
} else {
auto output = cudf::make_strings_column(string_count,
std::move(offsets_column),
chars_data.release(),
input.null_count(),
cudf::detail::copy_bitmask(input, stream, mr));

// This function does not return the validity vector.
return {std::move(output), rmm::device_uvector<bool>(0, stream)};
return {std::move(output), rmm::device_uvector<bool>(0, stream)};
}
}

std::unique_ptr<cudf::column> convert_column_type(cudf::column_view const& input,
Expand Down Expand Up @@ -480,21 +661,35 @@ std::unique_ptr<cudf::column> cast_strings_to_booleans(cudf::column_view const&
auto [output, validity] = detail::cast_strings_to_booleans(input, stream, mr);
auto [null_mask, null_count] =
cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity{}, stream, mr);
if (null_count > 0) {
output->set_null_mask(std::move(null_mask), null_count);
} else {
output->set_null_mask(rmm::device_buffer{}, 0);
}
if (null_count > 0) { output->set_null_mask(std::move(null_mask), null_count); }
return std::move(output);
}

std::unique_ptr<cudf::column> cast_strings_to_decimals(cudf::column_view const& input,
int precision,
int scale,
bool is_us_locale,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();

return detail::cast_strings_to_decimals(input, precision, scale, is_us_locale, stream, mr);
}

std::unique_ptr<cudf::column> remove_quotes(cudf::column_view const& input,
bool nullify_if_not_quoted,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();

auto [output, validity] = detail::remove_quotes(input, stream, mr);
auto [output, validity] = detail::remove_quotes(input, nullify_if_not_quoted, stream, mr);
if (validity.size() > 0) {
auto [null_mask, null_count] =
cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity{}, stream, mr);
if (null_count > 0) { output->set_null_mask(std::move(null_mask), null_count); }
}
return std::move(output);
}

Expand Down
9 changes: 9 additions & 0 deletions src/main/cpp/src/json_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ std::unique_ptr<cudf::column> cast_strings_to_booleans(
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

std::unique_ptr<cudf::column> cast_strings_to_decimals(
cudf::column_view const& input,
int precision,
int scale,
bool is_us_locale,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

std::unique_ptr<cudf::column> remove_quotes(
cudf::column_view const& input,
bool nullify_if_not_quoted,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

Expand Down
Loading
Loading