Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: compile neural operators using Reactant #52

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Nov 3, 2024

needs EnzymeAD/Reactant.jl#217 + a JLL bump

@avik-pal
Copy link
Member Author

avik-pal commented Nov 6, 2024

FNOs lead to crash of julia, need to investigate

@avik-pal
Copy link
Member Author

avik-pal commented Nov 8, 2024

some more progress. With EnzymeAD/Reactant.jl#245 the forward pass of FNO compiles. Though there seems to be an issue with the gradient EnzymeAD/Reactant.jl#246

@avik-pal
Copy link
Member Author

Bypassed the rfft issue. But now getting EnzymeAD/Reactant.jl#238 (comment)

julia> ∇fno_compiled = @compile ∇fno(fno, ps, st, x)
error: 'complex.add' op operand #0 must be complex type with floating-point elements, but got 'tensor<16x64x64xcomplex<f32>>'
ERROR: "failed to run pass manager on module"
Stacktrace:
  [1] run!
    @ /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:70 [inlined]
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:251
  [3] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:283
  [4] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:265 [inlined]
  [5] (::Reactant.Compiler.var"#30#32"{Bool, typeof(∇fno), Tuple{}})()
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:720
  [6] context!(f::Reactant.Compiler.var"#30#32"{Bool, typeof(∇fno), Tuple{}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
  [7] compile_xla(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:717
  [8] compile_xla
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:712 [inlined]
  [9] compile(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool, sync::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:744
 [10] top-level scope
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:490
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link

wsmoses commented Nov 11, 2024

That should be resolved by the arith raise pass and converted to a stablehlo.add

@avik-pal
Copy link
Member Author

FNO gradient segfaults with the AddOp pattern rewrite:

[2238547] signal 11 (128): Segmentation fault
in expression starting at REPL[31]:1
unknown function (ip: 0x7de12527076b)
_ZL8readBitsPKcmm at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZNK4mlir17DenseElementsAttr18IntElementIteratordeEv at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZNK4mlir17DenseElementsAttr13getSplatValueIN4llvm5APIntEEENSt9enable_ifIXoontsrSt10is_base_ofINS_9AttributeET_E5valuesrSt7is_sameIS6_S7_E5valueES7_E4typeEv at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZNK12_GLOBAL__N_111AddSimplify15matchAndRewriteEN4mlir9stablehlo5AddOpERNS1_15PatternRewriterE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZZN4mlir17PatternApplicator15matchAndRewriteEPNS_9OperationERNS_15PatternRewriterEN4llvm12function_refIFbRKNS_7PatternEEEENS6_IFvS9_EEENS6_IFNS5_13LogicalResultES9_EEEENKUlvE_clEv at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir17PatternApplicator15matchAndRewriteEPNS_9OperationERNS_15PatternRewriterEN4llvm12function_refIFbRKNS_7PatternEEEENS6_IFvS9_EEENS6_IFNS5_13LogicalResultES9_EEE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_126GreedyPatternRewriteDriver15processWorklistEv at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir28applyPatternsAndFoldGreedilyERNS_6RegionERKNS_23FrozenRewritePatternSetENS_19GreedyRewriteConfigEPb at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15ApplyPatternsOp10applyToOneERNS0_17TransformRewriterEPNS_9OperationERNS0_21ApplyToEachResultListERNS0_14TransformStateE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail20applyTransformToEachINS0_15ApplyPatternsOpERN4llvm14iterator_rangeINS4_20filter_iterator_implIPKPNS_9OperationEZNKS0_14TransformState13getPayloadOpsENS_5ValueEEUlS8_E_St26bidirectional_iterator_tagEEEEEENS_27DiagnosedSilenceableFailureET_RNS0_17TransformRewriterEOT0_RNS4_15SmallVectorImplINS0_21ApplyToEachResultListEEERSB_ at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformEachOpTraitINS0_15ApplyPatternsOpEE5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail35TransformOpInterfaceInterfaceTraits5ModelINS0_15ApplyPatternsOpEE5applyEPKNS2_7ConceptEPNS_9OperationERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformOpInterface5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform14TransformState14applyTransformENS0_20TransformOpInterfaceE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZL18applySequenceBlockRN4mlir5BlockENS_9transform22FailurePropagationModeERNS2_14TransformStateERNS2_16TransformResultsE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15NamedSequenceOp5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail35TransformOpInterfaceInterfaceTraits5ModelINS0_15NamedSequenceOpEE5applyEPKNS2_7ConceptEPNS_9OperationERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformOpInterface5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform14TransformState14applyTransformENS0_20TransformOpInterfaceE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15applyTransformsEPNS_9OperationENS0_20TransformOpInterfaceERKNS_11RaggedArrayIN4llvm12PointerUnionIJS2_NS_9AttributeENS_5ValueEEEEEERKNS0_16TransformOptionsEbNS5_12function_refIFvRNS0_14TransformStateEEEENSG_IFNS5_13LogicalResultESI_EEE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform27applyTransformNamedSequenceENS_11RaggedArrayIN4llvm12PointerUnionIJPNS_9OperationENS_9AttributeENS_5ValueEEEEEENS0_20TransformOpInterfaceENS_8ModuleOpERKNS0_16TransformOptionsE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_115InterpreterPass14runOnOperationEv at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor3runEPNS_4PassEPNS_9OperationENS_15AnalysisManagerEbj at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor11runPipelineERNS_13OpPassManagerEPNS_9OperationENS_15AnalysisManagerEbjPNS_16PassInstrumentorEPKNS_19PassInstrumentation18PipelineParentInfoE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager9runPassesEPNS_9OperationENS_15AnalysisManagerE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager3runEPNS_9OperationE at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/.julia/artifacts/09acc4d2f34243cb2b9c2df1dca0d492dcb19681/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/software/lux/Reactant.jl/src/mlir/libMLIR_h.jl:5853 [inlined]
run! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:74 [inlined]
#run_pass_pipeline!#1 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:264
run_pass_pipeline! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:259 [inlined]
#compile_mlir!#8 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:310
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:290 [inlined]
#6 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:285 [inlined]
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
unknown function (ip: 0x7de05cd4a9e6)
#compile_mlir#5 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:283
compile_mlir at /mnt/software/lux/Reactant.jl/src/Compiler.jl:280
unknown function (ip: 0x7de05cd48d2d)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:625
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:539
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval at ./Base.jl:130 [inlined]
repleval at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:229
#112 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:192 [inlined]
with_logstate at ./logging/logging.jl:522
with_logger at ./logging/logging.jl:632 [inlined]
#111 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:193
unknown function (ip: 0x7de06113517f)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052
unknown function (ip: 0x7de108100822)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/builtins.c:831
#64 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:34
unknown function (ip: 0x7de10814b74f)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
start_task at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/task.c:1202
Allocations: 286246536 (Pool: 286240354; Big: 6182); GC: 184

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants