Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BREAKING: Change expression types to DynamicExpressions.Expression (from DynamicExpressions.Node) #326

Merged
merged 209 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
209 commits
Select commit Hold shift + click to select a range
83bb8be
chore: update gitignore
MilesCranmer May 21, 2024
c38f52b
remove precompilation
MilesCranmer May 21, 2024
3bdb18f
wip: get many parts of expression interface working
MilesCranmer May 21, 2024
76d1bf0
get more parts working for parametric expressions
MilesCranmer May 22, 2024
8ac4e43
fix constant optimization for expressions
MilesCranmer May 22, 2024
ab74ff2
various fixes for expressions
MilesCranmer May 22, 2024
a43f928
fix other parts of expression interface
MilesCranmer May 22, 2024
88ce42a
specialize operators
MilesCranmer May 22, 2024
bdbd2ab
wip: almost working creation
MilesCranmer May 22, 2024
c43c6fc
more parts working
MilesCranmer May 22, 2024
82b4ccf
fix up other parts of parametric expression
MilesCranmer May 22, 2024
9af9cfb
create mutate constant for parameters
MilesCranmer May 22, 2024
5a1ddf6
fix strings
MilesCranmer May 22, 2024
95865eb
fix strings
MilesCranmer May 22, 2024
03ac9d3
formatting
MilesCranmer May 22, 2024
d38e8ab
fix complexity
MilesCranmer May 22, 2024
1080bcb
chore: bump DynamicExpressions
MilesCranmer Jun 10, 2024
e06d982
feat: undo abstract expression changes to InterfaceDynamicExpressions
MilesCranmer Jun 10, 2024
11e30fb
feat: add index parameter for Dataset
MilesCranmer Jun 10, 2024
c7476d0
feat: set node_type based on expression_type
MilesCranmer Jun 10, 2024
6183296
feat: export node_type
MilesCranmer Jun 10, 2024
35f7c0c
generalize initialization for expressions
MilesCranmer Jun 10, 2024
314da58
add example of parameterized functions
MilesCranmer Jun 11, 2024
de940bb
formatting
MilesCranmer Jun 13, 2024
a51f4b1
fix various issues with expressions
MilesCranmer Jun 13, 2024
cfd82ff
user expressions on regular trees
MilesCranmer Jun 13, 2024
2795fef
fix various issues with expressions in tests
MilesCranmer Jun 13, 2024
c57d07a
turn back on precompilation
MilesCranmer Jun 13, 2024
2e3ab8f
fix conversion steps
MilesCranmer Jun 13, 2024
6739a99
fix bug introduced by `get_tree_from_member`
MilesCranmer Jun 13, 2024
ee76cce
fix symbolic import
MilesCranmer Jun 13, 2024
8fb1103
missing union
MilesCranmer Jun 13, 2024
59c2079
switch to zygote for optimization
MilesCranmer Jun 14, 2024
ca43666
feat: allow user-specified autodiff backend; finite or zygote
MilesCranmer Jun 14, 2024
4fb7580
feat: add TODO item
MilesCranmer Jun 14, 2024
b91e3d8
Merge branch 'master' into parametric-expressions
MilesCranmer Jun 16, 2024
566851b
refactor: use callable structs rather than anonymous functions
MilesCranmer Jun 16, 2024
eef7c2d
test: run parametrized function example
MilesCranmer Jun 16, 2024
ced2f67
ci: split up tests into three parts
MilesCranmer Jun 16, 2024
f3e82af
test: fix jet error
MilesCranmer Jun 17, 2024
fe9c603
feat: use DifferentiationInterface.jl for AD backend
MilesCranmer Jun 17, 2024
f423f48
run parametrized_function example for test
MilesCranmer Jun 17, 2024
13687b4
fix: `default_node -> default_node_type`
MilesCranmer Jun 23, 2024
84e8161
feat: allow symbol for autodiff backend
MilesCranmer Jun 23, 2024
c81f9d0
chore: set version to alpha instead
MilesCranmer Jun 23, 2024
fc48a99
test: commit Manifest to reference alpha DE
MilesCranmer Jun 23, 2024
7fe690f
build: delete Manifest.toml
MilesCranmer Jun 23, 2024
6a6ea78
feat: better printing for halls of fame
MilesCranmer Jun 23, 2024
0a731ba
style: remove extra space
MilesCranmer Jun 23, 2024
9eb0e28
refactor: improve printing of PopMember
MilesCranmer Jun 23, 2024
c8ce382
refactor: avoid holding operators within expressions
MilesCranmer Jun 23, 2024
8c4c49b
refactor: rename expressions interface
MilesCranmer Jun 23, 2024
38923b7
fix: stripping of metadata in loaded state
MilesCranmer Jun 24, 2024
191b214
test: fix stripping of metadata
MilesCranmer Jun 24, 2024
6c9b9e9
fix: order of ParametricExpression arguments
MilesCranmer Jun 24, 2024
5f5e49e
test: declare instabilities
MilesCranmer Jun 24, 2024
6095705
test: fix remaining issues with test
MilesCranmer Jun 24, 2024
5745feb
test: fix missing import
MilesCranmer Jun 24, 2024
5a08de4
refactor: use correct interface option
MilesCranmer Jun 24, 2024
b63aea7
chore: bump DynamicExpressions version
MilesCranmer Jun 24, 2024
2df4977
chore: bump SymbolicUtils
MilesCranmer Jun 24, 2024
08dc851
chore: fix symbolic regression version
MilesCranmer Jun 24, 2024
5b33168
test: fix test interference
MilesCranmer Jun 24, 2024
368a1d9
test: fix stop on clock test
MilesCranmer Jun 24, 2024
8d50d4d
fix: jet issues with union splitting
MilesCranmer Jun 24, 2024
3ba1556
test: skip jet test on 1.11
MilesCranmer Jun 24, 2024
4bd4658
fix: various method ambiguities
MilesCranmer Jun 24, 2024
3f0edec
chore: update DE
MilesCranmer Jun 24, 2024
381b3a7
fix: jet union split
MilesCranmer Jun 24, 2024
d1d8e41
fix: use of `simplify_tree!`
MilesCranmer Jun 24, 2024
f5e9108
chore: bump DE version
MilesCranmer Jun 24, 2024
2f63ccb
feat: add classes feature to MLJ interface
MilesCranmer Jun 24, 2024
05493f6
fix: mlj interface for classes
MilesCranmer Jun 25, 2024
d1ceb11
refactor: move jet test to end
MilesCranmer Jun 25, 2024
cd4f279
refactor: update MLJ scitype
MilesCranmer Jun 25, 2024
bae08a2
style: clean up interface code
MilesCranmer Jun 25, 2024
17e5229
fix missing import
MilesCranmer Jun 26, 2024
67b851b
limit max calls in optimization
MilesCranmer Jun 26, 2024
28ff786
reduce `niterations`
MilesCranmer Jun 26, 2024
6c551a1
split up MLJ test
MilesCranmer Jun 26, 2024
fa2a30a
split up test_mixed
MilesCranmer Jun 26, 2024
1b7ed4f
speed up mixed test
MilesCranmer Jun 26, 2024
d4e2ac0
add tags on integration tests
MilesCranmer Jun 26, 2024
6f8229c
clean up test mixed util
MilesCranmer Jun 26, 2024
daec95a
float up redundant lines
MilesCranmer Jun 27, 2024
c26f1fd
simpler type assertion
MilesCranmer Jun 27, 2024
96ab40e
tests of new print_tree behavior
MilesCranmer Jun 27, 2024
1749547
test derivative helper functions
MilesCranmer Jun 27, 2024
979f1f7
test other printing methods
MilesCranmer Jun 27, 2024
9e5f467
test that new options dont change PopMember printout
MilesCranmer Jun 27, 2024
0d2023b
full test of break random connection
MilesCranmer Jun 27, 2024
669cef9
test form_random_connection!
MilesCranmer Jun 27, 2024
b3e951b
remove unused converter
MilesCranmer Jun 27, 2024
9563354
smoke test for warm start parametrized function
MilesCranmer Jun 27, 2024
898c2fd
add early_stop_condition to long mixed tests
MilesCranmer Jun 27, 2024
6d765c5
fix parameterized function test
MilesCranmer Jun 27, 2024
fc164b3
early stop condition for parameter test
MilesCranmer Jun 27, 2024
13ad4e8
make sure MLJ tests run
MilesCranmer Jun 27, 2024
b604ce9
fix parametrized function test
MilesCranmer Jun 27, 2024
9627348
avoid recursion with keyword args
MilesCranmer Jun 27, 2024
b461277
add checks for classes in MLJ
MilesCranmer Jun 27, 2024
d98bb7e
clear all variables in test_params.jl
MilesCranmer Jun 27, 2024
56e7797
fix parametrized expressions being reset
MilesCranmer Jun 27, 2024
1f104aa
fix parameterized function example not selecting last
MilesCranmer Jun 27, 2024
4d5f1cc
fix: simplification within SingleIteration
MilesCranmer Jun 28, 2024
394e768
test: gradients during optimization
MilesCranmer Jun 28, 2024
65b0a8d
test: fix parametrized function example
MilesCranmer Jun 28, 2024
de8135d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2024
a5a969a
test: gradients of parametric expression
MilesCranmer Jun 28, 2024
c3c2b53
feat: get Zygote gradients working for `ParametricExpressions`
MilesCranmer Jun 29, 2024
6eebb5f
chore: bump DynamicExpression dependency
MilesCranmer Jun 29, 2024
08f8d44
fix: ensure we continue in parametrized function test
MilesCranmer Jul 1, 2024
94c82ec
chore: bump DE version
MilesCranmer Jul 1, 2024
a72e73e
Merge branch 'master' into parametric-expressions
MilesCranmer Jul 1, 2024
47da6f5
feat: better monitor of bottlenecks
MilesCranmer Jul 1, 2024
32fec18
feat: dont print occupation during warmup
MilesCranmer Jul 1, 2024
6b2464a
feat: use `Threads.@spawn` over `@async`
MilesCranmer Jul 1, 2024
8e149bc
feat: don't print worker occupation temporarily
MilesCranmer Jul 1, 2024
2190234
feat: wip Enzyme support
MilesCranmer Jul 1, 2024
1577e67
feat: working Enzyme gradients in loss function
MilesCranmer Jul 1, 2024
98d6329
deps: add Enzyme to dependencies
MilesCranmer Jul 1, 2024
0ff426b
test: update evaluator tests api calls
MilesCranmer Jul 1, 2024
60fa199
refactor: fix ambiguity
MilesCranmer Jul 1, 2024
54a38eb
feat: Enzyme gradients for ParametricExpression
MilesCranmer Jul 1, 2024
86a39f0
fix: move back to `@async` to fix race condition
MilesCranmer Jul 1, 2024
1226a3c
test: only run Enzyme test on Julia 1.10
MilesCranmer Jul 1, 2024
6a0e32a
feat: back to fast loss functions
MilesCranmer Jul 1, 2024
0dac5c8
feat: explicit stack size for Enzyme
MilesCranmer Jul 2, 2024
e7b1340
feat: set stack size to 32 MB for Enzyme
MilesCranmer Jul 2, 2024
fcc23e2
ci: move Enzyme test to part 3
MilesCranmer Jul 2, 2024
1e60677
test: skip Enzyme test on 1.11
MilesCranmer Jul 2, 2024
d16ebf6
refactor: remove unused function
MilesCranmer Jul 2, 2024
7a70dfb
test: make stop condition earlier
MilesCranmer Jul 2, 2024
0b85a34
test: fix env in mlj tests
MilesCranmer Jul 2, 2024
365d1dd
test: avoid unicode line endings to fix JuliaSyntax bug
MilesCranmer Jul 15, 2024
80fb38f
fix: type instability in `init_dummy_pops`
MilesCranmer Jul 15, 2024
1d60c81
fix: more type stability for Enzyme
MilesCranmer Jul 15, 2024
acf6222
test: force specialized operators for Enzyme
MilesCranmer Jul 15, 2024
c12bac4
fix: issue with outer threads loop
MilesCranmer Jul 27, 2024
dafc848
deps: update DynamicExpressions version
MilesCranmer Jul 30, 2024
45d71b2
refactor: clean up stale imports with ExplicitImports.jl
MilesCranmer Jul 30, 2024
9204a1d
feat: expose new `get_scalar_constants` and `set_scalar_constants!`
MilesCranmer Jul 30, 2024
ffa9f52
fix: ensure `dataset.weights` is copied for Enzyme analysis
MilesCranmer Jul 30, 2024
5b75f59
refactor: only copy batched portion
MilesCranmer Aug 1, 2024
213180b
hack: delete all old artifacts
MilesCranmer Aug 1, 2024
8983373
ci: undo cache clearing
MilesCranmer Aug 1, 2024
226a36a
deps: bump DE with depwarn fix
MilesCranmer Aug 1, 2024
ed684fc
feat: add new string representations via dispatch
MilesCranmer Aug 1, 2024
a39ae45
refactor: `dataset.weights` not actually Enzyme issue
MilesCranmer Aug 1, 2024
a4709bf
fix: avoid storing `StatsBase.Weights` within `Options`
MilesCranmer Aug 1, 2024
1209652
fix: ensure we mark `@nospecialize`d argument as `@unstable`
MilesCranmer Aug 2, 2024
37712f6
fix: consistency checks in metadata stripping
MilesCranmer Aug 3, 2024
b4265c1
fix: ignore functions in Enzyme and ChainRulesCore
MilesCranmer Aug 3, 2024
002df6d
fix: EnzymeRules marking
MilesCranmer Aug 3, 2024
018e5a0
Revert "fix: EnzymeRules marking"
MilesCranmer Aug 3, 2024
48dccb1
Revert "fix: ignore functions in Enzyme and ChainRulesCore"
MilesCranmer Aug 3, 2024
0d7fb80
ci: test Enzyme separately
MilesCranmer Aug 3, 2024
eb7cde1
ci: ensure Preferences.jl installed
MilesCranmer Aug 3, 2024
2edb010
Revert "ci: ensure Preferences.jl installed"
MilesCranmer Aug 3, 2024
6d2cc64
Revert "ci: test Enzyme separately"
MilesCranmer Aug 3, 2024
c49673e
ci: skip Enzyme tests completely
MilesCranmer Aug 3, 2024
c64c63a
refactor: create `@ignore`d code to appease static analysis
MilesCranmer Aug 3, 2024
0f14c5f
refactor: clean up constructor of options
MilesCranmer Aug 3, 2024
8793b99
refactor: clean up constraint constructor
MilesCranmer Aug 4, 2024
509efe2
refactor: reduce some specialization
MilesCranmer Aug 4, 2024
8c06e47
test: ensure DynamicExpressions installed for preferences
MilesCranmer Aug 5, 2024
dae5593
fix: mistaken assertion
MilesCranmer Aug 5, 2024
bb3846d
fix: repeated depwarns cause dict write error
MilesCranmer Aug 5, 2024
5ab8d16
fix: prevent bad RNGs in MLJ tests
MilesCranmer Aug 5, 2024
1a74749
test: improve parameterized function example
MilesCranmer Aug 5, 2024
db2d9e9
refactor: clean up imports
MilesCranmer Aug 5, 2024
33bd4fe
test: make clocked test use serial mode
MilesCranmer Aug 5, 2024
16582d3
test: ensure options are lighterweight for clock test
MilesCranmer Aug 12, 2024
25f8fcd
test: increase npop for mlj test that fails sometimes
MilesCranmer Aug 12, 2024
03624c4
feat: create `PerThreadCache` for efficient caching
MilesCranmer Aug 13, 2024
58f4adf
refactor: use `PerThreadCache` for safe functions
MilesCranmer Aug 13, 2024
e1106dd
build: specify preferences in project
MilesCranmer Aug 23, 2024
419cf01
deps: bump SymbolicUtils to 3
MilesCranmer Aug 26, 2024
13b62c0
deps: bump DynamicQuantities to 1
MilesCranmer Aug 26, 2024
158b15e
deps: fix compat settings
MilesCranmer Aug 26, 2024
c43b7c7
deps: force SymbolicUtils 2+
MilesCranmer Aug 26, 2024
e3bcb39
deps: fix issue with [compat] on older Julia
MilesCranmer Aug 26, 2024
65a55c8
deps: add back old SymbolicUtils
MilesCranmer Aug 26, 2024
6890823
deps: fix DE version
MilesCranmer Aug 26, 2024
265ce2c
deps: skip ConstructionBase bug
MilesCranmer Aug 26, 2024
2861e43
deps: move ConstructionBase to extras
MilesCranmer Aug 26, 2024
8588d01
ci: force clear cache upon new Project.toml
MilesCranmer Aug 27, 2024
7c9b6ad
deps: force ConstructionBase dependency to fix version
MilesCranmer Aug 27, 2024
1f61246
test: verbose printing of tests
MilesCranmer Aug 27, 2024
96a1f96
test: make dimensional test reproducible
MilesCranmer Aug 28, 2024
0c4c7f9
test: split up dimensional analysis tests
MilesCranmer Aug 28, 2024
fa61c49
test: set up early stopping for test
MilesCranmer Aug 28, 2024
ff33bdf
test: note slow tests
MilesCranmer Aug 28, 2024
7740f6b
test: fix missing operator
MilesCranmer Aug 28, 2024
2caee75
docs: update examples
MilesCranmer Oct 5, 2024
548dfe5
docs: update more docs for new types
MilesCranmer Oct 5, 2024
b37adad
docs: StructuredExpression
MilesCranmer Oct 5, 2024
f49f88e
docs: explain how to create a custom expression
MilesCranmer Oct 5, 2024
45fba62
refactor: modularize `condition_mutate_constant!`
MilesCranmer Oct 5, 2024
fa8ccd0
docs: tweak docstring
MilesCranmer Oct 5, 2024
e078916
fix: when bin constraints passed as dict
MilesCranmer Oct 6, 2024
9bb4922
fix: all underscore identifier
MilesCranmer Oct 6, 2024
8fb0c94
fix: ambiguities
MilesCranmer Oct 6, 2024
c753d1a
Merge branch 'master' into parametric-expressions
MilesCranmer Oct 6, 2024
149b364
test: set random seed in test_mixed
MilesCranmer Oct 6, 2024
daee883
test: weaken test condition for turbo test
MilesCranmer Oct 6, 2024
bc43dd8
refactor: fix some other potential instabilities
MilesCranmer Oct 6, 2024
d6ac1de
deps: bump DE and DQ to 1.0
MilesCranmer Oct 6, 2024
e2b369e
test: fix comparison in parametric function test
MilesCranmer Oct 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ jobs:
fail-fast: false
matrix:
test:
- "unit"
- "integration"
- "part1"
- "part2"
- "part3"
julia-version:
- "1.6"
- "1.8"
Expand All @@ -37,22 +38,31 @@ jobs:
include:
- os: windows-latest
julia-version: "1"
test: "unit"
test: "part1"
- os: windows-latest
julia-version: "1"
test: "integration"
test: "part2"
- os: windows-latest
julia-version: "1"
test: "part3"
- os: macOS-latest
julia-version: "1"
test: "part1"
- os: macOS-latest
julia-version: "1"
test: "unit"
test: "part2"
- os: macOS-latest
julia-version: "1"
test: "integration"
test: "part3"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "part1"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "unit"
test: "part2"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "integration"
test: "part3"

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ docs/src/index.md
*.code-workspace
.vscode
**/*.json
LocalPreferences.toml
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <[email protected]>"]
version = "0.24.5"
version = "0.25.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
Expand Down Expand Up @@ -35,11 +37,13 @@ SymbolicRegressionJSON3Ext = "JSON3"
SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"

[compat]
ADTypes = "~1.4"
Compat = "^4.2"
Dates = "1"
DifferentiationInterface = "0.5"
DispatchDoctor = "0.4"
Distributed = "1"
DynamicExpressions = "0.16"
DynamicExpressions = "0.18.2"
DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14"
JSON3 = "1"
LineSearches = "7"
Expand All @@ -56,7 +60,7 @@ Random = "1"
Reexport = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
SymbolicUtils = "0.19, ^1.0.5"
SymbolicUtils = "0.19, ^1.0.5, 2"
TOML = "1"
julia = "1.6"

Expand Down
30 changes: 30 additions & 0 deletions examples/parameterized_function.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using SymbolicRegression
using Random: MersenneTwister
using Zygote
using MLJBase: machine, fit!, predict

rng = MersenneTwister(0)
X = NamedTuple{(:x1, :x2, :x3, :x4, :x5)}(ntuple(_ -> randn(rng, Float32, 30), Val(5)))
X = (; X..., classes=rand(rng, 1:2, 30))
p1 = rand(rng, Float32, 2)
p2 = rand(rng, Float32, 2)

y = [
2 * cos(X.x4[i] + p1[X.classes[i]]) + X.x1[i]^2 - p2[X.classes[i]] for
i in eachindex(X.classes)
]

model = SRRegressor(;
niterations=10,
binary_operators=[+, *, /, -],
unary_operators=[cos, exp],
populations=10,
expression_type=ParametricExpression,
expression_options=(; max_parameters=2),
autodiff_backend=:Zygote,
parallelism=:multithreading,
)

mach = machine(model, X, y)
fit!(mach)
ypred = predict(mach, X)
29 changes: 20 additions & 9 deletions ext/SymbolicRegressionSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module SymbolicRegressionSymbolicUtilsExt

using SymbolicUtils: Symbolic
using SymbolicRegression: AbstractExpressionNode, Node, Options
using SymbolicRegression: AbstractExpressionNode, AbstractExpression, Node, Options
using SymbolicRegression.MLJInterfaceModule: AbstractSRRegressor, get_options
using DynamicExpressions: get_tree, get_operators

import SymbolicRegression: node_to_symbolic, symbolic_to_node

Expand All @@ -11,10 +12,14 @@ import SymbolicRegression: node_to_symbolic, symbolic_to_node

Convert an expression to SymbolicUtils.jl form.
"""
function node_to_symbolic(tree::AbstractExpressionNode, options::Options; kws...)
return node_to_symbolic(tree, options.operators; kws...)
function node_to_symbolic(
tree::Union{AbstractExpressionNode,AbstractExpression}, options::Options; kws...
)
return node_to_symbolic(get_tree(tree), get_operators(tree, options); kws...)
end
function node_to_symbolic(tree::AbstractExpressionNode, m::AbstractSRRegressor; kws...)
function node_to_symbolic(
tree::Union{AbstractExpressionNode,AbstractExpression}, m::AbstractSRRegressor; kws...
)
return node_to_symbolic(tree, get_options(m); kws...)
end

Expand All @@ -31,24 +36,30 @@ function symbolic_to_node(eqn::Symbolic, m::AbstractSRRegressor; kws...)
end

function Base.convert(
::Type{Symbolic}, tree::AbstractExpressionNode, options::Options; kws...
::Type{Symbolic},
tree::Union{AbstractExpressionNode,AbstractExpression},
options::Union{Options,Nothing}=nothing;
kws...,
)
return convert(Symbolic, tree, options.operators; kws...)
return convert(Symbolic, get_tree(tree), get_operators(tree, options); kws...)
end
function Base.convert(
::Type{Symbolic}, tree::AbstractExpressionNode, m::AbstractSRRegressor; kws...
::Type{Symbolic},
tree::Union{AbstractExpressionNode,AbstractExpression},
m::AbstractSRRegressor;
kws...,
)
return convert(Symbolic, tree, get_options(m); kws...)
end

function Base.convert(
::Type{N}, x::Union{Number,Symbolic}, options::Options; kws...
) where {N<:AbstractExpressionNode}
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
return convert(N, x, options.operators; kws...)
end
function Base.convert(
::Type{N}, x::Union{Number,Symbolic}, m::AbstractSRRegressor; kws...
) where {N<:AbstractExpressionNode}
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
return convert(N, x, get_options(m); kws...)
end

Expand Down
17 changes: 14 additions & 3 deletions src/CheckConstraints.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module CheckConstraintsModule

using DynamicExpressions: AbstractExpressionNode, count_depth, tree_mapreduce
using DynamicExpressions:
AbstractExpressionNode, AbstractExpression, get_tree, count_depth, tree_mapreduce
using ..UtilsModule: vals
using ..CoreModule: Options
using ..ComplexityModule: compute_complexity, past_complexity_limit
Expand Down Expand Up @@ -70,6 +71,15 @@ function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Boo
end

"""Check if user-passed constraints are violated or not"""
function check_constraints(
ex::AbstractExpression,
options::Options,
maxsize::Int,
cursize::Union{Int,Nothing}=nothing,
)::Bool
tree = get_tree(ex)
return check_constraints(tree, options, maxsize, cursize)
end
function check_constraints(
tree::AbstractExpressionNode,
options::Options,
Expand All @@ -93,7 +103,8 @@ function check_constraints(
return true
end

check_constraints(tree::AbstractExpressionNode, options::Options)::Bool =
check_constraints(tree, options, options.maxsize)
check_constraints(
ex::Union{AbstractExpression,AbstractExpressionNode}, options::Options
)::Bool = check_constraints(ex, options, options.maxsize)

end
12 changes: 10 additions & 2 deletions src/Complexity.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module ComplexityModule

using DynamicExpressions: AbstractExpressionNode, count_nodes, tree_mapreduce
using DynamicExpressions:
AbstractExpression, AbstractExpressionNode, get_tree, count_nodes, tree_mapreduce
using ..CoreModule: Options, ComplexityMapping

function past_complexity_limit(tree::AbstractExpressionNode, options::Options, limit)::Bool
function past_complexity_limit(
tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options, limit
)::Bool
return compute_complexity(tree, options) > limit
end

Expand All @@ -14,6 +17,11 @@ By default, this is the number of nodes in a tree.
However, it could use the custom settings in options.complexity_mapping
if these are defined.
"""
function compute_complexity(
tree::AbstractExpression, options::Options; break_sharing=Val(false)
)
return compute_complexity(get_tree(tree), options; break_sharing)
end
function compute_complexity(
tree::AbstractExpressionNode, options::Options; break_sharing=Val(false)
)::Int
Expand Down
47 changes: 36 additions & 11 deletions src/ConstantOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module ConstantOptimizationModule

using LineSearches: LineSearches
using Optim: Optim
using DynamicExpressions: Node, count_constants
using DifferentiationInterface: value_and_gradient
using DynamicExpressions: Expression, Node, count_constants, get_constants, set_constants!
using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
using ..UtilsModule: get_birth_order
using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample
Expand All @@ -22,7 +23,7 @@ end
function dispatch_optimize_constants(
dataset::Dataset{T,L}, member::P, options::Options, idx
) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
nconst = count_constants(member.tree)
nconst = count_constants_for_optimization(member.tree)
nconst == 0 && return (member, 0.0)
if nconst == 1 && !(T <: Complex)
algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking())
Expand All @@ -39,36 +40,43 @@ function dispatch_optimize_constants(
idx,
)
end
count_constants_for_optimization(ex::Expression) = count_constants(ex)

function _optimize_constants(
dataset, member::P, options, algorithm, optimizer_options, idx
)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}}
tree = member.tree
eval_fraction = options.batching ? (options.batch_size / dataset.n) : 1.0
f(t) = eval_loss(t, dataset, options; regularization=false, idx=idx)::L
f = Evaluator(dataset, options, idx)
fg! = GradEvaluator(f)
obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing
f
else
Optim.only_fg!(fg!)
end
baseline = f(tree)
result = Optim.optimize(f, tree, algorithm, optimizer_options)
x0, refs = get_constants(tree)
result = Optim.optimize(obj, tree, algorithm, optimizer_options)
num_evals = result.f_calls * eval_fraction
# Try other initial conditions:
for _ in 1:(options.optimizer_nrestarts)
tmptree = copy(tree)
foreach(tmptree) do node
if node.degree == 0 && node.constant
node.val = (node.val) * (T(1) + T(1//2) * randn(T))
end
end
eps = randn(T, size(x0)...)
xt = @. x0 * (T(1) + T(1//2) * eps)
set_constants!(tmptree, xt, refs)
tmpresult = Optim.optimize(
f, tmptree, algorithm, optimizer_options; make_copy=false
obj, tmptree, algorithm, optimizer_options; make_copy=false
)
num_evals += tmpresult.f_calls * eval_fraction
# TODO: Does this need to take into account h_calls?

if tmpresult.minimum < result.minimum
result = tmpresult
end
end

if result.minimum < baseline
member.tree = result.minimizer
member.tree = result.minimizer::typeof(member.tree)
member.loss = eval_loss(member.tree, dataset, options; regularization=true, idx=idx)
member.score = loss_to_score(
member.loss, dataset.use_baseline, dataset.baseline_loss, member, options
Expand All @@ -80,4 +88,21 @@ function _optimize_constants(
return member, num_evals
end

struct Evaluator{D<:Dataset,O<:Options,I} <: Function
dataset::D
options::O
idx::I
end
(e::Evaluator)(t) = eval_loss(t, e.dataset, e.options; regularization=false, idx=e.idx)
struct GradEvaluator{F<:Evaluator} <: Function
f::F
end
function (g::GradEvaluator)(F, G, t)
(val, grad) = value_and_gradient(g.f, g.f.options.autodiff_backend, t)
if G !== nothing && grad !== nothing && grad.tree !== nothing
G .= grad.tree.gradient
end
return val
end

end
2 changes: 2 additions & 0 deletions src/Core.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module CoreModule

function create_expression end

include("Utils.jl")
include("ProgramConstants.jl")
include("Dataset.jl")
Expand Down
6 changes: 6 additions & 0 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import ...deprecate_varmap

- `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`.
- `y::AbstractVector{T}`: The desired output values, with shape `(n,)`.
- `index::Int`: The index of the output feature corresponding to this
dataset, if any.
- `n::Int`: The number of samples.
- `nfeatures::Int`: The number of features.
- `weighted::Bool`: Whether the dataset is non-uniformly weighted.
Expand Down Expand Up @@ -64,6 +66,7 @@ mutable struct Dataset{
}
@constfield X::AX
@constfield y::AY
@constfield index::Int
@constfield n::Int
@constfield nfeatures::Int
@constfield weighted::Bool
Expand Down Expand Up @@ -99,6 +102,7 @@ function Dataset(
X::AbstractMatrix{T},
y::Union{AbstractVector{T},Nothing}=nothing,
loss_type::Type{L}=Nothing;
index::Int=1,
weights::Union{AbstractVector{T},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
display_variable_names=variable_names,
Expand All @@ -123,6 +127,7 @@ function Dataset(
X,
y,
kws[:loss_type];
index,
weights,
variable_names,
display_variable_names,
Expand Down Expand Up @@ -206,6 +211,7 @@ function Dataset(
}(
X,
y,
index,
n,
nfeatures,
weighted,
Expand Down
Loading
Loading