Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Cassette for an overlayed MethodTable #40

Merged
merged 18 commits into from
Jul 16, 2024
Merged

Replace Cassette for an overlayed MethodTable #40

merged 18 commits into from
Jul 16, 2024

Conversation

mofeing
Copy link
Collaborator

@mofeing mofeing commented Jul 13, 2024

CC @wsmoses @vchuravy

The only remaining thing is an alternative implementation of Base.invoke_within to effectively overlaying the methods with our MethodTable.

Also, we might need to add some definitions based on Julia version due to changes in the AbstractInterpreter API.

  • Fix OpaqueClosure call on Julia 1.9
  • Fix TypeError: in new, expected DataType, got Type{Symbol} on Julia 1.11
  • Support kwargs

@mofeing mofeing marked this pull request as draft July 13, 2024 16:54
@mofeing mofeing changed the title Replace Cassette with an overlayed MethodTable Replace Cassette for an overlayed MethodTable Jul 13, 2024
README.md Outdated Show resolved Hide resolved
src/Interpreter.jl Outdated Show resolved Hide resolved
@mofeing
Copy link
Collaborator Author

mofeing commented Jul 15, 2024

@wsmoses This seems like a problem of Enzyme.jl on Julia 1.11 https://github.com/EnzymeAD/Reactant.jl/actions/runs/9944022931/job/27469120868?pr=40#step:9:561

@wsmoses
Copy link
Member

wsmoses commented Jul 15, 2024

Oh yea, but that's an easy fix here, we should just not run the non-reactant autodiff code on Julia <= 1.10

@mofeing
Copy link
Collaborator Author

mofeing commented Jul 15, 2024

In Julia <= 1.10? Wouldn't it be in Julia >=1.11?

@wsmoses
Copy link
Member

wsmoses commented Jul 15, 2024

Oh yeah sorry that’s what I meant

@mofeing
Copy link
Collaborator Author

mofeing commented Jul 15, 2024

Okay, but where does that happen? I mean, we are overriding Enzyme.autodiff so the original method shouldn't get called in Reactant right?

test/basic.jl Outdated Show resolved Hide resolved
Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11/10

@mofeing
Copy link
Collaborator Author

mofeing commented Jul 16, 2024

I think I found the problem for the crash on @code_hlo W * x. It seems like the return type for Base.:* with TracedRArrays is type-unstable, and the interpreter doesn't like that on Julia 1.9 (still type-unstable on Julia 1.10 but don't know why it accepts it 🤷).
The error is somewhat obscure since after the interpretation it says that the local vars are also arguments of the function 🫠

With new implementation based on type parameters instead of runtime values, it works.

Comparison

Julia 1.9

julia> code_warntype(*, (Reactant.TracedRArray{Float32, (10,20), 2}, Reactant.TracedRArray{Float32, (20,5), 2}))
MethodInstance for *(::Reactant.TracedRArray{Float32, (10, 20), 2}, ::Reactant.TracedRArray{Float32, (20, 5), 2})
  from *(lhs::Reactant.TracedRArray{ElType, Shape, 2}, rhs::Reactant.TracedRArray{ElType, Shape2, 2}) where {ElType, Shape, Shape2} @ Reactant /media/M2/mofeing/repos/Reactant.jl/src/overloads.jl:143
Static Parameters
  ElType = Float32
  Shape = (10, 20)
  Shape2 = (20, 5)
Arguments
  #self#::Core.Const(*)
  lhs::Reactant.TracedRArray{Float32, (10, 20), 2}
  rhs::Reactant.TracedRArray{Float32, (20, 5), 2}
Locals
  res::Reactant.MLIR.IR.Value
  precar::Reactant.MLIR.IR.Attribute
  prec::Reactant.MLIR.IR.Attribute
  dot_dimension_numbers::Reactant.MLIR.API.MlirAttribute
  resty::Reactant.MLIR.IR.Type
  rhsty::Reactant.MLIR.IR.Type
  lhsty::Reactant.MLIR.IR.Type
Body::Reactant.TracedRArray{Float32, _A, 2} where _A
1%1  = Reactant.MLIR.IR::Core.Const(Reactant.MLIR.IR)
│   %2  = Base.getproperty(%1, :type)::Core.Const(Reactant.MLIR.IR.type)
...%67 = (%65)(%66, res)::Reactant.TracedRArray{Float32, _A, 2} where _A
└──       return %67

Julia 1.9 with new implementation

julia> code_warntype(*, (Reactant.TracedRArray{Float32, (10,20), 2}, Reactant.TracedRArray{Float32, (20,5), 2}))
MethodInstance for *(::Reactant.TracedRArray{Float32, (10, 20), 2}, ::Reactant.TracedRArray{Float32, (20, 5), 2})
  from *(lhs::Reactant.TracedRArray{ElType, Shape, 2}, rhs::Reactant.TracedRArray{ElType, Shape2, 2}) where {ElType, Shape, Shape2} @ Reactant /media/M2/mofeing/repos/Reactant.jl/src/overloads.jl:143
Static Parameters
  ElType = Float32
  Shape = (10, 20)
  Shape2 = (20, 5)
Arguments
  #self#::Core.Const(*)
  lhs::Reactant.TracedRArray{Float32, (10, 20), 2}
  rhs::Reactant.TracedRArray{Float32, (20, 5), 2}
Locals
  res::Reactant.MLIR.IR.Value
  precar::Reactant.MLIR.IR.Attribute
  prec::Reactant.MLIR.IR.Attribute
  dot_dimension_numbers::Reactant.MLIR.API.MlirAttribute
  resty::Reactant.MLIR.IR.Type
  rhsty::Reactant.MLIR.IR.Type
  lhsty::Reactant.MLIR.IR.Type
Body::Reactant.TracedRArray{Float32, (10, 5), 2}
1%1  = Reactant.MLIR.IR::Core.Const(Reactant.MLIR.IR)
│   %2  = Base.getproperty(%1, :type)::Core.Const(Reactant.MLIR.IR.type)
...%63 = (%61)(%62, res)::Reactant.TracedRArray{Float32, (10, 5), 2}
└──       return %63

@mofeing mofeing marked this pull request as ready for review July 16, 2024 13:13
@mofeing
Copy link
Collaborator Author

mofeing commented Jul 16, 2024

The kwargs support can be postponed since compile doesn't support yet kwargs.

@mofeing mofeing merged commit 906395e into main Jul 16, 2024
9 of 14 checks passed
@mofeing mofeing deleted the abs-interp branch July 16, 2024 13:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants