Skip to content

Commit

Permalink
use namedtuple instead of immutabledict for analyses
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro committed Apr 1, 2022
1 parent e66a0ca commit edfe57e
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 174 deletions.
27 changes: 12 additions & 15 deletions docs/src/egraphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,22 +243,22 @@ Here's an example:
# This is a cost function that behaves like `astsize` but increments the cost
# of nodes containing the `^` operation. This results in a tendency to avoid
# extraction of expressions containing '^'.
function cost_function(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis})
function cost_function(n::ENodeTerm, g::EGraph)
cost = 1 + arity(n)

operation(n) == :^ && (cost += 2)

for id in arguments(n)
eclass = g[id]
# if the child e-class has not yet been analyzed, return +Inf
!hasdata(eclass, an) && (cost += Inf; break)
cost += last(getdata(eclass, an))
!hasdata(eclass, cost_function) && (cost += Inf; break)
cost += last(getdata(eclass, cost_function))
end
return cost
end

# All literal expressions (e.g `a`, 123, 0.42, "hello") have cost 1
cost_function(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) = 1
cost_function(n::ENodeLiteral, g::EGraph) = 1
```

## EGraph Analyses
Expand All @@ -271,10 +271,10 @@ Theoretically, the domain should form a [join semilattice](https://en.wikipedia.
Rewrites can cooperate with e-class analyses by depending on analysis facts and adding
equivalences that in turn establish additional facts.

In Metatheory.jl, EGraph Analyses are identified by a *type* that is subtype of `AbstractAnalysis`.
In Metatheory.jl, EGraph Analyses are identified by a unique name of type `Symbol`.
An [`EGraph`](@ref) can only contain one analysis per type.
The following functions define an interface for analyses based on multiple dispatch
on `AbstractAnalysis` types:
on `Val{analysis_name}` types:
* [islazy](@ref) should return true if the analysis should NOT be computed on-the-fly during egraphs operation, only when required.
* [make](@ref) should take an ENode and return a value from the analysis domain.
* [join](@ref) should return the semilattice join of two values in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?*)
Expand All @@ -292,14 +292,11 @@ the actual numeric result of the expressions in the EGraph, but we only care to
the symbolic expressions that will result in an even or an odd number.

Defining an EGraph Analysis is similar to the process [Mathematical Induction](https://en.wikipedia.org/wiki/Mathematical_induction).
To define a custom EGraph Analysis, one should start by defining a type that
subtypes `AbstractAnalysis` that will be used to identify this specific analysis and
to dispatch against the required methods.
To define a custom EGraph Analysis, one should start by defining a name of type `Symbol` that will be used to identify this specific analysis and to dispatch against the required methods.

```julia
using Metatheory
using Metatheory.EGraphs
abstract type OddEvenAnalysis <: AbstractAnalysis end
```

The next step, the base case of induction, is to define a method for
Expand All @@ -308,7 +305,7 @@ associate an analysis value only to the *literals* contained in the EGraph. To d
take advantage of multiple dispatch against `ENodeLiteral`.

```julia
function EGraphs.make(an::Type{OddEvenAnalysis}, g::EGraph, n::ENodeLiteral)
function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeLiteral)
if n.value isa Integer
return iseven(n.value) ? :even : :odd
else
Expand Down Expand Up @@ -336,7 +333,7 @@ From the definition of an [ENode](@ref), we know that children of ENodes are alw
to EClasses in the EGraph.

