diff --git a/src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs b/src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs index ec552dd177..72b2f94779 100644 --- a/src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs +++ b/src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs @@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where import Data.HashMap.Strict qualified as HashMap import Data.HashSet qualified as HashSet -import Data.List.NonEmpty.Extra qualified as NonEmpty import Data.Text qualified as T import Data.Text qualified as Text import Juvix.Compiler.Backend.Isabelle.Data.Result @@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} = mkExprCase c@Case {..} = case _caseValue of ExprIden v -> case _caseBranches of - CaseBranch {..} :| [] -> + CaseBranch {..} :| _ -> case _caseBranchPattern of PatVar v' -> substVar v' v _caseBranchBody _ -> ExprCase c - _ -> ExprCase c ExprTuple (Tuple (ExprIden v :| [])) -> case _caseBranches of - CaseBranch {..} :| [] -> + CaseBranch {..} :| _ -> case _caseBranchPattern of PatTuple (Tuple (PatVar v' :| [])) -> substVar v' v _caseBranchBody _ -> ExprCase c - _ -> ExprCase c - _ -> ExprCase c + _ -> + case _caseBranches of + br@CaseBranch {..} :| _ -> + case _caseBranchPattern of + PatVar _ -> + ExprCase + Case + { _caseValue = _caseValue, + _caseBranches = br :| [] + } + PatTuple (Tuple (PatVar _ :| [])) -> + ExprCase + Case + { _caseValue = _caseValue, + _caseBranches = br :| [] + } + _ -> ExprCase c goMutualBlock :: Internal.MutualBlock -> [Statement] goMutualBlock Internal.MutualBlock {..} = @@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} = : goClauses cls Nested pats npats -> let rhs = goExpression'' nset' nmap' _lambdaBody - argnames' = fmap getPatternArgName _lambdaPatterns + argnames' = fmap getPatternArgName lambdaPats vnames = - fmap - ( \(idx :: Int, mname) -> - maybe - ( defaultName - (getLoc cl) - ( disambiguate - (nset' ^. nameSet) - ("v_" <> show idx) - ) - ) - (overNameText (disambiguate (nset' ^. nameSet))) - mname - ) - (NonEmpty.zip (nonEmpty' [0 ..]) argnames') + nonEmpty' $ + fmap + ( \(idx :: Int, mname) -> + maybe + ( defaultName + (getLoc cl) + ( disambiguate + (nset' ^. nameSet) + ("v_" <> show idx) + ) + ) + (overNameText (disambiguate (nset' ^. nameSet))) + mname + ) + (zip [0 ..] argnames') nset'' = foldl' (flip (over nameSet . HashSet.insert . (^. namePretty))) nset' vnames - remainingBranches = goLambdaClauses'' nset'' nmap' cls + remainingBranches = goLambdaClauses'' nset'' nmap' (Just ty) cls valTuple = ExprTuple (Tuple (fmap ExprIden vnames)) patTuple = PatTuple (Tuple (nonEmpty' pats)) brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats) @@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} = } ] where - (npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns)) + lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns) + (npats0, nset', nmap') = goPatternArgsTop lambdaPats [] -> [] goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch] -> Pattern -> NonEmpty (Expression, Nested Pattern) -> NonEmpty CaseBranch @@ -828,18 +843,7 @@ goModule onlyTypes infoTable Internal.Module {..} = | patsNum == 0 = goExpression (head _lambdaClauses ^. Internal.lambdaBody) | otherwise = goLams vars where - patsNum = - case _lambdaType of - Just ty -> - length - . filterTypeArgs 0 ty - . toList - $ head _lambdaClauses ^. Internal.lambdaPatterns - Nothing -> - length - . filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit)) - . toList - $ head _lambdaClauses ^. Internal.lambdaPatterns + patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal.lambdaPatterns vars = map (\i -> defaultName (getLoc lam) ("x" <> show i)) [0 .. patsNum - 1] goLams :: [Name] -> Sem r Expression @@ -869,7 +873,7 @@ goModule onlyTypes infoTable Internal.Module {..} = Tuple { _tupleComponents = nonEmpty' vars' } - brs <- goLambdaClauses (toList _lambdaClauses) + brs <- goLambdaClauses _lambdaType (toList _lambdaClauses) return $ mkExprCase Case @@ -926,17 +930,29 @@ goModule onlyTypes infoTable Internal.Module {..} = Internal.CaseBranchRhsExpression e -> goExpression e Internal.CaseBranchRhsIf {} -> error "unsupported: side conditions" - goLambdaClauses'' :: NameSet -> NameMap -> [Internal.LambdaClause] -> [CaseBranch] - goLambdaClauses'' nset nmap cls = - run $ runReader nset $ runReader nmap $ goLambdaClauses cls - - goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => [Internal.LambdaClause] -> Sem r [CaseBranch] - goLambdaClauses = \case + filterLambdaPatternArgs :: Maybe Internal.Expression -> NonEmpty Internal.PatternArg -> [Internal.PatternArg] + filterLambdaPatternArgs mty cls = case mty of + Just ty -> + filterTypeArgs 0 ty + . toList + $ cls + Nothing -> + filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit)) + . toList + $ cls + + goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal.Expression -> [Internal.LambdaClause] -> [CaseBranch] + goLambdaClauses'' nset nmap mty cls = + run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls + + goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => Maybe Internal.Expression -> [Internal.LambdaClause] -> Sem r [CaseBranch] + goLambdaClauses mty = \case cl@Internal.LambdaClause {..} : cls -> do - (npat, nset, nmap) <- case _lambdaPatterns of - p :| [] -> goPatternArgCase p + let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns + (npat, nset, nmap) <- case lambdaPats of + [p] -> goPatternArgCase p _ -> do - (npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns) + (npats, nset, nmap) <- goPatternArgsCase lambdaPats let npat = fmap ( \pats -> @@ -950,7 +966,7 @@ goModule onlyTypes infoTable Internal.Module {..} = case npat of Nested pat [] -> do body <- withLocalNames nset nmap $ goExpression _lambdaBody - brs <- goLambdaClauses cls + brs <- goLambdaClauses mty cls return $ CaseBranch { _caseBranchPattern = pat, @@ -961,7 +977,7 @@ goModule onlyTypes infoTable Internal.Module {..} = let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) "v") nset' = over nameSet (HashSet.insert (vname ^. namePretty)) nset rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody - remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls + remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats) return [ CaseBranch @@ -1133,7 +1149,11 @@ goModule onlyTypes infoTable Internal.Module {..} = case HashMap.lookup name (infoTable ^. Internal.infoConstructors) of Just ctrInfo | ctrInfo ^. Internal.constructorInfoRecord -> - Just (indName, goRecordFields (getArgtys ctrInfo) args) + case HashMap.lookup indName (infoTable ^. Internal.infoInductives) of + Just indInfo + | length (indInfo ^. Internal.inductiveInfoConstructors) == 1 -> + Just (indName, goRecordFields (getArgtys ctrInfo) args) + _ -> Nothing where indName = ctrInfo ^. Internal.constructorInfoInductive _ -> Nothing diff --git a/tests/positive/Isabelle/Program.juvix b/tests/positive/Isabelle/Program.juvix index 217bae8939..0d1b313db0 100644 --- a/tests/positive/Isabelle/Program.juvix +++ b/tests/positive/Isabelle/Program.juvix @@ -154,3 +154,44 @@ funR4 : R -> R bf (b1 b2 : Bool) : Bool := not (b1 && b2); nf (n1 n2 : Int) : Bool := n1 - n2 >= n1 || n2 <= n1 + n2; + +-- Nested record patterns + +type MessagePacket (MessageType : Type) : Type := mkMessagePacket { + target : Nat; + mailbox : Maybe Nat; + message : MessageType; +}; + +type EnvelopedMessage (MessageType : Type) : Type := + mkEnvelopedMessage { + sender : Maybe Nat; + packet : MessagePacket MessageType; + }; + +type Timer (HandleType : Type): Type := mkTimer { + time : Nat; + handle : HandleType; +}; + +type Trigger (MessageType : Type) (HandleType : Type) := + | MessageArrived { envelope : EnvelopedMessage MessageType; } + | Elapsed { timers : List (Timer HandleType) }; + +getMessageFromTrigger : {M H : Type} -> Trigger M H -> Maybe M + | (MessageArrived@{ + envelope := (mkEnvelopedMessage@{ + packet := (mkMessagePacket@{ + message := m })})}) + := just m + | _ := nothing; + + +getMessageFromTrigger' {M H} (t : Trigger M H) : Maybe M := + case t of + | (MessageArrived@{ + envelope := (mkEnvelopedMessage@{ + packet := (mkMessagePacket@{ + message := m })})}) + := just m + | _ := nothing; diff --git a/tests/positive/Isabelle/isabelle/Program.thy b/tests/positive/Isabelle/isabelle/Program.thy index fec5012e10..3ef6bb3b09 100644 --- a/tests/positive/Isabelle/isabelle/Program.thy +++ b/tests/positive/Isabelle/isabelle/Program.thy @@ -240,4 +240,62 @@ fun bf :: "bool \ bool \ bool" where fun nf :: "int \ int \ bool" where "nf n1 n2 = (n1 - n2 \ n1 \ n2 \ n1 + n2)" +(* Nested record patterns *) +record 'MessageType MessagePacket = + target :: nat + mailbox :: "nat option" + message :: 'MessageType + +record 'MessageType EnvelopedMessage = + sender :: "nat option" + packet :: "'MessageType MessagePacket" + +record 'HandleType Timer = + time :: nat + handle :: 'HandleType + +datatype ('MessageType, 'HandleType) Trigger + = MessageArrived "'MessageType EnvelopedMessage" | + Elapsed "('HandleType Timer) list" + +fun target :: "'MessageType MessagePacket \ nat" where + "target (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) = + target'" + +fun mailbox :: "'MessageType MessagePacket \ nat option" where + "mailbox (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) = + mailbox'" + +fun message :: "'MessageType MessagePacket \ 'MessageType" where + "message (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) = + message'" + +fun sender :: "'MessageType EnvelopedMessage \ nat option" where + "sender (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = sender'" + +fun packet :: "'MessageType EnvelopedMessage \ 'MessageType MessagePacket" where + "packet (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = packet'" + +fun time :: "'HandleType Timer \ nat" where + "time (| Timer.time = time', Timer.handle = handle' |) = time'" + +fun handle :: "'HandleType Timer \ 'HandleType" where + "handle (| Timer.time = time', Timer.handle = handle' |) = handle'" + +fun getMessageFromTrigger :: "('M, 'H) Trigger \ 'M option" where + "getMessageFromTrigger v_0 = + (case (v_0) of + (MessageArrived v') \ + (case (EnvelopedMessage.packet v') of + (v'0) \ Some (MessagePacket.message v'0)) | + v'1 \ None)" + +fun getMessageFromTrigger' :: "('M, 'H) Trigger \ 'M option" where + "getMessageFromTrigger' t = + (case t of + (MessageArrived v') \ + (case (EnvelopedMessage.packet v') of + (v'0) \ Some (MessagePacket.message v'0)) | + v'2 \ None)" + end