diff --git a/Project.toml b/Project.toml index 69856a04..0a0de216 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ SafeTestsets = "0.0.1" StaticArrays = "1.9" StaticArraysCore = "1.4" Test = "1" +Zygote = "0.6.67" julia = "1.10" [extras] @@ -27,6 +28,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"] +test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Zygote"] diff --git a/docs/make.jl b/docs/make.jl index 23c85085..c1e6871a 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,7 +12,8 @@ makedocs(sitename = "SymbolicIndexingInterface.jl", format = Documenter.HTML(analytics = "UA-90474609-3", assets = ["assets/favicon.ico"], canonical = "https://docs.sciml.ai/SymbolicIndexingInterface/stable/"), - pages = pages) + pages = pages, + checkdocs = :exports) deploydocs(repo = "github.com/SciML/SymbolicIndexingInterface.jl.git"; push_preview = true) diff --git a/docs/src/api.md b/docs/src/api.md index dc1c7546..8286edf0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -30,10 +30,39 @@ allvariables ```@docs observed +parameter_observed +ParameterObservedFunction +``` + +#### Parameter timeseries + +If the index provider contains parameters that change during the course of the simulation +at discrete time points, it must implement the following methods to ensure correct +functioning of [`getu`](@ref) and [`getp`](@ref) for value providers that save the parameter +timeseries. Note that there can be multiple parameter timeseries, in case different parameters +may change at different times. + +```@docs +is_timeseries_parameter +timeseries_parameter_index +ParameterTimeseriesIndex ``` ## Value provider interface +### State indexing + +```@docs +Timeseries +NotTimeseries +is_timeseries +state_values +set_state! +current_time +getu +setu +``` + ### Parameter indexing ```@docs @@ -51,24 +80,14 @@ If a solution object saves a timeseries of parameter values that are updated dur simulation (such as by callbacks), it must implement the following methods to ensure correct functioning of [`getu`](@ref) and [`getp`](@ref). -```@docs -parameter_timeseries -parameter_values_at_time -parameter_values_at_state_time -``` - - -### State indexing +Parameter timeseries support requires that the value provider store the different +timeseries in a [`ParameterTimeseriesCollection`](@ref). ```@docs -Timeseries -NotTimeseries -is_timeseries -state_values -set_state! -current_time -getu -setu +is_parameter_timeseries +get_parameter_timeseries_collection +ParameterTimeseriesCollection +with_updated_parameter_timeseries_values ``` ### Batched Queries and Updates diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index 047d2162..0fd65091 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -1,5 +1,6 @@ # Implementing the Complete Symbolic Indexing Interface +## Index Provider Interface This tutorial will show how to define the entire Symbolic Indexing Interface on an `ExampleSystem`: @@ -17,9 +18,9 @@ end Not all the methods in the interface are required. Some only need to be implemented if a type supports specific functionality. Consider the following struct, which needs to implement the interface: -## Mandatory methods +### Mandatory methods -### Simple Indexing Functions +#### Simple Indexing Functions These are the simple functions which describe how to turn symbols into indices. @@ -84,7 +85,7 @@ function SymbolicIndexingInterface.default_values(sys::ExampleSystem) end ``` -### Observed Equation Handling +#### Observed Equation Handling These are for handling symbolic expressions and generating equations which are not directly in the solution vector. @@ -131,7 +132,12 @@ end In case a type does not support such observed quantities, `is_observed` must be defined to always return `false`, and `observed` does not need to be implemented. -### Note about constant structure +The same process can be followed for [`parameter_observed`](@ref), with the exception +that the returned function must not have `u` in its signature, and must be wrapped in a +[`ParameterObservedFunction`](@ref). In-place versions can also be implemented for +`parameter_observed`. + +#### Note about constant structure Note that the method definitions are all assuming `constant_structure(p) == true`. @@ -147,14 +153,16 @@ In case `constant_structure(p) == false`, the following methods would change: `observed(sys::ExampleSystem, sym, i)` where `i` is either the time index at which the index of `sym` is required or a `Vector` of state symbols at the current time index. -## Optional methods +### Optional methods Note that `observed` is optional if `is_observed` is always `false`, or the type is only responsible for identifying observed values and `observed` will always be called on a type that wraps this type. An example is `ModelingToolkit.AbstractSystem`, which can identify whether a value is observed, but cannot implement `observed` itself. -Other optional methods relate to indexing functions. If a type contains the values of +## Value Provider Interface + +Other interface methods relate to indexing functions. If a type contains the values of parameter variables, it must implement [`parameter_values`](@ref). This allows the default definitions of [`getp`](@ref) and [`setp`](@ref) to work. While `setp` is not typically useful for solution objects, it may be useful for integrators. Typically, @@ -276,7 +284,7 @@ similar functionality, but is called for every parameter that is updated, instea once. Thus, `finalize_parameters_hook!` is better for expensive computations that can be performed for a bulk parameter update. -# The `ParameterIndexingProxy` +## The `ParameterIndexingProxy` [`ParameterIndexingProxy`](@ref) is a wrapper around another type which implements the interface and allows using [`getp`](@ref) and [`setp`](@ref) to get and set parameter @@ -305,6 +313,164 @@ integrator.ps[:b] = 3.0 setp(integrator, :b)(integrator, 3.0) # functionally the same as above ``` +## Parameter Timeseries + +If a solution object includes modified parameter values (such as through callbacks) during the +simulation, it must implement several additional functions for correct functioning of +[`getu`](@ref) and [`getp`](@ref). [`ParameterTimeseriesCollection`](@ref) helps in +implementing parameter timeseries objects. The following mockup gives an example of +correct implementation of these functions and the indexing syntax they enable. + +```@example param_timeseries +using SymbolicIndexingInterface + +# First, we must implement a parameter object that knows where the parameters in +# each parameter timeseries are stored +struct MyParameterObject + p::Vector{Float64} + disc_idxs::Vector{Vector{Int}} +end + +# To be able to access parameter values +SymbolicIndexingInterface.parameter_values(mpo::MyParameterObject) = mpo.p +# Update the parameter object with new values +function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(mpo::MyParameterObject, args::Pair...) + for (ts_idx, val) in args + mpo.p[mpo.disc_idxs[ts_idx]] = val + end + return mpo +end + +struct ExampleSolution2 + sys::SymbolCache + u::Vector{Vector{Float64}} + t::Vector{Float64} + p::MyParameterObject # the parameter object. Only some parameters are timeseries params + p_ts::ParameterTimeseriesCollection +end + +# Add the `:ps` property to automatically wrap in `ParameterIndexingProxy` +function Base.getproperty(fs::ExampleSolution2, s::Symbol) + s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s) +end +# Use the contained `SymbolCache` for indexing +SymbolicIndexingInterface.symbolic_container(fs::ExampleSolution2) = fs.sys +# State indexing methods +SymbolicIndexingInterface.state_values(fs::ExampleSolution2) = fs.u +SymbolicIndexingInterface.current_time(fs::ExampleSolution2) = fs.t +# By default, `parameter_values` refers to the last value +SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2) = fs.p +SymbolicIndexingInterface.get_parameter_timeseries_collection(fs::ExampleSolution2) = fs.p_ts +# Mark the object as a timeseries object +SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution2}) = Timeseries() +# Mark the object as a parameter timeseries object +SymbolicIndexingInterface.is_parameter_timeseries(::Type{ExampleSolution2}) = Timeseries() +``` + +We will also need a timeseries object which will store individual parameter timeseries. +`DiffEqArray` in `RecursiveArrayTools.jl` satisfies this use case, but we will implement +one manually here. + +```@example param_timeseries +struct MyDiffEqArray + t::Vector{Float64} + u::Vector{Vector{Float64}} +end + +# Must be a timeseries object, and implement `current_time` and `state_values` +SymbolicIndexingInterface.is_timeseries(::Type{MyDiffEqArray}) = Timeseries() +SymbolicIndexingInterface.current_time(a::MyDiffEqArray) = a.t +SymbolicIndexingInterface.state_values(a::MyDiffEqArray) = a.u +``` + +Now we can create an example object and observe the new functionality. Note that +`sol.ps[sym, args...]` is identical to `getp(sol, sym)(sol, args...)`. In a real +application, the solution object will be populated during the solve process. We manually +construct the object here for demonstration. + +```@example param_timeseries +sys = SymbolCache( + [:x, :y, :z], [:a, :b, :c, :d], :t; + # specify that :b, :c and :d are timeseries parameters + # :b and :c belong to the same timeseries + # :d is in a different timeseries + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), + :c => ParameterTimeseriesIndex(1, 2), + :d => ParameterTimeseriesIndex(2, 1), + )) +b_c_timeseries = MyDiffEqArray( + collect(0.0:0.1:1.0), + [[0.25i, 0.35i] for i in 1:11] +) +d_timeseries = MyDiffEqArray( + collect(0.0:0.2:1.0), + [[0.17i] for i in 1:6] +) +p = MyParameterObject( + # parameter values at the final time + [4.2, b_c_timeseries.u[end]..., d_timeseries.u[end]...], + [[2, 3], [4]] +) +sol = ExampleSolution2( + sys, + [i * ones(3) for i in 1:5], # u + collect(0.0:0.25:1.0), # t + p, + ParameterTimeseriesCollection([b_c_timeseries, d_timeseries], deepcopy(p)) +) +sol.ps[:a] # returns the value of non-timeseries parameter +``` + +```@example param_timeseries +sol.ps[:b] # returns the timeseries of :b +``` + +```@example param_timeseries +sol.ps[:b, 3] # index at a specific index in the parameter timeseries +``` + +```@example param_timeseries +sol.ps[:b, [3, 6, 8]] # index using arrays +``` + +```@example param_timeseries +idxs = @show rand(Bool, 11) # boolean mask for indexing +sol.ps[:b, idxs] +``` + +```@example param_timeseries +sol.ps[[:a, :b]] # returns the values at the last timestep, since :a is not timeseries +``` + +```@example param_timeseries +# throws an error since :b and :d belong to different timeseries +try + sol.ps[[:b, :d]] +catch e + @show e +end +``` + +```@example param_timeseries +sol.ps[:(b + c)] # observed quantities work too +``` + +```@example param_timeseries +getu(sol, :b)(sol) # returns the values :b takes at the times in the state timeseries +``` + +```@example param_timeseries +getu(sol, [:b, :d])(sol) # works +``` + +## Custom containers + +A custom container object (such as `ModelingToolkit.MTKParameters`) should implement +[`remake_buffer`](@ref) to allow creating a new buffer with updated values, possibly +with different types. This is already implemented for `AbstractArray`s (including static +arrays). + # Implementing the `SymbolicTypeTrait` for a type The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It @@ -383,87 +549,3 @@ end Note the evaluation of the operation if all of the arguments are not symbolic. This is required since `symbolic_evaluate` must return an evaluated value if all symbolic variables are substituted. - -## Parameter Timeseries - -If a solution object saves modified parameter values (such as through callbacks) during the -simulation, it must implement [`parameter_timeseries`](@ref), -[`parameter_values_at_time`](@ref) and [`parameter_values_at_state_time`](@ref) for correct -functioning of [`getu`](@ref) and [`getp`](@ref). The following mockup gives an example -of correct implementation of these functions and the indexing syntax they enable. - -```@example param_timeseries -using SymbolicIndexingInterface - -struct ExampleSolution2 - sys::SymbolCache - u::Vector{Vector{Float64}} - t::Vector{Float64} - p::Vector{Vector{Float64}} - pt::Vector{Float64} -end - -# Add the `:ps` property to automatically wrap in `ParameterIndexingProxy` -function Base.getproperty(fs::ExampleSolution2, s::Symbol) - s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s) -end -# Use the contained `SymbolCache` for indexing -SymbolicIndexingInterface.symbolic_container(fs::ExampleSolution2) = fs.sys -# By default, `parameter_values` refers to the last value -SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2) = fs.p[end] -SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2, i) = fs.p[end][i] -# Index into the parameter timeseries vector -function SymbolicIndexingInterface.parameter_values_at_time(fs::ExampleSolution2, t) - fs.p[t] -end -# Find the first index in the parameter timeseries vector with a time smaller -# than the time from the state timeseries, and use that to index the parameter -# timeseries -function SymbolicIndexingInterface.parameter_values_at_state_time(fs::ExampleSolution2, t) - ptind = searchsortedfirst(fs.pt, fs.t[t]; lt = <=) - fs.p[ptind - 1] -end -SymbolicIndexingInterface.parameter_timeseries(fs::ExampleSolution2) = fs.pt -# Mark the object as a `Timeseries` object -SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution2}) = Timeseries() - -``` - -Now we can create an example object and observe the new functionality. Note that -`sol.ps[sym, args...]` is identical to `getp(sol, sym)(sol, args...)`. - -```@example param_timeseries -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) -sol = ExampleSolution2( - sys, - [i * ones(3) for i in 1:5], - [0.2i for i in 1:5], - [2i * ones(3) for i in 1:10], - [0.1i for i in 1:10] -) -sol.ps[:a] # returns the value at the last timestep -``` - -```@example param_timeseries -sol.ps[:a, :] # use Colon to fetch the entire parameter timeseries -``` - -```@example param_timeseries -sol.ps[:a, 3] # index at a specific index in the parameter timeseries -``` - -```@example param_timeseries -sol.ps[:a, [3, 6, 8]] # index using arrays -``` - -```@example param_timeseries -idxs = @show rand(Bool, 10) # boolean mask for indexing -sol.ps[:a, idxs] -``` - -## Custom containers - -A custom container object (such as `ModelingToolkit.MTKParameters`) should implement -[`remake_buffer`](@ref) to allow creating a new buffer with updated values, possibly -with different types. This is already implemented for `AbstractArray`s (including static -arrays). diff --git a/docs/src/terminology.md b/docs/src/terminology.md index 46c51735..cb9d0372 100644 --- a/docs/src/terminology.md +++ b/docs/src/terminology.md @@ -56,6 +56,12 @@ In code samples, a value provider is typically denoted with the name `valp`. providers. This allows for several syntactic improvements. The [`symbolic_container`](@ref) function is useful in defining such objects. -!!! note "Timeseries objects" - The documentation uses "Timeseries objects" to refer to value providers which implement - the [`Timeseries`](@ref) variant of the [`is_timeseries`](@ref) trait. +### Timeseries objects + +Timeseries objects are value providers which implement the [`Timeseries`](@ref) variant of +the [`is_timeseries`](@ref) trait. + +### Parameter timeseries objects + +Parameter timeseries objects are timeseries objects which implement the +[`Timeseries`](@ref) variant of the [`is_parameter_timeseries`](@ref) trait. diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 3b90d0fe..93c53987 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -8,25 +8,29 @@ using Accessors: @reset RuntimeGeneratedFunctions.init(@__MODULE__) export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname, - Timeseries, NotTimeseries, is_timeseries + Timeseries, NotTimeseries, is_timeseries, is_parameter_timeseries include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, + is_timeseries_parameter, timeseries_parameter_index, ParameterTimeseriesIndex, parameter_symbols, is_independent_variable, independent_variable_symbols, - is_observed, - observed, is_time_dependent, constant_structure, symbolic_container, - all_variable_symbols, - all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate + is_observed, observed, parameter_observed, ParameterObservedFunction, + is_time_dependent, constant_structure, symbolic_container, + all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values, + symbolic_evaluate include("index_provider_interface.jl") export SymbolCache include("symbol_cache.jl") export parameter_values, set_parameter!, finalize_parameters_hook!, - parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, + get_parameter_timeseries_collection, with_updated_parameter_timeseries_values, state_values, set_state!, current_time include("value_provider_interface.jl") +export ParameterTimeseriesCollection +include("parameter_timeseries_collection.jl") + export getp, setp include("parameter_indexing.jl") diff --git a/src/index_provider_interface.jl b/src/index_provider_interface.jl index 512bded7..a114450a 100644 --- a/src/index_provider_interface.jl +++ b/src/index_provider_interface.jl @@ -53,6 +53,92 @@ Return the index of the given parameter `sym` in `indp`, or `nothing` otherwise. """ parameter_index(indp, sym) = parameter_index(symbolic_container(indp), sym) +""" + is_timeseries_parameter(indp, sym) + +Check whether the given `sym` is a timeseries parameter in `indp`. +""" +function is_timeseries_parameter(indp, sym) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) + is_timeseries_parameter(symbolic_container(indp), sym) + else + return false + end +end + +""" + struct ParameterTimeseriesIndex + function ParameterTimeseriesIndex(timeseries_idx, parameter_idx) + +A struct storing the index of the timeseries of a timeseries parameter in a parameter +timeseries object. `timeseries_idx` refers to an index that identifies the timeseries +that the parameter belongs to. `parameter_idx` refers to the index of the parameter's +timeseries in that timeseries object. Note that `parameter_idx` may be different from +the object returned by [`parameter_index`](@ref) for a given parameter. The two fields in +this struct are `timeseries_idx` and `parameter_idx`. +""" +struct ParameterTimeseriesIndex{T, I} + timeseries_idx::T + parameter_idx::I +end + +""" + timeseries_parameter_index(indp, sym) + +Return the index of timeseries parameter `sym` in `indp`. Must return this index as a +[`ParameterTimeseriesIndex`](@ref) object. Return `nothing` if `sym` is not a timeseries +parameter in `indp`. Defaults to returning `nothing`. Respects the +[`symbolic_container`](@ref) fallback for `indp` if present. +""" +function timeseries_parameter_index(indp, sym) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) + timeseries_parameter_index(symbolic_container(indp), sym) + else + return nothing + end +end + +""" + struct ParameterObservedFunction + function ParameterObservedFunction(timeseries_idx, observed_fn::Function) + function ParameterObservedFunction(observed_fn::Function) + +A struct which stores the parameter observed function and optional timeseries index for +a particular symbol. The timeseries index is optional and may be omitted. Specifying the +timeseries index allows [`getp`](@ref) to return the appropriate timeseries for a +timeseries parameter. + +For time-dependent index providers (where `is_time_dependent(indp)`) `observed_fn` must +have the signature `(p, t) -> [values...]`. For non-time-dependent index providers +(where `!is_time_dependent(indp)`) `observed_fn` must have the signature +`(p) -> [values...]`. To support in-place `getp` methods, `observed_fn` must also have an +additional method which takes `buffer::AbstractArray` as its first argument. The required +values must be written to the buffer in the appropriate order. +""" +struct ParameterObservedFunction{I, F <: Function} + timeseries_idx::I + observed_fn::F +end + +""" + parameter_observed(indp, sym) + +Return the observed function of `sym` in `indp` as a [`ParameterObservedFunction`](@ref). +If `sym` only involves variables from a single parameter timeseries (optionally along +with non-timeseries parameters) the timeseries index of the parameter timeseries should +be provided in the [`ParameterObservedFunction`](@ref). In all other cases, just the +observed function should be returned as part of the `ParameterObservedFunction` object. + +By default, this function returns `nothing`. +""" +function parameter_observed(indp, sym) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) + return parameter_observed(symbolic_container(indp), sym) + else + return nothing + end +end + """ parameter_symbols(indp) @@ -88,7 +174,7 @@ is_observed(indp, sym) = is_observed(symbolic_container(indp), sym) Return the observed function of the given `sym` in `indp`. The returned function should have the signature `(u, p) -> [values...]` where `u` and `p` is the current state and -parameter vector, respectively. If `istimedependent(indp) == true`, the function should +parameter object, respectively. If `istimedependent(indp) == true`, the function should accept the current time `t` as its third parameter. If `constant_structure(indp) == false`, `observed` accepts a third parameter, which can either be a vector of symbols indicating the order of states or a time index, which identifies the order of states. This function diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 79336a31..de2d6caf 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -1,21 +1,3 @@ -parameter_values(arr::AbstractArray) = arr -parameter_values(arr::Tuple) = arr -parameter_values(arr::AbstractArray, i) = arr[i] -parameter_values(arr::Tuple, i) = arr[i] -parameter_values(prob, i) = parameter_values(parameter_values(prob), i) - -parameter_values_at_time(p, i) = parameter_values(p) - -parameter_values_at_state_time(p, i) = parameter_values(p) - -parameter_timeseries(_) = [0] - -# Tuple only included for the error message -function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx) - sys[idx] = val -end -set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx) - """ getp(indp, sym) @@ -33,12 +15,25 @@ Requires that the value provider implement [`parameter_values`](@ref). This func may not always need to be implemented, and has a default implementation for collections that implement `getindex`. -If the returned function is used on a timeseries object which saves parameter timeseries, it -can be used to index said timeseries. The timeseries object must implement -[`parameter_timeseries`](@ref), [`parameter_values_at_time`](@ref) and -[`parameter_values_at_state_time`](@ref). The function returned from `getp` will can be passed -`Colon()` (`:`) as the last argument to return the entire parameter timeseries for `p`, or -any index into the parameter timeseries for a subset of values. +If the returned function is used on a timeseries object which saves parameter timeseries, +it can be used to index said timeseries. The timeseries object must implement +[`is_parameter_timeseries`](@ref) and [`get_parameter_timeseries_collection`](@ref). +Additionally, the parameter object must implement +[`with_updated_parameter_timeseries_values`](@ref). + +If `sym` is a timeseries parameter, the function will return the timeseries of the +parameter if the value provider is a parameter timeseries object. An additional argument +can be provided to the function indicating the specific indexes in the timeseries at +which to access the values. If `sym` is an array of parameters, the following cases +apply: + +- All parameters are non-timeseries parameters: The function returns the value of each + parameter. +- All parameters are timeseries parameters: All the parameters must belong to the same + timeseries (otherwise `getp` will error). The function returns the timeseries of all + parameter values, and can be accessed at specific indices in the timeseries. +- A mix of timeseries and non-timeseries parameters: The function can _only_ be used on + non-timeseries objects and will return the value of each parameter at in the object. """ function getp(sys, p) symtype = symbolic_type(p) @@ -46,105 +41,515 @@ function getp(sys, p) _getp(sys, symtype, elsymtype, p) end -struct GetParameterIndex{I} <: AbstractGetIndexer +struct GetParameterIndex{I} <: AbstractParameterGetIndexer idx::I end +is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I} = IndexerNotTimeseries() +function is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I <: + ParameterTimeseriesIndex} + IndexerTimeseries() +end +function indexer_timeseries_index(gpi::GetParameterIndex{<:ParameterTimeseriesIndex}) + gpi.idx.timeseries_idx +end function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob) parameter_values(prob, gpi.idx) end -function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{Int, CartesianIndex}) - parameter_values( - parameter_values_at_time( - prob, only(to_indices(parameter_timeseries(prob), (i,)))), - gpi.idx) +function (gpi::GetParameterIndex)(::Timeseries, prob, args) + throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args)) +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob) + get_parameter_timeseries_collection(prob)[gpi.idx] +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob) + for (buf_idx, ts_idx) in zip(eachindex(buffer), + eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) + buffer[buf_idx] = gpi(ts, prob, ts_idx) + end + return buffer +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + ::Timeseries, prob, i::Union{Int, CartesianIndex}) + parameter_values(prob, gpi.idx, i) +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, ::Colon) + gpi(ts, prob) +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob, ::Colon) + gpi(buffer, ts, prob) +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + ts::Timeseries, prob, i::AbstractArray{Bool}) + map(only(to_indices( + parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx + gpi(ts, prob, idx) + end +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) + for (buf_idx, ts_idx) in zip(eachindex(buffer), + only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) + buffer[buf_idx] = gpi(ts, prob, ts_idx) + end + return buffer +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i) + gpi.((ts,), (prob,), i) +end +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob, i) + for (buf_idx, subidx) in zip(eachindex(buffer), i) + buffer[buf_idx] = gpi(ts, prob, subidx) + end + return buffer end -function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon}) - parameter_values.( - parameter_values_at_time.((prob,), - (j for j in only(to_indices(parameter_timeseries(prob), (i,))))), - (gpi.idx,)) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob) + throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi)) end -function (gpi::GetParameterIndex)(::Timeseries, prob, i) - parameter_values.(parameter_values_at_time.((prob,), i), (gpi.idx,)) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + ::AbstractArray, ::NotTimeseries, prob) + throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi)) end function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) return GetParameterIndex(p) end -function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) - idx = parameter_index(sys, p) - return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, - sys, NotSymbolic(), NotSymbolic(), idx) +struct GetParameterTimeseriesIndex{ + I <: GetParameterIndex, J <: GetParameterIndex{<:ParameterTimeseriesIndex}} <: + AbstractParameterGetIndexer + param_idx::I + param_timeseries_idx::J end -struct MultipleParameterGetters{G} <: AbstractGetIndexer - getters::G +is_indexer_timeseries(::Type{G}) where {G <: GetParameterTimeseriesIndex} = IndexerBoth() +function indexer_timeseries_index(gpti::GetParameterTimeseriesIndex) + indexer_timeseries_index(gpti.param_timeseries_idx) +end +as_not_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.param_idx +function as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) + gpti.param_timeseries_idx +end + +function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...) + gpti.param_timeseries_idx(ts, prob, args...) +end +function (gpti::GetParameterTimeseriesIndex)( + buffer::AbstractArray, ts::Timeseries, prob, args...) + gpti.param_timeseries_idx(buffer, ts, prob, args...) +end +function (gpti::GetParameterTimeseriesIndex)(ts::NotTimeseries, prob) + gpti.param_idx(ts, prob) +end + +struct GetParameterObserved{I, M, F <: Function} <: AbstractParameterGetIndexer + timeseries_idx::I + obsfn::F +end + +function GetParameterObserved{Multiple}(timeseries_idx::I, obsfn::F) where {Multiple, I, F} + if !isa(Multiple, Bool) + throw(TypeError(:GetParameterObserved, "{Multiple}", Bool, Multiple)) + end + return GetParameterObserved{I, Multiple, F}(timeseries_idx, obsfn) end -function (mpg::MultipleParameterGetters)(::IsTimeseriesTrait, prob) - map(g -> g(prob), mpg.getters) +const MultipleGetParameterObserved = GetParameterObserved{I, true} where {I} +const SingleGetParameterObserved = GetParameterObserved{I, false} where {I} + +function is_indexer_timeseries(::Type{G}) where {G <: GetParameterObserved{Nothing}} + IndexerNotTimeseries() +end +is_indexer_timeseries(::Type{G}) where {G <: GetParameterObserved} = IndexerBoth() +indexer_timeseries_index(gpo::GetParameterObserved) = gpo.timeseries_idx +function as_not_timeseries_indexer( + ::IndexerBoth, gpo::GetParameterObserved{I, M}) where {I, M} + return GetParameterObserved{M}(nothing, gpo.obsfn) +end +as_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo + +function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob) + gpo.obsfn(parameter_values(prob), current_time(prob)[end]) +end +for multiple in [true, false] + @eval function (gpo::GetParameterObserved{Nothing, $multiple})( + buffer::AbstractArray, ::Timeseries, prob) + gpo.obsfn(buffer, parameter_values(prob), current_time(prob)[end]) + return buffer + end +end +for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any] + @eval function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob, args::$argType) + throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpo, args)) + end + for multiple in [true, false] + @eval function (gpo::GetParameterObserved{Nothing, $multiple})( + ::AbstractArray, ::Timeseries, prob, args::$argType) + throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpo, args)) + end + end end -function (mpg::MultipleParameterGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex}) - map(g -> g(prob, i), mpg.getters) +function (gpo::GetParameterObserved)(::NotTimeseries, prob) + gpo.obsfn(parameter_values(prob), current_time(prob)) end -function (mpg::MultipleParameterGetters)(::Timeseries, prob, i) - [map(g -> g(prob, j), mpg.getters) - for j in only(to_indices(parameter_timeseries(prob), (i,)))] +function (gpo::GetParameterObserved)(buffer::AbstractArray, ::NotTimeseries, prob) + gpo.obsfn(buffer, parameter_values(prob), current_time(prob)) + return buffer end -function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob) - for (g, bufi) in zip(mpg.getters, eachindex(buffer)) - buffer[bufi] = g(prob) +function (gpo::GetParameterObserved)(::Timeseries, prob) + map(parameter_timeseries(prob, gpo.timeseries_idx)) do t + gpo.obsfn(parameter_values_at_time(prob, t), t) end - buffer end -function (mpg::MultipleParameterGetters)( +function (gpo::MultipleGetParameterObserved)(buffer::AbstractArray, ::Timeseries, prob) + times = parameter_timeseries(prob, gpo.timeseries_idx) + for (buf_idx, time) in zip(eachindex(buffer), times) + gpo.obsfn(buffer[buf_idx], parameter_values_at_time(prob, time), time) + end + return buffer +end +function (gpo::SingleGetParameterObserved)(buffer::AbstractArray, ::Timeseries, prob) + times = parameter_timeseries(prob, gpo.timeseries_idx) + for (buf_idx, time) in zip(eachindex(buffer), times) + buffer[buf_idx] = gpo.obsfn(parameter_values_at_time(prob, time), time) + end + return buffer +end +function (gpo::GetParameterObserved)(::Timeseries, prob, i::Union{Int, CartesianIndex}) + time = parameter_timeseries(prob, gpo.timeseries_idx)[i] + gpo.obsfn(parameter_values_at_time(prob, time), time) +end +function (gpo::MultipleGetParameterObserved)( buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex}) - for (g, bufi) in zip(mpg.getters, eachindex(buffer)) - buffer[bufi] = g(prob, i) + time = parameter_timeseries(prob, gpo.timeseries_idx)[i] + gpo.obsfn(buffer, parameter_values_at_time(prob, time), time) +end +function (gpo::GetParameterObserved)(ts::Timeseries, prob, ::Colon) + gpo(ts, prob) +end +for gpoType in [MultipleGetParameterObserved, SingleGetParameterObserved] + @eval function (gpo::$gpoType)(buffer::AbstractArray, ts::Timeseries, prob, ::Colon) + gpo(buffer, ts, prob) + end +end +function (gpo::GetParameterObserved)(ts::Timeseries, prob, i::AbstractArray{Bool}) + map(only(to_indices(parameter_timeseries(prob, gpo.timeseries_idx), (i,)))) do idx + gpo(ts, prob, idx) + end +end +function (gpo::MultipleGetParameterObserved)( + buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) + for (buf_idx, time_idx) in zip(eachindex(buffer), + only(to_indices(parameter_timeseries(prob, gpo.timeseries_idx), (i,)))) + gpo(buffer[buf_idx], ts, prob, time_idx) + end + return buffer +end +function (gpo::SingleGetParameterObserved)( + buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) + for (buf_idx, time_idx) in zip(eachindex(buffer), + only(to_indices(parameter_timeseries(prob, gpo.timeseries_idx), (i,)))) + buffer[buf_idx] = gpo(ts, prob, time_idx) + end + return buffer +end +function (gpo::GetParameterObserved)(ts::Timeseries, prob, i) + map(i) do idx + gpo(ts, prob, idx) + end +end +function (gpo::MultipleGetParameterObserved)(buffer::AbstractArray, ts::Timeseries, prob, i) + for (buf_idx, time_idx) in zip(eachindex(buffer), i) + gpo(buffer[buf_idx], ts, prob, time_idx) + end + return buffer +end +function (gpo::SingleGetParameterObserved)(buffer::AbstractArray, ts::Timeseries, prob, i) + for (buf_idx, time_idx) in zip(eachindex(buffer), i) + buffer[buf_idx] = gpo(ts, prob, time_idx) + end + return buffer +end + +struct GetParameterObservedNoTime{F <: Function} <: AbstractParameterGetIndexer + obsfn::F +end + +function is_indexer_timeseries(::Type{G}) where {G <: GetParameterObservedNoTime} + IndexerNotTimeseries() +end + +function (gpo::GetParameterObservedNoTime)(::NotTimeseries, prob) + gpo.obsfn(parameter_values(prob)) +end +function (gpo::GetParameterObservedNoTime)(buffer::AbstractArray, ::NotTimeseries, prob) + gpo.obsfn(buffer, parameter_values(prob)) +end + +function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) + if is_parameter(sys, p) + idx = parameter_index(sys, p) + if is_timeseries_parameter(sys, p) + ts_idx = timeseries_parameter_index(sys, p) + return GetParameterTimeseriesIndex( + GetParameterIndex(idx), GetParameterIndex(ts_idx)) + else + return GetParameterIndex(idx) + end + elseif is_observed(sys, p) + pofn = parameter_observed(sys, p) + if pofn === nothing + throw(ArgumentError("Index provider does not support `parameter_observed`; cannot use generate function for $p")) + end + if !is_time_dependent(sys) + return GetParameterObservedNoTime(pofn.observed_fn) + end + return GetParameterObserved{false}(pofn.timeseries_idx, pofn.observed_fn) + end + error("Invalid symbol $p for `getp`") +end + +struct MixedTimeseriesIndexes + indexes::Any +end + +struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: + AbstractParameterGetIndexer + getters::G + timeseries_idx::I +end + +function MultipleParametersGetter(getters) + has_timeseries_indexers = any(getters) do g + is_indexer_timeseries(g) == IndexerTimeseries() + end + has_non_timeseries_indexers = any(getters) do g + is_indexer_timeseries(g) == IndexerNotTimeseries() + end + if has_timeseries_indexers && has_non_timeseries_indexers + throw(ArgumentError("Cannot mix timeseries and non-timeseries indexers in `$MultipleParametersGetter`")) end - buffer + indexer_type = if has_timeseries_indexers + getters = as_timeseries_indexer.(getters) + timeseries_idx = indexer_timeseries_index(first(getters)) + IndexerTimeseries + elseif has_non_timeseries_indexers + getters = as_not_timeseries_indexer.(getters) + timeseries_idx = nothing + IndexerNotTimeseries + else + timeseries_idx = indexer_timeseries_index(first(getters)) + IndexerBoth + end + + if indexer_type != IndexerNotTimeseries && + !allequal(indexer_timeseries_index(g) for g in getters) + if indexer_type == IndexerTimeseries + throw(ArgumentError("All parameters must belong to the same timeseries")) + else + indexer_type = IndexerNotTimeseries + timeseries_idx = MixedTimeseriesIndexes(indexer_timeseries_index.(getters)) + getters = as_not_timeseries_indexer.(getters) + end + end + + return MultipleParametersGetter{indexer_type, typeof(getters), typeof(timeseries_idx)}( + getters, timeseries_idx) +end + +const AtLeastTimeseriesMPG = Union{ + MultipleParametersGetter{IndexerTimeseries}, MultipleParametersGetter{IndexerBoth}} +const MixedTimeseriesIndexMPG = MultipleParametersGetter{ + IndexerNotTimeseries, G, MixedTimeseriesIndexes} where {G} + +is_indexer_timeseries(::Type{<:MultipleParametersGetter{T}}) where {T} = T() +function indexer_timeseries_index(mpg::MultipleParametersGetter) + mpg.timeseries_idx end -function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob, i) - for (bufi, tsi) in zip( - eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,)))) - for (g, bufj) in zip(mpg.getters, eachindex(buffer[bufi])) - buffer[bufi][bufj] = g(prob, tsi) +function as_not_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) + MultipleParametersGetter(as_not_timeseries_indexer.(mpg.getters)) +end + +function as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) + MultipleParametersGetter(as_timeseries_indexer.(mpg.getters)) +end + +for (indexerTimeseriesType, timeseriesType) in [ + (IndexerNotTimeseries, IsTimeseriesTrait), + (IndexerBoth, NotTimeseries) +] + @eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})( + ::$timeseriesType, prob) + return _call.(mpg.getters, (prob,)) + end + @eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})( + buffer::AbstractArray, ::$timeseriesType, prob) + for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters) + buffer[buf_idx] = getter(prob) end + return buffer + end +end + +function (mpg::MixedTimeseriesIndexMPG)(::Timeseries, prob, args...) + throw(MixedParameterTimeseriesIndexError(prob, mpg.timeseries_idx.indexes)) +end + +function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::Timeseries, prob, args) + throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args)) +end +function (mpg::MultipleParametersGetter{IndexerNotTimeseries})( + ::AbstractArray, ::Timeseries, prob, args) + throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args)) +end +function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob) + map(eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg)))) do i + mpg(ts, prob, i) + end +end +function (mpg::AtLeastTimeseriesMPG)(::Timeseries, prob, i::Union{Int, CartesianIndex}) + CallWith(prob, i).(mpg.getters) +end +function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, ::Colon) + mpg(ts, prob) +end +function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, i::AbstractArray{Bool}) + map(only(to_indices( + parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) do idx + mpg(ts, prob, idx) + end +end +function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, i) + mpg.((ts,), (prob,), i) +end +function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob) + for (buf_idx, ts_idx) in zip(eachindex(buffer), + eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg)))) + mpg(buffer[buf_idx], ts, prob, ts_idx) end - buffer + return buffer end -function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::NotTimeseries, prob) - for (g, bufi) in zip(mpg.getters, eachindex(buffer)) - buffer[bufi] = g(prob) +function (mpg::AtLeastTimeseriesMPG)( + buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex}) + for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters) + buffer[buf_idx] = getter(prob, i) end - buffer + return buffer +end +function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, ::Colon) + mpg(buffer, ts, prob) +end +function (mpg::AtLeastTimeseriesMPG)( + buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) + mpg(buffer, ts, prob, + only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) +end +function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, i) + for (buf_idx, ts_idx) in zip(eachindex(buffer), i) + mpg(buffer[buf_idx], ts, prob, ts_idx) + end + return buffer +end +function (mpg::MultipleParametersGetter{IndexerTimeseries})(::NotTimeseries, prob) + throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg)) +end +function (mpg::MultipleParametersGetter{IndexerTimeseries})( + ::AbstractArray, ::NotTimeseries, prob) + throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg)) end -function (mpg::MultipleParameterGetters)(buffer::AbstractArray, prob, i...) - mpg(buffer, is_timeseries(prob), prob, i...) +struct AsParameterTupleWrapper{N, G <: AbstractParameterGetIndexer} <: + AbstractParameterGetIndexer + getter::G end -function (mpg::MultipleParameterGetters)(prob, i...) - mpg(is_timeseries(prob), prob, i...) + +AsParameterTupleWrapper{N}(getter::G) where {N, G} = AsParameterTupleWrapper{N, G}(getter) + +function is_indexer_timeseries(::Type{AsParameterTupleWrapper{N, G}}) where {N, G} + is_indexer_timeseries(G) +end +function indexer_timeseries_index(atw::AsParameterTupleWrapper) + indexer_timeseries_index(atw.getter) +end +function as_timeseries_indexer(::IndexerBoth, atw::AsParameterTupleWrapper{N}) where {N} + AsParameterTupleWrapper{N}(as_timeseries_indexer(atw.getter)) +end +function as_not_timeseries_indexer(::IndexerBoth, atw::AsParameterTupleWrapper{N}) where {N} + AsParameterTupleWrapper{N}(as_not_timeseries_indexer(atw.getter)) +end + +wrap_tuple(::AsParameterTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Val(N)) + +function (atw::AsParameterTupleWrapper)(ts::IsTimeseriesTrait, prob, args...) + atw(ts, is_indexer_timeseries(atw), prob, args...) +end +function (atw::AsParameterTupleWrapper)(ts::Timeseries, ::AtLeastTimeseriesIndexer, prob) + wrap_tuple.((atw,), atw.getter(ts, prob)) +end +function (atw::AsParameterTupleWrapper)( + ts::Timeseries, ::AtLeastTimeseriesIndexer, prob, i::Union{Int, CartesianIndex}) + wrap_tuple(atw, atw.getter(ts, prob, i)) +end +function (atw::AsParameterTupleWrapper)(ts::Timeseries, ::AtLeastTimeseriesIndexer, prob, i) + wrap_tuple.((atw,), atw.getter(ts, prob, i)) +end +# args is just so it throws +function (atw::AsParameterTupleWrapper)( + ts::Timeseries, ::IndexerNotTimeseries, prob, args...) + wrap_tuple(atw, atw.getter(ts, prob, args...)) +end +function (atw::AsParameterTupleWrapper)( + ts::NotTimeseries, ::AtLeastNotTimeseriesIndexer, prob, args...) + wrap_tuple(atw, atw.getter(ts, prob, args...)) +end +function (atw::AsParameterTupleWrapper)( + buffer::AbstractArray, ts::IsTimeseriesTrait, prob, args...) + atw.getter(buffer, ts, prob, args...) end +is_observed_getter(_) = false +is_observed_getter(::GetParameterObserved) = true +is_observed_getter(::GetParameterObservedNoTime) = true +is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.getters) + for (t1, t2) in [ (ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) + # We need to do it this way because if an `ODESystem` has `p[1], p[2], p[3]` as + # parameters (all scalarized) then `is_observed(sys, p[2:3]) == true`. Then, + # `getp` errors on older MTK that doesn't support `parameter_observed`. getters = getp.((sys,), p) - return MultipleParameterGetters(getters) + num_observed = count(is_observed_getter, getters) + + if num_observed == 0 + return MultipleParametersGetter(getters) + else + pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p) + if is_time_dependent(sys) + getter = GetParameterObserved{true}(pofn.timeseries_idx, pofn.observed_fn) + else + getter = GetParameterObservedNoTime(pofn.observed_fn) + end + return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter + end end end function _getp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p) if is_parameter(sys, p) idx = parameter_index(sys, p) - return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, - sys, NotSymbolic(), NotSymbolic(), idx) + if is_timeseries_parameter(sys, p) + ts_idx = timeseries_parameter_index(sys, p) + return GetParameterTimeseriesIndex(idx, ts_idx) + else + return GetParameterIndex(idx) + end end return getp(sys, collect(p)) end @@ -157,7 +562,7 @@ end function (phw::ParameterHookWrapper)(prob, args...) res = phw.setter(prob, args...) finalize_parameters_hook!(prob, phw.original_index) - res + return res end """ @@ -222,6 +627,8 @@ function _setp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p) if is_parameter(sys, p) idx = parameter_index(sys, p) return setp(sys, idx; run_hook = false) + elseif is_observed(sys, p) && (pobsfn = parameter_observed(sys, p)) !== nothing + return GetParameterObserved{false}(pobsfn.timeseries_idx, pobsfn.observed_fn) end return setp(sys, collect(p); run_hook = false) end diff --git a/src/parameter_timeseries_collection.jl b/src/parameter_timeseries_collection.jl new file mode 100644 index 00000000..0da27ab2 --- /dev/null +++ b/src/parameter_timeseries_collection.jl @@ -0,0 +1,205 @@ +""" + struct ParameterTimeseriesCollection{T} + function ParameterTimeseriesCollection(collection) + +A utility struct that helps in storing multiple parameter timeseries. It expects a +collection of timseries objects ([`is_timeseries`](@ref) returns [`Timeseries`](@ref)) +for each. Each of the timeseries objects should implement [`state_values`](@ref) and +[`current_time`](@ref). Effectively, the "states" of each contained timeseries object are +the parameter values it stores the timeseries of. + +The collection is expected to implement `Base.eachindex`, `Base.iterate` and +`Base.getindex`. The indexes of the collection should agree with the timeseries indexes +returned by calling [`timeseries_parameter_index`](@ref) on the corresponding index +provider. + +This type forwards `eachindex`, `iterate` and `length` to the contained `collection`. It +implements `Base.parent` to allow access to the contained `collection`, and has the +following `getindex` methods: + +- `getindex(ptc::ParameterTimeseriesCollection, idx) = ptc.collection[idx]`. +- `getindex(::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex)` returns the + timeseries of the parameter referred to by `idx`. +- `getindex(::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx)` + returns the value of the parameter referred to by `idx` at the time index `subidx`. +- Apart from these cases, if multiple indexes are provided the first is treated as a + timeseries index, the second the time index in the timeseries, and the (optional) + third the index of the parameter in an element of the timeseries. + +The three-argument version of [`parameter_values`](@ref) is implemented for this type. +The single-argument version of `parameter_values` returns the cached parameter object. +This type does not implement any traits. +""" +struct ParameterTimeseriesCollection{T, P} + collection::T + paramcache::P + + function ParameterTimeseriesCollection(collection::T, paramcache::P) where {T, P} + if any(x -> is_timeseries(x) == NotTimeseries(), collection) + throw(ArgumentError(""" + All objects in the collection `ParameterTimeseriesCollection` must be \ + timeseries objects. + """)) + end + new{T, P}(collection, paramcache) + end +end + +Base.eachindex(ptc::ParameterTimeseriesCollection) = eachindex(ptc.collection) + +Base.iterate(ptc::ParameterTimeseriesCollection, args...) = iterate(ptc.collection, args...) + +Base.length(ptc::ParameterTimeseriesCollection) = length(ptc.collection) + +Base.parent(ptc::ParameterTimeseriesCollection) = ptc.collection + +Base.getindex(ptc::ParameterTimeseriesCollection, idx) = ptc.collection[idx] +function Base.getindex(ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex) + timeseries = ptc.collection[idx.timeseries_idx] + return getindex.(state_values(timeseries), (idx.parameter_idx,)) +end +function Base.getindex( + ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx::Union{ + Int, CartesianIndex}) + timeseries = ptc.collection[idx.timeseries_idx] + return state_values(timeseries, subidx)[idx.parameter_idx] +end +function Base.getindex( + ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, ::Colon) + return ptc[idx] +end +function Base.getindex( + ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx::AbstractArray{Bool}) + timeseries = ptc.collection[idx.timeseries_idx] + map(only(to_indices(current_time(timeseries), (subidx,)))) do i + state_values(timeseries, i)[idx.parameter_idx] + end +end +function Base.getindex( + ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx) + timeseries = ptc.collection[idx.timeseries_idx] + getindex.(state_values.((timeseries,), subidx), idx.parameter_idx) +end +function Base.getindex(ptc::ParameterTimeseriesCollection, ts_idx, subidx) + return state_values(ptc.collection[ts_idx], subidx) +end +function Base.getindex(ptc::ParameterTimeseriesCollection, ts_idx, subidx, param_idx) + return ptc[ParameterTimeseriesIndex(ts_idx, param_idx), subidx] +end + +function parameter_values(ptc::ParameterTimeseriesCollection) + return ptc.paramcache +end + +function parameter_values( + ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx) + return ptc[idx, subidx] +end +function parameter_values(prob, i::ParameterTimeseriesIndex, j) + parameter_values(get_parameter_timeseries_collection(prob), i, j) +end +function parameter_timeseries(ptc::ParameterTimeseriesCollection, idx) + return current_time(ptc[idx]) +end + +function _timeseries_value(ptc::ParameterTimeseriesCollection, ts_idx, t) + ts_obj = ptc[ts_idx] + time_idx = searchsortedlast(current_time(ts_obj), t) + value = state_values(ts_obj, time_idx) + return value +end + +""" + parameter_values_at_time(valp, t) + +Return an indexable collection containing the value of all parameters in `valp` at time +`t`. Note that `t` here is a floating-point time, and not an index into a timeseries. + +This has a default implementation relying on [`get_parameter_timeseries_collection`](@ref) +and [`with_updated_parameter_timeseries_values`](@ref). +""" +function parameter_values_at_time(valp, t) + ptc = get_parameter_timeseries_collection(valp) + with_updated_parameter_timeseries_values(ptc.paramcache, + (ts_idx => _timeseries_value(ptc, ts_idx, t) for ts_idx in eachindex(ptc))...) +end + +""" + parameter_values_at_state_time(valp, i) + parameter_values_at_state_time(valp) + +Return an indexable collection containing the value of all parameters in `valp` at time +index `i` in the state timeseries. + +By default, this function relies on [`parameter_values_at_time`](@ref) and +[`current_time`](@ref) for a default implementation. + +The single-argument version of this function is a shorthand to return parameter values +at each point in the state timeseries. This also has a default implementation relying on +[`parameter_values_at_time`](@ref) and [`current_time`](@ref). +""" +function parameter_values_at_state_time end + +function parameter_values_at_state_time(p, i) + state_time = current_time(p, i) + return parameter_values_at_time(p, state_time) +end +function parameter_values_at_state_time(p) + return (parameter_values_at_time(p, t) for t in current_time(p)) +end + +""" + parameter_timeseries(valp, i) + +Return a vector of the time steps at which the parameter values in the parameter +timeseries at index `i` are saved. This is only required for objects where +`is_parameter_timeseries(valp) === Timeseries()`. It will not be called otherwise. It is +assumed that the timeseries is sorted in increasing order. + +See also: [`is_parameter_timeseries`](@ref). +""" +function parameter_timeseries end + +function parameter_timeseries(valp, i) + return parameter_timeseries(get_parameter_timeseries_collection(valp), i) +end + +""" + parameter_timeseries_at_state_time(valp, i, j) + parameter_timeseries_at_state_time(valp, i) + +Return the index of the timestep in the parameter timeseries at timeseries index `i` which +occurs just before or at the same time as the state timestep with index `j`. The two- +argument version of this function returns an iterable of indexes, one for each timestep in +the state timeseries. If `j` is an object that refers to multiple values in the state +timeseries (e.g. `Colon`), return an iterable of the indexes in the parameter timeseries +at the appropriate points. + +Both versions of this function have default implementations relying on +[`current_time`](@ref) and [`parameter_timeseries`](@ref), for the cases where `j` is one +of: `Int`, `CartesianIndex`, `AbstractArray{Bool}`, `Colon` or an iterable of the +aforementioned. +""" +function parameter_timeseries_at_state_time end + +function parameter_timeseries_at_state_time(valp, i, j::Union{Int, CartesianIndex}) + state_time = current_time(valp, j) + timeseries = parameter_timeseries(valp, i) + searchsortedlast(timeseries, state_time) +end + +function parameter_timeseries_at_state_time(valp, i, ::Colon) + parameter_timeseries_at_state_time(valp, i) +end + +function parameter_timeseries_at_state_time(valp, i, j::AbstractArray{Bool}) + parameter_timeseries_at_state_time(valp, i, only(to_indices(current_time(valp), (j,)))) +end + +function parameter_timeseries_at_state_time(valp, i, j) + (parameter_timeseries_at_state_time(valp, i, jj) for jj in j) +end + +function parameter_timeseries_at_state_time(valp, i) + parameter_timeseries_at_state_time(valp, i, eachindex(current_time(valp))) +end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index cd16b10f..25d18666 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -1,12 +1,7 @@ -state_values(arr::AbstractArray) = arr -state_values(arr, i) = state_values(arr)[i] - function set_state!(sys, val, idx) state_values(sys)[idx] = val end -current_time(p, i) = current_time(p)[i] - """ getu(indp, sym) @@ -25,6 +20,10 @@ support symbolic expressions, the value provider must implement [`observed`](@re This function typically does not need to be implemented, and has a default implementation relying on the above functions. + +If the value provider is a parameter timeseries object, the same rules apply as +[`getp`](@ref). The difference here is that `sym` may also contain non-parameter symbols, +and the values are always returned corresponding to the state timeseries. """ function getu(sys, sym) symtype = symbolic_type(sym) @@ -32,15 +31,18 @@ function getu(sys, sym) _getu(sys, symtype, elsymtype, sym) end -struct GetStateIndex{I} <: AbstractGetIndexer +struct GetStateIndex{I} <: AbstractStateGetIndexer idx::I end function (gsi::GetStateIndex)(::Timeseries, prob) getindex.(state_values(prob), (gsi.idx,)) end -function (gsi::GetStateIndex)(::Timeseries, prob, i) +function (gsi::GetStateIndex)(::Timeseries, prob, i::Union{Int, CartesianIndex}) getindex(state_values(prob, i), gsi.idx) end +function (gsi::GetStateIndex)(::Timeseries, prob, i) + getindex.(state_values(prob, i), gsi.idx) +end function (gsi::GetStateIndex)(::NotTimeseries, prob) state_values(prob, gsi.idx) end @@ -49,46 +51,111 @@ function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) return GetStateIndex(sym) end -struct GetpAtStateTime{G} <: AbstractGetIndexer +struct GetpAtStateTime{G} <: AbstractStateGetIndexer getter::G end -function (g::GetpAtStateTime)(::Timeseries, prob) - [g.getter(parameter_values_at_state_time(prob, i)) - for i in eachindex(current_time(prob))] +function (g::GetpAtStateTime)(ts::Timeseries, prob) + g(ts, is_parameter_timeseries(prob), prob) +end +function (g::GetpAtStateTime)(ts::Timeseries, prob, i) + g(ts, is_parameter_timeseries(prob), prob, i) +end +function (g::GetpAtStateTime)(::Timeseries, ::NotTimeseries, prob, _...) + g.getter(prob) +end +function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob) + g(ts, p_ts, is_indexer_timeseries(g.getter), prob) +end +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob) + g.getter.((prob,), + parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter))) +end +function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob) + g.getter(prob) +end +function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob, i) + g(ts, p_ts, is_indexer_timeseries(g.getter), prob, i) +end +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob, i) + g.getter(prob, + parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter), i)) +end +function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, + prob, ::Union{Int, CartesianIndex}) + g.getter(prob) +end +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Colon) + map(_ -> g.getter(prob), current_time(prob)) +end +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i::AbstractArray{Bool}) + num_ones = sum(i) + map(_ -> g.getter(prob), 1:num_ones) end -function (g::GetpAtStateTime)(::Timeseries, prob, i) - g.getter(parameter_values_at_state_time(prob, i)) +function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i) + map(_ -> g.getter(prob), 1:length(i)) end function (g::GetpAtStateTime)(::NotTimeseries, prob) g.getter(prob) end -struct GetIndepvar <: AbstractGetIndexer end +struct GetIndepvar <: AbstractStateGetIndexer end (::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob) (::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i) -struct TimeDependentObservedFunction{F} <: AbstractGetIndexer +struct TimeDependentObservedFunction{F} <: AbstractStateGetIndexer obsfn::F end -function (o::TimeDependentObservedFunction)(::Timeseries, prob) - curtime = current_time(prob) - return o.obsfn.(state_values(prob), - (parameter_values_at_state_time(prob, i) for i in eachindex(curtime)), - curtime) +function (o::TimeDependentObservedFunction)(ts::Timeseries, prob) + return o(ts, is_parameter_timeseries(prob), prob) end -function (o::TimeDependentObservedFunction)(::Timeseries, prob, i) +function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob) + map(o.obsfn, state_values(prob), + parameter_values_at_state_time(prob), current_time(prob)) +end +function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob) + o.obsfn.(state_values(prob), + (parameter_values(prob),), + current_time(prob)) +end +function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i) + return o(ts, is_parameter_timeseries(prob), prob, i) +end +function (o::TimeDependentObservedFunction)( + ::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex}) return o.obsfn(state_values(prob, i), parameter_values_at_state_time(prob, i), current_time(prob, i)) end +function (o::TimeDependentObservedFunction)( + ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon) + return o(ts, p_ts, prob) +end +function (o::TimeDependentObservedFunction)( + ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool}) + map(only(to_indices(current_time(prob), (i,)))) do idx + o(ts, p_ts, prob, idx) + end +end +function (o::TimeDependentObservedFunction)( + ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i) + o.((ts,), (p_ts,), (prob,), i) +end +function (o::TimeDependentObservedFunction)( + ::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex}) + o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) +end function (o::TimeDependentObservedFunction)(::NotTimeseries, prob) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end -struct TimeIndependentObservedFunction{F} <: AbstractGetIndexer +struct TimeIndependentObservedFunction{F} <: AbstractStateGetIndexer obsfn::F end @@ -115,33 +182,50 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) error("Invalid symbol $sym for `getu`") end -struct MultipleGetters{G} <: AbstractGetIndexer +struct MultipleGetters{G} <: AbstractStateGetIndexer getters::G end -function (mg::MultipleGetters)(::Timeseries, prob) - return broadcast(i -> map(g -> g(prob, i), mg.getters), - eachindex(state_values(prob))) +function (mg::MultipleGetters)(ts::Timeseries, prob) + return mg.((ts,), (prob,), eachindex(current_time(prob))) +end +function (mg::MultipleGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex}) + return map(CallWith(prob, i), mg.getters) end -function (mg::MultipleGetters)(::Timeseries, prob, i) - return map(g -> g(prob, i), mg.getters) +function (mg::MultipleGetters)(ts::Timeseries, prob, ::Colon) + return mg(ts, prob) +end +function (mg::MultipleGetters)(ts::Timeseries, prob, i::AbstractArray{Bool}) + return map(only(to_indices(current_time(prob), (i,)))) do idx + mg(ts, prob, idx) + end +end +function (mg::MultipleGetters)(ts::Timeseries, prob, i) + mg.((ts,), (prob,), i) end function (mg::MultipleGetters)(::NotTimeseries, prob) return map(g -> g(prob), mg.getters) end -struct AsTupleWrapper{G} <: AbstractGetIndexer +struct AsTupleWrapper{N, G} <: AbstractStateGetIndexer getter::G end +AsTupleWrapper{N}(getter::G) where {N, G} = AsTupleWrapper{N, G}(getter) + +wrap_tuple(::AsTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Val(N)) + function (atw::AsTupleWrapper)(::Timeseries, prob) - return Tuple.(atw.getter(prob)) + return wrap_tuple.((atw,), atw.getter(prob)) +end +function (atw::AsTupleWrapper)(::Timeseries, prob, i::Union{Int, CartesianIndex}) + return wrap_tuple(atw, atw.getter(prob, i)) end function (atw::AsTupleWrapper)(::Timeseries, prob, i) - return Tuple(atw.getter(prob, i)) + return wrap_tuple.((atw,), atw.getter(prob, i)) end function (atw::AsTupleWrapper)(::NotTimeseries, prob) - return Tuple(atw.getter(prob)) + wrap_tuple(atw, atw.getter(prob)) end for (t1, t2) in [ @@ -151,9 +235,14 @@ for (t1, t2) in [ ] @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) num_observed = count(x -> is_observed(sys, x), sym) - if num_observed <= 1 - getters = getu.((sys,), sym) - return MultipleGetters(getters) + if num_observed == 0 || num_observed == 1 && sym isa Tuple + if all(Base.Fix1(is_parameter, sys), sym) && + all(!Base.Fix1(is_timeseries_parameter, sys), sym) + GetpAtStateTime(getp(sys, sym)) + else + getters = getu.((sys,), sym) + return MultipleGetters(getters) + end else obs = observed(sys, sym isa Tuple ? collect(sym) : sym) getter = if is_time_dependent(sys) @@ -162,7 +251,7 @@ for (t1, t2) in [ TimeIndependentObservedFunction(obs) end if sym isa Tuple - getter = AsTupleWrapper(getter) + getter = AsTupleWrapper{length(sym)}(getter) end return getter end @@ -174,7 +263,7 @@ function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) idx = variable_index(sys, sym) return getu(sys, idx) elseif is_parameter(sys, sym) - return getp(sys, sym) + return GetpAtStateTime(getp(sys, sym)) end return getu(sys, collect(sym)) end diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index ad61409a..76262b52 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -1,52 +1,106 @@ """ - struct SymbolCache{V,P,I} - function SymbolCache(vars, [params, [indepvars]]) + struct SymbolCache + function SymbolCache(vars, [params, [indepvars]]; defaults = Dict(), timeseries_parameters = nothing) -A struct implementing the index provider interface for the trivial case of having a -vector of variables, parameters, and independent variables. It is considered time -dependent if it contains at least one independent variable. It returns `true` for -`is_observed(::SymbolCache, sym)` if `sym isa Expr`. Functions can be generated using +A struct implementing the index provider interface for a collection of variables, +parameters, and independent variables. `vars` and `params` can be specified as arrays +(in which case the index of a symbol is its index in the array) or `AbstractDict`s +mapping symbols to indices. It is considered time dependent if it contains at least one +independent variable. + +It returns `true` for `is_observed(::SymbolCache, sym)` if +`sym isa Union{Expr, Array{Expr}, Tuple{Vararg{Expr}}`. Functions can be generated using `observed` for `Expr`s involving variables in the `SymbolCache` if it has at most one independent variable. +`defaults` is an `AbstractDict` mapping variables and/or parameters to their default +initial values. The default initial values can also be other variables/ +parameters or expressions of them. `timeseries_parameters` is an `AbstractDict` the +timeseries parameters in `params` to their [`ParameterTimeseriesIndex`](@ref) indexes. + The independent variable may be specified as a single symbolic variable instead of an array containing a single variable if the system has only one independent variable. """ struct SymbolCache{ - V <: Union{Nothing, AbstractVector}, - P <: Union{Nothing, AbstractVector}, + V <: Union{Nothing, AbstractDict}, + P <: Union{Nothing, AbstractDict}, + T <: Union{Nothing, AbstractDict}, I, - D <: Dict + D <: AbstractDict } variables::V parameters::P + timeseries_parameters::T independent_variables::I defaults::D end +function to_dict_or_nothing(arr::Union{AbstractArray, Tuple}) + eltype(arr) <: Pair && return Dict(arr) + isempty(arr) && return nothing + return Dict(v => k for (k, v) in enumerate(arr)) +end +to_dict_or_nothing(d::AbstractDict) = d +to_dict_or_nothing(::Nothing) = nothing + function SymbolCache(vars = nothing, params = nothing, indepvars = nothing; - defaults = Dict{Symbol, Union{Symbol, Expr, Number}}()) - return SymbolCache{typeof(vars), typeof(params), typeof(indepvars), typeof(defaults)}( + defaults = Dict(), timeseries_parameters = nothing) + vars = to_dict_or_nothing(vars) + params = to_dict_or_nothing(params) + timeseries_parameters = to_dict_or_nothing(timeseries_parameters) + if timeseries_parameters !== nothing + if indepvars === nothing + throw(ArgumentError("Independent variable is required for timeseries parameters to exist")) + end + for (k, v) in timeseries_parameters + if !haskey(params, k) + throw(ArgumentError("Timeseries parameter $k must also be present in parameters.")) + end + if !isa(v, ParameterTimeseriesIndex) + throw(TypeError(:SymbolCache, "index of timeseries parameter $k", + ParameterTimeseriesIndex, v)) + end + end + end + return SymbolCache{typeof(vars), typeof(params), typeof(timeseries_parameters), + typeof(indepvars), typeof(defaults)}( vars, params, + timeseries_parameters, indepvars, defaults) end function is_variable(sc::SymbolCache, sym) - sc.variables !== nothing && any(isequal(sym), sc.variables) + sc.variables !== nothing && haskey(sc.variables, sym) end function variable_index(sc::SymbolCache, sym) - sc.variables === nothing ? nothing : findfirst(isequal(sym), sc.variables) + sc.variables === nothing ? nothing : get(sc.variables, sym, nothing) +end +function variable_symbols(sc::SymbolCache, i = nothing) + sc.variables === nothing && return [] + buffer = collect(keys(sc.variables)) + for (k, v) in sc.variables + buffer[v] = k + end + return buffer end -variable_symbols(sc::SymbolCache, i = nothing) = something(sc.variables, []) function is_parameter(sc::SymbolCache, sym) - sc.parameters !== nothing && any(isequal(sym), sc.parameters) + sc.parameters !== nothing && haskey(sc.parameters, sym) end function parameter_index(sc::SymbolCache, sym) - sc.parameters === nothing ? nothing : findfirst(isequal(sym), sc.parameters) + sc.parameters === nothing ? nothing : get(sc.parameters, sym, nothing) +end +function parameter_symbols(sc::SymbolCache) + sc.parameters === nothing ? [] : collect(keys(sc.parameters)) +end +function is_timeseries_parameter(sc::SymbolCache, sym) + sc.timeseries_parameters !== nothing && haskey(sc.timeseries_parameters, sym) +end +function timeseries_parameter_index(sc::SymbolCache, sym) + sc.timeseries_parameters === nothing ? nothing : + get(sc.timeseries_parameters, sym, nothing) end -parameter_symbols(sc::SymbolCache) = something(sc.parameters, []) function is_independent_variable(sc::SymbolCache, sym) sc.independent_variables === nothing && return false if symbolic_type(sc.independent_variables) == NotSymbolic() @@ -69,15 +123,17 @@ function independent_variable_symbols(sc::SymbolCache) end is_observed(sc::SymbolCache, sym) = false is_observed(::SymbolCache, ::Expr) = true -is_observed(::SymbolCache, ::AbstractArray{Expr}) = true +is_observed(::SymbolCache, ::Array{Expr}) = true is_observed(::SymbolCache, ::Tuple{Vararg{Expr}}) = true +# TODO: Make this less hacky struct ExpressionSearcher + parameters::Set{Symbol} declared::Set{Symbol} fnbody::Expr end -ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Expr(:block)) +ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Set{Symbol}(), Expr(:block)) function (exs::ExpressionSearcher)(sys, expr::Expr) for arg in expr.args @@ -94,7 +150,8 @@ function (exs::ExpressionSearcher)(sys, sym::Symbol) push!(exs.fnbody.args, :($sym = u[$idx])) elseif is_parameter(sys, sym) idx = parameter_index(sys, sym) - push!(exs.fnbody.args, :($sym = p[$idx])) + push!(exs.parameters, sym) + push!(exs.fnbody.args, :($sym = parameter_values(p, $idx))) elseif is_independent_variable(sys, sym) push!(exs.fnbody.args, :($sym = t)) end @@ -124,11 +181,104 @@ function observed(sc::SymbolCache, expr::Expr) end end end -function observed(sc::SymbolCache, exprs::AbstractArray{Expr}) - return observed(sc, :(reshape([$(exprs...)], $(size(exprs))))) + +to_expr(exprs::AbstractArray) = :(reshape([$(exprs...)], $(size(exprs)))) +to_expr(exprs::Tuple) = :(($(exprs...),)) + +function inplace_observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple}) + let cache = Dict{Expr, Function}() + return get!(cache, to_expr(exprs)) do + exs = ExpressionSearcher() + for expr in exprs + exs(sc, expr) + end + update_expr = Expr(:block) + for (i, expr) in enumerate(exprs) + push!(update_expr.args, :(buffer[$i] = $expr)) + end + fnexpr = if is_time_dependent(sc) + :(function (buffer, u, p, t) + $(exs.fnbody) + $update_expr + return buffer + end) + else + :(function (buffer, u, p) + $(exs.fnbody) + $update_expr + return buffer + end) + end + return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr) + end + end +end + +function observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple}) + for expr in exprs + if !(expr isa Union{Symbol, Expr}) + throw(TypeError(:observed, "SymbolCache", Union{Symbol, Expr}, expr)) + end + end + return observed(sc, to_expr(exprs)) +end + +function parameter_observed(sc::SymbolCache, expr::Expr) + if is_time_dependent(sc) + exs = ExpressionSearcher() + exs(sc, expr) + ts_idxs = Set() + for p in exs.parameters + is_timeseries_parameter(sc, p) || continue + push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx) + end + f = let fn = observed(sc, expr) + f1(p, t) = fn(nothing, p, t) + end + if length(ts_idxs) == 1 + return ParameterObservedFunction(only(ts_idxs), f) + else + return ParameterObservedFunction(nothing, f) + end + else + f = let fn = observed(sc, expr) + f2(p) = fn(nothing, p) + end + return ParameterObservedFunction(nothing, f) + end end -function observed(sc::SymbolCache, exprs::Tuple{Vararg{Expr}}) - return observed(sc, :(($(exprs...),))) + +function parameter_observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple}) + for ex in exprs + if !(ex isa Union{Symbol, Expr}) + throw(TypeError(:parameter_observed, "SymbolCache", Union{Symbol, Expr}, ex)) + end + end + if is_time_dependent(sc) + exs = ExpressionSearcher() + exs(sc, to_expr(exprs)) + ts_idxs = Set() + for p in exs.parameters + is_timeseries_parameter(sc, p) || continue + push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx) + end + + f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) + f1(p, t) = oop(nothing, p, t) + f1(buffer, p, t) = iip(buffer, nothing, p, t) + end + if length(ts_idxs) == 1 + return ParameterObservedFunction(only(ts_idxs), f) + else + return ParameterObservedFunction(nothing, f) + end + else + f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) + f2(p) = oop(nothing, p) + f2(buffer, p) = iip(buffer, nothing, p) + end + return ParameterObservedFunction(nothing, f) + end end function is_time_dependent(sc::SymbolCache) @@ -149,6 +299,7 @@ default_values(sc::SymbolCache) = sc.defaults function Base.copy(sc::SymbolCache) return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), sc.parameters === nothing ? nothing : copy(sc.parameters), + sc.timeseries_parameters === nothing ? nothing : copy(sc.timeseries_parameters), sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) : sc.independent_variables, copy(sc.defaults)) end diff --git a/src/trait.jl b/src/trait.jl index ea7964a9..de44aa5d 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -129,7 +129,7 @@ data. It may still be time-dependent. For example, an `ODEProblem` only stores the initial state of a system, so it is `NotTimeseries`, but still time-dependent. This is the default trait variant for all types. -See also: [`Timeseries`](@ref), [`is_timeseries`](@ref) +See also: [`Timeseries`](@ref), [`is_timeseries`](@ref). """ struct NotTimeseries <: IsTimeseriesTrait end @@ -138,10 +138,27 @@ struct NotTimeseries <: IsTimeseriesTrait end is_timeseries(::Type) Get the timeseries trait of a type. Defaults to [`NotTimeseries`](@ref) for all types. +A type for which `is_timeseries(T) == Timeseries()` may also have a parameter timeseries. +This is determined by the [`is_parameter_timeseries`](@ref) trait. -See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref) +See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref), [`is_parameter_timeseries`](@ref). """ function is_timeseries end is_timeseries(x) = is_timeseries(typeof(x)) is_timeseries(::Type) = NotTimeseries() + +""" + is_parameter_timeseries(x) = is_parameter_timeseries(typeof(x)) + is_parameter_timeseries(::Type) + +Get the parameter timeseries trait of a type. Defaults to [`NotTimeseries`](@ref) for all +types. A type for which `is_parameter_timeseries(T) == Timeseries()` must also have +`is_timeseries(T) == Timeseries()`. + +See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref), [`is_timeseries`](@ref). +""" +function is_parameter_timeseries end + +is_parameter_timeseries(x) = is_parameter_timeseries(typeof(x)) +is_parameter_timeseries(::Type) = NotTimeseries() diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 1f455b8c..9a9cb0e4 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -16,51 +16,32 @@ array/tuple. """ function parameter_values end -""" - parameter_values_at_time(valp, i) - -Return an indexable collection containing the value of all parameters in `valp` at time -index `i`. This is useful when parameter values change during the simulation (such as -through callbacks) and their values are saved. `i` is the time index in the timeserie - formed by these changing parameter values, obtained using [`parameter_timeseries`](@ref). - -By default, this function returns `parameter_values(valp)` regardless of `i`, and only needs -to be specialized for timeseries objects where parameter values are not constant at all -times. The resultant object should be indexable using [`parameter_values`](@ref). +parameter_values(arr::AbstractArray) = arr +parameter_values(arr::Tuple) = arr +parameter_values(arr::AbstractArray, i) = arr[i] +parameter_values(arr::Tuple, i) = arr[i] +parameter_values(prob, i) = parameter_values(parameter_values(prob), i) -If this function is implemented, [`parameter_values_at_state_time`](@ref) must be -implemented for [`getu`](@ref) to work correctly. """ -function parameter_values_at_time end - -""" - parameter_values_at_state_time(valp, i) + parameter_values_at_time(valp, t) Return an indexable collection containing the value of all parameters in `valp` at time -index `i`. This is useful when parameter values change during the simulation (such as -through callbacks) and their values are saved. `i` is the time index in the timeseries -formed by dependent variables (as opposed to the timeseries of the parameters, as in -[`parameter_values_at_time`](@ref)). +`t`. Note that `t` here is a floating-point time, and not an index into a timeseries. -By default, this function returns `parameter_values(valp)` regardless of `i`, and only -needs to be specialized for timeseries objects where parameter values are not constant at -all times. The resultant object should be indexable using [`parameter_values`](@ref). - -If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for -[`getp`](@ref) to work correctly. +This is useful for parameter timeseries objects, since some parameters change over time. """ -function parameter_values_at_state_time end +function get_parameter_timeseries_collection end """ - parameter_timeseries(valp) - -Return an iterable of time steps at which the parameter values are saved. This is only -required for objects where `is_timeseries(valp) === Timeseries()` and the parameter values -change during the simulation (such as through callbacks). By default, this returns `[0]`. + with_updated_parameter_timeseries_values(valp, args::Pair...) -See also: [`parameter_values_at_time`](@ref). +Return an indexable collection containing the value of all parameters in `valp`, with +parameters belonging to specific timeseries updated to different values. Each element in +`args...` contains the timeseries index as the first value, and the saved parameter values +in that partition. Not all parameter timeseries have to be updated using this method. If +an in-place update can be performed, it should be done and the modified `valp` returned. """ -function parameter_timeseries end +function with_updated_parameter_timeseries_values end """ set_parameter!(valp, val, idx) @@ -74,6 +55,12 @@ See: [`parameter_values`](@ref) """ function set_parameter! end +# Tuple only included for the error message +function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx) + sys[idx] = val +end +set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx) + """ finalize_parameters_hook!(valp, sym) @@ -97,7 +84,10 @@ Return an indexable collection containing the values of all states in the value each of which contain the state values at the corresponding timestep. In this case, the two-argument version of the function can also be implemented to efficiently return the state values at timestep `i`. By default, the two-argument method calls -`state_values(valp)[i]` +`state_values(valp)[i]`. If `i` consists of multiple indices (for example, `Colon`, +`AbstractArray{Int}`, `AbstractArray{Bool}`) specialized methods may be defined for +efficiency. By default, `state_values(valp, ::Colon) = state_values(valp)` to avoid +copying the timeseries. If this function is called with an `AbstractArray`, it will return the same array. @@ -105,6 +95,10 @@ See: [`is_timeseries`](@ref) """ function state_values end +state_values(arr::AbstractArray) = arr +state_values(arr, i) = state_values(arr)[i] +state_values(arr, ::Colon) = state_values(arr) + """ set_state!(valp, val, idx) @@ -125,13 +119,24 @@ Return the current time in the value provider `valp`. If `is_timeseries(valp)` is [`Timeseries`](@ref), return the vector of timesteps at which the state value is saved. In this case, the two-argument version of the function can also be implemented to efficiently return the time at timestep `i`. By default, the two- -argument method calls `current_time(p)[i]` +argument method calls `current_time(p)[i]`. It is assumed that the timeseries is sorted +in increasing order. +If `i` consists of multiple indices (for example, `Colon`, `AbstractArray{Int}`, +`AbstractArray{Bool}`) specialized methods may be defined for efficiency. By default, +`current_time(valp, ::Colon) = current_time(valp)` to avoid copying the timeseries. + +By default, the single-argument version acts as the identity function if +`valp isa AbstractVector`. See: [`is_timeseries`](@ref) """ function current_time end +current_time(arr::AbstractVector) = arr +current_time(valp, i) = current_time(valp)[i] +current_time(valp, ::Colon) = current_time(valp) + ########### # Utilities ########### @@ -139,7 +144,133 @@ function current_time end abstract type AbstractIndexer end abstract type AbstractGetIndexer <: AbstractIndexer end +abstract type AbstractStateGetIndexer <: AbstractGetIndexer end +abstract type AbstractParameterGetIndexer <: AbstractGetIndexer end abstract type AbstractSetIndexer <: AbstractIndexer end -(ai::AbstractGetIndexer)(prob) = ai(is_timeseries(prob), prob) -(ai::AbstractGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i) +(ai::AbstractStateGetIndexer)(prob) = ai(is_timeseries(prob), prob) +(ai::AbstractStateGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i) +(ai::AbstractParameterGetIndexer)(prob) = ai(is_parameter_timeseries(prob), prob) +(ai::AbstractParameterGetIndexer)(prob, i) = ai(is_parameter_timeseries(prob), prob, i) +function (ai::AbstractParameterGetIndexer)(buffer::AbstractArray, prob) + ai(buffer, is_parameter_timeseries(prob), prob) +end +function (ai::AbstractParameterGetIndexer)(buffer::AbstractArray, prob, i) + ai(buffer, is_parameter_timeseries(prob), prob, i) +end + +abstract type IsIndexerTimeseries end + +struct IndexerTimeseries <: IsIndexerTimeseries end +struct IndexerNotTimeseries <: IsIndexerTimeseries end +struct IndexerBoth <: IsIndexerTimeseries end + +const AtLeastTimeseriesIndexer = Union{IndexerTimeseries, IndexerBoth} +const AtLeastNotTimeseriesIndexer = Union{IndexerNotTimeseries, IndexerBoth} + +is_indexer_timeseries(x) = is_indexer_timeseries(typeof(x)) +function indexer_timeseries_index end + +as_not_timeseries_indexer(x) = as_not_timeseries_indexer(is_indexer_timeseries(x), x) +as_not_timeseries_indexer(::IndexerNotTimeseries, x) = x +function as_not_timeseries_indexer(::IndexerTimeseries, x) + error(""" + Tried to convert an `$IndexerTimeseries` to an `$IndexerNotTimeseries`. This \ + should never happen. Please file an issue with an MWE. + """) +end + +as_timeseries_indexer(x) = as_timeseries_indexer(is_indexer_timeseries(x), x) +as_timeseries_indexer(::IndexerTimeseries, x) = x +function as_timeseries_indexer(::IndexerNotTimeseries, x) + error(""" + Tried to convert an `$IndexerNotTimeseries` to an `$IndexerTimeseries`. This \ + should never happen. Please file an issue with an MWE. + """) +end + +struct CallWith{A} + args::A + + CallWith(args...) = new{typeof(args)}(args) +end + +function (cw::CallWith)(arg) + arg(cw.args...) +end + +function _call(f, args...) + return f(args...) +end + +########### +# Errors +########### + +struct ParameterTimeseriesValueIndexMismatchError{P <: IsTimeseriesTrait} <: Exception + valp::Any + indexer::Any + args::Any + + function ParameterTimeseriesValueIndexMismatchError{Timeseries}(valp, indexer, args) + if is_parameter_timeseries(valp) != Timeseries() + throw(ArgumentError(""" + This should never happen. Expected parameter timeseries value provider, \ + got $(valp). Open an issue in SymbolicIndexingInterface.jl with an MWE. + """)) + end + if is_indexer_timeseries(indexer) != IndexerNotTimeseries() + throw(ArgumentError(""" + This should never happen. Expected non-timeseries indexer, got \ + $(indexer). Open an issue in SymbolicIndexingInterface.jl with an MWE. + """)) + end + return new{Timeseries}(valp, indexer, args) + end + function ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(valp, indexer) + if is_parameter_timeseries(valp) != NotTimeseries() + throw(ArgumentError(""" + This should never happen. Expected non-parameter timeseries value \ + provider, got $(valp). Open an issue in SymbolicIndexingInterface.jl \ + with an MWE. + """)) + end + if is_indexer_timeseries(indexer) != IndexerTimeseries() + throw(ArgumentError(""" + This should never happen. Expected timeseries indexer, got $(indexer). \ + Open an issue in SymbolicIndexingInterface.jl with an MWE. + """)) + end + return new{NotTimeseries}(valp, indexer, nothing) + end +end + +function Base.showerror(io::IO, err::ParameterTimeseriesValueIndexMismatchError{Timeseries}) + print(io, """ + Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \ + (which is a parameter timeseries object) with non-timeseries indexer \ + $(err.indexer) at index $(err.args) in the timeseries. + """) +end + +function Base.showerror( + io::IO, err::ParameterTimeseriesValueIndexMismatchError{NotTimeseries}) + print(io, """ + Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \ + (which is not a parameter timeseries object) using timeseries indexer \ + $(err.indexer). + """) +end + +struct MixedParameterTimeseriesIndexError <: Exception + valp::Any + ts_idxs::Any +end + +function Base.showerror(io::IO, err::MixedParameterTimeseriesIndexError) + print(io, """ + Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \ + (which is a parameter timeseries object) with variables having mixed timeseries \ + indexes $(err.ts_idxs). + """) +end diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index bad15fbc..954b5e24 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,4 +1,8 @@ [deps] ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + +[compat] +SymbolicUtils = "<1.6" diff --git a/test/example_test.jl b/test/example_test.jl index 5d3caa0c..521de32a 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -71,6 +71,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test all(.!is_parameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) @test all(parameter_index.((sys,), [:c, :a, :b]) .== [3, 1, 2]) @test all(parameter_index.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) +@test all(.!is_timeseries_parameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) # fallback even if not implemented +@test all(timeseries_parameter_index.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) # fallback @test is_independent_variable(sys, :t) @test all(.!is_independent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r])) @test all(is_observed.((sys,), [:x, :y, :z, :a, :b, :c, :t])) @@ -88,6 +90,7 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t) @test independent_variable_symbols(sys) == [:t] @test all_variable_symbols(sys) == [:x, :y, :z] @test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z] +@test default_values(sys) == Dict() # fallback even if not implemented sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 42dd9500..5b392ab1 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -1,9 +1,20 @@ using SymbolicIndexingInterface +using SymbolicIndexingInterface: IndexerTimeseries, IndexerNotTimeseries, IndexerBoth, + is_indexer_timeseries, indexer_timeseries_index, + ParameterTimeseriesValueIndexMismatchError, + MixedParameterTimeseriesIndexError using Test +arr = [1.0, 2.0, 3.0] +@test parameter_values(arr) == arr +@test current_time(arr) == arr +tp = (1.0, 2.0, 3.0) +@test parameter_values(tp) == tp + struct FakeIntegrator{S, P} sys::S p::P + t::Float64 counter::Ref{Int} end @@ -12,99 +23,139 @@ function Base.getproperty(fi::FakeIntegrator, s::Symbol) end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p +SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, p) fi.counter[] += 1 end -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -for pType in [Vector, Tuple] - p = [1.0, 2.0, 3.0] - fi = FakeIntegrator(sys, pType(copy(p)), Ref(0)) - new_p = [4.0, 5.0, 6.0] - @test parameter_timeseries(fi) == [0] - for (sym, oldval, newval, check_inference) in [ - (:a, p[1], new_p[1], true), - (1, p[1], new_p[1], true), - ([:a, :b], p[1:2], new_p[1:2], true), - (1:2, p[1:2], new_p[1:2], true), - ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), - ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), - ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) - ] - get = getp(sys, sym) - set! = setp(sys, sym) - if check_inference - @inferred get(fi) - end - @test get(fi) == fi.ps[sym] - @test get(fi) == oldval +for sys in [ + SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t]), + SymbolCache([:x, :y, :z], + [:a, :b, :c, :d], + [:t], + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) +] + has_ts = sys.timeseries_parameters !== nothing + for pType in [Vector, Tuple] + p = [1.0, 2.0, 3.0, 4.0] + fi = FakeIntegrator(sys, pType(copy(p)), 9.0, Ref(0)) + new_p = [4.0, 5.0, 6.0, 7.0] + for (sym, oldval, newval, check_inference) in [ + (:a, p[1], new_p[1], true), + (1, p[1], new_p[1], true), + ([:a, :b], p[1:2], new_p[1:2], !has_ts), + (1:2, p[1:2], new_p[1:2], true), + ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) + ] + get = getp(sys, sym) + set! = setp(sys, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == fi.ps[sym] + @test get(fi) == oldval - if pType === Tuple - @test_throws MethodError set!(fi, newval) - continue - end + if pType === Tuple + @test_throws MethodError set!(fi, newval) + continue + end - @test fi.counter[] == 0 - if check_inference - @inferred set!(fi, newval) - else - set!(fi, newval) - end - @test fi.counter[] == 1 + @test fi.counter[] == 0 + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test fi.counter[] == 1 - @test get(fi) == newval - set!(fi, oldval) - @test get(fi) == oldval - @test fi.counter[] == 2 + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval + @test fi.counter[] == 2 - fi.ps[sym] = newval - @test get(fi) == newval - @test fi.counter[] == 3 - fi.ps[sym] = oldval - @test get(fi) == oldval - @test fi.counter[] == 4 + fi.ps[sym] = newval + @test get(fi) == newval + @test fi.counter[] == 3 + fi.ps[sym] = oldval + @test get(fi) == oldval + @test fi.counter[] == 4 - if check_inference - @inferred get(p) + if check_inference + @inferred get(p) + end + @test get(p) == oldval + if check_inference + @inferred set!(p, newval) + else + set!(p, newval) + end + @test get(p) == newval + set!(p, oldval) + @test get(p) == oldval + @test fi.counter[] == 4 + fi.counter[] = 0 end - @test get(p) == oldval - if check_inference - @inferred set!(p, newval) - else - set!(p, newval) + + for (sym, val) in [ + ([:a, :b, :c, :d], p), + ([:c, :a], p[[3, 1]]), + ((:b, :a), Tuple(p[[2, 1]])), + ((1, :c), Tuple(p[[1, 3]])), + (:(a + b + t), p[1] + p[2] + fi.t), + ([:(a + b + t), :c], [p[1] + p[2] + fi.t, p[3]]), + ((:(a + b + t), :c), (p[1] + p[2] + fi.t, p[3])) + ] + get = getp(sys, sym) + @inferred get(fi) + @test get(fi) == val + if sym isa Union{Array, Tuple} + buffer = zeros(length(sym)) + @inferred get(buffer, fi) + @test buffer == collect(val) + end end - @test get(p) == newval - set!(p, oldval) - @test get(p) == oldval - @test fi.counter[] == 4 - fi.counter[] = 0 end +end - for (sym, val) in [ - ([:a, :b, :c], p), - ([:c, :a], p[[3, 1]]), - ((:b, :a), p[[2, 1]]), - ((1, :c), p[[1, 3]]) - ] - buffer = zeros(length(sym)) - get = getp(sys, sym) - @inferred get(buffer, fi) - @test buffer == val +struct MyDiffEqArray + t::Vector{Float64} + u::Vector{Vector{Float64}} +end +SymbolicIndexingInterface.current_time(mda::MyDiffEqArray) = mda.t +SymbolicIndexingInterface.state_values(mda::MyDiffEqArray) = mda.u +SymbolicIndexingInterface.is_timeseries(::Type{MyDiffEqArray}) = Timeseries() + +struct MyParameterObject + p::Vector{Float64} + disc_idxs::Vector{Vector{Int}} +end + +SymbolicIndexingInterface.parameter_values(mpo::MyParameterObject) = mpo.p +function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( + mpo::MyParameterObject, args::Pair...) + for (ts_idx, val) in args + mpo.p[mpo.disc_idxs[ts_idx]] = val end + return mpo end +Base.getindex(mpo::MyParameterObject, i) = mpo.p[i] + struct FakeSolution sys::SymbolCache u::Vector{Vector{Float64}} t::Vector{Float64} - p::Vector{Vector{Float64}} - pt::Vector{Float64} + p::MyParameterObject + p_ts::ParameterTimeseriesCollection{Vector{MyDiffEqArray}, MyParameterObject} end function Base.getproperty(fs::FakeSolution, s::Symbol) @@ -113,78 +164,262 @@ end SymbolicIndexingInterface.state_values(fs::FakeSolution) = fs.u SymbolicIndexingInterface.current_time(fs::FakeSolution) = fs.t SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys -SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p[end] -SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[end][i] -function SymbolicIndexingInterface.parameter_values_at_time(fs::FakeSolution, t) - fs.p[t] -end -function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSolution, t) - ptind = searchsortedfirst(fs.pt, fs.t[t]; lt = <=) - fs.p[ptind - 1] -end -SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution) = fs.pt +SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p +SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[i] +SymbolicIndexingInterface.get_parameter_timeseries_collection(fs::FakeSolution) = fs.p_ts SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries() -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +SymbolicIndexingInterface.is_parameter_timeseries(::Type{FakeSolution}) = Timeseries() +sys = SymbolCache([:x, :y, :z], + [:a, :b, :c, :d], + :t; + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) +b_timeseries = MyDiffEqArray(collect(0:0.1:0.9), [[2.5i] for i in 1:10]) +c_timeseries = MyDiffEqArray(collect(0:0.25:0.9), [[3.5i] for i in 1:4]) +p = MyParameterObject( + [20.0, b_timeseries.u[end][1], c_timeseries.u[end][1], 30.0], [[2], [3]]) fs = FakeSolution( sys, [i * ones(3) for i in 1:5], [0.2i for i in 1:5], - [2i * ones(3) for i in 1:10], - [0.1i for i in 1:10] + p, + ParameterTimeseriesCollection([b_timeseries, c_timeseries], deepcopy(p)) ) -ps = fs.p -p = fs.p[end] -avals = getindex.(ps, 1) -bvals = getindex.(ps, 2) -cvals = getindex.(ps, 3) -@test parameter_timeseries(fs) == fs.pt -for (sym, val, arrval, check_inference) in [ - (:a, p[1], avals, true), - (1, p[1], avals, true), - ([:a, :b], p[1:2], vcat.(avals, bvals), true), - (1:2, p[1:2], vcat.(avals, bvals), true), - ((1, 2), Tuple(p[1:2]), tuple.(avals, bvals), true), - ([:a, [:b, :c]], [p[1], p[2:3]], - [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), - ([:a, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), - ((:a, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), - ((:a, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true), - ([1, [:b, :c]], [p[1], p[2:3]], - [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), - ([1, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), - ((1, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true) +aval = fs.p[1] +bval = getindex.(b_timeseries.u) +cval = getindex.(c_timeseries.u) +dval = fs.p[4] +bidx = timeseries_parameter_index(sys, :b) +cidx = timeseries_parameter_index(sys, :c) + +for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ + (:a, IndexerNotTimeseries, 0, aval, nothing, true), + (1, IndexerNotTimeseries, 0, aval, nothing, true), + ([:a, :d], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((:a, :d), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + ([1, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((1, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + ([:a, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), + ((:a, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), + (:b, IndexerBoth, 1, bval, zeros(length(bval)), true), + (bidx, IndexerTimeseries, 1, bval, zeros(length(bval)), true), + ([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), + ((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), + ([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), + ((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), + ([:b, :b], IndexerBoth, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, bidx), IndexerTimeseries, 1, + tuple.(bval, bval), map(_ -> zeros(2), bval), true), + (:(a + b), IndexerBoth, 1, bval .+ aval, zeros(length(bval)), true), + ([:(a + b), :a], IndexerBoth, 1, vcat.(bval .+ aval, aval), + map(_ -> zeros(2), bval), true), + ((:(a + b), :a), IndexerBoth, 1, tuple.(bval .+ aval, aval), + map(_ -> zeros(2), bval), true), + ([:(a + b), :b], IndexerBoth, 1, vcat.(bval .+ aval, bval), + map(_ -> zeros(2), bval), true), + ((:(a + b), :b), IndexerBoth, 1, tuple.(bval .+ aval, bval), + map(_ -> zeros(2), bval), true), + ([:(a + b), :c], IndexerNotTimeseries, 0, + [aval + bval[end], cval[end]], zeros(2), true), + ((:(a + b), :c), IndexerNotTimeseries, 0, + (aval + bval[end], cval[end]), zeros(2), true) +] + getter = getp(sys, sym) + @test is_indexer_timeseries(getter) isa indexer_trait + if indexer_trait <: Union{IndexerTimeseries, IndexerBoth} + @test indexer_timeseries_index(getter) == timeseries_index + end + test_inplace = buffer !== nothing + test_non_timeseries = indexer_trait !== IndexerTimeseries + if test_inplace && test_non_timeseries + non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] + non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : + deepcopy(buffer[end]) + test_non_timeseries_inplace = non_timeseries_buffer isa AbstractArray + end + isobs = sym isa Union{AbstractArray, Tuple} ? any(Base.Fix1(is_observed, sys), sym) : + is_observed(sys, sym) + if check_inference + @inferred getter(fs) + if test_inplace + @inferred getter(deepcopy(buffer), fs) + end + if test_non_timeseries && !isobs + @inferred getter(parameter_values(fs)) + if test_inplace && test_non_timeseries_inplace && test_non_timeseries_inplace + @inferred getter(deepcopy(non_timeseries_buffer), parameter_values(fs)) + end + end + end + @test getter(fs) == val + if test_inplace + tmp = deepcopy(buffer) + getter(tmp, fs) + if val isa Tuple + target = collect(val) + elseif eltype(val) <: Tuple + target = collect.(val) + else + target = val + end + @test tmp == target + end + if test_non_timeseries && !isobs + non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] + @test getter(parameter_values(fs)) == non_timeseries_val + if test_inplace && test_non_timeseries && test_non_timeseries_inplace + getter(non_timeseries_buffer, parameter_values(fs)) + if non_timeseries_val isa Tuple + target = collect(non_timeseries_val) + else + target = non_timeseries_val + end + @test non_timeseries_buffer == target + end + elseif !isobs + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) + if test_inplace + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter( + [], parameter_values(fs)) + end + end + for subidx in [ + 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] + if indexer_trait <: IndexerNotTimeseries + @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( + fs, subidx) + if test_inplace + @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( + [], fs, subidx) + end + else + if check_inference + @inferred getter(fs, subidx) + if test_inplace && buffer[subidx] isa AbstractArray + @inferred getter(deepcopy(buffer[subidx]), fs, subidx) + end + end + @test getter(fs, subidx) == val[subidx] + if test_inplace && buffer[subidx] isa AbstractArray + tmp = deepcopy(buffer[subidx]) + getter(tmp, fs, subidx) + if val[subidx] isa Tuple + target = collect(val[subidx]) + elseif eltype(val) <: Tuple + target = collect.(val[subidx]) + else + target = val[subidx] + end + @test tmp == target + end + end + end +end + +for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx), + [bidx, :c], (bidx, :c), [bidx, cidx], (bidx, cidx)] + @test_throws ArgumentError getp(sys, sym) +end + +for (sym, val) in [ + ([:b, :c], [bval[end], cval[end]]), + ((:b, :c), (bval[end], cval[end])) ] - get = getp(sys, sym) + getter = getp(sys, sym) + @test is_indexer_timeseries(getter) == IndexerNotTimeseries() + @test_throws MixedParameterTimeseriesIndexError getter(fs) + @test getter(parameter_values(fs)) == val +end + +bval_state = [b_timeseries.u[searchsortedlast(b_timeseries.t, t)][] for t in fs.t] +cval_state = [c_timeseries.u[searchsortedlast(c_timeseries.t, t)][] for t in fs.t] +xval = getindex.(fs.u, 1) + +for (sym, val_is_timeseries, val, check_inference) in [ + (:a, false, aval, true), + ([:a, :d], false, [aval, dval], true), + ((:a, :d), false, (aval, dval), true), + (:b, true, bval_state, true), + ([:a, :b], true, vcat.(aval, bval_state), false), + ((:a, :b), true, tuple.(aval, bval_state), true), + ([:b, :c], true, vcat.(bval_state, cval_state), true), + ((:b, :c), true, tuple.(bval_state, cval_state), true), + ([:a, :b, :c], true, vcat.(aval, bval_state, cval_state), false), + ((:a, :b, :c), true, tuple.(aval, bval_state, cval_state), true), + ([:x, :b], true, vcat.(xval, bval_state), false), + ((:x, :b), true, tuple.(xval, bval_state), true), + ([:x, :b, :c], true, vcat.(xval, bval_state, cval_state), false), + ((:x, :b, :c), true, tuple.(xval, bval_state, cval_state), true), + ([:a, :b, :x], true, vcat.(aval, bval_state, xval), false), + ((:a, :b, :x), true, tuple.(aval, bval_state, xval), true), + (:(2b), true, 2 .* bval_state, true), + ([:x, :(2b), :(3c)], true, vcat.(xval, 2 .* bval_state, 3 .* cval_state), true), + ((:x, :(2b), :(3c)), true, tuple.(xval, 2 .* bval_state, 3 .* cval_state), true) +] + getter = getu(sys, sym) if check_inference - @inferred get(fs) + @inferred getter(fs) end - @test get(fs) == fs.ps[sym] - @test get(fs) == val + @test getter(fs) == val - for sub_inds in [ - :, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))] + for subidx in [ + 1, CartesianIndex(2), :, rand(Bool, length(fs.t)), rand(eachindex(fs.t), 3), 1:2] if check_inference - @inferred get(fs, sub_inds) + @inferred getter(fs, subidx) + end + target = if val_is_timeseries + val[subidx] + else + if fs.t[subidx] isa AbstractArray + len = length(fs.t[subidx]) + fill(val, len) + else + val + end end - @test get(fs, sub_inds) == fs.ps[sym, sub_inds] - @test get(fs, sub_inds) == arrval[sub_inds] + @test getter(fs, subidx) == target end end -ps = fs.p[2:2:end] -avals = getindex.(ps, 1) -bvals = getindex.(ps, 2) -cvals = getindex.(ps, 3) -for (sym, val, arrval) in [ - (:a, p[1], avals), - ((:b, :c), p[2:3], tuple.(bvals, cvals)), - ([:c, :a], p[[3, 1]], vcat.(cvals, avals)) +@test_throws ErrorException getp(sys, :not_a_param) + +struct FakeNoTimeSolution + sys::SymbolCache + u::Vector{Float64} + p::Vector{Float64} +end + +SymbolicIndexingInterface.state_values(fs::FakeNoTimeSolution) = fs.u +SymbolicIndexingInterface.symbolic_container(fs::FakeNoTimeSolution) = fs.sys +SymbolicIndexingInterface.parameter_values(fs::FakeNoTimeSolution) = fs.p +SymbolicIndexingInterface.parameter_values(fs::FakeNoTimeSolution, i) = fs.p[i] + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) +u = [1.0, 2.0, 3.0] +p = [10.0, 20.0, 30.0] +fs = FakeNoTimeSolution(sys, u, p) + +for (sym, val, check_inference) in [ + (:a, p[1], true), + ([:a, :b], p[1:2], true), + ((:c, :b), (p[3], p[2]), true), + (:(a + b), p[1] + p[2], true), + ([:(a + b), :c], [p[1] + p[2], p[3]], true), + ((:(a + b), :c), (p[1] + p[2], p[3]), true) ] - get = getu(sys, sym) - @inferred get(fs) - @test get(fs) == arrval - for i in eachindex(ps) - @test get(fs, i) == arrval[i] + getter = getp(sys, sym) + if check_inference + @inferred getter(fs) + end + @test getter(fs) == val + + if sym isa Union{Array, Tuple} + buffer = zeros(length(sym)) + @inferred getter(buffer, fs) + @test buffer == collect(val) end end diff --git a/test/parameter_timeseries_collection_test.jl b/test/parameter_timeseries_collection_test.jl new file mode 100644 index 00000000..2b691a5c --- /dev/null +++ b/test/parameter_timeseries_collection_test.jl @@ -0,0 +1,50 @@ +using SymbolicIndexingInterface +using SymbolicIndexingInterface: parameter_timeseries +using Test + +struct MyDiffEqArray + t::Vector{Float64} + u::Vector{Vector{Float64}} +end +SymbolicIndexingInterface.current_time(mda::MyDiffEqArray) = mda.t +SymbolicIndexingInterface.state_values(mda::MyDiffEqArray) = mda.u +SymbolicIndexingInterface.is_timeseries(::Type{MyDiffEqArray}) = Timeseries() + +ps = ones(3) +@test_throws ArgumentError ParameterTimeseriesCollection((ones(3), 2ones(3)), ps) + +a_timeseries = MyDiffEqArray(collect(0:0.1:0.9), [[2.5i, sin(0.2i)] for i in 1:10]) +b_timeseries = MyDiffEqArray(collect(0:0.25:0.9), [[3.5i, log(1.3i)] for i in 1:4]) +c_timeseries = MyDiffEqArray(collect(0:0.17:0.90), [[4.3i] for i in 1:5]) +collection = (a_timeseries, b_timeseries, c_timeseries) +ptc = ParameterTimeseriesCollection(collection, ps) + +@test collect(eachindex(ptc)) == [1, 2, 3] +@test [x for x in ptc] == [a_timeseries, b_timeseries, c_timeseries] +@test length(ptc) == 3 +@test parent(ptc) === collection +@test parameter_values(ptc) === ps + +for i in 1:3 + @test ptc[i] === collection[i] + @test parameter_timeseries(ptc, i) == collection[i].t + for j in eachindex(collection[i].u[1]) + pti = ParameterTimeseriesIndex(i, j) + @test ptc[pti] == getindex.(collection[i].u, j) + for k in eachindex(collection[i].u) + rhs = collection[i].u[k][j] + @test ptc[pti, CartesianIndex(k)] == rhs + @test ptc[pti, k] == rhs + @test ptc[i, k] == collection[i].u[k] + @test ptc[i, k, j] == rhs + @test parameter_values(ptc, pti, k) == rhs + end + allidxs = eachindex(collection[i].u) + for subidx in [:, rand(allidxs, 3), rand(Bool, length(allidxs))] + rhs = getindex.(collection[i].u[subidx], j) + @test ptc[pti, subidx] == rhs + @test ptc[i, subidx, j] == rhs + @test parameter_values(ptc, pti, subidx) == rhs + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a295448e..b91706cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,9 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Fallback test" begin @time include("fallback_test.jl") end + @safetestset "ParameterTimeseriesCollection test" begin + @time include("parameter_timeseries_collection_test.jl") + end @safetestset "Parameter indexing test" begin @time include("parameter_indexing_test.jl") end @@ -42,6 +45,9 @@ if GROUP == "All" || GROUP == "Core" @safetestset "BatchedInterface test" begin @time include("batched_interface_test.jl") end + @safetestset "Simple Adjoints test" begin + @time include("simple_adjoints_test.jl") + end end if GROUP == "All" || GROUP == "Downstream" diff --git a/test/simple_adjoints_test.jl b/test/simple_adjoints_test.jl new file mode 100644 index 00000000..329fd104 --- /dev/null +++ b/test/simple_adjoints_test.jl @@ -0,0 +1,17 @@ +using SymbolicIndexingInterface +using Zygote + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +pstate = ProblemState(; u = rand(3), p = rand(3), t = rand()) + +getter = getu(sys, :x) +@test Zygote.gradient(getter, pstate)[1].u == [1.0, 0.0, 0.0] + +getter = getu(sys, [:x, :z]) +@test Zygote.gradient(sum ∘ getter, pstate)[1].u == [1.0, 0.0, 1.0] + +getter = getu(sys, :a) +@test Zygote.gradient(getter, pstate)[1].p == [1.0, 0.0, 0.0] + +getter = getu(sys, [:a, :c]) +@test Zygote.gradient(sum ∘ getter, pstate)[1].p == [1.0, 0.0, 1.0] diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 16c8a718..d8392636 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -13,6 +13,10 @@ SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) + +@test_throws ErrorException getu(sys, :q) +@test_throws ErrorException setu(sys, :q) + u = [1.0, 2.0, 3.0] p = [11.0, 12.0, 13.0] t = 0.5 @@ -130,14 +134,18 @@ struct FakeSolution{S, U, P, T} end SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() +function SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution{ + S, U, P, Nothing}}) where {S, U, P} + NotTimeseries() +end SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p SymbolicIndexingInterface.current_time(fp::FakeSolution) = fp.t sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -t = [1.5, 2.0] +u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]] +t = [1.5, 2.0, 2.3, 4.0] sol = FakeSolution(sys, u, p, t) xvals = getindex.(sol.u, 1) @@ -150,7 +158,7 @@ for (sym, ans, check_inference) in [(:x, xvals, true) (1, xvals, true) ([:x, :y], vcat.(xvals, yvals), true) (1:2, vcat.(xvals, yvals), true) - ([:x, 2], vcat.(xvals, yvals), false) + ([:x, 2], vcat.(xvals, yvals), true) ((:z, :y), tuple.(zvals, yvals), true) ((3, 2), tuple.(zvals, yvals), true) ([:x, [:y, :z]], @@ -186,7 +194,8 @@ for (sym, ans, check_inference) in [(:x, xvals, true) @inferred get(sol) end @test get(sol) == ans - for i in eachindex(u) + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] if check_inference @inferred get(sol, i) end @@ -204,6 +213,13 @@ for (sym, val, check_inference) in [ @inferred get(sol) end @test get(sol) == val + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] + if check_inference + @inferred get(sol, i) + end + @test get(sol, i) == val[i] + end end for (sym, val) in [(:a, p[1]) @@ -211,7 +227,41 @@ for (sym, val) in [(:a, p[1]) (:c, p[3]) ([:a, :b], p[1:2]) ((:c, :b), (p[3], p[2]))] - get = getu(fi, sym) - @inferred get(fi) - @test get(fi) == val + get = getu(sys, sym) + @inferred get(sol) + @test get(sol) == val +end + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) +u = [1.0, 2.0, 3.0] +p = [10.0, 20.0, 30.0] +fs = FakeSolution(sys, u, p, nothing) +@test is_timeseries(fs) == NotTimeseries() + +for (sym, val, check_inference) in [ + (:x, u[1], true), + (1, u[1], true), + ([:x, :y], u[1:2], true), + ((:x, :y), Tuple(u[1:2]), true), + (1:2, u[1:2], true), + ([:x, 2], u[1:2], true), + ((:x, 2), Tuple(u[1:2]), true), + ([1, 2], u[1:2], true), + ((1, 2), Tuple(u[1:2]), true), + (:a, p[1], true), + ([:a, :b], p[1:2], true), + ((:a, :b), Tuple(p[1:2]), true), + ([:x, :a], [u[1], p[1]], false), + ((:x, :a), (u[1], p[1]), true), + ([1, :a], [u[1], p[1]], false), + ((1, :a), (u[1], p[1]), true), + (:(x + y + a + b), u[1] + u[2] + p[1] + p[2], true), + ([:(x + a), :(y + b)], [u[1] + p[1], u[2] + p[2]], true), + ((:(x + a), :(y + b)), (u[1] + p[1], u[2] + p[2]), true) +] + getter = getu(sys, sym) + if check_inference + @inferred getter(fs) + end + @test getter(fs) == val end diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index f8a861b3..d4a55ab6 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -16,10 +16,10 @@ sc = SymbolCache( @test is_time_dependent(sc) @test constant_structure(sc) @test variable_symbols(sc) == [:x, :y, :z] -@test parameter_symbols(sc) == [:a, :b] +@test sort(parameter_symbols(sc)) == [:a, :b] @test independent_variable_symbols(sc) == [:t] @test all_variable_symbols(sc) == [:x, :y, :z] -@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z] +@test sort(sort(all_symbols(sc))) == [:a, :b, :t, :x, :y, :z] @test default_values(sc)[:x] == 1 @test default_values(sc)[:y] == :(2b) @test default_values(sc)[:b] == :(2a + x) @@ -45,6 +45,44 @@ obsfn4 = observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)]) obsfn5 = observed(sc, (:(x + a), :(y + b))) @test all(obsfn5(ones(3), 2ones(2), 3.0) .≈ (3.0, 3.0)) +@test_throws TypeError observed(sc, [:(x + a), 2]) +@test_throws TypeError observed(sc, (:(x + a), 2)) + +pobsfn1 = parameter_observed(sc, :(a + b + t)).observed_fn +@test pobsfn1(2ones(2), 3.0) == 7.0 +pobsfn2 = parameter_observed(sc, [:(a + b + t), :(a + t)]).observed_fn +@test pobsfn2(2ones(2), 3.0) == [7.0, 5.0] +buffer = zeros(2) +pobsfn2(buffer, 2ones(2), 3.0) +@test buffer == [7.0, 5.0] +pobsfn3 = parameter_observed(sc, (:(a + b + t), :(a + t))).observed_fn +@test pobsfn3(2ones(2), 3.0) == (7.0, 5.0) +buffer = zeros(2) +pobsfn3(buffer, 2ones(2), 3.0) +@test buffer == [7.0, 5.0] + +@test_throws TypeError parameter_observed(sc, [:(a + b), 4]) +@test_throws TypeError parameter_observed(sc, (:(a + b), 4)) + +sc = SymbolCache([:x, :y], [:a, :b, :c], :t; + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) +@test parameter_observed(sc, :(a + c)).timeseries_idx == 2 +@test parameter_observed(sc, [:a, :c]).timeseries_idx == 2 +@test parameter_observed(sc, (:a, :c)).timeseries_idx == 2 +@test parameter_observed(sc, :(2a)).timeseries_idx === nothing +@test parameter_observed(sc, [:(2a), :(3a)]).timeseries_idx === nothing +@test parameter_observed(sc, (:(2a), :(3a))).timeseries_idx === nothing +@test parameter_observed(sc, [:b, :c]).timeseries_idx === nothing +@test parameter_observed(sc, (:b, :c)).timeseries_idx === nothing + +@test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t; + timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) +@test_throws TypeError SymbolCache( + [:x, :y], [:a, :c], :t; timeseries_parameters = Dict(:c => (1, 1))) +@test_nowarn SymbolCache([:x, :y], [:a, :c], :t; + timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) + sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) @test sort(all_symbols(sc)) == [:a, :b, :x, :y] @@ -54,6 +92,9 @@ obsfn = observed(sc, :(x + b)) # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) +@test_throws ArgumentError SymbolCache( + [:x, :y], [:a, :b]; timeseries_parameters = Dict(:b => ParameterTimeseriesIndex(1, 1))) + sc = SymbolCache() @test all(.!is_variable.((sc,), [:x, :y, :a, :b, :t])) @test all(variable_index.((sc,), [:x, :y, :a, :b, :t]) .== nothing) @@ -77,6 +118,10 @@ sc = SymbolCache(nothing, nothing, :t) @test all_symbols(sc) == [:t] @test isempty(default_values(sc)) +sc = SymbolCache(nothing, nothing, [:t1, :t2, :t3]) +@test all(is_independent_variable.((sc,), [:t1, :t2, :t3])) +@test independent_variable_symbols(sc) == [:t1, :t2, :t3] + sc2 = copy(sc) @test sc.variables == sc2.variables @test sc.parameters == sc2.parameters