diff --git a/pyk/src/pyk/k2lean4/model.py b/pyk/src/pyk/k2lean4/model.py index c53e4c2a53..704e05315f 100644 --- a/pyk/src/pyk/k2lean4/model.py +++ b/pyk/src/pyk/k2lean4/model.py @@ -7,6 +7,9 @@ if TYPE_CHECKING: from collections.abc import Iterable + from typing import Final + +_LEAN_WORDS: Final = {'ite', 'end', 'where'} # Words that cannot be a the name of a declaration def indent(text: str, n: int) -> str: @@ -17,6 +20,16 @@ def indent(text: str, n: int) -> str: return '\n'.join(res) +def mask_name(name: str, mask: str | None) -> str: + """Append `mask` to `name` if `name` is in `_LEAN_KEYWORDS`.""" + if name == '': + return '' + elif name not in _LEAN_WORDS: + return name + else: + return f'{name}{mask}' if mask is not None and mask != '' else f'{name}Mask' + + @final @dataclass(frozen=True) class Module: @@ -64,7 +77,8 @@ def __init__(self, ident: str | DeclId, signature: Signature, modifiers: Modifie def __str__(self) -> str: modifiers = f'{self.modifiers} ' if self.modifiers else '' - return f'{modifiers}axiom {self.ident} {self.signature}' + ident = mask_name(f'{self.ident}', 'Ax') + return f'{modifiers}axiom {ident} {self.signature}' @final @@ -90,8 +104,9 @@ def __init__( def __str__(self) -> str: modifiers = f'{self.modifiers} ' if self.modifiers else '' + ident = mask_name(f'{self.ident}', 'Abbr') signature = f' {self.signature}' if self.signature else '' - return f'{modifiers}abbrev {self.ident}{signature} := {self.val}' + return f'{modifiers}abbrev {ident}{signature} := {self.val}' @final @@ -122,12 +137,13 @@ def __init__( def __str__(self) -> str: modifiers = f'{self.modifiers} ' if self.modifiers else '' + ident = mask_name(f'{self.ident}', 'Ind') signature = f' {self.signature}' if self.signature else '' where = ' where' if self.ctors else '' deriving = ', '.join(self.deriving) lines = [] - lines.append(f'{modifiers}inductive {self.ident}{signature}{where}') + lines.append(f'{modifiers}inductive {ident}{signature}{where}') for ctor in self.ctors: lines.append(f' | {ctor}') if deriving: @@ -171,7 +187,7 @@ def __str__(self) -> str: modifiers = f'{self.modifiers} ' if self.modifiers else '' attr_kind = f'{self.attr_kind.value} ' if self.attr_kind else '' priority = f' (priority := {self.priority})' if self.priority is not None else '' - ident = f' {self.ident}' if self.ident else '' + ident = f' {mask_name(str(self.ident), "Inst")}' if self.ident else '' signature = f' {self.signature}' if self.signature else '' decl = f'{modifiers}{attr_kind}instance{priority}{ident}{signature}' @@ -311,6 +327,7 @@ def __str__(self) -> str: lines = [] modifiers = f'{self.modifiers} ' if self.modifiers else '' + ident = mask_name(str(self.ident), 'Struct') binders = ( ' '.join(str(binder) for binder in self.signature.binders) if self.signature and self.signature.binders @@ -321,7 +338,7 @@ def __str__(self) -> str: extends = f' extends {extends}' if extends else '' ty = f' : {self.signature.ty}' if self.signature and self.signature.ty else '' where = ' where' if self.ctor else '' - lines.append(f'{modifiers}structure {self.ident}{binders}{extends}{ty}{where}') + lines.append(f'{modifiers}structure {ident}{binders}{extends}{ty}{where}') if self.deriving: lines.append(f' deriving {self.deriving}')