From 65e8df92e84e13b6954673a1af1ec502a5300624 Mon Sep 17 00:00:00 2001 From: Nikolaj Hey Hinnerskov Date: Wed, 25 Oct 2023 14:58:35 +0200 Subject: [PATCH] Interchange histogram with inner maps during reverse AD (#2031) Also fixes a bug in code generation for segmented intra-group SegHists requiring spinlocks. --- src/Futhark/AD/Rev/Hist.hs | 59 +++++++++++++++++++++++++ src/Futhark/AD/Rev/SOAC.hs | 6 +++ src/Futhark/CodeGen/ImpGen/GPU/Group.hs | 9 ++-- tests/ad/reducebyindexvecmin0.fut | 32 ++++++++++++++ tests/ad/reducebyindexvecmul0.fut | 32 ++++++++++++++ 5 files changed, 134 insertions(+), 4 deletions(-) create mode 100644 tests/ad/reducebyindexvecmin0.fut create mode 100644 tests/ad/reducebyindexvecmul0.fut diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 4a208fa30f..8da2ed537f 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -8,6 +8,7 @@ module Futhark.AD.Rev.Hist ( diffMinMaxHist, diffMulHist, diffAddHist, + diffVecHist, diffHist, ) where @@ -522,6 +523,64 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar updateAdj vs vs_bar +-- Special case for vectorised combining operator. Rewrite +-- reduce_by_index dst (map2 op) nes is vss +-- to +-- map3 (\dst_col vss_col ne -> +-- reduce_by_index dst_col op ne is vss_col +-- ) (transpose dst) (transpose vss) nes |> transpose +-- before differentiating. +diffVecHist :: + VjpOps -> + VName -> + StmAux () -> + SubExp -> + Lambda SOACS -> + VName -> + VName -> + VName -> + SubExp -> + SubExp -> + VName -> + ADM () -> + ADM () +diffVecHist ops x aux n op nes is vss w rf dst m = do + stms <- collectStms_ $ do + rank <- arrayRank <$> lookupType vss + let dims = [1, 0] ++ drop 2 [0 .. rank - 1] + + dstT <- letExp "dstT" $ BasicOp $ Rearrange dims dst + vssT <- letExp "vssT" $ BasicOp $ Rearrange dims vss + t_dstT <- lookupType dstT + t_vssT <- lookupType vssT + t_nes <- lookupType nes + + dst_col <- newParam "dst_col" $ rowType t_dstT + vss_col <- newParam "vss_col" $ rowType t_vssT + ne <- newParam "ne" $ rowType t_nes + + f <- mkIdentityLambda (Prim int64 : lambdaReturnType op) + map_lam <- + mkLambda [dst_col, vss_col, ne] $ do + -- TODO Have to copy dst_col, but isn't it already unique? + dst_col_cpy <- + letExp "dst_col_cpy" . BasicOp $ + Replicate mempty (Var $ paramName dst_col) + fmap (varsRes . pure) . letExp "col_res" $ + Op $ + Hist + n + [is, paramName vss_col] + [HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op] + f + histT <- + letExp "histT" $ + Op $ + Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $ + mapSOAC map_lam + auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT + foldr (vjpStm ops) m stms + -- -- a step in the radix sort implementation -- it assumes the key we are sorting diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index 3560cfe121..8a0d008b6e 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -156,6 +156,12 @@ vjpSOAC ops pat aux (Hist n as histops f) m let (is, vs) = splitAt (length histops) as splitHist ops pat aux histops n is vs m vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m + | isIdentityLambda f, + [x] <- patNames pat, + HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop, + -- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'. + Just op <- mapOp lam = + diffVecHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, diff --git a/src/Futhark/CodeGen/ImpGen/GPU/Group.hs b/src/Futhark/CodeGen/ImpGen/GPU/Group.hs index 7497739811..dd6da6a3ee 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/Group.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/Group.hs @@ -201,10 +201,11 @@ compileFlatId space = do -- Construct the necessary lock arrays for an intra-group histogram. prepareIntraGroupSegHist :: + Shape -> Count GroupSize SubExp -> [HistOp GPUMem] -> InKernelGen [[Imp.TExp Int64] -> InKernelGen ()] -prepareIntraGroupSegHist group_size = +prepareIntraGroupSegHist segments group_size = fmap snd . mapAccumLM onOp Nothing where onOp l op = do @@ -221,7 +222,7 @@ prepareIntraGroupSegHist group_size = locks <- newVName "locks" let num_locks = pe64 $ unCount group_size - dims = map pe64 $ shapeDims (histOpShape op <> histShape op) + dims = map pe64 $ shapeDims (segments <> histOpShape op <> histShape op) l' = Locking locks 0 1 0 (pure . (`rem` num_locks) . flattenIndex dims) locks_t = Array int32 (Shape [unCount group_size]) NoUniqueness @@ -517,7 +518,7 @@ compileGroupOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do sOp $ Imp.Barrier Imp.FenceLocal compileGroupOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do compileFlatId space - let (ltids, _dims) = unzip $ unSegSpace space + let (ltids, dims) = unzip $ unSegSpace space -- We don't need the red_pes, because it is guaranteed by our type -- rules that they occupy the same memory as the destinations for @@ -527,7 +528,7 @@ compileGroupOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do splitAt num_red_res $ patElems pat group_size <- kernelGroupSizeCount . kernelConstants <$> askEnv - ops' <- prepareIntraGroupSegHist group_size ops + ops' <- prepareIntraGroupSegHist (Shape $ init dims) group_size ops -- Ensure that all locks have been initialised. sOp $ Imp.Barrier Imp.FenceLocal diff --git a/tests/ad/reducebyindexvecmin0.fut b/tests/ad/reducebyindexvecmin0.fut new file mode 100644 index 0000000000..b03a95e682 --- /dev/null +++ b/tests/ad/reducebyindexvecmin0.fut @@ -0,0 +1,32 @@ +-- == +-- entry: vecmin +-- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- [[8i32, 5i32, -2i32, 4i32, 6i32], +-- [12i32, 8i32, 7i32, 2i32, 6i32], +-- [3i32, 9i32, -2i32, 11i32, 1i32], +-- [7i32, 3i32, 12i32, 7i32, 10i32], +-- [9i32, 12i32, 4i32, 1i32, 8i32], +-- [7i32, -1i32, 11i32, 6i32, 10i32], +-- [-2i32, 6i32, 7i32, 1i32, 12i32], +-- [8i32, 0i32, 9i32, 6i32, 7i32], +-- [7i32, 3i32, 6i32, 7i32, 8i32], +-- [1i32, 4i32, 2i32, 9i32, 9i32]] +-- [[2i32, 2i32, 2i32, 2i32, 4i32], +-- [2i32, 1i32, 3i32, 3i32, 1i32], +-- [2i32, 5i32, 2i32, 4i32, 5i32], +-- [1i32, 2i32, 1i32, 1i32, 4i32], +-- [5i32, 1i32, 4i32, 4i32, 5i32], +-- [4i32, 1i32, 4i32, 5i32, 3i32]] } +-- output { [[4i32, 1i32, 4i32, 5i32, 3i32], +-- [0i32, 0i32, 0i32, 0i32, 4i32], +-- [2i32, 0i32, 2i32, 0i32, 5i32], +-- [0i32, 1i32, 0i32, 0i32, 5i32], +-- [0i32, 0i32, 0i32, 1i32, 0i32], +-- [0i32, 2i32, 0i32, 0i32, 0i32], +-- [5i32, 0i32, 4i32, 4i32, 0i32], +-- [0i32, 5i32, 0i32, 4i32, 0i32], +-- [0i32, 0i32, 0i32, 0i32, 0i32], +-- [1i32, 0i32, 1i32, 0i32, 0i32]] } + +entry vecmin [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) = + vjp (hist (map2 i32.min) (replicate d i32.highest) bins is) vs adj_out diff --git a/tests/ad/reducebyindexvecmul0.fut b/tests/ad/reducebyindexvecmul0.fut new file mode 100644 index 0000000000..f93a779e64 --- /dev/null +++ b/tests/ad/reducebyindexvecmul0.fut @@ -0,0 +1,32 @@ +-- == +-- entry: vecmul +-- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- [[8i32, 5i32, -2i32, 4i32, 6i32], +-- [12i32, 8i32, 7i32, 2i32, 6i32], +-- [3i32, 9i32, -2i32, 11i32, 1i32], +-- [7i32, 3i32, 12i32, 7i32, 10i32], +-- [9i32, 12i32, 4i32, 1i32, 8i32], +-- [7i32, -1i32, 11i32, 6i32, 10i32], +-- [-2i32, 6i32, 7i32, 1i32, 12i32], +-- [8i32, 0i32, 9i32, 6i32, 7i32], +-- [7i32, 3i32, 6i32, 7i32, 8i32], +-- [1i32, 4i32, 2i32, 9i32, 9i32]] +-- [[2i32, 2i32, 2i32, 2i32, 4i32], +-- [2i32, 1i32, 3i32, 3i32, 1i32], +-- [2i32, 5i32, 2i32, 4i32, 5i32], +-- [1i32, 2i32, 1i32, 1i32, 4i32], +-- [5i32, 1i32, 4i32, 4i32, 5i32], +-- [4i32, 1i32, 4i32, 5i32, 3i32]] } +-- output { [[4i32, 1i32, 4i32, 5i32, 3i32], +-- [63i32, -96i32, 88i32, 54i32, 2880i32], +-- [112i32, 0i32, 108i32, 168i32, 280i32], +-- [-10i32, 6i32, 28i32, 4i32, 60i32], +-- [84i32, -64i32, 154i32, 108i32, 2160i32], +-- [108i32, 768i32, 56i32, 18i32, 1728i32], +-- [35i32, 3i32, 48i32, 28i32, 50i32], +-- [42i32, 135i32, -24i32, 308i32, 40i32], +-- [48i32, 0i32, -36i32, 264i32, 35i32], +-- [756i32, -192i32, 308i32, 12i32, 1920i32]] } + +entry vecmul [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) = + vjp (hist (map2 (*)) (replicate d 1i32) bins is) vs adj_out