Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add import_frule (reprised) #1333

Merged
merged 12 commits into from
May 7, 2024
Merged

add import_frule (reprised) #1333

merged 12 commits into from
May 7, 2024

Conversation

CarloLucibello
Copy link
Collaborator

@CarloLucibello CarloLucibello commented Mar 7, 2024

Continuation of #996, partially addressing #583. I made the following changes:

  • rebased
  • removed the wip import_rrule
  • moved everything to an extension (this required creating a dummy function in Enzyme.jl that is overloaded in the extension)
  • added some minimal tests

TODO

  • more test coverage
  • fix broken BatchDuplicated test
  • handle functions not defined in the Main module

@wsmoses
Copy link
Member

wsmoses commented Mar 11, 2024

Is there a reason the previous reverse mode one failed?

@CarloLucibello
Copy link
Collaborator Author

CarloLucibello commented Mar 13, 2024

It think it had a syntax error, but also since I'm not really sure what I'm doing I wanted to focus on one thing at a time.
Will file a PR next with the reverse mode.

Enzyme.@import_frule typeof(Base.sort) Any
for Tret in (Duplicated, DuplicatedNoNeed)
for Tx in (Duplicated, BatchDuplicated)
test_forward(sort, Tret, (x, Tx))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fails for Tx == BatchDuplicated. Any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the failure?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> x = [1.0, 2.0, 0.0];

julia> Enzyme.@import_frule typeof(Base.sort)  Any

julia> test_forward(Base.sort, Duplicated, (x, BatchDuplicated))
test_forward: sort with return activity Duplicated on (::Vector{Float64}, BatchDuplicated): Error During Test at /Users/carlo/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:68
  Got exception outside of a @test
  DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 6
  Stacktrace:
    [1] _bcs1
      @ ./broadcast.jl:555 [inlined]
    [2] _bcs
      @ ./broadcast.jl:549 [inlined]
    [3] broadcast_shape
      @ ./broadcast.jl:543 [inlined]
    [4] combine_axes
      @ ./broadcast.jl:524 [inlined]
    [5] instantiate
      @ ./broadcast.jl:306 [inlined]
    [6] materialize
      @ ./broadcast.jl:903 [inlined]
    [7] (::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}})(ε::Float64)
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/grad.jl:48
    [8] newf
      @ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:186 [inlined]
    [9] macro expansion
      @ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:135 [inlined]
   [10] __broadcast
      @ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:123 [inlined]
   [11] _broadcast
      @ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:119 [inlined]
   [12] copy
      @ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:60 [inlined]
   [13] materialize
      @ ./broadcast.jl:903 [inlined]
   [14] _eval_function(m::FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}, f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64, step::Float64)
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:249
   [15] _estimate_magnitudes(m::FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}, f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64)
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:378
   [16] estimate_step(m::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64)
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:365
   [17] (::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}})(f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64)
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:193
   [18] _jvp
      @ ~/.julia/packages/FiniteDifferences/zWRHl/src/grad.jl:48 [inlined]
   [19] jvp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, ::Tuple{Vector{Float64}, Tuple{Vector{Float64}, Vector{Float64}}})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/grad.jl:60
   [20] _fd_forward(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, rettype::Type, y::Vector{Float64}, activities::Tuple{Const{typeof(sort)}, BatchDuplicated{Vector{Float64}, 2}})
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/finite_difference_calls.jl:30
   [21] macro expansion
      @ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:77 [inlined]
   [22] macro expansion
      @ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [23] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll}; fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, fkwargs::@NamedTuple{}, rtol::Float64, atol::Float64, testset_name::Nothing)
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:70
   [24] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll})
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:53

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethaxen this looks like a bug in enzymetestutils?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, no. As I asked in #1264 (comment), does Enzyme now allow one to mix Duplicated and BatchDuplicated? That used to cause an error. EnzymeTestUtils assumes these are not mixed and provides are_activities_compatible to skip cases in TestSets that would mix them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no you need one single batch width for the whole program (e.g. all duplicated, or all batchduplicated with the same width).

However, we do automatically upgrade a duplicated/dupicatednoneed return to whatever the width of the args were, if they were batch (since no data is in the return).

It would be nice for this shorthand to work. But indeed @CarloLucibello the alternate is testing {Duplicated, DuplicatedNoNeed} ret x {Const, Duplicated} input and {BatchDuplicated, BatchDuplicatedNoNeed} ret x {Const, BatchDuplicated} input

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Irrespectively, either we should upgrade testutils to handle this case (since here enzyme actually supports this, by upgrading to batchduplicated), or we should throw a nicer error here rather than bailing out in finite differences internals.

Copy link
Collaborator Author

@CarloLucibello CarloLucibello Mar 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. These 4 are fine

test_forward(Base.sort, Duplicated, (x, Duplicated))
test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicated))
test_forward(Base.sort, DuplicatedNoNeed, (x, Duplicated))
test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated))

But anything involving Const errors

