Skip to content

Commit

Permalink
First attempt at scripter
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Swaminathan committed Feb 4, 2025
1 parent 8f14bd5 commit 92b1fc0
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 25 deletions.
2 changes: 1 addition & 1 deletion test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.7"
manifest_format = "2.0"
project_hash = "0952182c7949199c353768e12df781ad89450b36"
project_hash = "79646ec35824621c96f288d3984b6af1018d4656"

[[deps.ASL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Match = "7eb4fadd-790c-5f42-8a69-bfa0b872bfbf"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
77 changes: 53 additions & 24 deletions test/test_mhmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function quick_setup(agent_ids, agent_conns, nᵩ; testing=true)
end
δ_w = 0.001 # represents temporal system dynamics (shift from A=identity)

gt_params=LGSFModelParameters=μ,σ=[0.5],τ=[1.],ϕ₀=ϕ₀,
gt_params=LGSFModelParameters=μ,σ=[1.],τ=[1.],ϕ₀=ϕ₀,
A=Matrix{Float64}(I(nᵩ) .* (1- δ_w)),
Q=0.000001*Matrix{Float64}(I(nᵩ)))
gt_model=[initialize_SCRIBEModel_from_parameters(gt_params)]
Expand All @@ -29,10 +29,10 @@ function quick_setup(agent_ids, agent_conns, nᵩ; testing=true)
a = 0
for aid in agent_ids
a += 1
ag_params=LGSFModelParameters=μ,σ=[0.5],τ=[1.],
ag_params=LGSFModelParameters=μ,σ=[1.],τ=[1.],
ϕ₀=zeros(nᵩ), A=Matrix{Float64}(I(nᵩ)),
Q=0.0001*Matrix{Float64}(I(nᵩ)))
observer=LGSFObserverBehavior(0.01)
observer=LGSFObserverBehavior(0.1)
init_agent_loc=[0. 0.; -0.5 -0.5] + [0. 0.; a-1 a-1]
lg_Fs = initialize_KF(ag_params, observer, copy(init_agent_loc), gt_model[1])
ng.vertices[aid] = initialize_agent(aid, lg_Fs, ng)
Expand Down Expand Up @@ -210,9 +210,9 @@ end
function simple_print_results(gt_model::Vector{T} where T<:SCRIBEModel, ng::NetworkGraph)
agent_ids = collect(keys(ng.vertices))

fests_m = [(aid*"m", ng.vertices[aid].agent.estimates[end].estimate.ϕ) for aid in agent_ids]
fests_v = [(aid*"v", inv(ng.vertices[aid].agent.information[end].Y)) for aid in agent_ids]
fvals = Dict([(:k, ng.vertices["agent1"].agent.k), (, gt_model[end].ϕ), fests_m..., fests_v...])
fests_m = [(aid*"m", ng.vertices[aid].agent.estimates[end].estimate.ϕ) for aid in agent_ids] # Final Mean Estimates
fests_v = [(aid*"v", inv(ng.vertices[aid].agent.information[end].Y)) for aid in agent_ids] # Final Variance Estimates
fvals = Dict([(:k, ng.vertices["agent1"].agent.k), (, gt_model[end].ϕ), fests_m..., fests_v...]) # collate

println("Results:")
println("k: ", fvals[:k], "\nϕ(t=final): ", fvals[])
Expand Down Expand Up @@ -242,7 +242,8 @@ function frechet_dist_eval_plot(run_name::String; layout_size=(800,500))
p2 = plot(size=layout_size, xlabel="Time (in discretized steps)",
title="Close-up of performance of agent learned environment models \nafter gathering initial observations", margin=(10, :mm))
for aid in agent_ids
plot!(p2, 1:length(gt_model), frechet_dists[aid], label="Agent "*aid[end], lw=2, yrange=[0., 1.0])
plot!(p2, (1:length(gt_model))[3:end], frechet_dists[aid][3:end], label="Agent "*aid[end], lw=2)
println(aid*" FDists: ", frechet_dists[aid])
end

plot(p1, p2, layout=(1,2), size=(layout_size[1]*2, layout_size[2]))
Expand All @@ -257,39 +258,67 @@ function error_mapf(gt::SCRIBEModel, m::SCRIBEModel, x::Vector; mode=:norm)
end
end

