Skip to content

Commit

Permalink
properly use or create buffer in DataLoader (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 28, 2025
1 parent 469c272 commit 3a1d0ff
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 92 deletions.
11 changes: 0 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,3 @@ StatsBase = "0.33, 0.34"
Tables = "1.10"
Transducers = "0.4"
julia = "1.6"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "CUDA", "DataFrames", "SparseArrays", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ include("batchview.jl")
export batchsize,
BatchView

include("eachobs.jl")
include("dataloader.jl")
export eachobs, DataLoader

include("parallel.jl")
Expand Down
28 changes: 21 additions & 7 deletions src/batchview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,31 @@ Base.@propagate_inbounds function Base.getindex(A::BatchView, is::AbstractVector
return _getbatch(A, obsindices)
end

function _getbatch(A::BatchView{TElem, TData, TCollate}, obsindices) where {TElem, TData, TCollate}
return A.collate([getobs(A.data, i) for i in obsindices])
function getobs!(buffer, A::BatchView, i::Int)
obsindices = _batchrange(A, i)
return _getbatch!(buffer, A, obsindices)
end

function _getbatch(A::BatchView{TElem,TData,TCollate}, obsindices) where {TElem,TData,TCollate}
return A.collate([getobs(A.data, idx) for idx in obsindices])
end
function _getbatch(A::BatchView{TElem, TData, Val{false}}, obsindices) where {TElem, TData}
return [getobs(A.data, i) for i in obsindices]
function _getbatch!(buffer, A::BatchView{TElem,TData,TCollate}, obsindices) where {TElem,TData,TCollate}
return A.collate([getobs!(buffer[i], A.data, idx) for (i,idx) in enumerate(obsindices)])
end

function _getbatch(A::BatchView{TElem,TData,Val{false}}, obsindices) where {TElem,TData}
return [getobs(A.data, idx) for idx in obsindices]
end
function _getbatch(A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData}
function _getbatch!(buffer, A::BatchView{TElem,TData,Val{false}}, obsindices) where {TElem,TData}
return [getobs!(buffer[i], A.data, idx) for (i,idx) in enumerate(obsindices)]
end

function _getbatch(A::BatchView{TElem,TData,Val{nothing}}, obsindices) where {TElem,TData}
return getobs(A.data, obsindices)
end
function _getbatch!(buffer, A::BatchView{TElem,TData,Val{nothing}}, obsindices) where {TElem,TData}
return getobs!(buffer, A.data, obsindices)
end

Base.parent(A::BatchView) = A.data
Base.eltype(::BatchView{Tel}) where Tel = Tel
Expand All @@ -196,5 +212,3 @@ function Base.showarg(io::IO, A::BatchView, toplevel)
print(io, ')')
toplevel && print(io, " with eltype ", nameof(eltype(A))) # simplify
end

# --------------------------------------------------------------------
18 changes: 14 additions & 4 deletions src/eachobs.jl → src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ The original data is preserved in the `data` field of the DataLoader.
- **`buffer`**: If `buffer=true` and supported by the type of `data`,
a buffer will be allocated and reused for memory efficiency.
May want to set `partial=false` to avoid size mismatch.
Finally, can pass an external buffer to be used in `getobs!(buffer, data, idx)`.
Finally, can pass an external buffer to be used in `getobs!`
(depending on the `collate` and `batchsize` options, could be `getobs!(buffer, data, idxs)` or `getobs!(buffer[i], data, idx)`).
Default `false`.
- **`collate`**: Defines the batching behavior. Default `nothing`.
- If `nothing` , a batch is `getobs(data, indices)`.
- If `false`, each batch is `[getobs(data, i) for i in indices]`.
- If `true`, applies MLUtils to the vector of observations in a batch,
- If `true`, applies `MLUtils.batch` to the vector of observations in a batch,
recursively collating arrays in the last dimensions. See [`MLUtils.batch`](@ref) for more information
and examples.
- If a custom function, it will be used in place of `MLUtils.batch`. It should take a vector of observations as input.
Expand Down Expand Up @@ -138,7 +139,7 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn))
struct DataLoader{T,B,C,R<:AbstractRNG}
data::T
batchsize::Int
buffer::B
buffer::B # boolean, or external buffer
partial::Bool
shuffle::Bool
parallel::Bool
Expand Down Expand Up @@ -183,7 +184,7 @@ function Base.iterate(d::DataLoader)
if d.buffer == false
iter = (getobs(data, i) for i in 1:numobs(data))
elseif d.buffer == true
buf = getobs(data, 1)
buf = create_buffer(data)
iter = (getobs!(buf, data, i) for i in 1:numobs(data))
else # external buffer
buf = d.buffer
Expand All @@ -194,6 +195,15 @@ function Base.iterate(d::DataLoader)
return obs, (iter, state)
end

