-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace Cassette for an overlayed MethodTable
#40
Conversation
MethodTable
MethodTable
@wsmoses This seems like a problem of Enzyme.jl on Julia 1.11 https://github.com/EnzymeAD/Reactant.jl/actions/runs/9944022931/job/27469120868?pr=40#step:9:561 |
Oh yea, but that's an easy fix here, we should just not run the non-reactant autodiff code on Julia <= 1.10 |
In Julia <= 1.10? Wouldn't it be in Julia >=1.11? |
Oh yeah sorry that’s what I meant |
Okay, but where does that happen? I mean, we are overriding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11/10
I think I found the problem for the crash on With new implementation based on type parameters instead of runtime values, it works. ComparisonJulia 1.9julia> code_warntype(*, (Reactant.TracedRArray{Float32, (10,20), 2}, Reactant.TracedRArray{Float32, (20,5), 2}))
MethodInstance for *(::Reactant.TracedRArray{Float32, (10, 20), 2}, ::Reactant.TracedRArray{Float32, (20, 5), 2})
from *(lhs::Reactant.TracedRArray{ElType, Shape, 2}, rhs::Reactant.TracedRArray{ElType, Shape2, 2}) where {ElType, Shape, Shape2} @ Reactant /media/M2/mofeing/repos/Reactant.jl/src/overloads.jl:143
Static Parameters
ElType = Float32
Shape = (10, 20)
Shape2 = (20, 5)
Arguments
#self#::Core.Const(*)
lhs::Reactant.TracedRArray{Float32, (10, 20), 2}
rhs::Reactant.TracedRArray{Float32, (20, 5), 2}
Locals
res::Reactant.MLIR.IR.Value
precar::Reactant.MLIR.IR.Attribute
prec::Reactant.MLIR.IR.Attribute
dot_dimension_numbers::Reactant.MLIR.API.MlirAttribute
resty::Reactant.MLIR.IR.Type
rhsty::Reactant.MLIR.IR.Type
lhsty::Reactant.MLIR.IR.Type
Body::Reactant.TracedRArray{Float32, _A, 2} where _A
1 ─ %1 = Reactant.MLIR.IR::Core.Const(Reactant.MLIR.IR)
│ %2 = Base.getproperty(%1, :type)::Core.Const(Reactant.MLIR.IR.type)
...
│ %67 = (%65)(%66, res)::Reactant.TracedRArray{Float32, _A, 2} where _A
└── return %67 Julia 1.9 with new implementationjulia> code_warntype(*, (Reactant.TracedRArray{Float32, (10,20), 2}, Reactant.TracedRArray{Float32, (20,5), 2}))
MethodInstance for *(::Reactant.TracedRArray{Float32, (10, 20), 2}, ::Reactant.TracedRArray{Float32, (20, 5), 2})
from *(lhs::Reactant.TracedRArray{ElType, Shape, 2}, rhs::Reactant.TracedRArray{ElType, Shape2, 2}) where {ElType, Shape, Shape2} @ Reactant /media/M2/mofeing/repos/Reactant.jl/src/overloads.jl:143
Static Parameters
ElType = Float32
Shape = (10, 20)
Shape2 = (20, 5)
Arguments
#self#::Core.Const(*)
lhs::Reactant.TracedRArray{Float32, (10, 20), 2}
rhs::Reactant.TracedRArray{Float32, (20, 5), 2}
Locals
res::Reactant.MLIR.IR.Value
precar::Reactant.MLIR.IR.Attribute
prec::Reactant.MLIR.IR.Attribute
dot_dimension_numbers::Reactant.MLIR.API.MlirAttribute
resty::Reactant.MLIR.IR.Type
rhsty::Reactant.MLIR.IR.Type
lhsty::Reactant.MLIR.IR.Type
Body::Reactant.TracedRArray{Float32, (10, 5), 2}
1 ─ %1 = Reactant.MLIR.IR::Core.Const(Reactant.MLIR.IR)
│ %2 = Base.getproperty(%1, :type)::Core.Const(Reactant.MLIR.IR.type)
...
│ %63 = (%61)(%62, res)::Reactant.TracedRArray{Float32, (10, 5), 2}
└── return %63 |
The kwargs support can be postponed since |
CC @wsmoses @vchuravy
The only remaining thing is an alternative implementation ofBase.invoke_within
to effectively overlaying the methods with ourMethodTable
.Also, we might need to add some definitions based on Julia version due to changes in the
AbstractInterpreter
API.OpaqueClosure
call on Julia 1.9TypeError: in new, expected DataType, got Type{Symbol}
on Julia 1.11Support kwargs