Skip to content

Commit

Permalink
Added Estimators and Initializers for Agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Swaminathan committed Nov 14, 2024
1 parent c66219e commit b8e88dc
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 18 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/Manifest.toml
.vscode/settings.json
18 changes: 16 additions & 2 deletions src/SCRIBEModels.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module SCRIBEModels

export SCRIBEModel, SCRIBEModelParameters, initialize_SCRIBEModel_from_parameters, update_SCRIBEModel, predict_SCRIBEModel
export SCRIBEObserverBehavior, SCRIBEObserverState, scribe_observations
export SCRIBEObserverBehavior, SCRIBEObserverState, compute_obs_dynamics, scribe_observations

using GaussianDistributions: Gaussian
using LinearAlgebra: norm, I,
Expand All @@ -18,7 +18,7 @@ abstract type SCRIBEModelParameters end
Takes parameter structure of a type inheriting from `SCRIBEModelParameters`.\\
Produces a model of a type inheriting from `SCRIBEModel`.
"""
function initialize_SCRIBEModel_from_parameters(params::SCRIBEModelParameters)
function initialize_SCRIBEModel_from_parameters(params::SCRIBEModelParameters; k)
# This function has no implementation and is intended to be specialized
error("`initialize_model_from_parameters` is not implemented for the abstract type SCRIBEModelParameters. Please provide a specific implementation.")
end
Expand Down Expand Up @@ -64,6 +64,20 @@ Currently defined observer state types:
"""
abstract type SCRIBEObserverState end

"""Generic observation dynamics calculation function.
"""
function compute_obs_dynamics(smodel::SCRIBEModel, X::Matrix{Float64})
# This function has no implementation and is intended to be specialized
error("`compute_obs_dynamics` is not implemented for the abstract type SCRIBEModel. Please provide a specific implementation.")
end

"""Generic observation function.
"""
function scribe_observations(X::Matrix{Float64}, smodel::SCRIBEModel, o_b::SCRIBEObserverBehavior)
# This function has no implementation and is intended to be specialized
error("`scribe_observations` is not implemented for the abstract types SCRIBEModel and SCRIBEObserverBehavior. Please provide a specific implementation.")
end

include("lineargaussianscalarfields.jl")

end
53 changes: 50 additions & 3 deletions src/kalman_estimation.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
export AgentEnvModel, initialize_agent, next_agent_state
export SystemEstimators, simple_LGSF_Estimators

