Skip to content

Commit

Permalink
apacheGH-41358: [R] Support join "na_matches" argument (apache#41372)
Browse files Browse the repository at this point in the history
### Rationale for this change

Noticed in apache#41350, I made apache#41358 to implement this in C++, but it turns
out the option was there, just buried a bit.

### What changes are included in this PR?

`na_matches` is mapped through to the `key_cmp` field in
`HashJoinNodeOptions`. Acero supports having a different value for this
for each of the join keys, but dplyr does not, so I kept it constant for
all key columns to match the dplyr behavior.

### Are these changes tested?

Yes

### Are there any user-facing changes?

Yes
* GitHub Issue: apache#41358
  • Loading branch information
nealrichardson authored Apr 26, 2024
1 parent 15986ae commit ea314a3
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 36 deletions.
1 change: 1 addition & 0 deletions r/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

* R functions that users write that use functions that Arrow supports in dataset queries now can be used in queries too. Previously, only functions that used arithmetic operators worked. For example, `time_hours <- function(mins) mins / 60` worked, but `time_hours_rounded <- function(mins) round(mins / 60)` did not; now both work. These are automatic translations rather than true user-defined functions (UDFs); for UDFs, see `register_scalar_function()`. (#41223)
* `summarize()` supports more complex expressions, and correctly handles cases where column names are reused in expressions.
* The `na_matches` argument to the `dplyr::*_join()` functions is now supported. This argument controls whether `NA` values are considered equal when joining. (#41358)

# arrow 16.0.0

Expand Down
12 changes: 6 additions & 6 deletions r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ supported_dplyr_methods <- list(
compute = NULL,
collapse = NULL,
distinct = "`.keep_all = TRUE` not supported",
left_join = "the `copy` and `na_matches` arguments are ignored",
right_join = "the `copy` and `na_matches` arguments are ignored",
inner_join = "the `copy` and `na_matches` arguments are ignored",
full_join = "the `copy` and `na_matches` arguments are ignored",
semi_join = "the `copy` and `na_matches` arguments are ignored",
anti_join = "the `copy` and `na_matches` arguments are ignored",
left_join = "the `copy` argument is ignored",
right_join = "the `copy` argument is ignored",
inner_join = "the `copy` argument is ignored",
full_join = "the `copy` argument is ignored",
semi_join = "the `copy` argument is ignored",
anti_join = "the `copy` argument is ignored",
count = NULL,
tally = NULL,
rename_with = NULL,
Expand Down
4 changes: 2 additions & 2 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions r/R/dplyr-funcs-doc.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#' which returns an `arrow` [Table], or `collect()`, which pulls the resulting
#' Table into an R `tibble`.
#'
#' * [`anti_join()`][dplyr::anti_join()]: the `copy` and `na_matches` arguments are ignored
#' * [`anti_join()`][dplyr::anti_join()]: the `copy` argument is ignored
#' * [`arrange()`][dplyr::arrange()]
#' * [`collapse()`][dplyr::collapse()]
#' * [`collect()`][dplyr::collect()]
Expand All @@ -45,22 +45,22 @@
#' * [`distinct()`][dplyr::distinct()]: `.keep_all = TRUE` not supported
#' * [`explain()`][dplyr::explain()]
#' * [`filter()`][dplyr::filter()]
#' * [`full_join()`][dplyr::full_join()]: the `copy` and `na_matches` arguments are ignored
#' * [`full_join()`][dplyr::full_join()]: the `copy` argument is ignored
#' * [`glimpse()`][dplyr::glimpse()]
#' * [`group_by()`][dplyr::group_by()]
#' * [`group_by_drop_default()`][dplyr::group_by_drop_default()]
#' * [`group_vars()`][dplyr::group_vars()]
#' * [`groups()`][dplyr::groups()]
#' * [`inner_join()`][dplyr::inner_join()]: the `copy` and `na_matches` arguments are ignored
#' * [`left_join()`][dplyr::left_join()]: the `copy` and `na_matches` arguments are ignored
#' * [`inner_join()`][dplyr::inner_join()]: the `copy` argument is ignored
#' * [`left_join()`][dplyr::left_join()]: the `copy` argument is ignored
#' * [`mutate()`][dplyr::mutate()]: window functions (e.g. things that require aggregation within groups) not currently supported
#' * [`pull()`][dplyr::pull()]: the `name` argument is not supported; returns an R vector by default but this behavior is deprecated and will return an Arrow [ChunkedArray] in a future release. Provide `as_vector = TRUE/FALSE` to control this behavior, or set `options(arrow.pull_as_vector)` globally.
#' * [`relocate()`][dplyr::relocate()]
#' * [`rename()`][dplyr::rename()]
#' * [`rename_with()`][dplyr::rename_with()]
#' * [`right_join()`][dplyr::right_join()]: the `copy` and `na_matches` arguments are ignored
#' * [`right_join()`][dplyr::right_join()]: the `copy` argument is ignored
#' * [`select()`][dplyr::select()]
#' * [`semi_join()`][dplyr::semi_join()]: the `copy` and `na_matches` arguments are ignored
#' * [`semi_join()`][dplyr::semi_join()]: the `copy` argument is ignored
#' * [`show_query()`][dplyr::show_query()]
#' * [`slice_head()`][dplyr::slice_head()]: slicing within groups not supported; Arrow datasets do not have row order, so head is non-deterministic; `prop` only supported on queries where `nrow()` is knowable without evaluating
#' * [`slice_max()`][dplyr::slice_max()]: slicing within groups not supported; `with_ties = TRUE` (dplyr default) is not supported; `prop` only supported on queries where `nrow()` is knowable without evaluating
Expand Down
8 changes: 5 additions & 3 deletions r/R/dplyr-join.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ do_join <- function(x,
suffix = c(".x", ".y"),
...,
keep = FALSE,
na_matches,
na_matches = c("na", "never"),
join_type) {
# TODO: handle `copy` arg: ignore?
# TODO: handle `na_matches` arg
x <- as_adq(x)
y <- as_adq(y)
by <- handle_join_by(by, x, y)

na_matches <- match.arg(na_matches)

# For outer joins, we need to output the join keys on both sides so we
# can coalesce them afterwards.
left_output <- if (!keep && join_type == "RIGHT_OUTER") {
Expand All @@ -54,7 +55,8 @@ do_join <- function(x,
left_output = left_output,
right_output = right_output,
suffix = suffix,
keep = keep
keep = keep,
na_matches = na_matches == "na"
)
collapse.arrow_dplyr_query(x)
}
Expand Down
8 changes: 5 additions & 3 deletions r/R/query-engine.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ ExecPlan <- R6Class("ExecPlan",
left_output = .data$join$left_output,
right_output = .data$join$right_output,
left_suffix = .data$join$suffix[[1]],
right_suffix = .data$join$suffix[[2]]
right_suffix = .data$join$suffix[[2]],
na_matches = .data$join$na_matches
)
}

Expand Down Expand Up @@ -307,7 +308,7 @@ ExecNode <- R6Class("ExecNode",
out$extras$source_schema$metadata[["r"]]$attributes <- NULL
out
},
Join = function(type, right_node, by, left_output, right_output, left_suffix, right_suffix) {
Join = function(type, right_node, by, left_output, right_output, left_suffix, right_suffix, na_matches = TRUE) {
self$preserve_extras(
ExecNode_Join(
self,
Expand All @@ -318,7 +319,8 @@ ExecNode <- R6Class("ExecNode",
left_output = left_output,
right_output = right_output,
output_suffix_for_left = left_suffix,
output_suffix_for_right = right_suffix
output_suffix_for_right = right_suffix,
na_matches = na_matches
)
)
},
Expand Down
12 changes: 6 additions & 6 deletions r/man/acero.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 13 additions & 5 deletions r/src/compute-exec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,17 @@ std::shared_ptr<acero::ExecNode> ExecNode_Join(
const std::shared_ptr<acero::ExecNode>& right_data,
std::vector<std::string> left_keys, std::vector<std::string> right_keys,
std::vector<std::string> left_output, std::vector<std::string> right_output,
std::string output_suffix_for_left, std::string output_suffix_for_right) {
std::string output_suffix_for_left, std::string output_suffix_for_right,
bool na_matches) {
std::vector<arrow::FieldRef> left_refs, right_refs, left_out_refs, right_out_refs;
std::vector<acero::JoinKeyCmp> key_cmps;
for (auto&& name : left_keys) {
left_refs.emplace_back(std::move(name));
// Populate key_cmps in this loop, one for each key
// Note that Acero supports having different values for each key, but dplyr
// only supports one value for all keys, so we're only going to support that
// for now.
key_cmps.emplace_back(na_matches ? acero::JoinKeyCmp::IS : acero::JoinKeyCmp::EQ);
}
for (auto&& name : right_keys) {
right_refs.emplace_back(std::move(name));
Expand All @@ -434,10 +441,11 @@ std::shared_ptr<acero::ExecNode> ExecNode_Join(

return MakeExecNodeOrStop(
"hashjoin", input->plan(), {input.get(), right_data.get()},
acero::HashJoinNodeOptions{
join_type, std::move(left_refs), std::move(right_refs),
std::move(left_out_refs), std::move(right_out_refs), compute::literal(true),
std::move(output_suffix_for_left), std::move(output_suffix_for_right)});
acero::HashJoinNodeOptions{join_type, std::move(left_refs), std::move(right_refs),
std::move(left_out_refs), std::move(right_out_refs),
std::move(key_cmps), compute::literal(true),
std::move(output_suffix_for_left),
std::move(output_suffix_for_right)});
}

// [[acero::export]]
Expand Down
32 changes: 32 additions & 0 deletions r/tests/testthat/test-dplyr-join.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,35 @@ test_that("full joins handle keep", {
small_dataset_df
)
})

left <- tibble::tibble(
x = c(1, NA, 3),
)
right <- tibble::tibble(
x = c(1, NA, 3),
y = c("a", "b", "c")
)
na_matches_na <- right
na_matches_never <- tibble::tibble(
x = c(1, NA, 3),
y = c("a", NA, "c")
)
test_that("na_matches argument to join: na (default)", {
expect_equal(
arrow_table(left) %>%
left_join(right, by = "x", na_matches = "na") %>%
arrange(x) %>%
collect(),
na_matches_na %>% arrange(x)
)
})

test_that("na_matches argument to join: never", {
expect_equal(
arrow_table(left) %>%
left_join(right, by = "x", na_matches = "never") %>%
arrange(x) %>%
collect(),
na_matches_never %>% arrange(x)
)
})

0 comments on commit ea314a3

Please sign in to comment.