Skip to content

Commit

Permalink
Add lambda to marginalize (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
ojwoodford authored Nov 3, 2023
1 parent bfa26ce commit 8c4e76e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
20 changes: 10 additions & 10 deletions src/marginalize.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using StaticArrays, HybridArrays, Static
using StaticArrays, HybridArrays, Static, LinearAlgebra

function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::Int, blocksz::Int)
function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::Int, lambda::Float64, blocksz::Int)
# Get the list of blocks to marginalize out
ind = from.A.indicestransposed.colptr[blockind]:from.A.indicestransposed.colptr[blockind+1]-1
blocks = view(from.A.indicestransposed.rowval, ind)
Expand All @@ -10,7 +10,7 @@ function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::
N = length(blocks) - 1
dataindices = view(from.A.indicestransposed.nzval, ind)
# Get the diagonal block (to be marginalized)
diagblock = reshape(view(from.A.data, (0:blocksz*blocksz-1) .+ dataindices[end]), blocksz, blocksz)
diagblock = reshape(view(from.A.data, (0:blocksz*blocksz-1) .+ dataindices[end]), blocksz, blocksz) + I * lambda
@static if VERSION v"1.9"
diagblock = bunchkaufman(diagblock)
end
Expand All @@ -33,7 +33,7 @@ function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::
end
end

function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::Int, ::StaticInt{blocksz}) where blocksz
function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::Int, lambda::Float64, ::StaticInt{blocksz}) where blocksz
# Get the list of blocks to marginalize out
ind = from.A.indicestransposed.colptr[blockind]:from.A.indicestransposed.colptr[blockind+1]-1
blocks = view(from.A.indicestransposed.rowval, ind)
Expand All @@ -43,7 +43,7 @@ function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::
N = length(blocks) - 1
dataindices = view(from.A.indicestransposed.nzval, ind)
# Get the diagonal block (to be marginalized)
diagblock = SMatrix{blocksz, blocksz}(SizedMatrix{blocksz, blocksz}(view(from.A.data, SR(0, blocksz*blocksz-1).+dataindices[end])))
diagblock = SMatrix{blocksz, blocksz}(SizedMatrix{blocksz, blocksz}(view(from.A.data, SR(0, blocksz*blocksz-1).+dataindices[end]))) + I * lambda
@static if VERSION v"1.9"
diagblock = bunchkaufman(diagblock)
end
Expand All @@ -65,13 +65,13 @@ function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blockind::
end
end

function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blocks::AbstractRange, blocksz)
function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, blocks::AbstractRange, lambdas, blocksz)
for block in blocks
marginalize!(to, from, block, blocksz)
marginalize!(to, from, block, @inbounds(lambdas[min(block, end)]), blocksz)
end
end

