Skip to content

Commit

Permalink
Using as_record_batch_reader() for batching. Deprecating batchSize. F…
Browse files Browse the repository at this point in the history
…ixes #53
  • Loading branch information
Admin_mschuemi authored and Admin_mschuemi committed Mar 29, 2023
1 parent 5c5ae46 commit 68f9dbe
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 26 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ Suggests:
DBI
LazyData: false
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
Encoding: UTF-8
24 changes: 13 additions & 11 deletions R/Operations.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#' @param tbl An [`Andromeda`] table (or any other 'DBI' table).
#' @param fun A function where the first argument is a data frame.
#' @param ... Additional parameters passed to fun.
#' @param batchSize Number of rows to fetch at a time.
#' @param batchSize DEPRECATED: Number of rows to fetch at a time.
#' @param progressBar Show a progress bar?
#' @param safe Create a copy of tbl first? Allows writing to the same Andromeda as being
#' read from.
Expand All @@ -43,7 +43,7 @@
#' return(nrow(x))
#' }
#'
#' result <- batchApply(andr$cars, fun, batchSize = 25)
#' result <- batchApply(andr$cars, fun)
#'
#' result
#' # [[1]]
Expand All @@ -55,13 +55,18 @@
#' close(andr)
#' }
#' @export
batchApply <- function(tbl, fun, ..., batchSize = 100000, progressBar = FALSE, safe = FALSE) {
batchApply <- function(tbl, fun, ..., batchSize, progressBar = FALSE, safe = FALSE) {
if (!inherits(tbl, c("FileSystemDataset", "arrow_dplyr_query"))) {
abort("First argument must be an Andromeda table or a dplyr query of an Andromeda table")
}
if (!is.function(fun)) abort("Second argument must be a function")

if (safe || inherits(tbl, "arrow_dplyr_query")) {
if (!missing(batchSize)) {
rlang::warn("The `batchSize` argument is deprecated.",
.frequency = "regularly",
.frequency_id = "batchSize"
)
}
if (safe) {
tempAndromeda <- andromeda()
on.exit(close(tempAndromeda))
tempAndromeda$tbl <- tbl
Expand All @@ -71,10 +76,7 @@ batchApply <- function(tbl, fun, ..., batchSize = 100000, progressBar = FALSE, s
if(nrow(tbl) == 0) {
return(list())
}

scanner <- arrow::ScannerBuilder$create(tbl)$BatchSize(batch_size = batchSize)$Finish()
reader <- scanner$ToRecordBatchReader()

reader <- arrow::as_record_batch_reader(tbl)
output <- list()
if (progressBar) {
pb <- txtProgressBar(style = 3)
Expand Down Expand Up @@ -115,7 +117,7 @@ batchApply <- function(tbl, fun, ..., batchSize = 100000, progressBar = FALSE, s
#' @param groupVariable The variable to group by
#' @param fun A function where the first argument is a data frame.
#' @param ... Additional parameters passed to fun.
#' @param batchSize Number of rows fetched from the table at a time. This is not the number of
#' @param batchSize DEPRECATED: Number of rows fetched from the table at a time. This is not the number of
#' rows to which the function will be applied. Included mostly for testing
#' purposes.
#' @param progressBar Show a progress bar?
Expand Down Expand Up @@ -154,7 +156,7 @@ batchApply <- function(tbl, fun, ..., batchSize = 100000, progressBar = FALSE, s
#' close(andr)
#' }
#' @export
groupApply <- function(tbl, groupVariable, fun, ..., batchSize = 100000, progressBar = FALSE, safe = FALSE) {
groupApply <- function(tbl, groupVariable, fun, ..., batchSize, progressBar = FALSE, safe = FALSE) {
if (!groupVariable %in% names(tbl))
abort(sprintf("'%s' is not a variable in the table", groupVariable))

Expand Down
6 changes: 3 additions & 3 deletions man/batchApply.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/groupApply.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 9 additions & 9 deletions tests/testthat/test-batching.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ test_that("batchApply", {
return(nrow(batch) * multiplier)
}

result <- batchApply(andromeda$cars, doSomething, multiplier = 2, batchSize = 10)
result <- batchApply(andromeda$cars, doSomething, multiplier = 2)
result <- unlist(result)
expect_true(sum(result) == nrow(cars) * 2)
expect_true(length(result) == ceiling(nrow(cars)/10))
# expect_true(length(result) == ceiling(nrow(cars)/10))
rm(result)

# batchApply can also accept an arrow_dplyr_query
query <- dplyr::mutate(andromeda$cars, new_column = speed*dist)
expect_s3_class(query, "arrow_dplyr_query")
result <- query %>% batchApply(doSomething, multiplier = 2, batchSize = 10)
result <- query %>% batchApply(doSomething, multiplier = 2)
result <- unlist(result)
expect_true(sum(result) == nrow(cars) * 2)
expect_true(length(result) == ceiling(nrow(cars)/10))
# expect_true(length(result) == ceiling(nrow(cars)/10))

close(andromeda)
})
Expand All @@ -37,7 +37,7 @@ test_that("batchApply safe mode", {
appendToTable(andromeda$cars2, batch)
}
}
batchApply(andromeda$cars, doSomething, multiplier = 2, batchSize = 10, safe = TRUE)
batchApply(andromeda$cars, doSomething, multiplier = 2, safe = TRUE)

cars2 <- andromeda$cars2 %>% collect() %>% arrange(speed, dist)
cars1 <- cars %>% arrange(speed, dist)
Expand All @@ -54,8 +54,8 @@ test_that("batchApply progress bar", {
doSomething <- function(batch, multiplier) {
return(nrow(batch) * multiplier)
}
result <- capture_output(batchApply(andromeda$cars, doSomething, multiplier = 2, batchSize = 10, progressBar = TRUE))
expect_true(stringr::str_count(result, "=") > 100)
result <- capture_output(batchApply(andromeda$cars, doSomething, multiplier = 2, progressBar = TRUE))
expect_true(grepl("100%", result))
close(andromeda)
})

Expand All @@ -81,8 +81,8 @@ test_that("groupApply progress bar", {
doSomething <- function(batch, multiplier) {
return(nrow(batch) * multiplier)
}
result <- capture_output(groupApply(andromeda$cars, "speed", doSomething, multiplier = 2, batchSize = 10, progressBar = TRUE))
expect_true(stringr::str_count(result, "=") > 100)
result <- capture_output(groupApply(andromeda$cars, "speed", doSomething, multiplier = 2, progressBar = TRUE))
expect_true(grepl("100%", result))
close(andromeda)
})

Expand Down

0 comments on commit 68f9dbe

Please sign in to comment.