test_forward(Base.sort, DuplicatedNoNeed, (x, Const))
UndefVarError: `ChainRulesCore` not defined
  Stacktrace:
    [1] forward
      @ ~/.julia/dev/Enzyme/ext/EnzymeChainRulesCoreExt.jl:67
    [2] forward
      @ ~/.julia/dev/Enzyme/ext/EnzymeChainRulesCoreExt.jl:62 [inlined]
    [3] call_with_kwargs
      @ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:64 [inlined]
    [4] fwddiffejulia_call_with_kwargs_10402wrap
      @ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:0
    [5] macro expansion
      @ ~/.julia/dev/Enzyme/src/compiler.jl:5440 [inlined]
    [6] enzyme_call
      @ ~/.julia/dev/Enzyme/src/compiler.jl:5118 [inlined]
    [7] ForwardModeThunk
      @ ~/.julia/dev/Enzyme/src/compiler.jl:5003 [inlined]
    [8] autodiff
      @ ~/.julia/dev/Enzyme/src/Enzyme.jl:384 [inlined]
    [9] autodiff(::ForwardMode{FFIABI}, ::EnzymeTestUtils.var"#call_with_kwargs#39"{@NamedTuple{}}, ::Type{DuplicatedNoNeed}, ::Const{typeof(sort)}, ::Const{Vector{Float64}})
      @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:287
   [10] macro expansion
      @ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:79 [inlined]
   [11] macro expansion
      @ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [12] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll}; fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, fkwargs::@NamedTuple{}, rtol::Float64, atol::Float64, testset_name::Nothing)
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:70
   [13] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll})
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:53
   [14] macro expansion
      @ ~/.julia/dev/Enzyme/test/ext/chainrulescore.jl:34 [inlined]
   [15] macro expansion
      @ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [16] macro expansion
      @ ~/.julia/dev/Enzyme/test/ext/chainrulescore.jl:25 [inlined]
   [17] macro expansion
      @ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [18] top-level scope
      @ ~/.julia/dev/Enzyme/test/ext/chainrulescore.jl:11

I get similar errors for

 test_forward(Base.sort, DuplicatedNoNeed, (x, Const))
 test_forward(Base.sort, Duplicated, (x, Const))
 test_forward(Base.sort, BatchDuplicatedNoNeed, (x, Const))
 test_forward(Base.sort, BatchDuplicated, (x, Const))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, yeah we definitely need that one to pass. Maybe try changing
:($ty <: Const ? ChainRulesCore.NoTangent() : $val.dval) into
:($ty <: Const ? $(ChainRulesCore.NoTangent()) : $val.dval) to fix it? (and similar throughout).

cc @vchuravy since Julia macros are not my forte

test/ext/chainrulescore.jl Outdated Show resolved Hide resolved
ext/EnzymeChainRulesCoreExt.jl Show resolved Hide resolved
ext/EnzymeChainRulesCoreExt.jl Show resolved Hide resolved
ext/EnzymeChainRulesCoreExt.jl Show resolved Hide resolved
ext/EnzymeChainRulesCoreExt.jl Outdated Show resolved Hide resolved
ext/EnzymeChainRulesCoreExt.jl Outdated Show resolved Hide resolved
test/ext/chainrulescore.jl Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved
end

quote
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)}
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...);
kwargs...) where {RetAnnotation,
FA<:Annotation{<:$(esc(fn))},
$(anns...)}

batchsize = same_or_one(1, $(vals...))
if batchsize == 1
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval
cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...)
cres = $ChainRulesCore.frule((dfn, $(tangents...)), fn.val, $(primals...);
kwargs...)

ntuple(Val(batchsize)) do i
Base.@_inline_meta
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
return $ChainRulesCore.frule((dfn, $(tangentsi...)), fn.val,
$(primals...); kwargs...)

cres1 = begin
i = 1
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
$ChainRulesCore.frule((dfn, $(tangentsi...)), fn.val, $(primals...);
kwargs...)

dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
end
batches = ntuple(Val(batchsize-1)) do j
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
batches = ntuple(Val(batchsize-1)) do j
batches = ntuple(Val(batchsize - 1)) do j

Comment on lines +31 to +32
@testset "batch duplicated" begin
x = [1.0, 2.0, 0.0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@testset "batch duplicated" begin
x = [1.0, 2.0, 0.0]
@testset "batch duplicated" begin
x = [1.0, 2.0, 0.0]

Comment on lines +3027 to +3034
# TEST EXTENSIONS
@static if VERSION ≥ v"1.9-"
using SpecialFunctions
@testset "SpecialFunctions ext" begin
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1]
test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5)
test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
# TEST EXTENSIONS
@static if VERSION v"1.9-"
using SpecialFunctions
@testset "SpecialFunctions ext" begin
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1]
test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5)
test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5)
end
res = first(Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar),
Active((1, 2.2, 3, 4.4, 5, 6.6))))

Comment on lines +3036 to +3038
using ChainRulesCore
@testset "ChainRulesCore ext" begin
include("ext/chainrulescore.jl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
using ChainRulesCore
@testset "ChainRulesCore ext" begin
include("ext/chainrulescore.jl")
@test res[2][1] == 0
@test res[2][2] 2.0
@test res[2][3] 0
@test res[2][4] 4.0
@test res[2][5] 0
@test res[2][6] 6.0

Comment on lines +3040 to +3041
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end

end



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
# TEST EXTENSIONS
@static if VERSION v"1.9-"
using SpecialFunctions
@testset "SpecialFunctions ext" begin
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1]
test_scalar(lgabsg, 1.0; rtol=1.0e-5, atol=1.0e-5)
test_scalar(lgabsg, 1.0f0; rtol=1.0e-5, atol=1.0e-5)
end

Enzyme.@import_frule typeof(Base.sort) Any

test_forward(Base.sort, Duplicated, (x, Duplicated))
# Unsupported by EnzymeTestUtils
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethaxen would it be possible to add support for duplicatednoneed and variants to enzymetestutils.

I'm going to approve/merge this without those, but it would be nice to enable later

@wsmoses wsmoses merged commit ff878b5 into main May 7, 2024
34 of 48 checks passed
@wsmoses wsmoses deleted the cl/importfrule branch May 7, 2024 22:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants