From 3a073eca16829d68c3334c920ff4382e3d36350f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 24 Jul 2024 12:31:56 +0530 Subject: [PATCH] refactor: improve `Symbol` indexing --- src/systems/index_cache.jl | 124 ++++++++++++------------------------- 1 file changed, 40 insertions(+), 84 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index f992cd4907..899bba4aa5 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -21,18 +21,18 @@ end ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false) -const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}} +const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} const UnknownIndexMap = Dict{ - Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}} + BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} struct IndexCache unknown_idx::UnknownIndexMap - discrete_idx::Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int, Int}} + discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}} tunable_idx::ParamIndexMap constant_idx::ParamIndexMap dependent_idx::ParamIndexMap nonnumeric_idx::ParamIndexMap - observed_syms::Set{Union{Symbol, BasicSymbolic}} + observed_syms::Set{BasicSymbolic} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_sizes::Vector{BufferTemplate} constant_buffer_sizes::Vector{BufferTemplate} @@ -57,14 +57,6 @@ function IndexCache(sys::AbstractSystem) end unk_idxs[usym] = sym_idx unk_idxs[rsym] = sym_idx - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - name = getname(usym) - rname = getname(rsym) - unk_idxs[name] = sym_idx - unk_idxs[rname] = sym_idx - symbol_to_variable[name] = sym - symbol_to_variable[rname] = sym - end idx += length(sym) end for sym in unks @@ -80,14 +72,6 @@ function IndexCache(sys::AbstractSystem) rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs unk_idxs[rsym] = idxs - if hasname(arrsym) - name = getname(arrsym) - rname = getname(rsym) - unk_idxs[name] = idxs - unk_idxs[rname] = idxs - symbol_to_variable[name] = arrsym - symbol_to_variable[rname] = arrsym - end end end @@ -102,16 +86,6 @@ function IndexCache(sys::AbstractSystem) push!(observed_syms, ttsym) push!(observed_syms, rsym) push!(observed_syms, rttsym) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - symbol_to_variable[getname(sym)] = eq.lhs - symbol_to_variable[getname(ttsym)] = eq.lhs - symbol_to_variable[getname(rsym)] = eq.lhs - symbol_to_variable[getname(rttsym)] = eq.lhs - push!(observed_syms, getname(sym)) - push!(observed_syms, getname(ttsym)) - push!(observed_syms, getname(rsym)) - push!(observed_syms, getname(rttsym)) - end end end @@ -143,16 +117,12 @@ function IndexCache(sys::AbstractSystem) rttinp = renamespace(sys, ttinp) is_parameter(sys, inp) || error("Discrete subsystem $i input $inp is not a parameter") + disc_clocks[inp] = i disc_clocks[ttinp] = i disc_clocks[rinp] = i disc_clocks[rttinp] = i - if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex) - disc_clocks[getname(inp)] = i - disc_clocks[getname(ttinp)] = i - disc_clocks[getname(rinp)] = i - disc_clocks[getname(rttinp)] = i - end + insert_by_type!(disc_buffers[i], inp) end @@ -163,16 +133,12 @@ function IndexCache(sys::AbstractSystem) rttsym = renamespace(sys, ttsym) is_parameter(sys, sym) || error("Discrete subsystem $i unknown $sym is not a parameter") + disc_clocks[sym] = i disc_clocks[ttsym] = i disc_clocks[rsym] = i disc_clocks[rttsym] = i - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_clocks[getname(sym)] = i - disc_clocks[getname(ttsym)] = i - disc_clocks[getname(rsym)] = i - disc_clocks[getname(rttsym)] = i - end + insert_by_type!(disc_buffers[i], sym) end t = get_iv(sys) @@ -191,12 +157,6 @@ function IndexCache(sys::AbstractSystem) disc_clocks[ttsym] = i disc_clocks[rsym] = i disc_clocks[rttsym] = i - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_clocks[getname(sym)] = i - disc_clocks[getname(ttsym)] = i - disc_clocks[getname(rsym)] = i - disc_clocks[getname(rttsym)] = i - end end end @@ -237,13 +197,7 @@ function IndexCache(sys::AbstractSystem) disc_clocks[ttsym] = user_affect_clock disc_clocks[rsym] = user_affect_clock disc_clocks[rttsym] = user_affect_clock - if hasname(sym) && - (!iscall(sym) || operation(sym) !== getindex) - disc_clocks[getname(sym)] = user_affect_clock - disc_clocks[getname(ttsym)] = user_affect_clock - disc_clocks[getname(rsym)] = user_affect_clock - disc_clocks[getname(rttsym)] = user_affect_clock - end + buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) insert_by_type!(buffer, affect.lhs) else @@ -259,12 +213,7 @@ function IndexCache(sys::AbstractSystem) disc_clocks[ttdisc] = user_affect_clock disc_clocks[rdisc] = user_affect_clock disc_clocks[rttdisc] = user_affect_clock - if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex) - disc_clocks[getname(disc)] = user_affect_clock - disc_clocks[getname(ttdisc)] = user_affect_clock - disc_clocks[getname(rdisc)] = user_affect_clock - disc_clocks[getname(rttdisc)] = user_affect_clock - end + buffer = get!( disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) insert_by_type!(buffer, disc) @@ -316,10 +265,6 @@ function IndexCache(sys::AbstractSystem) for (j, sym) in enumerate(buffer[btype]) disc_idxs[sym] = (clockidx, i, j) disc_idxs[default_toterm(sym)] = (clockidx, i, j) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_idxs[getname(sym)] = (clockidx, i, j) - disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j) - end end end end @@ -327,10 +272,6 @@ function IndexCache(sys::AbstractSystem) haskey(disc_idxs, sym) && continue disc_idxs[sym] = (clockid, 0, 0) disc_idxs[default_toterm(sym)] = (clockid, 0, 0) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_idxs[getname(sym)] = (clockid, 0, 0) - disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0) - end end function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}}) @@ -345,16 +286,6 @@ function IndexCache(sys::AbstractSystem) idxs[ttp] = (i, j) idxs[rp] = (i, j) idxs[rttp] = (i, j) - if hasname(p) && (!iscall(p) || operation(p) !== getindex) - idxs[getname(p)] = (i, j) - idxs[getname(ttp)] = (i, j) - idxs[getname(rp)] = (i, j) - idxs[getname(rttp)] = (i, j) - symbol_to_variable[getname(p)] = p - symbol_to_variable[getname(ttp)] = p - symbol_to_variable[getname(rp)] = p - symbol_to_variable[getname(rttp)] = p - end end push!(buffer_sizes, BufferTemplate(T, length(buf))) end @@ -366,6 +297,14 @@ function IndexCache(sys::AbstractSystem) dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers) nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers) + for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), + keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs), + observed_syms, independent_variable_symbols(sys))) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = sym + end + end + return IndexCache( unk_idxs, disc_idxs, @@ -384,18 +323,26 @@ function IndexCache(sys::AbstractSystem) end function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) - return check_index_map(ic.unknown_idx, sym) !== nothing -end - -function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym::Symbol) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.unknown_idx, sym) !== nothing end function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + end return check_index_map(ic.unknown_idx, sym) end function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.tunable_idx, sym) !== nothing || check_index_map(ic.discrete_idx, sym) !== nothing || check_index_map(ic.constant_idx, sym) !== nothing || @@ -405,7 +352,8 @@ end function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) if sym isa Symbol - sym = ic.symbol_to_variable[sym] + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing end validate_size = Symbolics.isarraysymbolic(sym) && Symbolics.shape(sym) !== Symbolics.Unknown() @@ -425,10 +373,18 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) end function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.discrete_idx, sym) !== nothing end function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + end idx = check_index_map(ic.discrete_idx, sym) idx === nothing && return nothing clockid, partitionid... = idx