```julia
function EGraphs.make(an::Type{OddEvenAnalysis}, g::EGraph, n::ENodeTerm)
function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeTerm)
# Let's consider only binary function call terms.
if exprhead(n) == :call && arity(n) == 2
op = operation(n)
Expand All @@ -347,8 +344,8 @@ function EGraphs.make(an::Type{OddEvenAnalysis}, g::EGraph, n::ENodeTerm)

# Get the corresponding OddEvenAnalysis value of the children
# defaulting to nothing
ldata = getdata(l, an, nothing)
rdata = getdata(r, an, nothing)
ldata = getdata(l, :OddEvenAnalysis, nothing)
rdata = getdata(r, :OddEvenAnalysis, nothing)

if ldata isa Symbol && rdata isa Symbol
if op == :*
Expand All @@ -375,7 +372,7 @@ how to extract a single value out of the many analyses values contained in an EG
We do this by defining a method for [join](@ref).

```julia
function EGraphs.join(an::Type{OddEvenAnalysis}, a, b)
function EGraphs.join(::Val{:OddEvenAnalysis}, a, b)
if a == b
return a
else
Expand Down
2 changes: 0 additions & 2 deletions src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ export analyze!
export extract!
export astsize
export astsize_inv
export AbstractAnalysis
export MetadataAnalysis
export getcost!

include("ematch.jl")
Expand Down
115 changes: 56 additions & 59 deletions src/EGraphs/analysis.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,65 @@
analysis_reference(x::Symbol) = Val(x)
analysis_reference(x::Function) = x
analysis_reference(x) = error("$x is not a valid analysis reference")

"""
islazy(an::Type{<:AbstractAnalysis})
islazy(::Val{analysis_name})
Should return `true` if the EGraph Analysis `an` is lazy
and false otherwise. A *lazy* EGraph Analysis is computed
only when [analyze!](@ref) is called. *Non-lazy*
analyses are instead computed on-the-fly every time ENodes are added to the EGraph or
EClasses are merged.
"""
islazy(an::Type{<:AbstractAnalysis})::Bool = false
islazy(::Val{analysis_name}) where {analysis_name} = false
islazy(analysis_name) = islazy(Val(analysis_name))

"""
modify!(an::Type{<:AbstractAnalysis}, g, id)
modify!(::Val{analysis_name}, g, id)
The `modify!` function for EGraph Analysis can optionally modify the eclass
`g[id]` after it has been analyzed, typically by adding an ENode.
It should be **idempotent** if no other changes occur to the EClass.
(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)).
"""
modify!(analysis::Type{<:AbstractAnalysis}, g, id) = nothing
modify!(::Val{analysis_name}, g, id) where {analysis_name} = nothing
modify!(an, g, id) = modify!(analysis_reference(an), g, id)


"""
join(an::Type{<:AbstractAnalysis}, a, b)
join(::Val{analysis_name}, a, b)
Joins two analyses values into a single one, used by [analyze!](@ref)
when two eclasses are being merged or the analysis is being constructed.
"""
join(analysis::Type{<:AbstractAnalysis}, a, b) = error("Analysis does not implement join")
join(analysis::Val{analysis_name}, a, b) where {analysis_name} =
error("Analysis $analysis_name does not implement join")
join(an, a, b) = join(analysis_reference(an), a, b)

"""
make(an::Type{<:AbstractAnalysis}, g, n)
make(::Val{analysis_name}, g, n)
Given an ENode `n`, `make` should return the corresponding analysis value.
"""
make(analysis::Type{<:AbstractAnalysis}, g, n) = error("Analysis does not implement make")

make(::Val{analysis_name}, g, n) where {analysis_name} = error("Analysis $analysis_name does not implement make")
make(an, g, n) = make(analysis_reference(an), g, n)

# TODO default analysis for metadata here
abstract type MetadataAnalysis <: AbstractAnalysis end

analyze!(g::EGraph, an::Type{<:AbstractAnalysis}, id::EClassId) = analyze!(g, an, reachable(g, id))
analyze!(g::EGraph, an::Type{<:AbstractAnalysis}) = analyze!(g, an, collect(keys(g.classes)))
analyze!(g::EGraph, analysis_ref, id::EClassId) = analyze!(g, analysis_ref, reachable(g, id))
analyze!(g::EGraph, analysis_ref) = analyze!(g, analysis_ref, collect(keys(g.classes)))


