Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusSkytte committed Oct 14, 2024
1 parent 2349ff3 commit 4110c98
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 115 deletions.
175 changes: 61 additions & 114 deletions R/db_joins.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,63 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) {
}
join_na_select_fix <- function(vars, na_by, right = FALSE) {
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")
if (length(doubly_selected_columns) == 0) {
updated_vars <- vars # no doubly selected columns
} else {
# The vars table structure is not consistent between dplyr join types
if (checkmate::test_names(names(vars), identical.to = c("name", "x", "y"))) {
updated_vars <- vars # no doubly selected columns
updated_vars <- rbind(
tibble::tibble(
name = doubly_selected_columns,
x = ifelse(right, NA, doubly_selected_columns),
y = doubly_selected_columns
),
dplyr::filter(vars, .data$x %in% !!doubly_selected_columns | .data$y %in% !!doubly_selected_columns)
) |>
dplyr::symdiff(vars)
} else if (checkmate::test_names(names(vars), identical.to = c("name", "table", "var"))) {
updated_vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(vars, .data$var %in% !!doubly_selected_columns)
) |>
dplyr::symdiff(vars)
}
}
return(updated_vars)
}
#' Warn users that SQL does not match on NA by default
#'
#' @return
#' A warning that *_joins on SQL backends does not match NA by default.
#' @noRd
join_warn <- function() {
if (interactive() && identical(parent.frame(n = 2), globalenv())) {
rlang::warn(paste("*_joins in database-backend does not match NA by default.\n",
"If your data contains NA, the columns with NA values must be supplied to \"na_by\",",
"or you must specify na_matches = \"na\""),
.frequency = "once", .frequency_id = "*_join NA warning")
rlang::warn(
paste(
"*_joins in database-backend does not match NA by default.\n",
"If your data contains NA, the columns with NA values must be supplied to \"na_by\",",
"or you must specify na_matches = \"na\""
),
.frequency = "once",
.frequency_id = "*_join NA warning"
)
}
}

Expand All @@ -98,8 +144,11 @@ join_warn <- function() {
#' @noRd
join_warn_experimental <- function() {
if (interactive() && identical(parent.frame(n = 2), globalenv())) {
rlang::warn("*_joins with na_by is still experimental. Please report issues.",
.frequency = "once", .frequency_id = "*_join NA warning")
rlang::warn(
"*_joins with na_by is still experimental. Please report issues.",
.frequency = "once",
.frequency_id = "*_join NA warning"
)
}
}

Expand Down Expand Up @@ -169,24 +218,7 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

out <- do.call(dplyr::inner_join, args = args)

# The above solution breaks select statements
# The NA columns are selected twice (from both x and from y) so
# we manually fix the select component of the lazy query
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")

if (length(doubly_selected_columns) > 0) {
out$lazy_query$vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(out$lazy_query$vars, var %in% doubly_selected_columns)
) |>
dplyr::symdiff(out$lazy_query$vars)
}
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
}
Expand All @@ -208,24 +240,7 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

out <- do.call(dplyr::left_join, args = args)

# The above solution breaks select statements
# The NA columns are selected twice (from both x and from y) so
# we manually fix the select component of the lazy query
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")

if (length(doubly_selected_columns) > 0) {
out$lazy_query$vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(out$lazy_query$vars, var %in% doubly_selected_columns)
) |>
dplyr::symdiff(out$lazy_query$vars)
}
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
}
Expand All @@ -247,24 +262,7 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

out <- do.call(dplyr::right_join, args = args)

# The above solution breaks select statements
# The NA columns are selected twice (from both x and from y) so
# we manually fix the select component of the lazy query
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")

if (length(doubly_selected_columns) > 0) {
out$lazy_query$vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(out$lazy_query$vars, var %in% doubly_selected_columns)
) |>
dplyr::symdiff(out$lazy_query$vars)
}
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by, right = TRUE)

return(out)
}
Expand All @@ -287,24 +285,7 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

out <- do.call(dplyr::full_join, args = args)

# The above solution breaks select statements
# The NA columns are selected twice (from both x and from y) so
# we manually fix the select component of the lazy query
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")

if (length(doubly_selected_columns) > 0) {
out$lazy_query$vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(out$lazy_query$vars, var %in% doubly_selected_columns)
) |>
dplyr::symdiff(out$lazy_query$vars)
}
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
}
Expand All @@ -327,24 +308,7 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

out <- do.call(dplyr::semi_join, args = args)

# The above solution breaks select statements
# The NA columns are selected twice (from both x and from y) so
# we manually fix the select component of the lazy query
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")

if (length(doubly_selected_columns) > 0) {
out$lazy_query$vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(out$lazy_query$vars, var %in% doubly_selected_columns)
) |>
dplyr::symdiff(out$lazy_query$vars)
}
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
}
Expand All @@ -367,24 +331,7 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

out <- do.call(dplyr::anti_join, args = args)

# The above solution breaks select statements
# The NA columns are selected twice (from both x and from y) so
# we manually fix the select component of the lazy query
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")

if (length(doubly_selected_columns) > 0) {
out$lazy_query$vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(out$lazy_query$vars, var %in% doubly_selected_columns)
) |>
dplyr::symdiff(out$lazy_query$vars)
}
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
}
67 changes: 66 additions & 1 deletion tests/testthat/test-db_joins.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_that("*_join() works with character `by` and `na_by`", {
dplyr::arrange(number, t, letter)
qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = "number", multiple = "all") |>
dplyr::arrange(number, t, letter)
expect_mapequal(q, qr)
expect_equal(q, qr)

q <- dplyr::right_join(x, y, na_by = "number") |>
dplyr::collect() |>
Expand Down Expand Up @@ -186,3 +186,68 @@ test_that("*_join() does not break any dplyr joins", {
connection_clean_up(conn)
}
})



test_that("*_join() with only na_by works as dplyr joins", {
for (conn in get_test_conns()) {

# Define two test datasets
x <- get_table(conn, "__mtcars") |>
dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb)

y <- get_table(conn, "__mtcars") |>
dplyr::select(name, drat, wt, qsec)

# Test the standard joins
# left_join
qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::left_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::left_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# right_join
qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::right_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::right_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# inner_join
qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::inner_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::inner_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# full_join
qr <- dplyr::full_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::full_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::full_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# semi_join
qr <- dplyr::semi_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::semi_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::semi_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# anti_join
qr <- dplyr::anti_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::anti_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::anti_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

connection_clean_up(conn)
}
})

0 comments on commit 4110c98

Please sign in to comment.