Skip to content

Commit

Permalink
Move crevOnADInputs, crevOnDomains and the fwd counterparts
Browse files Browse the repository at this point in the history
down the module tree.
  • Loading branch information
Mikolaj committed Nov 29, 2023
1 parent 6ebf269 commit 9c5ffe8
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 162 deletions.
168 changes: 166 additions & 2 deletions src/HordeAd/Core/DualNumber.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,33 @@ module HordeAd.Core.DualNumber
, ensureToplevelSharing, scaleNotShared, addNotShared, multNotShared
-- , addParameters, dotParameters
, DerivativeStages (..)
, crevOnADInputs, crevOnDomains, cfwdOnADInputs, cfwdOnDomains
, generateDeltaInputsOD, generateDeltaInputsAst, makeADInputs
) where

import Prelude

import Control.Exception.Assert.Sugar
import qualified Data.Array.RankedS as OR
import Data.Bifunctor.Clown
import Data.Bifunctor.Flip
import Data.Bifunctor.Product
import Data.Functor.Const
import Data.Kind (Constraint, Type)
import GHC.TypeLits (KnownNat)
import Data.Proxy (Proxy)
import Data.Type.Equality (testEquality, (:~:) (Refl))
import qualified Data.Vector.Generic as V
import GHC.TypeLits (KnownNat, SomeNat (..), someNatVal)
import Type.Reflection (typeRep)

import HordeAd.Core.Ast
import HordeAd.Core.AstEnv
import HordeAd.Core.Delta (Dual)
import HordeAd.Core.AstTools
import HordeAd.Core.Delta
import HordeAd.Core.DualClass
import HordeAd.Core.TensorClass
import HordeAd.Core.Types
import HordeAd.Util.SizedIndex

-- * The main dual number type

Expand Down Expand Up @@ -69,6 +81,31 @@ constantADVal a = dDnotShared emptyADShare a (dZeroOfShape a)
type ADValClown dynamic = Flip (ADVal (Clown dynamic)) '()


-- * Assorted instances

type instance SimpleBoolOf (ADVal f) = SimpleBoolOf f

instance EqF f => EqF (ADVal f) where
D l1 u _ ==. D l2 v _ = (l1 `mergeADShare` l2, snd $ u ==. v)
D l1 u _ /=. D l2 v _ = (l1 `mergeADShare` l2, snd $ u /=. v)

instance OrdF f => OrdF (ADVal f) where
D l1 u _ <. D l2 v _ = (l1 `mergeADShare` l2, snd $ u <. v)
D l1 u _ <=. D l2 v _ = (l1 `mergeADShare` l2, snd $ u <=. v)
D l1 u _ >. D l2 v _ = (l1 `mergeADShare` l2, snd $ u >. v)
D l1 u _ >=. D l2 v _ = (l1 `mergeADShare` l2, snd $ u >=. v)

type instance RankedOf (ADVal f) = ADVal (RankedOf f)

type instance ShapedOf (ADVal f) = ADVal (ShapedOf f)

type instance DynamicOf (ADVal f) = ADValClown (DynamicOf f)

type instance PrimalOf (ADVal f) = f

type instance DualOf (ADVal f) = Product (Clown (Const ADShare)) (Dual f)


-- * Auxiliary definitions

-- | Add sharing information to the top level of a term, presumably
Expand Down Expand Up @@ -109,6 +146,133 @@ dotParameters (Domains a0 a1) (Domains b0 b1) =
else OD.toVector v1 LA.<.> OD.toVector u1) a1 b1)
-}

crevOnADInputs
:: (DualPart f, GoodScalar r, HasSingletonDict y)
=> Maybe (f r y)
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf (ADVal f))
-> (Domains (DynamicOf f), f r y)
-- The functions in which @revOnADInputs@ inlines are not inlined themselves
-- in client code, so the bloat is limited.
{-# INLINE crevOnADInputs #-}
crevOnADInputs mdt f inputs =
let -- Evaluate completely after terms constructed, to free memory
-- before evaluation allocates new memory and new FFI is started.
!(D _ v deltaTopLevel) = f inputs in
let (!astBindings, !gradient) =
reverseDervative (V.length inputs) v mdt deltaTopLevel
in assert (null astBindings)
(gradient, v)

crevOnDomains
:: forall r y f.
( DynamicOf f ~ DynamicOf (RankedOf f)
, ConvertTensor (RankedOf f) (ShapedOf f)
, Dual (Clown (DynamicOf f)) ~ DeltaD (RankedOf f) (ShapedOf f)
, DualPart f, GoodScalar r, HasSingletonDict y)
=> Maybe (f r y)
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf f)
-> (Domains (DynamicOf f), f r y)
crevOnDomains mdt f parameters =
let deltaInputs = generateDeltaInputsOD parameters
inputs = makeADInputs parameters deltaInputs
in crevOnADInputs mdt f inputs

