Skip to content

Commit

Permalink
trying to get KF to work right now
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Swaminathan committed Nov 30, 2024
1 parent 86d039b commit f422f96
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
6 changes: 4 additions & 2 deletions src/kalman_estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function next_agent_state(agent::AgentEnvModel, ϕₖ::Vector{Float64}, cwrld::S
let k=agent.k,
new_estimate=update_SCRIBEModel(agent.estimates[k].estimate, ϕₖ),
new_obs=scribe_observations(X, cwrld, agent.bhv)
push!(agent.estimates, AgentEnvEstimate(k, new_estimate, new_obs))
push!(agent.estimates, AgentEnvEstimate(k+1, new_estimate, new_obs))
end
end

Expand Down Expand Up @@ -152,7 +152,7 @@ function simple_LGSF_Estimators(system::AgentEnvModel)
get_A(k, system) = system.estimates[k].estimate.params.A
get_ϕ(k, system) = system.estimates[k].estimate.ϕ
get_Q(k, system) = system.params.w[:Q]
get_H(k, system) = compute_obs_dynamics(system.estimates[k].estimate, system.estimates[k].observations.X)
get_H(k, system) = compute_obs_dynamics(system.estimates[k].estimate, system.estimates[k].observations.X)[1]
get_z(k, system) = system.estimates[k].observations.z
get_R(k, system) = system.estimates[k].observations.v[:R]
get_Y(k, system) = system.information[k].Y
Expand All @@ -175,6 +175,8 @@ function compute_info_priors(Ef::SystemEstimators, k::Integer)
M = inv(A(k))' * Y(k) * inv(A(k))
Y⁻ = M - M * inv(M + inv(Q(k))) * M
y⁻ = Y⁻ * A(k) * Y(k) * y(k)
println("k: ", k, " | y⁻: ", y⁻)
println("k: ", k, " | Y⁻: ", Y⁻)
return Y⁻, y⁻
end

Expand Down
9 changes: 5 additions & 4 deletions src/lineargaussianscalarfields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ struct LGSFModelParameters <: SCRIBEModelParameters
else
@assert size(ϕ₀,1)==nᵩ
end
if A===nothing; A=I(nᵩ); end
if Q===nothing; Q=I(nᵩ); end
if A===nothing; A=I(nᵩ); else; @assert size(A)==(nᵩ,nᵩ); end
if Q===nothing; Q=0.001*I(nᵩ); else; @assert size(Q)==(nᵩ,nᵩ); end
w=Dict(:Q=>Q, :w_dist=>Gaussian(zeros(nᵩ), Q))
new(nᵩ, p, ψ_p, ϕ₀, A, w)
end
Expand Down Expand Up @@ -182,14 +182,15 @@ function compute_obs_dynamics(smodel::LGSFModel, X::Matrix{Float64}; tol=0.001)
if check_observability(H, tol)
return H, X
else
return compute_obs_dynamics(smodel, X+10*tol*rand(size(X)...); tol=tol)
return H, X
# return compute_obs_dynamics(smodel, X+10*tol*rand(size(X)...); tol=tol)
end
end
end

"""Checks the rank of the observability matrix.
"""
check_observability(H::Matrix{Float64}, tol::Float64) = rank(H, atol=tol) minimum(size(H))
check_observability(H::Matrix{Float64}, tol::Float64) = rank(H, atol=tol) size(H)[2]

function scribe_observations(X::Matrix{Float64}, smodel::LGSFModel, o_b::LGSFObserverBehavior)
let nₛ=size(X,1), v_s=o_b.v_s, R=v_s[]*I(nₛ)
Expand Down
48 changes: 32 additions & 16 deletions test/test_kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ using LinearAlgebra: I

function test_agent_setup()
ϕ₀=[-1,1,-1,1,-1.]
gt_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],ϕ₀=ϕ₀,A=nothing,Q=nothing)
gt_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],ϕ₀=ϕ₀,A=nothing,Q=0.1.*I(5))
gt_model=[initialize_SCRIBEModel_from_parameters(gt_params)]
@test typeof(gt_model) <: Vector{T} where T<:SCRIBEModel

