diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py index 51b9356..06faec9 100644 --- a/tests/test_tree_sitter.py +++ b/tests/test_tree_sitter.py @@ -3,8 +3,8 @@ from typing import List, Optional, Tuple from unittest import TestCase -from tree_sitter import Language, Parser, Tree -from tree_sitter.binding import LookaheadIterator, Node, Range +from tree_sitter import (Language, LookaheadIterator, Node, Parser, Query, + Range, Tree) LIB_PATH = path.join("build", "languages.so") @@ -1259,6 +1259,109 @@ def test_errors(self): PYTHON.query("(list))") PYTHON.query("(function_definition)") + def collect_matches( + self, + matches: List[Tuple[int, List[Tuple[Node, str]]]], + ) -> List[Tuple[int, List[Tuple[str, str]]]]: + return [(m[0], self.format_captures(m[1])) for m in matches] + + def format_captures( + self, + captures: List[Tuple[Node, str]], + ) -> List[Tuple[str, str]]: + return [(capture[1], capture[0].text.decode("utf-8")) for capture in captures] + + def assert_query_matches( + self, + language: Language, + query: Query, + source: bytes, + expected: List[Tuple[int, List[Tuple[str, str]]]] + ): + parser = Parser() + parser.set_language(language) + tree = parser.parse(source) + matches = query.matches(tree.root_node) + matches = self.collect_matches(matches) + self.assertEqual(matches, expected) + + def test_matches_with_simple_pattern(self): + query = JAVASCRIPT.query("(function_declaration name: (identifier) @fn-name)") + self.assert_query_matches( + JAVASCRIPT, + query, + b"function one() { two(); function three() {} }", + [(0, [('fn-name', 'one')]), (0, [('fn-name', 'three')])] + ) + + def test_matches_with_multiple_on_same_root(self): + query = JAVASCRIPT.query(""" + (class_declaration + name: (identifier) @the-class-name + (class_body + (method_definition + name: (property_identifier) @the-method-name))) + """) + self.assert_query_matches( + JAVASCRIPT, + query, + b""" + class Person { + // the constructor + constructor(name) { this.name = name; } + + // the getter + getFullName() { return this.name; } + } + """, + [ + (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]), + (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]), + ] + ) + + def test_matches_with_multiple_patterns_different_roots(self): + query = JAVASCRIPT.query(""" + (function_declaration name:(identifier) @fn-def) + (call_expression function:(identifier) @fn-ref) + """) + self.assert_query_matches( + JAVASCRIPT, + query, + b""" + function f1() { + f2(f3()); + } + """, + [(0, [("fn-def", "f1")]), (1, [("fn-ref", "f2")]), (1, [("fn-ref", "f3")])] + ) + + def test_matches_with_nesting_and_no_fields(self): + query = JAVASCRIPT.query(""" + (array + (array + (identifier) @x1 + (identifier) @x2)) + """) + self.assert_query_matches( + JAVASCRIPT, + query, + b""" + [[a]]; + [[c, d], [e, f, g, h]]; + [[h], [i]]; + """, + [ + (0, [("x1", "c"), ("x2", "d")]), + (0, [("x1", "e"), ("x2", "f")]), + (0, [("x1", "e"), ("x2", "g")]), + (0, [("x1", "f"), ("x2", "g")]), + (0, [("x1", "e"), ("x2", "h")]), + (0, [("x1", "f"), ("x2", "h")]), + (0, [("x1", "g"), ("x2", "h")]), + ] + ) + def test_captures(self): parser = Parser() parser.set_language(PYTHON) diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi index 9ef523c..e80e2a3 100644 --- a/tree_sitter/__init__.pyi +++ b/tree_sitter/__init__.pyi @@ -8,7 +8,6 @@ from tree_sitter.binding import \ from tree_sitter.binding import Node as Node from tree_sitter.binding import Parser as Parser from tree_sitter.binding import Query as Query -from tree_sitter.binding import QueryCapture as QueryCapture from tree_sitter.binding import Range as Range from tree_sitter.binding import Tree as Tree from tree_sitter.binding import TreeCursor as TreeCursor diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c index dff81b7..99a160b 100644 --- a/tree_sitter/binding.c +++ b/tree_sitter/binding.c @@ -63,6 +63,13 @@ typedef struct { TSQueryCapture capture; } QueryCapture; +typedef struct { + PyObject_HEAD + TSQueryMatch match; + PyObject *captures; + PyObject *pattern_index; +} QueryMatch; + typedef struct { PyObject_HEAD TSRange range; @@ -87,6 +94,7 @@ typedef struct { PyTypeObject *query_type; PyTypeObject *range_type; PyTypeObject *query_capture_type; + PyTypeObject *query_match_type; PyTypeObject *capture_eq_capture_type; PyTypeObject *capture_eq_string_type; PyTypeObject *capture_match_string_type; @@ -1726,6 +1734,32 @@ static PyObject *query_capture_new_internal(ModuleState *state, TSQueryCapture c return (PyObject *)self; } +static void match_dealloc(QueryMatch *self) { Py_TYPE(self)->tp_free(self); } + +static PyType_Slot query_match_type_slots[] = { + {Py_tp_doc, "A query match"}, + {Py_tp_dealloc, match_dealloc}, + {0, NULL}, +}; + +static PyType_Spec query_match_type_spec = { + .name = "tree_sitter.QueryMatch", + .basicsize = sizeof(QueryMatch), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT, + .slots = query_match_type_slots, +}; + +static PyObject *query_match_new_internal(ModuleState *state, TSQueryMatch match) { + QueryMatch *self = (QueryMatch *)state->query_match_type->tp_alloc(state->query_match_type, 0); + if (self != NULL) { + self->match = match; + self->captures = PyList_New(0); + self->pattern_index = 0; + } + return (PyObject *)self; +} + // Text Predicates static void capture_eq_capture_dealloc(CaptureEqCapture *self) { Py_TYPE(self)->tp_free(self); } @@ -1772,7 +1806,7 @@ static PyType_Spec capture_eq_string_type_spec = { // CaptureMatchString static PyType_Slot capture_match_string_type_slots[] = { - {Py_tp_doc, "Text predicate of the form #eq? @capture regex"}, + {Py_tp_doc, "Text predicate of the form #match? @capture regex"}, {Py_tp_dealloc, capture_match_string_dealloc}, {0, NULL}, }; @@ -1839,11 +1873,6 @@ static bool capture_match_string_is_instance(PyObject *self) { // Query -static PyObject *query_matches(Query *self, PyObject *args) { - PyErr_SetString(PyExc_NotImplementedError, "Not Implemented"); - return NULL; -} - static Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryMatch match, Tree *tree) { for (unsigned i = 0; i < match.capture_count; i++) { @@ -1956,6 +1985,90 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr return false; } +static PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) { + ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); + char *keywords[] = { + "node", "start_point", "end_point", "start_byte", "end_byte", NULL, + }; + + Node *node = NULL; + TSPoint start_point = {.row = 0, .column = 0}; + TSPoint end_point = {.row = UINT32_MAX, .column = UINT32_MAX}; + unsigned start_byte = 0, end_byte = UINT32_MAX; + + int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O|(II)(II)II", keywords, (PyObject **)&node, + &start_point.row, &start_point.column, &end_point.row, + &end_point.column, &start_byte, &end_byte); + if (!ok) { + return NULL; + } + + if (!PyObject_IsInstance((PyObject *)node, (PyObject *)state->node_type)) { + PyErr_SetString(PyExc_TypeError, "First argument to captures must be a Node"); + return NULL; + } + + ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); + ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); + ts_query_cursor_exec(state->query_cursor, self->query, node->node); + + QueryMatch *match = NULL; + PyObject *result = PyList_New(0); + if (result == NULL) { + goto error; + } + PyObject *captures_for_match = PyList_New(0); + + TSQueryMatch _match; + while (ts_query_cursor_next_match(state->query_cursor, &_match)) { + match = (QueryMatch *)query_match_new_internal(state, _match); + if (match == NULL) { + goto error; + } + PyObject *captures_for_match = PyList_New(0); + if (captures_for_match == NULL) { + goto error; + } + for (unsigned i = 0; i < _match.capture_count; i++) { + QueryCapture *capture = + (QueryCapture *)query_capture_new_internal(state, _match.captures[i]); + if (capture == NULL) { + Py_XDECREF(captures_for_match); + goto error; + } + if (satisfies_text_predicates(self, _match, (Tree *)node->tree)) { + PyObject *capture_name = + PyList_GetItem(self->capture_names, capture->capture.index); + PyObject *capture_node = + node_new_internal(state, capture->capture.node, node->tree); + PyObject *item = PyTuple_Pack(2, capture_node, capture_name); + if (item == NULL) { + Py_XDECREF(captures_for_match); + Py_XDECREF(capture_node); + goto error; + } + Py_XDECREF(capture_node); + PyList_Append(captures_for_match, item); + Py_XDECREF(item); + } + Py_XDECREF(capture); + } + PyObject *pattern_index = PyLong_FromLong(_match.pattern_index); + PyObject *tuple_match = PyTuple_Pack(2, pattern_index, captures_for_match); + PyList_Append(result, tuple_match); + Py_XDECREF(tuple_match); + Py_XDECREF(pattern_index); + Py_XDECREF(captures_for_match); + Py_XDECREF(match); + } + return result; + +error: + Py_XDECREF(result); + Py_XDECREF(match); + return NULL; +} + static PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); char *keywords[] = { @@ -2029,7 +2142,7 @@ static void query_dealloc(Query *self) { static PyMethodDef query_methods[] = { {.ml_name = "matches", .ml_meth = (PyCFunction)query_matches, - .ml_flags = METH_VARARGS, + .ml_flags = METH_KEYWORDS | METH_VARARGS, .ml_doc = "matches(node)\n--\n\n\ Get a list of all of the matches within the given node."}, { @@ -2835,6 +2948,8 @@ PyMODINIT_FUNC PyInit_binding(void) { state->range_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &range_type_spec, NULL); state->query_capture_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_capture_type_spec, NULL); + state->query_match_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_match_type_spec, NULL); state->capture_eq_capture_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_capture_type_spec, NULL); state->capture_eq_string_type = @@ -2854,6 +2969,7 @@ PyMODINIT_FUNC PyInit_binding(void) { (AddObjectRef(module, "Query", (PyObject *)state->query_type) < 0) || (AddObjectRef(module, "Range", (PyObject *)state->range_type) < 0) || (AddObjectRef(module, "QueryCapture", (PyObject *)state->query_capture_type) < 0) || + (AddObjectRef(module, "QueryMatch", (PyObject *)state->query_match_type) < 0) || (AddObjectRef(module, "CaptureEqCapture", (PyObject *)state->capture_eq_capture_type) < 0) || (AddObjectRef(module, "CaptureEqString", (PyObject *)state->capture_eq_string_type) < 0) || diff --git a/tree_sitter/binding.pyi b/tree_sitter/binding.pyi index 91b79d0..a8d008a 100644 --- a/tree_sitter/binding.pyi +++ b/tree_sitter/binding.pyi @@ -346,8 +346,14 @@ class Parser: class Query: """A set of patterns to search for in a syntax tree.""" - # Not implemented yet. Return type is wrong - def matches(self, node: Node) -> None: + def matches( + self, + node: Node, + start_point: Optional[Tuple[int, int]] = None, + end_point: Optional[Tuple[int, int]] = None, + start_byte: Optional[int] = None, + end_byte: Optional[int] = None, + ) -> List[Tuple[int, List[Tuple[Node, str]]]]: """Get a list of all of the matches within the given node.""" ... def captures( @@ -361,9 +367,6 @@ class Query: """Get a list of all of the captures within the given node.""" ... -class QueryCapture: - pass - class LookaheadIterator(Iterable): def reset(self, language: int, state: int) -> None: """Reset the lookahead iterator to a new language and parse state.