Skip to content

Commit

Permalink
Tweak OpsADVal for slightly faster rewriting
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 29, 2025
1 parent 101229d commit b47d13d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ instance (ADReadyNoLet target, ShareTensor target)
tunpair (D u u') = let (u1, u2) = tunpair u
(d1, d2) = unDeltaPair u'
in (dDnotShared u1 d1, dDnotShared u2 d2)
tfromSShare (D u u') = dDnotShared (tfromSShare u) (dFromS u')

-- Note that these instances don't do vectorization. To enable it,
-- use the Ast instance and only then interpret in ADVal.
Expand Down Expand Up @@ -545,7 +546,7 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
tlambda _ = id
-- Bangs are for the proper order of sharing stamps.
tcond !stk !b !u !v =
let uv = tfromVector (SNat @2) stk (V.fromList [u, v])
let uv = tfromVectorShare (SNat @2) stk (V.fromList [u, v])
in tindexBuildShare (SNat @2) stk uv (ifF b 0 1)
tprimalPart _stk (D u _) = u
tdualPart _stk (D _ u') = u'
Expand Down

0 comments on commit b47d13d

Please sign in to comment.