Skip to content

Commit

Permalink
Interchange histogram with inner maps during reverse AD (#2031)
Browse files Browse the repository at this point in the history
Also fixes a bug in code generation for segmented intra-group SegHists
requiring spinlocks.
  • Loading branch information
nhey authored Oct 25, 2023
1 parent db84f00 commit 65e8df9
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 4 deletions.
59 changes: 59 additions & 0 deletions src/Futhark/AD/Rev/Hist.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module Futhark.AD.Rev.Hist
( diffMinMaxHist,
diffMulHist,
diffAddHist,
diffVecHist,
diffHist,
)
where
Expand Down Expand Up @@ -522,6 +523,64 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do
vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar
updateAdj vs vs_bar

-- Special case for vectorised combining operator. Rewrite
-- reduce_by_index dst (map2 op) nes is vss
-- to
-- map3 (\dst_col vss_col ne ->
-- reduce_by_index dst_col op ne is vss_col
-- ) (transpose dst) (transpose vss) nes |> transpose
-- before differentiating.
diffVecHist ::
VjpOps ->
VName ->
StmAux () ->
SubExp ->
Lambda SOACS ->
VName ->
VName ->
VName ->
SubExp ->
SubExp ->
VName ->
ADM () ->
ADM ()
diffVecHist ops x aux n op nes is vss w rf dst m = do
stms <- collectStms_ $ do
rank <- arrayRank <$> lookupType vss
let dims = [1, 0] ++ drop 2 [0 .. rank - 1]

dstT <- letExp "dstT" $ BasicOp $ Rearrange dims dst
vssT <- letExp "vssT" $ BasicOp $ Rearrange dims vss
t_dstT <- lookupType dstT
t_vssT <- lookupType vssT
t_nes <- lookupType nes

dst_col <- newParam "dst_col" $ rowType t_dstT
vss_col <- newParam "vss_col" $ rowType t_vssT
ne <- newParam "ne" $ rowType t_nes

f <- mkIdentityLambda (Prim int64 : lambdaReturnType op)
map_lam <-
mkLambda [dst_col, vss_col, ne] $ do
-- TODO Have to copy dst_col, but isn't it already unique?
dst_col_cpy <-
letExp "dst_col_cpy" . BasicOp $
Replicate mempty (Var $ paramName dst_col)
fmap (varsRes . pure) . letExp "col_res" $
Op $
Hist
n
[is, paramName vss_col]
[HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op]
f
histT <-
letExp "histT" $
Op $
Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $
mapSOAC map_lam
auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT
foldr (vjpStm ops) m stms

--
-- a step in the radix sort implementation
-- it assumes the key we are sorting
Expand Down
6 changes: 6 additions & 0 deletions src/Futhark/AD/Rev/SOAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ vjpSOAC ops pat aux (Hist n as histops f) m
let (is, vs) = splitAt (length histops) as
splitHist ops pat aux histops n is vs m
vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m
| isIdentityLambda f,
[x] <- patNames pat,
HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop,
-- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'.
Just op <- mapOp lam =
diffVecHist ops x aux n op ne is vs w rf dst m
| isIdentityLambda f,
[x] <- patNames pat,
HistOp (Shape [w]) rf [dst] [ne] lam <- histop,
Expand Down
9 changes: 5 additions & 4 deletions src/Futhark/CodeGen/ImpGen/GPU/Group.hs
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,11 @@ compileFlatId space = do

-- Construct the necessary lock arrays for an intra-group histogram.
prepareIntraGroupSegHist ::
Shape ->
Count GroupSize SubExp ->
[HistOp GPUMem] ->
InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist group_size =
prepareIntraGroupSegHist segments group_size =
fmap snd . mapAccumLM onOp Nothing
where
onOp l op = do
Expand All @@ -221,7 +222,7 @@ prepareIntraGroupSegHist group_size =
locks <- newVName "locks"

let num_locks = pe64 $ unCount group_size
dims = map pe64 $ shapeDims (histOpShape op <> histShape op)
dims = map pe64 $ shapeDims (segments <> histOpShape op <> histShape op)
l' = Locking locks 0 1 0 (pure . (`rem` num_locks) . flattenIndex dims)
locks_t = Array int32 (Shape [unCount group_size]) NoUniqueness

Expand Down Expand Up @@ -517,7 +518,7 @@ compileGroupOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do
sOp $ Imp.Barrier Imp.FenceLocal
compileGroupOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do
compileFlatId space
let (ltids, _dims) = unzip $ unSegSpace space
let (ltids, dims) = unzip $ unSegSpace space

-- We don't need the red_pes, because it is guaranteed by our type
-- rules that they occupy the same memory as the destinations for
Expand All @@ -527,7 +528,7 @@ compileGroupOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do
splitAt num_red_res $ patElems pat

group_size <- kernelGroupSizeCount . kernelConstants <$> askEnv
ops' <- prepareIntraGroupSegHist group_size ops
ops' <- prepareIntraGroupSegHist (Shape $ init dims) group_size ops

-- Ensure that all locks have been initialised.
sOp $ Imp.Barrier Imp.FenceLocal
Expand Down
32 changes: 32 additions & 0 deletions tests/ad/reducebyindexvecmin0.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
-- ==
-- entry: vecmin
-- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64]
-- [[8i32, 5i32, -2i32, 4i32, 6i32],
-- [12i32, 8i32, 7i32, 2i32, 6i32],
-- [3i32, 9i32, -2i32, 11i32, 1i32],
-- [7i32, 3i32, 12i32, 7i32, 10i32],
-- [9i32, 12i32, 4i32, 1i32, 8i32],
-- [7i32, -1i32, 11i32, 6i32, 10i32],
-- [-2i32, 6i32, 7i32, 1i32, 12i32],
-- [8i32, 0i32, 9i32, 6i32, 7i32],
-- [7i32, 3i32, 6i32, 7i32, 8i32],
-- [1i32, 4i32, 2i32, 9i32, 9i32]]
-- [[2i32, 2i32, 2i32, 2i32, 4i32],
-- [2i32, 1i32, 3i32, 3i32, 1i32],
-- [2i32, 5i32, 2i32, 4i32, 5i32],
-- [1i32, 2i32, 1i32, 1i32, 4i32],
-- [5i32, 1i32, 4i32, 4i32, 5i32],
-- [4i32, 1i32, 4i32, 5i32, 3i32]] }
-- output { [[4i32, 1i32, 4i32, 5i32, 3i32],
-- [0i32, 0i32, 0i32, 0i32, 4i32],
-- [2i32, 0i32, 2i32, 0i32, 5i32],
-- [0i32, 1i32, 0i32, 0i32, 5i32],
-- [0i32, 0i32, 0i32, 1i32, 0i32],
-- [0i32, 2i32, 0i32, 0i32, 0i32],
-- [5i32, 0i32, 4i32, 4i32, 0i32],
-- [0i32, 5i32, 0i32, 4i32, 0i32],
-- [0i32, 0i32, 0i32, 0i32, 0i32],
-- [1i32, 0i32, 1i32, 0i32, 0i32]] }