cfwdOnADInputs
:: (DualPart f, GoodScalar r, HasSingletonDict y)
=> Domains (DynamicOf (ADVal f))
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf f)
-> (f r y, f r y)
{-# INLINE cfwdOnADInputs #-}
cfwdOnADInputs inputs f ds =
let !(D _ v deltaTopLevel) = f inputs in
let (astBindings, derivative) =
forwardDerivative (V.length inputs) deltaTopLevel ds
in assert (null astBindings)
(derivative, v)

cfwdOnDomains
:: forall r y f.
( DynamicOf f ~ DynamicOf (RankedOf f)
, ConvertTensor (RankedOf f) (ShapedOf f)
, Dual (Clown (DynamicOf f)) ~ DeltaD (RankedOf f) (ShapedOf f)
, DualPart f, GoodScalar r, HasSingletonDict y)
=> Domains (DynamicOf f)
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf f)
-> (f r y, f r y)
cfwdOnDomains parameters f ds =
let deltaInputs = generateDeltaInputsOD parameters
inputs = makeADInputs parameters deltaInputs
in cfwdOnADInputs inputs f ds

-- Actually, this is fully general, not only working for DomainsOD.
generateDeltaInputsOD
:: forall ranked shaped dynamic.
( dynamic ~ DynamicOf ranked, ConvertTensor ranked shaped
, Dual (Clown dynamic) ~ DeltaD ranked shaped )
=> Domains dynamic
-> Domains (DualClown dynamic)
{-# INLINE generateDeltaInputsOD #-}
generateDeltaInputsOD params =
let arrayToInput :: Int
-> DynamicExists dynamic
-> DynamicExists (DualClown dynamic)
arrayToInput i (DynamicExists @r t) =
let shl = dshape @ranked t
in case someNatVal $ toInteger $ length shl of
Just (SomeNat (_ :: Proxy n)) ->
let sh = listShapeToShape shl
in DynamicExists $ Flip $ RToD $ InputR @ranked @shaped @r @n
sh (toInputId i)
Nothing -> error "generateDeltaInputs: impossible someNatVal error"
in V.imap arrayToInput params
{- TODO: this can't be specified without a proxy, so we inline instead
{-# SPECIALIZE generateDeltaInputs
:: DomainsOD -> Data.Vector.Vector (Dual OD.Array Double) #-}
-}

-- This is preferred for AstDynamic, because it results in shorter terms.
generateDeltaInputsAst
:: forall ranked shaped dynamic s.
( dynamic ~ AstDynamic s
, Dual (Clown dynamic) ~ DeltaD ranked shaped )
=> Domains dynamic
-> Domains (DualClown dynamic)
{-# INLINE generateDeltaInputsAst #-}
generateDeltaInputsAst params =
let arrayToInput :: Int
-> DynamicExists dynamic
-> DynamicExists (DualClown dynamic)
arrayToInput i (DynamicExists @r d) = case d of
AstRToD @n w ->
DynamicExists $ Flip $ RToD $ InputR @ranked @shaped @r @n
(shapeAst w) (toInputId i)
AstSToD @sh _w ->
DynamicExists $ Flip $ SToD $ InputS @ranked @shaped @r @sh
(toInputId i)
in V.imap arrayToInput params
{- TODO: this can't be specified without a proxy, so we inline instead
{-# SPECIALIZE generateDeltaInputs
:: DomainsOD -> Data.Vector.Vector (Dual OD.Array Double) #-}
-}

makeADInputs
:: Domains dynamic
-> Domains (DualClown dynamic)
-> Domains (ADValClown dynamic)
{-# INLINE makeADInputs #-}
makeADInputs =
V.zipWith (\(DynamicExists @r p)
(DynamicExists @r2 d) ->
case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> DynamicExists
$ Flip $ dDnotShared emptyADShare (Clown p) $ runFlip d
_ -> error "makeADInputs: type mismatch")


-- * Reverse and forward derivative stages instances

type DerivativeStages :: forall k. TensorKind k -> Constraint
Expand Down
141 changes: 8 additions & 133 deletions src/HordeAd/Core/Engine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@ module HordeAd.Core.Engine
, fwd, fwdArtifactAdapt, fwdProduceArtifact
-- * Reverse and forward derivative stages class
, forwardPassByApplication
-- * Old gradient adaptors, with constant and fixed inputs and dt
, crev, crevDt, crevOnDomains, crevOnADInputs
-- * Old derivative adaptors, with constant and fixed inputs
, cfwd, cfwdOnDomains, cfwdOnADInputs
-- * Old gradient adaptors
, crev, crevDt
-- * Old derivative adaptors
, cfwd
-- * Additional common mechanisms
, generateDeltaInputsOD, generateDeltaInputsAst, makeADInputs, shapedToRanked
, shapedToRanked
-- * Re-exported for tests
, interpretAst
) where

import Prelude

import Control.Exception.Assert.Sugar
import qualified Data.Array.DynamicS as OD
import qualified Data.Array.RankedS as OR
import qualified Data.Array.Shape as OS
Expand All @@ -36,11 +35,10 @@ import qualified Data.EnumMap.Strict as EM
import Data.Functor.Const
import Data.Int (Int64)
import Data.Maybe (fromMaybe, isJust)
import Data.Proxy (Proxy)
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import Data.Type.Equality (gcastWith, (:~:) (Refl))
import qualified Data.Vector.Generic as V
import GHC.TypeLits (KnownNat, Nat, SomeNat (..), someNatVal)
import Type.Reflection (Typeable, typeRep)
import GHC.TypeLits (KnownNat, Nat)
import Type.Reflection (Typeable)
import Unsafe.Coerce (unsafeCoerce)

import HordeAd.Core.Adaptor
Expand Down Expand Up @@ -564,19 +562,6 @@ crevDtMaybe f vals mdt =
let g inputs = f $ parseDomains vals inputs
in parseDomains vals $ fst $ crevOnDomains mdt g (toDomains vals)

crevOnDomains
:: ( DynamicOf f ~ DynamicOf (RankedOf f)
, ConvertTensor (RankedOf f) (ShapedOf f)
, Dual (Clown (DynamicOf f)) ~ DeltaD (RankedOf f) (ShapedOf f)
, DualPart f, GoodScalar r, HasSingletonDict y)
=> Maybe (f r y)
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf f)
-> (Domains (DynamicOf f), f r y)
crevOnDomains mdt f parameters =
let deltaInputs = generateDeltaInputsOD parameters
inputs = makeADInputs parameters deltaInputs
in crevOnADInputs mdt f inputs
{-# SPECIALIZE crevOnDomains
:: HasSingletonDict y
=> Maybe (Flip OR.Array Double y)
Expand All @@ -585,23 +570,6 @@ crevOnDomains mdt f parameters =
-> DomainsOD
-> (DomainsOD, Flip OR.Array Double y) #-}

crevOnADInputs
:: (DualPart f, GoodScalar r, HasSingletonDict y)
=> Maybe (f r y)
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf (ADVal f))
-> (Domains (DynamicOf f), f r y)
-- The functions in which @revOnADInputs@ inlines are not inlined themselves
-- in client code, so the bloat is limited.
{-# INLINE crevOnADInputs #-}
crevOnADInputs mdt f inputs =
let -- Evaluate completely after terms constructed, to free memory
-- before evaluation allocates new memory and new FFI is started.
!(D _ v deltaTopLevel) = f inputs in
let (!astBindings, !gradient) =
reverseDervative (V.length inputs) v mdt deltaTopLevel
in assert (null astBindings)
(gradient, v)
{-# SPECIALIZE crevOnADInputs
:: HasSingletonDict y
=> Maybe (Flip OR.Array Double y)
Expand All @@ -628,102 +596,9 @@ cfwd f x ds =
let g inputs = f $ parseDomains ds inputs
in fst $ cfwdOnDomains (toDomains x) g (toDomains ds)

cfwdOnDomains
:: forall r y f.
( DynamicOf f ~ DynamicOf (RankedOf f)
, ConvertTensor (RankedOf f) (ShapedOf f)
, Dual (Clown (DynamicOf f)) ~ DeltaD (RankedOf f) (ShapedOf f)
, DualPart f, GoodScalar r, HasSingletonDict y)
=> Domains (DynamicOf f)
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf f)
-> (f r y, f r y)
cfwdOnDomains parameters f ds =
let deltaInputs = generateDeltaInputsOD parameters
inputs = makeADInputs parameters deltaInputs
in cfwdOnADInputs inputs f ds

cfwdOnADInputs
:: (DualPart f, GoodScalar r, HasSingletonDict y)
=> Domains (DynamicOf (ADVal f))
-> (Domains (DynamicOf (ADVal f)) -> ADVal f r y)
-> Domains (DynamicOf f)
-> (f r y, f r y)
{-# INLINE cfwdOnADInputs #-}
cfwdOnADInputs inputs f ds =
let !(D _ v deltaTopLevel) = f inputs in
let (astBindings, derivative) =
forwardDerivative (V.length inputs) deltaTopLevel ds
in assert (null astBindings)
(derivative, v)


-- * Additional common mechanisms

-- Actually, this is fully general, not only working for DomainsOD.
generateDeltaInputsOD
:: forall ranked shaped dynamic.
( dynamic ~ DynamicOf ranked, ConvertTensor ranked shaped
, Dual (Clown dynamic) ~ DeltaD ranked shaped )
=> Domains dynamic
-> Domains (DualClown dynamic)
{-# INLINE generateDeltaInputsOD #-}
generateDeltaInputsOD params =
let arrayToInput :: Int
-> DynamicExists dynamic
-> DynamicExists (DualClown dynamic)
arrayToInput i (DynamicExists @r t) =
let shl = dshape @ranked t
in case someNatVal $ toInteger $ length shl of
Just (SomeNat (_ :: Proxy n)) ->
let sh = listShapeToShape shl
in DynamicExists $ Flip $ RToD $ InputR @ranked @shaped @r @n
sh (toInputId i)
Nothing -> error "generateDeltaInputs: impossible someNatVal error"
in V.imap arrayToInput params
{- TODO: this can't be specified without a proxy, so we inline instead
{-# SPECIALIZE generateDeltaInputs
:: DomainsOD -> Data.Vector.Vector (Dual OD.Array Double) #-}
-}

-- This is preferred for AstDynamic, because it results in shorter terms.
generateDeltaInputsAst
:: forall ranked shaped dynamic s.
( dynamic ~ AstDynamic s
, Dual (Clown dynamic) ~ DeltaD ranked shaped )
=> Domains dynamic
-> Domains (DualClown dynamic)
{-# INLINE generateDeltaInputsAst #-}
generateDeltaInputsAst params =
let arrayToInput :: Int
-> DynamicExists dynamic
-> DynamicExists (DualClown dynamic)
arrayToInput i (DynamicExists @r d) = case d of
AstRToD @n w ->
DynamicExists $ Flip $ RToD $ InputR @ranked @shaped @r @n
(shapeAst w) (toInputId i)
AstSToD @sh _w ->
DynamicExists $ Flip $ SToD $ InputS @ranked @shaped @r @sh
(toInputId i)
in V.imap arrayToInput params
{- TODO: this can't be specified without a proxy, so we inline instead
{-# SPECIALIZE generateDeltaInputs
:: DomainsOD -> Data.Vector.Vector (Dual OD.Array Double) #-}
-}

makeADInputs
:: Domains dynamic
-> Domains (DualClown dynamic)
-> Domains (ADValClown dynamic)
{-# INLINE makeADInputs #-}
makeADInputs =
V.zipWith (\(DynamicExists @r p)
(DynamicExists @r2 d) ->
case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> DynamicExists
$ Flip $ dDnotShared emptyADShare (Clown p) $ runFlip d
_ -> error "makeADInputs: type mismatch")

shapedToRanked
:: forall vals svals dynamic.
( dynamic ~ OD.Array, NoShape svals ~ vals, Value vals ~ vals
Expand Down
Loading

0 comments on commit 9c5ffe8

Please sign in to comment.