Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create TemplateExpression for providing a pre-defined functional structure and constraints #355

Merged
merged 60 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
1494bde
refactor: move ParametricExpression overloads to separate file
MilesCranmer Oct 13, 2024
6c78145
fix: some imports in ExpressionBuilder
MilesCranmer Oct 13, 2024
87d37aa
feat: initial implementation of ConstrainedExpression
MilesCranmer Oct 13, 2024
2d9d634
feat: variable constraint checking
MilesCranmer Oct 13, 2024
4ce678d
feat: allow custom evaluation for ConstrainedExpression
MilesCranmer Oct 13, 2024
7223da3
fix: some bugs with ConstrainedExpression
MilesCranmer Oct 13, 2024
0300b05
fix: call to `sort_params`
MilesCranmer Oct 13, 2024
24edd8c
refactor: style adjustment
MilesCranmer Oct 13, 2024
3225e6f
fix: ensure `operators` always available
MilesCranmer Oct 13, 2024
2e0867b
feat: give reduced complexity for constrained expressions
MilesCranmer Oct 13, 2024
958890b
docs: missing comment
MilesCranmer Oct 13, 2024
bb86777
fix: aliasing issue in `simplify` with extra copy
MilesCranmer Oct 13, 2024
c81360b
refactor: fix bad nospecialize
MilesCranmer Oct 14, 2024
a3d28ed
fix: prevent aliasing during crossover
MilesCranmer Oct 14, 2024
52f738a
fix: guard more aliasing issues
MilesCranmer Oct 14, 2024
ae69e10
fix: missing import
MilesCranmer Oct 14, 2024
dc276f9
refactor: rename to `BlueprintExpression`
MilesCranmer Oct 14, 2024
f23b188
Merge branch 'master' into structured-expressions
MilesCranmer Oct 14, 2024
f455f7a
fix: missing `@unstable`
MilesCranmer Oct 14, 2024
bf712c6
docs: add extended docstring for BlueprintExpression and example
MilesCranmer Oct 14, 2024
2123ef2
fix: unbound type parameter
MilesCranmer Oct 14, 2024
a2afe7c
test: fix parametric expression test
MilesCranmer Oct 14, 2024
bd20e4d
refactor: rename to `TemplateExpression`
MilesCranmer Oct 14, 2024
a551e83
docs: change example name
MilesCranmer Oct 14, 2024
0514bca
feat: permit complex structured output
MilesCranmer Oct 14, 2024
acd7762
fix: weak dispatch in TemplateExpression
MilesCranmer Oct 14, 2024
686b844
fix: miscalculations of `maxsize`
MilesCranmer Oct 14, 2024
e71ec07
feat: export `with_contents` and `with_metadata`
MilesCranmer Oct 14, 2024
108d1c1
feat: callable TemplateExpression
MilesCranmer Oct 14, 2024
bd6d8f5
fix: new `maxsize` miscalc
MilesCranmer Oct 14, 2024
8433d66
fix: `check_constraints` for TemplateExpression
MilesCranmer Oct 14, 2024
1bb7f58
Merge branch 'master' into structured-expressions
MilesCranmer Oct 15, 2024
7ed6e93
test: add missing tags
MilesCranmer Oct 15, 2024
5a0f77c
test: fix pretty print test with fixed `maxsize`
MilesCranmer Oct 15, 2024
8c91200
fix: make `weighted` into callable `is_weighted`
MilesCranmer Oct 15, 2024
46a59e9
test: update dataset tests
MilesCranmer Oct 15, 2024
9ed392c
feat: generalize MLJ interface
MilesCranmer Oct 15, 2024
4e52ca4
fix: missing calls to `.weighted`
MilesCranmer Oct 15, 2024
97ea32b
fix: update type inference utility for TemplateExpression
MilesCranmer Oct 15, 2024
26ca30a
test: greatly improve TemplateExpression test
MilesCranmer Oct 15, 2024
4f1f852
docs: update `TemplateExpression` name
MilesCranmer Oct 15, 2024
204e5ef
docs: update `TemplateExpression`
MilesCranmer Oct 15, 2024
0b21586
test: tweak testitem name
MilesCranmer Oct 15, 2024
4ec4df6
test: move integration test to example folder
MilesCranmer Oct 15, 2024
e5efb10
docs: update comment
MilesCranmer Oct 15, 2024
1c809d1
fix: type instability in `eval_tree_dispatch`
MilesCranmer Oct 15, 2024
c809c6c
fix: new type instabilities from `expected_array_type`
MilesCranmer Oct 15, 2024
8b0860d
style: formatting
MilesCranmer Oct 15, 2024
27746e4
fix: interaction with DispatchDoctor
MilesCranmer Oct 15, 2024
6c69700
test: fix JET identified issue
MilesCranmer Oct 16, 2024
ee6096e
test: add missing template expression test
MilesCranmer Oct 16, 2024
9c30d14
fix: type instability in ParametricExpression
MilesCranmer Oct 16, 2024
5194c01
test: more test coverage
MilesCranmer Oct 16, 2024
b40ed0f
ci: update formatter
MilesCranmer Oct 16, 2024
52caafb
ci: update formatter
MilesCranmer Oct 16, 2024
214b95a
refactor: more appropriate module name
MilesCranmer Oct 17, 2024
aae35cb
fix: put back constructorof
MilesCranmer Oct 17, 2024
6890367
refactor: remove old redundant copy
MilesCranmer Oct 17, 2024
d0a6420
refactor: copy at the source, not the caller
MilesCranmer Oct 17, 2024
a9e5332
feat: weight TemplateExpression sampling by num nodes
MilesCranmer Oct 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/check-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@ jobs:
- name: "Cache dependencies"
uses: julia-actions/cache@v2
- name: Install JuliaFormatter and format
# This will use the latest version by default but you can set the version like so:
#
# julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))'
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
julia -e 'using JuliaFormatter; format(".", verbose=true)'
julia --startup-file=no -e 'using Pkg; pkg"activate --temp"; pkg"add [email protected]"; using JuliaFormatter; format("."; verbose=true)'
- name: "Format check"
run: |
julia -e '
Expand Down
5 changes: 5 additions & 0 deletions docs/src/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_
case your expression needs additional parameters. See the method for `ParametricExpression`
as an example.

