From b47d13d651e5279088927a7622acc66ebc2d0167 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Wed, 29 Jan 2025 20:42:41 +0100 Subject: [PATCH] Tweak OpsADVal for slightly faster rewriting --- src/HordeAd/Core/OpsADVal.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/HordeAd/Core/OpsADVal.hs b/src/HordeAd/Core/OpsADVal.hs index 7e9518c58..b9989a1c9 100644 --- a/src/HordeAd/Core/OpsADVal.hs +++ b/src/HordeAd/Core/OpsADVal.hs @@ -137,6 +137,7 @@ instance (ADReadyNoLet target, ShareTensor target) tunpair (D u u') = let (u1, u2) = tunpair u (d1, d2) = unDeltaPair u' in (dDnotShared u1 d1, dDnotShared u2 d2) + tfromSShare (D u u') = dDnotShared (tfromSShare u) (dFromS u') -- Note that these instances don't do vectorization. To enable it, -- use the Ast instance and only then interpret in ADVal. @@ -545,7 +546,7 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target) tlambda _ = id -- Bangs are for the proper order of sharing stamps. tcond !stk !b !u !v = - let uv = tfromVector (SNat @2) stk (V.fromList [u, v]) + let uv = tfromVectorShare (SNat @2) stk (V.fromList [u, v]) in tindexBuildShare (SNat @2) stk uv (ifF b 0 1) tprimalPart _stk (D u _) = u tdualPart _stk (D _ u') = u'