diff --git a/src/KSVD.jl b/src/KSVD.jl index 898d011..f2d1f2a 100644 --- a/src/KSVD.jl +++ b/src/KSVD.jl @@ -87,7 +87,7 @@ To enable timing outputs, run `TimerOutputs.enable_debug_timings(KSVD)`. function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10; ksvd_update_method = BatchedParallelKSVD{false, T}(; shuffle_indices=true, batch_size_per_thread=1), sparse_coding_method = ParallelMatchingPursuit(; max_nnz, rtol=5e-2), - verbose=false, + minibatch_size=nothing, # termination conditions maxiters::Int=100, #: The maximum number of iterations to perform. Defaults to 100. maxtime::Union{Nothing, <:Real}=nothing,# : The maximum time for solving the nonlinear system of equations. Defaults to nothing which means no time limit. Note that setting a time limit does have a small overhead. @@ -96,8 +96,8 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10; nnz_per_col_target::Number=0.0, # tracing options show_trace::Bool=false, - # store_trace::Bool, callback_fn::Union{Nothing, Function}=nothing, + verbose=false, ) where T timer = TimerOutput() emb_dim, n_samples = size(Y) @@ -107,7 +107,9 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10; D = init_dictionary(T, emb_dim, n_atoms) # size(D) == (n, K) @assert all(≈(1.0), norm.(eachcol(D))) end - X = sparse(zeros(T, 0, 0)) # to assign to later + X = sparse_coding(sparse_coding_method, Y, D; timer) + + Y_ = !isnothing(minibatch_size) ? similar(Y, size(Y, 1), minibatch_size) : similar(Y, 0, 0) # progressbar = Progress(maxiter) maybe_init_buffers!(ksvd_update_method, emb_dim, n_atoms, n_samples) @@ -127,10 +129,18 @@ function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10; termination_condition = :nothing tic = time() for iter in 1:maxiters + verbose && @info "Starting svd" + (Y_, X_) = if isnothing(minibatch_size) + minibatch_indices = sort(shuffle(axes(Y, 2))[1:minibatch_size]) + Y_ .= Y[:, minibatch_indices] + X_ = X[:, minibatch_indices] + (Y_, X_) + else + (Y, X) + end + ksvd_update(ksvd_update_method, Y_, D, X_; timer) verbose && @info "Starting sparse coding" X = sparse_coding(sparse_coding_method, Y, D; timer) - verbose && @info "Starting svd" - D, X = ksvd_update(ksvd_update_method, Y, D, X; timer) # put a task to compute the trace / termination conditions. push!(trace_channel, (iter, copy(D), copy(X)))