You can look at the files `src/ParametricExpression.jl` and `src/TemplateExpression.jl`
for more examples of custom expression types, though note that `ParametricExpression` itself
is defined in DynamicExpressions.jl, while that file just overloads some methods for
SymbolicRegression.jl.

## Other Customizations

Other internal abstract types include the following:
Expand Down
10 changes: 10 additions & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ ParametricNode

These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`.

## Template Expressions

Template expressions are a type of expression that allows you to specify a predefined structure.
This lets you also fit vector expressions, as the custom evaluation structure can simply return
a vector of tuples.

```@docs
TemplateExpression
```

## Population

Groups of equations are given as a population, which is
Expand Down
62 changes: 62 additions & 0 deletions examples/template_expression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using SymbolicRegression
using Random: rand
using MLJBase: machine, fit!, report
using Test: @test

options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
operators = options.operators
variable_names = (i -> "x$i").(1:3)
x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3)

variable_mapping = (; f=[1, 2], g1=[3], g2=[3])

function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractString}}})
return "( $(nt.f) + $(nt.g1), $(nt.f) + $(nt.g2) )"
end
function my_structure(nt::NamedTuple{<:Any,<:Tuple{Vararg{<:AbstractVector}}})
return map(i -> (nt.f[i] + nt.g1[i], nt.f[i] + nt.g2[i]), eachindex(nt.f))
end

st_expr = TemplateExpression(
(; f=x1, g1=x3, g2=x3);
structure=my_structure,
operators,
variable_names,
variable_mapping,
)

X = rand(100, 3) .* 10

# Our dataset is a vector of 2-tuples
y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))]

model = SRRegressor(;
binary_operators=(+, *),
unary_operators=(sin,),
maxsize=15,
expression_type=TemplateExpression,
expression_options=(; structure=my_structure, variable_mapping),
# The elementwise needs to operate directly on each row of `y`:
elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2,
early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7,
)

mach = machine(model, X, y)
fit!(mach)

# Check the performance of the model
r = report(mach)
idx = r.best_idx
best_loss = r.losses[idx]

@test best_loss < 1e-5

# Check the expression is split up correctly:
best_expr = r.equations[idx]
best_f = get_contents(best_expr).f
best_g1 = get_contents(best_expr).g1
best_g2 = get_contents(best_expr).g2

@test best_f(X') ≈ (@. sin(X[:, 1]))
@test best_g1(X') ≈ (@. X[:, 3] * X[:, 3])
@test best_g2(X') ≈ (@. X[:, 3])
4 changes: 1 addition & 3 deletions src/AdaptiveParsimony.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ struct RunningSearchStatistics
end

function RunningSearchStatistics(; options::AbstractOptions, window_size::Int=100000)
maxsize = options.maxsize
actualMaxsize = maxsize + MAX_DEGREE
init_frequencies = ones(Float64, actualMaxsize)
init_frequencies = ones(Float64, options.maxsize)

return RunningSearchStatistics(
window_size, init_frequencies, copy(init_frequencies) / sum(init_frequencies)
Expand Down
2 changes: 1 addition & 1 deletion src/CheckConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function flag_illegal_nests(tree::AbstractExpressionNode, options::AbstractOptio
return false
end

"""Check if user-passed constraints are violated or not"""
"""Check if user-passed constraints are satisfied. Returns false otherwise."""
function check_constraints(
ex::AbstractExpression,
options::AbstractOptions,
Expand Down
6 changes: 3 additions & 3 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function test_dataset_configuration(
) where {T<:DATA_TYPE}
n = dataset.n
if n != size(dataset.X, 2) ||
(dataset.y !== nothing && n != size(dataset.y::AbstractArray{T}, 1))
(dataset.y !== nothing && n != size(dataset.y::AbstractArray, 1))
throw(
AssertionError(
"Dataset dimensions are invalid. Make sure X is of shape [features, rows], y is of shape [rows] and if there are weights, they are of shape [rows].",
Expand All @@ -101,7 +101,7 @@ function test_dataset_configuration(
end

if !(typeof(options.elementwise_loss) <: SupervisedLoss) &&
dataset.weighted &&
is_weighted(dataset) &&
!(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)])
throw(
AssertionError(
Expand Down Expand Up @@ -132,7 +132,7 @@ function move_functions_to_workers(
continue
end
ops = (options.elementwise_loss,)
example_inputs = if dataset.weighted
example_inputs = if is_weighted(dataset)
(zero(T), zero(T), zero(T))
else
(zero(T), zero(T))
Expand Down
9 changes: 7 additions & 2 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ include("Options.jl")

using .ProgramConstantsModule:
MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE
using .DatasetModule: Dataset
using .DatasetModule: Dataset, is_weighted
using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation
using .OptionsStructModule: AbstractOptions, Options, ComplexityMapping, specialized_options
using .OptionsStructModule:
AbstractOptions,
Options,
ComplexityMapping,
specialized_options,
operator_specialization
using .OperatorsModule:
plus,
sub,
Expand Down
42 changes: 10 additions & 32 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import ...deprecate_varmap
dataset, if any.
- `n::Int`: The number of samples.
- `nfeatures::Int`: The number of features.
- `weighted::Bool`: Whether the dataset is non-uniformly weighted.
- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted,
- `weights::Union{AbstractVector,Nothing}`: If the dataset is weighted,
these specify the per-sample weight (with shape `(n,)`).
- `extra::NamedTuple`: Extra information to pass to a custom evaluation
function. Since this is an arbitrary named tuple, you could pass
Expand Down Expand Up @@ -49,8 +48,8 @@ mutable struct Dataset{
T<:DATA_TYPE,
L<:LOSS_TYPE,
AX<:AbstractMatrix{T},
AY<:Union{AbstractVector{T},Nothing},
AW<:Union{AbstractVector{T},Nothing},
AY<:Union{AbstractVector,Nothing},
AW<:Union{AbstractVector,Nothing},
NT<:NamedTuple,
XU<:Union{AbstractVector{<:Quantity},Nothing},
YU<:Union{Quantity,Nothing},
Expand All @@ -62,7 +61,6 @@ mutable struct Dataset{
const index::Int
const n::Int
const nfeatures::Int
const weighted::Bool
const weights::AW
const extra::NT
const avg_y::Union{T,Nothing}
Expand All @@ -81,7 +79,7 @@ end
Dataset(X::AbstractMatrix{T},
y::Union{AbstractVector{T},Nothing}=nothing,
loss_type::Type=Nothing;
weights::Union{AbstractVector{T}, Nothing}=nothing,
weights::Union{AbstractVector, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
y_variable_name::Union{String,Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
Expand All @@ -93,10 +91,10 @@ Construct a dataset to pass between internal functions.
"""
function Dataset(
X::AbstractMatrix{T},
y::Union{AbstractVector{T},Nothing}=nothing,
y::Union{AbstractVector,Nothing}=nothing,
loss_type::Type{L}=Nothing;
index::Int=1,
weights::Union{AbstractVector{T},Nothing}=nothing,
weights::Union{AbstractVector,Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
display_variable_names=variable_names,
y_variable_name::Union{String,Nothing}=nothing,
Expand Down Expand Up @@ -133,7 +131,6 @@ function Dataset(

n = size(X, BATCH_DIM)
nfeatures = size(X, FEATURE_DIM)
weighted = weights !== nothing
variable_names = if variable_names === nothing
["x$(i)" for i in 1:nfeatures]
else
Expand All @@ -150,10 +147,10 @@ function Dataset(
else
y_variable_name
end
avg_y = if y === nothing
avg_y = if y === nothing || !(eltype(y) isa Number)
nothing
else
if weighted
if weights !== nothing
sum(y .* weights) / sum(weights)
else
sum(y) / n
Expand Down Expand Up @@ -207,7 +204,6 @@ function Dataset(
index,
n,
nfeatures,
weighted,
weights,
extra,
avg_y,
Expand All @@ -222,26 +218,8 @@ function Dataset(
y_sym_units,
)
end
function Dataset(
X::AbstractMatrix,
y::Union{<:AbstractVector,Nothing}=nothing;
weights::Union{<:AbstractVector,Nothing}=nothing,
kws...,
)
T = promote_type(
eltype(X),
(y === nothing) ? eltype(X) : eltype(y),
(weights === nothing) ? eltype(X) : eltype(weights),
)
X = Base.Fix1(convert, T).(X)
if y !== nothing
y = Base.Fix1(convert, T).(y)
end
if weights !== nothing
weights = Base.Fix1(convert, T).(weights)
end
return Dataset(X, y; weights=weights, kws...)
end

is_weighted(dataset::Dataset) = dataset.weights !== nothing

function error_on_mismatched_size(_, ::Nothing)
return nothing
Expand Down
Loading
Loading