From 8e866249bb0898a9a6c93fa2665a53265970b15c Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Wed, 11 Oct 2023 10:59:00 +0200 Subject: [PATCH 1/9] irwim for histograms --- src/Futhark/AD/Rev/Hist.hs | 64 ++++++++++++++++++++++++++++++++++++++ src/Futhark/AD/Rev/SOAC.hs | 20 ++++++++++-- 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 4a208fa30f..6e782fe5d2 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -8,6 +8,7 @@ module Futhark.AD.Rev.Hist ( diffMinMaxHist, diffMulHist, diffAddHist, + diffVecHist, diffHist, ) where @@ -20,6 +21,8 @@ import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename +import Debug.Trace (trace, traceM) + getBinOpPlus :: PrimType -> BinOp getBinOpPlus (IntType x) = Add x OverflowUndef getBinOpPlus (FloatType f) = FAdd f @@ -522,6 +525,67 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar updateAdj vs vs_bar + +-- Special case for vectorised combining operator. Rewrite +-- reduce_by_index dst (map2 op) nes is vss +-- to +-- map3 (\dst_col vss_col ne -> +-- reduce_by_index dst_col op ne is vss_col +-- ) (transpose dst) (transpose vss) nes |> transpose +-- before differentiating. +diffVecHist :: + VjpOps -> + VName -> + StmAux () -> + SubExp -> + Lambda SOACS -> + VName -> + VName -> + VName -> + SubExp -> + SubExp -> + VName -> + ADM () -> + ADM () +diffVecHist ops x aux n op nes is vss w rf dst m = do + stms <- collectStms_ $ do + rank <- arrayRank <$> lookupType vss + let dims = [1, 0] ++ drop 2 [0 .. rank - 1] + + dstT <- letExp "dstT" $ BasicOp $ Rearrange dims dst + vssT <- letExp "vssT" $ BasicOp $ Rearrange dims vss + t_dstT <- lookupType dstT + t_vssT <- lookupType vssT + t_nes <- lookupType nes + + dst_col <- newParam "dst_col" $ rowType t_dstT + vss_col <- newParam "vss_col" $ rowType t_vssT + ne <- newParam "ne" $ rowType t_nes + + f <- mkIdentityLambda [Prim int64, Prim $ elemType t_dstT] + map_lam <- + mkLambda [dst_col, vss_col, ne] $ do + -- TODO is copying here the right solution?? + dst_col_cpy <- + letExp "dst_col_cpy" . BasicOp $ + Replicate mempty (Var $ paramName dst_col) + -- TODO fmap varsRes . letTupExp "col_res" $ + fmap varsRes . fmap pure . letExp "col_res" $ + Op $ + Hist + n + [is, paramName vss_col] + [HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op] + f + histT <- letExp "histT" $ Op $ + Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $ mapSOAC map_lam + -- TODO If we change x to be of type Pat Type, use: + -- addStm $ Let x aux $ Op $ + -- Just matching the style of other functions here: + auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT + trace (prettyString stms) $ foldr (vjpStm ops) m stms + + -- -- a step in the radix sort implementation -- it assumes the key we are sorting diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index 3560cfe121..8311415871 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -14,6 +14,7 @@ import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Util (chunks) +import Debug.Trace (trace) -- We split any multi-op scan or reduction into multiple operations so -- we can detect special cases. Post-AD, the result may be fused @@ -156,27 +157,40 @@ vjpSOAC ops pat aux (Hist n as histops f) m let (is, vs) = splitAt (length histops) as splitHist ops pat aux histops n is vs m vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m + | isIdentityLambda f, + [x] <- patNames pat, + HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop, + -- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'. + -- TODO handle nested maps? + Just op <- mapOp lam = + trace "\n\n#diffVecHist input" $ + trace (prettyString (Hist n [is, vs] [histop] f)) $ + trace "#diffVecHist result" $ + diffVecHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMinMaxOp op = - diffMinMaxHist ops x aux n op ne is vs w rf dst m + trace "#diffMinMaxHist" $ diffMinMaxHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMulOp op = - diffMulHist ops x aux n op ne is vs w rf dst m + trace "\n\n#diffMulHist input" $ + trace (prettyString (Hist n [is, vs] [histop] f)) $ + trace "#diffMulHist result" $ + diffMulHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isAddOp op = - diffAddHist ops x aux n lam ne is vs w rf dst m + trace "#diffAddHist" $ diffAddHist ops x aux n lam ne is vs w rf dst m vjpSOAC ops pat aux (Hist n as [histop] f) m | isIdentityLambda f, HistOp (Shape w) rf dst ne lam <- histop = do From bfe3d93ffd2edf848602f74bb24c46dbbdd858d0 Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Wed, 11 Oct 2023 11:03:23 +0200 Subject: [PATCH 2/9] Remove debug tracing. --- src/Futhark/AD/Rev/Hist.hs | 4 +--- src/Futhark/AD/Rev/SOAC.hs | 19 ++++--------------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 6e782fe5d2..054c07496d 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -21,8 +21,6 @@ import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename -import Debug.Trace (trace, traceM) - getBinOpPlus :: PrimType -> BinOp getBinOpPlus (IntType x) = Add x OverflowUndef getBinOpPlus (FloatType f) = FAdd f @@ -583,7 +581,7 @@ diffVecHist ops x aux n op nes is vss w rf dst m = do -- addStm $ Let x aux $ Op $ -- Just matching the style of other functions here: auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT - trace (prettyString stms) $ foldr (vjpStm ops) m stms + foldr (vjpStm ops) m stms -- diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index 8311415871..f1a8a56ccc 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -14,7 +14,6 @@ import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Util (chunks) -import Debug.Trace (trace) -- We split any multi-op scan or reduction into multiple operations so -- we can detect special cases. Post-AD, the result may be fused @@ -162,35 +161,25 @@ vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop, -- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'. -- TODO handle nested maps? - Just op <- mapOp lam = - trace "\n\n#diffVecHist input" $ - trace (prettyString (Hist n [is, vs] [histop] f)) $ - trace "#diffVecHist result" $ - diffVecHist ops x aux n op ne is vs w rf dst m + Just op <- mapOp lam = diffVecHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', - isMinMaxOp op = - trace "#diffMinMaxHist" $ diffMinMaxHist ops x aux n op ne is vs w rf dst m + isMinMaxOp op = diffMinMaxHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', - isMulOp op = - trace "\n\n#diffMulHist input" $ - trace (prettyString (Hist n [is, vs] [histop] f)) $ - trace "#diffMulHist result" $ - diffMulHist ops x aux n op ne is vs w rf dst m + isMulOp op = diffMulHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', - isAddOp op = - trace "#diffAddHist" $ diffAddHist ops x aux n lam ne is vs w rf dst m + isAddOp op = diffAddHist ops x aux n lam ne is vs w rf dst m vjpSOAC ops pat aux (Hist n as [histop] f) m | isIdentityLambda f, HistOp (Shape w) rf dst ne lam <- histop = do From 2d7a042cc7795950e59e68cd0f6b74fb755de717 Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Wed, 11 Oct 2023 16:33:20 +0200 Subject: [PATCH 3/9] Formatting and such. --- src/Futhark/AD/Rev/Hist.hs | 25 +++++++++++-------------- src/Futhark/AD/Rev/SOAC.hs | 13 ++++++++----- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 054c07496d..c2489ea91c 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -523,7 +523,6 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar updateAdj vs vs_bar - -- Special case for vectorised combining operator. Rewrite -- reduce_by_index dst (map2 op) nes is vss -- to @@ -563,27 +562,25 @@ diffVecHist ops x aux n op nes is vss w rf dst m = do f <- mkIdentityLambda [Prim int64, Prim $ elemType t_dstT] map_lam <- mkLambda [dst_col, vss_col, ne] $ do - -- TODO is copying here the right solution?? + -- TODO Have to copy dst_col, but isn't it already unique? dst_col_cpy <- letExp "dst_col_cpy" . BasicOp $ Replicate mempty (Var $ paramName dst_col) - -- TODO fmap varsRes . letTupExp "col_res" $ - fmap varsRes . fmap pure . letExp "col_res" $ + fmap (varsRes . pure) . letExp "col_res" $ Op $ Hist - n - [is, paramName vss_col] - [HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op] - f - histT <- letExp "histT" $ Op $ - Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $ mapSOAC map_lam - -- TODO If we change x to be of type Pat Type, use: - -- addStm $ Let x aux $ Op $ - -- Just matching the style of other functions here: + n + [is, paramName vss_col] + [HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op] + f + histT <- + letExp "histT" $ + Op $ + Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $ + mapSOAC map_lam auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT foldr (vjpStm ops) m stms - -- -- a step in the radix sort implementation -- it assumes the key we are sorting diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index f1a8a56ccc..8a0d008b6e 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -160,26 +160,29 @@ vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop, -- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'. - -- TODO handle nested maps? - Just op <- mapOp lam = diffVecHist ops x aux n op ne is vs w rf dst m + Just op <- mapOp lam = + diffVecHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', - isMinMaxOp op = diffMinMaxHist ops x aux n op ne is vs w rf dst m + isMinMaxOp op = + diffMinMaxHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', - isMulOp op = diffMulHist ops x aux n op ne is vs w rf dst m + isMulOp op = + diffMulHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', - isAddOp op = diffAddHist ops x aux n lam ne is vs w rf dst m + isAddOp op = + diffAddHist ops x aux n lam ne is vs w rf dst m vjpSOAC ops pat aux (Hist n as [histop] f) m | isIdentityLambda f, HistOp (Shape w) rf dst ne lam <- histop = do From 6a68f73991fda9e77aa978ee7f4ec3148ae832f1 Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Thu, 12 Oct 2023 10:12:55 +0200 Subject: [PATCH 4/9] Add two tiny tests. --- tests/ad/reducebyindexvecmin0.fut | 32 +++++++++++++++++++++++++++++++ tests/ad/reducebyindexvecmul0.fut | 32 +++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 tests/ad/reducebyindexvecmin0.fut create mode 100644 tests/ad/reducebyindexvecmul0.fut diff --git a/tests/ad/reducebyindexvecmin0.fut b/tests/ad/reducebyindexvecmin0.fut new file mode 100644 index 0000000000..ab972a26f9 --- /dev/null +++ b/tests/ad/reducebyindexvecmin0.fut @@ -0,0 +1,32 @@ +-- == +-- entry: vecmin +-- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- [[8i32, 5i32, -2i32, 4i32, 6i32], +-- [12i32, 8i32, 7i32, 2i32, 6i32], +-- [3i32, 9i32, -2i32, 11i32, 1i32], +-- [7i32, 3i32, 12i32, 7i32, 10i32], +-- [9i32, 12i32, 4i32, 1i32, 8i32], +-- [7i32, -1i32, 11i32, 6i32, 10i32], +-- [-2i32, 6i32, 7i32, 1i32, 12i32], +-- [8i32, 0i32, 9i32, 6i32, 7i32], +-- [7i32, 3i32, 6i32, 7i32, 8i32], +-- [1i32, 4i32, 2i32, 9i32, 9i32]] +-- [[2i32, 2i32, 2i32, 2i32, 4i32], +-- [2i32, 1i32, 3i32, 3i32, 1i32], +-- [2i32, 5i32, 2i32, 4i32, 5i32], +-- [1i32, 2i32, 1i32, 1i32, 4i32], +-- [5i32, 1i32, 4i32, 4i32, 5i32], +-- [4i32, 1i32, 4i32, 5i32, 3i32]] } +-- output { [[4i32, 1i32, 4i32, 5i32, 3i32], +-- [0i32, 0i32, 0i32, 0i32, 4i32], +-- [2i32, 0i32, 2i32, 0i32, 5i32], +-- [0i32, 1i32, 0i32, 0i32, 5i32], +-- [0i32, 0i32, 0i32, 1i32, 0i32], +-- [0i32, 2i32, 0i32, 0i32, 0i32], +-- [5i32, 0i32, 4i32, 4i32, 0i32], +-- [0i32, 5i32, 0i32, 4i32, 0i32], +-- [0i32, 0i32, 0i32, 0i32, 0i32], +-- [1i32, 0i32, 1i32, 0i32, 0i32]] } + +entry vecmin [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) = + vjp (hist (map2 i32.min) (replicate d i32.highest) bins is) vs adj_out diff --git a/tests/ad/reducebyindexvecmul0.fut b/tests/ad/reducebyindexvecmul0.fut new file mode 100644 index 0000000000..39558c6f96 --- /dev/null +++ b/tests/ad/reducebyindexvecmul0.fut @@ -0,0 +1,32 @@ +-- == +-- entry: vecmul +-- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- [[8i32, 5i32, -2i32, 4i32, 6i32], +-- [12i32, 8i32, 7i32, 2i32, 6i32], +-- [3i32, 9i32, -2i32, 11i32, 1i32], +-- [7i32, 3i32, 12i32, 7i32, 10i32], +-- [9i32, 12i32, 4i32, 1i32, 8i32], +-- [7i32, -1i32, 11i32, 6i32, 10i32], +-- [-2i32, 6i32, 7i32, 1i32, 12i32], +-- [8i32, 0i32, 9i32, 6i32, 7i32], +-- [7i32, 3i32, 6i32, 7i32, 8i32], +-- [1i32, 4i32, 2i32, 9i32, 9i32]] +-- [[2i32, 2i32, 2i32, 2i32, 4i32], +-- [2i32, 1i32, 3i32, 3i32, 1i32], +-- [2i32, 5i32, 2i32, 4i32, 5i32], +-- [1i32, 2i32, 1i32, 1i32, 4i32], +-- [5i32, 1i32, 4i32, 4i32, 5i32], +-- [4i32, 1i32, 4i32, 5i32, 3i32]] } +-- output { [[4i32, 1i32, 4i32, 5i32, 3i32], +-- [63i32, -96i32, 88i32, 54i32, 2880i32], +-- [112i32, 0i32, 108i32, 168i32, 280i32], +-- [-10i32, 6i32, 28i32, 4i32, 60i32], +-- [84i32, -64i32, 154i32, 108i32, 2160i32], +-- [108i32, 768i32, 56i32, 18i32, 1728i32], +-- [35i32, 3i32, 48i32, 28i32, 50i32], +-- [42i32, 135i32, -24i32, 308i32, 40i32], +-- [48i32, 0i32, -36i32, 264i32, 35i32], +-- [756i32, -192i32, 308i32, 12i32, 1920i32]] } + +entry vecmul [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) = + vjp (hist (map2 (*)) (replicate d 1i32) bins is) vs adj_out From 3080e403a64a24938033614b0df74ac1ef0ecabe Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Thu, 12 Oct 2023 10:42:06 +0200 Subject: [PATCH 5/9] Test compiled only; interpreter does not support AD. --- tests/ad/reducebyindexvecmin0.fut | 34 +++++++++++++++---------------- tests/ad/reducebyindexvecmul0.fut | 34 +++++++++++++++---------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/tests/ad/reducebyindexvecmin0.fut b/tests/ad/reducebyindexvecmin0.fut index ab972a26f9..b03a95e682 100644 --- a/tests/ad/reducebyindexvecmin0.fut +++ b/tests/ad/reducebyindexvecmin0.fut @@ -1,22 +1,22 @@ -- == -- entry: vecmin --- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] --- [[8i32, 5i32, -2i32, 4i32, 6i32], --- [12i32, 8i32, 7i32, 2i32, 6i32], --- [3i32, 9i32, -2i32, 11i32, 1i32], --- [7i32, 3i32, 12i32, 7i32, 10i32], --- [9i32, 12i32, 4i32, 1i32, 8i32], --- [7i32, -1i32, 11i32, 6i32, 10i32], --- [-2i32, 6i32, 7i32, 1i32, 12i32], --- [8i32, 0i32, 9i32, 6i32, 7i32], --- [7i32, 3i32, 6i32, 7i32, 8i32], --- [1i32, 4i32, 2i32, 9i32, 9i32]] --- [[2i32, 2i32, 2i32, 2i32, 4i32], --- [2i32, 1i32, 3i32, 3i32, 1i32], --- [2i32, 5i32, 2i32, 4i32, 5i32], --- [1i32, 2i32, 1i32, 1i32, 4i32], --- [5i32, 1i32, 4i32, 4i32, 5i32], --- [4i32, 1i32, 4i32, 5i32, 3i32]] } +-- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- [[8i32, 5i32, -2i32, 4i32, 6i32], +-- [12i32, 8i32, 7i32, 2i32, 6i32], +-- [3i32, 9i32, -2i32, 11i32, 1i32], +-- [7i32, 3i32, 12i32, 7i32, 10i32], +-- [9i32, 12i32, 4i32, 1i32, 8i32], +-- [7i32, -1i32, 11i32, 6i32, 10i32], +-- [-2i32, 6i32, 7i32, 1i32, 12i32], +-- [8i32, 0i32, 9i32, 6i32, 7i32], +-- [7i32, 3i32, 6i32, 7i32, 8i32], +-- [1i32, 4i32, 2i32, 9i32, 9i32]] +-- [[2i32, 2i32, 2i32, 2i32, 4i32], +-- [2i32, 1i32, 3i32, 3i32, 1i32], +-- [2i32, 5i32, 2i32, 4i32, 5i32], +-- [1i32, 2i32, 1i32, 1i32, 4i32], +-- [5i32, 1i32, 4i32, 4i32, 5i32], +-- [4i32, 1i32, 4i32, 5i32, 3i32]] } -- output { [[4i32, 1i32, 4i32, 5i32, 3i32], -- [0i32, 0i32, 0i32, 0i32, 4i32], -- [2i32, 0i32, 2i32, 0i32, 5i32], diff --git a/tests/ad/reducebyindexvecmul0.fut b/tests/ad/reducebyindexvecmul0.fut index 39558c6f96..f93a779e64 100644 --- a/tests/ad/reducebyindexvecmul0.fut +++ b/tests/ad/reducebyindexvecmul0.fut @@ -1,22 +1,22 @@ -- == -- entry: vecmul --- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] --- [[8i32, 5i32, -2i32, 4i32, 6i32], --- [12i32, 8i32, 7i32, 2i32, 6i32], --- [3i32, 9i32, -2i32, 11i32, 1i32], --- [7i32, 3i32, 12i32, 7i32, 10i32], --- [9i32, 12i32, 4i32, 1i32, 8i32], --- [7i32, -1i32, 11i32, 6i32, 10i32], --- [-2i32, 6i32, 7i32, 1i32, 12i32], --- [8i32, 0i32, 9i32, 6i32, 7i32], --- [7i32, 3i32, 6i32, 7i32, 8i32], --- [1i32, 4i32, 2i32, 9i32, 9i32]] --- [[2i32, 2i32, 2i32, 2i32, 4i32], --- [2i32, 1i32, 3i32, 3i32, 1i32], --- [2i32, 5i32, 2i32, 4i32, 5i32], --- [1i32, 2i32, 1i32, 1i32, 4i32], --- [5i32, 1i32, 4i32, 4i32, 5i32], --- [4i32, 1i32, 4i32, 5i32, 3i32]] } +-- compiled input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] +-- [[8i32, 5i32, -2i32, 4i32, 6i32], +-- [12i32, 8i32, 7i32, 2i32, 6i32], +-- [3i32, 9i32, -2i32, 11i32, 1i32], +-- [7i32, 3i32, 12i32, 7i32, 10i32], +-- [9i32, 12i32, 4i32, 1i32, 8i32], +-- [7i32, -1i32, 11i32, 6i32, 10i32], +-- [-2i32, 6i32, 7i32, 1i32, 12i32], +-- [8i32, 0i32, 9i32, 6i32, 7i32], +-- [7i32, 3i32, 6i32, 7i32, 8i32], +-- [1i32, 4i32, 2i32, 9i32, 9i32]] +-- [[2i32, 2i32, 2i32, 2i32, 4i32], +-- [2i32, 1i32, 3i32, 3i32, 1i32], +-- [2i32, 5i32, 2i32, 4i32, 5i32], +-- [1i32, 2i32, 1i32, 1i32, 4i32], +-- [5i32, 1i32, 4i32, 4i32, 5i32], +-- [4i32, 1i32, 4i32, 5i32, 3i32]] } -- output { [[4i32, 1i32, 4i32, 5i32, 3i32], -- [63i32, -96i32, 88i32, 54i32, 2880i32], -- [112i32, 0i32, 108i32, 168i32, 280i32], From 2901d0b2e646d783e5df6e67ac4fc1c17339ec7e Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Thu, 12 Oct 2023 13:44:10 +0200 Subject: [PATCH 6/9] Debugging. Revert "Remove debug tracing." This reverts commit bfe3d93ffd2edf848602f74bb24c46dbbdd858d0. More debug printing. --- src/Futhark/AD/Rev/Hist.hs | 10 +++++++++- src/Futhark/AD/Rev/SOAC.hs | 27 ++++++++++++++++++++++----- src/Futhark/AD/Rev/Scan.hs | 4 +++- src/Futhark/IR/Prop/Reshape.hs | 2 +- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index c2489ea91c..1f2060369a 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -21,6 +21,8 @@ import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename +import Debug.Trace (trace, traceM) + getBinOpPlus :: PrimType -> BinOp getBinOpPlus (IntType x) = Add x OverflowUndef getBinOpPlus (FloatType f) = FAdd f @@ -502,6 +504,12 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do f <- mkIdentityLambda [Prim int64, t] auxing aux . letBindNames [x] . Op $ Hist n [is, vs] [HistOp (Shape [w]) rf [dst_cpy] [ne] add] f + -- ### + let stm = auxing aux . letBindNames [x] . Op $ Hist n [is, vs] [HistOp (Shape [w]) rf [dst_cpy] [ne] add] f + stm' <- collectStms stm + traceM (concatMap prettyString stm') + stm + -- ### m @@ -579,7 +587,7 @@ diffVecHist ops x aux n op nes is vss w rf dst m = do Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $ mapSOAC map_lam auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT - foldr (vjpStm ops) m stms + trace (prettyString stms) $ foldr (vjpStm ops) m stms -- -- a step in the radix sort implementation diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index 8a0d008b6e..75d18726f4 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -14,6 +14,7 @@ import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Util (chunks) +import Debug.Trace (trace) -- We split any multi-op scan or reduction into multiple operations so -- we can detect special cases. Post-AD, the result may be fused @@ -108,7 +109,10 @@ vjpSOAC ops pat aux soac@(Screma w as form) m diffScanAdd ops x w lam ne a | Just [Scan lam ne] <- isScanSOAC form, Just op <- mapOp lam = do - diffScanVec ops (patNames pat) aux w op ne as m + trace "\n\n#diffScanVec input" $ + trace (prettyString (Let pat aux $ Op soac)) $ + trace "#diffScanVec result" $ + diffScanVec ops (patNames pat) aux w op ne as m | Just scans <- isScanSOAC form, length scans > 1 = splitScanRed ops (scanSOAC, scanNeutral) (pat, aux, scans, w, as) m @@ -160,29 +164,42 @@ vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop, -- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'. + -- TODO handle nested maps? Just op <- mapOp lam = - diffVecHist ops x aux n op ne is vs w rf dst m + trace "\n\n#diffVecHist input" $ + trace (prettyString (Let pat aux $ Op (Hist n [is, vs] [histop] f))) $ + trace "#diffVecHist result" $ + diffVecHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMinMaxOp op = - diffMinMaxHist ops x aux n op ne is vs w rf dst m + trace "\n\n#diffMinMaxHist input" $ + trace (prettyString (Let pat aux $ Op (Hist n [is, vs] [histop] f))) $ + trace "#diffMinMaxHist result" $ + diffMinMaxHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMulOp op = - diffMulHist ops x aux n op ne is vs w rf dst m + trace "\n\n#diffMulHist input" $ + trace (prettyString (Hist n [is, vs] [histop] f)) $ + trace "#diffMulHist result" $ + diffMulHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isAddOp op = - diffAddHist ops x aux n lam ne is vs w rf dst m + trace "\n\n#diffAddHist input" $ + trace (prettyString (Let pat aux $ Op (Hist n [is, vs] [histop] f))) $ + trace "#diffAddHist result" $ + diffAddHist ops x aux n lam ne is vs w rf dst m vjpSOAC ops pat aux (Hist n as [histop] f) m | isIdentityLambda f, HistOp (Shape w) rf dst ne lam <- histop = do diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index 11088afe4f..a7f47fb179 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -11,6 +11,8 @@ import Futhark.Tools import Futhark.Transform.Rename import Futhark.Util (chunk) +import Debug.Trace (trace) + data FirstOrSecond = WrtFirst | WrtSecond identityM :: Int -> Type -> ADM [[SubExp]] @@ -314,7 +316,7 @@ diffScanVec ops ys aux w lam ne as m = do ys transp_ys - foldr (vjpStm ops) m stmts + trace (prettyString stmts) $ foldr (vjpStm ops) m stmts diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM () diffScanAdd _ops ys n lam' ne as = do diff --git a/src/Futhark/IR/Prop/Reshape.hs b/src/Futhark/IR/Prop/Reshape.hs index 4e4acd9566..33dba88bd5 100644 --- a/src/Futhark/IR/Prop/Reshape.hs +++ b/src/Futhark/IR/Prop/Reshape.hs @@ -83,7 +83,7 @@ flattenIndex :: [num] -> num flattenIndex dims is - | length is /= length slicesizes = error "flattenIndex: length mismatch" + | length is /= length slicesizes = error ("flattenIndex: length mismatch @" ++ show (length dims) ++ " " ++ show (length is)) | otherwise = sum $ zipWith (*) is slicesizes where slicesizes = drop 1 $ sliceSizes dims From f12de675079efd10637e1cc22dff7442e88614cd Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Thu, 12 Oct 2023 15:34:35 +0200 Subject: [PATCH 7/9] Fix first bug. --- src/Futhark/AD/Rev/Hist.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 1f2060369a..ad54f51a3f 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -567,7 +567,7 @@ diffVecHist ops x aux n op nes is vss w rf dst m = do vss_col <- newParam "vss_col" $ rowType t_vssT ne <- newParam "ne" $ rowType t_nes - f <- mkIdentityLambda [Prim int64, Prim $ elemType t_dstT] + f <- mkIdentityLambda (Prim int64 : lambdaReturnType op) map_lam <- mkLambda [dst_col, vss_col, ne] $ do -- TODO Have to copy dst_col, but isn't it already unique? From 8aeea0190d950f5c30f02446aabdb6852f150b9a Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Tue, 24 Oct 2023 16:16:58 +0200 Subject: [PATCH 8/9] Fix second bug. --- src/Futhark/CodeGen/ImpGen/GPU/Group.hs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Futhark/CodeGen/ImpGen/GPU/Group.hs b/src/Futhark/CodeGen/ImpGen/GPU/Group.hs index 7497739811..dd6da6a3ee 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/Group.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/Group.hs @@ -201,10 +201,11 @@ compileFlatId space = do -- Construct the necessary lock arrays for an intra-group histogram. prepareIntraGroupSegHist :: + Shape -> Count GroupSize SubExp -> [HistOp GPUMem] -> InKernelGen [[Imp.TExp Int64] -> InKernelGen ()] -prepareIntraGroupSegHist group_size = +prepareIntraGroupSegHist segments group_size = fmap snd . mapAccumLM onOp Nothing where onOp l op = do @@ -221,7 +222,7 @@ prepareIntraGroupSegHist group_size = locks <- newVName "locks" let num_locks = pe64 $ unCount group_size - dims = map pe64 $ shapeDims (histOpShape op <> histShape op) + dims = map pe64 $ shapeDims (segments <> histOpShape op <> histShape op) l' = Locking locks 0 1 0 (pure . (`rem` num_locks) . flattenIndex dims) locks_t = Array int32 (Shape [unCount group_size]) NoUniqueness @@ -517,7 +518,7 @@ compileGroupOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do sOp $ Imp.Barrier Imp.FenceLocal compileGroupOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do compileFlatId space - let (ltids, _dims) = unzip $ unSegSpace space + let (ltids, dims) = unzip $ unSegSpace space -- We don't need the red_pes, because it is guaranteed by our type -- rules that they occupy the same memory as the destinations for @@ -527,7 +528,7 @@ compileGroupOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do splitAt num_red_res $ patElems pat group_size <- kernelGroupSizeCount . kernelConstants <$> askEnv - ops' <- prepareIntraGroupSegHist group_size ops + ops' <- prepareIntraGroupSegHist (Shape $ init dims) group_size ops -- Ensure that all locks have been initialised. sOp $ Imp.Barrier Imp.FenceLocal From bceb16d89f2c783e09e38bd1ffe3f04a1b1f2b7e Mon Sep 17 00:00:00 2001 From: Nikolaj Hinnerskov Date: Tue, 24 Oct 2023 16:17:15 +0200 Subject: [PATCH 9/9] Remove debug printing again. This reverts commit 2901d0b2e646d783e5df6e67ac4fc1c17339ec7e. --- src/Futhark/AD/Rev/Hist.hs | 10 +--------- src/Futhark/AD/Rev/SOAC.hs | 27 +++++---------------------- src/Futhark/AD/Rev/Scan.hs | 4 +--- src/Futhark/IR/Prop/Reshape.hs | 2 +- 4 files changed, 8 insertions(+), 35 deletions(-) diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index ad54f51a3f..8da2ed537f 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -21,8 +21,6 @@ import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename -import Debug.Trace (trace, traceM) - getBinOpPlus :: PrimType -> BinOp getBinOpPlus (IntType x) = Add x OverflowUndef getBinOpPlus (FloatType f) = FAdd f @@ -504,12 +502,6 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do f <- mkIdentityLambda [Prim int64, t] auxing aux . letBindNames [x] . Op $ Hist n [is, vs] [HistOp (Shape [w]) rf [dst_cpy] [ne] add] f - -- ### - let stm = auxing aux . letBindNames [x] . Op $ Hist n [is, vs] [HistOp (Shape [w]) rf [dst_cpy] [ne] add] f - stm' <- collectStms stm - traceM (concatMap prettyString stm') - stm - -- ### m @@ -587,7 +579,7 @@ diffVecHist ops x aux n op nes is vss w rf dst m = do Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $ mapSOAC map_lam auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT - trace (prettyString stms) $ foldr (vjpStm ops) m stms + foldr (vjpStm ops) m stms -- -- a step in the radix sort implementation diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index 75d18726f4..8a0d008b6e 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -14,7 +14,6 @@ import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Util (chunks) -import Debug.Trace (trace) -- We split any multi-op scan or reduction into multiple operations so -- we can detect special cases. Post-AD, the result may be fused @@ -109,10 +108,7 @@ vjpSOAC ops pat aux soac@(Screma w as form) m diffScanAdd ops x w lam ne a | Just [Scan lam ne] <- isScanSOAC form, Just op <- mapOp lam = do - trace "\n\n#diffScanVec input" $ - trace (prettyString (Let pat aux $ Op soac)) $ - trace "#diffScanVec result" $ - diffScanVec ops (patNames pat) aux w op ne as m + diffScanVec ops (patNames pat) aux w op ne as m | Just scans <- isScanSOAC form, length scans > 1 = splitScanRed ops (scanSOAC, scanNeutral) (pat, aux, scans, w, as) m @@ -164,42 +160,29 @@ vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop, -- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'. - -- TODO handle nested maps? Just op <- mapOp lam = - trace "\n\n#diffVecHist input" $ - trace (prettyString (Let pat aux $ Op (Hist n [is, vs] [histop] f))) $ - trace "#diffVecHist result" $ - diffVecHist ops x aux n op ne is vs w rf dst m + diffVecHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMinMaxOp op = - trace "\n\n#diffMinMaxHist input" $ - trace (prettyString (Let pat aux $ Op (Hist n [is, vs] [histop] f))) $ - trace "#diffMinMaxHist result" $ - diffMinMaxHist ops x aux n op ne is vs w rf dst m + diffMinMaxHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMulOp op = - trace "\n\n#diffMulHist input" $ - trace (prettyString (Hist n [is, vs] [histop] f)) $ - trace "#diffMulHist result" $ - diffMulHist ops x aux n op ne is vs w rf dst m + diffMulHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isAddOp op = - trace "\n\n#diffAddHist input" $ - trace (prettyString (Let pat aux $ Op (Hist n [is, vs] [histop] f))) $ - trace "#diffAddHist result" $ - diffAddHist ops x aux n lam ne is vs w rf dst m + diffAddHist ops x aux n lam ne is vs w rf dst m vjpSOAC ops pat aux (Hist n as [histop] f) m | isIdentityLambda f, HistOp (Shape w) rf dst ne lam <- histop = do diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index a7f47fb179..11088afe4f 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -11,8 +11,6 @@ import Futhark.Tools import Futhark.Transform.Rename import Futhark.Util (chunk) -import Debug.Trace (trace) - data FirstOrSecond = WrtFirst | WrtSecond identityM :: Int -> Type -> ADM [[SubExp]] @@ -316,7 +314,7 @@ diffScanVec ops ys aux w lam ne as m = do ys transp_ys - trace (prettyString stmts) $ foldr (vjpStm ops) m stmts + foldr (vjpStm ops) m stmts diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM () diffScanAdd _ops ys n lam' ne as = do diff --git a/src/Futhark/IR/Prop/Reshape.hs b/src/Futhark/IR/Prop/Reshape.hs index 33dba88bd5..4e4acd9566 100644 --- a/src/Futhark/IR/Prop/Reshape.hs +++ b/src/Futhark/IR/Prop/Reshape.hs @@ -83,7 +83,7 @@ flattenIndex :: [num] -> num flattenIndex dims is - | length is /= length slicesizes = error ("flattenIndex: length mismatch @" ++ show (length dims) ++ " " ++ show (length is)) + | length is /= length slicesizes = error "flattenIndex: length mismatch" | otherwise = sum $ zipWith (*) is slicesizes where slicesizes = drop 1 $ sliceSizes dims