Skip to content

Commit

Permalink
generate recursive derived instance functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Nov 25, 2024
1 parent 68783a9 commit a470c62
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions src/Juvix/Compiler/Internal/Translation/FromConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,33 @@ deriveEq ::
Sem r Internal.FunctionDef
deriveEq DerivingArgs {..} = do
arg <- getArg
argty <- getArgType
argsInfo <- goArgsInfo _derivingInstanceName
lam <- eqLambda arg
lamName <- Internal.freshFunVar (getLoc _derivingInstanceName) ("__eq__" <> _derivingInstanceName ^. Internal.nameText)
let lam = Internal.ExpressionIden (Internal.IdenFunction lamName)
lamFun <- eqLambda lam arg argty
lamTy <- Internal.ExpressionHole <$> Internal.freshHole (getLoc _derivingInstanceName)
let lamDef =
Internal.FunctionDef
{ _funDefTerminating = False,
_funDefIsInstanceCoercion = Nothing,
_funDefPragmas = mempty,
_funDefArgsInfo = [],
_funDefDocComment = Nothing,
_funDefName = lamName,
_funDefType = lamTy,
_funDefBody = lamFun,
_funDefBuiltin = Nothing
}
mkEq <- getBuiltin (getLoc eqName) BuiltinMkEq
let body = mkEq Internal.@@ lam
ty = Internal.foldFunType _derivingParameters ret
pragmas' <- goPragmas _derivingPragmas
let body =
Internal.ExpressionLet
Internal.Let
{ _letClauses = pure (Internal.LetFunDef lamDef),
_letExpression = mkEq Internal.@@ lam
}
ty = Internal.foldFunType _derivingParameters ret
return
Internal.FunctionDef
{ _funDefTerminating = False,
Expand All @@ -586,8 +607,13 @@ deriveEq DerivingArgs {..} = do
Internal.ExpressionIden (Internal.IdenInductive ind) <- return (fst (Internal.unfoldExpressionApp a))
getDefinedInductive ind

eqLambda :: Internal.InductiveInfo -> Sem r Internal.Expression
eqLambda d = do
getArgType :: Sem r Internal.Expression
getArgType = runFailDefaultM (throwDerivingWrongForm ret) $ do
[Internal.ApplicationArg Explicit a] <- return args
return a

eqLambda :: Internal.Expression -> Internal.InductiveInfo -> Internal.Expression -> Sem r Internal.Expression
eqLambda lam d argty = do
let loc = getLoc eqName
band <- getBuiltin loc BuiltinBoolAnd
btrue <- getBuiltin loc BuiltinBoolTrue
Expand Down Expand Up @@ -627,6 +653,7 @@ deriveEq DerivingArgs {..} = do
Internal.ConstructorName ->
Sem r Internal.LambdaClause
lambdaClause band btrue bisEqual c = do
argsRecursive :: [Bool] <- getRecursiveArgs
numArgs :: [IsImplicit] <- getNumArgs
let loc = getLoc _derivingInstanceName
mkpat :: Sem r ([Internal.VarName], Internal.PatternArg)
Expand All @@ -641,22 +668,24 @@ deriveEq DerivingArgs {..} = do
return
Internal.LambdaClause
{ _lambdaPatterns = p1 :| [p2],
_lambdaBody = allEq (zipExact v1 v2)
_lambdaBody = allEq (zip3Exact v1 v2 argsRecursive)
}
where
allEq :: (Internal.IsExpression expr) => [(expr, expr)] -> Internal.Expression
allEq :: (Internal.IsExpression expr) => [(expr, expr, Bool)] -> Internal.Expression
allEq k = case nonEmpty k of
Nothing -> Internal.toExpression btrue
Just l -> mkAnds (fmap (uncurry mkEq) l)
Just l -> mkAnds (fmap (uncurry3 mkEq) l)

mkAnds :: (Internal.IsExpression expr) => NonEmpty expr -> Internal.Expression
mkAnds = foldl1 mkAnd . fmap Internal.toExpression

mkAnd :: (Internal.IsExpression expr) => expr -> expr -> Internal.Expression
mkAnd a b = band Internal.@@ a Internal.@@ b

mkEq :: (Internal.IsExpression expr) => expr -> expr -> Internal.Expression
mkEq a b = bisEqual Internal.@@ a Internal.@@ b
mkEq :: (Internal.IsExpression expr) => expr -> expr -> Bool -> Internal.Expression
mkEq a b isRec
| isRec = lam Internal.@@ a Internal.@@ b
| otherwise = bisEqual Internal.@@ a Internal.@@ b

getNumArgs :: Sem r [IsImplicit]
getNumArgs = do
Expand All @@ -668,6 +697,12 @@ deriveEq DerivingArgs {..} = do
. each
. Internal.paramImplicit

getRecursiveArgs :: Sem r [Bool]
getRecursiveArgs = do
def <- getDefinedConstructor c
let argTypes = map (^. Internal.paramType) $ Internal.constructorArgs (def ^. Internal.constructorInfoType)
return $ map (== argty) argTypes

goFunctionDef ::
forall r.
(Members '[Reader DefaultArgsStack, Reader Pragmas, Error ScoperError, NameIdGen, Reader S.InfoTable] r) =>
Expand Down

0 comments on commit a470c62

Please sign in to comment.