diff --git a/pyproject.toml b/pyproject.toml index bc36dee..50dd6f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,13 @@ pytest = "^7.4.3" [tool.poetry.scripts] vyper-lsp = 'vyper_lsp.main:main' +[tool.coverage.run] +source = ["vyper_lsp"] +omit = ["vyper_lsp/analyzer/SourceAnalyzer.py", + "vyper_lsp/__init__.py", + "vyper_lsp/__main__.py", + ] + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_completions.py b/tests/test_completions.py index 3b59f3a..7654f84 100644 --- a/tests/test_completions.py +++ b/tests/test_completions.py @@ -1,6 +1,7 @@ from lsprotocol.types import ( CompletionContext, CompletionParams, + CompletionTriggerKind, Position, TextDocumentIdentifier, ) @@ -27,7 +28,7 @@ def baz(): self. """ - doc = Document(uri="examples/Foo.vy", source=src) + doc = Document(uri="", source=src) pos = Position(line=11, character=7) context = CompletionContext(trigger_character=".", trigger_kind=2) params = CompletionParams( @@ -62,11 +63,11 @@ def baz(): x: Foo = Foo. """ - doc = Document(uri="examples/Foo.vy", source=src) + doc = Document(uri="", source=src) pos = Position(line=15, character=18) context = CompletionContext(trigger_character=".", trigger_kind=2) params = CompletionParams( - text_document={"uri": doc.uri, "source": src}, position=pos, context=context + text_document=TextDocumentIdentifier(uri=doc.uri), position=pos, context=context ) analyzer = AstAnalyzer(ast) @@ -74,3 +75,40 @@ def baz(): assert len(completions.items) == 2 assert "BAR" in [c.label for c in completions.items] assert "BAZ" in [c.label for c in completions.items] + + +def test_completion_fn_decorator(ast): + src = """ +@internal +def foo(): + return + +@external +def bar(): + self.foo() +""" + ast.build_ast(src) + + src += """ +@ +""" + + doc = Document(uri="", source=src) + pos = Position(line=8, character=1) + context = CompletionContext( + trigger_character="@", trigger_kind=CompletionTriggerKind.TriggerCharacter + ) + params = CompletionParams( + text_document=TextDocumentIdentifier(uri=doc.uri), position=pos, context=context + ) + + analyzer = AstAnalyzer(ast) + completions = analyzer.get_completions_in_doc(doc, params) + assert len(completions.items) == 6 + labels = [c.label for c in completions.items] + assert "internal" in labels + assert "external" in labels + assert "payable" in labels + assert "nonpayable" in labels + assert "view" in labels + assert "pure" in labels diff --git a/tests/test_debouncer.py b/tests/test_debouncer.py new file mode 100644 index 0000000..3ca1749 --- /dev/null +++ b/tests/test_debouncer.py @@ -0,0 +1,21 @@ +import time +from vyper_lsp.debounce import Debouncer # Import Debouncer from your module + + +def test_debounce(): + result = [] + + def test_function(arg): + result.append(arg) + + debouncer = Debouncer(wait=0.5) + debounced_func = debouncer.debounce(test_function) + + debounced_func("first call") + time.sleep(0.2) # Sleep for less than the debounce period + debounced_func("second call") + time.sleep( + 0.6 + ) # Sleep for more than the debounce period to allow the function to execute + + assert result == ["second call"] diff --git a/tests/test_info.py b/tests/test_info.py index dbfacac..2d7141f 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -13,10 +13,18 @@ def foo(x: int128, y: int128) -> int128: @external def bar(): self.foo(1, 2) + +@internal +def baz(x: int128) -> int128: + return x + +@external +def foobar(): + self.foo(self.baz(1), 2) """ ast.build_ast(src) - doc = Document(uri="examples/Foo.vy", source=src) + doc = Document(uri="", source=src) pos = Position(line=7, character=13) params = SignatureHelpParams( @@ -29,3 +37,57 @@ def bar(): assert sig_help.active_signature == 0 assert sig_help.signatures[0].active_parameter == 1 assert sig_help.signatures[0].label == "foo(x: int128, y: int128) -> int128" + + pos = Position(line=15, character=22) + params = SignatureHelpParams( + text_document=TextDocumentIdentifier(doc.uri), position=pos + ) + sig_help = analyzer.signature_help(doc, params) + assert sig_help + assert sig_help.active_signature == 0 + assert sig_help.signatures[0].active_parameter == 1 + assert sig_help.signatures[0].label == "baz(x: int128) -> int128" + + +def test_hover(ast: AST): + src = """ +@internal +def foo( + x: int128, + y: int128 +) -> int128: + return x + y + +@external +def bar(): + self.foo(1, 2) + +@internal +def noreturn(x: uint256): + y: uint256 = x + +@internal +def baz(): + self.noreturn(1) +""" + ast.build_ast(src) + + doc = Document(uri="", source=src) + + pos = Position(line=10, character=11) + + analyzer = AstAnalyzer(ast) + hover = analyzer.hover_info(doc, pos) + assert hover + assert ( + hover + == """(Internal Function) def foo( + x: int128, + y: int128 +) -> int128:""" + ) + + pos = Position(line=18, character=11) + hover = analyzer.hover_info(doc, pos) + assert hover + assert hover == "(Internal Function) def noreturn(x: uint256):" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b50c03c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,43 @@ +from vyper_lsp import utils + + +def test_get_word_at_cursor(): + text = "self.foo = 123" + assert utils.get_word_at_cursor(text, 0) == "self" + assert utils.get_word_at_cursor(text, 1) == "self" + assert utils.get_word_at_cursor(text, 5) == "foo" + assert utils.get_word_at_cursor(text, 12) == "123" + + text = "foo_bar = 123" + assert utils.get_word_at_cursor(text, 0) == "foo_bar" + assert utils.get_word_at_cursor(text, 4) == "foo_bar" + + +def test_get_expression_at_cursor(): + text = "self.foo = 123" + assert utils.get_expression_at_cursor(text, 0) == "self.foo" + assert utils.get_expression_at_cursor(text, 1) == "self.foo" + assert utils.get_expression_at_cursor(text, 5) == "self.foo" + assert utils.get_expression_at_cursor(text, 12) == "123" + + text = "foo_bar = self.baz (1,2,3)" + assert utils.get_expression_at_cursor(text, 0) == "foo_bar" + assert utils.get_expression_at_cursor(text, 4) == "foo_bar" + assert utils.get_expression_at_cursor(text, 21) == "self.baz (1,2,3)" + + +def test_get_internal_fn_name_at_cursor(): + text = "self.foo = 123" + assert utils.get_internal_fn_name_at_cursor(text, 0) is None + assert utils.get_internal_fn_name_at_cursor(text, 1) is None + assert utils.get_internal_fn_name_at_cursor(text, 5) is None + assert utils.get_internal_fn_name_at_cursor(text, 12) is None + + text = "foo_bar = self.baz (1,2,3)" + assert utils.get_internal_fn_name_at_cursor(text, 0) is None + assert utils.get_internal_fn_name_at_cursor(text, 4) is None + assert utils.get_internal_fn_name_at_cursor(text, 21) == "baz" + + text = "self.foo(self.bar())" + assert utils.get_internal_fn_name_at_cursor(text, 7) == "foo" + assert utils.get_internal_fn_name_at_cursor(text, 15) == "bar" diff --git a/vyper_lsp/analyzer/AstAnalyzer.py b/vyper_lsp/analyzer/AstAnalyzer.py index 45c74bf..a9dbbe5 100644 --- a/vyper_lsp/analyzer/AstAnalyzer.py +++ b/vyper_lsp/analyzer/AstAnalyzer.py @@ -5,6 +5,7 @@ from packaging.version import Version from lsprotocol.types import ( Diagnostic, + DiagnosticSeverity, ParameterInformation, Position, Range, @@ -22,6 +23,9 @@ get_word_at_cursor, get_installed_vyper_version, get_internal_fn_name_at_cursor, + diagnostic_from_exception, + is_internal_fn, + is_state_var, ) from lsprotocol.types import ( CompletionItem, @@ -31,8 +35,8 @@ ) from pygls.server import LanguageServer -pattern = r"(.+) is deprecated\. Please use `(.+)` instead\." -compiled_pattern = re.compile(pattern) +pattern_text = r"(.+) is deprecated\. Please use `(.+)` instead\." +deprecation_pattern = re.compile(pattern_text) min_vyper_version = Version("0.3.7") @@ -44,7 +48,7 @@ BYTES_M_TYPES = {f"bytes{i}" for i in range(32, 0, -1)} DECIMAL_TYPES = {"decimal"} -BASE_TYPES = {"bool", "address"} | INTEGER_TYPES | BYTES_M_TYPES | DECIMAL_TYPES +BASE_TYPES = list({"bool", "address"} | INTEGER_TYPES | BYTES_M_TYPES | DECIMAL_TYPES) DECORATORS = ["payable", "nonpayable", "view", "pure", "external", "internal"] @@ -60,19 +64,6 @@ def __init__(self, ast: AST) -> None: else: self.diagnostics_enabled = True - def _range_from_exception(self, node: VyperException) -> Range: - return Range( - start=Position(line=node.lineno - 1, character=node.col_offset), - end=Position(line=node.end_lineno - 1, character=node.end_col_offset), - ) - - def _diagnostic_from_exception(self, node: VyperException) -> Diagnostic: - return Diagnostic( - range=self._range_from_exception(node), - message=str(node), - severity=1, - ) - def signature_help( self, doc: Document, params: SignatureHelpParams ) -> SignatureHelp: @@ -84,38 +75,36 @@ def signature_help( current_line, params.position.character - 1 ) - if expression.startswith("self."): - node = self.ast.find_function_declaration_node_for_name(fn_name) - if node: - fn_name = node.name - arg_str = ", ".join( - [f"{arg.arg}: {arg.annotation.id}" for arg in node.args.args] - ) - fn_label = f"{fn_name}({arg_str})" - parameters = [] - if node.returns: - line = doc.lines[node.lineno - 1] - fn_label = line.removeprefix("def ").removesuffix(":\n") - for arg in node.args.args: - start_index = fn_label.find(arg.arg) - end_index = start_index + len(arg.arg) - parameters.append( - ParameterInformation( - label=(start_index, end_index), documentation=None - ) - ) - active_parameter = current_line.split("(")[-1].count(",") - return SignatureHelp( - signatures=[ - SignatureInformation( - label=fn_label, - parameters=parameters, - documentation=None, - active_parameter=active_parameter or 0, - ) - ], - active_signature=0, + if not expression.startswith("self."): + return None + + node = self.ast.find_function_declaration_node_for_name(fn_name) + if not node: + return None + + fn_name = node.name + parameters = [] + line = doc.lines[node.lineno - 1] + fn_label = line.removeprefix("def ").removesuffix(":\n") + + for arg in node.args.args: + start_index = fn_label.find(arg.arg) + end_index = start_index + len(arg.arg) + parameters.append( + ParameterInformation(label=(start_index, end_index), documentation=None) + ) + active_parameter = current_line.split("(")[-1].count(",") + return SignatureHelp( + signatures=[ + SignatureInformation( + label=fn_label, + parameters=parameters, + documentation=None, + active_parameter=active_parameter or 0, ) + ], + active_signature=0, + ) def get_completions_in_doc( self, document: Document, params: CompletionParams @@ -124,48 +113,50 @@ def get_completions_in_doc( current_line = document.lines[params.position.line].strip() custom_types = self.ast.get_user_defined_types() - if params.context: - if params.context.trigger_character == ".": - # get element before the dot - element = current_line.split(" ")[-1].split(".")[0] - - # internal functions and state variables - if element == "self": - for fn in self.ast.get_internal_functions(): - items.append(CompletionItem(label=fn)) - # TODO: This should exclude constants and immutables - for var in self.ast.get_state_variables(): - items.append(CompletionItem(label=var)) - else: - # TODO: This is currently only correct for enums - # For structs, we'll need to get the type of the variable - for attr in self.ast.get_attributes_for_symbol(element): - items.append(CompletionItem(label=attr)) - completions = CompletionList(is_incomplete=False, items=items) - return completions - elif params.context.trigger_character == "@": - for dec in DECORATORS: - items.append(CompletionItem(label=dec)) - completions = CompletionList(is_incomplete=False, items=items) - return completions - elif params.context.trigger_character == ":": - for typ in custom_types + list(BASE_TYPES): - items.append(CompletionItem(label=typ, insert_text=f" {typ}")) + if not params.context: + return CompletionList(is_incomplete=False, items=[]) - completions = CompletionList(is_incomplete=False, items=items) - return completions + if params.context.trigger_character == ".": + # get element before the dot + element = current_line.split(" ")[-1].split(".")[0] + + # internal functions and state variables + if element == "self": + for fn in self.ast.get_internal_functions(): + items.append(CompletionItem(label=fn)) + # TODO: This should exclude constants and immutables + for var in self.ast.get_state_variables(): + items.append(CompletionItem(label=var)) else: - if params.context.trigger_character == " ": - if current_line[-1] == ":": - for typ in custom_types + list(BASE_TYPES): - items.append(CompletionItem(label=typ)) + # TODO: This is currently only correct for enums + # For structs, we'll need to get the type of the variable + for attr in self.ast.get_attributes_for_symbol(element): + items.append(CompletionItem(label=attr)) + completions = CompletionList(is_incomplete=False, items=items) + return completions + + if params.context.trigger_character == "@": + for dec in DECORATORS: + items.append(CompletionItem(label=dec)) + completions = CompletionList(is_incomplete=False, items=items) + return completions + + if params.context.trigger_character == ":": + for typ in custom_types + BASE_TYPES: + items.append(CompletionItem(label=typ, insert_text=f" {typ}")) + + completions = CompletionList(is_incomplete=False, items=items) + return completions + + if params.context.trigger_character == " ": + if current_line[-1] == ":": + for typ in custom_types + BASE_TYPES: + items.append(CompletionItem(label=typ)) - completions = CompletionList(is_incomplete=False, items=items) - return completions - return CompletionList(is_incomplete=False, items=[]) + completions = CompletionList(is_incomplete=False, items=items) + return completions - else: - return CompletionList(is_incomplete=False, items=[]) + return CompletionList(is_incomplete=False, items=[]) def get_completions( self, ls: LanguageServer, params: CompletionParams @@ -173,12 +164,6 @@ def get_completions( document = ls.workspace.get_text_document(params.text_document.uri) return self.get_completions_in_doc(document, params) - def _is_internal_fn(self, expression: str) -> bool: - return expression.startswith("self.") and "(" in expression - - def _is_state_var(self, expression: str) -> bool: - return expression.startswith("self.") and "(" not in expression - def _format_arg(self, arg: nodes.arg) -> str: if arg.annotation is None: return arg.arg @@ -203,55 +188,73 @@ def _format_arg(self, arg: nodes.arg) -> str: return f"{arg.arg}: {arg.annotation.id}" def _format_fn_signature(self, node: nodes.FunctionDef) -> str: - fn_name = node.name - arg_str = ", ".join([self._format_arg(arg) for arg in node.args.args]) - if node.returns: - if isinstance(node.returns, nodes.Subscript): - return_type_str = ( - f"{node.returns.value.id}[{node.returns.slice.value.value}]" - ) - else: - return_type_str = node.returns.id - return ( - f"(Internal Function) **{fn_name}**({arg_str}) -> **{return_type_str}**" - ) - return f"(Internal Function) **{fn_name}**({arg_str})" + pattern = r"def\s+(\w+)\((?:[^()]|\n)*\)(?:\s*->\s*[\w\[\], \n]+)?:" + match = re.search(pattern, node.node_source_code, re.MULTILINE) + if match: + function_def = match.group() + return f"(Internal Function) {function_def}" def hover_info(self, document: Document, pos: Position) -> Optional[str]: if len(document.lines) < pos.line: return None + og_line = document.lines[pos.line] word = get_word_at_cursor(og_line, pos.character) full_word = get_expression_at_cursor(og_line, pos.character) - if self._is_internal_fn(full_word): + if is_internal_fn(full_word): node = self.ast.find_function_declaration_node_for_name(word) - if node: - return self._format_fn_signature(node) - elif self._is_state_var(full_word): + return node and self._format_fn_signature(node) + + if is_state_var(full_word): node = self.ast.find_state_variable_declaration_node_for_name(word) - if node: - variable_type = node.annotation.id - return f"(State Variable) **{word}** : **{variable_type}**" - elif word in self.ast.get_structs(): + if not node: + return None + variable_type = node.annotation.id + return f"(State Variable) **{word}** : **{variable_type}**" + + if word in self.ast.get_structs(): node = self.ast.find_type_declaration_node_for_name(word) - if node: - return f"(Struct) **{word}**" - elif word in self.ast.get_enums(): + return node and f"(Struct) **{word}**" + + if word in self.ast.get_enums(): node = self.ast.find_type_declaration_node_for_name(word) - if node: - return f"(Enum) **{word}**" - elif word in self.ast.get_events(): + return node and f"(Enum) **{word}**" + + if word in self.ast.get_events(): node = self.ast.find_type_declaration_node_for_name(word) - if node: - return f"(Event) **{word}**" - elif word in self.ast.get_constants(): + return node and f"(Event) **{word}**" + + if word in self.ast.get_constants(): node = self.ast.find_state_variable_declaration_node_for_name(word) - if node: - variable_type = node.annotation.id - return f"(Constant) **{word}** : **{variable_type}**" - else: - return None + if not node: + return None + + variable_type = node.annotation.id + return f"(Constant) **{word}** : **{variable_type}**" + + return None + + def create_diagnostic( + self, line_num: int, character_start: int, character_end: int, message: str + ) -> Diagnostic: + """ + Helper function to create a diagnostic object. + + :param line_num: The line number of the diagnostic. + :param character_start: The starting character position of the diagnostic. + :param character_end: The ending character position of the diagnostic. + :param message: The diagnostic message. + :return: A Diagnostic object. + """ + return Diagnostic( + range=Range( + start=Position(line=line_num, character=character_start), + end=Position(line=line_num, character=character_end), + ), + message=message, + severity=DiagnosticSeverity.Warning, + ) def get_diagnostics(self, doc: Document) -> List[Diagnostic]: diagnostics = [] @@ -267,36 +270,33 @@ def get_diagnostics(self, doc: Document) -> List[Diagnostic]: compiler_data.vyper_module_folded except VyperException as e: if e.lineno is not None and e.col_offset is not None: - diagnostics.append(self._diagnostic_from_exception(e)) + diagnostics.append(diagnostic_from_exception(e)) else: for a in e.annotations: - diagnostics.append(self._diagnostic_from_exception(a)) + diagnostics.append(diagnostic_from_exception(a)) for warning in w: - match = compiled_pattern.match(str(warning.message)) - if not match: + m = deprecation_pattern.match(str(warning.message)) + if not m: continue - deprecated = match.group(1) - replacement = match.group(2) + deprecated = m.group(1) + replacement = m.group(2) replacements[deprecated] = replacement - # iterate over doc.lines and find all deprecated values - # and create a warning for each one at the correct position + # Iterate over doc.lines and find all deprecated values for i, line in enumerate(doc.lines): for deprecated, replacement in replacements.items(): - if deprecated in line: + for match in re.finditer(re.escape(deprecated), line): + character_start = match.start() + character_end = match.end() + diagnostic_message = ( + f"{deprecated} is deprecated. Please use {replacement} instead." + ) diagnostics.append( - Diagnostic( - range=Range( - start=Position( - line=i, character=line.index(deprecated) - ), - end=Position( - line=i, - character=line.index(deprecated) + len(deprecated), - ), - ), - message=f"{deprecated} is deprecated. Please use {replacement} instead.", - severity=2, + self.create_diagnostic( + line_num=i, + character_start=character_start, + character_end=character_end, + message=diagnostic_message, ) ) diff --git a/vyper_lsp/analyzer/SourceAnalyzer.py b/vyper_lsp/analyzer/SourceAnalyzer.py index 88848ff..d85ac96 100644 --- a/vyper_lsp/analyzer/SourceAnalyzer.py +++ b/vyper_lsp/analyzer/SourceAnalyzer.py @@ -1,3 +1,5 @@ +# REVIEW: not currently used + import re from typing import List, Optional from lark import UnexpectedInput, UnexpectedToken @@ -31,8 +33,8 @@ def format_parse_error(e): if isinstance(e, UnexpectedToken): expected = ", ".join(e.accepts or e.expected) return f"Unexpected token '{e.token}' at {e.line}:{e.column}. Expected one of: {expected}" - else: - return str(e) + + return str(e) LEGACY_VERSION_PRAGMA_REGEX = re.compile(r"^#\s*@version\s+(.*)$") @@ -40,12 +42,11 @@ def format_parse_error(e): def extract_version_pragma(line: str) -> Optional[str]: - if match := LEGACY_VERSION_PRAGMA_REGEX.match(line): - return match.group(1) - elif match := VERSION_PRAGMA_REGEX.match(line): - return match.group(1) - else: - return None + if (m := LEGACY_VERSION_PRAGMA_REGEX.match(line)) is not None: + return m.group(1) + elif (m := VERSION_PRAGMA_REGEX.match(line)) is not None: + return m.group(1) + return None # regex that matches numbers and underscores @@ -61,7 +62,7 @@ def __init__(self) -> None: def get_version_pragma(self, doc: Document) -> Optional[str]: doc_lines = doc.lines for line in doc_lines: - if version := extract_version_pragma(line): + if (version := extract_version_pragma(line)) is not None: return version def hover_info(self, doc: Document, pos: Position) -> Optional[str]: @@ -72,6 +73,7 @@ def get_parser_diagnostics(self, doc: Document) -> List[Diagnostic]: last_error = None def on_grammar_error(e: UnexpectedInput) -> bool: + # REVIEW: nonlocal!!! nonlocal last_error if ( last_error is not None diff --git a/vyper_lsp/ast.py b/vyper_lsp/ast.py index 45dbc0c..ad2c66c 100644 --- a/vyper_lsp/ast.py +++ b/vyper_lsp/ast.py @@ -34,17 +34,16 @@ def build_ast(self, src: str): self.ast_data = copy.deepcopy(compiler_data.vyper_module) except Exception as e: logger.error(f"Error generating AST, {e}") - pass + try: self.ast_data_unfolded = compiler_data.vyper_module_unfolded except Exception as e: logger.error(f"Error generating unfolded AST, {e}") - pass + try: self.ast_data_folded = compiler_data.vyper_module_folded except Exception as e: logger.error(f"Error generating folded AST, {e}") - pass @property def best_ast(self): @@ -54,8 +53,8 @@ def best_ast(self): return self.ast_data elif self.ast_data_folded: return self.ast_data_folded - else: - return None + + return None def get_descendants(self, *args, **kwargs): if self.best_ast is None: @@ -110,6 +109,7 @@ def get_state_variables(self): # missing from self.ast_data_unfolded and self.ast_data_folded when constants if self.ast_data is None: return [] + return [ node.target.id for node in self.ast_data.get_descendants(nodes.VariableDecl) ] @@ -155,8 +155,8 @@ def get_attributes_for_symbol(self, symbol: str): return self.get_struct_fields(symbol) elif isinstance(node, nodes.EnumDef): return self.get_enum_variants(symbol) - else: - return [] + + return [] def find_function_declaration_node_for_name(self, function: str): for node in self.get_descendants(nodes.FunctionDef): @@ -231,6 +231,8 @@ def find_top_level_node_at_pos(self, pos: Position) -> Optional[VyperNode]: if node.lineno <= pos.line and pos.line <= node.end_lineno: return node + return None + def find_nodes_referencing_symbol(self, symbol: str): # this only runs on subtrees return_nodes = [] @@ -256,3 +258,5 @@ def find_node_declaring_symbol(self, symbol: str): for node in self.get_descendants((nodes.AnnAssign, nodes.VariableDecl)): if node.target.id == symbol: return node + + return None diff --git a/vyper_lsp/grammar/grammar.lark b/vyper_lsp/grammar/grammar.lark index 0db1154..abf1244 100644 --- a/vyper_lsp/grammar/grammar.lark +++ b/vyper_lsp/grammar/grammar.lark @@ -1,5 +1,7 @@ // Vyper grammar for Lark +// REVIEW: not currently used! + // A module is a sequence of definitions and methods (and comments). // NOTE: Start symbol for the grammar // NOTE: Module can start with docstring diff --git a/vyper_lsp/logging.py b/vyper_lsp/logging.py index c78010a..54fb414 100644 --- a/vyper_lsp/logging.py +++ b/vyper_lsp/logging.py @@ -15,5 +15,6 @@ def __init__(self, ls): def emit(self, record): log_entry = self.format(record) - if self.ls: - self.ls.show_message_log(log_entry) + if not self.ls: + return + self.ls.show_message_log(log_entry) diff --git a/vyper_lsp/main.py b/vyper_lsp/main.py index 3777b8b..88c1af2 100755 --- a/vyper_lsp/main.py +++ b/vyper_lsp/main.py @@ -124,17 +124,17 @@ def go_to_definition( ) -> Optional[Location]: # TODO: Look for assignment nodes to find definition document = ls.workspace.get_text_document(params.text_document.uri) - range = navigator.find_declaration(document, params.position) - if range: - return Location(uri=params.text_document.uri, range=range) + range_ = navigator.find_declaration(document, params.position) + if range_: + return Location(uri=params.text_document.uri, range=range_) @server.feature(TEXT_DOCUMENT_REFERENCES) def find_references(ls: LanguageServer, params: DefinitionParams) -> List[Location]: document = ls.workspace.get_text_document(params.text_document.uri) return [ - Location(uri=params.text_document.uri, range=range) - for range in navigator.find_references(document, params.position) + Location(uri=params.text_document.uri, range=range_) + for range_ in navigator.find_references(document, params.position) ] @@ -160,8 +160,8 @@ def signature_help(ls: LanguageServer, params: SignatureHelpParams): @server.feature(TEXT_DOCUMENT_IMPLEMENTATION) def implementation(ls: LanguageServer, params: DefinitionParams): document = ls.workspace.get_text_document(params.text_document.uri) - range = navigator.find_implementation(document, params.position) - if range: + range_ = navigator.find_implementation(document, params.position) + if range_: return Location(uri=params.text_document.uri, range=range) diff --git a/vyper_lsp/navigation.py b/vyper_lsp/navigation.py index 7fcc8c5..ea9c3a1 100644 --- a/vyper_lsp/navigation.py +++ b/vyper_lsp/navigation.py @@ -6,7 +6,11 @@ from pygls.workspace import Document from vyper.ast import EnumDef, FunctionDef, VyperNode from vyper_lsp.ast import AST -from vyper_lsp.utils import get_expression_at_cursor, get_word_at_cursor +from vyper_lsp.utils import ( + get_expression_at_cursor, + get_word_at_cursor, + range_from_node, +) ENUM_VARIANT_PATTERN = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)\.([a-zA-Z_][a-zA-Z0-9_]*)") @@ -21,33 +25,35 @@ class ASTNavigator: def __init__(self, ast: AST): self.ast = ast - def _create_range_from_node(self, node: VyperNode) -> Range: - return Range( - start=Position(line=node.lineno - 1, character=node.col_offset), - end=Position(line=node.end_lineno - 1, character=node.end_col_offset), - ) - def _find_state_variable_declaration(self, word: str) -> Optional[Range]: node = self.ast.find_state_variable_declaration_node_for_name(word) if node: - return self._create_range_from_node(node) + return range_from_node(node) + + return None def _find_variable_declaration_under_node( self, node: VyperNode, symbol: str ) -> Optional[Range]: decl_node = AST.from_node(node).find_node_declaring_symbol(symbol) if decl_node: - return self._create_range_from_node(decl_node) + return range_from_node(decl_node) + + return None def _find_function_declaration(self, word: str) -> Optional[Range]: node = self.ast.find_function_declaration_node_for_name(word) if node: - return self._create_range_from_node(node) + return range_from_node(node) + + return None def find_type_declaration(self, word: str) -> Optional[Range]: node = self.ast.find_type_declaration_node_for_name(word) if node: - return self._create_range_from_node(node) + return range_from_node(node) + + return None def _is_state_var_decl(self, line, word): is_top_level = not line[0].isspace() @@ -65,50 +71,55 @@ def _is_internal_fn(self, line, word, expression): return is_def and (is_internal_call or is_internal_fn) def find_references(self, doc: Document, pos: Position) -> List[Range]: + # REVIEW: return is stylistically slightly different from ast analyzer if self.ast.ast_data is None: return [] - references = [] og_line = doc.lines[pos.line] word = get_word_at_cursor(og_line, pos.character) expression = get_expression_at_cursor(og_line, pos.character) - top_level_node = self.ast.find_top_level_node_at_pos(pos) - refs = [] + def finalize(refs): + return [range_from_node(ref) for ref in refs] if word in self.ast.get_enums(): - # find all references to this type - refs = self.ast.find_nodes_referencing_enum(word) - elif word in self.ast.get_structs() or word in self.ast.get_events(): - refs = self.ast.find_nodes_referencing_struct(word) - elif self._is_internal_fn(og_line, word, expression): - refs = self.ast.find_nodes_referencing_internal_function(word) - elif self._is_constant_decl(og_line, word): - refs = self.ast.find_nodes_referencing_constant(word) - elif self._is_state_var_decl(og_line, word): - refs = self.ast.find_nodes_referencing_state_variable(word) - elif isinstance(top_level_node, EnumDef): - # find all references to this enum variant - refs = self.ast.find_nodes_referencing_enum_variant( - top_level_node.name, word + return finalize(self.ast.find_nodes_referencing_enum(word)) + + if word in self.ast.get_structs() or word in self.ast.get_events(): + return finalize(self.ast.find_nodes_referencing_struct(word)) + + if self._is_internal_fn(og_line, word, expression): + return finalize(self.ast.find_nodes_referencing_internal_function(word)) + + if self._is_constant_decl(og_line, word): + return finalize(self.ast.find_nodes_referencing_constant(word)) + + if self._is_state_var_decl(og_line, word): + return finalize(self.ast.find_nodes_referencing_state_variable(word)) + + if isinstance(top_level_node, EnumDef): + return finalize( + self.ast.find_nodes_referencing_enum_variant(top_level_node.name, word) + ) + + if isinstance(top_level_node, FunctionDef): + return finalize( + AST.from_node(top_level_node).find_nodes_referencing_symbol(word) ) - elif isinstance(top_level_node, FunctionDef): - refs = AST.from_node(top_level_node).find_nodes_referencing_symbol(word) - for ref in refs: - range = self._create_range_from_node(ref) - references.append(range) - return references + return [] def _match_enum_variant(self, full_word: str) -> Optional[re.Match]: - match = ENUM_VARIANT_PATTERN.match(full_word) + match_ = ENUM_VARIANT_PATTERN.match(full_word) + if ( - match - and match.group(1) in self.ast.get_enums() - and match.group(2) in self.ast.get_enum_variants(match.group(1)) + match_ + and match_.group(1) in self.ast.get_enums() + and match_.group(2) in self.ast.get_enum_variants(match_.group(1)) ): - return match + return match_ + return None def find_declaration(self, document: Document, pos: Position) -> Optional[Range]: @@ -133,13 +144,15 @@ def find_declaration(self, document: Document, pos: Position) -> Optional[Range] elif word in self.ast.get_constants(): return self._find_state_variable_declaration(word) elif isinstance(top_level_node, FunctionDef): - range = self._find_variable_declaration_under_node(top_level_node, word) - if range: - return range - else: - match = self._match_enum_variant(full_word) - if match: - return self.find_type_declaration(match.group(1)) + range_ = self._find_variable_declaration_under_node(top_level_node, word) + if range_: + return range_ + + match_ = self._match_enum_variant(full_word) + if match_: + return self.find_type_declaration(match_.group(1)) + + return None def find_implementation(self, document: Document, pos: Position) -> Optional[Range]: og_line = document.lines[pos.line] @@ -151,8 +164,9 @@ def find_implementation(self, document: Document, pos: Position) -> Optional[Ran if expression.startswith("self."): return self._find_function_declaration(word) - elif og_line[0].isspace() and og_line.strip().startswith("def"): + + if og_line[0].isspace() and og_line.strip().startswith("def"): # only lookup external fns if we're in an interface def return self._find_function_declaration(word) - else: - return None + + return None diff --git a/vyper_lsp/utils.py b/vyper_lsp/utils.py index c4b809b..1f11efd 100644 --- a/vyper_lsp/utils.py +++ b/vyper_lsp/utils.py @@ -1,8 +1,13 @@ import logging +import string import re from pathlib import Path from importlib.metadata import version +from typing import Optional +from lsprotocol.types import Diagnostic, DiagnosticSeverity, Position, Range from packaging.version import Version +from vyper.ast import VyperNode +from vyper.exceptions import VyperException from vyper.compiler import CompilerData @@ -41,21 +46,25 @@ def is_attribute_access(line): return bool(re.match(reg, line.strip())) -def is_word_char(char): - # true for alnum and underscore - return char.isalnum() or char == "_" +_WORD_CHARS = string.ascii_letters + string.digits + "_" + + +# REVIEW: these get_.*_at_cursor helpers would benefit from having +# access to as much cursor information as possible (ex. line number), +# it could open up some possibilies when refactoring for performance def get_word_at_cursor(sentence: str, cursor_index: int) -> str: start = cursor_index end = cursor_index + # TODO: this could be a perf hotspot # Find the start of the word - while start > 0 and is_word_char(sentence[start - 1]): + while start > 0 and sentence[start - 1] in _WORD_CHARS: start -= 1 # Find the end of the word - while end < len(sentence) and is_word_char(sentence[end]): + while end < len(sentence) and sentence[end] in _WORD_CHARS: end += 1 # Extract the word @@ -65,42 +74,38 @@ def get_word_at_cursor(sentence: str, cursor_index: int) -> str: def _check_if_cursor_is_within_parenthesis(sentence: str, cursor_index: int) -> bool: - start = cursor_index - end = cursor_index - - # Find the start of the word - # TODO: this is a hacky way to do this, should be refactored - while start > 0 and sentence[start] != "(": - start -= 1 - - # Find the end of the word - # TODO: this is a hacky way to do this, should be refactored - while end < len(sentence) and sentence[end] != ")": - end += 1 - - if start != 0 and start < cursor_index and cursor_index < end: + # Find the nearest '(' before the cursor + start = sentence[:cursor_index][::-1].find("(") + if start != -1: + start = cursor_index - start - 1 + + # Find the nearest ')' after the cursor + end = sentence[cursor_index:].find(")") + if end != -1: + end += cursor_index + + # Check if cursor is within a valid pair of parentheses + if start != -1 and end != -1 and start < cursor_index < end: return True + return False def _get_entire_function_call(sentence: str, cursor_index: int) -> str: - start = cursor_index - end = cursor_index - - # Find the start of the word - # only skip spaces if we're within the parenthesis - while start > 0 and sentence[start - 1] != "(": - start -= 1 + # Regex pattern to match function calls + # This pattern looks for a word (function name), followed by optional spaces, + # and then parentheses with anything inside. + pattern = r"\b(?:\w+\.)*\w+\s*\([^)]*\)" - while start > 0 and sentence[start - 1] != " ": - start -= 1 + # Find all matches in the sentence + matches = [match for match in re.finditer(pattern, sentence)] - # Find the end of the word - while end < len(sentence) and sentence[end] != ")": - end += 1 + # Find the match that contains the cursor + for match in matches: + if match.start() <= cursor_index <= match.end(): + return match.group() - fn_call = sentence[start:end] - return fn_call + return "" # Return an empty string if no match is found def get_expression_at_cursor(sentence: str, cursor_index: int) -> str: @@ -112,17 +117,11 @@ def get_expression_at_cursor(sentence: str, cursor_index: int) -> str: end = cursor_index # Find the start of the word - while ( - start > 0 - and is_word_char(sentence[start - 1]) - or sentence[start - 1] in ".[]()" - ): + while start > 0 and sentence[start - 1] in _WORD_CHARS + ".[]()": start -= 1 # Find the end of the word - while ( - end < len(sentence) and is_word_char(sentence[end]) or sentence[end] in ".[]()" - ): + while end < len(sentence) and sentence[end] in _WORD_CHARS + ".[]()": end += 1 # Extract the word @@ -131,15 +130,76 @@ def get_expression_at_cursor(sentence: str, cursor_index: int) -> str: return word -def get_internal_fn_name_at_cursor(sentence: str, cursor_index: int) -> str: - # TODO: dont assume the fn call is at the end of the line - word = sentence.split("(")[0].split(" ")[-1].strip().split("self.")[-1] +def get_internal_fn_name_at_cursor(sentence: str, cursor_index: int) -> Optional[str]: + # TODO: Improve this function to handle more cases + # should be simpler, and handle when the cursor is on "self." before a fn name + # Split the sentence into segments at each 'self.' + segments = sentence.split("self.") + + # Accumulated length to keep track of the cursor's position relative to the original sentence + accumulated_length = 0 + + for segment in segments: + if not segment: + accumulated_length += len("self.") + continue + + # Update the accumulated length for each segment + segment_start = accumulated_length + segment_end = accumulated_length + len(segment) + accumulated_length = segment_end + 5 # Update for next segment + + # Check if the cursor is within the current segment + if segment_start <= cursor_index <= segment_end: + # Extract the function name from the segment + function_name = re.findall(r"\b\w+\s*\(", segment) + if function_name: + # Take the function name closest to the cursor + closest_fn = min( + function_name, + key=lambda fn: abs( + cursor_index - (segment_start + segment.find(fn)) + ), + ) + return closest_fn.split("(")[0].strip() - return word + return None def extract_enum_name(line: str): - match = re.match(r"enum\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*:", line) - if match: - return match.group(1) + m = re.match(r"enum\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*:", line) + if m: + return m.group(1) return None + + +def range_from_node(node: VyperNode) -> Range: + return Range( + start=Position(line=node.lineno - 1, character=node.col_offset), + end=Position(line=node.end_lineno - 1, character=node.end_col_offset), + ) + + +def range_from_exception(node: VyperException) -> Range: + return Range( + start=Position(line=node.lineno - 1, character=node.col_offset), + end=Position(line=node.end_lineno - 1, character=node.end_col_offset), + ) + + +def diagnostic_from_exception(node: VyperException) -> Diagnostic: + return Diagnostic( + range=range_from_exception(node), + message=str(node), + severity=DiagnosticSeverity.Error, + ) + + +# this looks like duplicated code, could be in utils +def is_internal_fn(expression: str) -> bool: + return expression.startswith("self.") and "(" in expression + + +# this looks like duplicated code, could be in utils +def is_state_var(expression: str) -> bool: + return expression.startswith("self.") and "(" not in expression