diff --git a/.github/workflows/check-format.yml b/.github/workflows/check-format.yml index de643f2ef..6bc6bdd51 100644 --- a/.github/workflows/check-format.yml +++ b/.github/workflows/check-format.yml @@ -24,12 +24,8 @@ jobs: - name: "Cache dependencies" uses: julia-actions/cache@v2 - name: Install JuliaFormatter and format - # This will use the latest version by default but you can set the version like so: - # - # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' + julia --startup-file=no -e 'using Pkg; pkg"activate --temp"; pkg"add JuliaFormatter@1.0.61"; using JuliaFormatter; format("."; verbose=true)' - name: "Format check" run: | julia -e ' diff --git a/docs/src/customization.md b/docs/src/customization.md index 2a9d9a072..09c194341 100644 --- a/docs/src/customization.md +++ b/docs/src/customization.md @@ -49,6 +49,11 @@ If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_ case your expression needs additional parameters. See the method for `ParametricExpression` as an example. +You can look at the files `src/ParametricExpression.jl` and `src/TemplateExpression.jl` +for more examples of custom expression types, though note that `ParametricExpression` itself +is defined in DynamicExpressions.jl, while that file just overloads some methods for +SymbolicRegression.jl. + ## Other Customizations Other internal abstract types include the following: diff --git a/docs/src/types.md b/docs/src/types.md index 92bf5632e..cd62389be 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -60,6 +60,16 @@ ParametricNode These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`. +## Template Expressions + +Template expressions are a type of expression that allows you to specify a predefined structure. +This lets you also fit vector expressions, as the custom evaluation structure can simply return +a vector of tuples. + +```@docs +TemplateExpression +``` + ## Population Groups of equations are given as a population, which is diff --git a/examples/template_expression.jl b/examples/template_expression.jl new file mode 100644 index 000000000..ade5fc5cf --- /dev/null +++ b/examples/template_expression.jl @@ -0,0 +1,62 @@ +using SymbolicRegression +using Random: rand +using MLJBase: machine, fit!, report +using Test: @test + +options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +operators = options.operators +variable_names = (i -> "x$i").(1:3) +x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + +variable_mapping = (; f=[1, 2], g1=[3], g2=[3]) + +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) + return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2) )" +end +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) + return map(i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i]), eachindex(nt.f)) +end + +st_expr = TemplateExpression( + (; f=x1, g1=x3, g2=x3); + structure=my_structure, + operators, + variable_names, + variable_mapping, +) + +X = rand(100, 3) .* 10 + +# Our dataset is a vector of 2-tuples +y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))] + +model = SRRegressor(; + binary_operators=(+, *), + unary_operators=(sin,), + maxsize=15, + expression_type=TemplateExpression, + expression_options=(; structure=my_structure, variable_mapping), + # The elementwise needs to operate directly on each row of `y`: + elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2, + early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7, +) + +mach = machine(model, X, y) +fit!(mach) + +# Check the performance of the model +r = report(mach) +idx = r.best_idx +best_loss = r.losses[idx] + +@test best_loss < 1e-5 + +# Check the expression is split up correctly: +best_expr = r.equations[idx] +best_f = get_contents(best_expr).f +best_g1 = get_contents(best_expr).g1 +best_g2 = get_contents(best_expr).g2 + +@test best_f(X') ≈ (@. sin(X[:, 1])) +@test best_g1(X') ≈ (@. X[:, 3] * X[:, 3]) +@test best_g2(X') ≈ (@. X[:, 3]) diff --git a/src/AdaptiveParsimony.jl b/src/AdaptiveParsimony.jl index aa33fa613..e3fded95c 100644 --- a/src/AdaptiveParsimony.jl +++ b/src/AdaptiveParsimony.jl @@ -24,9 +24,7 @@ struct RunningSearchStatistics end function RunningSearchStatistics(; options::AbstractOptions, window_size::Int=100000) - maxsize = options.maxsize - actualMaxsize = maxsize + MAX_DEGREE - init_frequencies = ones(Float64, actualMaxsize) + init_frequencies = ones(Float64, options.maxsize) return RunningSearchStatistics( window_size, init_frequencies, copy(init_frequencies) / sum(init_frequencies) diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index df748c218..fb0bbb712 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -69,7 +69,7 @@ function flag_illegal_nests(tree::AbstractExpressionNode, options::AbstractOptio return false end -"""Check if user-passed constraints are violated or not""" +"""Check if user-passed constraints are satisfied. Returns false otherwise.""" function check_constraints( ex::AbstractExpression, options::AbstractOptions, diff --git a/src/Configure.jl b/src/Configure.jl index eefd63619..2b184e5cd 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -88,7 +88,7 @@ function test_dataset_configuration( ) where {T<:DATA_TYPE} n = dataset.n if n != size(dataset.X, 2) || - (dataset.y !== nothing && n != size(dataset.y::AbstractArray{T}, 1)) + (dataset.y !== nothing && n != size(dataset.y::AbstractArray, 1)) throw( AssertionError( "Dataset dimensions are invalid. Make sure X is of shape [features, rows], y is of shape [rows] and if there are weights, they are of shape [rows].", @@ -101,7 +101,7 @@ function test_dataset_configuration( end if !(typeof(options.elementwise_loss) <: SupervisedLoss) && - dataset.weighted && + is_weighted(dataset) && !(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)]) throw( AssertionError( @@ -132,7 +132,7 @@ function move_functions_to_workers( continue end ops = (options.elementwise_loss,) - example_inputs = if dataset.weighted + example_inputs = if is_weighted(dataset) (zero(T), zero(T), zero(T)) else (zero(T), zero(T)) diff --git a/src/Core.jl b/src/Core.jl index 0a16b52d0..2d6e73d89 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -12,9 +12,14 @@ include("Options.jl") using .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE -using .DatasetModule: Dataset +using .DatasetModule: Dataset, is_weighted using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation -using .OptionsStructModule: AbstractOptions, Options, ComplexityMapping, specialized_options +using .OptionsStructModule: + AbstractOptions, + Options, + ComplexityMapping, + specialized_options, + operator_specialization using .OperatorsModule: plus, sub, diff --git a/src/Dataset.jl b/src/Dataset.jl index c8cb9767a..f9e28bcc5 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -19,8 +19,7 @@ import ...deprecate_varmap dataset, if any. - `n::Int`: The number of samples. - `nfeatures::Int`: The number of features. -- `weighted::Bool`: Whether the dataset is non-uniformly weighted. -- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted, +- `weights::Union{AbstractVector,Nothing}`: If the dataset is weighted, these specify the per-sample weight (with shape `(n,)`). - `extra::NamedTuple`: Extra information to pass to a custom evaluation function. Since this is an arbitrary named tuple, you could pass @@ -49,8 +48,8 @@ mutable struct Dataset{ T<:DATA_TYPE, L<:LOSS_TYPE, AX<:AbstractMatrix{T}, - AY<:Union{AbstractVector{T},Nothing}, - AW<:Union{AbstractVector{T},Nothing}, + AY<:Union{AbstractVector,Nothing}, + AW<:Union{AbstractVector,Nothing}, NT<:NamedTuple, XU<:Union{AbstractVector{<:Quantity},Nothing}, YU<:Union{Quantity,Nothing}, @@ -62,7 +61,6 @@ mutable struct Dataset{ const index::Int const n::Int const nfeatures::Int - const weighted::Bool const weights::AW const extra::NT const avg_y::Union{T,Nothing} @@ -81,7 +79,7 @@ end Dataset(X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing, loss_type::Type=Nothing; - weights::Union{AbstractVector{T}, Nothing}=nothing, + weights::Union{AbstractVector, Nothing}=nothing, variable_names::Union{Array{String, 1}, Nothing}=nothing, y_variable_name::Union{String,Nothing}=nothing, extra::NamedTuple=NamedTuple(), @@ -93,10 +91,10 @@ Construct a dataset to pass between internal functions. """ function Dataset( X::AbstractMatrix{T}, - y::Union{AbstractVector{T},Nothing}=nothing, + y::Union{AbstractVector,Nothing}=nothing, loss_type::Type{L}=Nothing; index::Int=1, - weights::Union{AbstractVector{T},Nothing}=nothing, + weights::Union{AbstractVector,Nothing}=nothing, variable_names::Union{Array{String,1},Nothing}=nothing, display_variable_names=variable_names, y_variable_name::Union{String,Nothing}=nothing, @@ -133,7 +131,6 @@ function Dataset( n = size(X, BATCH_DIM) nfeatures = size(X, FEATURE_DIM) - weighted = weights !== nothing variable_names = if variable_names === nothing ["x$(i)" for i in 1:nfeatures] else @@ -150,10 +147,10 @@ function Dataset( else y_variable_name end - avg_y = if y === nothing + avg_y = if y === nothing || !(eltype(y) isa Number) nothing else - if weighted + if weights !== nothing sum(y .* weights) / sum(weights) else sum(y) / n @@ -207,7 +204,6 @@ function Dataset( index, n, nfeatures, - weighted, weights, extra, avg_y, @@ -222,26 +218,8 @@ function Dataset( y_sym_units, ) end -function Dataset( - X::AbstractMatrix, - y::Union{<:AbstractVector,Nothing}=nothing; - weights::Union{<:AbstractVector,Nothing}=nothing, - kws..., -) - T = promote_type( - eltype(X), - (y === nothing) ? eltype(X) : eltype(y), - (weights === nothing) ? eltype(X) : eltype(weights), - ) - X = Base.Fix1(convert, T).(X) - if y !== nothing - y = Base.Fix1(convert, T).(y) - end - if weights !== nothing - weights = Base.Fix1(convert, T).(weights) - end - return Dataset(X, y; weights=weights, kws...) -end + +is_weighted(dataset::Dataset) = dataset.weights !== nothing function error_on_mismatched_size(_, ::Nothing) return nothing diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 12b20a06c..709937ecf 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -1,3 +1,7 @@ +""" +This module provides functions for creating, initializing, and manipulating +`<:AbstractExpression` instances and their metadata within the SymbolicRegression.jl framework. +""" module ExpressionBuilderModule using DispatchDoctor: @unstable @@ -5,8 +9,6 @@ using DynamicExpressions: AbstractExpressionNode, AbstractExpression, Expression, - ParametricExpression, - ParametricNode, constructorof, get_tree, get_contents, @@ -15,27 +17,20 @@ using DynamicExpressions: with_metadata, count_scalar_constants, eval_tree_array -using Random: default_rng, AbstractRNG using StatsBase: StatsBase using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE using ..HallOfFameModule: HallOfFame -using ..LossFunctionsModule: maybe_getindex -using ..InterfaceDynamicExpressionsModule: expected_array_type using ..PopulationModule: Population using ..PopMemberModule: PopMember import DynamicExpressions: get_operators import ..CoreModule: create_expression -import ..MutationFunctionsModule: - make_random_leaf, crossover_trees, mutate_constant, mutate_factor -import ..LossFunctionsModule: eval_tree_dispatch -import ..ConstantOptimizationModule: count_constants_for_optimization @unstable function create_expression( t::T, options::AbstractOptions, dataset::Dataset{T,L}, ::Val{embed}=Val(false) ) where {T,L,embed} return create_expression( - constructorof(options.node_type)(; val=t), options, dataset, Val(embed) + t, options, dataset, options.node_type, options.expression_type, Val(embed) ) end @unstable function create_expression( @@ -44,14 +39,37 @@ end dataset::Dataset{T,L}, ::Val{embed}=Val(false), ) where {T,L,embed} - return constructorof(options.expression_type)( - t; init_params(options, dataset, nothing, Val(embed))... + return create_expression( + t, options, dataset, options.node_type, options.expression_type, Val(embed) ) end function create_expression( - ex::AbstractExpression{T}, ::AbstractOptions, ::Dataset{T,L}, ::Val{embed}=Val(false) + ex::AbstractExpression{T}, + options::AbstractOptions, + ::Dataset{T,L}, + ::Val{embed}=Val(false), ) where {T,L,embed} - return ex + return ex::options.expression_type +end +@unstable function create_expression( + t::T, + options::AbstractOptions, + dataset::Dataset{T,L}, + ::Type{N}, + ::Type{E}, + ::Val{embed}=Val(false), +) where {T,L,embed,N<:AbstractExpressionNode,E<:AbstractExpression} + return create_expression(constructorof(N)(; val=t), options, dataset, N, E, Val(embed)) +end +@unstable function create_expression( + t::AbstractExpressionNode{T}, + options::AbstractOptions, + dataset::Dataset{T,L}, + ::Type{<:AbstractExpressionNode}, + ::Type{E}, + ::Val{embed}=Val(false), +) where {T,L,embed,E<:AbstractExpression} + return constructorof(E)(t; init_params(options, dataset, nothing, Val(embed))...) end @unstable function init_params( options::AbstractOptions, @@ -60,13 +78,17 @@ end ::Val{embed}, ) where {T,L,embed} consistency_checks(options, prototype) - return (; + raw_params = (; operators=embed ? options.operators : nothing, variable_names=embed ? dataset.variable_names : nothing, extra_init_params( options.expression_type, prototype, options, dataset, Val(embed) )..., ) + return sort_params(raw_params, options.expression_type) +end +function sort_params(raw_params::NamedTuple, ::Type{<:AbstractExpression}) + return raw_params end function extra_init_params( ::Type{E}, @@ -75,46 +97,16 @@ function extra_init_params( dataset::Dataset{T}, ::Val{embed}, ) where {T,embed,E<:AbstractExpression} + # TODO: Potential aliasing here return (; options.expression_options...) end -function extra_init_params( - ::Type{E}, - prototype::Union{Nothing,ParametricExpression}, - options::AbstractOptions, - dataset::Dataset{T}, - ::Val{embed}, -) where {T,embed,E<:ParametricExpression} - num_params = options.expression_options.max_parameters - num_classes = length(unique(dataset.extra.classes)) - parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing - _parameters = if prototype === nothing - randn(T, (num_params, num_classes)) - else - copy(get_metadata(prototype).parameters) - end - return (; parameters=_parameters, parameter_names) -end consistency_checks(::AbstractOptions, prototype::Nothing) = nothing function consistency_checks(options::AbstractOptions, prototype) - if prototype === nothing - return nothing - end @assert( prototype isa options.expression_type, "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))" ) - if prototype isa ParametricExpression - if prototype.metadata.parameter_names !== nothing - @assert( - length(prototype.metadata.parameter_names) == - options.expression_options.max_parameters, - "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(prototype.metadata.parameter_names)" - ) - end - @assert size(prototype.metadata.parameters, 1) == - options.expression_options.max_parameters - end return nothing end @@ -158,14 +150,15 @@ end end end -"""Strips all metadata except for top-level information""" -function strip_metadata( - ex::Expression, options::AbstractOptions, dataset::Dataset{T,L} -) where {T,L} - return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) -end +""" +Strips all metadata except for top-level information, so that we avoid needing +to copy irrelevant information to the evolution itself (like variable names +stored within an expression). + +The opposite of this is `embed_metadata`. +""" function strip_metadata( - ex::ParametricExpression, options::AbstractOptions, dataset::Dataset{T,L} + ex::AbstractExpression, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) end @@ -195,93 +188,6 @@ function strip_metadata( ) end -function eval_tree_dispatch( - tree::ParametricExpression{T}, dataset::Dataset{T}, options::AbstractOptions, idx -) where {T<:DATA_TYPE} - A = expected_array_type(dataset.X) - return eval_tree_array( - tree, - maybe_getindex(dataset.X, :, idx), - maybe_getindex(dataset.extra.classes, idx), - options.operators, - )::Tuple{A,Bool} -end - -function make_random_leaf( - nfeatures::Int, - ::Type{T}, - ::Type{N}, - rng::AbstractRNG=default_rng(), - options::Union{AbstractOptions,Nothing}=nothing, -) where {T<:DATA_TYPE,N<:ParametricNode} - choice = rand(rng, 1:3) - if choice == 1 - return ParametricNode(; val=randn(rng, T)) - elseif choice == 2 - return ParametricNode(T; feature=rand(rng, 1:nfeatures)) - else - tree = ParametricNode{T}() - tree.val = zero(T) - tree.degree = 0 - tree.feature = 0 - tree.constant = false - tree.is_parameter = true - tree.parameter = rand( - rng, UInt16(1):UInt16(options.expression_options.max_parameters) - ) - return tree - end -end - -function crossover_trees( - ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng() -) where {T} - tree1 = get_contents(ex1) - tree2 = get_contents(ex2) - out1, out2 = crossover_trees(tree1, tree2, rng) - ex1 = with_contents(ex1, out1) - ex2 = with_contents(ex2, out2) - - # We also randomly share parameters - nparams1 = size(ex1.metadata.parameters, 1) - nparams2 = size(ex2.metadata.parameters, 1) - num_params_switch = min(nparams1, nparams2) - idx_to_switch = StatsBase.sample( - rng, 1:num_params_switch, num_params_switch; replace=false - ) - for param_idx in idx_to_switch - ex2_params = ex2.metadata.parameters[param_idx, :] - ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :] - ex1.metadata.parameters[param_idx, :] .= ex2_params - end - - return ex1, ex2 -end - -function count_constants_for_optimization(ex::ParametricExpression) - return count_scalar_constants(get_tree(ex)) + length(ex.metadata.parameters) -end - -function mutate_constant( - ex::ParametricExpression{T}, - temperature, - options::AbstractOptions, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - if rand(rng, Bool) - # Normal mutation of inner constant - tree = get_contents(ex) - return with_contents(ex, mutate_constant(tree, temperature, options, rng)) - else - # Mutate parameters - parameter_index = rand(rng, 1:(options.expression_options.max_parameters)) - # We mutate all the parameters at once - factor = mutate_factor(T, temperature, options, rng) - ex.metadata.parameters[parameter_index, :] .*= factor - return ex - end -end - @unstable function get_operators(ex::AbstractExpression, options::AbstractOptions) return get_operators(ex, options.operators) end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 71032dfd5..a75b82939 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -63,7 +63,6 @@ Arguments: function HallOfFame( options::AbstractOptions, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} - actualMaxsize = options.maxsize + MAX_DEGREE base_tree = create_expression(zero(T), options, dataset) return HallOfFame{T,L,typeof(base_tree)}( @@ -75,9 +74,9 @@ function HallOfFame( options; parent=-1, deterministic=options.deterministic, - ) for i in 1:actualMaxsize + ) for i in 1:(options.maxsize) ], - [false for i in 1:actualMaxsize], + [false for i in 1:(options.maxsize)], ) end @@ -95,8 +94,7 @@ function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} P = PopMember{T,L,N} # Dominating pareto curve - must be better than all simpler equations dominating = P[] - actualMaxsize = length(hallOfFame.members) - for size in 1:actualMaxsize + for size in eachindex(hallOfFame.members) if !hallOfFame.exists[size] continue end diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 887627b7d..6c8aa45fd 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -7,11 +7,10 @@ using DynamicExpressions: GenericOperatorEnum, AbstractExpression, AbstractExpressionNode, - ParametricExpression, Node, GraphNode using DynamicQuantities: dimension, ustrip -using ..CoreModule: AbstractOptions +using ..CoreModule: AbstractOptions, Dataset using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap using ..UtilsModule: subscriptify @@ -56,39 +55,25 @@ function DE.eval_tree_array( options::AbstractOptions; kws..., ) - A = expected_array_type(X) - return DE.eval_tree_array( + A = expected_array_type(X, typeof(tree)) + out, complete = DE.eval_tree_array( tree, X, DE.get_operators(tree, options); turbo=options.turbo, bumper=options.bumper, kws..., - )::Tuple{A,Bool} -end -function DE.eval_tree_array( - tree::ParametricExpression, - X::AbstractMatrix, - classes::AbstractVector{<:Integer}, - options::AbstractOptions; - kws..., -) - A = expected_array_type(X) - return DE.eval_tree_array( - tree, - X, - classes, - DE.get_operators(tree, options); - turbo=options.turbo, - bumper=options.bumper, - kws..., - )::Tuple{A,Bool} + ) + return out::A, complete::Bool end -# Improve type inference by telling Julia the expected array returned -function expected_array_type(X::AbstractArray) +"""Improve type inference by telling Julia the expected array returned.""" +function expected_array_type(X::AbstractArray, ::Type) return typeof(similar(X, axes(X, 2))) end +function expected_array_type(X::AbstractArray, ::Type, ::Val{:eval_grad_tree_array}) + return typeof(X) +end """ eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions, direction::Int) @@ -116,11 +101,12 @@ function DE.eval_diff_tree_array( options::AbstractOptions, direction::Int, ) - A = expected_array_type(X) # TODO: Add `AbstractExpression` implementation in `Expression.jl` - return DE.eval_diff_tree_array( + A = expected_array_type(X, typeof(tree)) + out, grad, complete = DE.eval_diff_tree_array( DE.get_tree(tree), X, DE.get_operators(tree, options), direction - )::Tuple{A,A,Bool} + ) + return out::A, grad::A, complete::Bool end """ @@ -150,11 +136,12 @@ function DE.eval_grad_tree_array( options::AbstractOptions; kws..., ) - A = expected_array_type(X) - M = typeof(X) # TODO: This won't work with StaticArrays! - return DE.eval_grad_tree_array( + A = expected_array_type(X, typeof(tree)) + dA = expected_array_type(X, typeof(tree), Val(:eval_grad_tree_array)) + out, grad, complete = DE.eval_grad_tree_array( tree, X, DE.get_operators(tree, options); kws... - )::Tuple{A,M,Bool} + ) + return out::A, grad::dA, complete::Bool end """ @@ -167,11 +154,12 @@ function DE.differentiable_eval_tree_array( X::AbstractArray, options::AbstractOptions, ) - A = expected_array_type(X) # TODO: Add `AbstractExpression` implementation in `Expression.jl` - return DE.differentiable_eval_tree_array( + A = expected_array_type(X, typeof(tree)) + out, complete = DE.differentiable_eval_tree_array( DE.get_tree(tree), X, DE.get_operators(tree, options) - )::Tuple{A,Bool} + ) + return out::A, complete::Bool end const WILDCARD_UNIT_STRING = "[?]" diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index b41ad3a38..01dcca86b 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -5,14 +5,15 @@ using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree, eval_tree_array using LossFunctions: LossFunctions using LossFunctions: SupervisedLoss -using ..InterfaceDynamicExpressionsModule: expected_array_type -using ..CoreModule: AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE +using ..CoreModule: + AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE, is_weighted using ..ComplexityModule: compute_complexity using ..DimensionalAnalysisModule: violates_dimensional_constraints +using ..InterfaceDynamicExpressionsModule: expected_array_type function _loss( x::AbstractArray{T}, y::AbstractArray{T}, loss::LT -) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} +) where {T,LT<:Union{Function,SupervisedLoss}} if loss isa SupervisedLoss return LossFunctions.mean(loss, x, y) else @@ -23,7 +24,7 @@ end function _weighted_loss( x::AbstractArray{T}, y::AbstractArray{T}, w::AbstractArray{T}, loss::LT -) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} +) where {T,LT<:Union{Function,SupervisedLoss}} if loss isa SupervisedLoss return sum(loss, x, y, w; normalize=true) else @@ -42,13 +43,18 @@ end end function eval_tree_dispatch( - tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, - dataset::Dataset{T}, - options::AbstractOptions, - idx, -) where {T<:DATA_TYPE} - A = expected_array_type(dataset.X) - return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options)::Tuple{A,Bool} + tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx +) + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) + return out::A, complete::Bool +end +function eval_tree_dispatch( + tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx +) + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) + return out::A, complete::Bool end # Evaluate the loss of a particular expression on the input dataset. @@ -64,15 +70,19 @@ function _eval_loss( return L(Inf) end - loss_val = if dataset.weighted + loss_val = if is_weighted(dataset) _weighted_loss( prediction, - maybe_getindex(dataset.y, idx), + maybe_getindex(dataset.y::AbstractArray, idx), maybe_getindex(dataset.weights, idx), options.elementwise_loss, ) else - _loss(prediction, maybe_getindex(dataset.y, idx), options.elementwise_loss) + _loss( + prediction, + maybe_getindex(dataset.y::AbstractArray, idx), + options.elementwise_loss, + ) end if regularization diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 028493ff7..4d7a1d140 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -508,7 +508,7 @@ const input_scitype = Union{ MMI.metadata_model( SRRegressor; input_scitype, - target_scitype=AbstractVector{<:MMI.Continuous}, + target_scitype=AbstractVector{<:Any}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor", @@ -517,7 +517,7 @@ MMI.metadata_model( MMI.metadata_model( MultitargetSRRegressor; input_scitype, - target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, + target_scitype=Union{MMI.Table(Any),AbstractMatrix{<:Any}}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor", diff --git a/src/Mutate.jl b/src/Mutate.jl index 0f383177b..7b828f6f3 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -2,11 +2,9 @@ module MutateModule using DynamicExpressions: AbstractExpression, - ParametricExpression, with_contents, get_tree, preserve_sharing, - copy_node, count_scalar_constants, simplify_tree!, combine_operators @@ -29,7 +27,8 @@ using ..MutationFunctionsModule: crossover_trees, form_random_connection!, break_random_connection!, - randomly_rotate_tree! + randomly_rotate_tree!, + randomize_tree using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder @@ -91,11 +90,8 @@ Note that the weights were already copied, so you don't need to worry about muta - `curmaxsize::Int`: The current maximum size constraint for the member's expression tree. """ function condition_mutation_weights!( - weights::AbstractMutationWeights, - member::PopMember, - options::AbstractOptions, - curmaxsize::Int, -) + weights::AbstractMutationWeights, member::P, options::AbstractOptions, curmaxsize::Int +) where {T,L,N<:AbstractExpression,P<:PopMember{T,L,N}} tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 @@ -152,17 +148,6 @@ function condition_mutate_constant!( return nothing end -function condition_mutate_constant!( - ::Type{<:ParametricExpression}, - weights::AbstractMutationWeights, - member::PopMember, - options::AbstractOptions, - curmaxsize::Int, -) - # Avoid modifying the mutate_constant weight, since - # otherwise we would be mutating constants all the time! - return nothing -end # Go through one simulated options.annealing mutation cycle # exp(-delta/T) defines probability of accepting a change @@ -205,7 +190,7 @@ function next_generation( ############################################# local tree while (!successful_mutation) && attempts < max_attempts - tree = copy_node(member.tree) + tree = copy(member.tree) mutation_result = _dispatch_mutations!( tree, @@ -250,7 +235,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy_node(member.tree), + copy(member.tree), beforeScore, beforeLoss, options, @@ -279,7 +264,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy_node(member.tree), + copy(member.tree), beforeScore, beforeLoss, options, @@ -322,7 +307,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy_node(member.tree), + copy(member.tree), beforeScore, beforeLoss, options, @@ -592,10 +577,7 @@ function mutate!( nfeatures, kws..., ) where {T,N<:AbstractExpression{T},P<:PopMember} - tree_size_to_generate = rand(1:curmaxsize) - tree = with_contents( - tree, gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T) - ) + tree = randomize_tree(tree, curmaxsize, options, nfeatures) @recorder recorder["type"] = "randomize" return MutationResult{N,P}(; tree=tree) end diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index 496d584dd..73e0367b0 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -9,13 +9,37 @@ using DynamicExpressions: get_contents, with_contents, constructorof, - copy_node, set_node!, count_nodes, has_constants, has_operators using ..CoreModule: AbstractOptions, DATA_TYPE +""" + get_contents_for_mutation(ex::AbstractExpression, rng::AbstractRNG) + +Return the contents of an expression, which can be mutated. +You can overload this function for custom expression types that +need to be mutated in a specific way. + +The second return value is an optional context object that will be +passed to the `with_contents_for_mutation` function. +""" +function get_contents_for_mutation(ex::AbstractExpression, rng::AbstractRNG) + return get_contents(ex), nothing +end + +""" + with_contents_for_mutation(ex::AbstractExpression, context) + +Replace the contents of an expression with the given context object. +You can overload this function for custom expression types that +need to be mutated in a specific way. +""" +function with_contents_for_mutation(ex::AbstractExpression, new_contents, ::Nothing) + return with_contents(ex, new_contents) +end + """ random_node(tree::AbstractNode; filter::F=Returns(true)) @@ -34,8 +58,8 @@ end """Swap operands in binary operator for ops like pow and divide""" function swap_operands(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - ex = with_contents(ex, swap_operands(tree, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation(ex, swap_operands(tree, rng), context) return ex end function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) @@ -51,8 +75,8 @@ end function mutate_operator( ex::AbstractExpression{T}, options::AbstractOptions, rng::AbstractRNG=default_rng() ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, mutate_operator(tree, options, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation(ex, mutate_operator(tree, options, rng), context) return ex end function mutate_operator( @@ -79,8 +103,10 @@ function mutate_constant( options::AbstractOptions, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, mutate_constant(tree, temperature, options, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, mutate_constant(tree, temperature, options, rng), context + ) return ex end function mutate_constant( @@ -125,8 +151,10 @@ function append_random_op( rng::AbstractRNG=default_rng(); makeNewBinOp::Union{Bool,Nothing}=nothing, ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, append_random_op(tree, options, nfeatures, rng; makeNewBinOp)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, append_random_op(tree, options, nfeatures, rng; makeNewBinOp), context + ) return ex end function append_random_op( @@ -168,8 +196,10 @@ function insert_random_op( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, insert_random_op(tree, options, nfeatures, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, insert_random_op(tree, options, nfeatures, rng), context + ) return ex end function insert_random_op( @@ -181,7 +211,7 @@ function insert_random_op( node = rand(rng, NodeSampler(; tree)) choice = rand(rng) makeNewBinOp = choice < options.nbin / (options.nuna + options.nbin) - left = copy_node(node) + left = copy(node) if makeNewBinOp right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) @@ -202,8 +232,10 @@ function prepend_random_op( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, prepend_random_op(tree, options, nfeatures, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, prepend_random_op(tree, options, nfeatures, rng), context + ) return ex end function prepend_random_op( @@ -215,7 +247,7 @@ function prepend_random_op( node = tree choice = rand(rng) makeNewBinOp = choice < options.nbin / (options.nuna + options.nbin) - left = copy_node(tree) + left = copy(tree) if makeNewBinOp right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) @@ -263,8 +295,10 @@ function delete_random_op!( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, delete_random_op!(tree, options, nfeatures, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, delete_random_op!(tree, options, nfeatures, rng), context + ) return ex end function delete_random_op!( @@ -312,6 +346,30 @@ function delete_random_op!( return tree end +function randomize_tree( + ex::AbstractExpression, + curmaxsize::Int, + options::AbstractOptions, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, randomize_tree(tree, curmaxsize, options, nfeatures, rng), context + ) + return ex +end +function randomize_tree( + ::AbstractExpressionNode{T}, + curmaxsize::Int, + options::AbstractOptions, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree_size_to_generate = rand(rng, 1:curmaxsize) + return gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T, rng) +end + """Create a random equation by appending random operators""" function gen_random_tree( length::Int, @@ -353,11 +411,14 @@ end function crossover_trees( ex1::E, ex2::E, rng::AbstractRNG=default_rng() ) where {T,E<:AbstractExpression{T}} - tree1 = get_contents(ex1) - tree2 = get_contents(ex2) + if ex1 === ex2 + error("Attempted to crossover the same expression!") + end + tree1, context1 = get_contents_for_mutation(ex1, rng) + tree2, context2 = get_contents_for_mutation(ex2, rng) out1, out2 = crossover_trees(tree1, tree2, rng) - ex1 = with_contents(ex1, out1) - ex2 = with_contents(ex2, out2) + ex1 = with_contents_for_mutation(ex1, out1, context1) + ex2 = with_contents_for_mutation(ex2, out2, context2) return ex1, ex2 end @@ -365,23 +426,26 @@ end function crossover_trees( tree1::N, tree2::N, rng::AbstractRNG=default_rng() ) where {T,N<:AbstractExpressionNode{T}} - tree1 = copy_node(tree1) - tree2 = copy_node(tree2) + if tree1 === tree2 + error("Attempted to crossover the same tree!") + end + tree1 = copy(tree1) + tree2 = copy(tree2) node1, parent1, side1 = random_node_and_parent(tree1, rng) node2, parent2, side2 = random_node_and_parent(tree2, rng) - node1 = copy_node(node1) + node1 = copy(node1) if side1 == 'l' - parent1.l = copy_node(node2) + parent1.l = copy(node2) # tree1 now contains this. elseif side1 == 'r' - parent1.r = copy_node(node2) + parent1.r = copy(node2) # tree1 now contains this. else # 'n' # This means that there is no parent2. - tree1 = copy_node(node2) + tree1 = copy(node2) end if side2 == 'l' @@ -408,8 +472,8 @@ function get_two_nodes_without_loop(tree::AbstractNode, rng::AbstractRNG; max_at end function form_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - return with_contents(ex, form_random_connection!(tree, rng)) + tree, context = get_contents_for_mutation(ex, rng) + return with_contents_for_mutation(ex, form_random_connection!(tree, rng), context) end function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) if length(tree) < 5 @@ -432,8 +496,8 @@ function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rn end function break_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - return with_contents(ex, break_random_connection!(tree, rng)) + tree, context = get_contents_for_mutation(ex, rng) + return with_contents_for_mutation(ex, break_random_connection!(tree, rng), context) end function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) tree.degree == 0 && return tree @@ -451,9 +515,9 @@ function is_valid_rotation_node(node::AbstractNode) end function randomly_rotate_tree!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) + tree, context = get_contents_for_mutation(ex, rng) rotated_tree = randomly_rotate_tree!(tree, rng) - return with_contents(ex, rotated_tree) + return with_contents_for_mutation(ex, rotated_tree, context) end function randomly_rotate_tree!(tree::AbstractNode, rng::AbstractRNG=default_rng()) num_rotation_nodes = count(is_valid_rotation_node, tree) diff --git a/src/Options.jl b/src/Options.jl index 701fde2ff..4c84eb56e 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -510,11 +510,13 @@ $(OPTION_DESCRIPTIONS) # Not search options; just construction options: define_helper_functions::Bool=true, deprecated_return_state=nothing, - # Deprecated args: + ######################################### + # Deprecated args: ###################### fast_cycle::Bool=false, npopulations::Union{Nothing,Integer}=nothing, npop::Union{Nothing,Integer}=nothing, kws..., + ######################################### ) for k in keys(kws) !haskey(deprecated_options_mapping, k) && error("Unknown keyword argument: $k") @@ -733,7 +735,7 @@ $(OPTION_DESCRIPTIONS) options = Options{ typeof(complexity_mapping), - operator_specialization(typeof(operators)), + operator_specialization(typeof(operators), expression_type), node_type, expression_type, typeof(expression_options), diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index ba20fe92c..fa8a0035b 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -113,13 +113,16 @@ function ComplexityMapping( ) end -# Controls level of specialization we compile -function operator_specialization end -if VERSION >= v"1.10.0-DEV.0" - @eval operator_specialization(::Type{<:OperatorEnum}) = OperatorEnum -else - @eval operator_specialization(O::Type{<:OperatorEnum}) = O -end +""" +Controls level of specialization we compile into `Options`. + +Overload if needed for custom expression types. +""" +operator_specialization( + ::Type{O}, ::Type{<:AbstractExpression} +) where {O<:AbstractOperatorEnum} = O +@unstable operator_specialization(::Type{<:OperatorEnum}, ::Type{<:AbstractExpression}) = + OperatorEnum """ AbstractOptions diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl new file mode 100644 index 000000000..f98a1de08 --- /dev/null +++ b/src/ParametricExpression.jl @@ -0,0 +1,187 @@ +""" +Note that ParametricExpression is defined within DynamicExpressions.jl, +this file just adds custom behavior for SymbolicRegression.jl, where needed. +""" +module ParametricExpressionModule + +using DynamicExpressions: + DynamicExpressions as DE, + AbstractExpression, + ParametricExpression, + ParametricNode, + get_metadata, + with_metadata, + get_contents, + with_contents, + get_tree, + eval_tree_array +using StatsBase: StatsBase +using Random: default_rng, AbstractRNG + +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, AbstractMutationWeights +using ..PopMemberModule: PopMember +using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..LossFunctionsModule: LossFunctionsModule as LF +using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..MutateModule: MutateModule as MM +using ..MutationFunctionsModule: MutationFunctionsModule as MF +using ..ConstantOptimizationModule: ConstantOptimizationModule as CO + +function EB.extra_init_params( + ::Type{E}, + prototype::Union{Nothing,ParametricExpression}, + options::AbstractOptions, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:ParametricExpression} + num_params = options.expression_options.max_parameters + num_classes = length(unique(dataset.extra.classes)) + parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing + _parameters = if prototype === nothing + randn(T, (num_params, num_classes)) + else + copy(get_metadata(prototype).parameters) + end + return (; parameters=_parameters, parameter_names) +end +function EB.consistency_checks(options::AbstractOptions, prototype::ParametricExpression) + @assert( + options.expression_type <: ParametricExpression, + "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))" + ) + if get_metadata(prototype).parameter_names !== nothing + @assert( + length(get_metadata(prototype).parameter_names) == + options.expression_options.max_parameters, + "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(get_metadata(prototype).parameter_names)" + ) + end + @assert size(get_metadata(prototype).parameters, 1) == + options.expression_options.max_parameters + return nothing +end + +function DE.eval_tree_array( + tree::ParametricExpression, + X::AbstractMatrix, + classes::AbstractVector{<:Integer}, + options::AbstractOptions; + kws..., +) + A = expected_array_type(X, typeof(tree)) + out, complete = DE.eval_tree_array( + tree, + X, + classes, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + ) + return out::A, complete::Bool +end +function LF.eval_tree_dispatch( + tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx +) + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = DE.eval_tree_array( + tree, + LF.maybe_getindex(dataset.X, :, idx), + LF.maybe_getindex(dataset.extra.classes, idx), + options.operators, + ) + return out::A, complete::Bool +end + +function MM.condition_mutate_constant!( + ::Type{<:ParametricExpression}, + weights::AbstractMutationWeights, + member::PopMember, + options::AbstractOptions, + curmaxsize::Int, +) + # Avoid modifying the mutate_constant weight, since + # otherwise we would be mutating constants all the time! + return nothing +end +function MF.make_random_leaf( + nfeatures::Int, + ::Type{T}, + ::Type{N}, + rng::AbstractRNG=default_rng(), + options::Union{AbstractOptions,Nothing}=nothing, +) where {T<:DATA_TYPE,N<:ParametricNode} + choice = rand(rng, 1:3) + if choice == 1 + return ParametricNode(; val=randn(rng, T)) + elseif choice == 2 + return ParametricNode(T; feature=rand(rng, 1:nfeatures)) + else + tree = ParametricNode{T}() + tree.val = zero(T) + tree.degree = 0 + tree.feature = 0 + tree.constant = false + tree.is_parameter = true + tree.parameter = rand( + rng, UInt16(1):UInt16(options.expression_options.max_parameters) + ) + return tree + end +end + +function MF.crossover_trees( + ex1::ParametricExpression{T}, + ex2::ParametricExpression{T}, + rng::AbstractRNG=default_rng(), +) where {T} + tree1 = get_contents(ex1) + tree2 = get_contents(ex2) + out1, out2 = MF.crossover_trees(tree1, tree2, rng) + ex1 = with_contents(ex1, out1) + ex2 = with_contents(ex2, out2) + + # We also randomly share parameters + nparams1 = size(get_metadata(ex1).parameters, 1) + nparams2 = size(get_metadata(ex2).parameters, 1) + num_params_switch = min(nparams1, nparams2) + idx_to_switch = StatsBase.sample( + rng, 1:num_params_switch, num_params_switch; replace=false + ) + for param_idx in idx_to_switch + # TODO: Ensure no issues from aliasing here + ex2_params = get_metadata(ex2).parameters[param_idx, :] + get_metadata(ex2).parameters[param_idx, :] .= get_metadata(ex1).parameters[ + param_idx, :, + ] + get_metadata(ex1).parameters[param_idx, :] .= ex2_params + end + + return ex1, ex2 +end + +function CO.count_constants_for_optimization(ex::ParametricExpression) + return CO.count_scalar_constants(get_tree(ex)) + length(get_metadata(ex).parameters) +end + +function MF.mutate_constant( + ex::ParametricExpression{T}, + temperature, + options::AbstractOptions, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + if rand(rng, Bool) + # Normal mutation of inner constant + tree = get_contents(ex) + return with_contents(ex, MF.mutate_constant(tree, temperature, options, rng)) + else + # Mutate parameters + parameter_index = rand(rng, 1:(options.expression_options.max_parameters)) + # We mutate all the parameters at once + factor = MF.mutate_factor(T, temperature, options, rng) + get_metadata(ex).parameters[parameter_index, :] .*= factor + return ex + end +end + +end diff --git a/src/Population.jl b/src/Population.jl index 3a544730b..d475da168 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -3,7 +3,7 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, string_tree -using ..CoreModule: AbstractOptions, Dataset, RecordType, DATA_TYPE, LOSS_TYPE +using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics @@ -112,9 +112,7 @@ function best_of_sample( options::AbstractOptions, ) where {T,L,N} sample = sample_pop(pop, options) - return _best_of_sample( - sample.members, running_search_statistics, options - )::PopMember{T,L,N} + return copy(_best_of_sample(sample.members, running_search_statistics, options)) end function _best_of_sample( members::Vector{P}, @@ -168,7 +166,7 @@ const CACHED_WEIGHTS = PerThreadCache{Dict{Tuple{Int,Float32},typeof(test_weights)}}() end -@unstable function get_tournament_selection_weights(@nospecialize(options::AbstractOptions)) +@unstable function get_tournament_selection_weights(@nospecialize(options::Options)) n = options.tournament_selection_n p = options.tournament_selection_p # Computing the weights for the tournament becomes quite expensive, diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index f1ba1cfb3..23358d9dc 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -681,7 +681,7 @@ function update_hall_of_fame!( ) where {PM<:PopMember} for member in members size = compute_complexity(member, options) - valid_size = 0 < size < options.maxsize + MAX_DEGREE + valid_size = 0 < size <= options.maxsize if !valid_size continue end diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index d15e7914c..90edb8ee7 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -105,6 +105,7 @@ function optimize_and_simplify_population( # Note: we have to turn off this threading loop due to Enzyme, since we need # to manually allocate a new task with a larger stack for Enzyme. should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) + @threads_if should_thread for j in 1:(pop.n) if options.should_simplify tree = pop.members[j].tree diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 02070f324..53afae3af 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -12,7 +12,7 @@ export Population, ParametricNode, Expression, ParametricExpression, - StructuredExpression, + TemplateExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, @@ -51,6 +51,8 @@ export Population, get_tree, get_contents, get_metadata, + with_contents, + with_metadata, #Operators plus, @@ -91,7 +93,6 @@ using DynamicExpressions: ParametricNode, Expression, ParametricExpression, - StructuredExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, @@ -124,7 +125,9 @@ using DynamicExpressions: node_type, get_tree, get_contents, - get_metadata + get_metadata, + with_contents, + with_metadata using DynamicExpressions: with_type_parameters @reexport using LossFunctions: MarginLoss, @@ -217,6 +220,8 @@ using DispatchDoctor: @stable include("Migration.jl") include("SearchUtils.jl") include("ExpressionBuilder.jl") + include("TemplateExpression.jl") + include("ParametricExpression.jl") end using .CoreModule: @@ -231,6 +236,7 @@ using .CoreModule: Options, AbstractMutationWeights, MutationWeights, + is_weighted, sample_mutation, plus, sub, @@ -308,6 +314,7 @@ using .SearchUtilsModule: save_to_file, get_cur_maxsize, update_hall_of_fame! +using .TemplateExpressionModule: TemplateExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin @@ -410,7 +417,7 @@ which is useful for debugging and profiling. """ function equation_search( X::AbstractMatrix{T}, - y::AbstractMatrix{T}; + y::AbstractMatrix; niterations::Int=10, weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing, options::AbstractOptions=Options(), @@ -481,17 +488,8 @@ function equation_search( end function equation_search( - X::AbstractMatrix{T1}, y::AbstractMatrix{T2}; kw... -) where {T1<:DATA_TYPE,T2<:DATA_TYPE} - U = promote_type(T1, T2) - return equation_search( - convert(AbstractMatrix{U}, X), convert(AbstractMatrix{U}, y); kw... - ) -end - -function equation_search( - X::AbstractMatrix{T1}, y::AbstractVector{T2}; kw... -) where {T1<:DATA_TYPE,T2<:DATA_TYPE} + X::AbstractMatrix{T}, y::AbstractVector; kw... +) where {T<:DATA_TYPE} return equation_search(X, reshape(y, (1, size(y, 1))); kw..., v_dim_out=Val(1)) end @@ -1073,7 +1071,7 @@ end ) num_evals += evals_from_optimize if options.batching - for i_member in 1:(options.maxsize + MAX_DEGREE) + for i_member in 1:(options.maxsize) score, result_loss = score_func(dataset, best_seen.members[i_member], options) best_seen.members[i_member].score = score best_seen.members[i_member].loss = result_loss diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl new file mode 100644 index 000000000..d88c07dcc --- /dev/null +++ b/src/TemplateExpression.jl @@ -0,0 +1,366 @@ +module TemplateExpressionModule + +using Random: AbstractRNG +using DispatchDoctor: @unstable +using DynamicExpressions: + DynamicExpressions as DE, + AbstractStructuredExpression, + AbstractExpressionNode, + AbstractExpression, + AbstractOperatorEnum, + OperatorEnum, + Expression, + Metadata, + get_contents, + with_contents, + get_metadata, + get_operators, + get_variable_names, + get_tree, + node_type, + eval_tree_array, + count_nodes +using DynamicExpressions.InterfacesModule: + ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments + +using ..CoreModule: AbstractOptions, Dataset, CoreModule as CM, AbstractMutationWeights +using ..ConstantOptimizationModule: ConstantOptimizationModule as CO +using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE +using ..MutationFunctionsModule: MutationFunctionsModule as MF +using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..DimensionalAnalysisModule: DimensionalAnalysisModule as DA +using ..CheckConstraintsModule: CheckConstraintsModule as CC +using ..ComplexityModule: ComplexityModule +using ..LossFunctionsModule: LossFunctionsModule as LF +using ..MutateModule: MutateModule as MM +using ..PopMemberModule: PopMember + +""" + TemplateExpression{T,F,N,E,TS,C,D} <: AbstractStructuredExpression{T,F,N,E,D} + +A symbolic expression that allows the combination of multiple sub-expressions +in a structured way, with constraints on variable usage. + +`TemplateExpression` is designed for symbolic regression tasks where +domain-specific knowledge or constraints must be imposed on the model's structure. + +# Constructor + +- `TemplateExpression(trees; structure, operators, variable_names, variable_mapping)` + - `trees`: A `NamedTuple` holding the sub-expressions (e.g., `f = Expression(...)`, `g = Expression(...)`). + - `structure`: A function that defines how the sub-expressions are combined. This should have one method + that takes `trees` as input and returns a single `Expression` node, and another method which takes + a `NamedTuple` of `Vector` (representing the numerical results of each sub-expression) and returns + a single vector after combining them. + - `operators`: An `OperatorEnum` that defines the allowed operators for the sub-expressions. + - `variable_names`: An optional `Vector` of `String` that defines the names of the variables in the dataset. + - `variable_mapping`: A `NamedTuple` that defines which variables each sub-expression is allowed to access. + For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`. + +# Example + +Let's create an example `TemplateExpression` that combines two sub-expressions `f(x1, x2)` and `g(x3)`: + +```julia +# Define operators and variable names +operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +variable_names = ["x1", "x2", "x3"] + +# Create sub-expressions +x1 = Expression(Node{Float64}(; feature=1); operators, variable_names) +x2 = Expression(Node{Float64}(; feature=2); operators, variable_names) +x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) + +# Define structure function for symbolic and numerical evaluation +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:Expression}}}) + return sin(nt.f) + nt.g * nt.g +end +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) + return @. sin(nt.f) + nt.g * nt.g +end + +# Define variable constraints (if desired) +variable_mapping = (; f=[1, 2], g=[3]) + +# Create TemplateExpression +example_expr = (; f=x1, g=x3) +st_expr = TemplateExpression( + example_expr; + structure=my_structure, operators, variable_names, variable_mapping +) +``` + +When fitting a model in SymbolicRegression.jl, you would provide the `TemplateExpression` +as the `expression_type` argument, and then pass `expression_options=(; structure=my_structure, variable_mapping=variable_mapping)` +as additional options. The `variable_mapping` will constraint `f` to only have access to `x1` and `x2`, +and `g` to only have access to `x3`. +""" +struct TemplateExpression{ + T, + F<:Function, + N<:AbstractExpressionNode{T}, + E<:Expression{T,N}, # TODO: Generalize this + TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, + C<:NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, # The constraints + # TODO: No need for this to be a parametric type + D<:@NamedTuple{ + structure::F, operators::O, variable_names::V, variable_mapping::C + } where {O,V}, +} <: AbstractStructuredExpression{T,F,N,E,D} + trees::TS + metadata::Metadata{D} + + function TemplateExpression( + trees::TS, metadata::Metadata{D} + ) where { + TS, + F<:Function, + C<:NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, + D<:@NamedTuple{ + structure::F, operators::O, variable_names::V, variable_mapping::C + } where {O,V}, + } + E = typeof(first(values(trees))) + N = node_type(E) + return new{eltype(N),F,N,E,TS,C,D}(trees, metadata) + end +end + +function TemplateExpression( + trees::NamedTuple{<:Any,<:NTuple{<:Any,<:AbstractExpression}}; + structure::F, + operators::Union{AbstractOperatorEnum,Nothing}=nothing, + variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, + variable_mapping::NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, +) where {F<:Function} + @assert length(trees) == length(variable_mapping) + if variable_names !== nothing + # TODO: Should this be removed? + @assert Set(eachindex(variable_names)) == + Set(Iterators.flatten(values(variable_mapping))) + end + @assert keys(trees) == keys(variable_mapping) + example_tree = first(values(trees))::AbstractExpression + operators = get_operators(example_tree, operators) + variable_names = get_variable_names(example_tree, variable_names) + metadata = (; structure, operators, variable_names, variable_mapping) + return TemplateExpression(trees, Metadata(metadata)) +end + +@unstable DE.constructorof(::Type{<:TemplateExpression}) = TemplateExpression + +@implements( + ExpressionInterface{all_ei_methods_except(())}, TemplateExpression, [Arguments()] +) + +function EB.create_expression( + t::AbstractExpressionNode{T}, + options::AbstractOptions, + dataset::Dataset{T,L}, + ::Type{<:AbstractExpressionNode}, + ::Type{E}, + ::Val{embed}=Val(false), +) where {T,L,embed,E<:TemplateExpression} + function_keys = keys(options.expression_options.variable_mapping) + + # NOTE: We need to copy over the operators so we can call the structure function + operators = options.operators + variable_names = embed ? dataset.variable_names : nothing + inner_expressions = ntuple( + _ -> Expression(copy(t); operators, variable_names), length(function_keys) + ) + # TODO: Generalize to other inner expression types + return DE.constructorof(E)( + NamedTuple{function_keys}(inner_expressions); + EB.init_params(options, dataset, nothing, Val(embed))..., + ) +end +function EB.extra_init_params( + ::Type{E}, + prototype::Union{Nothing,AbstractExpression}, + options::AbstractOptions, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:TemplateExpression} + # We also need to include the operators here to be consistent with `create_expression`. + return (; options.operators, options.expression_options...) +end +function EB.sort_params(params::NamedTuple, ::Type{<:TemplateExpression}) + return (; + params.structure, params.operators, params.variable_names, params.variable_mapping + ) +end + +function ComplexityModule.compute_complexity( + tree::TemplateExpression, options::AbstractOptions; break_sharing=Val(false) +) + # Rather than including the complexity of the combined tree, + # we only sum the complexity of each inner expression, which will be smaller. + return sum( + ex -> ComplexityModule.compute_complexity(ex, options; break_sharing), + values(get_contents(tree)), + ) +end + +function DE.string_tree( + tree::TemplateExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... +) + raw_contents = get_contents(tree) + function_keys = keys(raw_contents) + inner_strings = NamedTuple{function_keys}( + map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) + ) + # TODO: Make a fallback function in case the structure function is undefined. + return get_metadata(tree).structure(inner_strings) +end +function DE.eval_tree_array( + tree::TemplateExpression{T}, + cX::AbstractMatrix{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., +) where {T} + raw_contents = get_contents(tree) + + # Raw numerical results of each inner expression: + outs = map(ex -> DE.eval_tree_array(ex, cX, operators; kws...), values(raw_contents)) + + # Combine them using the structure function: + results = NamedTuple{keys(raw_contents)}(map(first, outs)) + return get_metadata(tree).structure(results), all(last, outs) +end +function (ex::TemplateExpression)( + X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... +) + # TODO: Why do we need to do this? It should automatically handle this! + return DE.eval_tree_array(ex, X, operators; kws...) +end +@unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:TemplateExpression}) = Any + +function DA.violates_dimensional_constraints( + tree::TemplateExpression, dataset::Dataset, options::AbstractOptions +) + @assert dataset.X_units === nothing && dataset.y_units === nothing + return false +end +function MM.condition_mutation_weights!( + weights::AbstractMutationWeights, member::P, options::AbstractOptions, curmaxsize::Int +) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} + # HACK TODO + return nothing +end + +""" +We need full specialization for constrained expressions, as they rely on subexpressions being combined. +""" +function CM.operator_specialization( + ::Type{O}, ::Type{<:TemplateExpression} +) where {O<:OperatorEnum} + return O +end + +""" +We pick a random subexpression to mutate, +and also return the symbol we mutated on so that we can put it back together later. +""" +function MF.get_contents_for_mutation(ex::TemplateExpression, rng::AbstractRNG) + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + + # Sample weighted by number of nodes in each subexpression + num_nodes = map(count_nodes, values(raw_contents)) + weights = map(Base.Fix2(/, sum(num_nodes)), num_nodes) + cumsum_weights = cumsum(weights) + rand_val = rand(rng) + idx = findfirst(Base.Fix2(>=, rand_val), cumsum_weights)::Int + + key_to_mutate = function_keys[idx] + return raw_contents[key_to_mutate], key_to_mutate +end + +"""See `get_contents_for_mutation(::TemplateExpression, ::AbstractRNG)`.""" +function MF.with_contents_for_mutation( + ex::TemplateExpression, new_inner_contents, context::Symbol +) + raw_contents = get_contents(ex) + raw_contents_keys = keys(raw_contents) + new_contents = NamedTuple{raw_contents_keys}( + ntuple(length(raw_contents_keys)) do i + if raw_contents_keys[i] == context + new_inner_contents + else + raw_contents[raw_contents_keys[i]] + end + end, + ) + return with_contents(ex, new_contents) +end + +"""We combine the operators of each inner expression.""" +function DE.combine_operators( + ex::TemplateExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) where {T,N} + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + new_contents = NamedTuple{function_keys}( + map(Base.Fix2(DE.combine_operators, operators), values(raw_contents)) + ) + return with_contents(ex, new_contents) +end + +"""We simplify each inner expression.""" +function DE.simplify_tree!( + ex::TemplateExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) where {T,N} + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + new_contents = NamedTuple{function_keys}( + map(Base.Fix2(DE.simplify_tree!, operators), values(raw_contents)) + ) + return with_contents(ex, new_contents) +end + +function CO.count_constants_for_optimization(ex::TemplateExpression) + return sum(CO.count_constants_for_optimization, values(get_contents(ex))) +end + +function CC.check_constraints( + ex::TemplateExpression, + options::AbstractOptions, + maxsize::Int, + cursize::Union{Int,Nothing}=nothing, +)::Bool + raw_contents = get_contents(ex) + variable_mapping = get_metadata(ex).variable_mapping + + # First, we check the variable constraints at the top level: + has_invalid_variables = any(keys(raw_contents)) do key + tree = raw_contents[key] + allowed_variables = variable_mapping[key] + contains_other_features_than(tree, allowed_variables) + end + if has_invalid_variables + return false + end + + # We also check the combined complexity: + ((cursize === nothing) ? ComplexityModule.compute_complexity(ex, options) : cursize) > + maxsize && return false + + # Then, we check other constraints for inner expressions: + return all( + t -> CC.check_constraints(t, options, maxsize, nothing), values(raw_contents) + ) + # TODO: The concept of `cursize` doesn't really make sense here. +end +function contains_other_features_than(tree::AbstractExpression, features) + return contains_other_features_than(get_tree(tree), features) +end +function contains_other_features_than(tree::AbstractExpressionNode, features) + any(tree) do node + node.degree == 0 && !node.constant && node.feature ∉ features + end +end + +# TODO: Add custom behavior to adjust what feature nodes can be generated + +end diff --git a/src/Utils.jl b/src/Utils.jl index 473845651..da67bcf4d 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -208,10 +208,10 @@ end function _get_thread_cache(cache::PerThreadCache{T}) where {T} if cache.num_threads[] < Threads.nthreads() Base.@lock cache.lock begin - # The reason we have this extra `.len[]` parameter is to avoid + # The reason we have this extra `.num_threads[]` parameter is to avoid # a race condition between a thread resizing the array concurrent # to the check above. Basically we want to make sure the array is - # always big enough by the time we get to using it. Since `.len[]` + # always big enough by the time we get to using it. Since `.num_threads[]` # is set last, we can safely use the array. if cache.num_threads[] < Threads.nthreads() resize!(cache.x, Threads.nthreads()) diff --git a/test/runtests.jl b/test/runtests.jl index e2ebfbd4a..fcc2c5b08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,11 +110,7 @@ end end include("test_units.jl") - -@testitem "Dataset" tags = [:part3] begin - include("test_dataset.jl") -end - +include("test_dataset.jl") include("test_mixed.jl") @testitem "Testing fast-cycle and custom variable names" tags = [:part2] begin @@ -135,6 +131,7 @@ end ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" include("../examples/parameterized_function.jl") end +include("test_template_expression.jl") @testitem "Testing whether the recorder works." tags = [:part3] begin include("test_recorder.jl") diff --git a/test/test_dataset.jl b/test/test_dataset.jl index 9fdbcfd74..505471979 100644 --- a/test/test_dataset.jl +++ b/test/test_dataset.jl @@ -1,16 +1,29 @@ -using SymbolicRegression -using DispatchDoctor: allow_unstable - -@testset "Dataset construction" begin +@testitem "Dataset construction" tags = [:part3] begin + using SymbolicRegression # Promotion of types: dataset = Dataset(randn(3, 32), randn(Float32, 32); weights=randn(Float32, 32)) - @test typeof(dataset.y) == Array{Float64,1} - @test typeof(dataset.weights) == Array{Float64,1} + + # Will not automatically convert: + @test typeof(dataset.X) == Array{Float64,2} + @test typeof(dataset.y) == Array{Float32,1} + @test typeof(dataset.weights) == Array{Float32,1} end -@testset "With deprecated kwarg" begin +@testitem "With deprecated kwarg" tags = [:part3] begin + using SymbolicRegression + using DispatchDoctor: allow_unstable dataset = allow_unstable() do Dataset(randn(ComplexF32, 3, 32), randn(ComplexF32, 32); loss_type=Float64) end @test dataset isa Dataset{ComplexF32,Float64} end + +@testitem "vector output" tags = [:part3] begin + using SymbolicRegression + + X = randn(Float64, 3, 32) + y = [ntuple(_ -> randn(Float64), 3) for _ in 1:32] + dataset = Dataset(X, y) + @test dataset isa Dataset{Float64,Float64} + @test dataset.y isa Vector{NTuple{3,Float64}} +end diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl index 42a28d14e..56cfa1f6b 100644 --- a/test/test_pretty_printing.jl +++ b/test/test_pretty_printing.jl @@ -63,11 +63,7 @@ end .exists[6] = false .members[6] = undef .exists[7] = false - .members[7] = undef - .exists[8] = false - .members[8] = undef - .exists[9] = false - .members[9] = undef" + .members[7] = undef" @test s_hof == true_s end diff --git a/test/test_search_statistics.jl b/test/test_search_statistics.jl index 35a9b0175..cc2f5360a 100644 --- a/test/test_search_statistics.jl +++ b/test/test_search_statistics.jl @@ -13,7 +13,7 @@ end normalize_frequencies!(statistics) -@test sum(statistics.frequencies) == 1022 +@test sum(statistics.frequencies) == 1020 @test sum(statistics.normalized_frequencies) ≈ 1.0 @test statistics.normalized_frequencies[5] > statistics.normalized_frequencies[15] diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl new file mode 100644 index 000000000..2187ea334 --- /dev/null +++ b/test/test_template_expression.jl @@ -0,0 +1,161 @@ +@testitem "Basic utility of the TemplateExpression" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: SymbolicRegression as SR + using SymbolicRegression.CheckConstraintsModule: check_constraints + using DynamicExpressions: OperatorEnum + + options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + operators = options.operators + variable_names = (i -> "x$i").(1:3) + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + + # For combining expressions to a single expression: + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) = + "sin($(nt.f)) + $(nt.g)^2" + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) = + @. sin(nt.f) + nt.g^2 + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:Expression}}}) = + sin(nt.f) + nt.g * nt.g + + variable_mapping = (; f=[1, 2], g=[3]) + st_expr = TemplateExpression( + (; f=x1, g=cos(x3)); + structure=my_structure, + operators, + variable_names, + variable_mapping, + ) + @test string_tree(st_expr) == "sin(x1) + cos(x3)^2" + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(cos, sin)) + + # Changing the operators will change how the expression is interpreted for + # parts that are already evaluated: + @test string_tree(st_expr, operators) == "sin(x1) + sin(x3)^2" + + # We can evaluate with this too: + cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] + out, completed = st_expr(cX) + @test completed + @test out ≈ [sin(1.0) + cos(5.0)^2, sin(2.0) + cos(6.0)^2] + + # And also check the contents: + @test check_constraints(st_expr, options, 100) + + # We can see that violating the constraints will cause a violation: + new_expr = with_contents(st_expr, (; f=x3, g=cos(x3))) + @test !check_constraints(new_expr, options, 100) + new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) + @test check_constraints(new_expr, options, 100) + new_expr = with_contents(st_expr, (; f=x2, g=cos(x1))) + @test !check_constraints(new_expr, options, 100) + + # Checks the size of each individual expression: + new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) + + @test compute_complexity(new_expr, options) == 3 + @test check_constraints(new_expr, options, 3) + @test !check_constraints(new_expr, options, 2) +end +@testitem "Expression interface" tags = [:part3] begin + using SymbolicRegression + using DynamicExpressions: OperatorEnum + using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + + # For combining expressions to a single expression: + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) = + "sin($(nt.f)) + $(nt.g)^2" + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) = + @. sin(nt.f) + nt.g^2 + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:Expression}}}) = + sin(nt.f) + nt.g * nt.g + + variable_mapping = (; f=[1, 2], g=[3]) + st_expr = TemplateExpression( + (; f=x1, g=x3); structure=my_structure, operators, variable_names, variable_mapping + ) + @test Interfaces.test(ExpressionInterface, TemplateExpression, [st_expr]) +end +@testitem "Utilising TemplateExpression to build vector expressions" tags = [:part3] begin + using SymbolicRegression + using Random: rand + + # Define the structure function, which returns a tuple: + function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) + return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2), $(nt.f) + $(nt.g3) )" + end + function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) + return map( + i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i], nt.f[i] + nt.g3[i]), + eachindex(nt.f), + ) + end + + # Set up operators and variable names + options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + + # Create expressions + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); options.operators, variable_names)).(1:3) + + # Test with vector inputs: + nt_vector = NamedTuple{(:f, :g1, :g2, :g3)}((1:3, 4:6, 7:9, 10:12)) + @test my_structure(nt_vector) == [(5, 8, 11), (7, 10, 13), (9, 12, 15)] + + # And string inputs: + nt_string = NamedTuple{(:f, :g1, :g2, :g3)}(("x1", "x2", "x3", "x2")) + @test my_structure(nt_string) == "( x1 + x2, x1 + x3, x1 + x2 )" + + # Now, using TemplateExpression: + variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) + st_expr = TemplateExpression( + (; f=x1, g1=x2, g2=x3, g3=x2); + structure=my_structure, + options.operators, + variable_names, + variable_mapping, + ) + @test string_tree(st_expr) == "( x1 + x2, x1 + x3, x1 + x2 )" + + # We can directly call it: + cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] + out, completed = st_expr(cX) + @test completed + @test out == [(1 + 3, 1 + 5, 1 + 3), (2 + 4, 2 + 6, 2 + 4)] +end +@testitem "TemplateExpression getters" tags = [:part3] begin + using SymbolicRegression + using DynamicExpressions: get_operators, get_variable_names + + operators = + Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)).operators + variable_names = (i -> "x$i").(1:3) + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + + my_structure(nt) = nt.f + + variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) + + st_expr = TemplateExpression( + (; f=x1, g1=x3, g2=x3, g3=x3); + structure=my_structure, + operators, + variable_names, + variable_mapping, + ) + + @test st_expr isa TemplateExpression + @test get_operators(st_expr) == operators + @test get_variable_names(st_expr) == variable_names + @test get_metadata(st_expr).structure == my_structure +end +@testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin + include("../examples/template_expression.jl") +end