Skip to content

Commit

Permalink
Remove the dummy useDummies parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 27, 2023
1 parent dff6e8d commit 54eb35f
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 70 deletions.
32 changes: 16 additions & 16 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class DualPart (f :: TensorType ty) where
type Dual f = (result :: TensorType ty) | result -> f
reverseDervative
:: (HasSingletonDict y, GoodScalar r)
=> Bool -> DomainsOD -> f r y -> Maybe (f r y) -> Dual f r y
=> DomainsOD -> f r y -> Maybe (f r y) -> Dual f r y
-> (AstBindingsD (RankedOf f), Domains (RankedOf f))
forwardDerivative
:: (HasSingletonDict y, GoodScalar r)
Expand All @@ -477,22 +477,22 @@ gradientDtR
:: ( KnownNat y, GoodScalar r
, RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped
, ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked )
=> Bool -> DomainsOD
=> DomainsOD
-> ranked r y -> Maybe (ranked r y) -> DeltaR ranked shaped r y
-> (AstBindingsD ranked, Domains ranked)
gradientDtR useDummies !parameters0 value !mdt !deltaTopLevel =
gradientDtR !parameters0 value !mdt !deltaTopLevel =
let dt = fromMaybe (rreplicate0N (rshape value) 1) mdt
deltaDt = DeltaDtR dt deltaTopLevel
in gradientFromDelta useDummies parameters0 deltaDt
in gradientFromDelta parameters0 deltaDt
{-# SPECIALIZE gradientDtR
:: KnownNat y
=> Bool -> DomainsOD -> Flip OR.Array Double y -> Maybe (Flip OR.Array Double y)
=> DomainsOD -> Flip OR.Array Double y -> Maybe (Flip OR.Array Double y)
-> DeltaR (Flip OR.Array) (Flip OS.Array) Double y
-> (AstBindingsD (Flip OR.Array), Domains (Flip OR.Array) ) #-}
{- TODO: this causes a cyclic dependency:
{-# SPECIALIZE gradientDtR
:: KnownNat y
=> Bool -> DomainsOD -> AstRanked PrimalSpan Double y
=> DomainsOD -> AstRanked PrimalSpan Double y
-> Maybe (AstRanked PrimalSpan Double y)
-> DeltaR (AstRanked PrimalSpan) (AstShaped PrimalSpan) Double y
-> ( AstBindingsD (DynamicTensor (AstRanked PrimalSpan))
Expand All @@ -519,30 +519,30 @@ instance ( RankedTensor (RankedOf shaped), ShapedTensor shaped
, ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked )
=> DualPart @[Nat] shaped where
type Dual shaped = DeltaS (RankedOf shaped) shaped
reverseDervative useDummies parameters0 _ = gradientDtS useDummies parameters0
reverseDervative parameters0 _ = gradientDtS parameters0
forwardDerivative = derivativeFromDeltaS

gradientDtS
:: forall ranked shaped r y.
( Sh.Shape y, GoodScalar r
, RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped
, ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked )
=> Bool -> DomainsOD
=> DomainsOD
-> Maybe (shaped r y) -> DeltaS ranked shaped r y
-> (AstBindingsD ranked, Domains ranked)
gradientDtS useDummies !parameters0 !mdt !deltaTopLevel =
gradientDtS !parameters0 !mdt !deltaTopLevel =
let dt = fromMaybe 1 mdt
deltaDt = DeltaDtS dt deltaTopLevel
in gradientFromDelta useDummies parameters0 deltaDt
in gradientFromDelta parameters0 deltaDt
{-# SPECIALIZE gradientDtS
:: Sh.Shape y
=> Bool -> DomainsOD -> Maybe (Flip OS.Array Double y)
=> DomainsOD -> Maybe (Flip OS.Array Double y)
-> DeltaS (Flip OR.Array) (Flip OS.Array) Double y
-> (AstBindingsD (Flip OR.Array), Domains (Flip OR.Array)) #-}
{- TODO: this causes a cyclic dependency:
{-# SPECIALIZE gradientDtS
:: Sh.Shape y
=> Bool -> DomainsOD -> Maybe (AstShaped PrimalSpan Double y)
=> DomainsOD -> Maybe (AstShaped PrimalSpan Double y)
-> DeltaS (AstRanked PrimalSpan) (AstShaped PrimalSpan) Double y
-> ( AstBindingsD (DynamicTensor (AstShaped PrimalSpan))
, Domains (DynamicTensor (AstShaped PrimalSpan)) ) #-}
Expand Down Expand Up @@ -672,9 +672,9 @@ gradientFromDelta
( GoodScalar r, RankedTensor ranked, ShapedTensor shaped
, ConvertTensor ranked shaped
, ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked )
=> Bool -> DomainsOD -> DeltaDt ranked shaped r
=> DomainsOD -> DeltaDt ranked shaped r
-> (AstBindingsD ranked, Domains ranked)
gradientFromDelta useDummies !parameters0 !deltaDt =
gradientFromDelta !parameters0 !deltaDt =
-- Create finite maps that hold values associated with inputs
-- and with (possibly shared) term tree nodes.
-- The former are usually initialized with dummy values so that it's cheap
Expand Down Expand Up @@ -709,11 +709,11 @@ gradientFromDelta useDummies !parameters0 !deltaDt =
in (astBindings, gradient)
-- The warnings in the following seems spurious. A GHC issue to be opened.
{-# SPECIALIZE gradientFromDelta
:: Bool -> DomainsOD -> DeltaDt (Flip OR.Array) (Flip OS.Array) Double
:: DomainsOD -> DeltaDt (Flip OR.Array) (Flip OS.Array) Double
-> (AstBindingsD (Flip OR.Array), DomainsOD) #-}
{- TODO: this causes a cyclic dependency:
{-# SPECIALIZE gradientFromDelta
:: Bool -> DomainsOD -> DeltaDt (AstRanked PrimalSpan) (AstShaped PrimalSpan) Double
:: DomainsOD -> DeltaDt (AstRanked PrimalSpan) (AstShaped PrimalSpan) Double
-> (AstBindingsD (DynamicTensor (AstRanked PrimalSpan)), Domains (AstDynamic PrimalSpan)) #-}
-}

Expand Down
21 changes: 10 additions & 11 deletions src/HordeAd/Core/DualNumber.hs
Original file line number Diff line number Diff line change
Expand Up @@ -282,35 +282,35 @@ crevOnADInputs
:: forall ty (f :: TensorType ty) r y.
( RankedTensor (ADVal (RankedOf f))
, DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y)
=> Bool -> Maybe (f r y)
=> Maybe (f r y)
-> (Domains (ADVal (RankedOf f)) -> ADVal f r y)
-> Domains (ADVal (RankedOf f))
-> (DomainsOf (RankedOf f), f r y)
-- The functions in which @revOnADInputs@ inlines are not inlined themselves
-- in client code, so the bloat is limited.
{-# INLINE crevOnADInputs #-}
crevOnADInputs useDummies mdt f inputs =
crevOnADInputs mdt f inputs =
let -- Evaluate completely after terms constructed, to free memory
-- before evaluation allocates new memory and new FFI is started.
!(D l v deltaTopLevel) = f inputs in
let parameters0 = zeroParameters inputs
(!astBindings, !gradient) =
reverseDervative useDummies parameters0 v mdt deltaTopLevel
reverseDervative parameters0 v mdt deltaTopLevel
in (unletGradient @ty @f l astBindings gradient, unletValue l [] v)

crevOnDomains
:: forall r y f.
( RankedTensor (RankedOf f), RankedTensor (ADVal (RankedOf f))
, DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y
, RankedOf (ShapedOf f) ~ RankedOf f, ShapedOf (RankedOf f) ~ ShapedOf f )
=> Bool -> Maybe (f r y)
=> Maybe (f r y)
-> (Domains (ADVal (RankedOf f)) -> ADVal f r y)
-> Domains (RankedOf f)
-> (DomainsOf (RankedOf f), f r y)
crevOnDomains useDummies mdt f parameters =
crevOnDomains mdt f parameters =
let deltaInputs = generateDeltaInputs parameters
inputs = makeADInputs parameters deltaInputs
in crevOnADInputs useDummies mdt f inputs
in crevOnADInputs mdt f inputs

cfwdOnADInputs
:: (DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y)
Expand Down Expand Up @@ -403,7 +403,7 @@ class DerivativeStages g where
revArtifactFromForwardPass
:: (GoodScalar r, HasSingletonDict y)
=> TensorToken g
-> Bool -> Bool
-> Bool
-> (Domains (RankedOf (PrimalOf g))
-> [AstDynamicVarName]
-> Domains (RankedOf g)
Expand All @@ -413,16 +413,15 @@ class DerivativeStages g where

revProduceArtifact
:: (GoodScalar r, HasSingletonDict y)
=> TensorToken g -> Bool -> Bool
=> TensorToken g -> Bool
-> (Domains (RankedOf g) -> g r y)
-> AstEnv (ADVal (RankedOf (PrimalOf g)))
(ADVal (ShapedOf (PrimalOf g)))
-> DomainsOD
-> (AstArtifactRev (PrimalOf g) r y, Dual (PrimalOf g) r y)
{-# INLINE revProduceArtifact #-}
revProduceArtifact tf useDummies hasDt g envInit =
revArtifactFromForwardPass tf useDummies hasDt
(forwardPassByInterpretation g envInit)
revProduceArtifact tf hasDt g envInit =
revArtifactFromForwardPass tf hasDt (forwardPassByInterpretation g envInit)

revEvalArtifact
:: (GoodScalar r, HasSingletonDict y)
Expand Down
12 changes: 6 additions & 6 deletions src/HordeAd/Core/Engine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ revDtMaybe f vals mdt =
let g domains = f $ parseDomains vals domains
domainsOD = toDomains vals
artifact = fst $ revProduceArtifact TensorToken
True (isJust mdt) g EM.empty domainsOD
(isJust mdt) g EM.empty domainsOD
in gcastWith (unsafeCoerce Refl :: Value vals :~: vals) $ -- !!!
parseDomains vals
$ fst $ revEvalArtifact artifact domainsOD mdt
Expand All @@ -168,7 +168,7 @@ revArtifactAdapt
revArtifactAdapt hasDt f vals =
let g domains = f $ parseDomains vals domains
domainsOD = toDomains vals
in revProduceArtifact TensorToken True hasDt g EM.empty domainsOD
in revProduceArtifact TensorToken hasDt g EM.empty domainsOD
{-# SPECIALIZE revArtifactAdapt
:: ( HasSingletonDict y
, AdaptableDomains (AstRanked FullSpan) astvals
Expand All @@ -189,7 +189,7 @@ revProduceArtifactWithoutInterpretation
{-# INLINE revProduceArtifactWithoutInterpretation #-}
revProduceArtifactWithoutInterpretation tf hasDt g =
revArtifactFromForwardPass
@Nat @g TensorToken True hasDt (forwardPassByApplication tf g)
@Nat @g TensorToken hasDt (forwardPassByApplication tf g)

forwardPassByApplication
:: forall g r y.
Expand Down Expand Up @@ -344,18 +344,18 @@ crevDtMaybe f vals mdt =
gcastWith (unsafeCoerce Refl :: Value vals :~: vals) $ -- !!!
let g inputs = f $ parseDomains vals inputs
in parseDomains vals
$ fst $ crevOnDomains True mdt g (toDomains vals)
$ fst $ crevOnDomains mdt g (toDomains vals)

{-# SPECIALIZE crevOnDomains
:: HasSingletonDict y
=> Bool -> Maybe (Flip OR.Array Double y)
=> Maybe (Flip OR.Array Double y)
-> (Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) Double y)
-> DomainsOD
-> (DomainsOD, Flip OR.Array Double y) #-}

{-# SPECIALIZE crevOnADInputs
:: HasSingletonDict y
=> Bool -> Maybe (Flip OR.Array Double y)
=> Maybe (Flip OR.Array Double y)
-> (Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) Double y)
-> Domains (ADVal (Flip OR.Array))
-> (DomainsOD, Flip OR.Array Double y) #-}
Expand Down
20 changes: 10 additions & 10 deletions src/HordeAd/Core/TensorADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,15 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
-> DomainsOf (ADVal ranked)
rrev f _parameters0 parameters =
-- This computes the derivative of f again for each new @parmeters@.
fst $ crevOnDomains False Nothing (f @(ADVal (ADVal ranked))) parameters
fst $ crevOnDomains Nothing (f @(ADVal (ADVal ranked))) parameters
rrevDt :: (GoodScalar r, KnownNat n)
=> (forall f. ADReady f => Domains f -> f r n)
-> DomainsOD
-> DomainsOf (ADVal ranked)
-> ADVal ranked r n
-> DomainsOf (ADVal ranked)
rrevDt f _parameters0 parameters dt =
fst $ crevOnDomains False (Just dt) (f @(ADVal (ADVal ranked))) parameters
fst $ crevOnDomains (Just dt) (f @(ADVal (ADVal ranked))) parameters
rfwd :: (GoodScalar r, KnownNat n)
=> (forall f. ADReady f => Domains f -> f r n)
-> DomainsOD
Expand All @@ -392,9 +392,9 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
rfwd f _parameters0 parameters ds =
fst $ cfwdOnDomains parameters (f @(ADVal (ADVal ranked))) ds
srev f _parameters0 parameters =
fst $ crevOnDomains False Nothing (f @(ADVal (ADVal shaped))) parameters
fst $ crevOnDomains Nothing (f @(ADVal (ADVal shaped))) parameters
srevDt f _parameters0 parameters dt =
fst $ crevOnDomains False (Just dt) (f @(ADVal (ADVal shaped))) parameters
fst $ crevOnDomains (Just dt) (f @(ADVal (ADVal shaped))) parameters
sfwd f _parameters0 parameters ds =
fst $ cfwdOnDomains parameters (f @(ADVal (ADVal shaped))) ds
rfold :: forall rn rm n m.
Expand Down Expand Up @@ -430,7 +430,7 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
-> (ranked rn n, ranked rm m)
rf dt (x, a) =
domsToPair $ dunDomains @ranked domsOD $ fst
$ crevOnDomains False (Just dt) g (V.fromList [dfromR x, dfromR a])
$ crevOnDomains (Just dt) g (V.fromList [dfromR x, dfromR a])
in D (l1 `mergeADShare` l2)
(rfold @ranked f x0 as)
(FoldR f x0 as df rf x0' as')
Expand Down Expand Up @@ -492,7 +492,7 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
-> (shaped rn sh, shaped rm shm)
rf dt (x, a) =
domsToPair $ dunDomains @ranked domsOD $ fst
$ crevOnDomains False (Just dt) g (V.fromList [dfromS x, dfromS a])
$ crevOnDomains (Just dt) g (V.fromList [dfromS x, dfromS a])
in D (l1 `mergeADShare` l2)
(sfold @ranked f x0 as)
(FoldS f x0 as df rf x0' as')
Expand Down Expand Up @@ -544,15 +544,15 @@ instance DomainsTensor (Flip OR.Array) (Flip OS.Array) where
-> DomainsOD
-> DomainsOD
rrev f _parameters0 parameters =
fst $ crevOnDomains False Nothing (f @(ADVal (Flip OR.Array))) parameters
fst $ crevOnDomains Nothing (f @(ADVal (Flip OR.Array))) parameters
rrevDt :: (GoodScalar r, KnownNat n)
=> (forall f. ADReady f => Domains f -> f r n)
-> DomainsOD
-> DomainsOD
-> Flip OR.Array r n
-> DomainsOD
rrevDt f _parameters0 parameters dt =
fst $ crevOnDomains False (Just dt) (f @(ADVal (Flip OR.Array))) parameters
fst $ crevOnDomains (Just dt) (f @(ADVal (Flip OR.Array))) parameters
rfwd :: (GoodScalar r, KnownNat n)
=> (forall f. ADReady f => Domains f -> f r n)
-> DomainsOD
Expand All @@ -562,9 +562,9 @@ instance DomainsTensor (Flip OR.Array) (Flip OS.Array) where
rfwd f _parameters0 parameters ds =
fst $ cfwdOnDomains parameters (f @(ADVal (Flip OR.Array))) ds
srev f _parameters0 parameters =
fst $ crevOnDomains False Nothing (f @(ADVal (Flip OS.Array))) parameters
fst $ crevOnDomains Nothing (f @(ADVal (Flip OS.Array))) parameters
srevDt f _parameters0 parameters dt =
fst $ crevOnDomains False (Just dt) (f @(ADVal (Flip OS.Array))) parameters
fst $ crevOnDomains (Just dt) (f @(ADVal (Flip OS.Array))) parameters
sfwd f _parameters0 parameters ds =
fst $ cfwdOnDomains parameters (f @(ADVal (Flip OS.Array))) ds
rfold :: GoodScalar rm
Expand Down
Loading

0 comments on commit 54eb35f

Please sign in to comment.