diff --git a/src/KSVD.jl b/src/KSVD.jl index 812238a..c55c242 100644 --- a/src/KSVD.jl +++ b/src/KSVD.jl @@ -59,7 +59,6 @@ Dictionary vectors will be normalized such that ~all(norm.(eachcol(D), 2) .≈ 1 # Arguments - `Y::AbstractMatrix{T}`: Input data matrix of size (num_features x num_samples) - `n_atoms::Int`: Number of atoms (columns) in the dictionary -- `max_nnz=max(n_atoms÷10, 1)`: Maximum number of non-zero coefficients in each sparse representation # Keyword Arguments - `ksvd_update_method`: Method used for updating the dictionary (default: BatchedParallelKSVD) @@ -84,7 +83,7 @@ A named tuple containing: # Notes To enable timing outputs, run `TimerOutputs.enable_debug_timings(KSVD)`. """ -function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=n_atoms÷10; +function ksvd(Y::AbstractMatrix{T}, n_atoms::Int; ksvd_update_method = BatchedParallelKSVD{false, T}(; shuffle_indices=true, batch_size_per_thread=1), sparse_coding_method = ParallelMatchingPursuit(; max_nnz, rtol=5e-2), minibatch_size=nothing,