Skip to content

Commit

Permalink
Merge pull request #355 from MilesCranmer/structured-expressions
Browse files Browse the repository at this point in the history
Create `TemplateExpression` for providing a pre-defined functional structure and constraints
  • Loading branch information
MilesCranmer authored Oct 18, 2024
2 parents 3892a66 + a9e5332 commit a38f901
Show file tree
Hide file tree
Showing 30 changed files with 1,077 additions and 353 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/check-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@ jobs:
- name: "Cache dependencies"
uses: julia-actions/cache@v2
- name: Install JuliaFormatter and format
# This will use the latest version by default but you can set the version like so:
#
# julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))'
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
julia -e 'using JuliaFormatter; format(".", verbose=true)'
julia --startup-file=no -e 'using Pkg; pkg"activate --temp"; pkg"add [email protected]"; using JuliaFormatter; format("."; verbose=true)'
- name: "Format check"
run: |
julia -e '
Expand Down
5 changes: 5 additions & 0 deletions docs/src/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_
case your expression needs additional parameters. See the method for `ParametricExpression`
as an example.

You can look at the files `src/ParametricExpression.jl` and `src/TemplateExpression.jl`
for more examples of custom expression types, though note that `ParametricExpression` itself
is defined in DynamicExpressions.jl, while that file just overloads some methods for
SymbolicRegression.jl.

## Other Customizations

Other internal abstract types include the following:
Expand Down
10 changes: 10 additions & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ ParametricNode

