Skip to content

Commit

Permalink
feat: implement matches
Browse files Browse the repository at this point in the history
  • Loading branch information
amaanq authored and ObserverOfTime committed Feb 25, 2024
1 parent b0e732d commit bc59e3c
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 15 deletions.
107 changes: 105 additions & 2 deletions tests/test_tree_sitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import List, Optional, Tuple
from unittest import TestCase

from tree_sitter import Language, Parser, Tree
from tree_sitter.binding import LookaheadIterator, Node, Range
from tree_sitter import (Language, LookaheadIterator, Node, Parser, Query,
Range, Tree)

LIB_PATH = path.join("build", "languages.so")

Expand Down Expand Up @@ -1259,6 +1259,109 @@ def test_errors(self):
PYTHON.query("(list))")
PYTHON.query("(function_definition)")

def collect_matches(
self,
matches: List[Tuple[int, List[Tuple[Node, str]]]],
) -> List[Tuple[int, List[Tuple[str, str]]]]:
return [(m[0], self.format_captures(m[1])) for m in matches]

def format_captures(
self,
captures: List[Tuple[Node, str]],
) -> List[Tuple[str, str]]:
return [(capture[1], capture[0].text.decode("utf-8")) for capture in captures]

def assert_query_matches(
self,
language: Language,
query: Query,
source: bytes,
expected: List[Tuple[int, List[Tuple[str, str]]]]
):
parser = Parser()
parser.set_language(language)
tree = parser.parse(source)
matches = query.matches(tree.root_node)
matches = self.collect_matches(matches)
self.assertEqual(matches, expected)

def test_matches_with_simple_pattern(self):
query = JAVASCRIPT.query("(function_declaration name: (identifier) @fn-name)")
self.assert_query_matches(
JAVASCRIPT,
query,
b"function one() { two(); function three() {} }",
[(0, [('fn-name', 'one')]), (0, [('fn-name', 'three')])]
)

def test_matches_with_multiple_on_same_root(self):
query = JAVASCRIPT.query("""
(class_declaration
name: (identifier) @the-class-name
(class_body
(method_definition
name: (property_identifier) @the-method-name)))
""")
self.assert_query_matches(
JAVASCRIPT,
query,
b"""
class Person {
// the constructor
constructor(name) { this.name = name; }
// the getter
getFullName() { return this.name; }
}
""",
[
(0, [("the-class-name", "Person"), ("the-method-name", "constructor")]),
(0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]),
]
)

def test_matches_with_multiple_patterns_different_roots(self):
query = JAVASCRIPT.query("""
(function_declaration name:(identifier) @fn-def)
(call_expression function:(identifier) @fn-ref)
""")
self.assert_query_matches(
JAVASCRIPT,
query,
b"""
function f1() {
f2(f3());
}
""",
[(0, [("fn-def", "f1")]), (1, [("fn-ref", "f2")]), (1, [("fn-ref", "f3")])]
)

def test_matches_with_nesting_and_no_fields(self):
query = JAVASCRIPT.query("""
(array
(array
(identifier) @x1
(identifier) @x2))
""")
self.assert_query_matches(
JAVASCRIPT,
query,
b"""
[[a]];
[[c, d], [e, f, g, h]];
[[h], [i]];
""",
[
(0, [("x1", "c"), ("x2", "d")]),
(0, [("x1", "e"), ("x2", "f")]),
(0, [("x1", "e"), ("x2", "g")]),
(0, [("x1", "f"), ("x2", "g")]),
(0, [("x1", "e"), ("x2", "h")]),
(0, [("x1", "f"), ("x2", "h")]),
(0, [("x1", "g"), ("x2", "h")]),
]
)

def test_captures(self):
parser = Parser()
parser.set_language(PYTHON)
Expand Down
1 change: 0 additions & 1 deletion tree_sitter/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ from tree_sitter.binding import \
from tree_sitter.binding import Node as Node
from tree_sitter.binding import Parser as Parser
from tree_sitter.binding import Query as Query
from tree_sitter.binding import QueryCapture as QueryCapture
from tree_sitter.binding import Range as Range
from tree_sitter.binding import Tree as Tree
from tree_sitter.binding import TreeCursor as TreeCursor
Expand Down
130 changes: 123 additions & 7 deletions tree_sitter/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ typedef struct {
TSQueryCapture capture;
} QueryCapture;