function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, fromblock = isa(to, MultiVariateLSsparse) ? length(to.A.rowblocksizes)+1 : length(to.A.rowblockoffsets))
function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, lambdas, fromblock = isa(to, MultiVariateLSsparse) ? length(to.A.rowblocksizes)+1 : length(to.A.rowblockoffsets))
last = fromblock
finish = length(from.A.rowblocksizes)
while last <= finish
Expand All @@ -86,9 +86,9 @@ function marginalize!(to::MultiVariateLS, from::MultiVariateLSsparse, fromblock
range = first:last-1
if blocksz <= MAX_BLOCK_SZ
# marginalize!(to, from, first:last, static(blocksz))
valuedispatch(static(1), static(MAX_BLOCK_SZ), blocksz, fixallbutlast(marginalize!, to, from, range))
valuedispatch(static(1), static(MAX_BLOCK_SZ), blocksz, fixallbutlast(marginalize!, to, from, range, lambdas))
else
marginalize!(to, from, range, blocksz)
marginalize!(to, from, range, lambdas, blocksz)
end
end
return to
Expand Down
31 changes: 19 additions & 12 deletions test/marginalize.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using NLLSsolver, SparseArrays, StaticArrays, Test, Random
using NLLSsolver, SparseArrays, StaticArrays, Test, Random, LinearAlgebra

@testset "marginalize.jl" begin
function test_marginalize(lambda)
# Define the size of test problem
blocksizes = [1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 1, 1]
fromblock = 7
Expand All @@ -9,7 +9,6 @@ using NLLSsolver, SparseArrays, StaticArrays, Test, Random

# Intitialize a sparse linear system randomly
from = NLLSsolver.MultiVariateLSsparse(NLLSsolver.BlockSparseMatrix{Float64}(sparse(cols, rows, trues(length(rows))), blocksizes, blocksizes), 1:length(blocksizes))
Random.seed!(1)
from.A.data .= randn(length(from.A.data))
from.b .= randn(length(from.b))
# Make the diagonal blocks symmetric
Expand All @@ -19,7 +18,7 @@ using NLLSsolver, SparseArrays, StaticArrays, Test, Random
diagblock .= diagblock * diagblock'
end
end

# Construct the cropped system
to_d = NLLSsolver.constructcrop(from, fromblock)
@test isa(to_d, NLLSsolver.MultiVariateLSdense)
Expand All @@ -39,6 +38,7 @@ using NLLSsolver, SparseArrays, StaticArrays, Test, Random
@test all((to_s.boffsets .+ blocksizes[1:fromblock-1]) .<= (length(to_s.b) + 1))

# Compute the ground truth variable update
hessian += I * lambda
gtupdate = hessian \ from.b
gtupdate = gtupdate[1:croplen]

Expand All @@ -50,8 +50,8 @@ using NLLSsolver, SparseArrays, StaticArrays, Test, Random

# Compute the marginalized system using dynamic block sizes
for block in fromblock:length(from.A.rowblocksizes)
NLLSsolver.marginalize!(to_d, from, block, Int(from.A.rowblocksizes[block]))
NLLSsolver.marginalize!(to_s, from, block, Int(from.A.rowblocksizes[block]))
NLLSsolver.marginalize!(to_d, from, block, lambda, Int(from.A.rowblocksizes[block]))
NLLSsolver.marginalize!(to_s, from, block, lambda, Int(from.A.rowblocksizes[block]))
end

# Check that the results are the same
Expand All @@ -62,16 +62,16 @@ using NLLSsolver, SparseArrays, StaticArrays, Test, Random

# Check that the reduced systems give the correct variable update
@test hessian \ gradient gtupdate
@test hess_d \ to_d.b gtupdate
@test hess_s \ to_s.b gtupdate
@test (hess_d + I * lambda) \ to_d.b gtupdate
@test (hess_s + I * lambda) \ to_s.b gtupdate

# Reset the 'to' systems
NLLSsolver.initcrop!(to_d, from)
NLLSsolver.initcrop!(to_s, from)

# Compute the marginalized system using static block sizes
NLLSsolver.marginalize!(to_d, from)
NLLSsolver.marginalize!(to_s, from)
NLLSsolver.marginalize!(to_d, from, SVector(lambda))
NLLSsolver.marginalize!(to_s, from, SVector(lambda))

# Check that the results are the same
hess_d = NLLSsolver.symmetrifyfull(to_d.A)
Expand All @@ -80,6 +80,13 @@ using NLLSsolver, SparseArrays, StaticArrays, Test, Random
@test to_d.b == to_s.b

# Check that the reduced systems give the correct variable update
@test hess_d \ to_d.b gtupdate
@test hess_s \ to_s.b gtupdate
@test (hess_d + I * lambda) \ to_d.b gtupdate
@test (hess_s + I * lambda) \ to_s.b gtupdate
return
end

@testset "marginalize.jl" begin
Random.seed!(1)
test_marginalize(0.0)
test_marginalize(1.0)
end

0 comments on commit 8c4e76e

Please sign in to comment.