Skip to content

Commit

Permalink
tolerance for pivot_cache (#32)
Browse files Browse the repository at this point in the history
* tolerance

* test non-float

* generic type

* postpone non-float test

* support complex floats using real(T)

* bump version
  • Loading branch information
JeffFessler authored Apr 20, 2021
1 parent a7ff3bb commit 2c932bb
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NonNegLeastSquares"
uuid = "b7351bd1-99d9-5c5d-8786-f205a815c4d7"
version = "0.3.0"
version = "0.4.0"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
98 changes: 52 additions & 46 deletions src/pivot_cache.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
"""
x = pivot_cache(A, b; ...)
x = pivot_cache(A, b; ...)
Solves non-negative least-squares problem by block principal pivoting method
Solve non-negative least-squares problem by block principal pivoting method
(Algorithm 1) described in Kim & Park (2011).
Optional arguments:
tol: tolerance for nonnegativity constraints
max_iter: maximum number of iterations
* `tol` tolerance for nonnegativity constraints;
default for `AbstractFloat` types is `10^floor(log10(eps(T)^0.5))`,
which is `1e-8` for `Float64`, otherwise reverts to `1e-8`.
* `max_iter` maximum number of iterations, default `30 * size(A,2)`
References:
J. Kim and H. Park, Fast nonnegative matrix factorization: An
active-set-like method and comparisons, SIAM J. Sci. Comput., 33 (2011),
pp. 3261–3281.
J. Kim and H. Park, Fast nonnegative matrix factorization: An
active-set-like method and comparisons, SIAM J. Sci. Comput., 33 (2011),
pp. 3261–3281. https://doi.org/10.1137/110821172
"""
function pivot_cache(AtA,
Atb::AbstractVector{T};
tol::Float64=1e-8,
max_iter=30*size(AtA,2)) where {T}

function pivot_cache(
AtA,
Atb::AbstractVector{T};
tol::Real = (real(T) <: AbstractFloat) ? 10^floor(log10(eps(real(T))^0.5)) : 1e-8,
max_iter=30 * size(AtA,2),
) where {T}

# dimensions, initialize solution
q = size(AtA,1)

x = zeros(T, q) # primal variables
y = -Atb # dual variables
y = -Atb # dual variables

# parameters for swapping
α = 3
Expand All @@ -32,48 +35,48 @@ function pivot_cache(AtA,
# Store indices for the passive set, P
# we want Y[P] == 0, X[P] >= 0
# we want X[~P]== 0, Y[~P] >= 0
P = BitArray(false for _ in 1:q)
P = falses(q)

y[(!).(P)] = AtA[(!).(P),P]*x[P] - Atb[(!).(P)]
y[(!).(P)] = AtA[(!).(P),P] * x[P] - Atb[(!).(P)]

# identify indices of infeasible variables
V = @__dot__ (P & (x < -tol)) | (!P & (y < -tol))
V = @. (P & (x < -tol)) | (!P & (y < -tol))
nV = sum(V)

# while infeasible (number of infeasible variables > 0)
while nV > 0

if nV < β
# infeasible variables decreased
β = nV # store number of infeasible variables
α = 3 # reset α
else
# infeasible variables stayed the same or increased
if α >= 1
α = α-1 # tolerate increases for α cycles
else
# backup rule
i = findlast(V)
V = zeros(Bool,q)
V[i] = true
end
end

# update passive set
if nV < β
# infeasible variables decreased
β = nV # store number of infeasible variables
α = 3 # reset α
else
# infeasible variables stayed the same or increased
if α >= 1
α = α-1 # tolerate increases for α cycles
else
# backup rule
i = findlast(V)
V = zeros(Bool,q)
V[i] = true
end
end

# update passive set
# P & ~V removes infeasible variables from P
# V & ~P moves infeasible variables in ~P to P
@__dot__ P = (P & !V) | (V & !P)
@. P = (P & !V) | (V & !P)

# update primal/dual variables
if !all(!, P)
x[P] = _get_primal_dual(AtA, Atb, P)
end
# update primal/dual variables
if !all(!, P)
x[P] = _get_primal_dual(AtA, Atb, P)
end
#x[(!).(P)] = 0.0
y[(!).(P)] = AtA[(!).(P),P]*x[P] - Atb[(!).(P)]
#y[P] = 0.0

# check infeasibility
@__dot__ V = (P & (x < -tol)) | (!P & (y < -tol))
@. V = (P & (x < -tol)) | (!P & (y < -tol))
nV = sum(V)
end

Expand All @@ -88,17 +91,20 @@ end
return qr(AtA[P,P]) \ Atb[P]
end
end

@inline function _get_primal_dual(AtA, Atb, P)
return pinv(AtA[P,P])*Atb[P]
return pinv(AtA[P,P])*Atb[P]
end


## if multiple right hand sides are provided, solve each problem separately.
function pivot_cache(A,
B::AbstractMatrix{T};
gram::Bool = false,
use_parallel::Bool = true,
kwargs...) where {T}
# if multiple right hand sides are provided, solve each problem separately.
function pivot_cache(
A,
B::AbstractMatrix{T};
gram::Bool = false,
use_parallel::Bool = true,
kwargs...
) where {T}

n = size(A,2)
k = size(B,2)
Expand Down
2 changes: 1 addition & 1 deletion test/pivot_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# wrapper functions for convienence
# wrapper functions for convenience
nnls(A,b) = nonneg_lsq(A,b;alg=:nnls)
pivot(A,b) = nonneg_lsq(A,b;alg=:pivot)

Expand Down
40 changes: 28 additions & 12 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ function test_case2()
return A, b, x
end

function test_case3() # non-float
A = ones(Int, 4, 3)
b = 2*ones(Int, 4)
x = 2*ones(Int, 3)
return A, b, x
end

function test_algorithm(fh::Function, ε::Real=1e-5)
# Solve A*x = b for x, subject to x >=0
A, b, x = test_case1()
Expand All @@ -38,18 +45,18 @@ function test_algorithm(fh::Function, ε::Real=1e-5)
A, b, x = test_case2()
@test norm(fh(A,b) - x) < ε

## Test a bunch of random cases
for i = 1:100
m,n = rand(1:10),rand(1:10)
A3 = randn(m,n)
b3 = randn(m)
x3,resid = pyopt.nnls(A3,b3)
if resid > ε
@test norm(fh(A3,b3) - x3) < ε
else
@test norm(A3*fh(A3,b3) - b3) < ε
end
end
# Test a bunch of random cases
for i = 1:100
m,n = rand(1:10),rand(1:10)
A3 = randn(m,n)
b3 = randn(m)
x3,resid = pyopt.nnls(A3,b3)
if resid > ε
@test norm(fh(A3,b3) - x3) < ε
else
@test norm(A3*fh(A3,b3) - b3) < ε
end
end
end

nnls(A,b) = nonneg_lsq(A, b; alg=:nnls)
Expand All @@ -70,6 +77,15 @@ for (f, ε) in zip(algs, errs)
println("done")
end

#= non-float test fails, so revisit later
@testset "pivot_cache-non-float" begin
A, b, x = test_case3()
xi = pivot_cache(A, b)
xf = pivot_cache(Float32.(A), Float32.(b))
@test xi ≈ xf
end
=#

@testset "comb" begin
A, b, x = test_case2()
x0 = pivot_comb(A, b)
Expand Down

2 comments on commit 2c932bb

@JeffFessler
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/34759

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 2c932bb7cee975db31258555596481aa24351ee8
git push origin v0.4.0

Please sign in to comment.