Skip to content

Commit

Permalink
refactor: improve Symbol indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jul 29, 2024
1 parent 2c56f1c commit 3a073ec
Showing 1 changed file with 40 additions and 84 deletions.
124 changes: 40 additions & 84 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ end

ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)

const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
const UnknownIndexMap = Dict{
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}

struct IndexCache
unknown_idx::UnknownIndexMap
discrete_idx::Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int, Int}}
discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}}
tunable_idx::ParamIndexMap
constant_idx::ParamIndexMap
dependent_idx::ParamIndexMap
nonnumeric_idx::ParamIndexMap
observed_syms::Set{Union{Symbol, BasicSymbolic}}
observed_syms::Set{BasicSymbolic}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_sizes::Vector{BufferTemplate}
constant_buffer_sizes::Vector{BufferTemplate}
Expand All @@ -57,14 +57,6 @@ function IndexCache(sys::AbstractSystem)
end
unk_idxs[usym] = sym_idx
unk_idxs[rsym] = sym_idx
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
name = getname(usym)
rname = getname(rsym)
unk_idxs[name] = sym_idx
unk_idxs[rname] = sym_idx
symbol_to_variable[name] = sym
symbol_to_variable[rname] = sym
end
idx += length(sym)
end
for sym in unks
Expand All @@ -80,14 +72,6 @@ function IndexCache(sys::AbstractSystem)
rsym = renamespace(sys, arrsym)
unk_idxs[arrsym] = idxs
unk_idxs[rsym] = idxs
if hasname(arrsym)
name = getname(arrsym)
rname = getname(rsym)
unk_idxs[name] = idxs
unk_idxs[rname] = idxs
symbol_to_variable[name] = arrsym
symbol_to_variable[rname] = arrsym
end
end
end

Expand All @@ -102,16 +86,6 @@ function IndexCache(sys::AbstractSystem)
push!(observed_syms, ttsym)
push!(observed_syms, rsym)
push!(observed_syms, rttsym)
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
symbol_to_variable[getname(sym)] = eq.lhs
symbol_to_variable[getname(ttsym)] = eq.lhs
symbol_to_variable[getname(rsym)] = eq.lhs
symbol_to_variable[getname(rttsym)] = eq.lhs
push!(observed_syms, getname(sym))
push!(observed_syms, getname(ttsym))
push!(observed_syms, getname(rsym))
push!(observed_syms, getname(rttsym))
end
end
end

Expand Down Expand Up @@ -143,16 +117,12 @@ function IndexCache(sys::AbstractSystem)
rttinp = renamespace(sys, ttinp)
is_parameter(sys, inp) ||
error("Discrete subsystem $i input $inp is not a parameter")

disc_clocks[inp] = i
disc_clocks[ttinp] = i
disc_clocks[rinp] = i
disc_clocks[rttinp] = i
if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex)
disc_clocks[getname(inp)] = i
disc_clocks[getname(ttinp)] = i
disc_clocks[getname(rinp)] = i
disc_clocks[getname(rttinp)] = i
end

insert_by_type!(disc_buffers[i], inp)
end

Expand All @@ -163,16 +133,12 @@ function IndexCache(sys::AbstractSystem)
rttsym = renamespace(sys, ttsym)
is_parameter(sys, sym) ||
error("Discrete subsystem $i unknown $sym is not a parameter")

disc_clocks[sym] = i
disc_clocks[ttsym] = i
disc_clocks[rsym] = i
disc_clocks[rttsym] = i
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_clocks[getname(sym)] = i
disc_clocks[getname(ttsym)] = i
disc_clocks[getname(rsym)] = i
disc_clocks[getname(rttsym)] = i
end

insert_by_type!(disc_buffers[i], sym)
end
t = get_iv(sys)
Expand All @@ -191,12 +157,6 @@ function IndexCache(sys::AbstractSystem)
disc_clocks[ttsym] = i
disc_clocks[rsym] = i
disc_clocks[rttsym] = i
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_clocks[getname(sym)] = i
disc_clocks[getname(ttsym)] = i
disc_clocks[getname(rsym)] = i
disc_clocks[getname(rttsym)] = i
end
end
end

