From af37ae830c64d973f1b4d43552979cf0e06ba021 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 31 May 2024 22:49:03 +0530 Subject: [PATCH] fix: various bug and test fixes --- src/systems/abstractsystem.jl | 12 ++++++++---- src/systems/parameter_buffer.jl | 7 ++++--- test/mtkparameters.jl | 2 +- test/symbolic_indexing_interface.jl | 3 ++- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index b2e4be12f7..84db5f7b7c 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -427,7 +427,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return sym isa ParameterIndex || is_parameter(ic, sym) || - istree(sym) && operation(sym) === getindex && + istree(sym) && + operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end if unwrap(sym) isa Int @@ -462,10 +463,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) end elseif istree(sym) && operation(sym) === getindex && (idx = parameter_index(ic, first(arguments(sym)))) !== nothing - if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == nothing + if idx.portion isa SciMLStructures.Discrete && + idx.idx[2] == idx.idx[3] == nothing return nothing else - ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) + ParameterIndex( + idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) end else nothing @@ -485,7 +488,8 @@ end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing idx = parameter_index(ic, sym) - if idx === nothing || idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 + if idx === nothing || + idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 return nothing else return idx diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index ddcaa8a7d8..9694d1e749 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -132,7 +132,7 @@ function MTKParameters( end end tunable_buffer = narrow_buffer_type.(tunable_buffer) - disc_buffer = narrow_buffer_type.(disc_buffer) + disc_buffer = broadcast.(narrow_buffer_type, disc_buffer) const_buffer = narrow_buffer_type.(const_buffer) nonnumeric_buffer = narrow_buffer_type.(nonnumeric_buffer) @@ -149,7 +149,8 @@ function MTKParameters( oop, iip = build_function(dep_exprs, p...) update_function_iip, update_function_oop = RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip), RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(oop) - update_function_iip(ArrayPartition(dep_buffer), tunable_buffer..., disc_buffer..., + update_function_iip(ArrayPartition(dep_buffer), tunable_buffer..., + Iterators.flatten(disc_buffer)..., const_buffer..., nonnumeric_buffer..., dep_buffer...) dep_buffer = narrow_buffer_type.(dep_buffer) else @@ -442,7 +443,7 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val @set! newbuf.dependent = narrow_buffer_type_and_fallback_undefs.( oldbuf.dependent, split_into_buffers( - newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(false))) + newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(0))) end return newbuf end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 8b75c4f0e8..648bcbe95a 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -294,9 +294,9 @@ ps = MTKParameters(sys, yd2 => 2.0 + Sample(ssc)(x), Sample(t, dt)(x) => x, Sample(ssc)(x) => x, Hold(yd1) => yd1, Hold(yd2) => yd2], [x => 3.0]) -@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}} tsidx1 = timeseries_parameter_index(sys, flag).timeseries_idx tsidx2 = 3 - tsidx1 +@test SciMLBase.get_saveable_values(ps, tsidx1).x isa Tuple{Vector{Float64}, BitVector} @test length(ps.discrete[tsidx1][1]) == 3 @test length(ps.discrete[tsidx1][2]) == 1 @test length(ps.discrete[tsidx2][1]) == 3 diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index 4511b50ea6..8696e33684 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -19,7 +19,8 @@ using SciMLStructures: Tunable @test parameter_index(odesys, a) isa ParameterIndex{Tunable, Tuple{Int, Int}} @test parameter_index(odesys, b) == parameter_index(odesys, :b) @test parameter_index(odesys, b) isa ParameterIndex{Tunable, Tuple{Int, Int}} - @test parameter_index.((odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y,]) == + @test parameter_index.( + (odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y]) == [nothing, nothing, nothing, ParameterIndex(Tunable(), (1, 1)), nothing, nothing] @test isequal(parameter_symbols(odesys), [a, b]) @test all(is_independent_variable.((odesys,), [t, :t]))