From 3144c829dd549bded7e62a44d3b7f0350d306b29 Mon Sep 17 00:00:00 2001 From: Shashank Swaminathan Date: Mon, 13 Jan 2025 23:53:15 -0500 Subject: [PATCH] Added first portion of Frechet distance calculations --- test/test_mhmc.jl | 136 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 24 deletions(-) diff --git a/test/test_mhmc.jl b/test/test_mhmc.jl index 7622a2c..c32000a 100644 --- a/test/test_mhmc.jl +++ b/test/test_mhmc.jl @@ -35,7 +35,49 @@ function quick_setup(agent_ids, agent_conns, nᵩ; testing=true) return gt_model, ng end +function generate_agent_wpts(agent_ids, corners = [-5.,5.]) + n = length(agent_ids) + agent_coords = Dict([(aid, Any[]) for aid in agent_ids]) + + bc = [corners[1], corners[1]] + h = corners[2] - corners[1] + w = (h - 2*h/10) / n + + for aid in agent_ids + let ag = agent_coords[aid] + push!(ag, copy(bc)) + push!(ag, ag[end] + [0., h]) + push!(ag, ag[end] + [w/3, 0.]) + push!(ag, ag[end] + [0., -h]) + push!(ag, ag[end] + [w/3, 0.]) + push!(ag, ag[end] + [0., h]) + push!(ag, ag[end] + [w/3, 0.]) + push!(ag, ag[end] + [0., -h]) + + bc = ag[end] + [h/10, 0.] + end + end + + agent_coords +end + +function generate_sample_locs_from_wpts(wpts::Vector, k::Integer, d::Float64,) + let l = length(wpts), start = wpts[(k-1)%l+1], stop = wpts[k%l+1], + d1 = abs(stop[1] - start[1])/d, d2 = abs(stop[2] - start[2])/d + if d1 >= d2 + nₛ = max(Integer(round(d1)), 2) + else + nₛ = max(Integer(round(d2)), 2) + end + # println("\nstart: ", start, " | stop: ", stop, " || d1: ", d1, " | d2: ", d2, " | nₛ: ", nₛ) + hcat(range(start[1], stop[1], length=nₛ+1), + range(start[2], stop[2], length=nₛ+1)) + end +end + function mul_agent_distrib_KF(run_name::String, nₐ=3, nₛ=100; testing=true, tol=0.1) + space_corners = [-2., 2.] # corner coordinates + agent_ids = ["agent1", "agent2", "agent3", "agent4", "agent5"][1:nₐ] if nₐ==3 agent_conns = Dict([("agent1", ["agent2"]), @@ -56,8 +98,10 @@ function mul_agent_distrib_KF(run_name::String, nₐ=3, nₛ=100; testing=true, # Arrays of fused information over time per agent (array of arrays) # fused_info = Dict([(aid, [copy(ng.vertices[aid].agent.information[1])]) for aid in agent_ids]) - sz=(1,2) # or make it (1,2) - new_loc = Dict([(agent_ids[j], [0. 0.; -0.5 -0.5] + [0. 0.; j-1 j-1]) for j in eachindex(agent_ids)]) + # sz=(1,2) # or make it (1,2) + # new_loc = Dict([(agent_ids[j], [0. 0.; -0.5 -0.5] + [0. 0.; j-1 j-1]) for j in eachindex(agent_ids)]) + agent_wpts = generate_agent_wpts(agent_ids, space_corners) + sample_dists = (space_corners[2] - space_corners[1])/20 # ensure distance is small enough for lawnmower pattern # new_loc = zeros(sz...) for i in 1:nₛ print("k: ", i) @@ -74,7 +118,8 @@ function mul_agent_distrib_KF(run_name::String, nₐ=3, nₛ=100; testing=true, # new_loc = 3*rand(sz...) for aid in agent_ids # progress_agent_env_filter(ng.vertices[aid].agent, fused_info[aid][end], gt_model[i+1], copy(new_loc[aid])) - progress_agent_env_filter(ng.vertices[aid].agent, gt_model[i+1], copy(new_loc[aid])) + progress_agent_env_filter(ng.vertices[aid].agent, gt_model[i+1], + copy(generate_sample_locs_from_wpts(agent_wpts[aid], i, sample_dists))) push!(ng.vertices[aid].history, ng.vertices[aid].agent.estimates[i+1].observations.X) end println("ϕ: ", gt_model[end].ϕ, " | ̂ϕ: ", ng.vertices["agent1"].agent.estimates[end].estimate.ϕ) @@ -92,7 +137,7 @@ function mul_agent_distrib_KF(run_name::String, nₐ=3, nₛ=100; testing=true, simple_print_results(gt_model, ng) - @save "test/"*run_name*".jld2" gt_model ng + @save "test/res_data/"*run_name*".jld2" gt_model ng return gt_model, ng end @@ -110,30 +155,73 @@ function simple_print_results(gt_model::Vector{T} where T<:SCRIBEModel, ng::Netw end end -function error_map_plots(run_name::String) - @load "test/"*run_name*".jld2" gt ng +function frechet_dist_eval_plot(run_name::String; layout_size=(800,500)) + png_name = "test/res_plots/"*run_name*"_frechet_dist.png" + + @load "test/res_data/"*run_name*".jld2" gt_model ng x_range = -2:0.1:2 - y_range = -2:0.1:2 + y_range = copy(x_range) + err_sq(gt, am, v) = norm(predict_SCRIBEModel(gt, v) - predict_SCRIBEModel(am, v))^2 + frechet(gt, am) = sum([err_sq(gt, am, [x, y]) for x in x_range for y in y_range]) + agent_ids = collect(keys(ng.edges)) + frechet_dists = Dict([(aid, [frechet(gt_model[i], ng.vertices[aid].agent.estimates[i].estimate) for i in 1:length(gt_model)]) for aid in agent_ids]) + + p1 = plot(size=layout_size, xlabel="Time (in discretized steps)", ylabel="(pseudo-)Frechet distance", + title="Performance of agent learned environment models \ncompared to ground truth over time", margin=(10, :mm)) + for aid in agent_ids + plot!(p1, 1:length(gt_model), frechet_dists[aid], label="Agent "*aid[end], lw=2) + end + + p2 = plot(size=layout_size, xlabel="Time (in discretized steps)", + title="Close-up of performance of agent learned environment models \nafter gathering initial observations", margin=(10, :mm)) + for aid in agent_ids + plot!(p2, 1:length(gt_model), frechet_dists[aid], label="Agent "*aid[end], lw=2, yrange=[0., 1.0]) + end + # xlabel!(p2, "Time (in discretized steps)") + # ylabel!(p2, "(pseudo-)Frechet distance, zoomed in") + # title!(p2, "Performance of agent learned environment models \ncompared to ground truth over time") + + plot(p1, p2, layout=(1,2), size=(layout_size[1]*2, layout_size[2])) + savefig(png_name); + println("Plot saved at "*png_name) +end - let gt = gt[end], ag1_mod = ng.vertices["agent1"].agent.estimates[end].estimate, - png_name = "test/res_plots/"*run_name*"_err_map_agent1.png" - error_agent1(x::Vector) = norm(predict_SCRIBEModel(gt, x) - predict_SCRIBEModel(ag1_mod, x)) - - gt_map_vals = [predict_SCRIBEModel(gt, [x,y]) for y in y_range, x in x_range] - err1_map_vals = [error_agent1([x,y]) for y in y_range, x in x_range] - - gt_map = heatmap(x_range, y_range, gt_map_vals, - color=:viridis, xlabel="X", ylabel="Y", - title="Ground truth distribution of phenomena intensity") - err1_map = heatmap(x_range, y_range, err1_map_vals, - color=:viridis, xlabel="X", ylabel="Y", - title="Error distribution of \nagent 1's predictions of phenomena intensity") - - plot(gt_map, err1_map, layout=(2,1), size=(600,800)); - savefig(png_name); - println("Plot saved at "*png_name) +function error_map_plots(run_name::String; layout_size=(1200,1000)) + png_name = "test/res_plots/"*run_name*"_err_map.png" + + @load "test/res_data/"*run_name*".jld2" gt_model ng + gt = gt_model[end] + + num_plots = length(ng.vertices) + 1 + needs_padding = num_plots%2==1 + layout_num = (2,Integer(ceil(num_plots/2))) + + x_range = -2:0.1:2 + y_range = -2:0.1:2 + gt_map = heatmap(x_range, y_range, [predict_SCRIBEModel(gt, [x,y]) for y in y_range, x in x_range], + color=:viridis, xlabel="X", ylabel="Y", + title="\nGround truth distribution of phenomena intensity") + + error_mapf(m, x) = norm(predict_SCRIBEModel(gt, x) - predict_SCRIBEModel(m, x)) + error_maps = Dict{String, Any}() + error_mapv = Any[gt_map] + + for aid in keys(ng.vertices) + let m = ng.vertices[aid].agent.estimates[end].estimate, h=reduce(vcat, ng.vertices[aid].history) + error_maps[aid] = heatmap(x_range, y_range, [error_mapf(m, [x,y]) for y in y_range, x in x_range], + color=:viridis, xlabel="X", ylabel="Y", + title="Error distribution of \nAgent "*aid[end]*"'s predictions of phenomena intensity") + plot!(error_maps[aid], copy(h[:,1]), copy(h[:,2]), + linestyle=:dash, marker=:xcross, linecolor=:red, label="Sampling path and sites") + push!(error_mapv, error_maps[aid]) + end end + + if needs_padding; push!(error_mapv, nothing); end + plot(error_mapv..., layout=layout_num, size=layout_size) + savefig(png_name); + println("Plot saved at "*png_name) end # mul_agent_distrib_KF(); \ No newline at end of file