v_s=0.1
observer=LGSFObserverBehavior(v_s)
@test typeof(observer) <: SCRIBEObserverBehavior
sample_locations = []
init_sample_loc = [1 1.]
init_sample_loc = [0. 0.]
push!(sample_locations, init_sample_loc)
observations=[scribe_observations(init_sample_loc,gt_model[1],observer)]
sample_locations[end] = observations[1].X
Expand All @@ -34,14 +34,18 @@ function quick_setup(nᵩ=2; testing=true)
gt_model=[initialize_SCRIBEModel_from_parameters(gt_params)]

ag_params=LGSFModelParameters=hcat(range(-1,1,nᵩ), zeros(nᵩ)),σ=[1.],τ=[1.],
ϕ₀=zeros(nᵩ), A=Matrix{Float64}(I(nᵩ)), Q=Matrix{Float64}(I(nᵩ)))
observer=LGSFObserverBehavior(0.3)
ϕ₀=zeros(nᵩ).+0.1, A=Matrix{Float64}(I(nᵩ)), Q=0.01*Matrix{Float64}(I(nᵩ)))
observer=LGSFObserverBehavior(0.01)
sample_locations=[]
init_agent_loc=[0. 0.;]
# init_agent_loc=[0. 0.;]
init_agent_loc=[0. 0.; 0.5 0.5]
ag=initialize_agent(ag_params, observer, gt_model[1], init_agent_loc)
# push!(sample_locations, init_agent_loc)
push!(sample_locations, ag.estimates[1].observations.X)
if testing; @test typeof(ag) == AgentEnvModel; end
if testing
@test typeof(ag) == AgentEnvModel
# @test !isapprox(sample_locations[1], init_agent_loc)
end

lg_Fs=simple_LGSF_Estimators(ag)

Expand Down Expand Up @@ -85,14 +89,19 @@ function test_estimators(; testing=true)
end
end

function test_centralized_KF(; testing=true)
function test_observability(; testing=true)
(ϕ₀, gt_params, gt_model, ag_params, observer, sample_locations, ag, lg_Fs) = quick_setup(2, testing=testing)
@test !isapprox(sample_locations, zeros(1,2))
end

function test_centralized_KF(nₛ=3; testing=true)
(ϕ₀, gt_params, gt_model, ag_params, observer, sample_locations, ag, lg_Fs) = quick_setup(2, testing=testing)
nᵩ=ag_params.nᵩ # storing for easier debugging

fused_info=Any[ag.information[1]]
sz=(2,2)
sz=(1,2) # or make it (1,2)
new_loc = zeros(sz...)
for i in 1:100
for i in 1:nₛ
push!(fused_info,centralized_fusion([lg_Fs], i)[1])
push!(gt_model, update_SCRIBEModel(gt_model[i]))
new_loc = 3*rand(sz...)
Expand All @@ -103,16 +112,23 @@ function test_centralized_KF(; testing=true)
# next_agent_info_state(ag, centralized_fusion([ag], ag.k)[1])
end

return ag, lg_Fs
end
println("Results:")
println("k: ", ag.k, "\nϕⱼ(t=final): ", ag.estimates[end].estimate.ϕ, "\nϕ(t=final): ", gt_model[end].ϕ)

@testset "Single Agent Setup" begin
test_agent_setup()
return ag, lg_Fs, gt_model
end

@testset "Single Agent Estimators" begin
test_estimators()
end
# @testset "Single Agent Setup" begin
# test_agent_setup()
# end

# @testset "Single Agent Estimators" begin
# test_estimators()
# end

# @testset "H Observability" begin
# test_observability()
# end

@testset "Centralized Kalman Filter" begin
test_centralized_KF()
Expand Down

0 comments on commit f422f96

Please sign in to comment.