From 2c932bb7cee975db31258555596481aa24351ee8 Mon Sep 17 00:00:00 2001 From: Jeff Fessler Date: Tue, 20 Apr 2021 11:57:40 -0400 Subject: [PATCH] tolerance for pivot_cache (#32) * tolerance * test non-float * generic type * postpone non-float test * support complex floats using real(T) * bump version --- Project.toml | 2 +- src/pivot_cache.jl | 98 ++++++++++++++++++++++++---------------------- test/pivot_test.jl | 2 +- test/runtests.jl | 40 +++++++++++++------ 4 files changed, 82 insertions(+), 60 deletions(-) diff --git a/Project.toml b/Project.toml index 2ead300..4314b49 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NonNegLeastSquares" uuid = "b7351bd1-99d9-5c5d-8786-f205a815c4d7" -version = "0.3.0" +version = "0.4.0" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/src/pivot_cache.jl b/src/pivot_cache.jl index 2d75db5..bf0dac7 100644 --- a/src/pivot_cache.jl +++ b/src/pivot_cache.jl @@ -1,29 +1,32 @@ """ -x = pivot_cache(A, b; ...) + x = pivot_cache(A, b; ...) -Solves non-negative least-squares problem by block principal pivoting method +Solve non-negative least-squares problem by block principal pivoting method (Algorithm 1) described in Kim & Park (2011). Optional arguments: - tol: tolerance for nonnegativity constraints - max_iter: maximum number of iterations +* `tol` tolerance for nonnegativity constraints; + default for `AbstractFloat` types is `10^floor(log10(eps(T)^0.5))`, + which is `1e-8` for `Float64`, otherwise reverts to `1e-8`. +* `max_iter` maximum number of iterations, default `30 * size(A,2)` References: - J. Kim and H. Park, Fast nonnegative matrix factorization: An - active-set-like method and comparisons, SIAM J. Sci. Comput., 33 (2011), - pp. 3261–3281. + J. Kim and H. Park, Fast nonnegative matrix factorization: An + active-set-like method and comparisons, SIAM J. Sci. Comput., 33 (2011), + pp. 3261–3281. https://doi.org/10.1137/110821172 """ -function pivot_cache(AtA, - Atb::AbstractVector{T}; - tol::Float64=1e-8, - max_iter=30*size(AtA,2)) where {T} - +function pivot_cache( + AtA, + Atb::AbstractVector{T}; + tol::Real = (real(T) <: AbstractFloat) ? 10^floor(log10(eps(real(T))^0.5)) : 1e-8, + max_iter=30 * size(AtA,2), +) where {T} # dimensions, initialize solution q = size(AtA,1) x = zeros(T, q) # primal variables - y = -Atb # dual variables + y = -Atb # dual variables # parameters for swapping α = 3 @@ -32,48 +35,48 @@ function pivot_cache(AtA, # Store indices for the passive set, P # we want Y[P] == 0, X[P] >= 0 # we want X[~P]== 0, Y[~P] >= 0 - P = BitArray(false for _ in 1:q) + P = falses(q) - y[(!).(P)] = AtA[(!).(P),P]*x[P] - Atb[(!).(P)] + y[(!).(P)] = AtA[(!).(P),P] * x[P] - Atb[(!).(P)] # identify indices of infeasible variables - V = @__dot__ (P & (x < -tol)) | (!P & (y < -tol)) + V = @. (P & (x < -tol)) | (!P & (y < -tol)) nV = sum(V) # while infeasible (number of infeasible variables > 0) while nV > 0 - if nV < β - # infeasible variables decreased - β = nV # store number of infeasible variables - α = 3 # reset α - else - # infeasible variables stayed the same or increased - if α >= 1 - α = α-1 # tolerate increases for α cycles - else - # backup rule - i = findlast(V) - V = zeros(Bool,q) - V[i] = true - end - end - - # update passive set + if nV < β + # infeasible variables decreased + β = nV # store number of infeasible variables + α = 3 # reset α + else + # infeasible variables stayed the same or increased + if α >= 1 + α = α-1 # tolerate increases for α cycles + else + # backup rule + i = findlast(V) + V = zeros(Bool,q) + V[i] = true + end + end + + # update passive set # P & ~V removes infeasible variables from P # V & ~P moves infeasible variables in ~P to P - @__dot__ P = (P & !V) | (V & !P) + @. P = (P & !V) | (V & !P) - # update primal/dual variables - if !all(!, P) - x[P] = _get_primal_dual(AtA, Atb, P) - end + # update primal/dual variables + if !all(!, P) + x[P] = _get_primal_dual(AtA, Atb, P) + end #x[(!).(P)] = 0.0 y[(!).(P)] = AtA[(!).(P),P]*x[P] - Atb[(!).(P)] #y[P] = 0.0 # check infeasibility - @__dot__ V = (P & (x < -tol)) | (!P & (y < -tol)) + @. V = (P & (x < -tol)) | (!P & (y < -tol)) nV = sum(V) end @@ -88,17 +91,20 @@ end return qr(AtA[P,P]) \ Atb[P] end end + @inline function _get_primal_dual(AtA, Atb, P) - return pinv(AtA[P,P])*Atb[P] + return pinv(AtA[P,P])*Atb[P] end -## if multiple right hand sides are provided, solve each problem separately. -function pivot_cache(A, - B::AbstractMatrix{T}; - gram::Bool = false, - use_parallel::Bool = true, - kwargs...) where {T} +# if multiple right hand sides are provided, solve each problem separately. +function pivot_cache( + A, + B::AbstractMatrix{T}; + gram::Bool = false, + use_parallel::Bool = true, + kwargs... +) where {T} n = size(A,2) k = size(B,2) diff --git a/test/pivot_test.jl b/test/pivot_test.jl index 8ade9d3..1759a48 100644 --- a/test/pivot_test.jl +++ b/test/pivot_test.jl @@ -1,4 +1,4 @@ -# wrapper functions for convienence +# wrapper functions for convenience nnls(A,b) = nonneg_lsq(A,b;alg=:nnls) pivot(A,b) = nonneg_lsq(A,b;alg=:pivot) diff --git a/test/runtests.jl b/test/runtests.jl index c1be493..8c207bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,6 +30,13 @@ function test_case2() return A, b, x end +function test_case3() # non-float + A = ones(Int, 4, 3) + b = 2*ones(Int, 4) + x = 2*ones(Int, 3) + return A, b, x +end + function test_algorithm(fh::Function, ε::Real=1e-5) # Solve A*x = b for x, subject to x >=0 A, b, x = test_case1() @@ -38,18 +45,18 @@ function test_algorithm(fh::Function, ε::Real=1e-5) A, b, x = test_case2() @test norm(fh(A,b) - x) < ε - ## Test a bunch of random cases - for i = 1:100 - m,n = rand(1:10),rand(1:10) - A3 = randn(m,n) - b3 = randn(m) - x3,resid = pyopt.nnls(A3,b3) - if resid > ε - @test norm(fh(A3,b3) - x3) < ε - else - @test norm(A3*fh(A3,b3) - b3) < ε - end - end + # Test a bunch of random cases + for i = 1:100 + m,n = rand(1:10),rand(1:10) + A3 = randn(m,n) + b3 = randn(m) + x3,resid = pyopt.nnls(A3,b3) + if resid > ε + @test norm(fh(A3,b3) - x3) < ε + else + @test norm(A3*fh(A3,b3) - b3) < ε + end + end end nnls(A,b) = nonneg_lsq(A, b; alg=:nnls) @@ -70,6 +77,15 @@ for (f, ε) in zip(algs, errs) println("done") end +#= non-float test fails, so revisit later +@testset "pivot_cache-non-float" begin + A, b, x = test_case3() + xi = pivot_cache(A, b) + xf = pivot_cache(Float32.(A), Float32.(b)) + @test xi ≈ xf +end +=# + @testset "comb" begin A, b, x = test_case2() x0 = pivot_comb(A, b)