Expand Down Expand Up @@ -237,13 +197,7 @@ function IndexCache(sys::AbstractSystem)
disc_clocks[ttsym] = user_affect_clock
disc_clocks[rsym] = user_affect_clock
disc_clocks[rttsym] = user_affect_clock
if hasname(sym) &&
(!iscall(sym) || operation(sym) !== getindex)
disc_clocks[getname(sym)] = user_affect_clock
disc_clocks[getname(ttsym)] = user_affect_clock
disc_clocks[getname(rsym)] = user_affect_clock
disc_clocks[getname(rttsym)] = user_affect_clock
end

buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}())
insert_by_type!(buffer, affect.lhs)
else
Expand All @@ -259,12 +213,7 @@ function IndexCache(sys::AbstractSystem)
disc_clocks[ttdisc] = user_affect_clock
disc_clocks[rdisc] = user_affect_clock
disc_clocks[rttdisc] = user_affect_clock
if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex)
disc_clocks[getname(disc)] = user_affect_clock
disc_clocks[getname(ttdisc)] = user_affect_clock
disc_clocks[getname(rdisc)] = user_affect_clock
disc_clocks[getname(rttdisc)] = user_affect_clock
end

buffer = get!(
disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}())
insert_by_type!(buffer, disc)
Expand Down Expand Up @@ -316,21 +265,13 @@ function IndexCache(sys::AbstractSystem)
for (j, sym) in enumerate(buffer[btype])
disc_idxs[sym] = (clockidx, i, j)
disc_idxs[default_toterm(sym)] = (clockidx, i, j)
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_idxs[getname(sym)] = (clockidx, i, j)
disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j)
end
end
end
end
for (sym, clockid) in disc_clocks
haskey(disc_idxs, sym) && continue
disc_idxs[sym] = (clockid, 0, 0)
disc_idxs[default_toterm(sym)] = (clockid, 0, 0)
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
disc_idxs[getname(sym)] = (clockid, 0, 0)
disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0)
end
end

function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
Expand All @@ -345,16 +286,6 @@ function IndexCache(sys::AbstractSystem)
idxs[ttp] = (i, j)
idxs[rp] = (i, j)
idxs[rttp] = (i, j)
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
idxs[getname(p)] = (i, j)
idxs[getname(ttp)] = (i, j)
idxs[getname(rp)] = (i, j)
idxs[getname(rttp)] = (i, j)
symbol_to_variable[getname(p)] = p
symbol_to_variable[getname(ttp)] = p
symbol_to_variable[getname(rp)] = p
symbol_to_variable[getname(rttp)] = p
end
end
push!(buffer_sizes, BufferTemplate(T, length(buf)))
end
Expand All @@ -366,6 +297,14 @@ function IndexCache(sys::AbstractSystem)
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers)
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)

for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs),
observed_syms, independent_variable_symbols(sys)))
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
symbol_to_variable[getname(sym)] = sym
end
end

return IndexCache(
unk_idxs,
disc_idxs,
Expand All @@ -384,18 +323,26 @@ function IndexCache(sys::AbstractSystem)
end

function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym)
return check_index_map(ic.unknown_idx, sym) !== nothing
end

function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym::Symbol)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return false
end
return check_index_map(ic.unknown_idx, sym) !== nothing
end

function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return nothing
end
return check_index_map(ic.unknown_idx, sym)
end

function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return false
end
return check_index_map(ic.tunable_idx, sym) !== nothing ||
check_index_map(ic.discrete_idx, sym) !== nothing ||
check_index_map(ic.constant_idx, sym) !== nothing ||
Expand All @@ -405,7 +352,8 @@ end

function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
if sym isa Symbol
sym = ic.symbol_to_variable[sym]
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return nothing
end
validate_size = Symbolics.isarraysymbolic(sym) &&
Symbolics.shape(sym) !== Symbolics.Unknown()
Expand All @@ -425,10 +373,18 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
end

function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return false
end
return check_index_map(ic.discrete_idx, sym) !== nothing
end

function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return nothing
end
idx = check_index_map(ic.discrete_idx, sym)
idx === nothing && return nothing
clockid, partitionid... = idx
Expand Down

0 comments on commit 3a073ec

Please sign in to comment.