Skip to content

Commit

Permalink
Allow punning in record updates (#3125)
Browse files Browse the repository at this point in the history
Now it is allowed to use field puns in record updates. E.g.
```
type R :=
  mkR@{
    a : Nat;
    b : Nat;
    c : Nat;
  };

example : R :=
  let
    z :=
      mkR@{
        a := 0;
        b := 0;
        c := 0;
      };
    a := 6;
  in z@R{a} -- the field `a` is updated to 6
```
  • Loading branch information
janmasrovira authored Nov 29, 2024
1 parent bf06cb1 commit f343096
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 62 deletions.
102 changes: 82 additions & 20 deletions src/Juvix/Compiler/Concrete/Language/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,9 @@ deriving stock instance Ord (ConstructorDef 'Parsed)

deriving stock instance Ord (ConstructorDef 'Scoped)

data RecordUpdateField (s :: Stage) = RecordUpdateField
{ _fieldUpdateName :: Symbol,
_fieldUpdateArgIx :: FieldArgIxType s,
_fieldUpdateAssignKw :: Irrelevant (KeywordRef),
_fieldUpdateValue :: ExpressionType s
}
data RecordUpdateField (s :: Stage)
= RecordUpdateFieldAssign (RecordUpdateFieldItemAssign s)
| RecordUpdateFieldPun (RecordUpdatePun s)
deriving stock (Generic)

instance Serialize (RecordUpdateField 'Scoped)
Expand All @@ -842,6 +839,34 @@ deriving stock instance Ord (RecordUpdateField 'Parsed)

deriving stock instance Ord (RecordUpdateField 'Scoped)

data RecordUpdateFieldItemAssign (s :: Stage) = RecordUpdateFieldItemAssign
{ _fieldUpdateName :: Symbol,
_fieldUpdateArgIx :: FieldArgIxType s,
_fieldUpdateAssignKw :: Irrelevant (KeywordRef),
_fieldUpdateValue :: ExpressionType s
}
deriving stock (Generic)

instance Serialize (RecordUpdateFieldItemAssign 'Scoped)

instance NFData (RecordUpdateFieldItemAssign 'Scoped)

instance Serialize (RecordUpdateFieldItemAssign 'Parsed)

instance NFData (RecordUpdateFieldItemAssign 'Parsed)

deriving stock instance Show (RecordUpdateFieldItemAssign 'Parsed)

deriving stock instance Show (RecordUpdateFieldItemAssign 'Scoped)

deriving stock instance Eq (RecordUpdateFieldItemAssign 'Parsed)

deriving stock instance Eq (RecordUpdateFieldItemAssign 'Scoped)

deriving stock instance Ord (RecordUpdateFieldItemAssign 'Parsed)

deriving stock instance Ord (RecordUpdateFieldItemAssign 'Scoped)

data RecordField (s :: Stage) = RecordField
{ _fieldName :: SymbolType s,
_fieldIsImplicit :: IsImplicitField,
Expand Down Expand Up @@ -1161,34 +1186,34 @@ deriving stock instance Ord (RecordPatternAssign 'Parsed)

deriving stock instance Ord (RecordPatternAssign 'Scoped)

data FieldPun (s :: Stage) = FieldPun
data PatternFieldPun (s :: Stage) = PatternFieldPun
{ _fieldPunIx :: FieldArgIxType s,
_fieldPunField :: SymbolType s
}
deriving stock (Generic)

instance Serialize (FieldPun 'Scoped)
instance Serialize (PatternFieldPun 'Scoped)

instance NFData (FieldPun 'Scoped)
instance NFData (PatternFieldPun 'Scoped)

instance Serialize (FieldPun 'Parsed)
instance Serialize (PatternFieldPun 'Parsed)

instance NFData (FieldPun 'Parsed)
instance NFData (PatternFieldPun 'Parsed)

deriving stock instance Show (FieldPun 'Parsed)
deriving stock instance Show (PatternFieldPun 'Parsed)

deriving stock instance Show (FieldPun 'Scoped)
deriving stock instance Show (PatternFieldPun 'Scoped)

deriving stock instance Eq (FieldPun 'Parsed)
deriving stock instance Eq (PatternFieldPun 'Parsed)

deriving stock instance Eq (FieldPun 'Scoped)
deriving stock instance Eq (PatternFieldPun 'Scoped)

deriving stock instance Ord (FieldPun 'Parsed)
deriving stock instance Ord (PatternFieldPun 'Parsed)

deriving stock instance Ord (FieldPun 'Scoped)
deriving stock instance Ord (PatternFieldPun 'Scoped)

data RecordPatternItem (s :: Stage)
= RecordPatternItemFieldPun (FieldPun s)
= RecordPatternItemFieldPun (PatternFieldPun s)
| RecordPatternItemAssign (RecordPatternAssign s)
deriving stock (Generic)

Expand Down Expand Up @@ -2429,6 +2454,33 @@ deriving stock instance Ord (NamedArgumentFunctionDef 'Parsed)

deriving stock instance Ord (NamedArgumentFunctionDef 'Scoped)

data RecordUpdatePun (s :: Stage) = RecordUpdatePun
{ _recordUpdatePunSymbol :: Symbol,
_recordUpdatePunReferencedSymbol :: PunSymbolType s,
_recordUpdatePunFieldIndex :: FieldArgIxType s
}
deriving stock (Generic)

instance Serialize (RecordUpdatePun 'Scoped)

instance NFData (RecordUpdatePun 'Scoped)

instance Serialize (RecordUpdatePun 'Parsed)

instance NFData (RecordUpdatePun 'Parsed)

deriving stock instance Show (RecordUpdatePun 'Parsed)

deriving stock instance Show (RecordUpdatePun 'Scoped)

deriving stock instance Eq (RecordUpdatePun 'Parsed)

deriving stock instance Eq (RecordUpdatePun 'Scoped)

deriving stock instance Ord (RecordUpdatePun 'Parsed)

deriving stock instance Ord (RecordUpdatePun 'Scoped)

data NamedArgumentPun (s :: Stage) = NamedArgumentPun
{ _namedArgumentPunSymbol :: Symbol,
_namedArgumentReferencedSymbol :: PunSymbolType s
Expand Down Expand Up @@ -2910,6 +2962,8 @@ deriving stock instance Ord (FunctionLhs 'Parsed)
deriving stock instance Ord (FunctionLhs 'Scoped)

makeLenses ''SideIfs
makeLenses ''RecordUpdatePun
makeLenses ''RecordUpdateFieldItemAssign
makeLenses ''FunctionDefNameScoped
makeLenses ''TypeSig
makeLenses ''FunctionLhs
Expand All @@ -2922,7 +2976,7 @@ makeLenses ''RhsExpression
makeLenses ''PatternArg
makeLenses ''WildcardConstructor
makeLenses ''DoubleBracesExpression
makeLenses ''FieldPun
makeLenses ''PatternFieldPun
makeLenses ''RecordPatternAssign
makeLenses ''RecordPattern
makeLenses ''ParensRecordUpdate
Expand Down Expand Up @@ -3328,13 +3382,21 @@ instance (SingI s) => HasLoc (NamedArgumentNew s) where
NamedArgumentNewFunction f -> getLoc f
NamedArgumentItemPun f -> getLoc f

instance HasLoc (RecordUpdatePun s) where
getLoc RecordUpdatePun {..} = getLocSymbolType _recordUpdatePunSymbol

instance HasLoc (NamedArgumentPun s) where
getLoc NamedArgumentPun {..} = getLocSymbolType _namedArgumentPunSymbol

instance (SingI s) => HasLoc (NamedApplicationNew s) where
getLoc NamedApplicationNew {..} = getLocIdentifierType _namedApplicationNewName

instance (SingI s) => HasLoc (RecordUpdateField s) where
getLoc = \case
RecordUpdateFieldAssign a -> getLoc a
RecordUpdateFieldPun a -> getLoc a

instance (SingI s) => HasLoc (RecordUpdateFieldItemAssign s) where
getLoc f = getLocSymbolType (f ^. fieldUpdateName) <> getLocExpressionType (f ^. fieldUpdateValue)

instance HasLoc (RecordUpdate s) where
Expand Down Expand Up @@ -3514,7 +3576,7 @@ instance (SingI s) => HasLoc (RecordPatternAssign s) where
getLoc (a ^. recordPatternAssignField)
<> getLocPatternParensType (a ^. recordPatternAssignPattern)

instance (SingI s) => HasLoc (FieldPun s) where
instance (SingI s) => HasLoc (PatternFieldPun s) where
getLoc f = getLocSymbolType (f ^. fieldPunField)

instance (SingI s) => HasLoc (RecordPatternItem s) where
Expand Down
14 changes: 11 additions & 3 deletions src/Juvix/Compiler/Concrete/Print/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ instance (SingI s) => PrettyPrint (NamedApplicationNew s) where
instance (SingI s) => PrettyPrint (NamedArgumentFunctionDef s) where
ppCode (NamedArgumentFunctionDef f) = ppCode f

instance PrettyPrint (RecordUpdatePun s) where
ppCode = ppCode . (^. recordUpdatePunSymbol)

instance PrettyPrint (NamedArgumentPun s) where
ppCode = ppCode . (^. namedArgumentPunSymbol)

Expand All @@ -384,10 +387,15 @@ instance (SingI s) => PrettyPrint (RecordStatement s) where
RecordStatementField f -> ppCode f
RecordStatementSyntax f -> ppCode f

instance (SingI s) => PrettyPrint (RecordUpdateField s) where
ppCode RecordUpdateField {..} =
instance (SingI s) => PrettyPrint (RecordUpdateFieldItemAssign s) where
ppCode RecordUpdateFieldItemAssign {..} =
ppSymbolType _fieldUpdateName <+> ppCode _fieldUpdateAssignKw <+> ppExpressionType _fieldUpdateValue

instance (SingI s) => PrettyPrint (RecordUpdateField s) where
ppCode = \case
RecordUpdateFieldAssign a -> ppCode a
RecordUpdateFieldPun a -> ppCode a

instance (SingI s) => PrettyPrint (RecordUpdate s) where
ppCode RecordUpdate {..} = do
let Irrelevant (l, r) = _recordUpdateDelims
Expand Down Expand Up @@ -1203,7 +1211,7 @@ instance (SingI s) => PrettyPrint (FunctionDef s) where
instance PrettyPrint Wildcard where
ppCode w = morpheme (getLoc w) C.kwWildcard

instance (SingI s) => PrettyPrint (FieldPun s) where
instance (SingI s) => PrettyPrint (PatternFieldPun s) where
ppCode = ppSymbolType . (^. fieldPunField)

instance (SingI s) => PrettyPrint (RecordPatternAssign s) where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ reservePatternFunctionSymbols = goAtom

goRecordItem :: RecordPatternItem 'Parsed -> Sem r ()
goRecordItem = \case
RecordPatternItemFieldPun FieldPun {..} -> do
RecordPatternItemFieldPun PatternFieldPun {..} -> do
void (reservePatternName (NameUnqualified _fieldPunField))
RecordPatternItemAssign RecordPatternAssign {..} -> do
goAtoms _recordPatternAssignPattern
Expand Down Expand Up @@ -2355,17 +2355,9 @@ checkRecordPattern r = do
RecordPatternItemAssign a -> RecordPatternItemAssign <$> checkAssign a
RecordPatternItemFieldPun a -> RecordPatternItemFieldPun <$> checkPun a
where
findField :: Symbol -> Sem r' Int
findField f =
fromMaybeM (throw err) $
asks @(RecordNameSignature 'Parsed) (^? recordNames . at f . _Just . nameItemIndex)
where
err :: ScoperError
err = ErrUnexpectedField (UnexpectedField f)

checkAssign :: RecordPatternAssign 'Parsed -> Sem r' (RecordPatternAssign 'Scoped)
checkAssign RecordPatternAssign {..} = do
idx' <- findField _recordPatternAssignField
idx' <- findRecordFieldIdx _recordPatternAssignField
pat' <- checkParsePatternAtoms _recordPatternAssignPattern
return
RecordPatternAssign
Expand All @@ -2374,21 +2366,33 @@ checkRecordPattern r = do
..
}

checkPun :: FieldPun 'Parsed -> Sem r' (FieldPun 'Scoped)
checkPun :: PatternFieldPun 'Parsed -> Sem r' (PatternFieldPun 'Scoped)
checkPun f = do
idx' <- findField (f ^. fieldPunField)
idx' <- findRecordFieldIdx (f ^. fieldPunField)
pk <- ask
f' <- case pk of
PatternNamesKindVariables ->
bindVariableSymbol (f ^. fieldPunField)
PatternNamesKindFunctions -> do
getReservedDefinitionSymbol (f ^. fieldPunField)
return
FieldPun
PatternFieldPun
{ _fieldPunIx = idx',
_fieldPunField = f'
}

findRecordFieldIdx ::
forall r.
(Members '[Reader (RecordNameSignature 'Parsed), Error ScoperError] r) =>
Symbol ->
Sem r Int
findRecordFieldIdx f =
fromMaybeM (throw err) $
asks @(RecordNameSignature 'Parsed) (^? recordNames . at f . _Just . nameItemIndex)
where
err :: ScoperError
err = ErrUnexpectedField (UnexpectedField f)

checkListPattern ::
forall r.
(Members '[Reader PatternNamesKind, Error ScoperError, State Scope, State ScoperState, State ScoperSyntax, Reader BindingStrategy, InfoTableBuilder, Reader InfoTable, NameIdGen] r) =>
Expand Down Expand Up @@ -2914,7 +2918,7 @@ checkNamedApplicationNew napp = do
. each
. nameBlockSymbols
forM_ nargs (checkNameInSignature namesInSignature . (^. namedArgumentNewSymbol))
puns <- scopePuns
puns <- scopePuns (napp ^.. namedApplicationNewArguments . each . _NamedArgumentItemPun)
args' <- withLocalScope . localBindings . ignoreSyntax $ do
mapM_ reserveNamedArgumentName nargs
mapM (checkNamedArgumentNew puns) nargs
Expand All @@ -2939,12 +2943,8 @@ checkNamedApplicationNew napp = do
unless (HashSet.member fname namesInSig) $
throw (ErrUnexpectedArgument (UnexpectedArgument fname))

scopePuns :: Sem r (HashMap Symbol ScopedIden)
scopePuns =
hashMap
<$> mapWithM
scopePun
(napp ^.. namedApplicationNewArguments . each . _NamedArgumentItemPun . namedArgumentPunSymbol)
scopePuns :: [NamedArgumentPun s] -> Sem r (HashMap Symbol ScopedIden)
scopePuns puns = hashMap <$> mapWithM scopePun (puns ^.. each . namedArgumentPunSymbol)
where
scopePun :: Symbol -> Sem r ScopedIden
scopePun = checkScopedIden . NameUnqualified
Expand Down Expand Up @@ -2986,7 +2986,7 @@ checkRecordUpdate RecordUpdate {..} = do
let sig = info ^. recordInfoSignature
(vars' :: IntMap (IsImplicit, S.Symbol), fields') <- withLocalScope $ do
vs <- mapM bindRecordUpdateVariable (P.recordNameSignatureByIndex sig)
fs <- mapM (checkUpdateField sig) _recordUpdateFields
fs <- runReader sig (mapM checkUpdateField _recordUpdateFields)
return (vs, fs)
let extra' =
RecordUpdateExtra
Expand All @@ -3009,23 +3009,51 @@ checkRecordUpdate RecordUpdate {..} = do
return (_nameItemImplicit, v)

checkUpdateField ::
(Members '[HighlightBuilder, Error ScoperError, State Scope, State ScoperState, Reader ScopeParameters, InfoTableBuilder, Reader InfoTable, NameIdGen, Reader PackageId] r) =>
RecordNameSignature 'Parsed ->
forall r.
(Members '[HighlightBuilder, Error ScoperError, State Scope, State ScoperState, Reader ScopeParameters, InfoTableBuilder, Reader InfoTable, NameIdGen, Reader PackageId, Reader (RecordNameSignature 'Parsed)] r) =>
RecordUpdateField 'Parsed ->
Sem r (RecordUpdateField 'Scoped)
checkUpdateField sig f = do
checkUpdateField = \case
RecordUpdateFieldAssign a -> RecordUpdateFieldAssign <$> checkUpdateFieldAssign a
RecordUpdateFieldPun a -> RecordUpdateFieldPun <$> checkRecordPun a
where
checkRecordPun :: RecordUpdatePun 'Parsed -> Sem r (RecordUpdatePun 'Scoped)
checkRecordPun RecordUpdatePun {..} = do
idx <- findRecordFieldIdx _recordUpdatePunSymbol
s <- checkScopedIden (NameUnqualified _recordUpdatePunSymbol)
return
RecordUpdatePun
{ _recordUpdatePunSymbol,
_recordUpdatePunReferencedSymbol = s,
_recordUpdatePunFieldIndex = idx
}

getUpdateFieldIdx ::
(Member (Error ScoperError) r) =>
RecordNameSignature s2 ->
RecordUpdateFieldItemAssign s ->
Sem r Int
getUpdateFieldIdx sig f =
maybe (throw unexpectedField) return (sig ^? recordNames . at (f ^. fieldUpdateName) . _Just . nameItemIndex)
where
unexpectedField :: ScoperError
unexpectedField = ErrUnexpectedField (UnexpectedField (f ^. fieldUpdateName))

checkUpdateFieldAssign ::
(Members '[Reader (RecordNameSignature 'Parsed), HighlightBuilder, Error ScoperError, State Scope, State ScoperState, Reader ScopeParameters, InfoTableBuilder, Reader InfoTable, NameIdGen, Reader PackageId] r) =>
RecordUpdateFieldItemAssign 'Parsed ->
Sem r (RecordUpdateFieldItemAssign 'Scoped)
checkUpdateFieldAssign f = do
sig <- ask @(RecordNameSignature 'Parsed)
value' <- checkParseExpressionAtoms (f ^. fieldUpdateValue)
idx' <- maybe (throw unexpectedField) return (sig ^? recordNames . at (f ^. fieldUpdateName) . _Just . nameItemIndex)
idx' <- getUpdateFieldIdx sig f
return
RecordUpdateField
RecordUpdateFieldItemAssign
{ _fieldUpdateName = f ^. fieldUpdateName,
_fieldUpdateArgIx = idx',
_fieldUpdateAssignKw = f ^. fieldUpdateAssignKw,
_fieldUpdateValue = value'
}
where
unexpectedField :: ScoperError
unexpectedField = ErrUnexpectedField (UnexpectedField (f ^. fieldUpdateName))

getRecordInfo ::
forall r.
Expand Down
Loading

0 comments on commit f343096

Please sign in to comment.