Skip to content

Commit

Permalink
Merge pull request #29 from JuliaFolds2/cb/sizegtlen
Browse files Browse the repository at this point in the history
Support `size>length(input)`
  • Loading branch information
lmiq authored Mar 7, 2024
2 parents 8c7b63c + 7edce29 commit f65dce0
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/ChunkSplitters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,10 @@ function chunks(data; n::Integer=0, size::Integer=0, split::Symbol=:batch)
else
C = FixedSize
size >= 1 || throw(ArgumentError("size must be >= 1"))
size <= length(data) || throw(ArgumentError("size must be <= length(data)"))
end
is_chunkable(data) || not_chunkable_err(data)
(split in split_types) || split_err()
Chunk{typeof(data),C}(data, min(length(data), n), size, split)
Chunk{typeof(data),C}(data, min(length(data), n), min(length(data), size), split)
end
function missing_input_err()
throw(ArgumentError("You must either indicate the desired number of chunks (n) or the target size of a chunk (size)."))
Expand Down Expand Up @@ -281,7 +280,7 @@ function getchunk(itr, ichunk::Integer; n::Integer=0, size::Integer=0, split::Sy
C = FixedSize
size >= 1 || throw(ArgumentError("size must be >= 1"))
l = length(itr)
size <= l || throw(ArgumentError("size must be <= length(itr)"))
size = min(l, size) # handle size>length(itr)
n = cld(l, size)
end
ichunk <= n || throw(ArgumentError("index must be less or equal to number of chunks ($n)"))
Expand Down Expand Up @@ -507,6 +506,10 @@ end
# FixedSize
c = chunks(1:10; size=5)
@test length(c) == 2
# When size > array_length, we shouldn't create more than one chunk
c = chunks(1:10; size=20)
@test length(c) == 1
@test length(first(c)) == 10
for (l, s) in [(13, 10), (5, 2), (42, 7), (22, 15)]
local c = chunks(1:l; size=s)
@test all(length(c[i]) == length(c[i+1]) for i in 1:length(c)-2) # only the last chunk may have different length
Expand Down

0 comments on commit f65dce0

Please sign in to comment.