typedef struct {
PyObject_HEAD
TSQueryMatch match;
PyObject *captures;
PyObject *pattern_index;
} QueryMatch;

typedef struct {
PyObject_HEAD
TSRange range;
Expand All @@ -87,6 +94,7 @@ typedef struct {
PyTypeObject *query_type;
PyTypeObject *range_type;
PyTypeObject *query_capture_type;
PyTypeObject *query_match_type;
PyTypeObject *capture_eq_capture_type;
PyTypeObject *capture_eq_string_type;
PyTypeObject *capture_match_string_type;
Expand Down Expand Up @@ -1726,6 +1734,32 @@ static PyObject *query_capture_new_internal(ModuleState *state, TSQueryCapture c
return (PyObject *)self;
}

static void match_dealloc(QueryMatch *self) { Py_TYPE(self)->tp_free(self); }

static PyType_Slot query_match_type_slots[] = {
{Py_tp_doc, "A query match"},
{Py_tp_dealloc, match_dealloc},
{0, NULL},
};

static PyType_Spec query_match_type_spec = {
.name = "tree_sitter.QueryMatch",
.basicsize = sizeof(QueryMatch),
.itemsize = 0,
.flags = Py_TPFLAGS_DEFAULT,
.slots = query_match_type_slots,
};

static PyObject *query_match_new_internal(ModuleState *state, TSQueryMatch match) {
QueryMatch *self = (QueryMatch *)state->query_match_type->tp_alloc(state->query_match_type, 0);
if (self != NULL) {
self->match = match;
self->captures = PyList_New(0);
self->pattern_index = 0;
}
return (PyObject *)self;
}

// Text Predicates

static void capture_eq_capture_dealloc(CaptureEqCapture *self) { Py_TYPE(self)->tp_free(self); }
Expand Down Expand Up @@ -1772,7 +1806,7 @@ static PyType_Spec capture_eq_string_type_spec = {

// CaptureMatchString
static PyType_Slot capture_match_string_type_slots[] = {
{Py_tp_doc, "Text predicate of the form #eq? @capture regex"},
{Py_tp_doc, "Text predicate of the form #match? @capture regex"},
{Py_tp_dealloc, capture_match_string_dealloc},
{0, NULL},
};
Expand Down Expand Up @@ -1839,11 +1873,6 @@ static bool capture_match_string_is_instance(PyObject *self) {

// Query

static PyObject *query_matches(Query *self, PyObject *args) {
PyErr_SetString(PyExc_NotImplementedError, "Not Implemented");
return NULL;
}

static Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryMatch match,
Tree *tree) {
for (unsigned i = 0; i < match.capture_count; i++) {
Expand Down Expand Up @@ -1956,6 +1985,90 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr
return false;
}

static PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) {
ModuleState *state = PyType_GetModuleState(Py_TYPE(self));
char *keywords[] = {
"node", "start_point", "end_point", "start_byte", "end_byte", NULL,
};

Node *node = NULL;
TSPoint start_point = {.row = 0, .column = 0};
TSPoint end_point = {.row = UINT32_MAX, .column = UINT32_MAX};
unsigned start_byte = 0, end_byte = UINT32_MAX;

int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O|(II)(II)II", keywords, (PyObject **)&node,
&start_point.row, &start_point.column, &end_point.row,
&end_point.column, &start_byte, &end_byte);
if (!ok) {
return NULL;
}

if (!PyObject_IsInstance((PyObject *)node, (PyObject *)state->node_type)) {
PyErr_SetString(PyExc_TypeError, "First argument to captures must be a Node");
return NULL;
}

ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte);
ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point);
ts_query_cursor_exec(state->query_cursor, self->query, node->node);

QueryMatch *match = NULL;
PyObject *result = PyList_New(0);
if (result == NULL) {
goto error;
}
PyObject *captures_for_match = PyList_New(0);

TSQueryMatch _match;
while (ts_query_cursor_next_match(state->query_cursor, &_match)) {
match = (QueryMatch *)query_match_new_internal(state, _match);
if (match == NULL) {
goto error;
}
PyObject *captures_for_match = PyList_New(0);
if (captures_for_match == NULL) {
goto error;
}
for (unsigned i = 0; i < _match.capture_count; i++) {
QueryCapture *capture =
(QueryCapture *)query_capture_new_internal(state, _match.captures[i]);
if (capture == NULL) {
Py_XDECREF(captures_for_match);
goto error;
}
if (satisfies_text_predicates(self, _match, (Tree *)node->tree)) {
PyObject *capture_name =
PyList_GetItem(self->capture_names, capture->capture.index);
PyObject *capture_node =
node_new_internal(state, capture->capture.node, node->tree);
PyObject *item = PyTuple_Pack(2, capture_node, capture_name);
if (item == NULL) {
Py_XDECREF(captures_for_match);
Py_XDECREF(capture_node);
goto error;
}
Py_XDECREF(capture_node);
PyList_Append(captures_for_match, item);
Py_XDECREF(item);
}
Py_XDECREF(capture);
}
PyObject *pattern_index = PyLong_FromLong(_match.pattern_index);
PyObject *tuple_match = PyTuple_Pack(2, pattern_index, captures_for_match);
PyList_Append(result, tuple_match);
Py_XDECREF(tuple_match);
Py_XDECREF(pattern_index);
Py_XDECREF(captures_for_match);
Py_XDECREF(match);
}
return result;

error:
Py_XDECREF(result);
Py_XDECREF(match);
return NULL;
}

static PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) {
ModuleState *state = PyType_GetModuleState(Py_TYPE(self));
char *keywords[] = {
Expand Down Expand Up @@ -2029,7 +2142,7 @@ static void query_dealloc(Query *self) {
static PyMethodDef query_methods[] = {
{.ml_name = "matches",
.ml_meth = (PyCFunction)query_matches,
.ml_flags = METH_VARARGS,
.ml_flags = METH_KEYWORDS | METH_VARARGS,
.ml_doc = "matches(node)\n--\n\n\
Get a list of all of the matches within the given node."},
{
Expand Down Expand Up @@ -2835,6 +2948,8 @@ PyMODINIT_FUNC PyInit_binding(void) {
state->range_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &range_type_spec, NULL);
state->query_capture_type =
(PyTypeObject *)PyType_FromModuleAndSpec(module, &query_capture_type_spec, NULL);
state->query_match_type =
(PyTypeObject *)PyType_FromModuleAndSpec(module, &query_match_type_spec, NULL);
state->capture_eq_capture_type =
(PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_capture_type_spec, NULL);
state->capture_eq_string_type =
Expand All @@ -2854,6 +2969,7 @@ PyMODINIT_FUNC PyInit_binding(void) {
(AddObjectRef(module, "Query", (PyObject *)state->query_type) < 0) ||
(AddObjectRef(module, "Range", (PyObject *)state->range_type) < 0) ||
(AddObjectRef(module, "QueryCapture", (PyObject *)state->query_capture_type) < 0) ||
(AddObjectRef(module, "QueryMatch", (PyObject *)state->query_match_type) < 0) ||
(AddObjectRef(module, "CaptureEqCapture", (PyObject *)state->capture_eq_capture_type) <
0) ||
(AddObjectRef(module, "CaptureEqString", (PyObject *)state->capture_eq_string_type) < 0) ||
Expand Down
13 changes: 8 additions & 5 deletions tree_sitter/binding.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,14 @@ class Parser:
class Query:
"""A set of patterns to search for in a syntax tree."""

# Not implemented yet. Return type is wrong
def matches(self, node: Node) -> None:
def matches(
self,
node: Node,
start_point: Optional[Tuple[int, int]] = None,
end_point: Optional[Tuple[int, int]] = None,
start_byte: Optional[int] = None,
end_byte: Optional[int] = None,
) -> List[Tuple[int, List[Tuple[Node, str]]]]:
"""Get a list of all of the matches within the given node."""
...
def captures(
Expand All @@ -361,9 +367,6 @@ class Query:
"""Get a list of all of the captures within the given node."""
...

class QueryCapture:
pass

class LookaheadIterator(Iterable):
def reset(self, language: int, state: int) -> None:
"""Reset the lookahead iterator to a new language and parse state.
Expand Down

0 comments on commit bc59e3c

Please sign in to comment.