These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`.

## Template Expressions

Template expressions are a type of expression that allows you to specify a predefined structure.
This lets you also fit vector expressions, as the custom evaluation structure can simply return
a vector of tuples.

```@docs
TemplateExpression
```

## Population

Groups of equations are given as a population, which is
Expand Down
62 changes: 62 additions & 0 deletions examples/template_expression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using SymbolicRegression
using Random: rand
using MLJBase: machine, fit!, report
using Test: @test

options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
operators = options.operators
variable_names = (i -> "x$i").(1:3)
x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3)

variable_mapping = (; f=[1, 2], g1=[3], g2=[3])

function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}})
return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2) )"
end
function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}})
return map(i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i]), eachindex(nt.f))
end

st_expr = TemplateExpression(
(; f=x1, g1=x3, g2=x3);
structure=my_structure,
operators,
variable_names,
variable_mapping,
)

X = rand(100, 3) .* 10

# Our dataset is a vector of 2-tuples
y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))]

model = SRRegressor(;
binary_operators=(+, *),
unary_operators=(sin,),
maxsize=15,
expression_type=TemplateExpression,
expression_options=(; structure=my_structure, variable_mapping),
# The elementwise needs to operate directly on each row of `y`:
elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2,
early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7,
)

mach = machine(model, X, y)
fit!(mach)

# Check the performance of the model
r = report(mach)
idx = r.best_idx
best_loss = r.losses[idx]

@test best_loss < 1e-5

# Check the expression is split up correctly:
best_expr = r.equations[idx]
best_f = get_contents(best_expr).f
best_g1 = get_contents(best_expr).g1
best_g2 = get_contents(best_expr).g2

@test best_f(X') (@. sin(X[:, 1]))
@test best_g1(X') (@. X[:, 3] * X[:, 3])
@test best_g2(X') (@. X[:, 3])
4 changes: 1 addition & 3 deletions src/AdaptiveParsimony.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ struct RunningSearchStatistics
end

function RunningSearchStatistics(; options::AbstractOptions, window_size::Int=100000)
maxsize = options.maxsize
actualMaxsize = maxsize + MAX_DEGREE
init_frequencies = ones(Float64, actualMaxsize)
init_frequencies = ones(Float64, options.maxsize)

return RunningSearchStatistics(
window_size, init_frequencies, copy(init_frequencies) / sum(init_frequencies)
Expand Down
2 changes: 1 addition & 1 deletion src/CheckConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function flag_illegal_nests(tree::AbstractExpressionNode, options::AbstractOptio
return false
end

"""Check if user-passed constraints are violated or not"""
"""Check if user-passed constraints are satisfied. Returns false otherwise."""
function check_constraints(
ex::AbstractExpression,
options::AbstractOptions,
Expand Down
6 changes: 3 additions & 3 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function test_dataset_configuration(
) where {T<:DATA_TYPE}
n = dataset.n
if n != size(dataset.X, 2) ||
(dataset.y !== nothing && n != size(dataset.y::AbstractArray{T}, 1))
(dataset.y !== nothing && n != size(dataset.y::AbstractArray, 1))
throw(
AssertionError(
"Dataset dimensions are invalid. Make sure X is of shape [features, rows], y is of shape [rows] and if there are weights, they are of shape [rows].",
Expand All @@ -101,7 +101,7 @@ function test_dataset_configuration(
end

if !(typeof(options.elementwise_loss) <: SupervisedLoss) &&
dataset.weighted &&
is_weighted(dataset) &&
!(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)])
throw(
AssertionError(
Expand Down Expand Up @@ -132,7 +132,7 @@ function move_functions_to_workers(
continue
end
ops = (options.elementwise_loss,)
example_inputs = if dataset.weighted
example_inputs = if is_weighted(dataset)
(zero(T), zero(T), zero(T))
else
(zero(T), zero(T))
Expand Down
9 changes: 7 additions & 2 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ include("Options.jl")

using .ProgramConstantsModule:
MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE
using .DatasetModule: Dataset
using .DatasetModule: Dataset, is_weighted
using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation
using .OptionsStructModule: AbstractOptions, Options, ComplexityMapping, specialized_options
using .OptionsStructModule:
AbstractOptions,
Options,
ComplexityMapping,
specialized_options,
operator_specialization
using .OperatorsModule:
plus,
sub,
Expand Down
42 changes: 10 additions & 32 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import ...deprecate_varmap
dataset, if any.
- `n::Int`: The number of samples.
- `nfeatures::Int`: The number of features.
- `weighted::Bool`: Whether the dataset is non-uniformly weighted.
- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted,
- `weights::Union{AbstractVector,Nothing}`: If the dataset is weighted,
these specify the per-sample weight (with shape `(n,)`).
- `extra::NamedTuple`: Extra information to pass to a custom evaluation
function. Since this is an arbitrary named tuple, you could pass
Expand Down Expand Up @@ -49,8 +48,8 @@ mutable struct Dataset{
T<:DATA_TYPE,
L<:LOSS_TYPE,
AX<:AbstractMatrix{T},
AY<:Union{AbstractVector{T},Nothing},
AW<:Union{AbstractVector{T},Nothing},
AY<:Union{AbstractVector,Nothing},
AW<:Union{AbstractVector,Nothing},
NT<:NamedTuple,
XU<:Union{AbstractVector{<:Quantity},Nothing},
YU<:Union{Quantity,Nothing},
Expand All @@ -62,7 +61,6 @@ mutable struct Dataset{
const index::Int
const n::Int
const nfeatures::Int
const weighted::Bool
const weights::AW
const extra::NT
const avg_y::Union{T,Nothing}
Expand All @@ -81,7 +79,7 @@ end
Dataset(X::AbstractMatrix{T},
y::Union{AbstractVector{T},Nothing}=nothing,
loss_type::Type=Nothing;
weights::Union{AbstractVector{T}, Nothing}=nothing,
weights::Union{AbstractVector, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
y_variable_name::Union{String,Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
Expand All @@ -93,10 +91,10 @@ Construct a dataset to pass between internal functions.
"""
function Dataset(
X::AbstractMatrix{T},
y::Union{AbstractVector{T},Nothing}=nothing,
y::Union{AbstractVector,Nothing}=nothing,
loss_type::Type{L}=Nothing;
index::Int=1,
weights::Union{AbstractVector{T},Nothing}=nothing,
weights::Union{AbstractVector,Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
display_variable_names=variable_names,
y_variable_name::Union{String,Nothing}=nothing,
Expand Down Expand Up @@ -133,7 +131,6 @@ function Dataset(

n = size(X, BATCH_DIM)
nfeatures = size(X, FEATURE_DIM)
weighted = weights !== nothing
variable_names = if variable_names === nothing
["x$(i)" for i in 1:nfeatures]
else
Expand All @@ -150,10 +147,10 @@ function Dataset(
else
y_variable_name
end
avg_y = if y === nothing
avg_y = if y === nothing || !(eltype(y) isa Number)
nothing
else
if weighted
if weights !== nothing
sum(y .* weights) / sum(weights)
else
sum(y) / n
Expand Down Expand Up @@ -207,7 +204,6 @@ function Dataset(
index,
n,
nfeatures,
weighted,
weights,
extra,
avg_y,
Expand All @@ -222,26 +218,8 @@ function Dataset(
y_sym_units,
)
end
function Dataset(
X::AbstractMatrix,
y::Union{<:AbstractVector,Nothing}=nothing;
weights::Union{<:AbstractVector,Nothing}=nothing,
kws...,
)
T = promote_type(
eltype(X),
(y === nothing) ? eltype(X) : eltype(y),
(weights === nothing) ? eltype(X) : eltype(weights),
)
X = Base.Fix1(convert, T).(X)
if y !== nothing
y = Base.Fix1(convert, T).(y)
end
if weights !== nothing
weights = Base.Fix1(convert, T).(weights)
end
return Dataset(X, y; weights=weights, kws...)
end

is_weighted(dataset::Dataset) = dataset.weights !== nothing

function error_on_mismatched_size(_, ::Nothing)
return nothing
Expand Down
Loading

0 comments on commit a38f901

Please sign in to comment.