Skip to content

Commit

Permalink
Implement minibatching on the samples.
Browse files Browse the repository at this point in the history
Note that we also reveres the order of sparse coding and ksvd_update

Patch important bug with minibatch
  • Loading branch information
RomeoV committed Jul 19, 2024
1 parent e4274e2 commit e4c2fd2
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/KSVD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ To enable timing outputs, run `TimerOutputs.enable_debug_timings(KSVD)`.
function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
ksvd_update_method = BatchedParallelKSVD{false, T}(; shuffle_indices=true, batch_size_per_thread=1),
sparse_coding_method = ParallelMatchingPursuit(; max_nnz, rtol=5e-2),
verbose=false,
minibatch_size=nothing,
# termination conditions
maxiters::Int=100, #: The maximum number of iterations to perform. Defaults to 100.
maxtime::Union{Nothing, <:Real}=nothing,# : The maximum time for solving the nonlinear system of equations. Defaults to nothing which means no time limit. Note that setting a time limit does have a small overhead.
Expand All @@ -96,8 +96,8 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
nnz_per_col_target::Number=0.0,
# tracing options
show_trace::Bool=false,
# store_trace::Bool,
callback_fn::Union{Nothing, Function}=nothing,
verbose=false,
) where T
timer = TimerOutput()
emb_dim, n_samples = size(Y)
Expand All @@ -107,7 +107,9 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
D = init_dictionary(T, emb_dim, n_atoms) # size(D) == (n, K)
@assert all((1.0), norm.(eachcol(D)))
end
X = sparse(zeros(T, 0, 0)) # to assign to later
X = sparse_coding(sparse_coding_method, Y, D; timer)

Y_ = !isnothing(minibatch_size) ? similar(Y, size(Y, 1), minibatch_size) : similar(Y, 0, 0)

# progressbar = Progress(maxiter)
maybe_init_buffers!(ksvd_update_method, emb_dim, n_atoms, n_samples)
Expand All @@ -127,10 +129,18 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10;
termination_condition = :nothing
tic = time()
for iter in 1:maxiters
verbose && @info "Starting svd"
(Y_, X_) = if !isnothing(minibatch_size)
minibatch_indices = sort(shuffle(axes(Y, 2))[1:minibatch_size])
Y_ .= Y[:, minibatch_indices]
X_ = X[:, minibatch_indices]
(Y_, X_)
else
(Y, X)
end
ksvd_update(ksvd_update_method, Y_, D, X_; timer)
verbose && @info "Starting sparse coding"
X = sparse_coding(sparse_coding_method, Y, D; timer)
verbose && @info "Starting svd"
D, X = ksvd_update(ksvd_update_method, Y, D, X; timer)

# put a task to compute the trace / termination conditions.
push!(trace_channel, (iter, copy(D), copy(X)))
Expand Down

0 comments on commit e4c2fd2

Please sign in to comment.