"""Current discrete-time estimate of the linear system being observed.
Will store the discrete time of estimation for redundancy checking.
Expand All @@ -14,10 +17,16 @@ end
function init_agent_estimate(world::SCRIBEModel, k::Integer,
params::SCRIBEModelParameters, bhv::SCRIBEObserverBehavior,
X::VecOrMat{Float64})
AgentEnvEstimate(k, initialize_SCRIBEModel_from_parameters(params),
AgentEnvEstimate(k, initialize_SCRIBEModel_from_parameters(params, k=k),
scribe_observations(X,world,bhv))
end

function new_agent_estimate(world::SCRIBEModel, k::Integer,
estimate::SCRIBEModel, ϕₖ::Vector{Float64},
bhv::SCRIBEObserverBehavior, X::Matrix{Float64})
AgentEnvEstimate(k, update_SCRIBEModel(estimate, ϕₖ), scribe_observations(X, world, bhv))
end

"""Current discrete-time information.
This is separate from the actual model estimate - this is what the Kalman Filter interacts with.
Expand All @@ -26,9 +35,22 @@ The model estimate is a representation of the system, which is recovered from th
"""
struct AgentEnvInfo
y::Vector{Float64}
Y::Vector{Float64}
Y::Matrix{Float64}
i::Vector{Float64}
I::Vector{Float64}
I::Matrix{Float64}

AgentEnvInfo(y::Vector{Float64}, Y::Matrix{Float64},
i::Vector{Float64}, I::Matrix{Float64}) = new(y,Y,i,I)
end


"""Initial information associated per agent.
There is no information about the agent state, so all set to zero.
There is no such thing as "inital innovation", so arbitrarily set to zero.
"""
function init_agent_info(nᵩ::Integer)
AgentEnvInfo(zeros(nᵩ), zeros(nᵩ,nᵩ), zeros(nᵩ), zeros(nᵩ,nᵩ))
end

"""This is the collection of the system over time.
Expand All @@ -37,10 +59,24 @@ This is a linearized representation.
The mutating elements are the vectors, which are appended to.
"""
mutable struct AgentEnvModel
k::Integer
cwrld_sync::Integer
params::SCRIBEModelParameters
bhv::SCRIBEObserverBehavior
estimates::Vector{AgentEnvEstimate}
information::Vector{AgentEnvInfo}

AgentEnvModel(k::Integer, cwrld_sync::Integer, params::SCRIBEModelParameters, bhv::SCRIBEObserverBehavior,
estimates::Vector{AgentEnvEstimate},
information::Vector{AgentEnvInfo}) = new(k, cwrld_sync, params, bhv, estimates, information)
end

function initialize_agent(params::SCRIBEModelParameters, bhv::SCRIBEObserverBehavior, cwrld::SCRIBEModel, X₀::Matrix{Float64})
AgentEnvModel(1, cwrld.k-1, params, bhv, [init_agent_estimate(cwrld, 1, params, bhv, X₀)], [init_agent_info(params.nᵩ)])
end

function next_agent_state(agent::AgentEnvModel, ϕₖ::Vector{Float64}, cwrld::SCRIBEModel, X::Matrix{Float64})
push!(agent.estimates, new_agent_estimate(cwrld, agent.k+1, agent.estimates[agent.k].estimate, ϕₖ, agent.bhv, X))
end

struct SystemEstimators
Expand All @@ -49,4 +85,15 @@ struct SystemEstimators
ϕ::Function
H::Function
z::Function

SystemEstimators(system::AgentEnvModel, A::Function, ϕ::Function, H::Function, z::Function) = new(system, A, ϕ, H, z)
end

function simple_LGSF_Estimators(system::AgentEnvModel)
get_A(k, system) = system.estimates[k].estimate.params.A
get_ϕ(k, system) = system.estimates[k].estimate.ϕ
get_H(k, system) = compute_obs_dynamics(system.estimates[k].estimate, system.estimates[k].observations.X)
get_z(k, system) = system.estimates[k].observations.z

SystemEstimators(system, kA->get_A(kA, system), kϕ->get_ϕ(kϕ, system), kH->get_H(kH, system), kz->get_z(kz, system))
end
24 changes: 15 additions & 9 deletions src/lineargaussianscalarfields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ Input:
Output:
model::LGSFModel
"""
function initialize_SCRIBEModel_from_parameters(params::LGSFModelParameters)
return LGSFModel(1, params, x->ψ_from_params(x, params), params.ϕ₀, rand(params.w[:w_dist]))
function initialize_SCRIBEModel_from_parameters(params::LGSFModelParameters; k=1)
return LGSFModel(k, params, x->ψ_from_params(x, params), params.ϕ₀, rand(params.w[:w_dist]))
end

