Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding cudf::get_current_device_resource_ref in place of rmm calls #2398

Open
wants to merge 4 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/main/cpp/benchmarks/common/generate_input.cu
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ std::unique_ptr<cudf::column> create_random_column(data_profile const& profile,
null_mask.end(),
thrust::identity<bool>{},
cudf::get_default_stream(),
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());

return std::make_unique<cudf::column>(
cudf::data_type{cudf::type_to_id<T>()},
Expand Down Expand Up @@ -517,7 +517,7 @@ std::unique_ptr<cudf::column> create_random_utf8_string_column(data_profile cons
null_mask.end() - 1,
thrust::identity<bool>{},
cudf::get_default_stream(),
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());

return cudf::make_strings_column(
num_rows,
Expand Down Expand Up @@ -553,7 +553,7 @@ std::unique_ptr<cudf::column> create_random_column<cudf::string_view>(data_profi
cudf::out_of_bounds_policy::DONT_CHECK,
cudf::detail::negative_index_policy::NOT_ALLOWED,
cudf::get_default_stream(),
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());
return std::move(str_table->release()[0]);
}

Expand Down Expand Up @@ -641,7 +641,7 @@ std::unique_ptr<cudf::column> create_random_column<cudf::struct_view>(data_profi
valids.end(),
thrust::identity<bool>{},
cudf::get_default_stream(),
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());
}
return std::pair<rmm::device_buffer, cudf::size_type>{};
}();
Expand Down Expand Up @@ -731,7 +731,7 @@ std::unique_ptr<cudf::column> create_random_column<cudf::list_view>(data_profile
valids.end(),
thrust::identity<bool>{},
cudf::get_default_stream(),
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());
list_column = cudf::make_lists_column(
num_rows,
std::move(offsets_column),
Expand Down Expand Up @@ -851,7 +851,7 @@ std::pair<rmm::device_buffer, cudf::size_type> create_random_null_mask(
thrust::make_counting_iterator<cudf::size_type>(size),
bool_generator{seed, 1.0 - *null_probability},
cudf::get_default_stream(),
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/main/cpp/src/bloom_filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ std::unique_ptr<cudf::list_scalar> bloom_filter_create(
int num_hashes,
int bloom_filter_longs,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Inserts input values into a bloom filter.
Expand Down Expand Up @@ -79,7 +79,7 @@ std::unique_ptr<cudf::column> bloom_filter_probe(
cudf::column_view const& input,
cudf::device_span<uint8_t const> bloom_filter,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Probe a bloom filter with an input column of int64_t values.
Expand All @@ -96,7 +96,7 @@ std::unique_ptr<cudf::column> bloom_filter_probe(
cudf::column_view const& input,
cudf::list_scalar& bloom_filter,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Merge multiple bloom filters into a single output.
Expand All @@ -114,6 +114,6 @@ std::unique_ptr<cudf::column> bloom_filter_probe(
std::unique_ptr<cudf::list_scalar> bloom_filter_merge(
cudf::column_view const& bloom_filters,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

} // namespace spark_rapids_jni
2 changes: 1 addition & 1 deletion src/main/cpp/src/case_when.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ namespace spark_rapids_jni {
std::unique_ptr<cudf::column> select_first_true_index(
cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

} // namespace spark_rapids_jni
12 changes: 6 additions & 6 deletions src/main/cpp/src/cast_string.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ std::unique_ptr<cudf::column> string_to_integer(
bool ansi_mode,
bool strip,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Convert a string column into an decimal column.
Expand All @@ -97,7 +97,7 @@ std::unique_ptr<cudf::column> string_to_decimal(
bool ansi_mode,
bool strip,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Convert a string column into an float column.
Expand All @@ -115,22 +115,22 @@ std::unique_ptr<cudf::column> string_to_float(
cudf::strings_column_view const& string_col,
bool ansi_mode,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

std::unique_ptr<cudf::column> format_float(
cudf::column_view const& input,
int const digits,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

std::unique_ptr<cudf::column> float_to_string(
cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

std::unique_ptr<cudf::column> decimal_to_non_ansi_string(
cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

} // namespace spark_rapids_jni
4 changes: 2 additions & 2 deletions src/main/cpp/src/datetime_rebase.cu
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ std::unique_ptr<cudf::column> rebase_gregorian_to_julian(cudf::column_view const
if (input.size() == 0) { return cudf::empty_like(input); }

auto const stream = cudf::get_default_stream();
auto const mr = rmm::mr::get_current_device_resource();
auto const mr = cudf::get_current_device_resource_ref();
return type == cudf::type_id::TIMESTAMP_DAYS ? gregorian_to_julian_days(input, stream, mr)
: gregorian_to_julian_micros(input, stream, mr);
}
Expand All @@ -368,7 +368,7 @@ std::unique_ptr<cudf::column> rebase_julian_to_gregorian(cudf::column_view const
if (input.size() == 0) { return cudf::empty_like(input); }

auto const stream = cudf::get_default_stream();
auto const mr = rmm::mr::get_current_device_resource();
auto const mr = cudf::get_current_device_resource_ref();
return type == cudf::type_id::TIMESTAMP_DAYS ? julian_to_gregorian_days(input, stream, mr)
: julian_to_gregorian_micros(input, stream, mr);
}
Expand Down
14 changes: 7 additions & 7 deletions src/main/cpp/src/decimal_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ std::unique_ptr<cudf::table> multiply_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, cudf::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1026,7 +1026,7 @@ std::unique_ptr<cudf::table> divide_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, cudf::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1060,7 +1060,7 @@ std::unique_ptr<cudf::table> integer_divide_decimal128(cudf::column_view const&
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, cudf::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1093,7 +1093,7 @@ std::unique_ptr<cudf::table> remainder_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, cudf::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1126,7 +1126,7 @@ std::unique_ptr<cudf::table> add_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, cudf::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1159,7 +1159,7 @@ std::unique_ptr<cudf::table> sub_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, cudf::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1410,7 +1410,7 @@ std::pair<std::unique_ptr<cudf::column>, bool> floating_point_to_decimal(
output_type, input.size(), cudf::mask_state::UNALLOCATED, stream, mr);

auto const decimal_places = -output_type.scale();
auto const default_mr = rmm::mr::get_current_device_resource();
auto const default_mr = cudf::get_current_device_resource_ref();

rmm::device_uvector<int8_t> validity(input.size(), stream, default_mr);
rmm::device_scalar<bool> has_failure(false, stream, default_mr);
Expand Down
2 changes: 1 addition & 1 deletion src/main/cpp/src/decimal_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ std::pair<std::unique_ptr<cudf::column>, bool> floating_point_to_decimal(
cudf::data_type output_type,
int32_t precision,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

} // namespace cudf::jni
6 changes: 3 additions & 3 deletions src/main/cpp/src/from_json_to_raw_map.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ rmm::device_uvector<char> unify_json_strings(cudf::strings_column_view const& in
{
if (input.is_empty()) {
return cudf::detail::make_device_uvector_async<char>(
std::vector<char>{'[', ']'}, stream, rmm::mr::get_current_device_resource());
std::vector<char>{'[', ']'}, stream, cudf::get_current_device_resource_ref());
}

auto const d_strings = cudf::column_device_view::create(input.parent(), stream);
Expand All @@ -84,7 +84,7 @@ rmm::device_uvector<char> unify_json_strings(cudf::strings_column_view const& in
cudf::string_scalar(","), // append `,` character between the input rows
cudf::string_scalar("{}"), // replacement for null rows
stream,
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());
auto const joined_input_scv = cudf::strings_column_view{*joined_input};
auto const joined_input_size_bytes = joined_input_scv.chars_size(stream);
// TODO: This assertion requires a stream synchronization, may want to remove at some point.
Expand Down Expand Up @@ -656,7 +656,7 @@ std::unique_ptr<cudf::column> from_json_to_raw_map(cudf::strings_column_view con
cudf::device_span<char const>{unified_json_buff.data(), unified_json_buff.size()},
cudf::io::json_reader_options{},
stream,
rmm::mr::get_current_device_resource());
cudf::get_current_device_resource_ref());

#ifdef DEBUG_FROM_JSON
print_debug(tokens, "Tokens", ", ", stream);
Expand Down
6 changes: 3 additions & 3 deletions src/main/cpp/src/get_json_object.cu
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ construct_path_commands(
d_path_commands.reserve(h_path_commands->size());
for (auto const& path_commands : *h_path_commands) {
d_path_commands.emplace_back(cudf::detail::make_device_uvector_async(
path_commands, stream, rmm::mr::get_current_device_resource()));
path_commands, stream, cudf::get_current_device_resource_ref()));
}

return {std::move(d_path_commands),
Expand Down Expand Up @@ -1050,7 +1050,7 @@ std::vector<std::unique_ptr<cudf::column>> get_json_object_batch(
d_error_check.data() + idx});
}
auto d_path_data = cudf::detail::make_device_uvector_async(
h_path_data, stream, rmm::mr::get_current_device_resource());
h_path_data, stream, cudf::get_current_device_resource_ref());
thrust::uninitialized_fill(
rmm::exec_policy(stream), d_error_check.begin(), d_error_check.end(), 0);

Expand Down Expand Up @@ -1120,7 +1120,7 @@ std::vector<std::unique_ptr<cudf::column>> get_json_object_batch(

// Push data to the GPU and launch the kernel again.
d_path_data = cudf::detail::make_device_uvector_async(
h_path_data, stream, rmm::mr::get_current_device_resource());
h_path_data, stream, cudf::get_current_device_resource_ref());
thrust::uninitialized_fill(
rmm::exec_policy(stream), d_error_check.begin(), d_error_check.end(), 0);
kernel_launcher::exec(input, d_path_data, d_max_path_depth_exceeded, stream);
Expand Down
4 changes: 2 additions & 2 deletions src/main/cpp/src/get_json_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::unique_ptr<cudf::column> get_json_object(
cudf::strings_column_view const& input,
std::vector<std::tuple<path_instruction_type, std::string, int32_t>> const& instructions,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Extract multiple JSON objects from a JSON string based on the specified JSON paths.
Expand All @@ -67,6 +67,6 @@ std::vector<std::unique_ptr<cudf::column>> get_json_object_multiple_paths(
int64_t memory_budget_bytes,
int32_t parallel_override,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

} // namespace spark_rapids_jni
6 changes: 3 additions & 3 deletions src/main/cpp/src/hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::unique_ptr<cudf::column> murmur_hash3_32(
cudf::table_view const& input,
uint32_t seed = 0,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Computes the xxhash64 hash value of each row in the input set of columns.
Expand All @@ -56,7 +56,7 @@ std::unique_ptr<cudf::column> xxhash64(
cudf::table_view const& input,
int64_t seed = DEFAULT_XXHASH64_SEED,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Computes the Hive hash value of each row in the input set of columns.
Expand All @@ -70,6 +70,6 @@ std::unique_ptr<cudf::column> xxhash64(
std::unique_ptr<cudf::column> hive_hash(
cudf::table_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

} // namespace spark_rapids_jni
6 changes: 3 additions & 3 deletions src/main/cpp/src/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ struct percentile_dispatcher {
// - Having nulls in the input, and/or,
// - Having empty histograms.
auto out_validities =
rmm::device_uvector<int8_t>(num_histograms, stream, rmm::mr::get_current_device_resource());
rmm::device_uvector<int8_t>(num_histograms, stream, cudf::get_current_device_resource_ref());

auto const fill_percentile = [&](auto const sorted_validity_it) {
auto const sorted_input_it =
Expand Down Expand Up @@ -307,7 +307,7 @@ std::unique_ptr<cudf::column> create_histogram_if_valid(cudf::column_view const&
}
}

auto const default_mr = rmm::mr::get_current_device_resource();
auto const default_mr = cudf::get_current_device_resource_ref();

// We only check if there is any row in frequencies that are negative (invalid) or zero.
auto check_invalid_and_zero =
Expand Down Expand Up @@ -439,7 +439,7 @@ std::unique_ptr<cudf::column> percentile_from_histogram(cudf::column_view const&
auto const data_col = cudf::structs_column_view{histograms}.get_sliced_child(0);
auto const counts_col = cudf::structs_column_view{histograms}.get_sliced_child(1);

auto const default_mr = rmm::mr::get_current_device_resource();
auto const default_mr = cudf::get_current_device_resource_ref();
auto const d_data = cudf::column_device_view::create(data_col, stream);
auto const d_percentages =
cudf::detail::make_device_uvector_sync(percentages, stream, default_mr);
Expand Down
Loading
Loading