diff --git a/library/Control/Monad/Mock.hs b/library/Control/Monad/Mock.hs index b5d01d3..257a67a 100644 --- a/library/Control/Monad/Mock.hs +++ b/library/Control/Monad/Mock.hs @@ -63,12 +63,15 @@ boilerplate, look at 'Control.Monad.Mock.TH.makeAction' from "Control.Monad.Mock.TH". -} module Control.Monad.Mock - ( -- * The MockT monad transformer - MockT + ( + -- * The MonadMock class + MonadMock(..) + + -- * The MockT monad transformer + , MockT , Mock , runMockT , runMock - , mockAction -- * Actions and actions with results , Action(..) @@ -79,6 +82,7 @@ import Control.Monad.Base (MonadBase) import Control.Monad.Catch (MonadCatch, MonadThrow, MonadMask) import Control.Monad.Cont (MonadCont) import Control.Monad.Except (MonadError) +import Control.Monad.Fix import Control.Monad.IO.Class (MonadIO) import Control.Monad.Reader (MonadReader) import Control.Monad.State (StateT, MonadState(..), runStateT) @@ -129,7 +133,7 @@ data WithResult f where -- f m a@, @f@ should be an 'Action', which should be a GADT that represents a -- reified version of typeclass method calls. newtype MockT f m a = MockT (StateT [WithResult f] m a) - deriving ( Functor, Applicative, Monad, MonadTrans, MonadIO, MonadBase b + deriving ( Functor, Applicative, Monad, MonadTrans, MonadIO, MonadFix, MonadBase b , MonadReader r, MonadCont, MonadError e, MonadWriter w , MonadCatch, MonadThrow, MonadMask ) @@ -164,15 +168,18 @@ runMockT actions (MockT x) = runStateT x actions >>= \case runMock :: forall f a. Action f => [WithResult f] -> Mock f a -> a runMock actions x = runIdentity $ runMockT actions x --- | Logs a method call within a mock. -mockAction :: (Action f, Monad m) => String -> f r -> MockT f m r -mockAction fnName action = MockT $ get >>= \case - [] -> error' - $ "runMockT: expected end of program, called " ++ fnName ++ "\n" - ++ " given action: " ++ showAction action ++ "\n" - (action' :-> r) : actions - | Just Refl <- action `eqAction` action' -> put actions >> return r - | otherwise -> error' - $ "runMockT: argument mismatch in " ++ fnName ++ "\n" - ++ " given: " ++ showAction action ++ "\n" - ++ " expected: " ++ showAction action' ++ "\n" +class MonadMock f m where + -- | Logs a method call within a mock. + mockAction :: Action f => String -> f r -> m r + +instance Monad m => MonadMock f (MockT f m) where + mockAction fnName action = MockT $ get >>= \case + [] -> error' + $ "runMockT: expected end of program, called " ++ fnName ++ "\n" + ++ " given action: " ++ showAction action ++ "\n" + (action' :-> r) : actions + | Just Refl <- action `eqAction` action' -> put actions >> return r + | otherwise -> error' + $ "runMockT: argument mismatch in " ++ fnName ++ "\n" + ++ " given: " ++ showAction action ++ "\n" + ++ " expected: " ++ showAction action' ++ "\n" diff --git a/library/Control/Monad/Mock/Stateless.hs b/library/Control/Monad/Mock/Stateless.hs new file mode 100644 index 0000000..0b09c45 --- /dev/null +++ b/library/Control/Monad/Mock/Stateless.hs @@ -0,0 +1,115 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | A version of 'MockT' with a stateless 'MonadTransControl' instance +module Control.Monad.Mock.Stateless + ( + -- * The MonadMock class + MonadMock(..) + + -- * The MockT monad transformer + , MockT + , Mock + , runMockT + , runMock + , MockT_ + + -- * Actions and actions with results + , Action(..) + , WithResult(..) + ) where + +import Control.Monad.Base (MonadBase) +import Control.Monad.Catch (MonadCatch, MonadThrow, MonadMask) +import Control.Monad.Cont (MonadCont) +import Control.Monad.Except (MonadError) +import Control.Monad.Fix +import Control.Monad.Identity +import Control.Monad.IO.Class (MonadIO) +import Control.Monad.Primitive (PrimMonad(..)) +import Control.Monad.Reader (ReaderT(..), MonadReader(..)) +import Control.Monad.State (MonadState) +import Control.Monad.ST (ST, runST) +import Control.Monad.Trans (MonadTrans(..)) +import Control.Monad.Trans.Control +import Control.Monad.Writer (MonadWriter) +import Data.Primitive.MutVar (MutVar, newMutVar, readMutVar, writeMutVar) +import Data.Type.Equality ((:~:)(..)) + +import Control.Monad.Mock (Action(..), MonadMock(..)) + +type MockT f m = MockT_ (PrimState m) f m m + +type Mock s f = MockT f (ST s) + +-- | Represents both an expected call (an 'Action') and its expected result. +data WithResult m f where + -- | Matches a specific command + (:->) :: f r -> m r -> WithResult m f + -- | Skips commands as long as the predicate returns something + SkipWhile :: (forall r. f r -> Maybe (m r)) -> WithResult m f + +newtype MockT_ s f n m a = MockT (ReaderT (MutVar s [WithResult n f]) m a) + deriving ( Functor, Applicative, Monad, MonadIO, MonadFix + , MonadState st, MonadCont, MonadError e, MonadWriter w + , MonadCatch, MonadThrow, MonadMask + , MonadTrans, MonadTransControl + , MonadBase b, MonadBaseControl b + , PrimMonad) + +instance MonadReader r m => MonadReader r (MockT_ s f n m) where + ask = lift ask + local f (MockT act) = MockT $ do + env <- ask + lift $ local f $ runReaderT act env + +runMockT :: forall f m a . + (Action f, PrimMonad m) => + [WithResult m f] -> MockT f m a -> m a +runMockT actions (MockT x) = do + ref <- newMutVar actions + r <- runReaderT x ref + leftovers <- readMutVar ref + case leftovers of + [] -> return r + remainingActions -> error' + $ "runMockT: expected the following unexecuted actions to be run:\n" + ++ unlines (map (\(action :-> _) -> " " ++ showAction action) remainingActions) + +runMock :: forall f a. Action f => [WithResult Identity f] -> (forall s. Mock s f a) -> a +runMock actions x = runST $ runMockT (map (\(a :-> b) -> a :-> return(runIdentity b)) actions) x + +instance (PrimMonad m, PrimState m ~ s) => MonadMock f (MockT_ s f m m) where + mockAction fnName action = do + ref <- MockT ask + results <- lift $ readMutVar ref + case results of + [] -> error' + $ "runMockT: expected end of program, called " ++ fnName ++ "\n" + ++ " given action: " ++ showAction action ++ "\n" + SkipWhile f : actions + | Just res <- f action + -> lift res + | otherwise -> do + lift $ writeMutVar ref actions + mockAction fnName action + (action' :-> r) : actions + | Just Refl <- action `eqAction` action' -> do + lift $ writeMutVar ref actions + lift r + | otherwise -> error' + $ "runMockT: argument mismatch in " ++ fnName ++ "\n" + ++ " given: " ++ showAction action ++ "\n" + ++ " expected: " ++ showAction action' ++ "\n" + + +error' :: String -> a +#if MIN_VERSION_base(4,9,0) +error' = errorWithoutStackTrace +#else +error' = error +#endif diff --git a/library/Control/Monad/Mock/TH.hs b/library/Control/Monad/Mock/TH.hs index 7093d22..cab7599 100644 --- a/library/Control/Monad/Mock/TH.hs +++ b/library/Control/Monad/Mock/TH.hs @@ -81,6 +81,7 @@ spec = describe "copyFile" '$' module Control.Monad.Mock.TH (makeAction, deriveAction, ts) where import Control.Monad (replicateM, when, zipWithM) +import Control.Monad.Primitive (PrimMonad, PrimState) import Data.Char (toUpper) import Data.Foldable (traverse_) import Data.List (foldl', nub, partition) @@ -89,6 +90,7 @@ import GHC.Exts (Constraint) import Language.Haskell.TH import Control.Monad.Mock (Action(..), MockT, mockAction) +import qualified Control.Monad.Mock.Stateless as Stateless import Control.Monad.Mock.TH.Internal.TypesQuasi (ts) -- | Given a list of monadic typeclass constraints of kind @* -> 'Constraint'@, @@ -119,9 +121,12 @@ makeAction actionNameStr classTs = do mkStandaloneDec derivT = standaloneDeriveD' [] (derivT `AppT` (actionTypeCon `AppT` VarT actionParamName)) standaloneDecs = [mkStandaloneDec (ConT ''Eq), mkStandaloneDec (ConT ''Show)] actionInstanceDec <- deriveAction' actionTypeCon actionCons - classInstanceDecs <- zipWithM (mkInstance actionTypeCon) classTs methods + classInstanceDecs1 <- zipWithM (mkInstance (ConT ''MockT) (const []) actionTypeCon) classTs methods + primStateVar <- newName "s" + let primStateConstraint baseM = [ConT ''PrimMonad `AppT` baseM, EqualityT `AppT` VarT primStateVar `AppT` (ConT ''PrimState `AppT` baseM)] + classInstanceDecs2 <- zipWithM (mkInstance (ConT ''Stateless.MockT_ `AppT` VarT primStateVar) primStateConstraint actionTypeCon) classTs methods - return $ [actionDec] ++ standaloneDecs ++ [actionInstanceDec] ++ classInstanceDecs + return $ [actionDec] ++ standaloneDecs ++ [actionInstanceDec] ++ classInstanceDecs1 ++ classInstanceDecs2 where -- | Ensures that a provided constraint is something monad-mock can actually -- derive an instance for. Specifically, it must be a constraint of kind @@ -203,8 +208,8 @@ makeAction actionNameStr classTs = do methodNameToConstructorName name = mkName (toUpper c : cs) where (c:cs) = nameBase name - mkInstance :: Type -> Type -> [Dec] -> Q Dec - mkInstance actionT classT methodSigs = do + mkInstance :: Type -> (Type -> [Pred]) -> Type -> Type -> [Dec] -> Q Dec + mkInstance mockT mkExtraConstraints actionT classT methodSigs = do mVar <- newName "m" -- In order to calculate the constraints on the instance, we need to look @@ -229,10 +234,10 @@ makeAction actionNameStr classTs = do contextSubFns = map (uncurry substituteTypeVar) classBindsToInstanceBinds instanceContext = foldr map classContext contextSubFns - let instanceHead = classT `AppT` (ConT ''MockT `AppT` actionT `AppT` VarT mVar) + let instanceHead = classT `AppT` (mockT `AppT` actionT `AppT` VarT mVar) methodImpls <- traverse mkInstanceMethod methodSigs - return $ instanceD' instanceContext instanceHead methodImpls + return $ instanceD' (instanceContext ++ mkExtraConstraints (VarT mVar)) instanceHead methodImpls mkInstanceMethod :: Dec -> Q Dec mkInstanceMethod (SigD name typ) = do diff --git a/package.yaml b/package.yaml index 69e1786..a9402ff 100644 --- a/package.yaml +++ b/package.yaml @@ -48,7 +48,8 @@ library: - th-orphans - monad-control >= 1.0.0.0 && < 2 - mtl - - template-haskell >= 2.10.0.0 && < 2.13 + - primitive + - template-haskell >= 2.10.0.0 && < 2.15 - transformers-base when: - condition: impl(ghc < 8) diff --git a/test-suite/Control/Monad/Mock/StatelessSpec.hs b/test-suite/Control/Monad/Mock/StatelessSpec.hs new file mode 100644 index 0000000..166db0a --- /dev/null +++ b/test-suite/Control/Monad/Mock/StatelessSpec.hs @@ -0,0 +1,71 @@ +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE FunctionalDependencies #-} + +module Control.Monad.Mock.StatelessSpec (spec) where + +import Prelude hiding (readFile, writeFile) + +import Control.Exception (evaluate) +import Control.Monad.Except (MonadError, runExceptT) +import Control.Monad.ST (runST) +import Data.Function ((&)) +import Test.Hspec + +import Control.Monad.Mock.Stateless +import Control.Monad.Mock.TH + +class MonadError e m => MonadFileSystem e m | m -> e where + readFile :: FilePath -> m String + writeFile :: FilePath -> String -> m () +makeAction "FileSystemAction" [ts| MonadFileSystem String |] + +copyFileAndReturn :: MonadFileSystem e m => FilePath -> FilePath -> m String +copyFileAndReturn a b = do + x <- readFile a + writeFile b x + return x + +spec :: Spec +spec = describe "MockT" $ do + it "runs computations with mocked method implementations" $ do + let result = runST + $ copyFileAndReturn "foo.txt" "bar.txt" + & runMockT [ ReadFile "foo.txt" :-> "file contents" + , WriteFile "bar.txt" "file contents" :-> () ] + & runExceptT + result `shouldBe` Right "file contents" + + it "raises an exception if calls are not in the right order" $ do + let result = runST + $ copyFileAndReturn "foo.txt" "bar.txt" + & runMockT [ WriteFile "bar.txt" "file contents" :-> () + , ReadFile "foo.txt" :-> "file contents" ] + & runExceptT + exnMessage = + "runMockT: argument mismatch in readFile\n\ + \ given: ReadFile \"foo.txt\"\n\ + \ expected: WriteFile \"bar.txt\" \"file contents\"\n" + evaluate result `shouldThrow` errorCall exnMessage + + it "raises an exception if calls are missing" $ do + let result = -- running on top of IO + copyFileAndReturn "foo.txt" "bar.txt" + & runMockT [ ReadFile "foo.txt" :-> "file contents" + , WriteFile "bar.txt" "file contents" :-> () + , ReadFile "qux.txt" :-> "file contents 2" ] + & runExceptT + let exnMessage = + "runMockT: expected the following unexecuted actions to be run:\n\ + \ ReadFile \"qux.txt\"\n" + result `shouldThrow` errorCall exnMessage + + it "raises an exception if there are too many calls" $ do + let result = runST + $ copyFileAndReturn "foo.txt" "bar.txt" + & runMockT [ ReadFile "foo.txt" :-> "file contents" ] + & runExceptT + exnMessage = + "runMockT: expected end of program, called writeFile\n\ + \ given action: WriteFile \"bar.txt\" \"file contents\"\n" + evaluate result `shouldThrow` errorCall exnMessage