"""
analyze!(egraph, analysis, [ECLASS_IDS])
analyze!(egraph, analysis_name, [ECLASS_IDS])
Given an [EGraph](@ref) and an `analysis` of type `<:AbstractAnalysis`,
Given an [EGraph](@ref) and an `analysis` identified by name `analysis_name`,
do an automated bottom up trasversal of the EGraph, associating a value from the
domain of `analysis` to each ENode in the egraph by the [make](@ref) function.
domain of analysis to each ENode in the egraph by the [make](@ref) function.
Then, for each [EClass](@ref), compute the [join](@ref) of the children ENodes analyses values.
After `analyze!` is called, an analysis value will be associated to each EClass in the EGraph.
One can inspect and retrieve analysis values by using [hasdata](@ref) and [getdata](@ref).
Note that an [EGraph](@ref) can only contain one analysis of type `an`.
One can inspect and retrieve analysis values by using [hasdata](@ref) and [getdata](@ref).
"""
function analyze!(g::EGraph, an::Type{<:AbstractAnalysis}, ids::Vector{EClassId})
push!(g.analyses, an)
function analyze!(g::EGraph, analysis_ref, ids::Vector{EClassId})
push!(g.analyses, analysis_ref)
ids = sort(ids)
# @assert isempty(g.dirty)

