From a28a0fc42a6f982a70588cf57e193e0fa1265601 Mon Sep 17 00:00:00 2001 From: James Schloss Date: Mon, 16 Sep 2024 14:07:33 +0200 Subject: [PATCH] Revert "removing heuristic" This reverts commit 54796ad22c6b74dcce0d48f1c69ad5eb8b0a5219. --- src/gpuarrays.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/gpuarrays.jl b/src/gpuarrays.jl index d15bac6..1cc3949 100644 --- a/src/gpuarrays.jl +++ b/src/gpuarrays.jl @@ -1,5 +1,8 @@ # GPUArrays.jl interface +import KernelAbstractions +import KernelAbstractions: Backend + # # Device functionality # @@ -7,6 +10,24 @@ ## execution +@inline function GPUArrays.launch_heuristic(::oneAPIBackend, obj::O, args::Vararg{Any,N}; + elements::Int, elements_per_thread::Int) where {O,N} + ndrange = ceil(Int, elements / elements_per_thread) + ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange, + nothing) + + # this might not be the final context, since we may tune the workgroupsize + ctx = KA.mkcontext(obj, ndrange, iterspace) + + kernel = @oneapi launch=false obj.f(ctx, args...) + + items = launch_configuration(kernel) + # XXX: how many groups is a good number? the API doesn't tell us. + # measured on a low-end IGP, 32 blocks seems like a good sweet spot. + # note that this only matters for grid-stride kernels, like broadcast. + return (threads=items, blocks=32) +end + const GLOBAL_RNGs = Dict{ZeDevice,GPUArrays.RNG}() function GPUArrays.default_rng(::Type{<:oneArray}) dev = device()