"""Computes the stochastic evolution of ϕ for a given timestep of an LGSFModel.
Expand Down Expand Up @@ -139,7 +139,7 @@ end
The following fields define the information stored:\\
`k::Integer`: Discrete timestep associated with observations \\
`nₛ::Integer`: Number of samples gathered in this time step \\
`X::VecOrMat{Float64}`: Matrix of observation locations (Vector if single observation) \\
`X::Matrix{Float64}`: Matrix of observation locations (Vector if single observation) \\
\t* Matrix stores the locations vertically, such that each new row represents a new location\\
`H::Matrix{Float64}`: Observation matrix representing taking samples at X \\
`v::Dict{Symbol, AbstractArray{Float64}}`: Dictionary of underlying sample noise factors
Expand All @@ -150,23 +150,29 @@ Constructor input: List of locations `X`, current system state `lmodel`, observe
struct LGSFObserverState <: SCRIBEObserverState
k::Integer
nₛ::Integer
X::VecOrMat{Float64}
X::Matrix{Float64}
H::Matrix{Float64}
v::Dict{Symbol, AbstractArray{Float64}}
z::Vector{Float64}

function LGSFObserverState(k::Integer, nₛ::Integer, X::VecOrMat{Float64},
function LGSFObserverState(k::Integer, nₛ::Integer, X::Matrix{Float64},
H::Matrix{Float64}, v::Dict{Symbol, AbstractArray{Float64}},
z::Vector{Float64})
new(k,nₛ,X,H,v,z)
end
end

function scribe_observations(X::VecOrMat{Float64}, lmodel::LGSFModel, o_b::LGSFObserverBehavior)
"""Compute the observation dynamics matrix.
This is also referred to as the **H matrix**.
"""
compute_obs_dynamics(smodel::LGSFModel, X::Matrix{Float64}) = mapslices(smodel.ψ,X,dims=2)

function scribe_observations(X::Matrix{Float64}, smodel::LGSFModel, o_b::LGSFObserverBehavior)
let nₛ=size(X,1), v_s=o_b.v_s, R=v_s[]*I(nₛ)
v=Dict(:R=>R, :k=>rand(Gaussian(zeros(nₛ), R)))
H=mapslices(lmodel.ψ,X,dims=2)
z=muladd(H,lmodel.ϕ,v[:k])
LGSFObserverState(lmodel.k, nₛ, X, H, v, z)
H=compute_obs_dynamics(smodel, X)
z=muladd(H,smodel.ϕ,v[:k])
LGSFObserverState(smodel.k, nₛ, X, H, v, z)
end
end
47 changes: 43 additions & 4 deletions test/test_kalman.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using Test
using SCRIBE
using LinearAlgebra: I

function agent_setup()
function test_agent_setup()
ϕ₀=[-1,1,-1,1,-1.]
gt_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],ϕ₀=ϕ₀,A=nothing,Q=nothing)
gt_model=[initialize_SCRIBEModel_from_parameters(gt_params)]
Expand All @@ -18,9 +19,47 @@ function agent_setup()
@test typeof(observations[1]) <: SCRIBEObserverState
@test isapprox(observations[1].X,sample_locations[1])
@test observations[1].k==1
println("Observatons: ", observations[1])
return gt_params, gt_model, sample_locations, observer, observations
end

@testset "Individual Agent Setup" begin
agent_setup()
function test_estimators(; testing=true)
ϕ₀=[-0.5,0.5,0.75,0.5,-0.5]
gt_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],ϕ₀=ϕ₀,A=nothing,Q=nothing)
gt_model=[initialize_SCRIBEModel_from_parameters(gt_params)]

ag_params=LGSFModelParameters=hcat(range(-1,1,5), zeros(5)),σ=[1.],τ=[1.],
ϕ₀=zeros(size(ϕ₀,1)),A=Matrix{Float64}(I(size(ϕ₀,1))),Q=Matrix{Float64}(I(size(ϕ₀,1))))
observer=LGSFObserverBehavior(0.3)
sample_locations=[]
init_agent_loc=[0. 0.;]
push!(sample_locations, init_agent_loc)
ag=initialize_agent(ag_params, observer, gt_model[1], init_agent_loc)
if testing; @test typeof(ag) == AgentEnvModel; end

lg_Fs=simple_LGSF_Estimators(ag)
if testing
@test lg_Fs.ϕ(1)==zeros(5)
@test lg_Fs.A(1)==Matrix{Float64}(I(5))
@test lg_Fs.H(1)!==nothing
@test lg_Fs.z(1)!==nothing

@test_throws BoundsError lg_Fs.ϕ(2)
@test_throws BoundsError lg_Fs.A(2)
@test_throws BoundsError lg_Fs.H(2)
@test_throws BoundsError lg_Fs.z(2)
end

for basic_i in 2:5
update_SCRIBEModel()
end

return ag, lg_Fs
end

@testset "Single Agent Setup" begin
test_agent_setup()
end

@testset "Single Agent Estimators" begin
test_estimators()
end

0 comments on commit b8e88dc

Please sign in to comment.