Expand All @@ -66,12 +70,12 @@ function analyze!(g::EGraph, an::Type{<:AbstractAnalysis}, ids::Vector{EClassId}
for id in ids
eclass = g[id]
id = eclass.id
pass = mapreduce(x -> make(an, g, x), (x, y) -> join(an, x, y), eclass)
pass = mapreduce(x -> make(analysis_ref, g, x), (x, y) -> join(analysis_ref, x, y), eclass)
# pass = make_pass(G, analysis, find(G,id))

# if pass !== missing
if !isequal(pass, getdata(eclass, an, missing))
setdata!(eclass, an, pass)
if !isequal(pass, getdata(eclass, analysis_ref, missing))
setdata!(eclass, analysis_ref, pass)
did_something = true
push!(g.dirty, id)
end
Expand All @@ -81,7 +85,7 @@ function analyze!(g::EGraph, an::Type{<:AbstractAnalysis}, ids::Vector{EClassId}
for id in ids
eclass = g[id]
id = eclass.id
if !hasdata(eclass, an)
if !hasdata(eclass, analysis_ref)
error("failed to compute analysis for eclass ", id)
end
end
Expand All @@ -93,65 +97,60 @@ end
A basic cost function, where the computed cost is the size
(number of children) of the current expression.
"""
function astsize(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis})
function astsize(n::ENodeTerm, g::EGraph)
cost = 1 + arity(n)
for id in arguments(n)
eclass = g[id]
!hasdata(eclass, an) && (cost += Inf; break)
cost += last(getdata(eclass, an))
!hasdata(eclass, astsize) && (cost += Inf; break)
cost += last(getdata(eclass, astsize))
end
return cost
end

astsize(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) = 1
astsize(n::ENodeLiteral, g::EGraph) = 1

"""
A basic cost function, where the computed cost is the size
(number of children) of the current expression, times -1.
Strives to get the largest expression
"""
function astsize_inv(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis})
function astsize_inv(n::ENodeTerm, g::EGraph)
cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize
for id in arguments(n)
eclass = g[id]
!hasdata(eclass, an) && (cost += Inf; break)
cost += last(getdata(eclass, an))
!hasdata(eclass, astsize_inv) && (cost += Inf; break)
cost += last(getdata(eclass, astsize_inv))
end
return cost
end

astsize_inv(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) = -1
astsize_inv(n::ENodeLiteral, g::EGraph) = -1


"""
An [`AbstractAnalysis`](@ref) that computes the cost of expression nodes
and chooses the node with the smallest cost for each E-Class.
This abstract type is parametrised by a function F.
This is useful for the analysis storage in [`EClass`](@ref)
When passing a function to analysis functions it is considered as a cost function
"""
abstract type ExtractionAnalysis{F} <: AbstractAnalysis end

make(a::Type{ExtractionAnalysis{F}}, g::EGraph, n::AbstractENode) where {F} = (n, F(n, g, a))
make(f::Function, g::EGraph, n::AbstractENode) = (n, f(n, g))

join(a::Type{<:ExtractionAnalysis}, from, to) = last(from) <= last(to) ? from : to
join(f::Function, from, to) = last(from) <= last(to) ? from : to

islazy(a::Type{<:ExtractionAnalysis}) = true
islazy(::Function) = true

function rec_extract(g::EGraph, an, id::EClassId; cse_env = nothing)
function rec_extract(g::EGraph, costfun, id::EClassId; cse_env = nothing)
eclass = g[id]
if !isnothing(cse_env) && haskey(cse_env, id)
(sym, _) = cse_env[id]
return sym
end
anval = getdata(eclass, an, (nothing, Inf))
anval = getdata(eclass, costfun, (nothing, Inf))
(n, ck) = anval
ck == Inf && error("Infinite cost when extracting enode")

if n isa ENodeLiteral
return n.value
elseif n isa ENodeTerm
children = map(child -> rec_extract(g, an, child; cse_env = cse_env), arguments(n))
meta = getdata(eclass, MetadataAnalysis, nothing)
children = map(child -> rec_extract(g, costfun, child; cse_env = cse_env), arguments(n))
meta = getdata(eclass, :metadata_analysis, nothing)
T = termtype(n)
egraph_reconstruct_expression(T, operation(n), children; metadata = meta, exprhead = exprhead(n))
else
Expand All @@ -164,54 +163,52 @@ Given a cost function, extract the expression
with the smallest computed cost from an [`EGraph`](@ref)
"""
function extract!(g::EGraph, costfun::Function; root = -1, cse = false)
a = ExtractionAnalysis{costfun}
if root == -1
root = g.root
end
analyze!(g, a, root)
analyze!(g, costfun, root)
if cse
# TODO make sure there is no assignments/stateful code!!
cse_env = OrderedDict{EClassId,Tuple{Symbol,Any}}() #
collect_cse!(g, a, root, cse_env, Set{EClassId}())
collect_cse!(g, costfun, root, cse_env, Set{EClassId}())
# @show root
# @show cse_env

body = rec_extract(g, a, root; cse_env = cse_env)
body = rec_extract(g, costfun, root; cse_env = cse_env)

assignments = [Expr(:(=), name, val) for (id, (name, val)) in cse_env]
# return body
Expr(:let, Expr(:block, assignments...), body)
else
return rec_extract(g, a, root)
return rec_extract(g, costfun, root)
end
end


# Builds a dict e-class id => (symbol, extracted term) of common subexpressions in an e-graph
function collect_cse!(g::EGraph, an, id, cse_env, seen)
function collect_cse!(g::EGraph, costfun, id, cse_env, seen)
eclass = g[id]
anval = getdata(eclass, an, (nothing, Inf))
anval = getdata(eclass, costfun, (nothing, Inf))
(cn, ck) = anval
ck == Inf && error("Error when computing CSE")
if cn isa ENodeTerm
if id in seen
cse_env[id] = (gensym(), rec_extract(g, an, id))#, cse_env=cse_env)) # todo generalize symbol?
cse_env[id] = (gensym(), rec_extract(g, costfun, id))#, cse_env=cse_env)) # todo generalize symbol?
return
end
for child_id in arguments(cn)
collect_cse!(g, an, child_id, cse_env, seen)
collect_cse!(g, costfun, child_id, cse_env, seen)
end
push!(seen, id)
end
end

getcost!(g::EGraph, costfun::Function; root = -1) = getcost!(g, ExtractionAnalysis{costfun}; root = root)

function getcost!(g::EGraph, analysis::Type{ExtractionAnalysis{F}}; root = -1) where {F}
function getcost!(g::EGraph, costfun; root = -1) where {F}
if root == -1
root = g.root
end
analyze!(g, analysis, root)
bestnode, cost = getdata(g[root], analysis)
analyze!(g, costfun, root)
bestnode, cost = getdata(g[root], costfun)
return cost
end
Loading

0 comments on commit edfe57e

Please sign in to comment.