From a6e4cce9b28b311994e6945ebac16af94f31a492 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 13 Jul 2024 19:57:01 -0400 Subject: [PATCH] try with getmap lookup --- src/Reactant.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 257115c2..b9e1cf62 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -208,6 +208,10 @@ using Enzyme TracedSetPath = 5 end +@inline getmap(::Val{T}) where T = nothing +@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...) +@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T, T2} = T2 + @inline is_concrete_tuple(x::T2) where {T2} = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) @inline function traced_type(val::Type{T}, seen::ST, ::Val{mode}) where {ST,T,mode} @@ -391,17 +395,18 @@ end return IdDict{iddict_name(T),traced_type(iddict_val(T), seen, Val(mode))} end - if Val(T) ∈ seen - return T + nextTy = getmap(Val(T), seen...) + if nextTy != nothing + return nextTy end - seen = (Val(T), seen...) + seen2 = (Val(T), Val(T), seen...) changed = false subTys = Type[] for f in 1:fieldcount(T) subT = fieldtype(T, f) - subTT = traced_type(subT, seen, Val(mode)) + subTT = traced_type(subT, seen2, Val(mode)) changed |= subT != subTT push!(subTys, subTT) end @@ -421,12 +426,13 @@ end end TT2 = Core.apply_type(T.name.wrapper, subParms...) + seen3 = (Val(T), Val(TT2), seen...) if fieldcount(T) == fieldcount(TT2) legal = true for f in 1:fieldcount(T) subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) - subTT = traced_type(subT, seen, Val(mode)) + subTT = traced_type(subT, seen3, Val(mode)) if subT2 != subTT @show "illegal subs", subT2, subTT legal = false