entry vecmin [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) =
vjp (hist (map2 i32.min) (replicate d i32.highest) bins is) vs adj_out
32 changes: 32 additions & 0 deletions tests/ad/reducebyindexvecmul0.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
-- ==
-- entry: vecmul
-- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64]
-- [[8i32, 5i32, -2i32, 4i32, 6i32],
-- [12i32, 8i32, 7i32, 2i32, 6i32],
-- [3i32, 9i32, -2i32, 11i32, 1i32],
-- [7i32, 3i32, 12i32, 7i32, 10i32],
-- [9i32, 12i32, 4i32, 1i32, 8i32],
-- [7i32, -1i32, 11i32, 6i32, 10i32],
-- [-2i32, 6i32, 7i32, 1i32, 12i32],
-- [8i32, 0i32, 9i32, 6i32, 7i32],
-- [7i32, 3i32, 6i32, 7i32, 8i32],
-- [1i32, 4i32, 2i32, 9i32, 9i32]]
-- [[2i32, 2i32, 2i32, 2i32, 4i32],
-- [2i32, 1i32, 3i32, 3i32, 1i32],
-- [2i32, 5i32, 2i32, 4i32, 5i32],
-- [1i32, 2i32, 1i32, 1i32, 4i32],
-- [5i32, 1i32, 4i32, 4i32, 5i32],
-- [4i32, 1i32, 4i32, 5i32, 3i32]] }
-- output { [[4i32, 1i32, 4i32, 5i32, 3i32],
-- [63i32, -96i32, 88i32, 54i32, 2880i32],
-- [112i32, 0i32, 108i32, 168i32, 280i32],
-- [-10i32, 6i32, 28i32, 4i32, 60i32],
-- [84i32, -64i32, 154i32, 108i32, 2160i32],
-- [108i32, 768i32, 56i32, 18i32, 1728i32],
-- [35i32, 3i32, 48i32, 28i32, 50i32],
-- [42i32, 135i32, -24i32, 308i32, 40i32],
-- [48i32, 0i32, -36i32, 264i32, 35i32],
-- [756i32, -192i32, 308i32, 12i32, 1920i32]] }

entry vecmul [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) =
vjp (hist (map2 (*)) (replicate d 1i32) bins is) vs adj_out

0 comments on commit 65e8df9

Please sign in to comment.