create_buffer(x) = getobs(x, 1)
function create_buffer(x::BatchView)
obsindices = _batchrange(x, 1)
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
end
function create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
obsindices = _batchrange(x, 1)
return getobs(x.data, obsindices)
end

function Base.iterate(::DataLoader, (iter, state))
ret = iterate(iter, state)
Expand Down
12 changes: 12 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
30 changes: 30 additions & 0 deletions test/batchview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,34 @@ using MLUtils: obsview
@test y isa String
end
end


@testset "getobs!" begin
X = rand(4, 15)
buf1 = rand(4, 3)
bv = BatchView(X, batchsize=3)
@test @inferred(getobs!(buf1, bv, 2)) === buf1
@test buf1 == getobs(bv, 2)

buf12 = [rand(4) for _=1:3]
bv12 = BatchView(X, batchsize=3, collate=false)
res = @inferred(getobs!(buf12, bv12, 2))
@test all(res .=== buf12)
@test buf12 == getobs(bv12, 2)

@testset "custom type" begin # issue #156
struct DummyData{X}
x::X
end
MLUtils.numobs(data::DummyData) = numobs(data.x)
MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx)
MLUtils.getobs!(buffer, data::DummyData, idx) = getobs!(buffer, data.x, idx)

data = DummyData(X)
buf = rand(4, 3)
bv = BatchView(data, batchsize=3)
@test @inferred(getobs!(buf, bv, 2)) === buf
@test buf == getobs(bv, 2)
end
end
end
69 changes: 69 additions & 0 deletions test/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,72 @@
end
end
end

@testset "eachobs" begin
for (i,x) in enumerate(eachobs(X))
@test x == X[:,i]
end

for (i,x) in enumerate(eachobs(X, buffer=true))
@test x == X[:,i]
end

b = zeros(size(X, 1))
for (i,x) in enumerate(eachobs(X, buffer=b))
@test x == X[:,i]
end
@test b == X[:,end]

@testset "batched" begin
for (i, x) in enumerate(eachobs(X, batchsize=2, partial=true))
if i != 8
@test size(x) == (4,2)
@test x == X[:,2i-1:2i]
else
@test size(x) == (4,1)
@test x == X[:,2i-1:2i-1]
end
end

for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=true, partial=false))
@test size(x) == (4,2)
@test x == X[:,2i-1:2i]
end

b = zeros(4, 2)
for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=b, partial=false))
@test size(x) == (4,2)
@test x == X[:,2i-1:2i]
end
@test b == X[:,end-2:end-1]
end

@testset "shuffled" begin
# does not reshuffle on iteration
shuffled = eachobs(shuffleobs(1:50))
@test collect(shuffled) == collect(shuffled)

# does reshuffle
reshuffled = eachobs(1:50, shuffle = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, buffer = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, parallel = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, buffer = true, parallel = true)
@test collect(reshuffled) != collect(reshuffled)
end
@testset "Argument combinations" begin
for batchsize (-1, 2), buffer (true, false), collate (nothing, true, false),
parallel (true, false), shuffle (true, false), partial (true, false)
if !(buffer isa Bool) && batchsize > 0
buffer = getobs(BatchView(X; batchsize), 1)
end
iter = eachobs(X; batchsize, shuffle, buffer, parallel, partial)
@test_nowarn for _ in iter end
end
end
end
67 changes: 0 additions & 67 deletions test/eachobs.jl

This file was deleted.

3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Transducers
using ChainRulesTestUtils: test_rrule
using Zygote: ZygoteRuleConfig
using ChainRulesCore: rrule_via_ad
using DataFrames
using DataFrames: DataFrame
using CUDA

showcompact(io, x) = show(IOContext(io, :compact => true), x)
Expand Down Expand Up @@ -90,7 +90,6 @@ include("test_utils.jl")
# @testset "MLUtils.jl" begin

@testset "batchview" begin; include("batchview.jl"); end
@testset "eachobs" begin; include("eachobs.jl"); end
@testset "dataloader" begin; include("dataloader.jl"); end
@testset "folds" begin; include("folds.jl"); end
@testset "observation" begin; include("observation.jl"); end
Expand Down

0 comments on commit 3a1d0ff

Please sign in to comment.