Skip to content

Commit

Permalink
Define and use rfoldD
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 8, 2024
1 parent 130360e commit 60f92a5
Show file tree
Hide file tree
Showing 11 changed files with 520 additions and 6 deletions.
23 changes: 23 additions & 0 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,29 @@ data AstRanked :: AstSpanType -> RankedTensorType where
-> AstRanked s rn n
-> AstRanked s rm (1 + m)
-> AstRanked s rn n
AstFoldD :: forall rn n s. KnownNat n
=> ( AstVarName (AstRanked PrimalSpan) rn n
, [AstDynamicVarName]
, AstRanked PrimalSpan rn n )
-> AstRanked s rn n
-> Domains (AstRanked s) -- one rank higher than above
-> AstRanked s rn n
AstFoldDDer :: forall rn n s. KnownNat n
=> ( AstVarName (AstRanked PrimalSpan) rn n
, [AstDynamicVarName]
, AstRanked PrimalSpan rn n )
-> ( AstVarName (AstRanked PrimalSpan) rn n
, [AstDynamicVarName]
, AstVarName (AstRanked PrimalSpan) rn n
, [AstDynamicVarName]
, AstRanked PrimalSpan rn n )
-> ( AstVarName (AstRanked PrimalSpan) rn n
, AstVarName (AstRanked PrimalSpan) rn n
, [AstDynamicVarName]
, AstDomains PrimalSpan )
-> AstRanked s rn n
-> Domains (AstRanked s) -- one rank higher than above
-> AstRanked s rn n
AstScan :: forall rn rm n m s. (GoodScalar rm, KnownNat m, KnownNat n)
=> ( AstVarName (AstRanked PrimalSpan) rn n
, AstVarName (AstRanked PrimalSpan) rm m
Expand Down
28 changes: 28 additions & 0 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,21 @@ inlineAst memo v0 = case v0 of
in (memo2, Ast.AstFoldDer (nvar, mvar, v2)
(varDx, varDa, varn1, varm1, ast2)
(varDt2, nvar2, mvar2, doms2) x02 as2)
Ast.AstFoldD (nvar, mvar, v) x0 as ->
let (_, v2) = inlineAst EM.empty v
(memo1, x02) = inlineAst memo x0
(memo2, as2) = mapAccumR inlineAstDynamic memo1 as
in (memo2, Ast.AstFoldD (nvar, mvar, v2) x02 as2)
Ast.AstFoldDDer (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
(varDt2, nvar2, mvar2, doms) x0 as ->
let (_, v2) = inlineAst EM.empty v
(_, doms2) = inlineAstDomains EM.empty doms
(_, ast2) = inlineAst EM.empty ast1
(memo1, x02) = inlineAst memo x0
(memo2, as2) = mapAccumR inlineAstDynamic memo1 as
in (memo2, Ast.AstFoldDDer (nvar, mvar, v2)
(varDx, varDa, varn1, varm1, ast2)
(varDt2, nvar2, mvar2, doms2) x02 as2)
Ast.AstScan (nvar, mvar, v) x0 as ->
let (_, v2) = inlineAst EM.empty v
(memo1, x02) = inlineAst memo x0
Expand Down Expand Up @@ -600,6 +615,19 @@ unletAst env t = case t of
, unletAstDomains (emptyUnletEnv emptyADShare) doms )
(unletAst env x0)
(unletAst env as)
Ast.AstFoldD (nvar, mvar, v) x0 as ->
Ast.AstFoldD (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
(unletAst env x0)
(V.map (unletAstDynamic env) as)
Ast.AstFoldDDer (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
(varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstFoldDDer (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
( varDx, varDa, varn1, varm1
, unletAst (emptyUnletEnv emptyADShare) ast1 )
( varDt2, nvar2, mvar2
, unletAstDomains (emptyUnletEnv emptyADShare) doms )
(unletAst env x0)
(V.map (unletAstDynamic env) as)
Ast.AstScan (nvar, mvar, v) x0 as ->
Ast.AstScan (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
(unletAst env x0)
Expand Down
21 changes: 21 additions & 0 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,27 @@ interpretAst !env = \case
x0i = interpretAst @ranked env x0
asi = interpretAst @ranked env as
in rfoldDer @ranked f df rf x0i asi
AstFoldD @_ @n1 f@(_, vars, _) x0 as ->
let g :: forall f. ADReady f => f r n1 -> DomainsOf f -> f r n1
g = interpretLambda2D interpretAst EM.empty f
-- TODO: interpretLambda2D, and others, breaks sharing!
od = V.fromList $ map odFromVar vars
x0i = interpretAst env x0
asi = interpretAstDynamic env <$> as
in rfoldD g od x0i asi
AstFoldDDer @_ @n1 f0@(_, vars, _) df0 rf0 x0 as ->
let f :: forall f. ADReady f => f r n1 -> DomainsOf f -> f r n1
f = interpretLambda2D interpretAst EM.empty f0
df :: forall f. ADReady f
=> f r n1 -> DomainsOf f -> f r n1 -> DomainsOf f -> f r n1
df = interpretLambda4D interpretAst EM.empty df0
rf :: forall f. ADReady f
=> f r n1 -> f r n1 -> DomainsOf f -> DomainsOf f
rf = interpretLambda3D interpretAstDomains EM.empty rf0
od = V.fromList $ map odFromVar vars
x0i = interpretAst env x0
asi = interpretAstDynamic env <$> as
in rfoldDDer f df rf od x0i asi
AstScan @_ @rm @n1 @m f x0 as ->
let g :: forall f. ADReady f => f r n1 -> f rm m -> f r n1
g = interpretLambda2 interpretAst EM.empty f
Expand Down
58 changes: 58 additions & 0 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ areAllArgsInts = \case
AstFwd{} -> False
AstFold{} -> False
AstFoldDer{} -> False
AstFoldD{} -> False
AstFoldDDer{} -> False
AstScan{} -> False
AstScanDer{} -> False
AstScanD{} -> False
Expand Down Expand Up @@ -405,6 +407,62 @@ printAstAux cfg d = \case
. printAst cfg 11 x0
. showString " "
. printAst cfg 11 as
AstFoldD (nvar, mvars, v) x0 as ->
showParen (d > 10)
$ showString "rfoldD "
. (showParen True
$ showString "\\"
. showString (printAstVarName (varRenames cfg) nvar)
. showString " "
. showListWith (showString
. printAstDynamicVarName (varRenames cfg)) mvars
. showString " -> "
. printAst cfg 0 v)
. showString " "
. printAst cfg 11 x0
. showString " "
. printDomainsAst cfg as
AstFoldDDer (nvar, mvars, v) (varDx, varsDa, varn1, varsm1, ast1)
(varDt2, nvar2, mvars2, doms) x0 as ->
showParen (d > 10)
$ showString "rfoldDDer "
. (showParen True
$ showString "\\"
. showString (printAstVarName (varRenames cfg) nvar)
. showString " "
. showListWith (showString
. printAstDynamicVarName (varRenames cfg)) mvars
. showString " -> "
. printAst cfg 0 v)
. showString " "
. (showParen True
$ showString "\\"
. showString (printAstVarName (varRenames cfg) varDx)
. showString " "
. showListWith (showString
. printAstDynamicVarName (varRenames cfg)) varsDa
. showString " "
. showString (printAstVarName (varRenames cfg) varn1)
. showString " "
. showListWith (showString
. printAstDynamicVarName (varRenames cfg)) varsm1
. showString " -> "
. printAst cfg 0 ast1)
. showString " "
. (showParen True
$ showString "\\"
. showString (printAstVarName (varRenames cfg) varDt2)
. showString " "
. showString (printAstVarName (varRenames cfg) nvar2)
. showString " "
. showListWith (showString
. printAstDynamicVarName (varRenames cfg)) mvars2
. showString " -> "
. printAstDomains cfg 0 doms)
. showString " "
. printAst cfg 11 x0
. showString " "
. printDomainsAst cfg as
AstScan (nvar, mvar, v) x0 as ->
showParen (d > 10)
$ showString "rscan "
Expand Down
28 changes: 28 additions & 0 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ simplifyStepNonIndex t = case t of
Ast.AstFwd{} -> t
Ast.AstFold{} -> t
Ast.AstFoldDer{} -> t
Ast.AstFoldD{} -> t
Ast.AstFoldDDer{} -> t
Ast.AstScan{} -> t
Ast.AstScanDer{} -> t
Ast.AstScanD{} -> t
Expand Down Expand Up @@ -430,6 +432,8 @@ astIndexROrStepOnly stepOnly v0 ix@(i1 :. (rest1 :: AstIndex m1)) =
Ast.AstFwd{} -> Ast.AstIndex v0 ix
Ast.AstFold{} -> Ast.AstIndex v0 ix -- normal form
Ast.AstFoldDer{} -> Ast.AstIndex v0 ix -- normal form
Ast.AstFoldD{} -> Ast.AstIndex v0 ix -- normal form
Ast.AstFoldDDer{} -> Ast.AstIndex v0 ix -- normal form
Ast.AstScan{} -> Ast.AstIndex v0 ix -- normal form
Ast.AstScanDer{} -> Ast.AstIndex v0 ix -- normal form
Ast.AstScanD{} -> Ast.AstIndex v0 ix -- normal form
Expand Down Expand Up @@ -761,6 +765,8 @@ astGatherROrStepOnly stepOnly sh0 v0 (vars0, ix0) =
Ast.AstFwd{} -> Ast.AstGather sh4 v4 (vars4, ix4)
Ast.AstFold{} -> Ast.AstGather sh4 v4 (vars4, ix4) -- normal form
Ast.AstFoldDer{} -> Ast.AstGather sh4 v4 (vars4, ix4) -- normal form
Ast.AstFoldD{} -> Ast.AstGather sh4 v4 (vars4, ix4) -- normal form
Ast.AstFoldDDer{} -> Ast.AstGather sh4 v4 (vars4, ix4) -- normal form
Ast.AstScan{} -> Ast.AstGather sh4 v4 (vars4, ix4) -- normal form
Ast.AstScanDer{} -> Ast.AstGather sh4 v4 (vars4, ix4) -- normal form
Ast.AstScanD{} -> Ast.AstGather sh4 v4 (vars4, ix4) -- normal form
Expand Down Expand Up @@ -1513,6 +1519,8 @@ astPrimalPart t = case t of
Ast.AstFwd{} -> Ast.AstPrimalPart t -- the other only normal form
Ast.AstFold{} -> Ast.AstPrimalPart t
Ast.AstFoldDer{} -> Ast.AstPrimalPart t
Ast.AstFoldD{} -> Ast.AstPrimalPart t
Ast.AstFoldDDer{} -> Ast.AstPrimalPart t
Ast.AstScan{} -> Ast.AstPrimalPart t
Ast.AstScanDer{} -> Ast.AstPrimalPart t
Ast.AstScanD{} -> Ast.AstPrimalPart t
Expand Down Expand Up @@ -1590,6 +1598,8 @@ astDualPart t = case t of
Ast.AstFwd{} -> Ast.AstDualPart t
Ast.AstFold{} -> Ast.AstDualPart t
Ast.AstFoldDer{} -> Ast.AstDualPart t
Ast.AstFoldD{} -> Ast.AstDualPart t
Ast.AstFoldDDer{} -> Ast.AstDualPart t
Ast.AstScan{} -> Ast.AstDualPart t
Ast.AstScanDer{} -> Ast.AstDualPart t
Ast.AstScanD{} -> Ast.AstDualPart t
Expand Down Expand Up @@ -1874,6 +1884,15 @@ simplifyAst t = case t of
(varDx, varDa, varn1, varm1, simplifyAst ast1)
(varDt2, nvar2, mvar2, simplifyAstDomains doms)
(simplifyAst x0) (simplifyAst as)
Ast.AstFoldD (nvar, mvar, v) x0 as ->
Ast.AstFoldD (nvar, mvar, simplifyAst v) (simplifyAst x0)
(V.map simplifyAstDynamic as)
Ast.AstFoldDDer (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
(varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstFoldDDer (nvar, mvar, simplifyAst v)
(varDx, varDa, varn1, varm1, simplifyAst ast1)
(varDt2, nvar2, mvar2, simplifyAstDomains doms)
(simplifyAst x0) (V.map simplifyAstDynamic as)
Ast.AstScan (nvar, mvar, v) x0 as ->
Ast.AstScan (nvar, mvar, simplifyAst v) (simplifyAst x0) (simplifyAst as)
Ast.AstScanDer (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
Expand Down Expand Up @@ -2476,6 +2495,15 @@ substitute1Ast i var v1 = case v1 of
(Nothing, Nothing) -> Nothing
(mx0, mas) ->
Just $ Ast.AstFoldDer f df dr (fromMaybe x0 mx0) (fromMaybe as mas)
Ast.AstFoldD f x0 as ->
case (substitute1Ast i var x0, substitute1Domains i var as) of
(Nothing, Nothing) -> Nothing
(mx0, mas) -> Just $ Ast.AstFoldD f (fromMaybe x0 mx0) (fromMaybe as mas)
Ast.AstFoldDDer f df dr x0 as ->
case (substitute1Ast i var x0, substitute1Domains i var as) of
(Nothing, Nothing) -> Nothing
(mx0, mas) ->
Just $ Ast.AstFoldDDer f df dr (fromMaybe x0 mx0) (fromMaybe as mas)
Ast.AstScan f x0 as ->
case (substitute1Ast i var x0, substitute1Ast i var as) of
(Nothing, Nothing) -> Nothing
Expand Down
5 changes: 5 additions & 0 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ shapeAst = \case
AstFwd (_var, v) _l _ds -> shapeAst v
AstFold _f x0 _as -> shapeAst x0
AstFoldDer _f _df _rf x0 _as -> shapeAst x0
AstFoldD _f x0 _as -> shapeAst x0
AstFoldDDer _f _df _rf x0 _as -> shapeAst x0
AstScan _f x0 as -> lengthAst as + 1 :$ shapeAst x0
AstScanDer _f _df _rf x0 as -> lengthAst as + 1 :$ shapeAst x0
AstScanD _f x0 as ->
Expand Down Expand Up @@ -177,6 +179,9 @@ varInAst var = \case
in any f l || any f ds
AstFold _f x0 as -> varInAst var x0 || varInAst var as
AstFoldDer _f _df _rf x0 as -> varInAst var x0 || varInAst var as
AstFoldD _f x0 as -> varInAst var x0 || any (varInAstDynamic var) as
AstFoldDDer _f _df _rf x0 as ->
varInAst var x0 || any (varInAstDynamic var) as
AstScan _f x0 as -> varInAst var x0 || varInAst var as
AstScanDer _f _df _rf x0 as -> varInAst var x0 || varInAst var as
AstScanD _f x0 as -> varInAst var x0 || any (varInAstDynamic var) as
Expand Down
66 changes: 66 additions & 0 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,71 @@ build1V k (var, v00) =
$ substProjDomains k var shm mvar doms) )
(build1VOccurenceUnknown k (var, x0))
(astTr $ build1VOccurenceUnknown k (var, as))
Ast.AstFoldD{} ->
error "build1V: impossible case of AstFoldD"
Ast.AstFoldDDer @_ @n2
(nvar, mvars, v) (varDx, varsDa, varn1, varsm1, ast1)
(varDt2, nvar2, mvars2, doms) x0 as ->
case someNatVal $ toInteger k of
Just (SomeNat @k3 _) ->
let shn = shapeAst x0
substProjDynamicDomains :: AstDomains PrimalSpan
-> AstDynamicVarName
-> (AstDomains PrimalSpan, AstDynamicVarName)
substProjDynamicDomains v3 (AstDynamicVarName @ty @r3 @sh3 varId)
| Just Refl <- testEquality (typeRep @ty) (typeRep @Nat) =
( withListShape (Sh.shapeT @sh3) $ \sh1 ->
substProjDomains @_ @r3 @s k var sh1 (AstVarName varId) v3
, AstDynamicVarName @ty @r3 @(k3 ': sh3) varId )
substProjDynamicDomains _ _ =
error "substProjDynamicDomains: unexpected type"
substProjVarsDomains :: [AstDynamicVarName]
-> AstDomains PrimalSpan
-> (AstDomains PrimalSpan, [AstDynamicVarName])
substProjVarsDomains vars v3 =
mapAccumR substProjDynamicDomains v3 vars
(vOut, mvarsOut) = substProjVars @k3 var mvars v
(ast1Out, varsDaOut) = substProjVars @k3 var varsDa ast1
(ast1Out2, varsm1Out) = substProjVars @k3 var varsm1 ast1Out
(domsOut, mvars2Out) = substProjVarsDomains mvars2 doms
astTrDynamicRanked :: DynamicTensor (AstRanked s)
-> DynamicTensor (AstRanked s)
astTrDynamicRanked t@(DynamicRanked @_ @n3 u) =
case cmpNat (Proxy @2) (Proxy @n3) of
EQI -> DynamicRanked $ astTr @(n3 - 2) u
LTI -> DynamicRanked $ astTr @(n3 - 2) u
_ -> t
astTrDynamicRanked DynamicShaped{} =
error "astTrDynamicRanked:DynamicShaped"
astTrDynamicRanked (DynamicRankedDummy p1 (Proxy @sh3)) =
let permute10 (m0 : m1 : ms) = m1 : m0 : ms
permute10 ms = ms
sh3Permuted = permute10 $ Sh.shapeT @sh3
in Sh.withShapeP sh3Permuted $ \proxy ->
DynamicRankedDummy p1 proxy
astTrDynamicRanked DynamicShapedDummy{} =
error "astTrDynamicRanked:DynamicShapedDummy"
in Ast.AstFoldDDer
( AstVarName $ varNameToAstVarId nvar
, mvarsOut
, build1VOccurenceUnknownRefresh
k (var, substProjRanked k var shn nvar vOut) )
( AstVarName $ varNameToAstVarId varDx
, varsDaOut
, AstVarName $ varNameToAstVarId varn1
, varsm1Out
, build1VOccurenceUnknownRefresh
k (var, substProjRanked k var shn varDx
$ substProjRanked k var shn varn1 ast1Out2) )
( AstVarName $ varNameToAstVarId varDt2
, AstVarName $ varNameToAstVarId nvar2
, mvars2Out
, build1VOccurenceUnknownAstDomainsRefresh
k (var, substProjDomains k var shn nvar domsOut) )
(build1VOccurenceUnknown k (var, x0))
(V.map (\u -> astTrDynamicRanked
$ build1VOccurenceUnknownDynamic k (var, u)) as)
_ -> error "build1V: impossible someNatVal"
Ast.AstScan{} ->
error "build1V: impossible case of AstScan"
Ast.AstScanDer @_ @_ @n2
Expand Down Expand Up @@ -641,6 +706,7 @@ build1VIndex k (var, v0, ix@(_ :. _)) =
Ast.AstScatter{} -> ruleD
Ast.AstAppend{} -> ruleD
Ast.AstFoldDer{} -> ruleD
Ast.AstFoldDDer{} -> ruleD
Ast.AstScanDer{} -> ruleD
Ast.AstScanDDer{} -> ruleD
_ -> build1VOccurenceUnknown k (var, v) -- not a normal form
Expand Down
Loading

0 comments on commit 60f92a5

Please sign in to comment.