Skip to content

Commit

Permalink
Improve SegRed sequentialization in certain cases (#2054)
Browse files Browse the repository at this point in the history
This commit improves GPU code generation for
non-segmented and large-segments segmented reductions with
non-commutative and primitive operators (see description in module
header).
  • Loading branch information
sortraev authored Dec 5, 2023
1 parent 06732a5 commit 5234eb8
Show file tree
Hide file tree
Showing 5 changed files with 808 additions and 633 deletions.
4 changes: 2 additions & 2 deletions src/Futhark/CodeGen/ImpGen/GPU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Futhark.CodeGen.ImpGen.GPU.SegScan
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.MonadFreshNames
import Futhark.Util.IntegralExp (divUp, rem)
import Futhark.Util.IntegralExp (divUp, nextMul)
import Prelude hiding (quot, rem)

callKernelOperations :: Operations GPUMem HostEnv Imp.HostOp
Expand Down Expand Up @@ -200,7 +200,7 @@ checkLocalMemoryReqs code = do
-- These allocations will actually be padded to an 8-byte aligned
-- size, so we should take that into account when checking whether
-- they fit.
alignedSize x = x + ((8 - (x `rem` 8)) `rem` 8)
alignedSize x = nextMul x 8

withAcc ::
Pat LetDecMem ->
Expand Down
79 changes: 50 additions & 29 deletions src/Futhark/CodeGen/ImpGen/GPU/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ module Futhark.CodeGen.ImpGen.GPU.Base
fenceForArrays,
updateAcc,
genZeroes,
isPrimParam,
getChunkSize,

-- * Host-level bulk operations
sReplicate,
Expand Down Expand Up @@ -257,6 +259,30 @@ fenceForArrays = fmap (foldl' max Imp.FenceLocal) . mapM need
. entryArrayLoc
=<< lookupArray arr

isPrimParam :: (Typed p) => Param p -> Bool
isPrimParam = primType . paramType

-- | Given a list of parameter types, compute the largest available chunk size
-- given the parameters for which we want chunking and the available resources.
-- Used in SegScan.SinglePass.compileSegScan, and SegRed.compileSegRed (with
-- primitive non-commutative operators only).
getChunkSize :: (Num a) => [Type] -> a
getChunkSize types = fromInteger $ max 1 $ min mem_constraint reg_constraint
where
types' = map elemType $ filter primType types
sizes = map primByteSize types'

sum_sizes = sum sizes
sum_sizes' = sum (map (max 4 . primByteSize) types') `div` 4
max_size = maximum sizes

mem_constraint = max k_mem sum_sizes `div` max_size
reg_constraint = (k_reg - 1 - sum_sizes') `div` (2 * sum_sizes')

-- TODO: Make these constants dynamic by querying device
k_reg = 64
k_mem = 95

inBlockScan ::
KernelConstants ->
Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
Expand All @@ -275,7 +301,7 @@ inBlockScan constants seg_flag arrs_full_size lockstep_width block_size active a
splitAt (length actual_params `div` 2) actual_params
y_to_x =
forM_ (zip x_params y_params) $ \(x, y) ->
when (primType (paramType x)) $
when (isPrimParam x) $
copyDWIM (paramName x) [] (Var (paramName y)) []

-- Set initial y values
Expand Down Expand Up @@ -342,23 +368,21 @@ inBlockScan constants seg_flag arrs_full_size lockstep_width block_size active a
array_scan = not $ all primType $ lambdaReturnType scan_lam

readInitial p arr
| primType $ paramType p =
| isPrimParam p =
copyDWIMFix (paramName p) [] (Var arr) [ltid]
| otherwise =
copyDWIMFix (paramName p) [] (Var arr) [gtid]

readParam behind p arr
| primType $ paramType p =
| isPrimParam p =
copyDWIMFix (paramName p) [] (Var arr) [ltid - behind]
| otherwise =
copyDWIMFix (paramName p) [] (Var arr) [gtid - behind + arrs_full_size]

writeResult x y arr
| primType $ paramType x = do
copyDWIMFix arr [ltid] (Var $ paramName x) []
copyDWIM (paramName y) [] (Var $ paramName x) []
| otherwise =
copyDWIM (paramName y) [] (Var $ paramName x) []
writeResult x y arr = do
when (isPrimParam x) $
copyDWIMFix arr [ltid] (Var $ paramName x) []
copyDWIM (paramName y) [] (Var $ paramName x) []

groupScan ::
Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
Expand Down Expand Up @@ -421,13 +445,13 @@ groupScan seg_flag arrs_full_size w lam arrs = do
group_offset = sExt64 (kernelGroupId constants) * kernelGroupSize constants

writeBlockResult p arr
| primType $ paramType p =
| isPrimParam p =
copyDWIMFix arr [sExt64 block_id] (Var $ paramName p) []
| otherwise =
copyDWIMFix arr [group_offset + sExt64 block_id] (Var $ paramName p) []

readPrevBlockResult p arr
| primType $ paramType p =
| isPrimParam p =
copyDWIMFix (paramName p) [] (Var arr) [sExt64 block_id - 1]
| otherwise =
copyDWIMFix (paramName p) [] (Var arr) [group_offset + sExt64 block_id - 1]
Expand All @@ -440,7 +464,7 @@ groupScan seg_flag arrs_full_size w lam arrs = do
sComment "save correct values for first block" $
sWhen is_first_block $
forM_ (zip x_params arrs) $ \(x, arr) ->
unless (primType $ paramType x) $
unless (isPrimParam x) $
copyDWIMFix arr [arrs_full_size + group_offset + sExt64 block_size + ltid] (Var $ paramName x) []

barrier
Expand All @@ -467,7 +491,7 @@ groupScan seg_flag arrs_full_size w lam arrs = do
sComment "move correct values for first block back a block" $
sWhen is_first_block $
forM_ (zip x_params arrs) $ \(x, arr) ->
unless (primType $ paramType x) $
unless (isPrimParam x) $
copyDWIMFix
arr
[arrs_full_size + group_offset + ltid]
Expand Down Expand Up @@ -498,7 +522,7 @@ groupScan seg_flag arrs_full_size w lam arrs = do

write_final_result =
forM_ (zip x_params arrs) $ \(p, arr) ->
when (primType $ paramType p) $
when (isPrimParam p) $
copyDWIMFix arr [ltid] (Var $ paramName p) []

sComment "carry-in for every block except the first" $
Expand All @@ -512,7 +536,7 @@ groupScan seg_flag arrs_full_size w lam arrs = do
sComment "restore correct values for first block" $
sWhen (is_first_block .&&. ltid_in_bounds) $
forM_ (zip3 x_params y_params arrs) $ \(x, y, arr) ->
if primType (paramType y)
if isPrimParam y
then copyDWIMFix arr [ltid] (Var $ paramName y) []
else copyDWIMFix (paramName x) [] (Var arr) [arrs_full_size + group_offset + ltid]

Expand Down Expand Up @@ -550,17 +574,13 @@ groupReduceWithOffset offset w lam arrs = do
let i = local_tid + tvExp offset
copyDWIMFix (paramName param) [] (Var arr) [sExt64 i]

writeReduceOpResult param arr
| Prim _ <- paramType param =
copyDWIMFix arr [sExt64 local_tid] (Var $ paramName param) []
| otherwise =
pure ()
writeReduceOpResult param arr =
when (isPrimParam param) $
copyDWIMFix arr [sExt64 local_tid] (Var $ paramName param) []

writeArrayOpResult param arr
| Prim _ <- paramType param =
pure ()
| otherwise =
copyDWIMFix arr [sExt64 local_tid] (Var $ paramName param) []
writeArrayOpResult param arr =
unless (isPrimParam param) $
copyDWIMFix arr [sExt64 local_tid] (Var $ paramName param) []

let (reduce_acc_params, reduce_arr_params) =
splitAt (length arrs) $ lambdaParams lam
Expand Down Expand Up @@ -622,10 +642,11 @@ groupReduceWithOffset offset w lam arrs = do
cross_wave_reductions
errorsync

sComment "Copy array-typed operands to result array" $ do
sWhen (local_tid .==. 0) $
localOps threadOperations $
zipWithM_ writeArrayOpResult reduce_acc_params arrs
unless (all isPrimParam reduce_acc_params) $
sComment "Copy array-typed operands to result array" $
sWhen (local_tid .==. 0) $
localOps threadOperations $
zipWithM_ writeArrayOpResult reduce_acc_params arrs

compileThreadOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileThreadOp pat (Alloc size space) =
Expand Down
Loading

0 comments on commit 5234eb8

Please sign in to comment.