From 3a1d0ff3537f3284642c33577656d3a902d6c3f5 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 28 Jan 2025 12:28:09 +0100 Subject: [PATCH] properly use or create buffer in DataLoader (#191) --- Project.toml | 11 ----- src/MLUtils.jl | 2 +- src/batchview.jl | 28 +++++++++---- src/{eachobs.jl => dataloader.jl} | 18 ++++++-- test/Project.toml | 12 ++++++ test/batchview.jl | 30 ++++++++++++++ test/dataloader.jl | 69 +++++++++++++++++++++++++++++++ test/eachobs.jl | 67 ------------------------------ test/runtests.jl | 3 +- 9 files changed, 148 insertions(+), 92 deletions(-) rename src/{eachobs.jl => dataloader.jl} (94%) create mode 100644 test/Project.toml delete mode 100644 test/eachobs.jl diff --git a/Project.toml b/Project.toml index 6b3292e..36d8f24 100644 --- a/Project.toml +++ b/Project.toml @@ -32,14 +32,3 @@ StatsBase = "0.33, 0.34" Tables = "1.10" Transducers = "0.4" julia = "1.6" - -[extras] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[targets] -test = ["ChainRulesTestUtils", "CUDA", "DataFrames", "SparseArrays", "Test", "Zygote"] diff --git a/src/MLUtils.jl b/src/MLUtils.jl index e43098a..4c344b5 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -39,7 +39,7 @@ include("batchview.jl") export batchsize, BatchView -include("eachobs.jl") +include("dataloader.jl") export eachobs, DataLoader include("parallel.jl") diff --git a/src/batchview.jl b/src/batchview.jl index bfe15d7..3312858 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -162,15 +162,31 @@ Base.@propagate_inbounds function Base.getindex(A::BatchView, is::AbstractVector return _getbatch(A, obsindices) end -function _getbatch(A::BatchView{TElem, TData, TCollate}, obsindices) where {TElem, TData, TCollate} - return A.collate([getobs(A.data, i) for i in obsindices]) +function getobs!(buffer, A::BatchView, i::Int) + obsindices = _batchrange(A, i) + return _getbatch!(buffer, A, obsindices) +end + +function _getbatch(A::BatchView{TElem,TData,TCollate}, obsindices) where {TElem,TData,TCollate} + return A.collate([getobs(A.data, idx) for idx in obsindices]) end -function _getbatch(A::BatchView{TElem, TData, Val{false}}, obsindices) where {TElem, TData} - return [getobs(A.data, i) for i in obsindices] +function _getbatch!(buffer, A::BatchView{TElem,TData,TCollate}, obsindices) where {TElem,TData,TCollate} + return A.collate([getobs!(buffer[i], A.data, idx) for (i,idx) in enumerate(obsindices)]) +end + +function _getbatch(A::BatchView{TElem,TData,Val{false}}, obsindices) where {TElem,TData} + return [getobs(A.data, idx) for idx in obsindices] end -function _getbatch(A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData} +function _getbatch!(buffer, A::BatchView{TElem,TData,Val{false}}, obsindices) where {TElem,TData} + return [getobs!(buffer[i], A.data, idx) for (i,idx) in enumerate(obsindices)] +end + +function _getbatch(A::BatchView{TElem,TData,Val{nothing}}, obsindices) where {TElem,TData} return getobs(A.data, obsindices) end +function _getbatch!(buffer, A::BatchView{TElem,TData,Val{nothing}}, obsindices) where {TElem,TData} + return getobs!(buffer, A.data, obsindices) +end Base.parent(A::BatchView) = A.data Base.eltype(::BatchView{Tel}) where Tel = Tel @@ -196,5 +212,3 @@ function Base.showarg(io::IO, A::BatchView, toplevel) print(io, ')') toplevel && print(io, " with eltype ", nameof(eltype(A))) # simplify end - -# -------------------------------------------------------------------- diff --git a/src/eachobs.jl b/src/dataloader.jl similarity index 94% rename from src/eachobs.jl rename to src/dataloader.jl index b31fafe..0107835 100644 --- a/src/eachobs.jl +++ b/src/dataloader.jl @@ -61,12 +61,13 @@ The original data is preserved in the `data` field of the DataLoader. - **`buffer`**: If `buffer=true` and supported by the type of `data`, a buffer will be allocated and reused for memory efficiency. May want to set `partial=false` to avoid size mismatch. - Finally, can pass an external buffer to be used in `getobs!(buffer, data, idx)`. + Finally, can pass an external buffer to be used in `getobs!` + (depending on the `collate` and `batchsize` options, could be `getobs!(buffer, data, idxs)` or `getobs!(buffer[i], data, idx)`). Default `false`. - **`collate`**: Defines the batching behavior. Default `nothing`. - If `nothing` , a batch is `getobs(data, indices)`. - If `false`, each batch is `[getobs(data, i) for i in indices]`. - - If `true`, applies MLUtils to the vector of observations in a batch, + - If `true`, applies `MLUtils.batch` to the vector of observations in a batch, recursively collating arrays in the last dimensions. See [`MLUtils.batch`](@ref) for more information and examples. - If a custom function, it will be used in place of `MLUtils.batch`. It should take a vector of observations as input. @@ -138,7 +139,7 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn)) struct DataLoader{T,B,C,R<:AbstractRNG} data::T batchsize::Int - buffer::B + buffer::B # boolean, or external buffer partial::Bool shuffle::Bool parallel::Bool @@ -183,7 +184,7 @@ function Base.iterate(d::DataLoader) if d.buffer == false iter = (getobs(data, i) for i in 1:numobs(data)) elseif d.buffer == true - buf = getobs(data, 1) + buf = create_buffer(data) iter = (getobs!(buf, data, i) for i in 1:numobs(data)) else # external buffer buf = d.buffer @@ -194,6 +195,15 @@ function Base.iterate(d::DataLoader) return obs, (iter, state) end +create_buffer(x) = getobs(x, 1) +function create_buffer(x::BatchView) + obsindices = _batchrange(x, 1) + return [getobs(A.data, idx) for idx in enumerate(obsindices)] +end +function create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData} + obsindices = _batchrange(x, 1) + return getobs(x.data, obsindices) +end function Base.iterate(::DataLoader, (iter, state)) ret = iterate(iter, state) diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..b83c121 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,12 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/batchview.jl b/test/batchview.jl index 89ab083..a400921 100644 --- a/test/batchview.jl +++ b/test/batchview.jl @@ -128,4 +128,34 @@ using MLUtils: obsview @test y isa String end end + + + @testset "getobs!" begin + X = rand(4, 15) + buf1 = rand(4, 3) + bv = BatchView(X, batchsize=3) + @test @inferred(getobs!(buf1, bv, 2)) === buf1 + @test buf1 == getobs(bv, 2) + + buf12 = [rand(4) for _=1:3] + bv12 = BatchView(X, batchsize=3, collate=false) + res = @inferred(getobs!(buf12, bv12, 2)) + @test all(res .=== buf12) + @test buf12 == getobs(bv12, 2) + + @testset "custom type" begin # issue #156 + struct DummyData{X} + x::X + end + MLUtils.numobs(data::DummyData) = numobs(data.x) + MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx) + MLUtils.getobs!(buffer, data::DummyData, idx) = getobs!(buffer, data.x, idx) + + data = DummyData(X) + buf = rand(4, 3) + bv = BatchView(data, batchsize=3) + @test @inferred(getobs!(buf, bv, 2)) === buf + @test buf == getobs(bv, 2) + end + end end diff --git a/test/dataloader.jl b/test/dataloader.jl index 12f84ec..132ab75 100644 --- a/test/dataloader.jl +++ b/test/dataloader.jl @@ -282,3 +282,72 @@ end end end + +@testset "eachobs" begin + for (i,x) in enumerate(eachobs(X)) + @test x == X[:,i] + end + + for (i,x) in enumerate(eachobs(X, buffer=true)) + @test x == X[:,i] + end + + b = zeros(size(X, 1)) + for (i,x) in enumerate(eachobs(X, buffer=b)) + @test x == X[:,i] + end + @test b == X[:,end] + + @testset "batched" begin + for (i, x) in enumerate(eachobs(X, batchsize=2, partial=true)) + if i != 8 + @test size(x) == (4,2) + @test x == X[:,2i-1:2i] + else + @test size(x) == (4,1) + @test x == X[:,2i-1:2i-1] + end + end + + for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=true, partial=false)) + @test size(x) == (4,2) + @test x == X[:,2i-1:2i] + end + + b = zeros(4, 2) + for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=b, partial=false)) + @test size(x) == (4,2) + @test x == X[:,2i-1:2i] + end + @test b == X[:,end-2:end-1] + end + + @testset "shuffled" begin + # does not reshuffle on iteration + shuffled = eachobs(shuffleobs(1:50)) + @test collect(shuffled) == collect(shuffled) + + # does reshuffle + reshuffled = eachobs(1:50, shuffle = true) + @test collect(reshuffled) != collect(reshuffled) + + reshuffled = eachobs(1:50, shuffle = true, buffer = true) + @test collect(reshuffled) != collect(reshuffled) + + reshuffled = eachobs(1:50, shuffle = true, parallel = true) + @test collect(reshuffled) != collect(reshuffled) + + reshuffled = eachobs(1:50, shuffle = true, buffer = true, parallel = true) + @test collect(reshuffled) != collect(reshuffled) + end + @testset "Argument combinations" begin + for batchsize ∈ (-1, 2), buffer ∈ (true, false), collate ∈ (nothing, true, false), + parallel ∈ (true, false), shuffle ∈ (true, false), partial ∈ (true, false) + if !(buffer isa Bool) && batchsize > 0 + buffer = getobs(BatchView(X; batchsize), 1) + end + iter = eachobs(X; batchsize, shuffle, buffer, parallel, partial) + @test_nowarn for _ in iter end + end + end +end diff --git a/test/eachobs.jl b/test/eachobs.jl deleted file mode 100644 index 70194e9..0000000 --- a/test/eachobs.jl +++ /dev/null @@ -1,67 +0,0 @@ -@testset "eachobs" begin - for (i,x) in enumerate(eachobs(X)) - @test x == X[:,i] - end - - for (i,x) in enumerate(eachobs(X, buffer=true)) - @test x == X[:,i] - end - - b = zeros(size(X, 1)) - for (i,x) in enumerate(eachobs(X, buffer=b)) - @test x == X[:,i] - end - @test b == X[:,end] - - @testset "batched" begin - for (i, x) in enumerate(eachobs(X, batchsize=2, partial=true)) - if i != 8 - @test size(x) == (4,2) - @test x == X[:,2i-1:2i] - else - @test size(x) == (4,1) - @test x == X[:,2i-1:2i-1] - end - end - - for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=true, partial=false)) - @test size(x) == (4,2) - @test x == X[:,2i-1:2i] - end - - b = zeros(4, 2) - for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=b, partial=false)) - @test size(x) == (4,2) - @test x == X[:,2i-1:2i] - end - end - - @testset "shuffled" begin - # does not reshuffle on iteration - shuffled = eachobs(shuffleobs(1:50)) - @test collect(shuffled) == collect(shuffled) - - # does reshuffle - reshuffled = eachobs(1:50, shuffle = true) - @test collect(reshuffled) != collect(reshuffled) - - reshuffled = eachobs(1:50, shuffle = true, buffer = true) - @test collect(reshuffled) != collect(reshuffled) - - reshuffled = eachobs(1:50, shuffle = true, parallel = true) - @test collect(reshuffled) != collect(reshuffled) - - reshuffled = eachobs(1:50, shuffle = true, buffer = true, parallel = true) - @test collect(reshuffled) != collect(reshuffled) - end - @testset "Argument combinations" begin - for batchsize ∈ (-1, 2), buffer ∈ (true, false), collate ∈ (nothing, true, false), - parallel ∈ (true, false), shuffle ∈ (true, false), partial ∈ (true, false) - if !(buffer isa Bool) && batchsize > 0 - buffer = getobs(BatchView(X; batchsize), 1) - end - iter = eachobs(X; batchsize, shuffle, buffer, parallel, partial) - @test_nowarn for _ in iter end - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 9da8387..14b4076 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,7 @@ using Transducers using ChainRulesTestUtils: test_rrule using Zygote: ZygoteRuleConfig using ChainRulesCore: rrule_via_ad -using DataFrames +using DataFrames: DataFrame using CUDA showcompact(io, x) = show(IOContext(io, :compact => true), x) @@ -90,7 +90,6 @@ include("test_utils.jl") # @testset "MLUtils.jl" begin @testset "batchview" begin; include("batchview.jl"); end -@testset "eachobs" begin; include("eachobs.jl"); end @testset "dataloader" begin; include("dataloader.jl"); end @testset "folds" begin; include("folds.jl"); end @testset "observation" begin; include("observation.jl"); end