Skip to content

Commit

Permalink
try with getmap lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jul 13, 2024
1 parent 03f3865 commit a6e4cce
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ using Enzyme
TracedSetPath = 5
end

@inline getmap(::Val{T}) where T = nothing
@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...)
@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T, T2} = T2

@inline is_concrete_tuple(x::T2) where {T2} =
(x <: Tuple) && !(x === Tuple) && !(x isa UnionAll)
@inline function traced_type(val::Type{T}, seen::ST, ::Val{mode}) where {ST,T,mode}
Expand Down Expand Up @@ -391,17 +395,18 @@ end
return IdDict{iddict_name(T),traced_type(iddict_val(T), seen, Val(mode))}
end

if Val(T) seen
return T
nextTy = getmap(Val(T), seen...)
if nextTy != nothing
return nextTy
end

seen = (Val(T), seen...)
seen2 = (Val(T), Val(T), seen...)

changed = false
subTys = Type[]
for f in 1:fieldcount(T)
subT = fieldtype(T, f)
subTT = traced_type(subT, seen, Val(mode))
subTT = traced_type(subT, seen2, Val(mode))
changed |= subT != subTT
push!(subTys, subTT)
end
Expand All @@ -421,12 +426,13 @@ end
end

TT2 = Core.apply_type(T.name.wrapper, subParms...)
seen3 = (Val(T), Val(TT2), seen...)
if fieldcount(T) == fieldcount(TT2)
legal = true
for f in 1:fieldcount(T)
subT = fieldtype(T, f)
subT2 = fieldtype(TT2, f)
subTT = traced_type(subT, seen, Val(mode))
subTT = traced_type(subT, seen3, Val(mode))
if subT2 != subTT
@show "illegal subs", subT2, subTT
legal = false
Expand Down

0 comments on commit a6e4cce

Please sign in to comment.