Skip to content

Commit

Permalink
Big debugging shifts ...
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Swaminathan committed Nov 20, 2024
1 parent 67d0542 commit 1c1ee4f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 46 deletions.
10 changes: 5 additions & 5 deletions test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0"
git-tree-sha1 = "3101c32aab536e7a27b1763c0797dba151b899ad"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.112"
version = "0.25.113"

[deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -116,9 +116,9 @@ version = "0.5.2"

[[deps.HypergeometricFunctions]]
deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
git-tree-sha1 = "7c4195be1649ae622304031ed46a2f4df989f1eb"
git-tree-sha1 = "b1c2585431c382e3fe5805874bda6aea90a95de9"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
version = "0.3.24"
version = "0.3.25"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
Expand Down Expand Up @@ -322,7 +322,7 @@ uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.5.1+0"

[[deps.SCRIBE]]
deps = ["Combinatorics", "GaussianDistributions", "Kalman", "LinearAlgebra", "Parameters", "Reexport"]
deps = ["Combinatorics", "GaussianDistributions", "Kalman", "LinearAlgebra", "Parameters", "Reexport", "Statistics"]
path = ".."
uuid = "9d033a4c-5281-4799-847b-6fce4a2de302"
version = "1.0.0-DEV"
Expand Down
78 changes: 37 additions & 41 deletions test/test_kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ function test_agent_setup()
return gt_params, gt_model, sample_locations, observer, observations
end

function quick_setup(; testing=true)
ϕ₀=[-0.5,0.5,0.75,0.5,-0.5]
gt_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],ϕ₀=ϕ₀,A=nothing,Q=nothing)
function quick_setup(nᵩ=2; testing=true)
if nᵩ==2
ϕ₀=[-0.5,0.5]
else
ϕ₀=[-0.5,0.5,0.75,0.5,-0.5]
end

gt_params=LGSFModelParameters=hcat(range(-1,1,nᵩ), zeros(nᵩ)),σ=[1.],τ=[1.],ϕ₀=ϕ₀,A=nothing,Q=nothing)
gt_model=[initialize_SCRIBEModel_from_parameters(gt_params)]

ag_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],
ϕ₀=zeros(size(ϕ₀,1)),A=Matrix{Float64}(I(size(ϕ₀,1))),Q=Matrix{Float64}(I(size(ϕ₀,1))))
ag_params=LGSFModelParameters=hcat(range(-1,1,nᵩ), zeros(nᵩ)),σ=[1.],τ=[1.],
ϕ₀=zeros(nᵩ), A=Matrix{Float64}(I(nᵩ)), Q=Matrix{Float64}(I(nᵩ)))
observer=LGSFObserverBehavior(0.3)
sample_locations=[]
init_agent_loc=[0. 0.;]
Expand All @@ -42,20 +47,7 @@ function quick_setup(; testing=true)
end

function test_estimators(; testing=true)
ϕ₀=[-0.5,0.5,0.75,0.5,-0.5]
gt_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],ϕ₀=ϕ₀,A=nothing,Q=nothing)
gt_model=[initialize_SCRIBEModel_from_parameters(gt_params)]

ag_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],
ϕ₀=zeros(size(ϕ₀,1)),A=Matrix{Float64}(I(size(ϕ₀,1))),Q=Matrix{Float64}(I(size(ϕ₀,1))))
observer=LGSFObserverBehavior(0.3)
sample_locations=[]
init_agent_loc=[0. 0.;]
push!(sample_locations, init_agent_loc)
ag=initialize_agent(ag_params, observer, gt_model[1], init_agent_loc)
if testing; @test typeof(ag) == AgentEnvModel; end

lg_Fs=simple_LGSF_Estimators(ag)
(ϕ₀, gt_params, gt_model, ag_params, observer, sample_locations, ag, lg_Fs) = quick_setup(5, testing=testing)

if testing
@test lg_Fs.ϕ(1)==zeros(5)
Expand All @@ -69,34 +61,34 @@ function test_estimators(; testing=true)
@test_throws BoundsError lg_Fs.z(2)
end

# Add temporary lexical scoping for stupid testing
let gt_model=deepcopy(gt_model), sample_locations=deepcopy(sample_locations), ag=deepcopy(ag)
lg_Fs=simple_LGSF_Estimators(ag)

for i in 1:4
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, rand(2,2))
next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
end
for i in 1:4
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, rand(2,2))
next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
end

if testing
@test_throws BoundsError lg_Fs.ϕ(6)
@test_throws BoundsError lg_Fs.A(6)
@test_throws BoundsError lg_Fs.H(6)
@test_throws BoundsError lg_Fs.z(6)

@test all([isapprox(lg_Fs.ϕ(i), zeros(5)) for i in 1:5])
@test !all([size(lg_Fs.H(i))==(2,5) for i in 1:5])
@test all([size(lg_Fs.H(i))==(2,5) for i in 2:5])
@test all([size(lg_Fs.z(i))==(2,) for i in 2:5])
end
if testing
@test_throws BoundsError lg_Fs.ϕ(6)
@test_throws BoundsError lg_Fs.A(6)
@test_throws BoundsError lg_Fs.H(6)
@test_throws BoundsError lg_Fs.z(6)

@test all([isapprox(lg_Fs.ϕ(i), zeros(5)) for i in 1:5])
@test !all([size(lg_Fs.H(i))==(2,5) for i in 1:5])
@test all([size(lg_Fs.H(i))==(2,5) for i in 2:5])
@test all([size(lg_Fs.z(i))==(2,) for i in 2:5])
end
end

function test_centralized_KF(; testing=true)
(ϕ₀, gt_params, gt_model, ag_params, observer, sample_locations, ag, lg_Fs) = quick_setup(2, testing=testing)
nᵩ=ag_params.nᵩ # storing for easier debugging

fused_info=[ag.information[1]]
fused_info=Any[ag.information[1]]
for i in 1:100
push!(fused_info,centralized_fusion([lg_Fs], i)[1])
push!(gt_model, update_SCRIBEModel(gt_model[i]))
push!(sample_locations, 3*rand(5,2))
push!(sample_locations, 3*rand(2,2))
progress_agent_env_filter(ag, fused_info[end], gt_model[i+1], sample_locations[i+1])
# next_agent_state(ag,zeros(ag_params.nᵩ), gt_model[i+1], sample_locations[i+1])
# next_agent_time(ag)
Expand All @@ -112,4 +104,8 @@ end

@testset "Single Agent Estimators" begin
test_estimators()
end

@testset "Centralized Kalman Filter" begin
test_centralized_KF()
end

0 comments on commit 1c1ee4f

Please sign in to comment.