Skip to content

Commit

Permalink
basic changes made
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Swaminathan committed Nov 24, 2024
1 parent 1c1ee4f commit 86d039b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/SCRIBEModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export SCRIBEModel, SCRIBEModelParameters, initialize_SCRIBEModel_from_parameter
export SCRIBEObserverBehavior, SCRIBEObserverState, compute_obs_dynamics, scribe_observations

using GaussianDistributions: Gaussian
using LinearAlgebra: norm, I,
using LinearAlgebra: norm, I, , rank

"""Abstract type defined for specialization during model instantiation.
Expand Down
25 changes: 21 additions & 4 deletions src/lineargaussianscalarfields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,33 @@ struct LGSFObserverState <: SCRIBEObserverState
end
end

"""Compute the observation dynamics matrix.
"""Compute the observation dynamics matrix `H`.
This is also referred to as the **H matrix**.
Takes steps to ensure the resultant observation matrix is not rank-deficient.
It does so by perturbing the input sample location by 10x the rank-deficience tolerance.
Returns:
* `H`: The observation matrix
* `X`: The end sample location (after potential perturbation).
"""
function compute_obs_dynamics(smodel::LGSFModel, X::Matrix{Float64}; tol=0.001)
let H=mapslices(smodel.ψ,X,dims=2)
if check_observability(H, tol)
return H, X
else
return compute_obs_dynamics(smodel, X+10*tol*rand(size(X)...); tol=tol)
end
end
end

"""Checks the rank of the observability matrix.
"""
compute_obs_dynamics(smodel::LGSFModel, X::Matrix{Float64}) = mapslices(smodel.ψ,X,dims=2)
check_observability(H::Matrix{Float64}, tol::Float64) = rank(H, atol=tol) minimum(size(H))

function scribe_observations(X::Matrix{Float64}, smodel::LGSFModel, o_b::LGSFObserverBehavior)
let nₛ=size(X,1), v_s=o_b.v_s, R=v_s[]*I(nₛ)
v=Dict(:R=>R, :k=>rand(Gaussian(zeros(nₛ), R)))
H=compute_obs_dynamics(smodel, X)
(H, X)=compute_obs_dynamics(smodel, X)
z=muladd(H,smodel.ϕ,v[:k])
LGSFObserverState(smodel.k, nₛ, X, H, v, z)
end
Expand Down
18 changes: 13 additions & 5 deletions test/test_kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function test_agent_setup()
init_sample_loc = [1 1.]
push!(sample_locations, init_sample_loc)
observations=[scribe_observations(init_sample_loc,gt_model[1],observer)]
sample_locations[end] = observations[1].X

@test typeof(observations[1]) <: SCRIBEObserverState
@test isapprox(observations[1].X,sample_locations[1])
Expand All @@ -37,8 +38,9 @@ function quick_setup(nᵩ=2; testing=true)
observer=LGSFObserverBehavior(0.3)
sample_locations=[]
init_agent_loc=[0. 0.;]
push!(sample_locations, init_agent_loc)
ag=initialize_agent(ag_params, observer, gt_model[1], init_agent_loc)
# push!(sample_locations, init_agent_loc)
push!(sample_locations, ag.estimates[1].observations.X)
if testing; @test typeof(ag) == AgentEnvModel; end

lg_Fs=simple_LGSF_Estimators(ag)
Expand All @@ -61,10 +63,13 @@ function test_estimators(; testing=true)
@test_throws BoundsError lg_Fs.z(2)
end

sz = (2,2)
new_loc = zeros(sz...)
for i in 1:4
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, rand(2,2))
next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
new_loc = rand(sz...)
next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], new_loc)
push!(sample_locations, ag.estimates[i+1].observations.X)
end

if testing
Expand All @@ -85,11 +90,14 @@ function test_centralized_KF(; testing=true)
nᵩ=ag_params.nᵩ # storing for easier debugging

fused_info=Any[ag.information[1]]
sz=(2,2)
new_loc = zeros(sz...)
for i in 1:100
push!(fused_info,centralized_fusion([lg_Fs], i)[1])
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, 3*rand(2,2))
progress_agent_env_filter(ag, fused_info[end], gt_model[i+1], sample_locations[i+1])
new_loc = 3*rand(sz...)
progress_agent_env_filter(ag, fused_info[end], gt_model[i+1], new_loc)
push!(sample_locations, ag.estimates[i+1].observations.X)
# next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
# next_agent_time(ag)
# next_agent_info_state(ag, centralized_fusion([ag], ag.k)[1])
Expand Down

0 comments on commit 86d039b

Please sign in to comment.