Skip to content

Commit

Permalink
feat: weight TemplateExpression sampling by num nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Oct 17, 2024
1 parent d0a6420 commit a9e5332
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ using DynamicExpressions:
get_variable_names,
get_tree,
node_type,
eval_tree_array
eval_tree_array,
count_nodes
using DynamicExpressions.InterfacesModule:
ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments

Expand Down Expand Up @@ -264,8 +265,15 @@ and also return the symbol we mutated on so that we can put it back together lat
function MF.get_contents_for_mutation(ex::TemplateExpression, rng::AbstractRNG)
raw_contents = get_contents(ex)
function_keys = keys(raw_contents)
key_to_mutate = rand(rng, function_keys)

# Sample weighted by number of nodes in each subexpression
num_nodes = map(count_nodes, values(raw_contents))
weights = map(Base.Fix2(/, sum(num_nodes)), num_nodes)
cumsum_weights = cumsum(weights)
rand_val = rand(rng)
idx = findfirst(Base.Fix2(>=, rand_val), cumsum_weights)::Int

key_to_mutate = function_keys[idx]
return raw_contents[key_to_mutate], key_to_mutate
end

Expand Down

0 comments on commit a9e5332

Please sign in to comment.