From 1494bde425763c2111754ef2b6f2840c92c44610 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 19:18:09 +0100 Subject: [PATCH 01/58] refactor: move ParametricExpression overloads to separate file --- src/ExpressionBuilder.jl | 166 ++++++------------------------------ src/ParametricExpression.jl | 152 +++++++++++++++++++++++++++++++++ src/SymbolicRegression.jl | 1 + 3 files changed, 177 insertions(+), 142 deletions(-) create mode 100644 src/ParametricExpression.jl diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 12b20a06c..633391ed0 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -5,8 +5,6 @@ using DynamicExpressions: AbstractExpressionNode, AbstractExpression, Expression, - ParametricExpression, - ParametricNode, constructorof, get_tree, get_contents, @@ -15,43 +13,49 @@ using DynamicExpressions: with_metadata, count_scalar_constants, eval_tree_array -using Random: default_rng, AbstractRNG using StatsBase: StatsBase using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE using ..HallOfFameModule: HallOfFame -using ..LossFunctionsModule: maybe_getindex -using ..InterfaceDynamicExpressionsModule: expected_array_type using ..PopulationModule: Population using ..PopMemberModule: PopMember import DynamicExpressions: get_operators import ..CoreModule: create_expression -import ..MutationFunctionsModule: - make_random_leaf, crossover_trees, mutate_constant, mutate_factor -import ..LossFunctionsModule: eval_tree_dispatch -import ..ConstantOptimizationModule: count_constants_for_optimization @unstable function create_expression( t::T, options::AbstractOptions, dataset::Dataset{T,L}, ::Val{embed}=Val(false) ) where {T,L,embed} return create_expression( - constructorof(options.node_type)(; val=t), options, dataset, Val(embed) + t, options, dataset, options.node_type, options.expression_type, Val(embed) ) end +function create_expression( + ex::AbstractExpression{T}, + options::AbstractOptions, + ::Dataset{T,L}, + ::Val{embed}=Val(false), +) where {T,L,embed} + return ex::options.expression_type +end @unstable function create_expression( - t::AbstractExpressionNode{T}, + t::T, options::AbstractOptions, dataset::Dataset{T,L}, + ::Type{N}, + ::Type{<:AbstractExpression}, ::Val{embed}=Val(false), -) where {T,L,embed} - return constructorof(options.expression_type)( - t; init_params(options, dataset, nothing, Val(embed))... - ) +) where {T,L,embed,N<:AbstractExpressionNode} + return create_expression(constructorof(N)(; val=t), options, dataset, N, E, Val(embed)) end -function create_expression( - ex::AbstractExpression{T}, ::AbstractOptions, ::Dataset{T,L}, ::Val{embed}=Val(false) -) where {T,L,embed} - return ex +@unstable function create_expression( + t::AbstractExpressionNode{T}, + options::AbstractOptions, + dataset::Dataset{T,L}, + ::Type{<:AbstractExpressionNode}, + ::Type{E}, + ::Val{embed}=Val(false), +) where {T,L,embed,E<:AbstractExpression} + return constructorof(E)(t; init_params(options, dataset, nothing, Val(embed))...) end @unstable function init_params( options::AbstractOptions, @@ -75,46 +79,16 @@ function extra_init_params( dataset::Dataset{T}, ::Val{embed}, ) where {T,embed,E<:AbstractExpression} + # TODO: Potential aliasing here return (; options.expression_options...) end -function extra_init_params( - ::Type{E}, - prototype::Union{Nothing,ParametricExpression}, - options::AbstractOptions, - dataset::Dataset{T}, - ::Val{embed}, -) where {T,embed,E<:ParametricExpression} - num_params = options.expression_options.max_parameters - num_classes = length(unique(dataset.extra.classes)) - parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing - _parameters = if prototype === nothing - randn(T, (num_params, num_classes)) - else - copy(get_metadata(prototype).parameters) - end - return (; parameters=_parameters, parameter_names) -end consistency_checks(::AbstractOptions, prototype::Nothing) = nothing function consistency_checks(options::AbstractOptions, prototype) - if prototype === nothing - return nothing - end @assert( prototype isa options.expression_type, "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))" ) - if prototype isa ParametricExpression - if prototype.metadata.parameter_names !== nothing - @assert( - length(prototype.metadata.parameter_names) == - options.expression_options.max_parameters, - "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(prototype.metadata.parameter_names)" - ) - end - @assert size(prototype.metadata.parameters, 1) == - options.expression_options.max_parameters - end return nothing end @@ -164,11 +138,6 @@ function strip_metadata( ) where {T,L} return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) end -function strip_metadata( - ex::ParametricExpression, options::AbstractOptions, dataset::Dataset{T,L} -) where {T,L} - return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) -end function strip_metadata( member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} @@ -195,93 +164,6 @@ function strip_metadata( ) end -function eval_tree_dispatch( - tree::ParametricExpression{T}, dataset::Dataset{T}, options::AbstractOptions, idx -) where {T<:DATA_TYPE} - A = expected_array_type(dataset.X) - return eval_tree_array( - tree, - maybe_getindex(dataset.X, :, idx), - maybe_getindex(dataset.extra.classes, idx), - options.operators, - )::Tuple{A,Bool} -end - -function make_random_leaf( - nfeatures::Int, - ::Type{T}, - ::Type{N}, - rng::AbstractRNG=default_rng(), - options::Union{AbstractOptions,Nothing}=nothing, -) where {T<:DATA_TYPE,N<:ParametricNode} - choice = rand(rng, 1:3) - if choice == 1 - return ParametricNode(; val=randn(rng, T)) - elseif choice == 2 - return ParametricNode(T; feature=rand(rng, 1:nfeatures)) - else - tree = ParametricNode{T}() - tree.val = zero(T) - tree.degree = 0 - tree.feature = 0 - tree.constant = false - tree.is_parameter = true - tree.parameter = rand( - rng, UInt16(1):UInt16(options.expression_options.max_parameters) - ) - return tree - end -end - -function crossover_trees( - ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng() -) where {T} - tree1 = get_contents(ex1) - tree2 = get_contents(ex2) - out1, out2 = crossover_trees(tree1, tree2, rng) - ex1 = with_contents(ex1, out1) - ex2 = with_contents(ex2, out2) - - # We also randomly share parameters - nparams1 = size(ex1.metadata.parameters, 1) - nparams2 = size(ex2.metadata.parameters, 1) - num_params_switch = min(nparams1, nparams2) - idx_to_switch = StatsBase.sample( - rng, 1:num_params_switch, num_params_switch; replace=false - ) - for param_idx in idx_to_switch - ex2_params = ex2.metadata.parameters[param_idx, :] - ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :] - ex1.metadata.parameters[param_idx, :] .= ex2_params - end - - return ex1, ex2 -end - -function count_constants_for_optimization(ex::ParametricExpression) - return count_scalar_constants(get_tree(ex)) + length(ex.metadata.parameters) -end - -function mutate_constant( - ex::ParametricExpression{T}, - temperature, - options::AbstractOptions, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - if rand(rng, Bool) - # Normal mutation of inner constant - tree = get_contents(ex) - return with_contents(ex, mutate_constant(tree, temperature, options, rng)) - else - # Mutate parameters - parameter_index = rand(rng, 1:(options.expression_options.max_parameters)) - # We mutate all the parameters at once - factor = mutate_factor(T, temperature, options, rng) - ex.metadata.parameters[parameter_index, :] .*= factor - return ex - end -end - @unstable function get_operators(ex::AbstractExpression, options::AbstractOptions) return get_operators(ex, options.operators) end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl new file mode 100644 index 000000000..3ac4f93d0 --- /dev/null +++ b/src/ParametricExpression.jl @@ -0,0 +1,152 @@ +module ParametricExpressionModule +## Note that ParametricExpression is defined within DynamicExpressions.jl, +## this file just adds custom behavior for SymbolicRegression.jl, where needed + +using DynamicExpressions: + AbstractExpression, + ParametricExpression, + ParametricNode, + get_metadata, + with_metadata, + get_contents, + with_contents, + eval_tree_array +using StatsBase: StatsBase +using Random: default_rng, AbstractRNG + +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE +using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..LossFunctionsModule: LossFunctionsModule as LF +using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..MutationFunctionsModule: MutationFunctionsModule as MF +using ..ConstantOptimizationModule: ConstantOptimizationModule as CO + +function EB.extra_init_params( + ::Type{E}, + prototype::Union{Nothing,ParametricExpression}, + options::AbstractOptions, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:ParametricExpression} + num_params = options.expression_options.max_parameters + num_classes = length(unique(dataset.extra.classes)) + parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing + _parameters = if prototype === nothing + randn(T, (num_params, num_classes)) + else + copy(get_metadata(prototype).parameters) + end + return (; parameters=_parameters, parameter_names) +end +function EB.strip_metadata( + ex::ParametricExpression, options::AbstractOptions, dataset::Dataset{T,L} +) where {T,L} + return with_metadata(ex; EB.init_params(options, dataset, ex, Val(false))...) +end +function EB.consistency_checks(options::AbstractOptions, prototype::ParametricExpression) + @assert options.expression_type <: ParametricExpression + if get_metadata(prototype).parameter_names !== nothing + @assert( + length(get_metadata(prototype).parameter_names) == + options.expression_options.max_parameters, + "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(get_metadata(prototype).parameter_names)" + ) + end + @assert size(get_metadata(prototype).parameters, 1) == + options.expression_options.max_parameters + return nothing +end + +function LF.eval_tree_dispatch( + tree::ParametricExpression{T}, dataset::Dataset{T}, options::AbstractOptions, idx +) where {T<:DATA_TYPE} + A = expected_array_type(dataset.X) + return eval_tree_array( + tree, + LF.maybe_getindex(dataset.X, :, idx), + LF.maybe_getindex(dataset.extra.classes, idx), + options.operators, + )::Tuple{A,Bool} +end + +function MF.make_random_leaf( + nfeatures::Int, + ::Type{T}, + ::Type{N}, + rng::AbstractRNG=default_rng(), + options::Union{AbstractOptions,Nothing}=nothing, +) where {T<:DATA_TYPE,N<:ParametricNode} + choice = rand(rng, 1:3) + if choice == 1 + return ParametricNode(; val=randn(rng, T)) + elseif choice == 2 + return ParametricNode(T; feature=rand(rng, 1:nfeatures)) + else + tree = ParametricNode{T}() + tree.val = zero(T) + tree.degree = 0 + tree.feature = 0 + tree.constant = false + tree.is_parameter = true + tree.parameter = rand( + rng, UInt16(1):UInt16(options.expression_options.max_parameters) + ) + return tree + end +end + +function MF.crossover_trees( + ex1::ParametricExpression{T}, + ex2::ParametricExpression{T}, + rng::AbstractRNG=default_rng(), +) where {T} + tree1 = get_contents(ex1) + tree2 = get_contents(ex2) + out1, out2 = MF.crossover_trees(tree1, tree2, rng) + ex1 = with_contents(ex1, out1) + ex2 = with_contents(ex2, out2) + + # We also randomly share parameters + nparams1 = size(get_metadata(ex1).parameters, 1) + nparams2 = size(get_metadata(ex2).parameters, 1) + num_params_switch = min(nparams1, nparams2) + idx_to_switch = StatsBase.sample( + rng, 1:num_params_switch, num_params_switch; replace=false + ) + for param_idx in idx_to_switch + # TODO: Ensure no issues from aliasing here + ex2_params = get_metadata(ex2).parameters[param_idx, :] + get_metadata(ex2).parameters[param_idx, :] .= get_metadata(ex1).parameters[ + param_idx, :, + ] + get_metadata(ex1).parameters[param_idx, :] .= ex2_params + end + + return ex1, ex2 +end + +function CO.count_constants_for_optimization(ex::ParametricExpression) + return CO.count_scalar_constants(get_tree(ex)) + length(get_metadata(ex).parameters) +end + +function MF.mutate_constant( + ex::ParametricExpression{T}, + temperature, + options::AbstractOptions, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + if rand(rng, Bool) + # Normal mutation of inner constant + tree = get_contents(ex) + return with_contents(ex, MF.mutate_constant(tree, temperature, options, rng)) + else + # Mutate parameters + parameter_index = rand(rng, 1:(options.expression_options.max_parameters)) + # We mutate all the parameters at once + factor = MF.mutate_factor(T, temperature, options, rng) + get_metadata(ex).parameters[parameter_index, :] .*= factor + return ex + end +end + +end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 06512aa6e..d34d24d69 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -217,6 +217,7 @@ using DispatchDoctor: @stable include("Migration.jl") include("SearchUtils.jl") include("ExpressionBuilder.jl") + include("ParametricExpression.jl") end using .CoreModule: From 6c7814573daea7d0c65e5a03dcfa140ae9ac3109 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 19:24:50 +0100 Subject: [PATCH 02/58] fix: some imports in ExpressionBuilder --- src/ExpressionBuilder.jl | 18 ++++++++++++++++-- src/ParametricExpression.jl | 7 +++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 633391ed0..c840d3935 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -1,3 +1,7 @@ +""" +This module provides functions for creating, initializing, and manipulating +`<:AbstractExpression` instances and their metadata within the SymbolicRegression.jl framework. +""" module ExpressionBuilderModule using DispatchDoctor: @unstable @@ -29,6 +33,16 @@ import ..CoreModule: create_expression t, options, dataset, options.node_type, options.expression_type, Val(embed) ) end +@unstable function create_expression( + t::AbstractExpressionNode{T}, + options::AbstractOptions, + dataset::Dataset{T,L}, + ::Val{embed}=Val(false), +) where {T,L,embed} + return create_expression( + t, options, dataset, options.node_type, options.expression_type, Val(embed) + ) +end function create_expression( ex::AbstractExpression{T}, options::AbstractOptions, @@ -42,9 +56,9 @@ end options::AbstractOptions, dataset::Dataset{T,L}, ::Type{N}, - ::Type{<:AbstractExpression}, + ::Type{E}, ::Val{embed}=Val(false), -) where {T,L,embed,N<:AbstractExpressionNode} +) where {T,L,embed,N<:AbstractExpressionNode,E<:AbstractExpression} return create_expression(constructorof(N)(; val=t), options, dataset, N, E, Val(embed)) end @unstable function create_expression( diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 3ac4f93d0..40f7930a8 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -1,6 +1,8 @@ +""" +Note that ParametricExpression is defined within DynamicExpressions.jl, +this file just adds custom behavior for SymbolicRegression.jl, where needed. +""" module ParametricExpressionModule -## Note that ParametricExpression is defined within DynamicExpressions.jl, -## this file just adds custom behavior for SymbolicRegression.jl, where needed using DynamicExpressions: AbstractExpression, @@ -10,6 +12,7 @@ using DynamicExpressions: with_metadata, get_contents, with_contents, + get_tree, eval_tree_array using StatsBase: StatsBase using Random: default_rng, AbstractRNG From 87d37aa17a60fba8f1223cb09436b6e810342416 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 20:56:47 +0100 Subject: [PATCH 03/58] feat: initial implementation of ConstrainedExpression --- docs/src/customization.md | 5 + src/ConstrainedExpression.jl | 175 +++++++++++++++++++++++++++++++++++ src/Core.jl | 7 +- src/Mutate.jl | 8 +- src/MutationFunctions.jl | 107 ++++++++++++++++----- src/Options.jl | 6 +- src/OptionsStruct.jl | 16 ++-- src/SymbolicRegression.jl | 5 +- 8 files changed, 288 insertions(+), 41 deletions(-) create mode 100644 src/ConstrainedExpression.jl diff --git a/docs/src/customization.md b/docs/src/customization.md index 2a9d9a072..7b04604bc 100644 --- a/docs/src/customization.md +++ b/docs/src/customization.md @@ -49,6 +49,11 @@ If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_ case your expression needs additional parameters. See the method for `ParametricExpression` as an example. +You can look at the files `src/ParametricExpression.jl` and `src/ConstrainedExpression.jl` +for more examples of custom expression types, though note that `ParametricExpression` itself +is defined in DynamicExpressions.jl, while that file just overloads some methods for +SymbolicRegression.jl. + ## Other Customizations Other internal abstract types include the following: diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl new file mode 100644 index 000000000..3bc2dd2e5 --- /dev/null +++ b/src/ConstrainedExpression.jl @@ -0,0 +1,175 @@ +module ConstrainedExpressionModule + +using Random: AbstractRNG +using DynamicExpressions: + DynamicExpressions as DE, + AbstractStructuredExpression, + AbstractExpressionNode, + AbstractExpression, + AbstractOperatorEnum, + OperatorEnum, + Expression, + Metadata, + get_contents, + with_contents, + get_operators, + get_variable_names, + node_type +using DynamicExpressions.InterfacesModule: + ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments + +using ..CoreModule: AbstractOptions, Dataset, CoreModule as CM +using ..ConstantOptimizationModule: ConstantOptimizationModule as CO +using ..MutationFunctionsModule: MutationFunctionsModule as MF +using ..ExpressionBuilderModule: ExpressionBuilderModule as EB + +struct ConstrainedExpression{ + T, + F<:Function, + N<:AbstractExpressionNode{T}, + E<:Expression{T,N}, # TODO: Generalize this + TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, + C<:NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, # The constraints + D<:@NamedTuple{ + structure::F, operators::O, variable_names::V, variable_mapping::C + } where {O,V}, +} <: AbstractStructuredExpression{T,F,N,E,D} + trees::TS + metadata::Metadata{D} + + function ConstrainedExpression( + trees::TS, metadata::Metadata{D} + ) where { + TS, + F<:Function, + C<:NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, + D<:@NamedTuple{ + structure::F, operators::O, variable_names::V, variable_mapping::C + } where {O,V}, + } + E = typeof(first(values(trees))) + N = node_type(E) + return new{eltype(N),F,N,E,TS,C,D}(trees, metadata) + end +end + +function ConstrainedExpression( + trees::TS; + structure::F, + operators::Union{AbstractOperatorEnum,Nothing}=nothing, + variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, + variable_mapping::NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, +) where {E<:AbstractExpression,TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}},F<:Function} + @assert length(trees) == length(variable_mapping) + if variable_names !== nothing + # TODO: Should this be removed? + @assert Set(eachindex(variable_names)) == + Set(Iterators.flatten(values(variable_mapping))) + end + @assert keys(trees) == keys(variable_mapping) + example_tree = first(values(trees))::AbstractExpression + operators = get_operators(example_tree, operators) + variable_names = get_variable_names(example_tree, variable_names) + metadata = (; structure, operators, variable_names, variable_mapping) + return ConstrainedExpression(trees, Metadata(metadata)) +end + +DE.constructorof(::Type{<:ConstrainedExpression}) = ConstrainedExpression + +@implements( + ExpressionInterface{all_ei_methods_except(())}, ConstrainedExpression, [Arguments()] +) + +function EB.create_expression( + t::AbstractExpressionNode{T}, + options::AbstractOptions, + dataset::Dataset{T,L}, + ::Type{<:AbstractExpressionNode}, + ::Type{E}, + ::Val{embed}=Val(false), +) where {T,L,embed,E<:ConstrainedExpression} + function_keys = keys(options.expression_options.variable_mapping) + + # TODO: Generalize to other inner expression types + operators = options.operators + variable_names = embed ? dataset.variable_names : nothing + inner_expressions = ntuple( + _ -> Expression(copy(t); operators, variable_names), length(function_keys) + ) + return DE.constructorof(E)( + NamedTuple{function_keys}(inner_expressions); + EB.init_params(options, dataset, nothing, Val(embed))..., + ) +end + +""" +We need full specialization for constrained expressions, as they rely on subexpressions being combined. +""" +CM.operator_specialization( + ::Type{O}, ::Type{<:ConstrainedExpression} +) where {O<:OperatorEnum} = O + +""" +We pick a random subexpression to mutate, +and also return the symbol we mutated on so that we can put it back together later. +""" +function MF.get_contents_for_mutation(ex::ConstrainedExpression, rng::AbstractRNG) + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + key_to_mutate = rand(rng, function_keys) + + return raw_contents[key_to_mutate], key_to_mutate +end + +"""See `get_contents_for_mutation(::ConstrainedExpression, ::AbstractRNG)`.""" +function MF.with_contents_for_mutation( + ex::ConstrainedExpression, new_inner_contents, context::Symbol +) + raw_contents = get_contents(ex) + raw_contents_keys = keys(raw_contents) + new_contents = NamedTuple{raw_contents_keys}( + ntuple(length(raw_contents_keys)) do i + if raw_contents_keys[i] == context + new_inner_contents + else + raw_contents[raw_contents_keys[i]] + end + end, + ) + return with_contents(ex, new_contents) +end + +"""We combine the operators of each inner expression.""" +function DE.combine_operators( + ex::ConstrainedExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) where {T,N} + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + new_contents = NamedTuple{function_keys}( + map(Base.Fix2(DE.combine_operators, operators), values(raw_contents)) + ) + return with_contents(ex, new_contents) +end + +"""We simplify each inner expression.""" +function DE.simplify_tree!( + ex::ConstrainedExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) where {T,N} + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + new_contents = NamedTuple{function_keys}( + map(Base.Fix2(DE.simplify_tree!, operators), values(raw_contents)) + ) + return with_contents(ex, new_contents) +end + +function CO.count_constants_for_optimization(ex::ConstrainedExpression) + return sum(CO.count_constants_for_optimization, values(get_contents(ex))) +end + +# TODO: Add custom behavior to adjust what feature nodes can be generated +# TODO: Better versions: +# - Allow evaluation to call structure function - in which case the structure would simply combine the results. +# - Maybe we want to do similar for string output as well. That way, the operators provided don't really matter. + +end diff --git a/src/Core.jl b/src/Core.jl index 0a16b52d0..1ae420863 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -14,7 +14,12 @@ using .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE using .DatasetModule: Dataset using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation -using .OptionsStructModule: AbstractOptions, Options, ComplexityMapping, specialized_options +using .OptionsStructModule: + AbstractOptions, + Options, + ComplexityMapping, + specialized_options, + operator_specialization using .OperatorsModule: plus, sub, diff --git a/src/Mutate.jl b/src/Mutate.jl index 40c14e807..3cf7aadd2 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -29,7 +29,8 @@ using ..MutationFunctionsModule: crossover_trees, form_random_connection!, break_random_connection!, - randomly_rotate_tree! + randomly_rotate_tree!, + randomize_tree using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder @@ -593,10 +594,7 @@ function mutate!( nfeatures, kws..., ) where {T,N<:AbstractExpression{T},P<:PopMember} - tree_size_to_generate = rand(1:curmaxsize) - tree = with_contents( - tree, gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T) - ) + tree = randomize_tree(tree, curmaxsize, options, nfeatures) @recorder recorder["type"] = "randomize" return MutationResult{N,P}(; tree=tree) end diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index 496d584dd..e0a7fa3f4 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -16,6 +16,31 @@ using DynamicExpressions: has_operators using ..CoreModule: AbstractOptions, DATA_TYPE +""" + get_contents_for_mutation(ex::AbstractExpression, rng::AbstractRNG) + +Return the contents of an expression, which can be mutated. +You can overload this function for custom expression types that +need to be mutated in a specific way. + +The second return value is an optional context object that will be +passed to the `with_contents_for_mutation` function. +""" +function get_contents_for_mutation(ex::AbstractExpression, rng::AbstractRNG) + return get_contents(ex), nothing +end + +""" + with_contents_for_mutation(ex::AbstractExpression, context) + +Replace the contents of an expression with the given context object. +You can overload this function for custom expression types that +need to be mutated in a specific way. +""" +function with_contents_for_mutation(ex::AbstractExpression, new_contents, ::Nothing) + return with_contents(ex, new_contents) +end + """ random_node(tree::AbstractNode; filter::F=Returns(true)) @@ -34,8 +59,8 @@ end """Swap operands in binary operator for ops like pow and divide""" function swap_operands(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - ex = with_contents(ex, swap_operands(tree, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation(ex, swap_operands(tree, rng), context) return ex end function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) @@ -51,8 +76,8 @@ end function mutate_operator( ex::AbstractExpression{T}, options::AbstractOptions, rng::AbstractRNG=default_rng() ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, mutate_operator(tree, options, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation(ex, mutate_operator(tree, options, rng), context) return ex end function mutate_operator( @@ -79,8 +104,10 @@ function mutate_constant( options::AbstractOptions, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, mutate_constant(tree, temperature, options, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, mutate_constant(tree, temperature, options, rng), context + ) return ex end function mutate_constant( @@ -125,8 +152,10 @@ function append_random_op( rng::AbstractRNG=default_rng(); makeNewBinOp::Union{Bool,Nothing}=nothing, ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, append_random_op(tree, options, nfeatures, rng; makeNewBinOp)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, append_random_op(tree, options, nfeatures, rng; makeNewBinOp), context + ) return ex end function append_random_op( @@ -168,8 +197,10 @@ function insert_random_op( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, insert_random_op(tree, options, nfeatures, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, insert_random_op(tree, options, nfeatures, rng), context + ) return ex end function insert_random_op( @@ -202,8 +233,10 @@ function prepend_random_op( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, prepend_random_op(tree, options, nfeatures, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, prepend_random_op(tree, options, nfeatures, rng), context + ) return ex end function prepend_random_op( @@ -263,8 +296,10 @@ function delete_random_op!( nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, delete_random_op!(tree, options, nfeatures, rng)) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, delete_random_op!(tree, options, nfeatures, rng), context + ) return ex end function delete_random_op!( @@ -312,6 +347,30 @@ function delete_random_op!( return tree end +function randomize_tree( + ex::AbstractExpression, + curmaxsize::Int, + options::AbstractOptions, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) + tree, context = get_contents_for_mutation(ex, rng) + ex = with_contents_for_mutation( + ex, randomize_tree(tree, curmaxsize, options, nfeatures, rng), context + ) + return ex +end +function randomize_tree( + ::AbstractExpressionNode{T}, + curmaxsize::Int, + options::AbstractOptions, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree_size_to_generate = rand(rng, 1:curmaxsize) + return gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T, rng) +end + """Create a random equation by appending random operators""" function gen_random_tree( length::Int, @@ -353,11 +412,11 @@ end function crossover_trees( ex1::E, ex2::E, rng::AbstractRNG=default_rng() ) where {T,E<:AbstractExpression{T}} - tree1 = get_contents(ex1) - tree2 = get_contents(ex2) + tree1, context1 = get_contents_for_mutation(ex1, rng) + tree2, context2 = get_contents_for_mutation(ex2, rng) out1, out2 = crossover_trees(tree1, tree2, rng) - ex1 = with_contents(ex1, out1) - ex2 = with_contents(ex2, out2) + ex1 = with_contents_for_mutation(ex1, out1, context1) + ex2 = with_contents_for_mutation(ex2, out2, context2) return ex1, ex2 end @@ -408,8 +467,8 @@ function get_two_nodes_without_loop(tree::AbstractNode, rng::AbstractRNG; max_at end function form_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - return with_contents(ex, form_random_connection!(tree, rng)) + tree, context = get_contents_for_mutation(ex, rng) + return with_contents_for_mutation(ex, form_random_connection!(tree, rng), context) end function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) if length(tree) < 5 @@ -432,8 +491,8 @@ function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rn end function break_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - return with_contents(ex, break_random_connection!(tree, rng)) + tree, context = get_contents_for_mutation(ex, rng) + return with_contents_for_mutation(ex, break_random_connection!(tree, rng), context) end function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) tree.degree == 0 && return tree @@ -451,9 +510,9 @@ function is_valid_rotation_node(node::AbstractNode) end function randomly_rotate_tree!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) + tree, context = get_contents_for_mutation(ex, rng) rotated_tree = randomly_rotate_tree!(tree, rng) - return with_contents(ex, rotated_tree) + return with_contents_for_mutation(ex, rotated_tree, context) end function randomly_rotate_tree!(tree::AbstractNode, rng::AbstractRNG=default_rng()) num_rotation_nodes = count(is_valid_rotation_node, tree) diff --git a/src/Options.jl b/src/Options.jl index 701fde2ff..4c84eb56e 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -510,11 +510,13 @@ $(OPTION_DESCRIPTIONS) # Not search options; just construction options: define_helper_functions::Bool=true, deprecated_return_state=nothing, - # Deprecated args: + ######################################### + # Deprecated args: ###################### fast_cycle::Bool=false, npopulations::Union{Nothing,Integer}=nothing, npop::Union{Nothing,Integer}=nothing, kws..., + ######################################### ) for k in keys(kws) !haskey(deprecated_options_mapping, k) && error("Unknown keyword argument: $k") @@ -733,7 +735,7 @@ $(OPTION_DESCRIPTIONS) options = Options{ typeof(complexity_mapping), - operator_specialization(typeof(operators)), + operator_specialization(typeof(operators), expression_type), node_type, expression_type, typeof(expression_options), diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index ba20fe92c..a0cea1b8b 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -113,13 +113,15 @@ function ComplexityMapping( ) end -# Controls level of specialization we compile -function operator_specialization end -if VERSION >= v"1.10.0-DEV.0" - @eval operator_specialization(::Type{<:OperatorEnum}) = OperatorEnum -else - @eval operator_specialization(O::Type{<:OperatorEnum}) = O -end +""" +Controls level of specialization we compile into `Options`. + +Overload if needed for custom expression types. +""" +operator_specialization( + ::Type{O}, ::Type{<:AbstractExpression} +) where {O<:AbstractOperatorEnum} = O +operator_specialization(::Type{<:OperatorEnum}, ::Type{<:AbstractExpression}) = OperatorEnum """ AbstractOptions diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index d34d24d69..d1ca92fc9 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -12,7 +12,7 @@ export Population, ParametricNode, Expression, ParametricExpression, - StructuredExpression, + ConstrainedExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, @@ -91,7 +91,6 @@ using DynamicExpressions: ParametricNode, Expression, ParametricExpression, - StructuredExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, @@ -217,6 +216,7 @@ using DispatchDoctor: @stable include("Migration.jl") include("SearchUtils.jl") include("ExpressionBuilder.jl") + include("ConstrainedExpression.jl") include("ParametricExpression.jl") end @@ -309,6 +309,7 @@ using .SearchUtilsModule: save_to_file, get_cur_maxsize, update_hall_of_fame! +using .ConstrainedExpressionModule: ConstrainedExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin From 2d9d634e540ceb421e8fd04197c8fa1b8f15c339 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 21:44:27 +0100 Subject: [PATCH 04/58] feat: variable constraint checking --- src/CheckConstraints.jl | 2 +- src/ConstrainedExpression.jl | 66 +++++++++++++++++++++++++++++++++++- src/LossFunctions.jl | 17 +++++----- src/ParametricExpression.jl | 4 +-- 4 files changed, 77 insertions(+), 12 deletions(-) diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index df748c218..fb0bbb712 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -69,7 +69,7 @@ function flag_illegal_nests(tree::AbstractExpressionNode, options::AbstractOptio return false end -"""Check if user-passed constraints are violated or not""" +"""Check if user-passed constraints are satisfied. Returns false otherwise.""" function check_constraints( ex::AbstractExpression, options::AbstractOptions, diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl index 3bc2dd2e5..00f3d860a 100644 --- a/src/ConstrainedExpression.jl +++ b/src/ConstrainedExpression.jl @@ -12,16 +12,22 @@ using DynamicExpressions: Metadata, get_contents, with_contents, + get_metadata, get_operators, get_variable_names, - node_type + get_tree, + node_type, + eval_tree_array using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments using ..CoreModule: AbstractOptions, Dataset, CoreModule as CM using ..ConstantOptimizationModule: ConstantOptimizationModule as CO +using ..InterfaceDynamicExpressionsModule: expected_array_type using ..MutationFunctionsModule: MutationFunctionsModule as MF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..CheckConstraintsModule: CheckConstraintsModule as CC +using ..LossFunctionsModule: LossFunctionsModule as LF struct ConstrainedExpression{ T, @@ -102,6 +108,28 @@ function EB.create_expression( ) end +# function LF.eval_tree_dispatch( +# tree::ConstrainedExpression, dataset::Dataset, options::AbstractOptions, idx +# ) +# raw_contents = get_contents(tree) + +# # Raw numerical results of each inner expression: +# outs = map( +# ex -> LF.eval_tree_dispatch(ex, dataset, options, idx), +# values(raw_contents) +# ) + +# # Check for any invalid evaluations +# if !all(last, outs) +# # TODO: Would be nice to return early +# return first(outs), false +# end + +# # Combine them using the structure function: +# results = NamedTuple{keys(raw_contents)}(map(first, outs)) +# return get_metadata(tree).structure(results), true +# end + """ We need full specialization for constrained expressions, as they rely on subexpressions being combined. """ @@ -167,6 +195,42 @@ function CO.count_constants_for_optimization(ex::ConstrainedExpression) return sum(CO.count_constants_for_optimization, values(get_contents(ex))) end +function CC.check_constraints( + ex::ConstrainedExpression, + options::AbstractOptions, + maxsize::Int, + cursize::Union{Int,Nothing}=nothing, +)::Bool + raw_contents = get_contents(ex) + variable_mapping = get_metadata(ex).variable_mapping + + # First, we check the variable constraints at the top level: + has_invalid_variables = any(keys(raw_contents)) do key + tree = raw_contents[key] + allowed_variables = variable_mapping[key] + contains_other_features_than(tree, allowed_variables) + end + if has_invalid_variables + return false + end + + # Then, we check other constraints for inner expressions: + if any(t -> !CC.check_constraints(t, options, maxsize, cursize), values(raw_contents)) + return false + end + + # Then, we check the constraints for the combined tree: + return CC.check_constraints(get_tree(ex), options, maxsize, cursize) +end +function contains_other_features_than(tree::AbstractExpression, features) + return contains_other_features_than(get_tree(tree), features) +end +function contains_other_features_than(tree::AbstractExpressionNode, features) + any(tree) do node + node.degree == 0 && !node.constant && node.feature ∉ features + end +end + # TODO: Add custom behavior to adjust what feature nodes can be generated # TODO: Better versions: # - Allow evaluation to call structure function - in which case the structure would simply combine the results. diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index b41ad3a38..0b6ca0414 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -41,14 +41,15 @@ end end end -function eval_tree_dispatch( - tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, - dataset::Dataset{T}, - options::AbstractOptions, - idx, -) where {T<:DATA_TYPE} - A = expected_array_type(dataset.X) - return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options)::Tuple{A,Bool} +for N in (:AbstractExpression, :AbstractExpressionNode) + @eval function eval_tree_dispatch( + tree::$N, dataset::Dataset, options::AbstractOptions, idx + ) + A = expected_array_type(dataset.X) + return eval_tree_array( + tree, maybe_getindex(dataset.X, :, idx), options + )::Tuple{A,Bool} + end end # Evaluate the loss of a particular expression on the input dataset. diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 40f7930a8..87063e2cc 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -61,8 +61,8 @@ function EB.consistency_checks(options::AbstractOptions, prototype::ParametricEx end function LF.eval_tree_dispatch( - tree::ParametricExpression{T}, dataset::Dataset{T}, options::AbstractOptions, idx -) where {T<:DATA_TYPE} + tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx +) A = expected_array_type(dataset.X) return eval_tree_array( tree, From 4ce678d2c9179b44c410c91ccc5ff12109172bab Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 22:20:04 +0100 Subject: [PATCH 05/58] feat: allow custom evaluation for ConstrainedExpression --- src/ConstrainedExpression.jl | 42 +++++++++++++++++------------------- src/ExpressionBuilder.jl | 10 +++++++-- src/ParametricExpression.jl | 5 ----- 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl index 00f3d860a..82001e4cf 100644 --- a/src/ConstrainedExpression.jl +++ b/src/ConstrainedExpression.jl @@ -96,39 +96,37 @@ function EB.create_expression( ) where {T,L,embed,E<:ConstrainedExpression} function_keys = keys(options.expression_options.variable_mapping) - # TODO: Generalize to other inner expression types + # NOTE: We need to copy over the operators so we can call the structure function operators = options.operators variable_names = embed ? dataset.variable_names : nothing inner_expressions = ntuple( _ -> Expression(copy(t); operators, variable_names), length(function_keys) ) + # TODO: Generalize to other inner expression types return DE.constructorof(E)( NamedTuple{function_keys}(inner_expressions); EB.init_params(options, dataset, nothing, Val(embed))..., ) end -# function LF.eval_tree_dispatch( -# tree::ConstrainedExpression, dataset::Dataset, options::AbstractOptions, idx -# ) -# raw_contents = get_contents(tree) - -# # Raw numerical results of each inner expression: -# outs = map( -# ex -> LF.eval_tree_dispatch(ex, dataset, options, idx), -# values(raw_contents) -# ) - -# # Check for any invalid evaluations -# if !all(last, outs) -# # TODO: Would be nice to return early -# return first(outs), false -# end - -# # Combine them using the structure function: -# results = NamedTuple{keys(raw_contents)}(map(first, outs)) -# return get_metadata(tree).structure(results), true -# end +function LF.eval_tree_dispatch( + tree::ConstrainedExpression, dataset::Dataset, options::AbstractOptions, idx +) + raw_contents = get_contents(tree) + + # Raw numerical results of each inner expression: + outs = map(ex -> LF.eval_tree_dispatch(ex, dataset, options, idx), values(raw_contents)) + + # Check for any invalid evaluations + if !all(last, outs) + # TODO: Would be nice to return early + return first(outs), false + end + + # Combine them using the structure function: + results = NamedTuple{keys(raw_contents)}(map(first, outs)) + return get_metadata(tree).structure(results), true +end """ We need full specialization for constrained expressions, as they rely on subexpressions being combined. diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index c840d3935..4ccfd49c0 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -146,9 +146,15 @@ end end end -"""Strips all metadata except for top-level information""" +""" +Strips all metadata except for top-level information, so that we avoid needing +to copy irrelevant information to the evolution itself (like variable names +stored within an expression). + +The opposite of this is `embed_metadata`. +""" function strip_metadata( - ex::Expression, options::AbstractOptions, dataset::Dataset{T,L} + ex::AbstractExpression, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 87063e2cc..8ac24ef5a 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -41,11 +41,6 @@ function EB.extra_init_params( end return (; parameters=_parameters, parameter_names) end -function EB.strip_metadata( - ex::ParametricExpression, options::AbstractOptions, dataset::Dataset{T,L} -) where {T,L} - return with_metadata(ex; EB.init_params(options, dataset, ex, Val(false))...) -end function EB.consistency_checks(options::AbstractOptions, prototype::ParametricExpression) @assert options.expression_type <: ParametricExpression if get_metadata(prototype).parameter_names !== nothing From 7223da30c415accf00c6d0caff116a3cfc00698a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 22:30:09 +0100 Subject: [PATCH 06/58] fix: some bugs with ConstrainedExpression --- src/ConstrainedExpression.jl | 7 ++++++- src/ExpressionBuilder.jl | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl index 82001e4cf..c28cb347c 100644 --- a/src/ConstrainedExpression.jl +++ b/src/ConstrainedExpression.jl @@ -108,6 +108,11 @@ function EB.create_expression( EB.init_params(options, dataset, nothing, Val(embed))..., ) end +function sort_params(params::NamedTuple, ::Type{<:ConstrainedExpression}) + return (; + params.structure, params.operators, params.variable_names, params.variable_mapping + ) +end function LF.eval_tree_dispatch( tree::ConstrainedExpression, dataset::Dataset, options::AbstractOptions, idx @@ -120,7 +125,7 @@ function LF.eval_tree_dispatch( # Check for any invalid evaluations if !all(last, outs) # TODO: Would be nice to return early - return first(outs), false + return first(first(outs)), false end # Combine them using the structure function: diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 4ccfd49c0..674a91122 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -78,13 +78,18 @@ end ::Val{embed}, ) where {T,L,embed} consistency_checks(options, prototype) - return (; + raw_params = (; operators=embed ? options.operators : nothing, variable_names=embed ? dataset.variable_names : nothing, extra_init_params( options.expression_type, prototype, options, dataset, Val(embed) )..., ) + sorted_params = sort_params(raw_params, options.expression_type) + return sorted_params +end +function sort_params(raw_params::NamedTuple, ::Type{<:AbstractExpression}) + return raw_params end function extra_init_params( ::Type{E}, From 0300b059b765b23c7eb92e12a681d0506165b464 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 22:49:42 +0100 Subject: [PATCH 07/58] fix: call to `sort_params` --- src/ConstrainedExpression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl index c28cb347c..30b29b3a4 100644 --- a/src/ConstrainedExpression.jl +++ b/src/ConstrainedExpression.jl @@ -108,7 +108,7 @@ function EB.create_expression( EB.init_params(options, dataset, nothing, Val(embed))..., ) end -function sort_params(params::NamedTuple, ::Type{<:ConstrainedExpression}) +function EB.sort_params(params::NamedTuple, ::Type{<:ConstrainedExpression}) return (; params.structure, params.operators, params.variable_names, params.variable_mapping ) From 24edd8c246fcb4ca0679b8aa9086e04e815deb20 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 22:49:55 +0100 Subject: [PATCH 08/58] refactor: style adjustment --- src/ExpressionBuilder.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 674a91122..709937ecf 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -85,8 +85,7 @@ end options.expression_type, prototype, options, dataset, Val(embed) )..., ) - sorted_params = sort_params(raw_params, options.expression_type) - return sorted_params + return sort_params(raw_params, options.expression_type) end function sort_params(raw_params::NamedTuple, ::Type{<:AbstractExpression}) return raw_params From 3225e6f82cb319b3e452bfa99b63ec2749e6fbbe Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 23:09:15 +0100 Subject: [PATCH 09/58] fix: ensure `operators` always available --- src/ConstrainedExpression.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl index 30b29b3a4..28cda395c 100644 --- a/src/ConstrainedExpression.jl +++ b/src/ConstrainedExpression.jl @@ -108,6 +108,16 @@ function EB.create_expression( EB.init_params(options, dataset, nothing, Val(embed))..., ) end +function EB.extra_init_params( + ::Type{E}, + prototype::Union{Nothing,AbstractExpression}, + options::AbstractOptions, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:ConstrainedExpression} + # We also need to include the operators here to be consistent with `create_expression`. + return (; options.operators, options.expression_options...) +end function EB.sort_params(params::NamedTuple, ::Type{<:ConstrainedExpression}) return (; params.structure, params.operators, params.variable_names, params.variable_mapping From 2e0867bc89d29b430c539acacf063baa27b65e4e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 23:09:32 +0100 Subject: [PATCH 10/58] feat: give reduced complexity for constrained expressions --- src/ConstrainedExpression.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl index 28cda395c..f49eaaecb 100644 --- a/src/ConstrainedExpression.jl +++ b/src/ConstrainedExpression.jl @@ -27,6 +27,7 @@ using ..InterfaceDynamicExpressionsModule: expected_array_type using ..MutationFunctionsModule: MutationFunctionsModule as MF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB using ..CheckConstraintsModule: CheckConstraintsModule as CC +using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF struct ConstrainedExpression{ @@ -36,6 +37,7 @@ struct ConstrainedExpression{ E<:Expression{T,N}, # TODO: Generalize this TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, C<:NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, # The constraints + # TODO: No need for this to be a parametric type D<:@NamedTuple{ structure::F, operators::O, variable_names::V, variable_mapping::C } where {O,V}, @@ -124,6 +126,17 @@ function EB.sort_params(params::NamedTuple, ::Type{<:ConstrainedExpression}) ) end +function ComplexityModule.compute_complexity( + tree::ConstrainedExpression, options::AbstractOptions; break_sharing=Val(false) +) + # Rather than including the complexity of the combined tree, + # we only sum the complexity of each inner expression, which will be smaller. + return sum( + ex -> ComplexityModule.compute_complexity(ex, options; break_sharing), + values(get_contents(tree)), + ) +end + function LF.eval_tree_dispatch( tree::ConstrainedExpression, dataset::Dataset, options::AbstractOptions, idx ) From 958890b7dc69ab81f9698c905f16e5f06f36b043 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Oct 2024 23:32:18 +0100 Subject: [PATCH 11/58] docs: missing comment --- src/ConstrainedExpression.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ConstrainedExpression.jl b/src/ConstrainedExpression.jl index f49eaaecb..b1ec6186d 100644 --- a/src/ConstrainedExpression.jl +++ b/src/ConstrainedExpression.jl @@ -258,6 +258,7 @@ function contains_other_features_than(tree::AbstractExpressionNode, features) end # TODO: Add custom behavior to adjust what feature nodes can be generated +# TODO: Add custom printing # TODO: Better versions: # - Allow evaluation to call structure function - in which case the structure would simply combine the results. # - Maybe we want to do similar for string output as well. That way, the operators provided don't really matter. From bb86777c33745e089c2ed9fcd7e2589561762fac Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 00:59:07 +0100 Subject: [PATCH 12/58] fix: aliasing issue in `simplify` with extra copy --- src/Mutate.jl | 9 ++++----- src/MutationFunctions.jl | 17 ++++++++--------- src/SingleIteration.jl | 7 +++++++ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/Mutate.jl b/src/Mutate.jl index 3cf7aadd2..ddf9a889e 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -6,7 +6,6 @@ using DynamicExpressions: with_contents, get_tree, preserve_sharing, - copy_node, count_scalar_constants, simplify_tree!, combine_operators @@ -206,7 +205,7 @@ function next_generation( ############################################# local tree while (!successful_mutation) && attempts < max_attempts - tree = copy_node(member.tree) + tree = copy(member.tree) # TODO: This uses dynamic dispatch. But it doesn't seem that bad # in terms of performance. Still should investigate in more detail. @@ -253,7 +252,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy_node(member.tree), + copy(member.tree), beforeScore, beforeLoss, options, @@ -282,7 +281,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy_node(member.tree), + copy(member.tree), beforeScore, beforeLoss, options, @@ -324,7 +323,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy_node(member.tree), + copy(member.tree), beforeScore, beforeLoss, options, diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index e0a7fa3f4..3adee6fb8 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -9,7 +9,6 @@ using DynamicExpressions: get_contents, with_contents, constructorof, - copy_node, set_node!, count_nodes, has_constants, @@ -212,7 +211,7 @@ function insert_random_op( node = rand(rng, NodeSampler(; tree)) choice = rand(rng) makeNewBinOp = choice < options.nbin / (options.nuna + options.nbin) - left = copy_node(node) + left = copy(node) if makeNewBinOp right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) @@ -248,7 +247,7 @@ function prepend_random_op( node = tree choice = rand(rng) makeNewBinOp = choice < options.nbin / (options.nuna + options.nbin) - left = copy_node(tree) + left = copy(tree) if makeNewBinOp right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) @@ -424,23 +423,23 @@ end function crossover_trees( tree1::N, tree2::N, rng::AbstractRNG=default_rng() ) where {T,N<:AbstractExpressionNode{T}} - tree1 = copy_node(tree1) - tree2 = copy_node(tree2) + tree1 = copy(tree1) + tree2 = copy(tree2) node1, parent1, side1 = random_node_and_parent(tree1, rng) node2, parent2, side2 = random_node_and_parent(tree2, rng) - node1 = copy_node(node1) + node1 = copy(node1) if side1 == 'l' - parent1.l = copy_node(node2) + parent1.l = copy(node2) # tree1 now contains this. elseif side1 == 'r' - parent1.r = copy_node(node2) + parent1.r = copy(node2) # tree1 now contains this. else # 'n' # This means that there is no parent2. - tree1 = copy_node(node2) + tree1 = copy(node2) end if side2 == 'l' diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index d15e7914c..7b9de2784 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -105,6 +105,13 @@ function optimize_and_simplify_population( # Note: we have to turn off this threading loop due to Enzyme, since we need # to manually allocate a new task with a larger stack for Enzyme. should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) + + # TODO: This `copy` is necessary to avoid an undefined reference + # error when simplifying, and only for `ConstrainedExpression`. + # But, why is it needed? Could it be that + # some of the expressions across the population share subtrees? + pop.members .= map(copy, pop.members) + @threads_if should_thread for j in 1:(pop.n) if options.should_simplify tree = pop.members[j].tree From c81360b1230f5eef1f4bce31736066427b9078c5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 01:19:38 +0100 Subject: [PATCH 13/58] refactor: fix bad nospecialize --- src/Population.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Population.jl b/src/Population.jl index 3a544730b..b348f6d91 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -168,7 +168,7 @@ const CACHED_WEIGHTS = PerThreadCache{Dict{Tuple{Int,Float32},typeof(test_weights)}}() end -@unstable function get_tournament_selection_weights(@nospecialize(options::AbstractOptions)) +@unstable function get_tournament_selection_weights(@nospecialize(options::Options)) n = options.tournament_selection_n p = options.tournament_selection_p # Computing the weights for the tournament becomes quite expensive, From a3d28ed00b0e880a314709255d6c6157d5298d7e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 01:19:52 +0100 Subject: [PATCH 14/58] fix: prevent aliasing during crossover --- src/RegularizedEvolution.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl index 06358a328..a365a35a8 100644 --- a/src/RegularizedEvolution.jl +++ b/src/RegularizedEvolution.jl @@ -84,8 +84,8 @@ function reg_evol_cycle( pop.members[oldest] = baby else # Crossover - allstar1 = best_of_sample(pop, running_search_statistics, options) - allstar2 = best_of_sample(pop, running_search_statistics, options) + allstar1 = copy(best_of_sample(pop, running_search_statistics, options)) + allstar2 = copy(best_of_sample(pop, running_search_statistics, options)) baby1, baby2, crossover_accepted, tmp_num_evals = crossover_generation( allstar1, allstar2, dataset, curmaxsize, options From 52f738adc2b62b040ea7ebf3e5a055fa1fe6b678 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 01:21:40 +0100 Subject: [PATCH 15/58] fix: guard more aliasing issues --- src/MutationFunctions.jl | 6 ++++++ src/RegularizedEvolution.jl | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index 3adee6fb8..73e0367b0 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -411,6 +411,9 @@ end function crossover_trees( ex1::E, ex2::E, rng::AbstractRNG=default_rng() ) where {T,E<:AbstractExpression{T}} + if ex1 === ex2 + error("Attempted to crossover the same expression!") + end tree1, context1 = get_contents_for_mutation(ex1, rng) tree2, context2 = get_contents_for_mutation(ex2, rng) out1, out2 = crossover_trees(tree1, tree2, rng) @@ -423,6 +426,9 @@ end function crossover_trees( tree1::N, tree2::N, rng::AbstractRNG=default_rng() ) where {T,N<:AbstractExpressionNode{T}} + if tree1 === tree2 + error("Attempted to crossover the same tree!") + end tree1 = copy(tree1) tree2 = copy(tree2) diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl index a365a35a8..7b140de71 100644 --- a/src/RegularizedEvolution.jl +++ b/src/RegularizedEvolution.jl @@ -31,7 +31,7 @@ function reg_evol_cycle( for i in 1:n_evol_cycles if rand() > options.crossover_probability - allstar = best_of_sample(pop, running_search_statistics, options) + allstar = copy(best_of_sample(pop, running_search_statistics, options)) mutation_recorder = RecordType() baby, mutation_accepted, tmp_num_evals = next_generation( dataset, From ae69e104373ab49a07db616f68f89a36a6dfda3b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 01:23:09 +0100 Subject: [PATCH 16/58] fix: missing import --- src/Population.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Population.jl b/src/Population.jl index b348f6d91..54aabd369 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -3,7 +3,7 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, string_tree -using ..CoreModule: AbstractOptions, Dataset, RecordType, DATA_TYPE, LOSS_TYPE +using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics From dc276f9588c77f1c9dc315044365ad5fcc18ee6a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 02:17:42 +0100 Subject: [PATCH 17/58] refactor: rename to `BlueprintExpression` --- ...edExpression.jl => BlueprintExpression.jl} | 38 +++++++++---------- src/SingleIteration.jl | 2 +- src/SymbolicRegression.jl | 6 +-- 3 files changed, 23 insertions(+), 23 deletions(-) rename src/{ConstrainedExpression.jl => BlueprintExpression.jl} (87%) diff --git a/src/ConstrainedExpression.jl b/src/BlueprintExpression.jl similarity index 87% rename from src/ConstrainedExpression.jl rename to src/BlueprintExpression.jl index b1ec6186d..27f53a340 100644 --- a/src/ConstrainedExpression.jl +++ b/src/BlueprintExpression.jl @@ -30,7 +30,7 @@ using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF -struct ConstrainedExpression{ +struct BlueprintExpression{ T, F<:Function, N<:AbstractExpressionNode{T}, @@ -45,7 +45,7 @@ struct ConstrainedExpression{ trees::TS metadata::Metadata{D} - function ConstrainedExpression( + function BlueprintExpression( trees::TS, metadata::Metadata{D} ) where { TS, @@ -61,7 +61,7 @@ struct ConstrainedExpression{ end end -function ConstrainedExpression( +function BlueprintExpression( trees::TS; structure::F, operators::Union{AbstractOperatorEnum,Nothing}=nothing, @@ -79,13 +79,13 @@ function ConstrainedExpression( operators = get_operators(example_tree, operators) variable_names = get_variable_names(example_tree, variable_names) metadata = (; structure, operators, variable_names, variable_mapping) - return ConstrainedExpression(trees, Metadata(metadata)) + return BlueprintExpression(trees, Metadata(metadata)) end -DE.constructorof(::Type{<:ConstrainedExpression}) = ConstrainedExpression +DE.constructorof(::Type{<:BlueprintExpression}) = BlueprintExpression @implements( - ExpressionInterface{all_ei_methods_except(())}, ConstrainedExpression, [Arguments()] + ExpressionInterface{all_ei_methods_except(())}, BlueprintExpression, [Arguments()] ) function EB.create_expression( @@ -95,7 +95,7 @@ function EB.create_expression( ::Type{<:AbstractExpressionNode}, ::Type{E}, ::Val{embed}=Val(false), -) where {T,L,embed,E<:ConstrainedExpression} +) where {T,L,embed,E<:BlueprintExpression} function_keys = keys(options.expression_options.variable_mapping) # NOTE: We need to copy over the operators so we can call the structure function @@ -116,18 +116,18 @@ function EB.extra_init_params( options::AbstractOptions, dataset::Dataset{T}, ::Val{embed}, -) where {T,embed,E<:ConstrainedExpression} +) where {T,embed,E<:BlueprintExpression} # We also need to include the operators here to be consistent with `create_expression`. return (; options.operators, options.expression_options...) end -function EB.sort_params(params::NamedTuple, ::Type{<:ConstrainedExpression}) +function EB.sort_params(params::NamedTuple, ::Type{<:BlueprintExpression}) return (; params.structure, params.operators, params.variable_names, params.variable_mapping ) end function ComplexityModule.compute_complexity( - tree::ConstrainedExpression, options::AbstractOptions; break_sharing=Val(false) + tree::BlueprintExpression, options::AbstractOptions; break_sharing=Val(false) ) # Rather than including the complexity of the combined tree, # we only sum the complexity of each inner expression, which will be smaller. @@ -138,7 +138,7 @@ function ComplexityModule.compute_complexity( end function LF.eval_tree_dispatch( - tree::ConstrainedExpression, dataset::Dataset, options::AbstractOptions, idx + tree::BlueprintExpression, dataset::Dataset, options::AbstractOptions, idx ) raw_contents = get_contents(tree) @@ -160,14 +160,14 @@ end We need full specialization for constrained expressions, as they rely on subexpressions being combined. """ CM.operator_specialization( - ::Type{O}, ::Type{<:ConstrainedExpression} + ::Type{O}, ::Type{<:BlueprintExpression} ) where {O<:OperatorEnum} = O """ We pick a random subexpression to mutate, and also return the symbol we mutated on so that we can put it back together later. """ -function MF.get_contents_for_mutation(ex::ConstrainedExpression, rng::AbstractRNG) +function MF.get_contents_for_mutation(ex::BlueprintExpression, rng::AbstractRNG) raw_contents = get_contents(ex) function_keys = keys(raw_contents) key_to_mutate = rand(rng, function_keys) @@ -175,9 +175,9 @@ function MF.get_contents_for_mutation(ex::ConstrainedExpression, rng::AbstractRN return raw_contents[key_to_mutate], key_to_mutate end -"""See `get_contents_for_mutation(::ConstrainedExpression, ::AbstractRNG)`.""" +"""See `get_contents_for_mutation(::BlueprintExpression, ::AbstractRNG)`.""" function MF.with_contents_for_mutation( - ex::ConstrainedExpression, new_inner_contents, context::Symbol + ex::BlueprintExpression, new_inner_contents, context::Symbol ) raw_contents = get_contents(ex) raw_contents_keys = keys(raw_contents) @@ -195,7 +195,7 @@ end """We combine the operators of each inner expression.""" function DE.combine_operators( - ex::ConstrainedExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing + ex::BlueprintExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing ) where {T,N} raw_contents = get_contents(ex) function_keys = keys(raw_contents) @@ -207,7 +207,7 @@ end """We simplify each inner expression.""" function DE.simplify_tree!( - ex::ConstrainedExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing + ex::BlueprintExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing ) where {T,N} raw_contents = get_contents(ex) function_keys = keys(raw_contents) @@ -217,12 +217,12 @@ function DE.simplify_tree!( return with_contents(ex, new_contents) end -function CO.count_constants_for_optimization(ex::ConstrainedExpression) +function CO.count_constants_for_optimization(ex::BlueprintExpression) return sum(CO.count_constants_for_optimization, values(get_contents(ex))) end function CC.check_constraints( - ex::ConstrainedExpression, + ex::BlueprintExpression, options::AbstractOptions, maxsize::Int, cursize::Union{Int,Nothing}=nothing, diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index 7b9de2784..36155a600 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -107,7 +107,7 @@ function optimize_and_simplify_population( should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) # TODO: This `copy` is necessary to avoid an undefined reference - # error when simplifying, and only for `ConstrainedExpression`. + # error when simplifying, and only for `BlueprintExpression`. # But, why is it needed? Could it be that # some of the expressions across the population share subtrees? pop.members .= map(copy, pop.members) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index d1ca92fc9..603f26068 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -12,7 +12,7 @@ export Population, ParametricNode, Expression, ParametricExpression, - ConstrainedExpression, + BlueprintExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, @@ -216,7 +216,7 @@ using DispatchDoctor: @stable include("Migration.jl") include("SearchUtils.jl") include("ExpressionBuilder.jl") - include("ConstrainedExpression.jl") + include("BlueprintExpression.jl") include("ParametricExpression.jl") end @@ -309,7 +309,7 @@ using .SearchUtilsModule: save_to_file, get_cur_maxsize, update_hall_of_fame! -using .ConstrainedExpressionModule: ConstrainedExpression +using .ConstrainedExpressionModule: BlueprintExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin From f455f7aebd519880510dfbf57da5abf0e30800ed Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 02:50:37 +0100 Subject: [PATCH 18/58] fix: missing `@unstable` --- src/OptionsStruct.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index a0cea1b8b..fa8a0035b 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -121,7 +121,8 @@ Overload if needed for custom expression types. operator_specialization( ::Type{O}, ::Type{<:AbstractExpression} ) where {O<:AbstractOperatorEnum} = O -operator_specialization(::Type{<:OperatorEnum}, ::Type{<:AbstractExpression}) = OperatorEnum +@unstable operator_specialization(::Type{<:OperatorEnum}, ::Type{<:AbstractExpression}) = + OperatorEnum """ AbstractOptions From bf712c6c0666a06ccf44700430ba5b3e2578fd02 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 03:38:04 +0100 Subject: [PATCH 19/58] docs: add extended docstring for BlueprintExpression and example --- examples/blueprint_expression.jl | 85 ++++++++++++++++++++++++++++++++ src/BlueprintExpression.jl | 60 ++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 examples/blueprint_expression.jl diff --git a/examples/blueprint_expression.jl b/examples/blueprint_expression.jl new file mode 100644 index 000000000..2295c990b --- /dev/null +++ b/examples/blueprint_expression.jl @@ -0,0 +1,85 @@ +using SymbolicRegression +using DynamicExpressions: + DynamicExpressions as DE, + Metadata, + get_tree, + get_operators, + get_variable_names, + OperatorEnum, + AbstractExpression +using Random: MersenneTwister +using MLJBase: machine, fit!, predict, report +using Test + +using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface + +# Impose structure: +# Function f(x1, x2) +# Function g(x3) +# y = sin(f) + g^2 +operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +variable_names = (i -> "x$i").(1:3) +x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + +# For combining expressions to a single expression: +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractExpression}}}) + return sin(nt.f) + nt.g * nt.g +end +# For combining numerical outputs: +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) + return @. sin(nt.f) + nt.g * nt.g +end + +variable_mapping = (; f=[1, 2], g=[3]) + +st_expr = BlueprintExpression( + (; f=x1, g=x3); structure=my_structure, operators, variable_names, variable_mapping +) + +# @test Interfaces.test( +# ExpressionInterface, +# BlueprintExpression, +# [st_expr] +# ) + +model = SRRegressor(; + niterations=200, + binary_operators=(+, *, /, -), + unary_operators=(sin, cos), + populations=30, + maxsize=30, + expression_type=BlueprintExpression, + expression_options=(; structure=my_structure, variable_mapping), + parallelism=:multithreading, +) + +X = rand(100, 3) .* 5 +y = @. exp(X[:, 1]) + X[:, 3] * X[:, 3] + +mach = machine(model, X, y) + +fit!(mach) + +# using Profile +# Profile.init(n=10^8, delay=1e-4) +# mach = machine(model, X, y) +# @profile fit!(mach, verbosity=0) +# Profile.clear() +# mach = machine(model, X, y) +# @profile fit!(mach) + +# using PProf +# pprof() +# idx1 = lastindex(report(mach).equations) +# ypred1 = predict(mach, (data=X, idx=idx1)) +# loss1 = sum(i -> abs2(ypred1[i] - y[i]), eachindex(y)) + +# # Should keep all parameters +# stop_at[] = 1e-5 +# fit!(mach) +# idx2 = lastindex(report(mach).equations) +# ypred2 = predict(mach, (data=X, idx=idx2)) +# loss2 = sum(i -> abs2(ypred2[i] - y[i]), eachindex(y)) + +# # Should get better: +# @test loss1 >= loss2 diff --git a/src/BlueprintExpression.jl b/src/BlueprintExpression.jl index 27f53a340..6f0dafb35 100644 --- a/src/BlueprintExpression.jl +++ b/src/BlueprintExpression.jl @@ -30,6 +30,66 @@ using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF +""" + BlueprintExpression{T,F,N,E,TS,C,D} <: AbstractStructuredExpression{T,F,N,E,D} + +A symbolic expression that allows the combination of multiple sub-expressions +in a structured way, with constraints on variable usage. + +`BlueprintExpression` is designed for symbolic regression tasks where +domain-specific knowledge or constraints must be imposed on the model's structure. + +# Constructor + +- `BlueprintExpression(trees; structure, operators, variable_names, variable_mapping)` + - `trees`: A `NamedTuple` holding the sub-expressions (e.g., `f = Expression(...)`, `g = Expression(...)`). + - `structure`: A function that defines how the sub-expressions are combined. This should have one method + that takes `trees` as input and returns a single `Expression` node, and another method which takes + a `NamedTuple` of `Vector` (representing the numerical results of each sub-expression) and returns + a single vector after combining them. + - `operators`: An `OperatorEnum` that defines the allowed operators for the sub-expressions. + - `variable_names`: An optional `Vector` of `String` that defines the names of the variables in the dataset. + - `variable_mapping`: A `NamedTuple` that defines which variables each sub-expression is allowed to access. + For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`. + +# Example + +Let's create an example `BlueprintExpression` that combines two sub-expressions `f(x1, x2)` and `g(x3)`: + +```julia +# Define operators and variable names +operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +variable_names = ["x1", "x2", "x3"] + +# Create sub-expressions +x1 = Expression(Node{Float64}(; feature=1); operators, variable_names) +x2 = Expression(Node{Float64}(; feature=2); operators, variable_names) +x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) + +# Define structure function for symbolic and numerical evaluation +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:Expression}}}) + return sin(nt.f) + nt.g * nt.g +end +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) + return @. sin(nt.f) + nt.g * nt.g +end + +# Define variable constraints (if desired) +variable_mapping = (; f=[1, 2], g=[3]) + +# Create BlueprintExpression +example_expr = (; f=x1, g=x3) +st_expr = BlueprintExpression( + example_expr; + structure=my_structure, operators, variable_names, variable_mapping +) +``` + +When fitting a model in SymbolicRegression.jl, you would provide the `BlueprintExpression` +as the `expression_type` argument, and then pass `expression_options=(; structure=my_structure, variable_mapping=variable_mapping)` +as additional options. The `variable_mapping` will constraint `f` to only have access to `x1` and `x2`, +and `g` to only have access to `x3`. +""" struct BlueprintExpression{ T, F<:Function, From 2123ef29c11aa5370a10da4040f63b256b122ff2 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 03:49:41 +0100 Subject: [PATCH 20/58] fix: unbound type parameter --- src/BlueprintExpression.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/BlueprintExpression.jl b/src/BlueprintExpression.jl index 6f0dafb35..e56645e99 100644 --- a/src/BlueprintExpression.jl +++ b/src/BlueprintExpression.jl @@ -122,12 +122,12 @@ struct BlueprintExpression{ end function BlueprintExpression( - trees::TS; + trees::NamedTuple{<:Any,<:NTuple{<:Any,<:AbstractExpression}}; structure::F, operators::Union{AbstractOperatorEnum,Nothing}=nothing, variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, variable_mapping::NamedTuple{<:Any,<:NTuple{<:Any,Vector{Int}}}, -) where {E<:AbstractExpression,TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}},F<:Function} +) where {F<:Function} @assert length(trees) == length(variable_mapping) if variable_names !== nothing # TODO: Should this be removed? From a2afe7cc66ad9db28e0ee282716d69ec2c73a363 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 03:56:11 +0100 Subject: [PATCH 21/58] test: fix parametric expression test --- src/ParametricExpression.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 8ac24ef5a..0180a1602 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -42,7 +42,10 @@ function EB.extra_init_params( return (; parameters=_parameters, parameter_names) end function EB.consistency_checks(options::AbstractOptions, prototype::ParametricExpression) - @assert options.expression_type <: ParametricExpression + @assert( + options.expression_type <: ParametricExpression, + "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))" + ) if get_metadata(prototype).parameter_names !== nothing @assert( length(get_metadata(prototype).parameter_names) == From bd20e4dcfb6046e4411983250e5df8731410eae1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 21:32:43 +0100 Subject: [PATCH 22/58] refactor: rename to `TemplateExpression` --- examples/blueprint_expression.jl | 37 +++++--- src/SingleIteration.jl | 2 +- src/SymbolicRegression.jl | 6 +- ...intExpression.jl => TemplateExpression.jl} | 88 ++++++++++++------- 4 files changed, 83 insertions(+), 50 deletions(-) rename src/{BlueprintExpression.jl => TemplateExpression.jl} (78%) diff --git a/examples/blueprint_expression.jl b/examples/blueprint_expression.jl index 2295c990b..fb53dc1b5 100644 --- a/examples/blueprint_expression.jl +++ b/examples/blueprint_expression.jl @@ -22,23 +22,28 @@ variable_names = (i -> "x$i").(1:3) x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) # For combining expressions to a single expression: -function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractExpression}}}) - return sin(nt.f) + nt.g * nt.g +function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) + return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2), $(nt.f) + $(nt.g3) )" end -# For combining numerical outputs: function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) - return @. sin(nt.f) + nt.g * nt.g + return map( + i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i], nt.f[i] + nt.g3[i]), eachindex(nt.f) + ) end -variable_mapping = (; f=[1, 2], g=[3]) +variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) -st_expr = BlueprintExpression( - (; f=x1, g=x3); structure=my_structure, operators, variable_names, variable_mapping +st_expr = TemplateExpression( + (; f=x1, g1=x3, g2=x3, g3=x3); + structure=my_structure, + operators, + variable_names, + variable_mapping, ) # @test Interfaces.test( # ExpressionInterface, -# BlueprintExpression, +# TemplateExpression, # [st_expr] # ) @@ -48,18 +53,26 @@ model = SRRegressor(; unary_operators=(sin, cos), populations=30, maxsize=30, - expression_type=BlueprintExpression, + expression_type=TemplateExpression, expression_options=(; structure=my_structure, variable_mapping), parallelism=:multithreading, + elementwise_loss=((x1, x2, x3), (y1, y2, y3)) -> + (y1 - x1)^2 + (y2 - x2)^2 + (y3 - x3)^2, ) -X = rand(100, 3) .* 5 -y = @. exp(X[:, 1]) + X[:, 3] * X[:, 3] +X = rand(100, 3) +y = [ + (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 2]) + X[i, 3]^2, sin(X[i, 3]) + X[i, 3]^2) for + i in eachindex(axes(X, 1)) +] -mach = machine(model, X, y) +dataset = Dataset(X', y) +mach = machine(model, X, y) fit!(mach) +println("hello") + # using Profile # Profile.init(n=10^8, delay=1e-4) # mach = machine(model, X, y) diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index 36155a600..1d305cbb2 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -107,7 +107,7 @@ function optimize_and_simplify_population( should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) # TODO: This `copy` is necessary to avoid an undefined reference - # error when simplifying, and only for `BlueprintExpression`. + # error when simplifying, and only for `TemplateExpression`. # But, why is it needed? Could it be that # some of the expressions across the population share subtrees? pop.members .= map(copy, pop.members) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 5a087b8eb..b405fece4 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -12,7 +12,7 @@ export Population, ParametricNode, Expression, ParametricExpression, - BlueprintExpression, + TemplateExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, @@ -216,7 +216,7 @@ using DispatchDoctor: @stable include("Migration.jl") include("SearchUtils.jl") include("ExpressionBuilder.jl") - include("BlueprintExpression.jl") + include("TemplateExpression.jl") include("ParametricExpression.jl") end @@ -309,7 +309,7 @@ using .SearchUtilsModule: save_to_file, get_cur_maxsize, update_hall_of_fame! -using .ConstrainedExpressionModule: BlueprintExpression +using .ConstrainedExpressionModule: TemplateExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin diff --git a/src/BlueprintExpression.jl b/src/TemplateExpression.jl similarity index 78% rename from src/BlueprintExpression.jl rename to src/TemplateExpression.jl index e56645e99..4971d82fe 100644 --- a/src/BlueprintExpression.jl +++ b/src/TemplateExpression.jl @@ -21,27 +21,30 @@ using DynamicExpressions: using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments -using ..CoreModule: AbstractOptions, Dataset, CoreModule as CM +using ..CoreModule: AbstractOptions, Dataset, CoreModule as CM, AbstractMutationWeights using ..ConstantOptimizationModule: ConstantOptimizationModule as CO using ..InterfaceDynamicExpressionsModule: expected_array_type using ..MutationFunctionsModule: MutationFunctionsModule as MF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..DimensionalAnalysisModule: DimensionalAnalysisModule as DA using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF +using ..MutateModule: MutateModule as MM +using ..PopMemberModule: PopMember """ - BlueprintExpression{T,F,N,E,TS,C,D} <: AbstractStructuredExpression{T,F,N,E,D} + TemplateExpression{T,F,N,E,TS,C,D} <: AbstractStructuredExpression{T,F,N,E,D} A symbolic expression that allows the combination of multiple sub-expressions in a structured way, with constraints on variable usage. -`BlueprintExpression` is designed for symbolic regression tasks where +`TemplateExpression` is designed for symbolic regression tasks where domain-specific knowledge or constraints must be imposed on the model's structure. # Constructor -- `BlueprintExpression(trees; structure, operators, variable_names, variable_mapping)` +- `TemplateExpression(trees; structure, operators, variable_names, variable_mapping)` - `trees`: A `NamedTuple` holding the sub-expressions (e.g., `f = Expression(...)`, `g = Expression(...)`). - `structure`: A function that defines how the sub-expressions are combined. This should have one method that takes `trees` as input and returns a single `Expression` node, and another method which takes @@ -54,7 +57,7 @@ domain-specific knowledge or constraints must be imposed on the model's structur # Example -Let's create an example `BlueprintExpression` that combines two sub-expressions `f(x1, x2)` and `g(x3)`: +Let's create an example `TemplateExpression` that combines two sub-expressions `f(x1, x2)` and `g(x3)`: ```julia # Define operators and variable names @@ -77,20 +80,20 @@ end # Define variable constraints (if desired) variable_mapping = (; f=[1, 2], g=[3]) -# Create BlueprintExpression +# Create TemplateExpression example_expr = (; f=x1, g=x3) -st_expr = BlueprintExpression( +st_expr = TemplateExpression( example_expr; structure=my_structure, operators, variable_names, variable_mapping ) ``` -When fitting a model in SymbolicRegression.jl, you would provide the `BlueprintExpression` +When fitting a model in SymbolicRegression.jl, you would provide the `TemplateExpression` as the `expression_type` argument, and then pass `expression_options=(; structure=my_structure, variable_mapping=variable_mapping)` as additional options. The `variable_mapping` will constraint `f` to only have access to `x1` and `x2`, and `g` to only have access to `x3`. """ -struct BlueprintExpression{ +struct TemplateExpression{ T, F<:Function, N<:AbstractExpressionNode{T}, @@ -105,7 +108,7 @@ struct BlueprintExpression{ trees::TS metadata::Metadata{D} - function BlueprintExpression( + function TemplateExpression( trees::TS, metadata::Metadata{D} ) where { TS, @@ -121,7 +124,7 @@ struct BlueprintExpression{ end end -function BlueprintExpression( +function TemplateExpression( trees::NamedTuple{<:Any,<:NTuple{<:Any,<:AbstractExpression}}; structure::F, operators::Union{AbstractOperatorEnum,Nothing}=nothing, @@ -139,13 +142,13 @@ function BlueprintExpression( operators = get_operators(example_tree, operators) variable_names = get_variable_names(example_tree, variable_names) metadata = (; structure, operators, variable_names, variable_mapping) - return BlueprintExpression(trees, Metadata(metadata)) + return TemplateExpression(trees, Metadata(metadata)) end -DE.constructorof(::Type{<:BlueprintExpression}) = BlueprintExpression +DE.constructorof(::Type{<:TemplateExpression}) = TemplateExpression @implements( - ExpressionInterface{all_ei_methods_except(())}, BlueprintExpression, [Arguments()] + ExpressionInterface{all_ei_methods_except(())}, TemplateExpression, [Arguments()] ) function EB.create_expression( @@ -155,7 +158,7 @@ function EB.create_expression( ::Type{<:AbstractExpressionNode}, ::Type{E}, ::Val{embed}=Val(false), -) where {T,L,embed,E<:BlueprintExpression} +) where {T,L,embed,E<:TemplateExpression} function_keys = keys(options.expression_options.variable_mapping) # NOTE: We need to copy over the operators so we can call the structure function @@ -176,18 +179,18 @@ function EB.extra_init_params( options::AbstractOptions, dataset::Dataset{T}, ::Val{embed}, -) where {T,embed,E<:BlueprintExpression} +) where {T,embed,E<:TemplateExpression} # We also need to include the operators here to be consistent with `create_expression`. return (; options.operators, options.expression_options...) end -function EB.sort_params(params::NamedTuple, ::Type{<:BlueprintExpression}) +function EB.sort_params(params::NamedTuple, ::Type{<:TemplateExpression}) return (; params.structure, params.operators, params.variable_names, params.variable_mapping ) end function ComplexityModule.compute_complexity( - tree::BlueprintExpression, options::AbstractOptions; break_sharing=Val(false) + tree::TemplateExpression, options::AbstractOptions; break_sharing=Val(false) ) # Rather than including the complexity of the combined tree, # we only sum the complexity of each inner expression, which will be smaller. @@ -197,37 +200,54 @@ function ComplexityModule.compute_complexity( ) end +function DE.string_tree( + tree::TemplateExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... +) + raw_contents = get_contents(tree) + function_keys = keys(raw_contents) + inner_strings = NamedTuple{function_keys}( + map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) + ) + # TODO: Make a fallback function in case the structure function is undefined. + return get_metadata(tree).structure(inner_strings) +end function LF.eval_tree_dispatch( - tree::BlueprintExpression, dataset::Dataset, options::AbstractOptions, idx + tree::TemplateExpression, dataset::Dataset, options::AbstractOptions, idx ) raw_contents = get_contents(tree) # Raw numerical results of each inner expression: outs = map(ex -> LF.eval_tree_dispatch(ex, dataset, options, idx), values(raw_contents)) - # Check for any invalid evaluations - if !all(last, outs) - # TODO: Would be nice to return early - return first(first(outs)), false - end - # Combine them using the structure function: results = NamedTuple{keys(raw_contents)}(map(first, outs)) - return get_metadata(tree).structure(results), true + return get_metadata(tree).structure(results), all(last, outs) +end +function DA.violates_dimensional_constraints( + tree::TemplateExpression, dataset::Dataset, options::AbstractOptions +) + @assert dataset.X_units === nothing && dataset.y_units === nothing + return false +end +function MM.condition_mutation_weights!( + _::AbstractMutationWeights, _::P, _::AbstractOptions, _::Int +) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} + # HACK TODO + return nothing end """ We need full specialization for constrained expressions, as they rely on subexpressions being combined. """ CM.operator_specialization( - ::Type{O}, ::Type{<:BlueprintExpression} + ::Type{O}, ::Type{<:TemplateExpression} ) where {O<:OperatorEnum} = O """ We pick a random subexpression to mutate, and also return the symbol we mutated on so that we can put it back together later. """ -function MF.get_contents_for_mutation(ex::BlueprintExpression, rng::AbstractRNG) +function MF.get_contents_for_mutation(ex::TemplateExpression, rng::AbstractRNG) raw_contents = get_contents(ex) function_keys = keys(raw_contents) key_to_mutate = rand(rng, function_keys) @@ -235,9 +255,9 @@ function MF.get_contents_for_mutation(ex::BlueprintExpression, rng::AbstractRNG) return raw_contents[key_to_mutate], key_to_mutate end -"""See `get_contents_for_mutation(::BlueprintExpression, ::AbstractRNG)`.""" +"""See `get_contents_for_mutation(::TemplateExpression, ::AbstractRNG)`.""" function MF.with_contents_for_mutation( - ex::BlueprintExpression, new_inner_contents, context::Symbol + ex::TemplateExpression, new_inner_contents, context::Symbol ) raw_contents = get_contents(ex) raw_contents_keys = keys(raw_contents) @@ -255,7 +275,7 @@ end """We combine the operators of each inner expression.""" function DE.combine_operators( - ex::BlueprintExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing + ex::TemplateExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing ) where {T,N} raw_contents = get_contents(ex) function_keys = keys(raw_contents) @@ -267,7 +287,7 @@ end """We simplify each inner expression.""" function DE.simplify_tree!( - ex::BlueprintExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing + ex::TemplateExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing ) where {T,N} raw_contents = get_contents(ex) function_keys = keys(raw_contents) @@ -277,12 +297,12 @@ function DE.simplify_tree!( return with_contents(ex, new_contents) end -function CO.count_constants_for_optimization(ex::BlueprintExpression) +function CO.count_constants_for_optimization(ex::TemplateExpression) return sum(CO.count_constants_for_optimization, values(get_contents(ex))) end function CC.check_constraints( - ex::BlueprintExpression, + ex::TemplateExpression, options::AbstractOptions, maxsize::Int, cursize::Union{Int,Nothing}=nothing, From a551e83f4c1022a0d9c4518a74ae7097614568bd Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 21:33:06 +0100 Subject: [PATCH 23/58] docs: change example name --- examples/{blueprint_expression.jl => template_expression.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{blueprint_expression.jl => template_expression.jl} (100%) diff --git a/examples/blueprint_expression.jl b/examples/template_expression.jl similarity index 100% rename from examples/blueprint_expression.jl rename to examples/template_expression.jl From 0514bcaf817f0c622de127f0d8fad316278f6527 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 21:42:12 +0100 Subject: [PATCH 24/58] feat: permit complex structured output --- src/Configure.jl | 2 +- src/Dataset.jl | 26 +++----------------------- src/LossFunctions.jl | 4 ++-- src/SymbolicRegression.jl | 15 +++------------ src/TemplateExpression.jl | 13 +++---------- 5 files changed, 12 insertions(+), 48 deletions(-) diff --git a/src/Configure.jl b/src/Configure.jl index eefd63619..0fedd2f54 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -88,7 +88,7 @@ function test_dataset_configuration( ) where {T<:DATA_TYPE} n = dataset.n if n != size(dataset.X, 2) || - (dataset.y !== nothing && n != size(dataset.y::AbstractArray{T}, 1)) + (dataset.y !== nothing && n != size(dataset.y::AbstractArray, 1)) throw( AssertionError( "Dataset dimensions are invalid. Make sure X is of shape [features, rows], y is of shape [rows] and if there are weights, they are of shape [rows].", diff --git a/src/Dataset.jl b/src/Dataset.jl index c8cb9767a..954f71620 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -49,7 +49,7 @@ mutable struct Dataset{ T<:DATA_TYPE, L<:LOSS_TYPE, AX<:AbstractMatrix{T}, - AY<:Union{AbstractVector{T},Nothing}, + AY<:Union{AbstractVector,Nothing}, AW<:Union{AbstractVector{T},Nothing}, NT<:NamedTuple, XU<:Union{AbstractVector{<:Quantity},Nothing}, @@ -93,7 +93,7 @@ Construct a dataset to pass between internal functions. """ function Dataset( X::AbstractMatrix{T}, - y::Union{AbstractVector{T},Nothing}=nothing, + y::Union{AbstractVector,Nothing}=nothing, loss_type::Type{L}=Nothing; index::Int=1, weights::Union{AbstractVector{T},Nothing}=nothing, @@ -150,7 +150,7 @@ function Dataset( else y_variable_name end - avg_y = if y === nothing + avg_y = if y === nothing || !(eltype(y) isa Number) nothing else if weighted @@ -222,26 +222,6 @@ function Dataset( y_sym_units, ) end -function Dataset( - X::AbstractMatrix, - y::Union{<:AbstractVector,Nothing}=nothing; - weights::Union{<:AbstractVector,Nothing}=nothing, - kws..., -) - T = promote_type( - eltype(X), - (y === nothing) ? eltype(X) : eltype(y), - (weights === nothing) ? eltype(X) : eltype(weights), - ) - X = Base.Fix1(convert, T).(X) - if y !== nothing - y = Base.Fix1(convert, T).(y) - end - if weights !== nothing - weights = Base.Fix1(convert, T).(weights) - end - return Dataset(X, y; weights=weights, kws...) -end function error_on_mismatched_size(_, ::Nothing) return nothing diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 0b6ca0414..c722e2790 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -12,7 +12,7 @@ using ..DimensionalAnalysisModule: violates_dimensional_constraints function _loss( x::AbstractArray{T}, y::AbstractArray{T}, loss::LT -) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} +) where {T,LT<:Union{Function,SupervisedLoss}} if loss isa SupervisedLoss return LossFunctions.mean(loss, x, y) else @@ -23,7 +23,7 @@ end function _weighted_loss( x::AbstractArray{T}, y::AbstractArray{T}, w::AbstractArray{T}, loss::LT -) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} +) where {T,LT<:Union{Function,SupervisedLoss}} if loss isa SupervisedLoss return sum(loss, x, y, w; normalize=true) else diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index b405fece4..9effdee14 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -412,7 +412,7 @@ which is useful for debugging and profiling. """ function equation_search( X::AbstractMatrix{T}, - y::AbstractMatrix{T}; + y::AbstractMatrix; niterations::Int=10, weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing, options::AbstractOptions=Options(), @@ -483,17 +483,8 @@ function equation_search( end function equation_search( - X::AbstractMatrix{T1}, y::AbstractMatrix{T2}; kw... -) where {T1<:DATA_TYPE,T2<:DATA_TYPE} - U = promote_type(T1, T2) - return equation_search( - convert(AbstractMatrix{U}, X), convert(AbstractMatrix{U}, y); kw... - ) -end - -function equation_search( - X::AbstractMatrix{T1}, y::AbstractVector{T2}; kw... -) where {T1<:DATA_TYPE,T2<:DATA_TYPE} + X::AbstractMatrix{T}, y::AbstractVector; kw... +) where {T<:DATA_TYPE} return equation_search(X, reshape(y, (1, size(y, 1))); kw..., v_dim_out=Val(1)) end diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 4971d82fe..b0cb8b29e 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -321,12 +321,9 @@ function CC.check_constraints( end # Then, we check other constraints for inner expressions: - if any(t -> !CC.check_constraints(t, options, maxsize, cursize), values(raw_contents)) - return false - end - - # Then, we check the constraints for the combined tree: - return CC.check_constraints(get_tree(ex), options, maxsize, cursize) + return all( + t -> CC.check_constraints(t, options, maxsize, cursize), values(raw_contents) + ) end function contains_other_features_than(tree::AbstractExpression, features) return contains_other_features_than(get_tree(tree), features) @@ -338,9 +335,5 @@ function contains_other_features_than(tree::AbstractExpressionNode, features) end # TODO: Add custom behavior to adjust what feature nodes can be generated -# TODO: Add custom printing -# TODO: Better versions: -# - Allow evaluation to call structure function - in which case the structure would simply combine the results. -# - Maybe we want to do similar for string output as well. That way, the operators provided don't really matter. end From acd7762c2a22a5f3a815b221ace0967b022a4ff6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 21:48:36 +0100 Subject: [PATCH 25/58] fix: weak dispatch in TemplateExpression --- examples/template_expression.jl | 26 -------------------------- src/Mutate.jl | 7 ++----- 2 files changed, 2 insertions(+), 31 deletions(-) diff --git a/examples/template_expression.jl b/examples/template_expression.jl index fb53dc1b5..f9693d2b5 100644 --- a/examples/template_expression.jl +++ b/examples/template_expression.jl @@ -70,29 +70,3 @@ dataset = Dataset(X', y) mach = machine(model, X, y) fit!(mach) - -println("hello") - -# using Profile -# Profile.init(n=10^8, delay=1e-4) -# mach = machine(model, X, y) -# @profile fit!(mach, verbosity=0) -# Profile.clear() -# mach = machine(model, X, y) -# @profile fit!(mach) - -# using PProf -# pprof() -# idx1 = lastindex(report(mach).equations) -# ypred1 = predict(mach, (data=X, idx=idx1)) -# loss1 = sum(i -> abs2(ypred1[i] - y[i]), eachindex(y)) - -# # Should keep all parameters -# stop_at[] = 1e-5 -# fit!(mach) -# idx2 = lastindex(report(mach).equations) -# ypred2 = predict(mach, (data=X, idx=idx2)) -# loss2 = sum(i -> abs2(ypred2[i] - y[i]), eachindex(y)) - -# # Should get better: -# @test loss1 >= loss2 diff --git a/src/Mutate.jl b/src/Mutate.jl index 4e0550329..ddc481ee3 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -91,11 +91,8 @@ Note that the weights were already copied, so you don't need to worry about muta - `curmaxsize::Int`: The current maximum size constraint for the member's expression tree. """ function condition_mutation_weights!( - weights::AbstractMutationWeights, - member::PopMember, - options::AbstractOptions, - curmaxsize::Int, -) + weights::AbstractMutationWeights, member::P, options::AbstractOptions, curmaxsize::Int +) where {T,L,N<:AbstractExpression,P<:PopMember{T,L,N}} tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 From 686b844debfb39a42ed3db919b7f04149d0726d1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 22:29:06 +0100 Subject: [PATCH 26/58] fix: miscalculations of `maxsize` --- src/AdaptiveParsimony.jl | 4 +--- src/CheckConstraints.jl | 2 +- src/HallOfFame.jl | 8 +++----- src/SearchUtils.jl | 2 +- src/SymbolicRegression.jl | 2 +- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/AdaptiveParsimony.jl b/src/AdaptiveParsimony.jl index aa33fa613..e3fded95c 100644 --- a/src/AdaptiveParsimony.jl +++ b/src/AdaptiveParsimony.jl @@ -24,9 +24,7 @@ struct RunningSearchStatistics end function RunningSearchStatistics(; options::AbstractOptions, window_size::Int=100000) - maxsize = options.maxsize - actualMaxsize = maxsize + MAX_DEGREE - init_frequencies = ones(Float64, actualMaxsize) + init_frequencies = ones(Float64, options.maxsize) return RunningSearchStatistics( window_size, init_frequencies, copy(init_frequencies) / sum(init_frequencies) diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index fb0bbb712..64e2d6df1 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -85,7 +85,7 @@ function check_constraints( maxsize::Int, cursize::Union{Int,Nothing}=nothing, )::Bool - ((cursize === nothing) ? compute_complexity(tree, options) : cursize) > maxsize && + ((cursize === nothing) ? compute_complexity(tree, options) : cursize) >= maxsize && return false count_depth(tree) > options.maxdepth && return false for i in 1:(options.nbin) diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 71032dfd5..a75b82939 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -63,7 +63,6 @@ Arguments: function HallOfFame( options::AbstractOptions, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} - actualMaxsize = options.maxsize + MAX_DEGREE base_tree = create_expression(zero(T), options, dataset) return HallOfFame{T,L,typeof(base_tree)}( @@ -75,9 +74,9 @@ function HallOfFame( options; parent=-1, deterministic=options.deterministic, - ) for i in 1:actualMaxsize + ) for i in 1:(options.maxsize) ], - [false for i in 1:actualMaxsize], + [false for i in 1:(options.maxsize)], ) end @@ -95,8 +94,7 @@ function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} P = PopMember{T,L,N} # Dominating pareto curve - must be better than all simpler equations dominating = P[] - actualMaxsize = length(hallOfFame.members) - for size in 1:actualMaxsize + for size in eachindex(hallOfFame.members) if !hallOfFame.exists[size] continue end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index f1ba1cfb3..23358d9dc 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -681,7 +681,7 @@ function update_hall_of_fame!( ) where {PM<:PopMember} for member in members size = compute_complexity(member, options) - valid_size = 0 < size < options.maxsize + MAX_DEGREE + valid_size = 0 < size <= options.maxsize if !valid_size continue end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 9effdee14..1e6cfb64f 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -1066,7 +1066,7 @@ end ) num_evals += evals_from_optimize if options.batching - for i_member in 1:(options.maxsize + MAX_DEGREE) + for i_member in 1:(options.maxsize) score, result_loss = score_func(dataset, best_seen.members[i_member], options) best_seen.members[i_member].score = score best_seen.members[i_member].loss = result_loss From e71ec0731c7dd8635b7adbd56bae1823932d53e9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 22:29:23 +0100 Subject: [PATCH 27/58] feat: export `with_contents` and `with_metadata` --- src/SymbolicRegression.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 1e6cfb64f..12645765c 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -51,6 +51,8 @@ export Population, get_tree, get_contents, get_metadata, + with_contents, + with_metadata, #Operators plus, @@ -123,7 +125,9 @@ using DynamicExpressions: node_type, get_tree, get_contents, - get_metadata + get_metadata, + with_contents, + with_metadata using DynamicExpressions: with_type_parameters @reexport using LossFunctions: MarginLoss, From 108d1c18716fcfafa21747bcb1c6b47c07daa492 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 22:29:38 +0100 Subject: [PATCH 28/58] feat: callable TemplateExpression --- src/TemplateExpression.jl | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index b0cb8b29e..41a3e8898 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -211,18 +211,27 @@ function DE.string_tree( # TODO: Make a fallback function in case the structure function is undefined. return get_metadata(tree).structure(inner_strings) end -function LF.eval_tree_dispatch( - tree::TemplateExpression, dataset::Dataset, options::AbstractOptions, idx -) +function DE.eval_tree_array( + tree::TemplateExpression{T}, + cX::AbstractMatrix{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., +) where {T} raw_contents = get_contents(tree) # Raw numerical results of each inner expression: - outs = map(ex -> LF.eval_tree_dispatch(ex, dataset, options, idx), values(raw_contents)) + outs = map(ex -> DE.eval_tree_array(ex, cX, operators; kws...), values(raw_contents)) # Combine them using the structure function: results = NamedTuple{keys(raw_contents)}(map(first, outs)) return get_metadata(tree).structure(results), all(last, outs) end +function (ex::TemplateExpression)( + X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... +) + # TODO: Why do we need to do this? It should automatically handle this! + return DE.eval_tree_array(ex, X, operators; kws...) +end function DA.violates_dimensional_constraints( tree::TemplateExpression, dataset::Dataset, options::AbstractOptions ) From bd6d8f520f46b7efd5e1b463954ee10770a05732 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 22:31:43 +0100 Subject: [PATCH 29/58] fix: new `maxsize` miscalc --- src/CheckConstraints.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index 64e2d6df1..fb0bbb712 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -85,7 +85,7 @@ function check_constraints( maxsize::Int, cursize::Union{Int,Nothing}=nothing, )::Bool - ((cursize === nothing) ? compute_complexity(tree, options) : cursize) >= maxsize && + ((cursize === nothing) ? compute_complexity(tree, options) : cursize) > maxsize && return false count_depth(tree) > options.maxdepth && return false for i in 1:(options.nbin) From 8433d66df3ccc49f40cad21921f250a91f11979e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Oct 2024 22:38:25 +0100 Subject: [PATCH 30/58] fix: `check_constraints` for TemplateExpression --- src/TemplateExpression.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 41a3e8898..39e41fd04 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -329,10 +329,15 @@ function CC.check_constraints( return false end + # We also check the combined complexity: + ((cursize === nothing) ? ComplexityModule.compute_complexity(ex, options) : cursize) > + maxsize && return false + # Then, we check other constraints for inner expressions: return all( - t -> CC.check_constraints(t, options, maxsize, cursize), values(raw_contents) + t -> CC.check_constraints(t, options, maxsize, nothing), values(raw_contents) ) + # TODO: The concept of `cursize` doesn't really make sense here. end function contains_other_features_than(tree::AbstractExpression, features) return contains_other_features_than(get_tree(tree), features) From 7ed6e9350d4a17b2ee7d7d3f55952d54a12e81a6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 08:26:00 +0100 Subject: [PATCH 31/58] test: add missing tags --- test/test_template_expression.jl | 214 +++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 test/test_template_expression.jl diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl new file mode 100644 index 000000000..364bea20f --- /dev/null +++ b/test/test_template_expression.jl @@ -0,0 +1,214 @@ +@testitem "Basic utility of the TemplateExpression" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: SymbolicRegression as SR + using SymbolicRegression.CheckConstraintsModule: check_constraints + using DynamicExpressions: OperatorEnum + + options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + operators = options.operators + variable_names = (i -> "x$i").(1:3) + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + + # For combining expressions to a single expression: + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) = + "sin($(nt.f)) + $(nt.g)^2" + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) = + @. sin(nt.f) + nt.g^2 + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:Expression}}}) = + sin(nt.f) + nt.g * nt.g + + variable_mapping = (; f=[1, 2], g=[3]) + st_expr = TemplateExpression( + (; f=x1, g=cos(x3)); + structure=my_structure, + operators, + variable_names, + variable_mapping, + ) + @test string_tree(st_expr) == "sin(x1) + cos(x3)^2" + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(cos, sin)) + + # Changing the operators will change how the expression is interpreted for + # parts that are already evaluated: + @test string_tree(st_expr, operators) == "sin(x1) + sin(x3)^2" + + # We can evaluate with this too: + cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] + out, completed = st_expr(cX) + @test completed + @test out ≈ [sin(1.0) + cos(5.0)^2, sin(2.0) + cos(6.0)^2] + + # And also check the contents: + @test check_constraints(st_expr, options, 100) + + # We can see that violating the constraints will cause a violation: + new_expr = with_contents(st_expr, (; f=x3, g=cos(x3))) + @test !check_constraints(new_expr, options, 100) + new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) + @test check_constraints(new_expr, options, 100) + new_expr = with_contents(st_expr, (; f=x2, g=cos(x1))) + @test !check_constraints(new_expr, options, 100) + + # Checks the size of each individual expression: + new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) + + @test compute_complexity(new_expr, options) == 3 + @test check_constraints(new_expr, options, 3) + @test !check_constraints(new_expr, options, 2) +end +@testitem "Expression interface" tags = [:part3] begin + using SymbolicRegression + using DynamicExpressions: OperatorEnum + using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + + # For combining expressions to a single expression: + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) = + "sin($(nt.f)) + $(nt.g)^2" + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) = + @. sin(nt.f) + nt.g^2 + my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:Expression}}}) = + sin(nt.f) + nt.g * nt.g + + variable_mapping = (; f=[1, 2], g=[3]) + st_expr = TemplateExpression( + (; f=x1, g=x3); structure=my_structure, operators, variable_names, variable_mapping + ) + @test Interfaces.test(ExpressionInterface, TemplateExpression, [st_expr]) +end +@testitem "Utilising TemplateExpression to build vector expressions" tags = [:part3] begin + using SymbolicRegression + using Random: rand + + # Define the structure function, which returns a tuple: + function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) + return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2), $(nt.f) + $(nt.g3) )" + end + function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) + return map( + i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i], nt.f[i] + nt.g3[i]), + eachindex(nt.f), + ) + end + + # Set up operators and variable names + options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + + # Create expressions + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); options.operators, variable_names)).(1:3) + + # Test with vector inputs: + nt_vector = NamedTuple{(:f, :g1, :g2, :g3)}((1:3, 4:6, 7:9, 10:12)) + @test my_structure(nt_vector) == [(5, 8, 11), (7, 10, 13), (9, 12, 15)] + + # And string inputs: + nt_string = NamedTuple{(:f, :g1, :g2, :g3)}(("x1", "x2", "x3", "x2")) + @test my_structure(nt_string) == "( x1 + x2, x1 + x3, x1 + x2 )" + + # Now, using TemplateExpression: + variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) + st_expr = TemplateExpression( + (; f=x1, g1=x2, g2=x3, g3=x2); + structure=my_structure, + options.operators, + variable_names, + variable_mapping, + ) + @test string_tree(st_expr) == "( x1 + x2, x1 + x3, x1 + x2 )" + + # We can directly call it: + cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] + out, completed = st_expr(cX) + @test completed + @test out == [(1 + 3, 1 + 5, 1 + 3), (2 + 4, 2 + 6, 2 + 4)] +end +@testitem "TemplateExpression Creation" tags = [:part3] begin + using SymbolicRegression + using DynamicExpressions: get_operators, get_variable_names + + operators = + Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)).operators + variable_names = (i -> "x$i").(1:3) + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + + my_structure(nt) = nt.f + + variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) + + st_expr = TemplateExpression( + (; f=x1, g1=x3, g2=x3, g3=x3); + structure=my_structure, + operators, + variable_names, + variable_mapping, + ) + + @test st_expr isa TemplateExpression + @test get_operators(st_expr) == operators + @test get_variable_names(st_expr) == variable_names + @test get_metadata(st_expr).structure == my_structure +end +# Integration test with fit! and performance check +@testitem "Integration Test with fit! and Performance Check" begin + using SymbolicRegression + using Random: rand + using MLJBase: machine, fit!, report + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1, x2, x3 = + (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) + + variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) + + st_expr = TemplateExpression( + (; f=x1, g1=x3, g2=x3, g3=x3); + structure=my_structure, + operators, + variable_names, + variable_mapping, + ) + + model = SRRegressor(; + niterations=200, + binary_operators=(+, *, /, -), + unary_operators=(sin, cos), + populations=30, + maxsize=30, + expression_type=TemplateExpression, + expression_options=(; structure=my_structure, variable_mapping), + parallelism=:multithreading, + elementwise_loss=((x1, x2, x3), (y1, y2, y3)) -> + (y1 - x1)^2 + (y2 - x2)^2 + (y3 - x3)^2, + ) + + X = rand(100, 3) + y = [ + (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 2]) + X[i, 3]^2, sin(X[i, 3]) + X[i, 3]^2) for + i in eachindex(axes(X, 1)) + ] + + dataset = Dataset(X', y) + + mach = machine(model, X, y) + fit!(mach) + + # Check the performance of the model + r = report(mach) + idx = r.best_idx + best_loss = r.losses[idx] + + # Ensure the loss is within a reasonable range + @test best_loss < 1e-2 # Adjust this threshold based on expected performance + + # The final model should be a TemplateExpression: + +end From 5a0f77c38bda7d20bf0322c5f66decd8ce636939 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 08:26:24 +0100 Subject: [PATCH 32/58] test: fix pretty print test with fixed `maxsize` --- test/test_pretty_printing.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl index 42a28d14e..56cfa1f6b 100644 --- a/test/test_pretty_printing.jl +++ b/test/test_pretty_printing.jl @@ -63,11 +63,7 @@ end .exists[6] = false .members[6] = undef .exists[7] = false - .members[7] = undef - .exists[8] = false - .members[8] = undef - .exists[9] = false - .members[9] = undef" + .members[7] = undef" @test s_hof == true_s end From 8c912001864e7dfb2473b0170df4f2458f305033 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 08:28:54 +0100 Subject: [PATCH 33/58] fix: make `weighted` into callable `is_weighted` --- src/Core.jl | 2 +- src/Dataset.jl | 16 +++++++--------- src/LossFunctions.jl | 5 +++-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/Core.jl b/src/Core.jl index 1ae420863..2d6e73d89 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -12,7 +12,7 @@ include("Options.jl") using .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE -using .DatasetModule: Dataset +using .DatasetModule: Dataset, is_weighted using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation using .OptionsStructModule: AbstractOptions, diff --git a/src/Dataset.jl b/src/Dataset.jl index 954f71620..f9e28bcc5 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -19,8 +19,7 @@ import ...deprecate_varmap dataset, if any. - `n::Int`: The number of samples. - `nfeatures::Int`: The number of features. -- `weighted::Bool`: Whether the dataset is non-uniformly weighted. -- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted, +- `weights::Union{AbstractVector,Nothing}`: If the dataset is weighted, these specify the per-sample weight (with shape `(n,)`). - `extra::NamedTuple`: Extra information to pass to a custom evaluation function. Since this is an arbitrary named tuple, you could pass @@ -50,7 +49,7 @@ mutable struct Dataset{ L<:LOSS_TYPE, AX<:AbstractMatrix{T}, AY<:Union{AbstractVector,Nothing}, - AW<:Union{AbstractVector{T},Nothing}, + AW<:Union{AbstractVector,Nothing}, NT<:NamedTuple, XU<:Union{AbstractVector{<:Quantity},Nothing}, YU<:Union{Quantity,Nothing}, @@ -62,7 +61,6 @@ mutable struct Dataset{ const index::Int const n::Int const nfeatures::Int - const weighted::Bool const weights::AW const extra::NT const avg_y::Union{T,Nothing} @@ -81,7 +79,7 @@ end Dataset(X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing, loss_type::Type=Nothing; - weights::Union{AbstractVector{T}, Nothing}=nothing, + weights::Union{AbstractVector, Nothing}=nothing, variable_names::Union{Array{String, 1}, Nothing}=nothing, y_variable_name::Union{String,Nothing}=nothing, extra::NamedTuple=NamedTuple(), @@ -96,7 +94,7 @@ function Dataset( y::Union{AbstractVector,Nothing}=nothing, loss_type::Type{L}=Nothing; index::Int=1, - weights::Union{AbstractVector{T},Nothing}=nothing, + weights::Union{AbstractVector,Nothing}=nothing, variable_names::Union{Array{String,1},Nothing}=nothing, display_variable_names=variable_names, y_variable_name::Union{String,Nothing}=nothing, @@ -133,7 +131,6 @@ function Dataset( n = size(X, BATCH_DIM) nfeatures = size(X, FEATURE_DIM) - weighted = weights !== nothing variable_names = if variable_names === nothing ["x$(i)" for i in 1:nfeatures] else @@ -153,7 +150,7 @@ function Dataset( avg_y = if y === nothing || !(eltype(y) isa Number) nothing else - if weighted + if weights !== nothing sum(y .* weights) / sum(weights) else sum(y) / n @@ -207,7 +204,6 @@ function Dataset( index, n, nfeatures, - weighted, weights, extra, avg_y, @@ -223,6 +219,8 @@ function Dataset( ) end +is_weighted(dataset::Dataset) = dataset.weights !== nothing + function error_on_mismatched_size(_, ::Nothing) return nothing end diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index c722e2790..45dd10ab5 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -6,7 +6,8 @@ using DynamicExpressions: using LossFunctions: LossFunctions using LossFunctions: SupervisedLoss using ..InterfaceDynamicExpressionsModule: expected_array_type -using ..CoreModule: AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE +using ..CoreModule: + AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE, is_weighted using ..ComplexityModule: compute_complexity using ..DimensionalAnalysisModule: violates_dimensional_constraints @@ -65,7 +66,7 @@ function _eval_loss( return L(Inf) end - loss_val = if dataset.weighted + loss_val = if is_weighted(dataset) _weighted_loss( prediction, maybe_getindex(dataset.y, idx), From 46a59e91de463b3bc0cab09fe824c4a33e830011 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 08:32:11 +0100 Subject: [PATCH 34/58] test: update dataset tests --- test/runtests.jl | 6 +----- test/test_dataset.jl | 27 ++++++++++++++++++++------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e2ebfbd4a..377dc52fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,11 +110,7 @@ end end include("test_units.jl") - -@testitem "Dataset" tags = [:part3] begin - include("test_dataset.jl") -end - +include("test_dataset.jl") include("test_mixed.jl") @testitem "Testing fast-cycle and custom variable names" tags = [:part2] begin diff --git a/test/test_dataset.jl b/test/test_dataset.jl index 9fdbcfd74..505471979 100644 --- a/test/test_dataset.jl +++ b/test/test_dataset.jl @@ -1,16 +1,29 @@ -using SymbolicRegression -using DispatchDoctor: allow_unstable - -@testset "Dataset construction" begin +@testitem "Dataset construction" tags = [:part3] begin + using SymbolicRegression # Promotion of types: dataset = Dataset(randn(3, 32), randn(Float32, 32); weights=randn(Float32, 32)) - @test typeof(dataset.y) == Array{Float64,1} - @test typeof(dataset.weights) == Array{Float64,1} + + # Will not automatically convert: + @test typeof(dataset.X) == Array{Float64,2} + @test typeof(dataset.y) == Array{Float32,1} + @test typeof(dataset.weights) == Array{Float32,1} end -@testset "With deprecated kwarg" begin +@testitem "With deprecated kwarg" tags = [:part3] begin + using SymbolicRegression + using DispatchDoctor: allow_unstable dataset = allow_unstable() do Dataset(randn(ComplexF32, 3, 32), randn(ComplexF32, 32); loss_type=Float64) end @test dataset isa Dataset{ComplexF32,Float64} end + +@testitem "vector output" tags = [:part3] begin + using SymbolicRegression + + X = randn(Float64, 3, 32) + y = [ntuple(_ -> randn(Float64), 3) for _ in 1:32] + dataset = Dataset(X, y) + @test dataset isa Dataset{Float64,Float64} + @test dataset.y isa Vector{NTuple{3,Float64}} +end From 9ed392c2d55c04a110e00e6b5d0e31dba34aa27b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 08:37:33 +0100 Subject: [PATCH 35/58] feat: generalize MLJ interface --- src/MLJInterface.jl | 4 ++-- test/test_search_statistics.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 028493ff7..4d7a1d140 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -508,7 +508,7 @@ const input_scitype = Union{ MMI.metadata_model( SRRegressor; input_scitype, - target_scitype=AbstractVector{<:MMI.Continuous}, + target_scitype=AbstractVector{<:Any}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor", @@ -517,7 +517,7 @@ MMI.metadata_model( MMI.metadata_model( MultitargetSRRegressor; input_scitype, - target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, + target_scitype=Union{MMI.Table(Any),AbstractMatrix{<:Any}}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor", diff --git a/test/test_search_statistics.jl b/test/test_search_statistics.jl index 35a9b0175..cc2f5360a 100644 --- a/test/test_search_statistics.jl +++ b/test/test_search_statistics.jl @@ -13,7 +13,7 @@ end normalize_frequencies!(statistics) -@test sum(statistics.frequencies) == 1022 +@test sum(statistics.frequencies) == 1020 @test sum(statistics.normalized_frequencies) ≈ 1.0 @test statistics.normalized_frequencies[5] > statistics.normalized_frequencies[15] From 4e52ca4dc7b51b5b1093e77170480239e5d67281 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 08:39:22 +0100 Subject: [PATCH 36/58] fix: missing calls to `.weighted` --- src/Configure.jl | 4 ++-- src/SymbolicRegression.jl | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Configure.jl b/src/Configure.jl index 0fedd2f54..2b184e5cd 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -101,7 +101,7 @@ function test_dataset_configuration( end if !(typeof(options.elementwise_loss) <: SupervisedLoss) && - dataset.weighted && + is_weighted(dataset) && !(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)]) throw( AssertionError( @@ -132,7 +132,7 @@ function move_functions_to_workers( continue end ops = (options.elementwise_loss,) - example_inputs = if dataset.weighted + example_inputs = if is_weighted(dataset) (zero(T), zero(T), zero(T)) else (zero(T), zero(T)) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 12645765c..0530dba4a 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -236,6 +236,7 @@ using .CoreModule: Options, AbstractMutationWeights, MutationWeights, + is_weighted, sample_mutation, plus, sub, From 97ea32b1521aa4877f5af414f628a6211b512d56 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:13:23 +0100 Subject: [PATCH 37/58] fix: update type inference utility for TemplateExpression --- src/InterfaceDynamicExpressions.jl | 47 ++++++++++-------------------- src/LossFunctions.jl | 19 ++++++------ src/Mutate.jl | 12 -------- src/ParametricExpression.jl | 36 +++++++++++++++++++++-- src/TemplateExpression.jl | 5 +++- test/test_template_expression.jl | 23 ++++++++++----- 6 files changed, 77 insertions(+), 65 deletions(-) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 887627b7d..0a9629f3e 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -7,11 +7,10 @@ using DynamicExpressions: GenericOperatorEnum, AbstractExpression, AbstractExpressionNode, - ParametricExpression, Node, GraphNode using DynamicQuantities: dimension, ustrip -using ..CoreModule: AbstractOptions +using ..CoreModule: AbstractOptions, Dataset using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap using ..UtilsModule: subscriptify @@ -56,7 +55,6 @@ function DE.eval_tree_array( options::AbstractOptions; kws..., ) - A = expected_array_type(X) return DE.eval_tree_array( tree, X, @@ -64,31 +62,16 @@ function DE.eval_tree_array( turbo=options.turbo, bumper=options.bumper, kws..., - )::Tuple{A,Bool} -end -function DE.eval_tree_array( - tree::ParametricExpression, - X::AbstractMatrix, - classes::AbstractVector{<:Integer}, - options::AbstractOptions; - kws..., -) - A = expected_array_type(X) - return DE.eval_tree_array( - tree, - X, - classes, - DE.get_operators(tree, options); - turbo=options.turbo, - bumper=options.bumper, - kws..., - )::Tuple{A,Bool} + )::Tuple{<:expected_array_type(X, typeof(tree)),Bool} end -# Improve type inference by telling Julia the expected array returned -function expected_array_type(X::AbstractArray) +"""Improve type inference by telling Julia the expected array returned.""" +function expected_array_type(X::AbstractArray, ::Type) return typeof(similar(X, axes(X, 2))) end +function expected_array_type(X::AbstractArray, ::Type, ::Val{:eval_grad_tree_array}) + return typeof(X) +end """ eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions, direction::Int) @@ -116,11 +99,12 @@ function DE.eval_diff_tree_array( options::AbstractOptions, direction::Int, ) - A = expected_array_type(X) # TODO: Add `AbstractExpression` implementation in `Expression.jl` return DE.eval_diff_tree_array( DE.get_tree(tree), X, DE.get_operators(tree, options), direction - )::Tuple{A,A,Bool} + )::Tuple{ + <:expected_array_type(X, typeof(tree)),<:expected_array_type(X, typeof(tree)),Bool + } end """ @@ -150,11 +134,13 @@ function DE.eval_grad_tree_array( options::AbstractOptions; kws..., ) - A = expected_array_type(X) - M = typeof(X) # TODO: This won't work with StaticArrays! return DE.eval_grad_tree_array( tree, X, DE.get_operators(tree, options); kws... - )::Tuple{A,M,Bool} + )::Tuple{ + <:expected_array_type(X, typeof(tree)), + <:expected_array_type(X, typeof(tree), Val(:eval_grad_tree_array)), + Bool, + } end """ @@ -167,11 +153,10 @@ function DE.differentiable_eval_tree_array( X::AbstractArray, options::AbstractOptions, ) - A = expected_array_type(X) # TODO: Add `AbstractExpression` implementation in `Expression.jl` return DE.differentiable_eval_tree_array( DE.get_tree(tree), X, DE.get_operators(tree, options) - )::Tuple{A,Bool} + )::Tuple{<:expected_array_type(X, typeof(tree)),Bool} end const WILDCARD_UNIT_STRING = "[?]" diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 45dd10ab5..7f68fe682 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -5,7 +5,6 @@ using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree, eval_tree_array using LossFunctions: LossFunctions using LossFunctions: SupervisedLoss -using ..InterfaceDynamicExpressionsModule: expected_array_type using ..CoreModule: AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE, is_weighted using ..ComplexityModule: compute_complexity @@ -42,15 +41,15 @@ end end end -for N in (:AbstractExpression, :AbstractExpressionNode) - @eval function eval_tree_dispatch( - tree::$N, dataset::Dataset, options::AbstractOptions, idx - ) - A = expected_array_type(dataset.X) - return eval_tree_array( - tree, maybe_getindex(dataset.X, :, idx), options - )::Tuple{A,Bool} - end +function eval_tree_dispatch( + tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx +) + return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) +end +function eval_tree_dispatch( + tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx +) + return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) end # Evaluate the loss of a particular expression on the input dataset. diff --git a/src/Mutate.jl b/src/Mutate.jl index ddc481ee3..7b828f6f3 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -2,7 +2,6 @@ module MutateModule using DynamicExpressions: AbstractExpression, - ParametricExpression, with_contents, get_tree, preserve_sharing, @@ -149,17 +148,6 @@ function condition_mutate_constant!( return nothing end -function condition_mutate_constant!( - ::Type{<:ParametricExpression}, - weights::AbstractMutationWeights, - member::PopMember, - options::AbstractOptions, - curmaxsize::Int, -) - # Avoid modifying the mutate_constant weight, since - # otherwise we would be mutating constants all the time! - return nothing -end # Go through one simulated options.annealing mutation cycle # exp(-delta/T) defines probability of accepting a change diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 0180a1602..b31f60e70 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -5,6 +5,7 @@ this file just adds custom behavior for SymbolicRegression.jl, where needed. module ParametricExpressionModule using DynamicExpressions: + DynamicExpressions as DE, AbstractExpression, ParametricExpression, ParametricNode, @@ -17,10 +18,12 @@ using DynamicExpressions: using StatsBase: StatsBase using Random: default_rng, AbstractRNG -using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, AbstractMutationWeights +using ..PopMemberModule: PopMember using ..InterfaceDynamicExpressionsModule: expected_array_type using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..MutateModule: MutateModule as MM using ..MutationFunctionsModule: MutationFunctionsModule as MF using ..ConstantOptimizationModule: ConstantOptimizationModule as CO @@ -58,18 +61,45 @@ function EB.consistency_checks(options::AbstractOptions, prototype::ParametricEx return nothing end +function DE.eval_tree_array( + tree::ParametricExpression, + X::AbstractMatrix, + classes::AbstractVector{<:Integer}, + options::AbstractOptions; + kws..., +) + return DE.eval_tree_array( + tree, + X, + classes, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + )::Tuple{<:expected_array_type(X, typeof(tree)),Bool} +end function LF.eval_tree_dispatch( tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx ) - A = expected_array_type(dataset.X) return eval_tree_array( tree, LF.maybe_getindex(dataset.X, :, idx), LF.maybe_getindex(dataset.extra.classes, idx), options.operators, - )::Tuple{A,Bool} + ) end +function MM.condition_mutate_constant!( + ::Type{<:ParametricExpression}, + weights::AbstractMutationWeights, + member::PopMember, + options::AbstractOptions, + curmaxsize::Int, +) + # Avoid modifying the mutate_constant weight, since + # otherwise we would be mutating constants all the time! + return nothing +end function MF.make_random_leaf( nfeatures::Int, ::Type{T}, diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 39e41fd04..2715f35e1 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -1,6 +1,7 @@ module ConstrainedExpressionModule using Random: AbstractRNG +using DispatchDoctor: @unstable using DynamicExpressions: DynamicExpressions as DE, AbstractStructuredExpression, @@ -23,7 +24,7 @@ using DynamicExpressions.InterfacesModule: using ..CoreModule: AbstractOptions, Dataset, CoreModule as CM, AbstractMutationWeights using ..ConstantOptimizationModule: ConstantOptimizationModule as CO -using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..MutationFunctionsModule: MutationFunctionsModule as MF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB using ..DimensionalAnalysisModule: DimensionalAnalysisModule as DA @@ -232,6 +233,8 @@ function (ex::TemplateExpression)( # TODO: Why do we need to do this? It should automatically handle this! return DE.eval_tree_array(ex, X, operators; kws...) end +@unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:TemplateExpression}) = Any + function DA.violates_dimensional_constraints( tree::TemplateExpression, dataset::Dataset, options::AbstractOptions ) diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl index 364bea20f..e54f7f799 100644 --- a/test/test_template_expression.jl +++ b/test/test_template_expression.jl @@ -156,19 +156,29 @@ end @test get_variable_names(st_expr) == variable_names @test get_metadata(st_expr).structure == my_structure end -# Integration test with fit! and performance check -@testitem "Integration Test with fit! and Performance Check" begin +@testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin using SymbolicRegression using Random: rand using MLJBase: machine, fit!, report - operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + operators = options.operators variable_names = (i -> "x$i").(1:3) x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) + function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) + return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2), $(nt.f) + $(nt.g3) )" + end + function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) + return map( + i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i], nt.f[i] + nt.g3[i]), + eachindex(nt.f), + ) + end + st_expr = TemplateExpression( (; f=x1, g1=x3, g2=x3, g3=x3); structure=my_structure, @@ -190,9 +200,9 @@ end (y1 - x1)^2 + (y2 - x2)^2 + (y3 - x3)^2, ) - X = rand(100, 3) + X = rand(100, 3) .* 10 y = [ - (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 2]) + X[i, 3]^2, sin(X[i, 3]) + X[i, 3]^2) for + (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 2]) + X[i, 3]^3, sin(X[i, 1]) + X[i, 3]^2) for i in eachindex(axes(X, 1)) ] @@ -208,7 +218,4 @@ end # Ensure the loss is within a reasonable range @test best_loss < 1e-2 # Adjust this threshold based on expected performance - - # The final model should be a TemplateExpression: - end From 26ca30aa72a7e99fa7c3b7b615807af24014c76a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:29:05 +0100 Subject: [PATCH 38/58] test: greatly improve TemplateExpression test --- test/test_template_expression.jl | 42 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl index e54f7f799..0c6211e91 100644 --- a/test/test_template_expression.jl +++ b/test/test_template_expression.jl @@ -167,20 +167,17 @@ end x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) + variable_mapping = (; f=[1, 2], g1=[3], g2=[3]) function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) - return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2), $(nt.f) + $(nt.g3) )" + return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2) )" end function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) - return map( - i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i], nt.f[i] + nt.g3[i]), - eachindex(nt.f), - ) + return map(i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i]), eachindex(nt.f)) end st_expr = TemplateExpression( - (; f=x1, g1=x3, g2=x3, g3=x3); + (; f=x1, g1=x3, g2=x3); structure=my_structure, operators, variable_names, @@ -188,23 +185,17 @@ end ) model = SRRegressor(; - niterations=200, - binary_operators=(+, *, /, -), - unary_operators=(sin, cos), - populations=30, - maxsize=30, + binary_operators=(+, *), + unary_operators=(sin,), + maxsize=15, expression_type=TemplateExpression, expression_options=(; structure=my_structure, variable_mapping), - parallelism=:multithreading, - elementwise_loss=((x1, x2, x3), (y1, y2, y3)) -> - (y1 - x1)^2 + (y2 - x2)^2 + (y3 - x3)^2, + elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2, + early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7, ) X = rand(100, 3) .* 10 - y = [ - (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 2]) + X[i, 3]^3, sin(X[i, 1]) + X[i, 3]^2) for - i in eachindex(axes(X, 1)) - ] + y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))] dataset = Dataset(X', y) @@ -216,6 +207,15 @@ end idx = r.best_idx best_loss = r.losses[idx] - # Ensure the loss is within a reasonable range - @test best_loss < 1e-2 # Adjust this threshold based on expected performance + @test best_loss < 1e-5 + + # Check the expression is split up correctly: + best_expr = r.equations[idx] + best_f = get_contents(best_expr).f + best_g1 = get_contents(best_expr).g1 + best_g2 = get_contents(best_expr).g2 + + @test best_f(X') ≈ (@. sin(X[:, 1])) + @test best_g1(X') ≈ (@. X[:, 3] * X[:, 3]) + @test best_g2(X') ≈ (@. X[:, 3]) end From 4f1f85297d0c8d48dbafd208bf5b414784b0b3b7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:32:07 +0100 Subject: [PATCH 39/58] docs: update `TemplateExpression` name --- docs/src/customization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/customization.md b/docs/src/customization.md index 7b04604bc..09c194341 100644 --- a/docs/src/customization.md +++ b/docs/src/customization.md @@ -49,7 +49,7 @@ If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_ case your expression needs additional parameters. See the method for `ParametricExpression` as an example. -You can look at the files `src/ParametricExpression.jl` and `src/ConstrainedExpression.jl` +You can look at the files `src/ParametricExpression.jl` and `src/TemplateExpression.jl` for more examples of custom expression types, though note that `ParametricExpression` itself is defined in DynamicExpressions.jl, while that file just overloads some methods for SymbolicRegression.jl. From 204e5efbd734c9108ff09aef0a1f3361893202cc Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:33:31 +0100 Subject: [PATCH 40/58] docs: update `TemplateExpression` --- docs/src/types.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/src/types.md b/docs/src/types.md index 92bf5632e..cd62389be 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -60,6 +60,16 @@ ParametricNode These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`. +## Template Expressions + +Template expressions are a type of expression that allows you to specify a predefined structure. +This lets you also fit vector expressions, as the custom evaluation structure can simply return +a vector of tuples. + +```@docs +TemplateExpression +``` + ## Population Groups of equations are given as a population, which is From 0b215863521dd9b6c00a7a319f3eb215b89175ec Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:34:58 +0100 Subject: [PATCH 41/58] test: tweak testitem name --- test/test_template_expression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl index 0c6211e91..80d65fa9a 100644 --- a/test/test_template_expression.jl +++ b/test/test_template_expression.jl @@ -129,7 +129,7 @@ end @test completed @test out == [(1 + 3, 1 + 5, 1 + 3), (2 + 4, 2 + 6, 2 + 4)] end -@testitem "TemplateExpression Creation" tags = [:part3] begin +@testitem "TemplateExpression getters" tags = [:part3] begin using SymbolicRegression using DynamicExpressions: get_operators, get_variable_names From 4ec4df6aecfc361803a9656fdb58290b8ba13ced Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:37:53 +0100 Subject: [PATCH 42/58] test: move integration test to example folder --- examples/template_expression.jl | 84 ++++++++++++++------------------ test/test_template_expression.jl | 62 +---------------------- 2 files changed, 38 insertions(+), 108 deletions(-) diff --git a/examples/template_expression.jl b/examples/template_expression.jl index f9693d2b5..ade5fc5cf 100644 --- a/examples/template_expression.jl +++ b/examples/template_expression.jl @@ -1,72 +1,62 @@ using SymbolicRegression -using DynamicExpressions: - DynamicExpressions as DE, - Metadata, - get_tree, - get_operators, - get_variable_names, - OperatorEnum, - AbstractExpression -using Random: MersenneTwister -using MLJBase: machine, fit!, predict, report -using Test +using Random: rand +using MLJBase: machine, fit!, report +using Test: @test -using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface - -# Impose structure: -# Function f(x1, x2) -# Function g(x3) -# y = sin(f) + g^2 -operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +operators = options.operators variable_names = (i -> "x$i").(1:3) x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) -# For combining expressions to a single expression: +variable_mapping = (; f=[1, 2], g1=[3], g2=[3]) + function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) - return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2), $(nt.f) + $(nt.g3) )" + return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2) )" end function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) - return map( - i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i], nt.f[i] + nt.g3[i]), eachindex(nt.f) - ) + return map(i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i]), eachindex(nt.f)) end -variable_mapping = (; f=[1, 2], g1=[3], g2=[3], g3=[3]) - st_expr = TemplateExpression( - (; f=x1, g1=x3, g2=x3, g3=x3); + (; f=x1, g1=x3, g2=x3); structure=my_structure, operators, variable_names, variable_mapping, ) -# @test Interfaces.test( -# ExpressionInterface, -# TemplateExpression, -# [st_expr] -# ) +X = rand(100, 3) .* 10 + +# Our dataset is a vector of 2-tuples +y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))] model = SRRegressor(; - niterations=200, - binary_operators=(+, *, /, -), - unary_operators=(sin, cos), - populations=30, - maxsize=30, + binary_operators=(+, *), + unary_operators=(sin,), + maxsize=15, expression_type=TemplateExpression, expression_options=(; structure=my_structure, variable_mapping), - parallelism=:multithreading, - elementwise_loss=((x1, x2, x3), (y1, y2, y3)) -> - (y1 - x1)^2 + (y2 - x2)^2 + (y3 - x3)^2, + # The elementwise needs to operate directly on each row of `y`: + elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2, + early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7, ) -X = rand(100, 3) -y = [ - (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 2]) + X[i, 3]^2, sin(X[i, 3]) + X[i, 3]^2) for - i in eachindex(axes(X, 1)) -] - -dataset = Dataset(X', y) - mach = machine(model, X, y) fit!(mach) + +# Check the performance of the model +r = report(mach) +idx = r.best_idx +best_loss = r.losses[idx] + +@test best_loss < 1e-5 + +# Check the expression is split up correctly: +best_expr = r.equations[idx] +best_f = get_contents(best_expr).f +best_g1 = get_contents(best_expr).g1 +best_g2 = get_contents(best_expr).g2 + +@test best_f(X') ≈ (@. sin(X[:, 1])) +@test best_g1(X') ≈ (@. X[:, 3] * X[:, 3]) +@test best_g2(X') ≈ (@. X[:, 3]) diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl index 80d65fa9a..2187ea334 100644 --- a/test/test_template_expression.jl +++ b/test/test_template_expression.jl @@ -157,65 +157,5 @@ end @test get_metadata(st_expr).structure == my_structure end @testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin - using SymbolicRegression - using Random: rand - using MLJBase: machine, fit!, report - - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - operators = options.operators - variable_names = (i -> "x$i").(1:3) - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - variable_mapping = (; f=[1, 2], g1=[3], g2=[3]) - - function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}}) - return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2) )" - end - function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}}) - return map(i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i]), eachindex(nt.f)) - end - - st_expr = TemplateExpression( - (; f=x1, g1=x3, g2=x3); - structure=my_structure, - operators, - variable_names, - variable_mapping, - ) - - model = SRRegressor(; - binary_operators=(+, *), - unary_operators=(sin,), - maxsize=15, - expression_type=TemplateExpression, - expression_options=(; structure=my_structure, variable_mapping), - elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2, - early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7, - ) - - X = rand(100, 3) .* 10 - y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))] - - dataset = Dataset(X', y) - - mach = machine(model, X, y) - fit!(mach) - - # Check the performance of the model - r = report(mach) - idx = r.best_idx - best_loss = r.losses[idx] - - @test best_loss < 1e-5 - - # Check the expression is split up correctly: - best_expr = r.equations[idx] - best_f = get_contents(best_expr).f - best_g1 = get_contents(best_expr).g1 - best_g2 = get_contents(best_expr).g2 - - @test best_f(X') ≈ (@. sin(X[:, 1])) - @test best_g1(X') ≈ (@. X[:, 3] * X[:, 3]) - @test best_g2(X') ≈ (@. X[:, 3]) + include("../examples/template_expression.jl") end From e5efb10a81bc827548d193badfa356b76cf3a186 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:40:37 +0100 Subject: [PATCH 43/58] docs: update comment --- src/Utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Utils.jl b/src/Utils.jl index 473845651..da67bcf4d 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -208,10 +208,10 @@ end function _get_thread_cache(cache::PerThreadCache{T}) where {T} if cache.num_threads[] < Threads.nthreads() Base.@lock cache.lock begin - # The reason we have this extra `.len[]` parameter is to avoid + # The reason we have this extra `.num_threads[]` parameter is to avoid # a race condition between a thread resizing the array concurrent # to the check above. Basically we want to make sure the array is - # always big enough by the time we get to using it. Since `.len[]` + # always big enough by the time we get to using it. Since `.num_threads[]` # is set last, we can safely use the array. if cache.num_threads[] < Threads.nthreads() resize!(cache.x, Threads.nthreads()) From 1c809d1149e110a59248d29a0e9d8d427b19f086 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 09:42:46 +0100 Subject: [PATCH 44/58] fix: type instability in `eval_tree_dispatch` --- src/LossFunctions.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 7f68fe682..ba8abcc9e 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -9,6 +9,7 @@ using ..CoreModule: AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE, is_weighted using ..ComplexityModule: compute_complexity using ..DimensionalAnalysisModule: violates_dimensional_constraints +using ..InterfaceDynamicExpressionsModule: expected_array_type function _loss( x::AbstractArray{T}, y::AbstractArray{T}, loss::LT @@ -44,12 +45,16 @@ end function eval_tree_dispatch( tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx ) - return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) + return eval_tree_array( + tree, maybe_getindex(dataset.X, :, idx), options + )::Tuple{<:expected_array_type(dataset.X, typeof(tree)),Bool} end function eval_tree_dispatch( tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx ) - return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) + return eval_tree_array( + tree, maybe_getindex(dataset.X, :, idx), options + )::Tuple{<:expected_array_type(dataset.X, typeof(tree)),Bool} end # Evaluate the loss of a particular expression on the input dataset. From c809c6cfea2735087682933ecdab9a0be2de751c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 10:13:34 +0100 Subject: [PATCH 45/58] fix: new type instabilities from `expected_array_type` --- src/InterfaceDynamicExpressions.jl | 31 ++++++++++++++++-------------- src/LossFunctions.jl | 12 ++++++------ src/ParametricExpression.jl | 6 ++++-- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 0a9629f3e..6c8aa45fd 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -55,14 +55,16 @@ function DE.eval_tree_array( options::AbstractOptions; kws..., ) - return DE.eval_tree_array( + A = expected_array_type(X, typeof(tree)) + out, complete = DE.eval_tree_array( tree, X, DE.get_operators(tree, options); turbo=options.turbo, bumper=options.bumper, kws..., - )::Tuple{<:expected_array_type(X, typeof(tree)),Bool} + ) + return out::A, complete::Bool end """Improve type inference by telling Julia the expected array returned.""" @@ -100,11 +102,11 @@ function DE.eval_diff_tree_array( direction::Int, ) # TODO: Add `AbstractExpression` implementation in `Expression.jl` - return DE.eval_diff_tree_array( + A = expected_array_type(X, typeof(tree)) + out, grad, complete = DE.eval_diff_tree_array( DE.get_tree(tree), X, DE.get_operators(tree, options), direction - )::Tuple{ - <:expected_array_type(X, typeof(tree)),<:expected_array_type(X, typeof(tree)),Bool - } + ) + return out::A, grad::A, complete::Bool end """ @@ -134,13 +136,12 @@ function DE.eval_grad_tree_array( options::AbstractOptions; kws..., ) - return DE.eval_grad_tree_array( + A = expected_array_type(X, typeof(tree)) + dA = expected_array_type(X, typeof(tree), Val(:eval_grad_tree_array)) + out, grad, complete = DE.eval_grad_tree_array( tree, X, DE.get_operators(tree, options); kws... - )::Tuple{ - <:expected_array_type(X, typeof(tree)), - <:expected_array_type(X, typeof(tree), Val(:eval_grad_tree_array)), - Bool, - } + ) + return out::A, grad::dA, complete::Bool end """ @@ -154,9 +155,11 @@ function DE.differentiable_eval_tree_array( options::AbstractOptions, ) # TODO: Add `AbstractExpression` implementation in `Expression.jl` - return DE.differentiable_eval_tree_array( + A = expected_array_type(X, typeof(tree)) + out, complete = DE.differentiable_eval_tree_array( DE.get_tree(tree), X, DE.get_operators(tree, options) - )::Tuple{<:expected_array_type(X, typeof(tree)),Bool} + ) + return out::A, complete::Bool end const WILDCARD_UNIT_STRING = "[?]" diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index ba8abcc9e..2830ecdcd 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -45,16 +45,16 @@ end function eval_tree_dispatch( tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx ) - return eval_tree_array( - tree, maybe_getindex(dataset.X, :, idx), options - )::Tuple{<:expected_array_type(dataset.X, typeof(tree)),Bool} + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) + return out::A, complete::Bool end function eval_tree_dispatch( tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx ) - return eval_tree_array( - tree, maybe_getindex(dataset.X, :, idx), options - )::Tuple{<:expected_array_type(dataset.X, typeof(tree)),Bool} + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) + return out::A, complete::Bool end # Evaluate the loss of a particular expression on the input dataset. diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index b31f60e70..4294acdc7 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -68,7 +68,8 @@ function DE.eval_tree_array( options::AbstractOptions; kws..., ) - return DE.eval_tree_array( + A = expected_array_type(X, typeof(tree)) + out, complete = DE.eval_tree_array( tree, X, classes, @@ -76,7 +77,8 @@ function DE.eval_tree_array( turbo=options.turbo, bumper=options.bumper, kws..., - )::Tuple{<:expected_array_type(X, typeof(tree)),Bool} + ) + return out::A, complete::Bool end function LF.eval_tree_dispatch( tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx From 8b0860da40daef630fbd56e26cbfa404b8ce8d05 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 10:15:40 +0100 Subject: [PATCH 46/58] style: formatting --- src/TemplateExpression.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 2715f35e1..a3b241747 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -251,9 +251,11 @@ end """ We need full specialization for constrained expressions, as they rely on subexpressions being combined. """ -CM.operator_specialization( +function CM.operator_specialization( ::Type{O}, ::Type{<:TemplateExpression} -) where {O<:OperatorEnum} = O +) where {O<:OperatorEnum} + return O +end """ We pick a random subexpression to mutate, From 27746e4383606613b6a9b67d1bfd6e4ec83c1cec Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 15 Oct 2024 10:20:19 +0100 Subject: [PATCH 47/58] fix: interaction with DispatchDoctor --- src/TemplateExpression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index a3b241747..aeaf5aa09 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -242,7 +242,7 @@ function DA.violates_dimensional_constraints( return false end function MM.condition_mutation_weights!( - _::AbstractMutationWeights, _::P, _::AbstractOptions, _::Int + weights::AbstractMutationWeights, member::P, options::AbstractOptions, curmaxsize::Int ) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} # HACK TODO return nothing From 6c69700615bdc826719eaee08f880b86d8af6011 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 16 Oct 2024 20:46:32 +0100 Subject: [PATCH 48/58] test: fix JET identified issue --- src/LossFunctions.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 2830ecdcd..01dcca86b 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -73,12 +73,16 @@ function _eval_loss( loss_val = if is_weighted(dataset) _weighted_loss( prediction, - maybe_getindex(dataset.y, idx), + maybe_getindex(dataset.y::AbstractArray, idx), maybe_getindex(dataset.weights, idx), options.elementwise_loss, ) else - _loss(prediction, maybe_getindex(dataset.y, idx), options.elementwise_loss) + _loss( + prediction, + maybe_getindex(dataset.y::AbstractArray, idx), + options.elementwise_loss, + ) end if regularization From ee6096efdfe6a940b855f7410bf924cacdc69310 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 16 Oct 2024 20:56:14 +0100 Subject: [PATCH 49/58] test: add missing template expression test --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 377dc52fc..fcc2c5b08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -131,6 +131,7 @@ end ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" include("../examples/parameterized_function.jl") end +include("test_template_expression.jl") @testitem "Testing whether the recorder works." tags = [:part3] begin include("test_recorder.jl") From 9c30d141592bdcaf00db3dc740cf92987500033a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 16 Oct 2024 21:08:05 +0100 Subject: [PATCH 50/58] fix: type instability in ParametricExpression --- src/ParametricExpression.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 4294acdc7..f98a1de08 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -83,12 +83,14 @@ end function LF.eval_tree_dispatch( tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx ) - return eval_tree_array( + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = DE.eval_tree_array( tree, LF.maybe_getindex(dataset.X, :, idx), LF.maybe_getindex(dataset.extra.classes, idx), options.operators, ) + return out::A, complete::Bool end function MM.condition_mutate_constant!( From 5194c014789045c347d6e5d6234c82885bdba4f6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 16 Oct 2024 22:17:47 +0100 Subject: [PATCH 51/58] test: more test coverage --- src/TemplateExpression.jl | 7 ++++++- test/test_template_expression.jl | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index aeaf5aa09..dfc8b59ca 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -146,7 +146,12 @@ function TemplateExpression( return TemplateExpression(trees, Metadata(metadata)) end -DE.constructorof(::Type{<:TemplateExpression}) = TemplateExpression +function DE.constructorof(::Type{<:TemplateExpression}) + return error( + "TemplateExpression requires additional information to constructor correctly. " * + "Please use `create_expression` instead.", + ) +end @implements( ExpressionInterface{all_ei_methods_except(())}, TemplateExpression, [Arguments()] diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl index 2187ea334..8f44d5210 100644 --- a/test/test_template_expression.jl +++ b/test/test_template_expression.jl @@ -159,3 +159,12 @@ end @testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin include("../examples/template_expression.jl") end +@testitem "Unimplemented functions" tags = [:part3] begin + using SymbolicRegression + using DynamicExpressions: constructorof + + @test_throws ErrorException constructorof(TemplateExpression) + @test_throws "TemplateExpression requires additional information to constructor correctly." constructorof( + TemplateExpression + ) +end From b40ed0fa9c63718e84136b079477ba420b02e834 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 16 Oct 2024 22:23:20 +0100 Subject: [PATCH 52/58] ci: update formatter --- .github/workflows/check-format.yml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/check-format.yml b/.github/workflows/check-format.yml index de643f2ef..df5b98af3 100644 --- a/.github/workflows/check-format.yml +++ b/.github/workflows/check-format.yml @@ -24,12 +24,8 @@ jobs: - name: "Cache dependencies" uses: julia-actions/cache@v2 - name: Install JuliaFormatter and format - # This will use the latest version by default but you can set the version like so: - # - # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' + julia --startup-file=no -e 'using Pkg; pkg"activate --temp"; pkg"add JuliaFormatter"; using JuliaFormatter; format(".", verbose=true)' - name: "Format check" run: | julia -e ' From 52caafbe8ccd3340753ccc04161c5e96d04aeb7a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 16 Oct 2024 22:30:44 +0100 Subject: [PATCH 53/58] ci: update formatter --- .github/workflows/check-format.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-format.yml b/.github/workflows/check-format.yml index df5b98af3..6bc6bdd51 100644 --- a/.github/workflows/check-format.yml +++ b/.github/workflows/check-format.yml @@ -25,7 +25,7 @@ jobs: uses: julia-actions/cache@v2 - name: Install JuliaFormatter and format run: | - julia --startup-file=no -e 'using Pkg; pkg"activate --temp"; pkg"add JuliaFormatter"; using JuliaFormatter; format(".", verbose=true)' + julia --startup-file=no -e 'using Pkg; pkg"activate --temp"; pkg"add JuliaFormatter@1.0.61"; using JuliaFormatter; format("."; verbose=true)' - name: "Format check" run: | julia -e ' From 214b95a48c48bdb68f63f4c43ea001782e1d33dc Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 17 Oct 2024 11:40:24 +0100 Subject: [PATCH 54/58] refactor: more appropriate module name --- src/SymbolicRegression.jl | 2 +- src/TemplateExpression.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 0530dba4a..53afae3af 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -314,7 +314,7 @@ using .SearchUtilsModule: save_to_file, get_cur_maxsize, update_hall_of_fame! -using .ConstrainedExpressionModule: TemplateExpression +using .TemplateExpressionModule: TemplateExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index dfc8b59ca..23f095f32 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -1,4 +1,4 @@ -module ConstrainedExpressionModule +module TemplateExpressionModule using Random: AbstractRNG using DispatchDoctor: @unstable From aae35cb11c581e043cd8758be5d508e7c8738cc5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 17 Oct 2024 14:36:06 +0100 Subject: [PATCH 55/58] fix: put back constructorof --- src/TemplateExpression.jl | 7 +------ test/test_template_expression.jl | 9 --------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 23f095f32..e6abc796a 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -146,12 +146,7 @@ function TemplateExpression( return TemplateExpression(trees, Metadata(metadata)) end -function DE.constructorof(::Type{<:TemplateExpression}) - return error( - "TemplateExpression requires additional information to constructor correctly. " * - "Please use `create_expression` instead.", - ) -end +@unstable DE.constructorof(::Type{<:TemplateExpression}) = TemplateExpression @implements( ExpressionInterface{all_ei_methods_except(())}, TemplateExpression, [Arguments()] diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl index 8f44d5210..2187ea334 100644 --- a/test/test_template_expression.jl +++ b/test/test_template_expression.jl @@ -159,12 +159,3 @@ end @testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin include("../examples/template_expression.jl") end -@testitem "Unimplemented functions" tags = [:part3] begin - using SymbolicRegression - using DynamicExpressions: constructorof - - @test_throws ErrorException constructorof(TemplateExpression) - @test_throws "TemplateExpression requires additional information to constructor correctly." constructorof( - TemplateExpression - ) -end From 6890367ef2eaba10788e8b48d493ac17796d691b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 17 Oct 2024 23:41:30 +0100 Subject: [PATCH 56/58] refactor: remove old redundant copy --- src/SingleIteration.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index 1d305cbb2..90edb8ee7 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -106,12 +106,6 @@ function optimize_and_simplify_population( # to manually allocate a new task with a larger stack for Enzyme. should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) - # TODO: This `copy` is necessary to avoid an undefined reference - # error when simplifying, and only for `TemplateExpression`. - # But, why is it needed? Could it be that - # some of the expressions across the population share subtrees? - pop.members .= map(copy, pop.members) - @threads_if should_thread for j in 1:(pop.n) if options.should_simplify tree = pop.members[j].tree From d0a642054af351027c6e938d4457f41640b377b3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 17 Oct 2024 23:43:47 +0100 Subject: [PATCH 57/58] refactor: copy at the source, not the caller --- src/Population.jl | 4 +--- src/RegularizedEvolution.jl | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/Population.jl b/src/Population.jl index 54aabd369..d475da168 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -112,9 +112,7 @@ function best_of_sample( options::AbstractOptions, ) where {T,L,N} sample = sample_pop(pop, options) - return _best_of_sample( - sample.members, running_search_statistics, options - )::PopMember{T,L,N} + return copy(_best_of_sample(sample.members, running_search_statistics, options)) end function _best_of_sample( members::Vector{P}, diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl index 7b140de71..06358a328 100644 --- a/src/RegularizedEvolution.jl +++ b/src/RegularizedEvolution.jl @@ -31,7 +31,7 @@ function reg_evol_cycle( for i in 1:n_evol_cycles if rand() > options.crossover_probability - allstar = copy(best_of_sample(pop, running_search_statistics, options)) + allstar = best_of_sample(pop, running_search_statistics, options) mutation_recorder = RecordType() baby, mutation_accepted, tmp_num_evals = next_generation( dataset, @@ -84,8 +84,8 @@ function reg_evol_cycle( pop.members[oldest] = baby else # Crossover - allstar1 = copy(best_of_sample(pop, running_search_statistics, options)) - allstar2 = copy(best_of_sample(pop, running_search_statistics, options)) + allstar1 = best_of_sample(pop, running_search_statistics, options) + allstar2 = best_of_sample(pop, running_search_statistics, options) baby1, baby2, crossover_accepted, tmp_num_evals = crossover_generation( allstar1, allstar2, dataset, curmaxsize, options From a9e5332c7a335eac9a8149a2119c62045f78691d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 18 Oct 2024 00:32:01 +0100 Subject: [PATCH 58/58] feat: weight TemplateExpression sampling by num nodes --- src/TemplateExpression.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index e6abc796a..d88c07dcc 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -18,7 +18,8 @@ using DynamicExpressions: get_variable_names, get_tree, node_type, - eval_tree_array + eval_tree_array, + count_nodes using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments @@ -264,8 +265,15 @@ and also return the symbol we mutated on so that we can put it back together lat function MF.get_contents_for_mutation(ex::TemplateExpression, rng::AbstractRNG) raw_contents = get_contents(ex) function_keys = keys(raw_contents) - key_to_mutate = rand(rng, function_keys) + # Sample weighted by number of nodes in each subexpression + num_nodes = map(count_nodes, values(raw_contents)) + weights = map(Base.Fix2(/, sum(num_nodes)), num_nodes) + cumsum_weights = cumsum(weights) + rand_val = rand(rng) + idx = findfirst(Base.Fix2(>=, rand_val), cumsum_weights)::Int + + key_to_mutate = function_keys[idx] return raw_contents[key_to_mutate], key_to_mutate end