Skip to content

Commit

Permalink
Starting on information fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Swaminathan committed Nov 16, 2024
1 parent 187f525 commit a586d66
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 23 deletions.
28 changes: 19 additions & 9 deletions src/kalman_estimation.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Parameters: @unpack
using Combinatorics: combinations

export AgentEnvModel, initialize_agent, next_agent_state
export SystemEstimators, simple_LGSF_Estimators
export AgentEnvModel, initialize_agent, next_agent_state, next_agent_time, next_agent_info_state
export SystemEstimators, simple_LGSF_Estimators, compute_info_priors, compute_innov_from_obs
export centralized_fusion

"""Current discrete-time estimate of the linear system being observed.
Expand Down Expand Up @@ -88,6 +89,16 @@ function initialize_agent(params::SCRIBEModelParameters, bhv::SCRIBEObserverBeha
AgentEnvModel(1, cwrld.k-1, params, bhv, [init_agent_estimate(cwrld, 1, params, bhv, X₀)], [init_agent_info(params.nᵩ)])
end

"""Adds new internal system estimate for t=k+1.
"""
function next_agent_state(agent::AgentEnvModel, ϕₖ::Vector{Float64}, cwrld::SCRIBEModel, X::Matrix{Float64})
push!(agent.estimates, new_agent_estimate(cwrld, agent.k+1, agent.estimates[agent.k].estimate, ϕₖ, agent.bhv, X))
end

next_agent_time(agent::AgentEnvModel) = agent.k+=1

next_agent_info_state(agent::AgentEnvModel, info::AgentEnvInfo) = push!(agent.information, info)

struct SystemEstimators
system::AgentEnvModel
A::Function
Expand Down Expand Up @@ -119,10 +130,6 @@ function simple_LGSF_Estimators(system::AgentEnvModel)
kY->get_Y(kY, system), ky->get_y(ky, system))
end

function next_agent_state(agent::AgentEnvModel, ϕₖ::Vector{Float64}, cwrld::SCRIBEModel, X::Matrix{Float64})
push!(agent.estimates, new_agent_estimate(cwrld, agent.k+1, agent.estimates[agent.k].estimate, ϕₖ, agent.bhv, X))
end

"""Computes the prior update **of the next step** Y⁻(k+1).
Takes two inputs:
Expand Down Expand Up @@ -152,11 +159,12 @@ function compute_innov_from_obs(Ef::SystemEstimators, k::Integer)
end

function centralized_fusion(agent_estimators::Vector{SystemEstimators}, k::Integer)
nₐ=size(priors,1)
# Compute priors from current system state at t=k and previous information state at t=k-1
priors = map(ef->compute_info_priors(ef, k), agent_estimators)

# Ensure all priors are the same
if size(priors,1)>1
if nₐ>1
for prior_pair in combinations(priors, 2)
let prior_a=prior_pair[1], prior_b=prior_pair[2]
@assert isapprox(prior_a[1], prior_b[1]) "Information matrix priors for Y⁻(k+1) are diverged!"
Expand All @@ -165,8 +173,10 @@ function centralized_fusion(agent_estimators::Vector{SystemEstimators}, k::Integ
end
end


# Compute innovations for t=k+1 from current observation at t=k
innovs = map(ef->compute_innov_from_obs(ef, k), agent_estimators)
avg_innov = average(innovs)
δĪ=mean(map(x->x[1], innovs))
δī=mean(map(x->x[2], innovs))

return [AgentEnvInfo(priors[a][2]+nₐ*δī, priors[a][1]+nₐ*δĪ, δī, δĪ) for a in 1:nₐ]
end
40 changes: 26 additions & 14 deletions test/test_kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,34 @@ function test_estimators(; testing=true)
@test_throws BoundsError lg_Fs.z(2)
end

for i in 1:4
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, rand(2,2))
next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
# Add temporary lexical scoping for stupid testing
let gt_model=deepcopy(gt_model), sample_locations=deepcopy(sample_locations), ag=deepcopy(ag)

for i in 1:4
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, rand(2,2))
next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
end

if testing
@test_throws BoundsError lg_Fs.ϕ(6)
@test_throws BoundsError lg_Fs.A(6)
@test_throws BoundsError lg_Fs.H(6)
@test_throws BoundsError lg_Fs.z(6)

@test all([isapprox(lg_Fs.ϕ(i), zeros(5)) for i in 1:5])
@test !all([size(lg_Fs.H(i))==(2,5) for i in 1:5])
@test all([size(lg_Fs.H(i))==(2,5) for i in 2:5])
@test all([size(lg_Fs.z(i))==(2,) for i in 2:5])
end
end

if testing
@test_throws BoundsError lg_Fs.ϕ(6)
@test_throws BoundsError lg_Fs.A(6)
@test_throws BoundsError lg_Fs.H(6)
@test_throws BoundsError lg_Fs.z(6)

@test all([isapprox(lg_Fs.ϕ(i), zeros(5)) for i in 1:5])
@test !all([size(lg_Fs.H(i))==(2,5) for i in 1:5])
@test all([size(lg_Fs.H(i))==(2,5) for i in 2:5])
@test all([size(lg_Fs.z(i))==(2,) for i in 2:5])
for i in 1:100
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, 3*rand(5,2))
next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
next_agent_time(ag)
next_agent_info_state(ag, centralized_fusion([ag], ag.k)[1])
end

return ag, lg_Fs
Expand Down

0 comments on commit a586d66

Please sign in to comment.