Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow pattern-matching in variable definitions #3181

Merged
merged 15 commits into from
Nov 25, 2024
2 changes: 1 addition & 1 deletion juvix-stdlib
Submodule juvix-stdlib updated 1 files
+13 −3 index.juvix
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Backend/Html/Translation/FromTyped.hs
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,12 @@ goAxiom axiom = do
goDeriving :: forall r. (Members '[Reader HtmlOptions] r) => Deriving 'Scoped -> Sem r Html
goDeriving def = do
sig <- ppHelper (ppCode def)
defHeader (def ^. derivingFunLhs . funLhsName) sig Nothing
defHeader (def ^. derivingFunLhs . funLhsName . functionDefName) sig Nothing

goFunctionDef :: forall r. (Members '[Reader HtmlOptions] r) => FunctionDef 'Scoped -> Sem r Html
goFunctionDef def = do
sig <- ppHelper (ppCode (functionDefLhs def))
defHeader (def ^. signName) sig (def ^. signDoc)
defHeader (def ^. signName . functionDefName) sig (def ^. signDoc)

goInductive :: forall r. (Members '[Reader HtmlOptions] r) => InductiveDef 'Scoped -> Sem r Html
goInductive def = do
Expand Down
8 changes: 4 additions & 4 deletions src/Juvix/Compiler/Concrete/Data/InfoTableBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ runInfoTableBuilder ini = reinterpret (runState ini) $ \case
in do
modify' (over infoInductives (HashMap.insert (ity ^. inductiveName . nameId) ity))
highlightDoc (ity ^. inductiveName . nameId) j
RegisterFunctionDef f ->
RegisterFunctionDef f -> do
let j = f ^. signDoc
in do
modify' (over infoFunctions (HashMap.insert (f ^. signName . nameId) f))
highlightDoc (f ^. signName . nameId) j
fid = f ^. signName . functionDefName . nameId
modify' (over infoFunctions (HashMap.insert fid f))
highlightDoc fid j
RegisterName n -> highlightName (S.anameFromName n)
RegisterScopedIden n -> highlightName (anameFromScopedIden n)
RegisterModuleDoc uid doc -> highlightDoc uid doc
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Concrete/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ groupStatements = \case
definesSymbol n s = case s of
StatementInductive d -> n `elem` syms d
StatementAxiom d -> n == symbolParsed (d ^. axiomName)
StatementFunctionDef d -> n == symbolParsed (d ^. signName)
StatementFunctionDef d -> withFunctionSymbol False (\n' -> n == symbolParsed n') (d ^. signName)
_ -> False
where
syms :: InductiveDef s -> [Symbol]
Expand Down
17 changes: 13 additions & 4 deletions src/Juvix/Compiler/Concrete/Gen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ simplestFunctionDefParsed funNameTxt funBody = do
funName <- symbol funNameTxt
return (simplestFunctionDef funName (mkExpressionAtoms funBody))

simplestFunctionDef :: FunctionName s -> ExpressionType s -> FunctionDef s
simplestFunctionDef :: forall s. (SingI s) => FunctionName s -> ExpressionType s -> FunctionDef s
simplestFunctionDef funName funBody =
FunctionDef
{ _signName = funName,
{ _signName = name,
_signBody = SigBodyExpression funBody,
_signTypeSig =
TypeSig
Expand All @@ -42,6 +42,15 @@ simplestFunctionDef funName funBody =
_signInstance = Nothing,
_signCoercion = Nothing
}
where
name :: FunctionSymbolType s
name = case sing :: SStage s of
SParsed -> FunctionDefName funName
SScoped ->
FunctionDefNameScoped
{ _functionDefName = funName,
_functionDefNamePattern = Nothing
}

smallUniverseExpression :: forall s r. (SingI s) => (Members '[Reader Interval] r) => Sem r (ExpressionType s)
smallUniverseExpression = do
Expand Down Expand Up @@ -284,7 +293,7 @@ mkTypeSigType ts = do

mkTypeSigType' :: forall s. (SingI s) => ExpressionType s -> TypeSig s -> (ExpressionType s)
mkTypeSigType' wildcard TypeSig {..} =
foldr mkFun rty (map mkFunctionParameters _typeSigArgs)
foldr (mkFun . mkFunctionParameters) rty _typeSigArgs
where
rty = fromMaybe wildcard _typeSigRetType

Expand All @@ -297,7 +306,7 @@ mkTypeSigType' wildcard TypeSig {..} =
{ _paramNames = getSigArgNames arg,
_paramImplicit = _sigArgImplicit,
_paramDelims = fmap Just _sigArgDelims,
_paramColon = Irrelevant $ maybe Nothing (Just . (^. unIrrelevant)) _sigArgColon,
_paramColon = Irrelevant $ fmap (^. unIrrelevant) _sigArgColon,
_paramType = fromMaybe (univ (getLoc arg)) _sigArgType
}

Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Concrete/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ statementLabel = \case
StatementSyntax s -> goSyntax s
StatementOpenModule {} -> Nothing
StatementProjectionDef {} -> Nothing
StatementFunctionDef f -> Just (f ^. signName . symbolTypeLabel)
StatementDeriving f -> Just (f ^. derivingFunLhs . funLhsName . symbolTypeLabel)
StatementFunctionDef f -> withFunctionSymbol Nothing (Just . (^. symbolTypeLabel)) (f ^. signName)
StatementDeriving f -> withFunctionSymbol Nothing (Just . (^. symbolTypeLabel)) (f ^. derivingFunLhs . funLhsName)
StatementImport i -> Just (i ^. importModulePath . to modulePathTypeLabel)
StatementInductive i -> Just (i ^. inductiveName . symbolTypeLabel)
StatementModule i -> Just (i ^. modulePath . to modulePathTypeLabel)
Expand Down
97 changes: 80 additions & 17 deletions src/Juvix/Compiler/Concrete/Language/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ type family SymbolType s = res | res -> s where
SymbolType 'Parsed = Symbol
SymbolType 'Scoped = S.Symbol

type FunctionSymbolType :: Stage -> GHCType
type family FunctionSymbolType s = res | res -> s where
FunctionSymbolType 'Parsed = FunctionDefNameParsed
FunctionSymbolType 'Scoped = FunctionDefNameScoped

type IdentifierType :: Stage -> GHCType
type family IdentifierType s = res | res -> s where
IdentifierType 'Parsed = Name
Expand Down Expand Up @@ -701,8 +706,27 @@ deriving stock instance Ord (Deriving 'Parsed)

deriving stock instance Ord (Deriving 'Scoped)

data FunctionDefNameParsed
= FunctionDefNamePattern (PatternAtom 'Parsed)
| FunctionDefName Symbol
deriving stock (Eq, Ord, Show, Generic)

instance Serialize FunctionDefNameParsed

instance NFData FunctionDefNameParsed

data FunctionDefNameScoped = FunctionDefNameScoped
{ _functionDefName :: S.Symbol,
_functionDefNamePattern :: Maybe PatternArg
}
deriving stock (Eq, Ord, Show, Generic)

instance Serialize FunctionDefNameScoped

instance NFData FunctionDefNameScoped

data FunctionDef (s :: Stage) = FunctionDef
{ _signName :: FunctionName s,
{ _signName :: FunctionSymbolType s,
_signTypeSig :: TypeSig s,
_signDoc :: Maybe (Judoc s),
_signPragmas :: Maybe ParsedPragmas,
Expand Down Expand Up @@ -2860,7 +2884,7 @@ data FunctionLhs (s :: Stage) = FunctionLhs
_funLhsTerminating :: Maybe KeywordRef,
_funLhsInstance :: Maybe KeywordRef,
_funLhsCoercion :: Maybe KeywordRef,
_funLhsName :: FunctionName s,
_funLhsName :: FunctionSymbolType s,
_funLhsTypeSig :: TypeSig s
}
deriving stock (Generic)
Expand All @@ -2886,6 +2910,7 @@ deriving stock instance Ord (FunctionLhs 'Parsed)
deriving stock instance Ord (FunctionLhs 'Scoped)

makeLenses ''SideIfs
makeLenses ''FunctionDefNameScoped
makeLenses ''TypeSig
makeLenses ''FunctionLhs
makeLenses ''Statements
Expand Down Expand Up @@ -2975,6 +3000,7 @@ makeLenses ''MarkdownInfo
makeLenses ''Deriving

makePrisms ''NamedArgumentNew
makePrisms ''FunctionDefNameParsed

functionDefLhs :: FunctionDef s -> FunctionLhs s
functionDefLhs FunctionDef {..} =
Expand Down Expand Up @@ -3146,6 +3172,29 @@ instance HasLoc (OpenModule s short) where
instance HasLoc (ProjectionDef s) where
getLoc = getLoc . (^. projectionConstructor)

getLocFunctionSymbolType :: forall s. (SingI s) => FunctionSymbolType s -> Interval
getLocFunctionSymbolType = case sing :: SStage s of
SParsed -> getLoc
SScoped -> getLoc

instance HasLoc FunctionDefNameScoped where
getLoc FunctionDefNameScoped {..} =
getLoc _functionDefName
<>? (getLoc <$> _functionDefNamePattern)

instance HasLoc FunctionDefNameParsed where
getLoc = \case
FunctionDefNamePattern a -> getLoc a
FunctionDefName s -> getLoc s

instance (SingI s) => HasLoc (FunctionLhs s) where
getLoc FunctionLhs {..} =
(getLoc <$> _funLhsBuiltin)
?<> (getLoc <$> _funLhsTerminating)
?<> ( getLocFunctionSymbolType _funLhsName
<>? (getLocExpressionType <$> _funLhsTypeSig ^. typeSigRetType)
)

instance (SingI s) => HasLoc (Deriving s) where
getLoc Deriving {..} =
getLoc _derivingKw
Expand Down Expand Up @@ -3382,22 +3431,14 @@ instance (SingI s) => HasLoc (FunctionDefBody s) where
SigBodyExpression e -> getLocExpressionType e
SigBodyClauses cl -> getLocSpan cl

instance (SingI s) => HasLoc (FunctionLhs s) where
getLoc FunctionLhs {..} =
(getLoc <$> _funLhsBuiltin)
?<> (getLoc <$> _funLhsTerminating)
?<> ( getLocSymbolType _funLhsName
<>? (getLocExpressionType <$> _funLhsTypeSig ^. typeSigRetType)
)

instance (SingI s) => HasLoc (FunctionDef s) where
getLoc FunctionDef {..} =
(getLoc <$> _signDoc)
?<> (getLoc <$> _signPragmas)
?<> (getLoc <$> _signBuiltin)
?<> (getLoc <$> _signTerminating)
?<> getLocSymbolType _signName
<> (getLoc _signBody)
?<> (getLocFunctionSymbolType _signName)
<> getLoc _signBody

instance HasLoc (Example s) where
getLoc e = e ^. exampleLoc
Expand Down Expand Up @@ -3433,6 +3474,11 @@ getLocPatternParensType = case sing :: SStage s of
SScoped -> getLoc
SParsed -> getLoc

getLocPatternAtomType :: forall s. (SingI s) => PatternAtomType s -> Interval
getLocPatternAtomType = case sing :: SStage s of
SScoped -> getLoc
SParsed -> getLoc

instance (SingI s) => HasLoc (RecordPatternAssign s) where
getLoc a =
getLoc (a ^. recordPatternAssignField)
Expand Down Expand Up @@ -3581,17 +3627,34 @@ symbolParsed sym = case sing :: SStage s of
SParsed -> sym
SScoped -> sym ^. S.nameConcrete

getFunctionSymbol :: forall s. (SingI s) => FunctionSymbolType s -> SymbolType s
getFunctionSymbol sym = case sing :: SStage s of
SParsed -> case sym of
FunctionDefName p -> p
FunctionDefNamePattern {} -> impossibleError "invalid call"
SScoped -> sym ^. functionDefName

functionSymbolPattern :: forall s. (SingI s) => FunctionSymbolType s -> Maybe (PatternAtomType s)
functionSymbolPattern f = case sing :: SStage s of
SParsed -> f ^? _FunctionDefNamePattern
SScoped -> f ^. functionDefNamePattern

withFunctionSymbol :: forall s a. (SingI s) => a -> (SymbolType s -> a) -> FunctionSymbolType s -> a
withFunctionSymbol a f sym = case sing :: SStage s of
SParsed -> maybe a f (sym ^? _FunctionDefName)
SScoped -> f (sym ^. functionDefName)

namedArgumentNewSymbolParsed :: (SingI s) => SimpleGetter (NamedArgumentNew s) Symbol
namedArgumentNewSymbolParsed = to $ \case
NamedArgumentItemPun a -> a ^. namedArgumentPunSymbol
NamedArgumentNewFunction a -> symbolParsed (a ^. namedArgumentFunctionDef . signName)
NamedArgumentNewFunction a -> symbolParsed (getFunctionSymbol (a ^. namedArgumentFunctionDef . signName))

namedArgumentNewSymbol :: Lens' (NamedArgumentNew 'Parsed) Symbol
namedArgumentNewSymbol f = \case
NamedArgumentItemPun a -> NamedArgumentItemPun <$> namedArgumentPunSymbol f a
NamedArgumentNewFunction a ->
NamedArgumentNewFunction
<$> (namedArgumentFunctionDef . signName) f a
NamedArgumentItemPun a -> NamedArgumentItemPun <$> (namedArgumentPunSymbol f a)
NamedArgumentNewFunction a -> do
a' <- f (a ^?! namedArgumentFunctionDef . signName . _FunctionDefName)
return $ NamedArgumentNewFunction (over namedArgumentFunctionDef (set signName (FunctionDefName a')) a)

scopedIdenSrcName :: Lens' ScopedIden S.Name
scopedIdenSrcName f n = case n ^. scopedIdenAlias of
Expand Down
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Concrete/Print/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,10 @@ instance (SingI s) => PrettyPrint (FunctionLhs s) where
coercion' = (<> if isJust instance' then space else line) . ppCode <$> _funLhsCoercion
instance' = (<> line) . ppCode <$> _funLhsInstance
builtin' = (<> line) . ppCode <$> _funLhsBuiltin
name' = annDef _funLhsName (ppSymbolType _funLhsName)
mpat :: Maybe (PatternAtomType s) = functionSymbolPattern _funLhsName
name' = case mpat of
Just pat -> withFunctionSymbol id annDef _funLhsName (ppPatternAtomType pat)
Nothing -> annDef (getFunctionSymbol _funLhsName) (ppSymbolType (getFunctionSymbol _funLhsName))
sig' = ppCode _funLhsTypeSig
builtin'
?<> termin'
Expand Down
Loading
Loading