Skip to content

Commit

Permalink
parse fn and module names w regex
Browse files Browse the repository at this point in the history
and small renaming
  • Loading branch information
z80dev committed Dec 14, 2024
1 parent 867d985 commit a1b0ce7
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 79 deletions.
6 changes: 3 additions & 3 deletions tests/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def baz():
)

analyzer = AstAnalyzer(ast)
completions = analyzer.get_completions_in_doc(doc, params)
completions = analyzer._get_completions_in_doc(doc, params)
assert len(completions.items) == 1
assert "foo" in [c.label for c in completions.items]

Expand Down Expand Up @@ -71,7 +71,7 @@ def baz():
)

analyzer = AstAnalyzer(ast)
completions = analyzer.get_completions_in_doc(doc, params)
completions = analyzer._get_completions_in_doc(doc, params)
assert len(completions.items) == 2
assert "BAR" in [c.label for c in completions.items]
assert "BAZ" in [c.label for c in completions.items]
Expand Down Expand Up @@ -103,7 +103,7 @@ def bar():
)

analyzer = AstAnalyzer(ast)
completions = analyzer.get_completions_in_doc(doc, params)
completions = analyzer._get_completions_in_doc(doc, params)
assert len(completions.items) == 7
labels = [c.label for c in completions.items]
assert "internal" in labels
Expand Down
17 changes: 4 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,9 @@ def test_get_expression_at_cursor():
assert utils.get_expression_at_cursor(text, 21) == "self.baz (1,2,3)"


def test_get_internal_fn_name_at_cursor():
text = "self.foo = 123"
assert utils.get_internal_fn_name_at_cursor(text, 0) is None
assert utils.get_internal_fn_name_at_cursor(text, 1) is None
assert utils.get_internal_fn_name_at_cursor(text, 5) is None
assert utils.get_internal_fn_name_at_cursor(text, 12) is None

text = "foo_bar = self.baz (1,2,3)"
assert utils.get_internal_fn_name_at_cursor(text, 0) is None
assert utils.get_internal_fn_name_at_cursor(text, 4) is None
assert utils.get_internal_fn_name_at_cursor(text, 21) == "baz"
def test_parse_fncall_expression():
text = "self.foo()"
assert utils.parse_fncall_expression(text) == ("self", "foo")

