Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize Enzyme rules #155

Open
GiggleLiu opened this issue Apr 18, 2023 · 0 comments
Open

Customize Enzyme rules #155

GiggleLiu opened this issue Apr 18, 2023 · 0 comments

Comments

@GiggleLiu
Copy link
Collaborator

GiggleLiu commented Apr 18, 2023

The following is an example of unsuccessful trail. The return value is not backwarded properly.

julia> using Enzyme, Enzyme.EnzymeRules, OMEinsum

julia> function EnzymeRules.augmented_primal(
                config::EnzymeRules.ConfigWidth{1},
                func::Const{typeof(einsum)}, ::Type{<:Duplicated}, 
                code::Const, xs::Duplicated, size_dict)
           @info("In custom augmented primal rule.")
           # Compute primal
           if EnzymeRules.needs_primal(config)
               primal = func.val(code.val, xs.val, size_dict.val); 
                        shadow=zero(primal)
           else
               primal, shadow = nothing, nothing
           end
           # Save x in tape if x will be overwritten
           @info EnzymeRules.overwritten(config)
           if EnzymeRules.overwritten(config)[3]
               tape = copy(xs.val)
           else
               tape = nothing
           end
           return EnzymeRules.AugmentedReturn(primal, shadow, tape)
       end

julia> function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1},
               func::Const{typeof(einsum)}, dret::Type{<:Duplicated}, tape,
               code::Const,
               xs::Duplicated, size_dict)
   @info """In custom reverse rule: $config.
I was expecting `drect` to be an object rather than a type!!!!"""
   xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val
   for i=1:length(xs.val)
       xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
             xval, OMEinsum.getiy(code.val), size_dict.val, conj(dret.dval), i)
   end
   return ()
end

julia> x = randn(3, 3);

julia> gx = zero(x);

julia> autodiff(ReverseWithPrimal, x->sum(einsum(ein"ii->i", x, Dict('i'=>3))),
                                     Duplicated((x,), (gx,)))
[ Info: In custom augmented primal rule.
[ Info: (false, false, false, true)
[ Info: In custom reverse rule: ConfigWidth{1, true, true, (false, 
              false, false, true)}(). 
    I was expecting `drect` to be an object rather than a type!!!!
ERROR: type DataType has no field dval
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant