Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix nested pattern matching
Browse files Browse the repository at this point in the history
lukaszcz committed Sep 18, 2024
1 parent d855023 commit 42bd0a8
Showing 3 changed files with 167 additions and 48 deletions.
116 changes: 68 additions & 48 deletions src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Data.List.NonEmpty.Extra qualified as NonEmpty
import Data.Text qualified as T
import Data.Text qualified as Text
import Juvix.Compiler.Backend.Isabelle.Data.Result
@@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
mkExprCase c@Case {..} = case _caseValue of
ExprIden v ->
case _caseBranches of
CaseBranch {..} :| [] ->
CaseBranch {..} :| _ ->
case _caseBranchPattern of
PatVar v' -> substVar v' v _caseBranchBody
_ -> ExprCase c
_ -> ExprCase c
ExprTuple (Tuple (ExprIden v :| [])) ->
case _caseBranches of
CaseBranch {..} :| [] ->
CaseBranch {..} :| _ ->
case _caseBranchPattern of
PatTuple (Tuple (PatVar v' :| [])) -> substVar v' v _caseBranchBody
_ -> ExprCase c
_ -> ExprCase c
_ -> ExprCase c
_ ->
case _caseBranches of
br@CaseBranch {..} :| _ ->
case _caseBranchPattern of
PatVar _ ->
ExprCase
Case
{ _caseValue = _caseValue,
_caseBranches = br :| []
}
PatTuple (Tuple (PatVar _ :| [])) ->
ExprCase
Case
{ _caseValue = _caseValue,
_caseBranches = br :| []
}
_ -> ExprCase c

goMutualBlock :: Internal.MutualBlock -> [Statement]
goMutualBlock Internal.MutualBlock {..} =
@@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} =
: goClauses cls
Nested pats npats ->
let rhs = goExpression'' nset' nmap' _lambdaBody
argnames' = fmap getPatternArgName _lambdaPatterns
argnames' = fmap getPatternArgName lambdaPats
vnames =
fmap
( \(idx :: Int, mname) ->
maybe
( defaultName
(getLoc cl)
( disambiguate
(nset' ^. nameSet)
("v_" <> show idx)
)
)
(overNameText (disambiguate (nset' ^. nameSet)))
mname
)
(NonEmpty.zip (nonEmpty' [0 ..]) argnames')
nonEmpty' $
fmap
( \(idx :: Int, mname) ->
maybe
( defaultName
(getLoc cl)
( disambiguate
(nset' ^. nameSet)
("v_" <> show idx)
)
)
(overNameText (disambiguate (nset' ^. nameSet)))
mname
)
(zip [0 ..] argnames')
nset'' = foldl' (flip (over nameSet . HashSet.insert . (^. namePretty))) nset' vnames
remainingBranches = goLambdaClauses'' nset'' nmap' cls
remainingBranches = goLambdaClauses'' nset'' nmap' (Just ty) cls
valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
patTuple = PatTuple (Tuple (nonEmpty' pats))
brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats)
@@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} =
}
]
where
(npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))
lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns)
(npats0, nset', nmap') = goPatternArgsTop lambdaPats
[] -> []

goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch] -> Pattern -> NonEmpty (Expression, Nested Pattern) -> NonEmpty CaseBranch
@@ -828,18 +843,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
| patsNum == 0 = goExpression (head _lambdaClauses ^. Internal.lambdaBody)
| otherwise = goLams vars
where
patsNum =
case _lambdaType of
Just ty ->
length
. filterTypeArgs 0 ty
. toList
$ head _lambdaClauses ^. Internal.lambdaPatterns
Nothing ->
length
. filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
. toList
$ head _lambdaClauses ^. Internal.lambdaPatterns
patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal.lambdaPatterns
vars = map (\i -> defaultName (getLoc lam) ("x" <> show i)) [0 .. patsNum - 1]

goLams :: [Name] -> Sem r Expression
@@ -869,7 +873,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
Tuple
{ _tupleComponents = nonEmpty' vars'
}
brs <- goLambdaClauses (toList _lambdaClauses)
brs <- goLambdaClauses _lambdaType (toList _lambdaClauses)
return $
mkExprCase
Case
@@ -926,17 +930,29 @@ goModule onlyTypes infoTable Internal.Module {..} =
Internal.CaseBranchRhsExpression e -> goExpression e
Internal.CaseBranchRhsIf {} -> error "unsupported: side conditions"

goLambdaClauses'' :: NameSet -> NameMap -> [Internal.LambdaClause] -> [CaseBranch]
goLambdaClauses'' nset nmap cls =
run $ runReader nset $ runReader nmap $ goLambdaClauses cls

goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => [Internal.LambdaClause] -> Sem r [CaseBranch]
goLambdaClauses = \case
filterLambdaPatternArgs :: Maybe Internal.Expression -> NonEmpty Internal.PatternArg -> [Internal.PatternArg]
filterLambdaPatternArgs mty cls = case mty of
Just ty ->
filterTypeArgs 0 ty
. toList
$ cls
Nothing ->
filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
. toList
$ cls

goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal.Expression -> [Internal.LambdaClause] -> [CaseBranch]
goLambdaClauses'' nset nmap mty cls =
run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls

goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => Maybe Internal.Expression -> [Internal.LambdaClause] -> Sem r [CaseBranch]
goLambdaClauses mty = \case
cl@Internal.LambdaClause {..} : cls -> do
(npat, nset, nmap) <- case _lambdaPatterns of
p :| [] -> goPatternArgCase p
let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns
(npat, nset, nmap) <- case lambdaPats of
[p] -> goPatternArgCase p
_ -> do
(npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
(npats, nset, nmap) <- goPatternArgsCase lambdaPats
let npat =
fmap
( \pats ->
@@ -950,7 +966,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
case npat of
Nested pat [] -> do
body <- withLocalNames nset nmap $ goExpression _lambdaBody
brs <- goLambdaClauses cls
brs <- goLambdaClauses mty cls
return $
CaseBranch
{ _caseBranchPattern = pat,
@@ -961,7 +977,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) "v")
nset' = over nameSet (HashSet.insert (vname ^. namePretty)) nset
rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls
let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats)
return
[ CaseBranch
@@ -1133,7 +1149,11 @@ goModule onlyTypes infoTable Internal.Module {..} =
case HashMap.lookup name (infoTable ^. Internal.infoConstructors) of
Just ctrInfo
| ctrInfo ^. Internal.constructorInfoRecord ->
Just (indName, goRecordFields (getArgtys ctrInfo) args)
case HashMap.lookup indName (infoTable ^. Internal.infoInductives) of
Just indInfo
| length (indInfo ^. Internal.inductiveInfoConstructors) == 1 ->
Just (indName, goRecordFields (getArgtys ctrInfo) args)
_ -> Nothing
where
indName = ctrInfo ^. Internal.constructorInfoInductive
_ -> Nothing
41 changes: 41 additions & 0 deletions tests/positive/Isabelle/Program.juvix
Original file line number Diff line number Diff line change
@@ -154,3 +154,44 @@ funR4 : R -> R
bf (b1 b2 : Bool) : Bool := not (b1 && b2);

nf (n1 n2 : Int) : Bool := n1 - n2 >= n1 || n2 <= n1 + n2;

-- Nested record patterns

type MessagePacket (MessageType : Type) : Type := mkMessagePacket {
target : Nat;
mailbox : Maybe Nat;
message : MessageType;
};

type EnvelopedMessage (MessageType : Type) : Type :=
mkEnvelopedMessage {
sender : Maybe Nat;
packet : MessagePacket MessageType;
};

type Timer (HandleType : Type): Type := mkTimer {
time : Nat;
handle : HandleType;
};

type Trigger (MessageType : Type) (HandleType : Type) :=
| MessageArrived { envelope : EnvelopedMessage MessageType; }
| Elapsed { timers : List (Timer HandleType) };

getMessageFromTrigger : {M H : Type} -> Trigger M H -> Maybe M
| (MessageArrived@{
envelope := (mkEnvelopedMessage@{
packet := (mkMessagePacket@{
message := m })})})
:= just m
| _ := nothing;


getMessageFromTrigger' {M H} (t : Trigger M H) : Maybe M :=
case t of
| (MessageArrived@{
envelope := (mkEnvelopedMessage@{
packet := (mkMessagePacket@{
message := m })})})
:= just m
| _ := nothing;
58 changes: 58 additions & 0 deletions tests/positive/Isabelle/isabelle/Program.thy
Original file line number Diff line number Diff line change
@@ -240,4 +240,62 @@ fun bf :: "bool \<Rightarrow> bool \<Rightarrow> bool" where
fun nf :: "int \<Rightarrow> int \<Rightarrow> bool" where
"nf n1 n2 = (n1 - n2 \<ge> n1 \<or> n2 \<le> n1 + n2)"

(* Nested record patterns *)
record 'MessageType MessagePacket =
target :: nat
mailbox :: "nat option"
message :: 'MessageType

record 'MessageType EnvelopedMessage =
sender :: "nat option"
packet :: "'MessageType MessagePacket"

record 'HandleType Timer =
time :: nat
handle :: 'HandleType

datatype ('MessageType, 'HandleType) Trigger
= MessageArrived "'MessageType EnvelopedMessage" |
Elapsed "('HandleType Timer) list"

fun target :: "'MessageType MessagePacket \<Rightarrow> nat" where
"target (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
target'"

fun mailbox :: "'MessageType MessagePacket \<Rightarrow> nat option" where
"mailbox (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
mailbox'"

fun message :: "'MessageType MessagePacket \<Rightarrow> 'MessageType" where
"message (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
message'"

fun sender :: "'MessageType EnvelopedMessage \<Rightarrow> nat option" where
"sender (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = sender'"

fun packet :: "'MessageType EnvelopedMessage \<Rightarrow> 'MessageType MessagePacket" where
"packet (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = packet'"

fun time :: "'HandleType Timer \<Rightarrow> nat" where
"time (| Timer.time = time', Timer.handle = handle' |) = time'"

fun handle :: "'HandleType Timer \<Rightarrow> 'HandleType" where
"handle (| Timer.time = time', Timer.handle = handle' |) = handle'"

fun getMessageFromTrigger :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
"getMessageFromTrigger v_0 =
(case (v_0) of
(MessageArrived v') \<Rightarrow>
(case (EnvelopedMessage.packet v') of
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
v'1 \<Rightarrow> None)"

fun getMessageFromTrigger' :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
"getMessageFromTrigger' t =
(case t of
(MessageArrived v') \<Rightarrow>
(case (EnvelopedMessage.packet v') of
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
v'2 \<Rightarrow> None)"

end

0 comments on commit 42bd0a8

Please sign in to comment.