diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/FunctionCall.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/FunctionCall.hs index 04040d4b12..33c59e53e0 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/FunctionCall.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/FunctionCall.hs @@ -7,10 +7,11 @@ import Data.HashMap.Strict qualified as HashMap import Juvix.Compiler.Internal.Extra import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data import Juvix.Prelude +import Safe (headMay) viewCall :: forall r. - (Members '[Reader SizeInfo] r) => + (Members '[Reader SizeInfoMap] r) => Expression -> Sem r (Maybe FunCall) viewCall = \case @@ -19,12 +20,15 @@ viewCall = \case ExpressionApplication (Application f x impl) | isImplicitOrInstance impl -> viewCall f -- implicit arguments are ignored | otherwise -> do - c <- viewCall f - x' <- callArg - return $ over callArgs (`snoc` x') <$> c + mc <- viewCall f + case mc of + Just c -> do + x' <- callArg (c ^. callRef) + return $ Just $ over callArgs (`snoc` x') c + Nothing -> return Nothing where - callArg :: Sem r (CallRow, Expression) - callArg = do + callArg :: FunctionRef -> Sem r (CallRow, Expression) + callArg fref = do lt <- (^. callRow) <$> lessThan eq <- (^. callRow) <$> equalTo return (CallRow (lt `mplus` eq), x) @@ -33,7 +37,7 @@ viewCall = \case lessThan = case viewExpressionAsPattern x of Nothing -> return (CallRow Nothing) Just x' -> do - s <- asks (findIndex (elem x') . (^. sizeSmaller)) + s <- asks (findIndex (elem x') . (^. sizeSmaller) . findSizeInfo) return $ case s of Nothing -> CallRow Nothing Just s' -> CallRow (Just (s', RLe)) @@ -41,11 +45,37 @@ viewCall = \case equalTo = case viewExpressionAsPattern x of Just x' -> do - s <- asks (elemIndex x' . (^. sizeEqual)) + s <- asks (elemIndex x' . (^. sizeEqual) . findSizeInfo) return $ case s of Nothing -> CallRow Nothing Just s' -> CallRow (Just (s', REq)) Nothing -> return (CallRow Nothing) + findSizeInfo :: SizeInfoMap -> SizeInfo + findSizeInfo infos = + {- + If the call is not to any nested function being defined, then we + associate it with the most nested function. Without this, + termination for mutually recursive functions doesn't work. + + Consider: + ``` + isEven (x : Nat) : Bool := + let + isEven' : Nat -> Bool + | zero := true + | (suc n) := isOdd' n; + isOdd' : Nat -> Bool + | zero := false + | (suc n) := isEven' n; + in isEven' x; + ``` + The call `isEven' n` inside `isOdd'` needs to be associated with + `isOdd'`, not with `isEven`, and not just forgotten. + -} + fromMaybe (maybe emptySizeInfo snd . headMay $ infos ^. sizeInfoMap) + . (lookup fref) + . (^. sizeInfoMap) + $ infos _ -> return Nothing where singletonCall :: FunctionRef -> FunCall diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Checker.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Checker.hs index 08d56beb80..9923e223fa 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Checker.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Checker.hs @@ -59,6 +59,7 @@ instance Scannable Expression where buildCallMap = run . execState emptyCallMap + . runReader emptySizeInfoMap . scanTopExpression runTerminationState :: TerminationState -> Sem (Termination ': r) a -> Sem r (TerminationState, a) @@ -122,21 +123,21 @@ scanInductive i = do scanMutualStatement :: (Members '[State CallMap] r) => MutualStatement -> Sem r () scanMutualStatement = \case StatementInductive i -> scanInductive i - StatementFunction i -> scanFunctionDef i + StatementFunction i -> runReader emptySizeInfoMap $ scanFunctionDef i StatementAxiom a -> scanAxiom a scanAxiom :: (Members '[State CallMap] r) => AxiomDef -> Sem r () scanAxiom = scanTopExpression . (^. axiomType) scanFunctionDef :: - (Members '[State CallMap] r) => + (Members '[State CallMap, Reader SizeInfoMap] r) => FunctionDef -> Sem r () scanFunctionDef f@FunctionDef {..} = do registerFunctionDef f runReader (Just _funDefName) $ do scanTypeSignature _funDefType - scanFunctionBody _funDefBody + scanFunctionBody _funDefName _funDefBody scanDefaultArgs _funDefArgsInfo scanDefaultArgs :: @@ -153,38 +154,41 @@ scanTypeSignature :: (Members '[State CallMap, Reader (Maybe FunctionRef)] r) => Expression -> Sem r () -scanTypeSignature = runReader emptySizeInfo . scanExpression +scanTypeSignature = runReader emptySizeInfoMap . scanExpression scanFunctionBody :: forall r. - (Members '[State CallMap, Reader (Maybe FunctionRef)] r) => + (Members '[State CallMap, Reader SizeInfoMap, Reader (Maybe FunctionRef)] r) => + FunctionName -> Expression -> Sem r () -scanFunctionBody topbody = go [] topbody +scanFunctionBody funName topbody = go [] topbody where go :: [PatternArg] -> Expression -> Sem r () go revArgs body = case body of ExpressionLambda Lambda {..} -> mapM_ goClause _lambdaClauses - _ -> runReader (mkSizeInfo (reverse revArgs)) (scanExpression body) + _ -> + local + (over sizeInfoMap ((funName, mkSizeInfo (reverse revArgs)) :)) + (scanExpression body) where goClause :: LambdaClause -> Sem r () goClause (LambdaClause pats clBody) = go (reverse (toList pats) ++ revArgs) clBody scanLet :: - (Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) => + (Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) => Let -> Sem r () scanLet l = do mapM_ scanLetClause (l ^. letClauses) scanExpression (l ^. letExpression) --- NOTE that we forget about the arguments of the hosting function -scanLetClause :: (Members '[State CallMap] r) => LetClause -> Sem r () +scanLetClause :: (Members '[State CallMap, Reader SizeInfoMap] r) => LetClause -> Sem r () scanLetClause = \case LetFunDef d -> scanFunctionDef d LetMutualBlock m -> scanMutualBlockLet m -scanMutualBlockLet :: (Members '[State CallMap] r) => MutualBlockLet -> Sem r () +scanMutualBlockLet :: (Members '[State CallMap, Reader SizeInfoMap] r) => MutualBlockLet -> Sem r () scanMutualBlockLet MutualBlockLet {..} = mapM_ scanFunctionDef _mutualLet scanTopExpression :: @@ -192,18 +196,26 @@ scanTopExpression :: Expression -> Sem r () scanTopExpression = - runReader (Nothing @FunctionRef) - . runReader emptySizeInfo + runReader emptySizeInfoMap + . runReader (Nothing @FunctionRef) . scanExpression scanExpression :: - (Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) => + (Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) => Expression -> Sem r () scanExpression e = viewCall e >>= \case Just c -> do - whenJustM (ask @(Maybe FunctionRef)) (\caller -> runReader caller (registerCall c)) + -- Are we recursively calling a function being defined? + recCall <- asks (elem (c ^. callRef) . map fst . (^. sizeInfoMap)) + if + | recCall -> + runReader (c ^. callRef) (registerCall c) + | otherwise -> + whenJustM + (ask @(Maybe FunctionRef)) + (\caller -> runReader caller (registerCall c)) mapM_ (scanExpression . snd) (c ^. callArgs) Nothing -> case e of ExpressionApplication a -> directExpressions_ scanExpression a diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Data/SizeInfo.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Data/SizeInfo.hs index 0bd7bae97e..1a736d56f4 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Data/SizeInfo.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Termination/Data/SizeInfo.hs @@ -12,6 +12,12 @@ data SizeInfo = SizeInfo makeLenses ''SizeInfo +newtype SizeInfoMap = SizeInfoMap + { _sizeInfoMap :: [(FunctionName, SizeInfo)] + } + +makeLenses ''SizeInfoMap + emptySizeInfo :: SizeInfo emptySizeInfo = SizeInfo @@ -19,6 +25,9 @@ emptySizeInfo = _sizeSmaller = mempty } +emptySizeInfoMap :: SizeInfoMap +emptySizeInfoMap = SizeInfoMap [] + mkSizeInfo :: [PatternArg] -> SizeInfo mkSizeInfo args = SizeInfo {..} where diff --git a/test/Termination/Positive.hs b/test/Termination/Positive.hs index f1fc4ca5ce..b8194d0055 100644 --- a/test/Termination/Positive.hs +++ b/test/Termination/Positive.hs @@ -67,7 +67,15 @@ tests = PosTest "Ignore instance arguments" $(mkRelDir ".") - $(mkRelFile "issue2414.juvix") + $(mkRelFile "issue2414.juvix"), + PosTest + "Nested local definitions" + $(mkRelDir ".") + $(mkRelFile "Nested1.juvix"), + PosTest + "Named arguments" + $(mkRelDir ".") + $(mkRelFile "Nested2.juvix") ] testsWithKeyword :: [PosTest] diff --git a/tests/positive/Termination/Nested1.juvix b/tests/positive/Termination/Nested1.juvix new file mode 100644 index 0000000000..4c3bab7b18 --- /dev/null +++ b/tests/positive/Termination/Nested1.juvix @@ -0,0 +1,11 @@ +module Nested1; + +import Stdlib.Data.List open; + +go {A B} (f : A -> B) : List A -> List B + | nil := nil + | (elem :: next) := + let + var1 := f elem; + var2 := go f next; + in var1 :: var2; diff --git a/tests/positive/Termination/Nested2.juvix b/tests/positive/Termination/Nested2.juvix new file mode 100644 index 0000000000..bfbe1f3fd8 --- /dev/null +++ b/tests/positive/Termination/Nested2.juvix @@ -0,0 +1,16 @@ +module Nested2; + +type MyList A := + | myNil + | myCons@{ + elem : A; + next : MyList A; + }; + +go {A B} (f : A -> B) : MyList A -> MyList B + | myNil := myNil + | myCons@{elem; next} := + myCons@{ + elem := f elem; + next := go f next; + };