text = "self.foo(self.bar())"
assert utils.get_internal_fn_name_at_cursor(text, 7) == "foo"
assert utils.get_internal_fn_name_at_cursor(text, 15) == "bar"
assert utils.parse_fncall_expression(text) == ("self", "bar")
44 changes: 18 additions & 26 deletions vyper_lsp/analyzer/AstAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
)
from pygls.workspace import Document
from vyper.ast import nodes
from vyper_lsp import utils
from vyper_lsp.analyzer.BaseAnalyzer import Analyzer
from vyper_lsp.ast import AST
from vyper_lsp.utils import (
format_fn,
get_expression_at_cursor,
get_word_at_cursor,
get_installed_vyper_version,
get_internal_fn_name_at_cursor,
)
from lsprotocol.types import (
CompletionItem,
Expand Down Expand Up @@ -60,33 +60,26 @@ def __init__(self, ast: AST) -> None:
def signature_help(
self, doc: Document, params: SignatureHelpParams
) -> Optional[SignatureHelp]:
logger.info("signature help triggered")
# TODO: Implement checking external functions, module functions, and interfaces
current_line = doc.lines[params.position.line]
expression = get_expression_at_cursor(
current_line, params.position.character - 1
)
logger.info(f"expression: {expression}")
# regex for matching 'module.function'
fncall_pattern = "(.*)\\.(.*)"

if matches := re.match(fncall_pattern, expression):
module, fn = matches.groups()
logger.info(f"looking up function {fn} in module {module}")
if module in self.ast.imports:
logger.info("found module")
if fn := self.ast.imports[module].functions[fn]:
logger.info(f"args: {fn.arguments}")
parsed = utils.parse_fncall_expression(expression)
if parsed is None:
return None
module, fn_name = parsed

logger.info(f"looking up function {fn_name} in module {module}")
if module in self.ast.imports:
logger.info("found module")
if fn := self.ast.imports[module].functions[fn_name]:
logger.info(f"args: {fn.arguments}")

# this returns for all external functions
# TODO: Implement checking interfaces
if not expression.startswith("self."):
return None

# TODO: Implement checking external functions, module functions, and interfaces
fn_name = get_internal_fn_name_at_cursor(
current_line, params.position.character - 1
)

if not fn_name:
return None

Expand Down Expand Up @@ -196,7 +189,7 @@ def _dot_completions_for_element(

return completions

def get_completions_in_doc(
def _get_completions_in_doc(
self, document: Document, params: CompletionParams
) -> CompletionList:
items = []
Expand Down Expand Up @@ -270,7 +263,7 @@ def get_completions(
self, ls: LanguageServer, params: CompletionParams
) -> CompletionList:
document = ls.workspace.get_text_document(params.text_document.uri)
return self.get_completions_in_doc(document, params)
return self._get_completions_in_doc(document, params)

def _format_arg(self, arg: nodes.arg) -> str:
if arg.annotation is None:
Expand Down Expand Up @@ -302,13 +295,13 @@ def _format_fn_signature(self, node: nodes.FunctionDef) -> str:
function_def = match.group()
return f"(Internal Function) {function_def}"

def is_internal_fn(self, expression: str):
def _is_internal_fn(self, expression: str):
if not expression.startswith("self."):
return False
fn_name = expression.split("self.")[-1]
return fn_name in self.ast.functions and self.ast.functions[fn_name].is_internal

def is_state_var(self, expression: str):
def _is_state_var(self, expression: str):
if not expression.startswith("self."):
return False
var_name = expression.split("self.")[-1]
Expand All @@ -322,12 +315,11 @@ def hover_info(self, doc: Document, pos: Position) -> Optional[str]:
word = get_word_at_cursor(og_line, pos.character)
full_word = get_expression_at_cursor(og_line, pos.character)

if self.is_internal_fn(full_word):
logger.info("looking for internal fn")
if self._is_internal_fn(full_word):
node = self.ast.find_function_declaration_node_for_name(word)
return node and self._format_fn_signature(node)

if self.is_state_var(full_word):
if self._is_state_var(full_word):
node = self.ast.find_state_variable_declaration_node_for_name(word)
if not node:
return None
Expand Down
50 changes: 13 additions & 37 deletions vyper_lsp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from pathlib import Path
from importlib.metadata import version
from typing import Optional
from typing import Optional, Tuple
from lsprotocol.types import Diagnostic, DiagnosticSeverity, Position, Range
from packaging.version import Version
from pygls.workspace import Document
Expand Down Expand Up @@ -125,42 +125,6 @@ def get_expression_at_cursor(sentence: str, cursor_index: int) -> str:
return word


def get_internal_fn_name_at_cursor(sentence: str, cursor_index: int) -> Optional[str]:
# TODO: Improve this function to handle more cases
# should be simpler, and handle when the cursor is on "self." before a fn name
# Split the sentence into segments at each 'self.'
segments = sentence.split("self.")

# Accumulated length to keep track of the cursor's position relative to the original sentence
accumulated_length = 0

for segment in segments:
if not segment:
accumulated_length += len("self.")
continue

# Update the accumulated length for each segment
segment_start = accumulated_length
segment_end = accumulated_length + len(segment)
accumulated_length = segment_end + 5 # Update for next segment

# Check if the cursor is within the current segment
if segment_start <= cursor_index <= segment_end:
# Extract the function name from the segment
function_name = re.findall(r"\b\w+\s*\(", segment)
if function_name:
# Take the function name closest to the cursor
closest_fn = min(
function_name,
key=lambda fn: abs(
cursor_index - (segment_start + segment.find(fn))
),
)
return closest_fn.split("(")[0].strip()

return None


def extract_enum_name(line: str):
m = re.match(r"enum\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*:", line)
if m:
Expand Down Expand Up @@ -238,3 +202,15 @@ def format_fn(func) -> str:
f"def __{escape_underscores(func.name)}__({args}){return_value}: _{mutability}_"
)
return out


def parse_fncall_expression(expression: str) -> Optional[Tuple[str, str]]:
# regex for matching 'module.function' or 'module.function(args)', not capturing args
fncall_pattern = "(.*)\\.([^\\(]+)(?:\\(.*\\))?"

if matches := re.match(fncall_pattern, expression):
groups = matches.groups()
module, fn = groups
if "(" in module:
module = module.split("(")[-1]
return module, fn

0 comments on commit a1b0ce7

Please sign in to comment.