From a470c62e6fad8340b2f9f764f97d4884b6a10919 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 22 Nov 2024 20:17:14 +0100 Subject: [PATCH] generate recursive derived instance functions --- .../Internal/Translation/FromConcrete.hs | 55 +++++++++++++++---- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/src/Juvix/Compiler/Internal/Translation/FromConcrete.hs b/src/Juvix/Compiler/Internal/Translation/FromConcrete.hs index 9a6b85f731..e1b2504167 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromConcrete.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromConcrete.hs @@ -554,12 +554,33 @@ deriveEq :: Sem r Internal.FunctionDef deriveEq DerivingArgs {..} = do arg <- getArg + argty <- getArgType argsInfo <- goArgsInfo _derivingInstanceName - lam <- eqLambda arg + lamName <- Internal.freshFunVar (getLoc _derivingInstanceName) ("__eq__" <> _derivingInstanceName ^. Internal.nameText) + let lam = Internal.ExpressionIden (Internal.IdenFunction lamName) + lamFun <- eqLambda lam arg argty + lamTy <- Internal.ExpressionHole <$> Internal.freshHole (getLoc _derivingInstanceName) + let lamDef = + Internal.FunctionDef + { _funDefTerminating = False, + _funDefIsInstanceCoercion = Nothing, + _funDefPragmas = mempty, + _funDefArgsInfo = [], + _funDefDocComment = Nothing, + _funDefName = lamName, + _funDefType = lamTy, + _funDefBody = lamFun, + _funDefBuiltin = Nothing + } mkEq <- getBuiltin (getLoc eqName) BuiltinMkEq - let body = mkEq Internal.@@ lam - ty = Internal.foldFunType _derivingParameters ret pragmas' <- goPragmas _derivingPragmas + let body = + Internal.ExpressionLet + Internal.Let + { _letClauses = pure (Internal.LetFunDef lamDef), + _letExpression = mkEq Internal.@@ lam + } + ty = Internal.foldFunType _derivingParameters ret return Internal.FunctionDef { _funDefTerminating = False, @@ -586,8 +607,13 @@ deriveEq DerivingArgs {..} = do Internal.ExpressionIden (Internal.IdenInductive ind) <- return (fst (Internal.unfoldExpressionApp a)) getDefinedInductive ind - eqLambda :: Internal.InductiveInfo -> Sem r Internal.Expression - eqLambda d = do + getArgType :: Sem r Internal.Expression + getArgType = runFailDefaultM (throwDerivingWrongForm ret) $ do + [Internal.ApplicationArg Explicit a] <- return args + return a + + eqLambda :: Internal.Expression -> Internal.InductiveInfo -> Internal.Expression -> Sem r Internal.Expression + eqLambda lam d argty = do let loc = getLoc eqName band <- getBuiltin loc BuiltinBoolAnd btrue <- getBuiltin loc BuiltinBoolTrue @@ -627,6 +653,7 @@ deriveEq DerivingArgs {..} = do Internal.ConstructorName -> Sem r Internal.LambdaClause lambdaClause band btrue bisEqual c = do + argsRecursive :: [Bool] <- getRecursiveArgs numArgs :: [IsImplicit] <- getNumArgs let loc = getLoc _derivingInstanceName mkpat :: Sem r ([Internal.VarName], Internal.PatternArg) @@ -641,13 +668,13 @@ deriveEq DerivingArgs {..} = do return Internal.LambdaClause { _lambdaPatterns = p1 :| [p2], - _lambdaBody = allEq (zipExact v1 v2) + _lambdaBody = allEq (zip3Exact v1 v2 argsRecursive) } where - allEq :: (Internal.IsExpression expr) => [(expr, expr)] -> Internal.Expression + allEq :: (Internal.IsExpression expr) => [(expr, expr, Bool)] -> Internal.Expression allEq k = case nonEmpty k of Nothing -> Internal.toExpression btrue - Just l -> mkAnds (fmap (uncurry mkEq) l) + Just l -> mkAnds (fmap (uncurry3 mkEq) l) mkAnds :: (Internal.IsExpression expr) => NonEmpty expr -> Internal.Expression mkAnds = foldl1 mkAnd . fmap Internal.toExpression @@ -655,8 +682,10 @@ deriveEq DerivingArgs {..} = do mkAnd :: (Internal.IsExpression expr) => expr -> expr -> Internal.Expression mkAnd a b = band Internal.@@ a Internal.@@ b - mkEq :: (Internal.IsExpression expr) => expr -> expr -> Internal.Expression - mkEq a b = bisEqual Internal.@@ a Internal.@@ b + mkEq :: (Internal.IsExpression expr) => expr -> expr -> Bool -> Internal.Expression + mkEq a b isRec + | isRec = lam Internal.@@ a Internal.@@ b + | otherwise = bisEqual Internal.@@ a Internal.@@ b getNumArgs :: Sem r [IsImplicit] getNumArgs = do @@ -668,6 +697,12 @@ deriveEq DerivingArgs {..} = do . each . Internal.paramImplicit + getRecursiveArgs :: Sem r [Bool] + getRecursiveArgs = do + def <- getDefinedConstructor c + let argTypes = map (^. Internal.paramType) $ Internal.constructorArgs (def ^. Internal.constructorInfoType) + return $ map (== argty) argTypes + goFunctionDef :: forall r. (Members '[Reader DefaultArgsStack, Reader Pragmas, Error ScoperError, NameIdGen, Reader S.InfoTable] r) =>