Skip to content

Commit

Permalink
!defer_within_autodiff -> within_autodiff_rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx authored and vchuravy committed Jan 17, 2025
1 parent aa6ef9c commit 9f54068
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
inactive_rules::Bool
broadcast_rewrite::Bool

# When true, leave the check for within_autodiff to the handler.
defer_within_autodiff::Bool
# When false, leave the check for within_autodiff to the handler.
within_autodiff_rewrite::Bool

handler::T
end
Expand Down Expand Up @@ -173,7 +173,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
defer_within_autodiff::Bool = false,
within_autodiff_rewrite::Bool = true,
handler = nothing
)
@assert world <= Base.get_world_counter()
Expand Down Expand Up @@ -234,7 +234,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool,
defer_within_autodiff::Bool,
within_autodiff_rewrite::Bool,
handler
)
end
Expand All @@ -246,9 +246,9 @@ EnzymeInterpreter(
mode::API.CDerivativeMode,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
defer_within_autodiff::Bool = false,
within_autodiff_rewrite::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, defer_within_autodiff, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)

function EnzymeInterpreter(interp::EnzymeInterpreter;

Check warning on line 253 in src/compiler/interpreter.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/interpreter.jl#L253

Added line #L253 was not covered by tests
cache_or_token = (@static if HAS_INTEGRATED_CACHE
Expand All @@ -265,7 +265,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
reverse_rules = interp.reverse_rules,
inactive_rules = interp.inactive_rules,
broadcast_rewrite = interp.broadcast_rewrite,
defer_within_autodiff = interp.defer_within_autodiff,
within_autodiff_rewrite = interp.within_autodiff_rewrite,
handler = interp.handler)
return EnzymeInterpreter(

Check warning on line 270 in src/compiler/interpreter.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler/interpreter.jl#L270

Added line #L270 was not covered by tests
cache_or_token,
Expand All @@ -278,7 +278,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
reverse_rules,
inactive_rules,
broadcast_rewrite,
defer_within_autodiff,
within_autodiff_rewrite,
handler
)
end
Expand Down Expand Up @@ -949,7 +949,7 @@ function abstract_call_known(

(; fargs, argtypes) = arginfo

if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff
if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
Expand Down

0 comments on commit 9f54068

Please sign in to comment.