Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Jun 3, 2024
1 parent cbdc20c commit 6c3e687
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 27 deletions.
21 changes: 16 additions & 5 deletions examples/manualad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@ using OMEinsum: cost_and_gradient
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))

# evaluate the cost and the gradient of leaves
function gf(code, xs, res, ȳ = OMEinsum.init_gradient(code, xs))
cost, tree = OMEinsum.gradient_tree(code, xs, ȳ)
# extract the gradients on leaves (i.e. the input tensors).
return cost, OMEinsum.extract_leaves!(code, tree, res)
end

using Zygote
xA, xB, xC = randn(2, 3), randn(3, 4), randn(4, 2)
= fill(1.0)
function gfunc(A, B, C)
cost, (gA, gB, gC) = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C), ȳ)
= fill(one(eltype(A)))
res = Zygote.Buffer(Any[nothing, nothing, nothing])
cost, (gA, gB, gC) = gf(ein"(ij, jk), ki->", (A, B, C), res, ȳ)
@info "summing"
return sum(gA .* xA) + sum(gB .* xB) + sum(gC .* xC)
end
Zygote.gradient(gfunc, A, B, C)

using ReverseDiff
ReverseDiff.@grad_from_chainrules einsum(args...; kwargs...)

ReverseDiff.gradient(gfunc, (A, B, C))
zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C)
mg = gf(ein"(ij, jk), ki->", (A, B, C), Any[nothing, nothing, nothing])

using FiniteDiff
h = FiniteDiff.finite_difference_hessian(v->ein"(ij, jk), ki->"(reshape(v[1:6], 2, 3), reshape(v[7:18], 3, 4), reshape(v[19:end], 4, 2))[], [vec(A); vec(B); vec(C)])
4 changes: 2 additions & 2 deletions src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using OMEinsumContractionOrders
using AbstractTrees
import LinearAlgebra: BlasFloat

export @ein_str, @ein, @ein!, ein
export @ein_str, @ein, @ein!, ein, @optein_str
export einsum!, einsum, dynamic_einsum
export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum
export getiyv, getixsv, uniquelabels, labeltype
Expand Down Expand Up @@ -40,7 +40,7 @@ include("interfaces.jl")
include("einsequence.jl")
include("slicing.jl")
include("autodiff.jl")
include("manualad.jl")
include("bp.jl")

include("contractionorder.jl")

Expand Down
16 changes: 15 additions & 1 deletion src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,18 @@ end
@non_differentiable get_size_dict!(::Any, ::Any, ::Any)
@non_differentiable DynamicEinCode(::Any, ::Any)
@non_differentiable DynamicEinCode(::Any)
@non_differentiable getixsv(::Any)
@non_differentiable getixsv(::Any)

echo(x; tag="echo") = x
function ChainRulesCore.rrule(::typeof(echo), x; tag="echo")
@info "$tag: $x"
x, function (dy)
@info "$tag (back): x̄ = $dy"
return (NoTangent(), dy)
end
end

macro echo(var)
name = QuoteNode(var)
esc(:($var = $echo($var; tag="$($name)")))
end
86 changes: 68 additions & 18 deletions src/manualad.jl → src/bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ struct CacheTree{T}
end
CacheTree(content::AbstractArray{T}, siblings) where T = CacheTree(content, CacheTree{T}[siblings...])

"""
cached_einsum(code, xs, size_dict)
Compute the einsum contraction and cache the intermediate contraction results.
### Arguments
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `xs`: The input tensors.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.
### Returns
- `CacheTree`: The cached tree storing the intermediate results.
"""
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
# slicing is not supported yet.
if length(se.slicing) != 0
Expand All @@ -34,17 +47,36 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
end
end

# computed gradient tree by back propagation
function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
"""
back_propagate(f, code, cache, ȳ, size_dict)
Back propagate the message `ȳ` through the cached tree `cache` and return a tree storing the intermediate messages.
The message can be gradients et al.
### Arguments
- `f`: The back-propagation rule. The signature is `f(eins, xs, y, size_dict, dy) -> dxs`, where
- `eins`: The contraction code at the current node.
- `xs`: The input tensors at the current node.
- `y`: The output tensor at the current node.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.
- `dy`: The message on the output tensor (`y`) to back-propagate through the current node.
- `dxs`: The message on the input tensors (`xs`) as the result of back-propagation.
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `cache`: The cached intermediate results, which can be generated by [`cached_einsum`](@ref).
- `ȳ`: The message to back-propagate.
- `size_dict`: The size dictionary, which maps the label to the size of the corresponding dimension.
### Returns
- `CacheTree`: The tree storing the intermediate messages.
"""
function back_propagate(f, se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
if length(se.slicing) != 0
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
end
return generate_gradient_tree(se.eins, cache, dy, size_dict)
return back_propagate(f, se.eins, cache, dy, size_dict)
end

# recursively compute the gradients and store it into a tree.
# also known as the back-propagation algorithm.
function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
function back_propagate(f, code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
if isleaf(code)
return CacheTree(dy, CacheTree{T}[])
else
Expand All @@ -58,28 +90,44 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
# ...
# ```
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
dxs = einsum_backward_rule(rootcode(code), xs, cache.content, size_dict, dy)
return CacheTree(dy, generate_gradient_tree.(siblings(code), cache.siblings, dxs, Ref(size_dict)))
dxs = f(rootcode(code), xs, cache.content, size_dict, dy)
return CacheTree(dy, back_propagate.(Ref(f), siblings(code), cache.siblings, dxs, Ref(size_dict)))
end
end

# a unified interface of the backward rules for real numbers and tropical numbers
function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy)
return ntuple(i -> einsum_grad(getixs(eins), xs, getiy(eins), size_dict, dy, i), length(xs))
end

# the main function for generating the gradient tree.
function gradient_tree(code::AbstractEinsum, xs, ȳ)
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
size_dict = get_size_dict!(getixsv(code), xs, Dict{labeltype(code), Int}())
# forward compute and cache intermediate results.
cache = cached_einsum(code, xs, size_dict)
# back-propagate
return copy(cache.content), generate_gradient_tree(code, cache, ȳ, size_dict)
function bprule(eins, @nospecialize(xs), @nospecialize(y), size_dict, @nospecialize(dy))
res = ntuple(i -> einsum_grad(getixs(eins), xs, getiy(eins), size_dict, dy, i), length(xs))
return res
end
return copy(cache.content), back_propagate(bprule, code, cache, ȳ, size_dict)
end

# evaluate the cost and the gradient of leaves
function cost_and_gradient(code, xs, ȳ = init_gradient(code, xs))
"""
cost_and_gradient(code, xs, ȳ)
Compute the cost and the gradients w.r.t the input tensors `xs`.
### Arguments
- `code`: The contraction code, which can be a `NestedEinsum` or a `SlicedEinsum`.
- `xs`: The input tensors.
- `ȳ`: The message to back-propagate. Default is `1`.
### Returns
- `cost`: The cost of the contraction.
- `grads`: The gradients w.r.t the input tensors.
"""
function cost_and_gradient(code, xs, ȳ = nothing)
if=== nothing
= init_gradient(code, xs)
@assert ndims(ȳ) == 0 "The output must be a scalar! Or you need to feed the gradient manually. Got: $(ndims(ȳ))!"
end
cost, tree = gradient_tree(code, xs, ȳ)
# extract the gradients on leaves (i.e. the input tensors).
return cost, extract_leaves(code, tree)
Expand All @@ -94,7 +142,7 @@ function init_gradient(code, xs)
end

# since slicing is not supported, we forward it to NestedEinsum.
extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(rootcode(code), cache)
extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache)

# extract gradients on leaf nodes.
function extract_leaves(code::NestedEinsum, cache::CacheTree)
Expand All @@ -108,7 +156,9 @@ function extract_leaves!(code, cache, res)
res[tensorindex(code)] = cache.content
else
# resurse deeper
extract_leaves!.(siblings(code), cache.siblings, Ref(res))
for (subcode, sib) in zip(siblings(code), cache.siblings)
extract_leaves!(subcode, sib, res)
end
end
return res
end
10 changes: 10 additions & 0 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ macro ein_str(s::AbstractString)
ein(s)
end

"""
optein"ij,jk,kl -> ik"(A, B, C)
String macro interface that similar to [`@ein_str`](@ref), with optimized contraction order (dimensions are assumed to be uniform).
"""
macro optein_str(s::AbstractString)
code = ein(s)
optimize_code(code, uniformsize(code, 20), TreeSA(; ntrials=1, niters=10)).eins
end

function ein(s::AbstractString)
s = replace(replace(s, "\n" => ""), " "=>"")
m = match(r"([\(\)a-z,α-ω]*)->([a-zα-ω]*)", s)
Expand Down
2 changes: 1 addition & 1 deletion src/slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} wher
it = SliceIterator(se, size_dict)
res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv))
eins_sliced = drop_slicedim(se.eins, se.slicing)
for slicemap in it
for slicemap in it # `slicemap` is a Dict storing a mapping from sliced_labels to the current slice index
# NOTE: @debug will break Zygote
# @debug "computing slice $k/$(length(it))"
xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs))
Expand Down
17 changes: 17 additions & 0 deletions test/bp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using OMEinsum, Test, Zygote

@testset "bp check" begin
A, B, C = randn(2, 3), randn(3, 4), randn(4, 2)
cost0 = ein"(ij, jk), ki->"(A, B, C)[]
zg = Zygote.gradient((a, b, c)->ein"(ij, jk), ki->"(a, b, c)[], A, B, C)
cost, mg = OMEinsum.cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
@test cost[] cost0
@test all(zg .≈ mg)

code = optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA())
cost0 = code(A, B, C)[]
zg = Zygote.gradient((a, b, c)->code(a, b, c)[], A, B, C)
cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C))
@test cost[] cost0
@test all(zg .≈ mg)
end
5 changes: 5 additions & 0 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ using OMEinsum: get_size_dict
ijk->
ikl" == ein"ijk,ijk->ikl"
end

@testset "opein" begin
code = optein"ij,jk,ki->"
@test code isa NestedEinsum
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ end
include("contractionorder.jl")
end

@testset "back propagation" begin
include("bp.jl")
end

@testset "docstring" begin
Documenter.doctest(OMEinsum; manual=false)
end

0 comments on commit 6c3e687

Please sign in to comment.