diff --git a/Project.toml b/Project.toml index 70629f84a1..1a5732bdd0 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" @@ -79,6 +80,7 @@ DocStringExtensions = "0.7, 0.8, 0.9" DomainSets = "0.6, 0.7" DynamicQuantities = "^0.11.2, 0.12, 0.13" ExprTools = "0.1.10" +Expronicon = "0.8" FindFirstFunctions = "1" ForwardDiff = "0.10.3" FunctionWrappersWrappers = "0.1" diff --git a/docs/src/tutorials/SampledData.md b/docs/src/tutorials/SampledData.md index d2d9294bdb..e748673a1b 100644 --- a/docs/src/tutorials/SampledData.md +++ b/docs/src/tutorials/SampledData.md @@ -16,7 +16,7 @@ A clock can be seen as an *event source*, i.e., when the clock ticks, an event i - [`Hold`](@ref) - [`ShiftIndex`](@ref) -When a continuous-time variable `x` is sampled using `xd = Sample(x, dt)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc. +When a continuous-time variable `x` is sampled using `xd = Sample(dt)(x)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc. To make a discrete-time variable available to the continuous partition, the [`Hold`](@ref) operator is used. `xc = Hold(xd)` creates a continuous-time variable `xc` that is updated whenever the clock associated with `xd` ticks, and holds its value constant between ticks. @@ -34,7 +34,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t @variables x(t) y(t) u(t) dt = 0.1 # Sample interval -clock = Clock(t, dt) # A periodic clock with tick rate dt +clock = Clock(dt) # A periodic clock with tick rate dt k = ShiftIndex(clock) eqs = [ @@ -98,7 +98,7 @@ may thus be modeled as ```julia @variables t y(t) [description = "Output"] u(t) [description = "Input"] -k = ShiftIndex(Clock(t, dt)) +k = ShiftIndex(Clock(dt)) eqs = [ a2 * y(k) + a1 * y(k - 1) + a0 * y(k - 2) ~ b2 * u(k) + b1 * u(k - 1) + b0 * u(k - 2) ] @@ -127,10 +127,10 @@ requires specification of the initial condition for both `x(k-1)` and `x(k-2)`. Multi-rate systems are easy to model using multiple different clocks. The following set of equations is valid, and defines *two different discrete-time partitions*, each with its own clock: ```julia -yd1 ~ Sample(t, dt1)(y) -ud1 ~ kp * (Sample(t, dt1)(r) - yd1) -yd2 ~ Sample(t, dt2)(y) -ud2 ~ kp * (Sample(t, dt2)(r) - yd2) +yd1 ~ Sample(dt1)(y) +ud1 ~ kp * (Sample(dt1)(r) - yd1) +yd2 ~ Sample(dt2)(y) +ud2 ~ kp * (Sample(dt2)(r) - yd2) ``` `yd1` and `ud1` belong to the same clock which ticks with an interval of `dt1`, while `yd2` and `ud2` belong to a different clock which ticks with an interval of `dt2`. The two clocks are *not synchronized*, i.e., they are not *guaranteed* to tick at the same point in time, even if one tick interval is a rational multiple of the other. Mechanisms for synchronization of clocks are not yet implemented. @@ -147,7 +147,7 @@ using ModelingToolkit: t_nounits as t using ModelingToolkit: D_nounits as D dt = 0.5 # Sample interval @variables r(t) -clock = Clock(t, dt) +clock = Clock(dt) k = ShiftIndex(clock) function plant(; name) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 5be9e6dbb2..bdd7ed1ba8 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -42,7 +42,8 @@ using SciMLStructures using Compat using AbstractTrees using DiffEqBase, SciMLBase, ForwardDiff -using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap +using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain, + PeriodicClock, Clock, SolverStepClock, Continuous using Distributed import JuliaFormatter using MLStyle @@ -270,6 +271,6 @@ export debug_system #export has_discrete_domain, has_continuous_domain #export is_discrete_domain, is_continuous_domain, is_hybrid_domain export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime -export Clock #, InferredDiscrete, +export Clock, SolverStepClock, TimeDomain end # module diff --git a/src/clock.jl b/src/clock.jl index 5df6cfb022..13924d562f 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,13 +1,26 @@ -abstract type TimeDomain end -abstract type AbstractDiscrete <: TimeDomain end +module InferredClock -Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) +export InferredTimeDomain -struct Inferred <: TimeDomain end -struct InferredDiscrete <: AbstractDiscrete end -struct Continuous <: TimeDomain end +using Expronicon.ADT: @adt, @match +using SciMLBase: TimeDomain -Symbolics.option_to_metadata_type(::Val{:timedomain}) = TimeDomain +@adt InferredTimeDomain begin + Inferred + InferredDiscrete +end + +Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x) + +end + +using .InferredClock + +struct VariableTimeDomain end +Symbolics.option_to_metadata_type(::Val{:timedomain}) = VariableTimeDomain + +is_concrete_time_domain(::TimeDomain) = true +is_concrete_time_domain(_) = false """ is_continuous_domain(x) @@ -16,7 +29,7 @@ true if `x` contains only continuous-domain signals. See also [`has_continuous_domain`](@ref) """ function is_continuous_domain(x) - issym(x) && return getmetadata(x, TimeDomain, false) isa Continuous + issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous !has_discrete_domain(x) && has_continuous_domain(x) end @@ -24,7 +37,7 @@ function get_time_domain(x) if iscall(x) && operation(x) isa Operator output_timedomain(x) else - getmetadata(x, TimeDomain, nothing) + getmetadata(x, VariableTimeDomain, nothing) end end get_time_domain(x::Num) = get_time_domain(value(x)) @@ -37,14 +50,14 @@ Determine if variable `x` has a time-domain attributed to it. function has_time_domain(x::Symbolic) # getmetadata(x, Continuous, nothing) !== nothing || # getmetadata(x, Discrete, nothing) !== nothing - getmetadata(x, TimeDomain, nothing) !== nothing + getmetadata(x, VariableTimeDomain, nothing) !== nothing end has_time_domain(x::Num) = has_time_domain(value(x)) has_time_domain(x) = false for op in [Differential] - @eval input_timedomain(::$op, arg = nothing) = Continuous() - @eval output_timedomain(::$op, arg = nothing) = Continuous() + @eval input_timedomain(::$op, arg = nothing) = Continuous + @eval output_timedomain(::$op, arg = nothing) = Continuous end """ @@ -83,12 +96,17 @@ true if `x` contains only discrete-domain signals. See also [`has_discrete_domain`](@ref) """ function is_discrete_domain(x) - if hasmetadata(x, TimeDomain) || issym(x) - return getmetadata(x, TimeDomain, false) isa AbstractDiscrete + if hasmetadata(x, VariableTimeDomain) || issym(x) + return is_discrete_time_domain(getmetadata(x, VariableTimeDomain, false)) end !has_discrete_domain(x) && has_continuous_domain(x) end +sampletime(c) = @match c begin + PeriodicClock(dt, _...) => dt + _ => nothing +end + struct ClockInferenceException <: Exception msg::Any end @@ -97,57 +115,7 @@ function Base.showerror(io::IO, cie::ClockInferenceException) print(io, "ClockInferenceException: ", cie.msg) end -abstract type AbstractClock <: AbstractDiscrete end - -""" - Clock <: AbstractClock - Clock([t]; dt) - -The default periodic clock with independent variables `t` and tick interval `dt`. -If `dt` is left unspecified, it will be inferred (if possible). -""" -struct Clock <: AbstractClock - "Independent variable" - t::Union{Nothing, Symbolic} - "Period" - dt::Union{Nothing, Float64} - Clock(t::Union{Num, Symbolic}, dt = nothing) = new(value(t), dt) - Clock(t::Nothing, dt = nothing) = new(t, dt) -end -Clock(dt::Real) = Clock(nothing, dt) -Clock() = Clock(nothing, nothing) - -sampletime(c) = isdefined(c, :dt) ? c.dt : nothing -Base.hash(c::Clock, seed::UInt) = hash(c.dt, seed ⊻ 0x953d7a9a18874b90) -function Base.:(==)(c1::Clock, c2::Clock) - ((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) && c1.dt == c2.dt -end - -is_concrete_time_domain(x) = x isa Union{AbstractClock, Continuous} - -""" - SolverStepClock <: AbstractClock - SolverStepClock() - SolverStepClock(t) - -A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). This clock **does generally not have equidistant tick intervals**, instead, the tick interval depends on the adaptive step-size selection of the continuous solver, as well as any continuous event handling. If adaptivity of the solver is turned off and there are no continuous events, the tick interval will be given by the fixed solver time step `dt`. - -Due to possibly non-equidistant tick intervals, this clock should typically not be used with discrete-time systems that assume a fixed sample time, such as PID controllers and digital filters. -""" -struct SolverStepClock <: AbstractClock - "Independent variable" - t::Union{Nothing, Symbolic} - "Period" - SolverStepClock(t::Union{Num, Symbolic}) = new(value(t)) -end -SolverStepClock() = SolverStepClock(nothing) - -Base.hash(c::SolverStepClock, seed::UInt) = seed ⊻ 0x953d7b9a18874b91 -function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock) - ((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) -end - -struct IntegerSequence <: AbstractClock +struct IntegerSequence t::Union{Nothing, Symbolic} IntegerSequence(t::Union{Num, Symbolic}) = new(value(t)) end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 34f628a8b3..9a1332bf81 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -85,8 +85,8 @@ $(TYPEDEF) Represents a sample operator. A discrete-time signal is created by sampling a continuous-time signal. # Constructors -`Sample(clock::TimeDomain = InferredDiscrete())` -`Sample([t], dt::Real)` +`Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete)` +`Sample(dt::Real)` `Sample(x::Num)`, with a single argument, is shorthand for `Sample()(x)`. @@ -100,16 +100,23 @@ julia> using Symbolics julia> @variables t; -julia> Δ = Sample(t, 0.01) +julia> Δ = Sample(0.01) (::Sample) (generic function with 2 methods) ``` """ struct Sample <: Operator clock::Any - Sample(clock::TimeDomain = InferredDiscrete()) = new(clock) - Sample(t, dt::Real) = new(Clock(t, dt)) + Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete) = new(clock) +end + +function Sample(arg::Real) + arg = unwrap(arg) + if symbolic_type(arg) == NotSymbolic() + Sample(Clock(arg)) + else + Sample()(arg) + end end -Sample(x) = Sample()(x) (D::Sample)(x) = Term{symtype(x)}(D, Any[x]) (D::Sample)(x::Num) = Num(D(value(x))) SymbolicUtils.promote_symtype(::Sample, x) = x @@ -178,11 +185,14 @@ Shift(t, 1)(x(t)) ``` """ struct ShiftIndex - clock::TimeDomain + clock::Union{InferredTimeDomain, TimeDomain, IntegerSequence} steps::Int - ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps) - ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps) - ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps) + function ShiftIndex( + clock::Union{TimeDomain, InferredTimeDomain} = Inferred, steps::Int = 0) + new(clock, steps) + end + ShiftIndex(dt::Real, steps::Int = 0) = new(Clock(dt), steps) + ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(), steps) end function (xn::Num)(k::ShiftIndex) @@ -206,7 +216,7 @@ function (xn::Num)(k::ShiftIndex) # xn = Sample(t, clock)(xn) # end # QUESTION: should we return a variable with time domain set to k.clock? - xn = setmetadata(xn, TimeDomain, k.clock) + xn = setmetadata(xn, VariableTimeDomain, k.clock) if steps == 0 return xn # x(k) needs no shift operator if the step of k is 0 end @@ -219,37 +229,37 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i) """ input_timedomain(op::Operator) -Return the time-domain type (`Continuous()` or `Discrete()`) that `op` operates on. +Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` operates on. """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + InferredDiscrete end """ output_timedomain(op::Operator) -Return the time-domain type (`Continuous()` or `Discrete()`) that `op` results in. +Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` results in. """ function output_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + InferredDiscrete end -input_timedomain(::Sample, arg = nothing) = Continuous() +input_timedomain(::Sample, arg = nothing) = Continuous output_timedomain(s::Sample, arg = nothing) = s.clock function input_timedomain(h::Hold, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() # the Hold accepts any discrete + InferredDiscrete # the Hold accepts any discrete end -output_timedomain(::Hold, arg = nothing) = Continuous() +output_timedomain(::Hold, arg = nothing) = Continuous sampletime(op::Sample, arg = nothing) = sampletime(op.clock) sampletime(op::ShiftIndex, arg = nothing) = sampletime(op.clock) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 98e9e9586c..9a0ae5a580 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -564,8 +564,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) return obsfn end +function has_observed_with_lhs(sys, sym) + has_observed(sys) || return false + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + return any(isequal(sym), ic.observed_syms) + else + return any(isequal(sym), [eq.lhs for eq in observed(sys)]) + end +end + function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym) - if is_variable(sys, sym) + if is_variable(sys, sym) || is_independent_variable(sys, sym) push!(ts_idxs, ContinuousTimeseries()) elseif is_timeseries_parameter(sys, sym) push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx) @@ -578,17 +587,33 @@ for traitT in [ ] @eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym) allsyms = vars(sym; op = Symbolics.Operator) - foreach(allsyms) do s - _all_ts_idxs!(ts_idxs, sys, s) + for s in allsyms + s = unwrap(s) + if is_variable(sys, s) || is_independent_variable(sys, s) || + has_observed_with_lhs(sys, s) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, s) + push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + end end end end +function _all_ts_idxs!(ts_idxs, ::ScalarSymbolic, sys, sym::Symbol) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + return _all_ts_idxs!(ts_idxs, sys, ic.symbol_to_variable[sym]) + elseif is_variable(sys, sym) || is_independent_variable(sys, sym) || + any(isequal(sym), [getname(eq.lhs) for eq in observed(sys)]) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, sym) + push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + end +end function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray) - foreach(sym) do s + for s in sym _all_ts_idxs!(ts_idxs, sys, s) end end -_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, sym) +_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, symbolic_type(sym), sys, sym) function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym) if !is_time_dependent(sys) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index c6a8464536..dfdef69034 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -8,8 +8,8 @@ end function ClockInference(ts::TransformationState) @unpack structure = ts @unpack graph = structure - eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)] - var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)] + eq_domain = TimeDomain[Continuous for _ in 1:nsrcs(graph)] + var_domain = TimeDomain[Continuous for _ in 1:ndsts(graph)] inferred = BitSet() for (i, v) in enumerate(get_fullvars(ts)) d = get_time_domain(v) @@ -151,7 +151,7 @@ function split_system(ci::ClockInference{S}) where {S} get!(clock_to_id, d) do cid = (cid_counter[] += 1) push!(id_to_clock, d) - if d isa Continuous + if d == Continuous continuous_id[] = cid end cid diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 6c60fb5991..123e4c1699 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -840,7 +840,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first if sys isa ODESystem && build_initializeprob && (((implicit_dae || !isempty(missingvars)) && - all(isequal(Continuous()), ci.var_domain) && + all(==(Continuous), ci.var_domain) && ModelingToolkit.get_tearing_state(sys) !== nothing) || !isempty(initialization_equations(sys))) && t !== nothing if eltype(u0map) <: Number @@ -1026,14 +1026,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) discrete_cbs = map(affects, clocks) do affect, clock - if clock isa Clock - PeriodicCallback(affect, clock.dt; + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - elseif clock isa SolverStepClock - DiscreteCallback(Returns(true), affect, + &SolverStepClock => DiscreteCallback(Returns(true), affect, initialize = (c, u, t, integrator) -> affect(integrator)) - else - error("$clock is not a supported clock type.") + _ => error("$clock is not a supported clock type.") end end if cbs === nothing @@ -1127,14 +1125,15 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], u0 = h(p, tspan[1]) cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv - if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + discrete_cbs = map(affects, clocks) do affect, clock + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - else - error("$clock is not a supported clock type.") + &SolverStepClock => DiscreteCallback(Returns(true), affect, + initialize = (c, u, t, integrator) -> affect(integrator)) + _ => error("$clock is not a supported clock type.") end end if cbs === nothing @@ -1189,14 +1188,15 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], u0 = h(p, tspan[1]) cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv - if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + discrete_cbs = map(affects, clocks) do affect, clock + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - else - error("$clock is not a supported clock type.") + &SolverStepClock => DiscreteCallback(Returns(true), affect, + initialize = (c, u, t, integrator) -> affect(integrator)) + _ => error("$clock is not a supported clock type.") end end if cbs === nothing diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 13fb7adef2..f992cd4907 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -32,6 +32,7 @@ struct IndexCache constant_idx::ParamIndexMap dependent_idx::ParamIndexMap nonnumeric_idx::ParamIndexMap + observed_syms::Set{Union{Symbol, BasicSymbolic}} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_sizes::Vector{BufferTemplate} constant_buffer_sizes::Vector{BufferTemplate} @@ -48,16 +49,21 @@ function IndexCache(sys::AbstractSystem) let idx = 1 for sym in unks usym = unwrap(sym) + rsym = renamespace(sys, usym) sym_idx = if Symbolics.isarraysymbolic(sym) reshape(idx:(idx + length(sym) - 1), size(sym)) else idx end unk_idxs[usym] = sym_idx + unk_idxs[rsym] = sym_idx if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) name = getname(usym) + rname = getname(rsym) unk_idxs[name] = sym_idx + unk_idxs[rname] = sym_idx symbol_to_variable[name] = sym + symbol_to_variable[rname] = sym end idx += length(sym) end @@ -71,18 +77,41 @@ function IndexCache(sys::AbstractSystem) if idxs == idxs[begin]:idxs[end] idxs = reshape(idxs[begin]:idxs[end], size(idxs)) end + rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs + unk_idxs[rsym] = idxs if hasname(arrsym) name = getname(arrsym) + rname = getname(rsym) unk_idxs[name] = idxs + unk_idxs[rname] = idxs symbol_to_variable[name] = arrsym + symbol_to_variable[rname] = arrsym end end end + observed_syms = Set{Union{Symbol, BasicSymbolic}}() for eq in observed(sys) - if symbolic_type(eq.lhs) != NotSymbolic() && hasname(eq.lhs) - symbol_to_variable[getname(eq.lhs)] = eq.lhs + if symbolic_type(eq.lhs) != NotSymbolic() + sym = eq.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + push!(observed_syms, sym) + push!(observed_syms, ttsym) + push!(observed_syms, rsym) + push!(observed_syms, rttsym) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = eq.lhs + symbol_to_variable[getname(ttsym)] = eq.lhs + symbol_to_variable[getname(rsym)] = eq.lhs + symbol_to_variable[getname(rttsym)] = eq.lhs + push!(observed_syms, getname(sym)) + push!(observed_syms, getname(ttsym)) + push!(observed_syms, getname(rsym)) + push!(observed_syms, getname(rttsym)) + end end end @@ -109,26 +138,40 @@ function IndexCache(sys::AbstractSystem) for inp in inps inp = unwrap(inp) + ttinp = default_toterm(inp) + rinp = renamespace(sys, inp) + rttinp = renamespace(sys, ttinp) is_parameter(sys, inp) || error("Discrete subsystem $i input $inp is not a parameter") disc_clocks[inp] = i - disc_clocks[default_toterm(inp)] = i + disc_clocks[ttinp] = i + disc_clocks[rinp] = i + disc_clocks[rttinp] = i if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex) disc_clocks[getname(inp)] = i - disc_clocks[default_toterm(inp)] = i + disc_clocks[getname(ttinp)] = i + disc_clocks[getname(rinp)] = i + disc_clocks[getname(rttinp)] = i end insert_by_type!(disc_buffers[i], inp) end for sym in unknowns(disc_sys) sym = unwrap(sym) + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) is_parameter(sys, sym) || error("Discrete subsystem $i unknown $sym is not a parameter") disc_clocks[sym] = i - disc_clocks[default_toterm(sym)] = i + disc_clocks[ttsym] = i + disc_clocks[rsym] = i + disc_clocks[rttsym] = i if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i - disc_clocks[getname(default_toterm(sym))] = i + disc_clocks[getname(ttsym)] = i + disc_clocks[getname(rsym)] = i + disc_clocks[getname(rttsym)] = i end insert_by_type!(disc_buffers[i], sym) end @@ -138,21 +181,31 @@ function IndexCache(sys::AbstractSystem) # FIXME: This shouldn't be necessary eq.rhs === -0.0 && continue sym = eq.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) if iscall(sym) && operation(sym) == Shift(t, 1) sym = only(arguments(sym)) end disc_clocks[sym] = i - disc_clocks[sym] = i - disc_clocks[default_toterm(sym)] = i + disc_clocks[ttsym] = i + disc_clocks[rsym] = i + disc_clocks[rttsym] = i if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i - disc_clocks[getname(default_toterm(sym))] = i + disc_clocks[getname(ttsym)] = i + disc_clocks[getname(rsym)] = i + disc_clocks[getname(rttsym)] = i end end end for par in inputs[continuous_id] is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") + par = unwrap(par) + ttpar = default_toterm(par) + rpar = renamespace(sys, par) + rttpar = renamespace(sys, ttpar) iscall(par) && operation(par) isa Hold || error("Continuous subsystem input is not a Hold") if haskey(disc_clocks, par) @@ -163,6 +216,9 @@ function IndexCache(sys::AbstractSystem) haskey(disc_clocks, sym) || error("Variable $par not part of a discrete subsystem") disc_clocks[par] = disc_clocks[sym] + disc_clocks[ttpar] = disc_clocks[sym] + disc_clocks[rpar] = disc_clocks[sym] + disc_clocks[rttpar] = disc_clocks[sym] insert_by_type!(disc_buffers[disc_clocks[sym]], par) end end @@ -172,13 +228,21 @@ function IndexCache(sys::AbstractSystem) for affect in affs if affect isa Equation is_parameter(sys, affect.lhs) || continue - - disc_clocks[affect.lhs] = user_affect_clock - disc_clocks[default_toterm(affect.lhs)] = user_affect_clock - if hasname(affect.lhs) && - (!iscall(affect.lhs) || operation(affect.lhs) !== getindex) - disc_clocks[getname(affect.lhs)] = user_affect_clock - disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock + sym = affect.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + + disc_clocks[sym] = user_affect_clock + disc_clocks[ttsym] = user_affect_clock + disc_clocks[rsym] = user_affect_clock + disc_clocks[rttsym] = user_affect_clock + if hasname(sym) && + (!iscall(sym) || operation(sym) !== getindex) + disc_clocks[getname(sym)] = user_affect_clock + disc_clocks[getname(ttsym)] = user_affect_clock + disc_clocks[getname(rsym)] = user_affect_clock + disc_clocks[getname(rttsym)] = user_affect_clock end buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) insert_by_type!(buffer, affect.lhs) @@ -188,11 +252,18 @@ function IndexCache(sys::AbstractSystem) is_parameter(sys, disc) || error("Expected discrete variable $disc in callback to be a parameter") disc = unwrap(disc) + ttdisc = default_toterm(disc) + rdisc = renamespace(sys, disc) + rttdisc = renamespace(sys, ttdisc) disc_clocks[disc] = user_affect_clock - disc_clocks[default_toterm(disc)] = user_affect_clock + disc_clocks[ttdisc] = user_affect_clock + disc_clocks[rdisc] = user_affect_clock + disc_clocks[rttdisc] = user_affect_clock if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex) disc_clocks[getname(disc)] = user_affect_clock - disc_clocks[getname(default_toterm(disc))] = user_affect_clock + disc_clocks[getname(ttdisc)] = user_affect_clock + disc_clocks[getname(rdisc)] = user_affect_clock + disc_clocks[getname(rttdisc)] = user_affect_clock end buffer = get!( disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) @@ -267,13 +338,22 @@ function IndexCache(sys::AbstractSystem) buffer_sizes = BufferTemplate[] for (i, (T, buf)) in enumerate(buffers) for (j, p) in enumerate(buf) + ttp = default_toterm(p) + rp = renamespace(sys, p) + rttp = renamespace(sys, ttp) idxs[p] = (i, j) - idxs[default_toterm(p)] = (i, j) + idxs[ttp] = (i, j) + idxs[rp] = (i, j) + idxs[rttp] = (i, j) if hasname(p) && (!iscall(p) || operation(p) !== getindex) idxs[getname(p)] = (i, j) + idxs[getname(ttp)] = (i, j) + idxs[getname(rp)] = (i, j) + idxs[getname(rttp)] = (i, j) symbol_to_variable[getname(p)] = p - idxs[getname(default_toterm(p))] = (i, j) - symbol_to_variable[getname(default_toterm(p))] = p + symbol_to_variable[getname(ttp)] = p + symbol_to_variable[getname(rp)] = p + symbol_to_variable[getname(rttp)] = p end end push!(buffer_sizes, BufferTemplate(T, length(buf))) @@ -293,6 +373,7 @@ function IndexCache(sys::AbstractSystem) const_idxs, dependent_idxs, nonnumeric_idxs, + observed_syms, disc_buffer_sizes, tunable_buffer_sizes, const_buffer_sizes, @@ -306,6 +387,10 @@ function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) return check_index_map(ic.unknown_idx, sym) !== nothing end +function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym::Symbol) + return check_index_map(ic.unknown_idx, sym) !== nothing +end + function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) return check_index_map(ic.unknown_idx, sym) end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 919c5d6760..81fb928a86 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -581,16 +581,15 @@ function SciMLBase.create_parameter_timeseries_collection( for (i, partition) in enumerate(ps.discrete) clock = id_to_clock[i] - if clock isa Clock - ts = tspan[1]:(clock.dt):tspan[2] - push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1))) - elseif clock isa SolverStepClock - push!(buffers, + @match clock begin + PeriodicClock(dt, _...) => begin + ts = tspan[1]:(dt):tspan[2] + push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1))) + end + &SolverStepClock => push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], eltype(tspan)[], (1, 1))) - elseif clock isa Continuous - continue - else - error("Unhandled clock $clock") + &Continuous => continue + _ => error("Unhandled clock $clock") end end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index ff26552c79..2cbf820d0d 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -8,6 +8,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, isparameter, isconstant, independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, + InferredTimeDomain, VariableType, getvariabletype, has_equations, ODESystem using ..BipartiteGraphs import ..BipartiteGraphs: invview, complete @@ -331,7 +332,7 @@ function TearingState(sys; quick_cancel = false, check = true) !isdifferential(var) && (it = input_timedomain(var)) !== nothing set_incidence = false var = only(arguments(var)) - var = setmetadata(var, TimeDomain, it) + var = setmetadata(var, VariableTimeDomain, it) @goto ANOTHER_VAR end end @@ -660,7 +661,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals @set! sys.defaults = merge(ModelingToolkit.defaults(sys), Dict(v => 0.0 for v in Iterators.flatten(inputs))) end - ps = [setmetadata(sym, TimeDomain, get(time_domains, sym, Continuous())) + ps = [setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous)) for sym in get_ps(sys)] @set! sys.ps = ps else diff --git a/test/clock.jl b/test/clock.jl index 69b7c30c50..fc9857b257 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -14,7 +14,7 @@ dt = 0.1 @parameters kp # u(n + 1) := f(u(n)) -eqs = [yd ~ Sample(t, dt)(y) +eqs = [yd ~ Sample(dt)(y) ud ~ kp * (r - yd) r ~ 1.0 @@ -70,35 +70,35 @@ sss, = ModelingToolkit._structural_simplify!( @test equations(sss) == [D(x) ~ u - x] sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) @test isempty(equations(sss)) -d = Clock(t, dt) +d = Clock(dt) k = ShiftIndex(d) @test observed(sss) == [yd(k + 1) ~ Sample(t, dt)(y); r(k + 1) ~ 1.0; ud(k + 1) ~ kp * (r(k + 1) - yd(k + 1))] -d = Clock(t, dt) +d = Clock(dt) # Note that TearingState reorders the equations -@test eqmap[1] == Continuous() +@test eqmap[1] == Continuous @test eqmap[2] == d @test eqmap[3] == d @test eqmap[4] == d -@test eqmap[5] == Continuous() -@test eqmap[6] == Continuous() +@test eqmap[5] == Continuous +@test eqmap[6] == Continuous @test varmap[yd] == d @test varmap[ud] == d @test varmap[r] == d -@test varmap[x] == Continuous() -@test varmap[y] == Continuous() -@test varmap[u] == Continuous() +@test varmap[x] == Continuous +@test varmap[y] == Continuous +@test varmap[u] == Continuous @info "Testing shift normalization" dt = 0.1 @variables x(t) y(t) u(t) yd(t) ud(t) @parameters kp -d = Clock(t, dt) +d = Clock(dt) k = ShiftIndex(d) -eqs = [yd ~ Sample(t, dt)(y) +eqs = [yd ~ Sample(dt)(y) ud ~ kp * yd + ud(k - 2) # plant (time continuous part) @@ -171,10 +171,10 @@ eqs = [yd ~ Sample(t, dt)(y) eqs = [ # controller (time discrete part `dt=0.1`) - yd1 ~ Sample(t, dt)(y) - ud1 ~ kp * (Sample(t, dt)(r) - yd1) - yd2 ~ Sample(t, dt2)(y) - ud2 ~ kp * (Sample(t, dt2)(r) - yd2) + yd1 ~ Sample(dt)(y) + ud1 ~ kp * (Sample(dt)(r) - yd1) + yd2 ~ Sample(dt2)(y) + ud2 ~ kp * (Sample(dt2)(r) - yd2) # plant (time continuous part) u ~ Hold(ud1) + Hold(ud2) @@ -183,8 +183,8 @@ eqs = [yd ~ Sample(t, dt)(y) @named sys = ODESystem(eqs, t) ci, varmap = infer_clocks(sys) - d = Clock(t, dt) - d2 = Clock(t, dt2) + d = Clock(dt) + d2 = Clock(dt2) #@test get_eq_domain(eqs[1]) == d #@test get_eq_domain(eqs[3]) == d2 @@ -192,15 +192,15 @@ eqs = [yd ~ Sample(t, dt)(y) @test varmap[ud1] == d @test varmap[yd2] == d2 @test varmap[ud2] == d2 - @test varmap[r] == Continuous() - @test varmap[x] == Continuous() - @test varmap[y] == Continuous() - @test varmap[u] == Continuous() + @test varmap[r] == Continuous + @test varmap[x] == Continuous + @test varmap[y] == Continuous + @test varmap[u] == Continuous @info "test composed systems" dt = 0.5 - d = Clock(t, dt) + d = Clock(dt) k = ShiftIndex(d) timevec = 0:0.1:4 @@ -240,16 +240,16 @@ eqs = [yd ~ Sample(t, dt)(y) ci, varmap = infer_clocks(cl) - @test varmap[f.x] == Clock(t, 0.5) - @test varmap[p.x] == Continuous() - @test varmap[p.y] == Continuous() - @test varmap[c.ud] == Clock(t, 0.5) - @test varmap[c.yd] == Clock(t, 0.5) - @test varmap[c.y] == Continuous() - @test varmap[f.y] == Clock(t, 0.5) - @test varmap[f.u] == Clock(t, 0.5) - @test varmap[p.u] == Continuous() - @test varmap[c.r] == Clock(t, 0.5) + @test varmap[f.x] == Clock(0.5) + @test varmap[p.x] == Continuous + @test varmap[p.y] == Continuous + @test varmap[c.ud] == Clock(0.5) + @test varmap[c.yd] == Clock(0.5) + @test varmap[c.y] == Continuous + @test varmap[f.y] == Clock(0.5) + @test varmap[f.u] == Clock(0.5) + @test varmap[p.u] == Continuous + @test varmap[c.r] == Clock(0.5) ## Multiple clock rates @info "Testing multi-rate hybrid system" @@ -260,10 +260,10 @@ eqs = [yd ~ Sample(t, dt)(y) eqs = [ # controller (time discrete part `dt=0.1`) - yd1 ~ Sample(t, dt)(y) + yd1 ~ Sample(dt)(y) ud1 ~ kp * (r - yd1) # controller (time discrete part `dt=0.2`) - yd2 ~ Sample(t, dt2)(y) + yd2 ~ Sample(dt2)(y) ud2 ~ kp * (r - yd2) # plant (time continuous part) @@ -273,8 +273,8 @@ eqs = [yd ~ Sample(t, dt)(y) @named cl = ODESystem(eqs, t) - d = Clock(t, dt) - d2 = Clock(t, dt2) + d = Clock(dt) + d2 = Clock(dt2) ci, varmap = infer_clocks(cl) @test varmap[yd1] == d @@ -331,8 +331,8 @@ eqs = [yd ~ Sample(t, dt)(y) using ModelingToolkitStandardLibrary.Blocks dt = 0.05 - d = Clock(t, dt) - k = ShiftIndex() + d = Clock(dt) + k = ShiftIndex(d) @mtkmodel DiscretePI begin @components begin @@ -362,7 +362,7 @@ eqs = [yd ~ Sample(t, dt)(y) output = RealOutput() end @equations begin - output.u ~ Sample(t, dt)(input.u) + output.u ~ Sample(dt)(input.u) end end @@ -474,7 +474,7 @@ eqs = [yd ~ Sample(t, dt)(y) ## Test continuous clock - c = ModelingToolkit.SolverStepClock(t) + c = ModelingToolkit.SolverStepClock k = ShiftIndex(c) @mtkmodel CounterSys begin diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index 242be8f1d7..fc03f53d74 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -157,10 +157,10 @@ end dt = 0.1 @variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t) @parameters kp kq - d = Clock(t, dt) + d = Clock(dt) k = ShiftIndex(d) - eqs = [yd ~ Sample(t, dt)(y) + eqs = [yd ~ Sample(dt)(y) ud ~ kp * (r - yd) + kq * z r ~ 1.0 u ~ Hold(ud)