function error_map_plots(run_name::String; layout_size=(1200,1000), mode=:tane)
function error_map_plots(run_name::String; layout_size=(2700,1500), mode=:tane)
png_name = "test/res_plots/"*run_name*"_err_map.png"

@load "test/res_data/"*run_name*".jld2" gt_model ng
gt = gt_model[end]

num_plots = length(ng.vertices) + 1
num_plots = length(ng.vertices)
needs_padding = num_plots%2==1
layout_num = (2,Integer(ceil(num_plots/2)))
# layout_num = (2,Integer(ceil(num_plots/2)))
layout_num = (num_plots, num_plots+1)
pred_cgrad = :thermal
gt_err_cgrad = :solar
rel_err_cgrad = :ice

x_range = -5:0.1:5
y_range = copy(x_range)
gt_map = heatmap(x_range, y_range, [predict_SCRIBEModel(gt, [x,y]) for y in y_range, x in x_range],
color=:viridis, xlabel="X", ylabel="Y",
title="\nGround truth distribution of phenomena intensity",
c = :thermal)
color=:viridis, # xlabel="X", ylabel="Y",
title="Ground truth distribution",
c = pred_cgrad)

error_maps = Dict{String, Any}()
error_maps = Dict{String, Vector}()
error_mapv = Any[gt_map]

for aid in keys(ng.vertices)
let m = ng.vertices[aid].agent.estimates[end].estimate, h=reduce(vcat, ng.vertices[aid].history)
error_maps[aid] = heatmap(x_range, y_range, [error_mapf(gt, m, [x,y]; mode=mode) for y in y_range, x in x_range],
color=:viridis, xlabel="X", ylabel="Y",
title="Error distribution of \nAgent "*aid[end]*"'s predictions of phenomena intensity",
c = :berlin)
plot!(error_maps[aid], copy(h[:,1]), copy(h[:,2]),
linestyle=:dash, marker=:xcross, linecolor=:red, label="Sampling path and sites")
push!(error_mapv, error_maps[aid])
for (i, aid) in enumerate(keys(ng.vertices))
error_maps[aid] = Any[]
if i1
println("Blank because ground truth is already added, for "*aid[end])
push!(error_maps[aid], plot(legend=false,grid=false,foreground_color_subplot=:white))
end
for (j, a2d) in enumerate(keys(ng.vertices))
let m1 = ng.vertices[aid].agent.estimates[end].estimate,
m2 = ng.vertices[a2d].agent.estimates[end].estimate,
h1=reduce(vcat, ng.vertices[aid].history),
h2=reduce(vcat, ng.vertices[a2d].history)

if i==j
println("Ground comparison for agent "*aid[end])
push!(error_maps[aid], heatmap(x_range, y_range, [error_mapf(gt, m1, [x,y]; mode=mode) for y in y_range, x in x_range],
color=:viridis, # xlabel="X", ylabel="Y",
title="% error between \nAgent "*aid[end]*"'s predictions and ground truth",
c = gt_err_cgrad))
plot!(error_maps[aid][end], copy(h1[:,1]), copy(h1[:,2]),
linestyle=:dash, marker=:xcross, linecolor=:red, label="Agent "*aid[end]*" sampling sites")
else
println("Comparison for agent "*aid[end]*" with agent "*a2d[end])
push!(error_maps[aid], heatmap(x_range, y_range, [error_mapf(m1, m2, [x,y]; mode=mode) for y in y_range, x in x_range],
color=:viridis, # xlabel="X", ylabel="Y",
title="% error between \nAgent "*aid[end]*" and Agent "*a2d[end]*" predictions",
c = rel_err_cgrad))
# plot!(error_maps[aid][end], copy(h1[:,1]), copy(h1[:,2]),
# linestyle=:dash, marker=:xcross, linecolor=:red, label="Agent "*aid[end]*" sampling sites")
# plot!(error_maps[aid][end], copy(h2[:,1]), copy(h2[:,2]),
# linestyle=:dash, marker=:xcross, linecolor=:red, label="Agent "*a2d[end]*" sampling sites")
end
end
end
append!(error_mapv, copy(error_maps[aid]))
end

if needs_padding; push!(error_mapv, nothing); end
# if needs_padding; push!(error_mapv, nothing); end
plot(error_mapv..., layout=layout_num, size=layout_size)
savefig(png_name);
println("Plot saved at "*png_name)
Expand Down
Loading

0 comments on commit 92b1fc0

Please sign in to comment.