From c91d5a70ad2d2ee58caf580f5a62137982af63db Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 24 Sep 2024 14:15:38 +0200 Subject: [PATCH] add macro to create custom Ops also on aarch64 (#871) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> --- docs/examples/03-reduce.jl | 9 +++- docs/src/knownissues.md | 2 + docs/src/reference/advanced.md | 1 + src/operators.jl | 99 +++++++++++++++++++++++++++++++--- test/Project.toml | 8 +-- test/test_reduce.jl | 27 +++++----- 6 files changed, 120 insertions(+), 26 deletions(-) diff --git a/docs/examples/03-reduce.jl b/docs/examples/03-reduce.jl index 86d22be11..cfd3f9c99 100644 --- a/docs/examples/03-reduce.jl +++ b/docs/examples/03-reduce.jl @@ -31,10 +31,15 @@ function pool(S1::SummaryStat, S2::SummaryStat) SummaryStat(m,v,n) end +# Register the custom reduction operator. This is necessary only on platforms +# where Julia doesn't support closures as cfunctions (e.g. ARM), but can be used +# on all platforms for consistency. +MPI.@RegisterOp(pool, SummaryStat) + X = randn(10,3) .* [1,3,7]' # Perform a scalar reduction -summ = MPI.Reduce(SummaryStat(X), pool, root, comm) +summ = MPI.Reduce(SummaryStat(X), pool, comm; root) if MPI.Comm_rank(comm) == root @show summ.var @@ -42,7 +47,7 @@ end # Perform a vector reduction: # the reduction operator is applied elementwise -col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, root, comm) +col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, comm; root) if MPI.Comm_rank(comm) == root col_var = map(summ -> summ.var, col_summ) diff --git a/docs/src/knownissues.md b/docs/src/knownissues.md index 3cd84d067..179ccd6ed 100644 --- a/docs/src/knownissues.md +++ b/docs/src/knownissues.md @@ -210,3 +210,5 @@ However they have two limitations: * [Julia's C-compatible function pointers](https://docs.julialang.org/en/v1/manual/calling-c-and-fortran-code/index.html#Creating-C-Compatible-Julia-Function-Pointers-1) cannot be used where the `stdcall` calling convention is expected, which is the case for 32-bit Microsoft MPI, * closure cfunctions in Julia are based on LLVM trampolines, which are not supported on ARM architecture. + +As an alternative [`MPI.@RegisterOp`](@ref) may be used to statically register reduction operations. diff --git a/docs/src/reference/advanced.md b/docs/src/reference/advanced.md index 6440fd5ca..7d4ed0ab6 100644 --- a/docs/src/reference/advanced.md +++ b/docs/src/reference/advanced.md @@ -26,6 +26,7 @@ MPI.Types.duplicate ```@docs MPI.Op +MPI.@RegisterOp ``` ## Info objects diff --git a/src/operators.jl b/src/operators.jl index 5ded20f22..9f3fed798 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -16,6 +16,7 @@ associative, and if `iscommutative` is true, assumed to be commutative as well. - [`Allreduce!`](@ref)/[`Allreduce`](@ref) - [`Scan!`](@ref)/[`Scan`](@ref) - [`Exscan!`](@ref)/[`Exscan`](@ref) +- [`@RegisterOp`](@ref) """ mutable struct Op val::MPI_Op @@ -81,21 +82,36 @@ end function (w::OpWrapper{F,T})(_a::Ptr{Cvoid}, _b::Ptr{Cvoid}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype}) where {F,T} len = unsafe_load(_len) - @assert isconcretetype(T) - a = Ptr{T}(_a) - b = Ptr{T}(_b) - for i = 1:len - unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i) + if !isconcretetype(T) + concrete_T = to_type(Datatype(unsafe_load(t))) # Ptr might actually point to a Julia object so we could unsafe_pointer_to_objref? + else + concrete_T = T end + function copy(::Type{T}) where T + @assert isconcretetype(T) + a = Ptr{T}(_a) + b = Ptr{T}(_b) + for i = 1:len + unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i) + end + end + copy(concrete_T) return nothing end - function Op(f, T=Any; iscommutative=false) @static if MPI_LIBRARY == "MicrosoftMPI" && Sys.WORD_SIZE == 32 - error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.") + error(""" + User-defined reduction operators are not supported on 32-bit Windows. + See https://github.com/JuliaParallel/MPI.jl/issues/246 for more details. + """) elseif Sys.ARCH ∈ (:aarch64, :ppc64le, :powerpc64le) || startswith(lowercase(String(Sys.ARCH)), "arm") - error("User-defined reduction operators are currently not supported on non-Intel architectures.\nSee https://github.com/JuliaParallel/MPI.jl/issues/404 for more details.") + error(""" + User-defined reduction operators are currently not supported on non-Intel architectures. + See https://github.com/JuliaParallel/MPI.jl/issues/404 for more details. + + You may want to use `@RegisterOp` to statically register `f`. + """) end w = OpWrapper{typeof(f),T}(f) fptr = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype})) @@ -107,3 +123,70 @@ function Op(f, T=Any; iscommutative=false) finalizer(free, op) return op end + +""" + @RegisterOp(f, T) + +Register a custom operator [`Op`](@ref) using the function `f` statically. +On platfroms like AArch64, Julia does not support runtime closures, +being passed to C. The generic version of [`Op`](@ref) uses runtime closures +to support arbitrary functions being passed as MPI reduction operators. +`@RegisterOp` statically adds a function to the set of functions allowed as +as an MPI operator. + +```julia +function my_reduce(x, y) + 2x+y-x +end +MPI.@RegisterOp(my_reduce, Int) +# ... +MPI.Reduce!(send_arr, recv_arr, my_reduce, MPI.COMM_WORLD; root=root) +#... +``` +!!! warning + Note that `@RegisterOp` works be introducing a new method of the generic function `Op`. + It can only be used as a top-level statement and may trigger method invalidations. + +!!! note + `T` can be `Any`, but this will lead to a runtime dispatch. +""" +macro RegisterOp(f, T) + name_wrapper = gensym(Symbol(f, :_, T, :_wrapper)) + name_fptr = gensym(Symbol(f, :_, T, :_ptr)) + name_module = gensym(Symbol(f, :_, T, :_module)) + # The gist is that we can use a method very similar to how we handle `min`/`max` + # but since this might be used from user code we can't use add_load_time_hook! + # this is why we introduce a new module that has a `__init__` function. + # If this module approach is too costly for loading MPI.jl for internal use we could use + # `add_load_time_hook` + expr = quote + module $(name_module) + # import ..$f, ..$T + $(Expr(:import, Expr(:., :., :., f), Expr(:., :., :., T))) # Julia 1.6 strugles with import ..$f, ..$T + const $(name_wrapper) = $OpWrapper{typeof($f),$T}($f) + const $(name_fptr) = Ref(@cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))) + function __init__() + $(name_fptr)[] = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype})) + end + import MPI: Op + # we can't create a const Op since MPI needs to be initialized? + function Op(::typeof($f), ::Type{<:$T}; iscommutative=false) + op = Op($OP_NULL.val, $(name_fptr)[]) + # int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op) + $API.MPI_Op_create($(name_fptr)[], iscommutative, op) + + finalizer($free, op) + end + end + end + expr.head = :toplevel + esc(expr) +end + +@RegisterOp(min, Any) +@RegisterOp(max, Any) +@RegisterOp(+, Any) +@RegisterOp(*, Any) +@RegisterOp(&, Any) +@RegisterOp(|, Any) +@RegisterOp(⊻, Any) diff --git a/test/Project.toml b/test/Project.toml index f83c1fb10..52ccb4376 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,10 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + [compat] AMDGPU = "0.6, 0.7, 0.8, 0.9, 1" CUDA = "3, 4, 5" @@ -16,7 +20,3 @@ TOML = "< 0.0.1, 1.0" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - -[weakdeps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/test/test_reduce.jl b/test/test_reduce.jl index dbf0c6e79..6ad5d24cf 100644 --- a/test/test_reduce.jl +++ b/test/test_reduce.jl @@ -59,10 +59,15 @@ if isroot @test sum_mesg == sz .* mesg end +function my_reduce(x, y) + 2x+y-x +end +MPI.@RegisterOp(my_reduce, Any) + if can_do_closures - operators = [MPI.SUM, +, (x,y) -> 2x+y-x] + operators = [MPI.SUM, +, my_reduce, (x,y) -> 2x+y-x] else - operators = [MPI.SUM, +] + operators = [MPI.SUM, +, my_reduce] end for T = [Int] @@ -117,19 +122,17 @@ end MPI.Barrier( MPI.COMM_WORLD ) -if can_do_closures - send_arr = [Double64(i)/10 for i = 1:10] - - result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root) - if rank == root - @test result ≈ [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64) - else - @test result === nothing - end +send_arr = [Double64(i)/10 for i = 1:10] - MPI.Barrier( MPI.COMM_WORLD ) +result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root) +if rank == root + @test result ≈ [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64) +else + @test result === nothing end +MPI.Barrier( MPI.COMM_WORLD ) + GC.gc() MPI.Finalize() @test MPI.Finalized()