diff --git a/R/db_joins.R b/R/db_joins.R index 0d207903..03a214ed 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -175,8 +175,10 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) - + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) if (!"na_by" %in% names(.dots)) { @@ -209,7 +211,10 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -244,7 +249,10 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -280,7 +288,10 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -304,7 +315,10 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -324,7 +338,10 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) diff --git a/tests/testthat/test-db_joins.R b/tests/testthat/test-db_joins.R index 320c7c25..a7c95885 100644 --- a/tests/testthat/test-db_joins.R +++ b/tests/testthat/test-db_joins.R @@ -1,4 +1,4 @@ -test_that("*_join() works", { +test_that("*_join() works with character `by` and `na_by`", { for (conn in get_test_conns()) { # Define two test datasets @@ -115,3 +115,32 @@ test_that("*_join() works", { connection_clean_up(conn) } }) + + +test_that("*_join() works with `dplyr::join_by()`", { + for (conn in get_test_conns()) { + + # Define two test datasets + x <- get_table(conn, "__mtcars") |> + dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb) + + y <- get_table(conn, "__mtcars") |> + dplyr::select(name, drat, wt, qsec) + + + # Test the implemented joins + q <- dplyr::left_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) + expect_equal(q, qr) + + q <- dplyr::right_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) + expect_equal(q, qr) + + q <- dplyr::inner_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) + expect_equal(q, qr) + + connection_clean_up(conn) + } +})