Skip to content

Commit

Permalink
refactor: python queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Desdaemon committed May 22, 2024
1 parent 20a554a commit 3272e43
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 70 deletions.
111 changes: 56 additions & 55 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ use tracing_subscriber::EnvFilter;
impl LanguageServer for Backend {
#[instrument(skip_all)]
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
let _blocker = self.root_setup.block();
let root = params.root_uri.and_then(|uri| uri.to_file_path().ok()).or_else(
#[allow(deprecated)]
|| params.root_path.map(PathBuf::from),
Expand Down Expand Up @@ -182,60 +181,6 @@ impl LanguageServer for Backend {
self.capabilities.pull_diagnostics.store(true, Relaxed);
}

let token = NumberOrString::String("odoo-lsp/postinit".to_string());
let mut progress = None;
if self
.client
.send_request::<WorkDoneProgressCreate>(WorkDoneProgressCreateParams { token: token.clone() })
.await
.is_ok()
{
_ = self
.client
.send_notification::<Progress>(ProgressParams {
token: token.clone(),
value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(WorkDoneProgressBegin {
title: "Indexing".to_string(),
..Default::default()
})),
})
.await;
progress = Some((&self.client, token.clone()));
}

self.ensure_nonoverlapping_roots();

for root in self.roots.iter() {
match self.index.add_root(&root, progress.clone(), false).await {
Ok(Some(results)) => {
info!(
target: "initialized",
"{} | {} modules | {} records | {} templates | {} models | {} components | {:.2}s",
root.display(),
results.module_count,
results.record_count,
results.template_count,
results.model_count,
results.component_count,
results.elapsed.as_secs_f64()
);
}
Err(err) => {
error!("could not add root {}:\n{err}", root.display());
}
_ => {}
}
}

if progress.is_some() {
_ = self
.client
.send_notification::<Progress>(ProgressParams {
token,
value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
})
.await;
}

Ok(InitializeResult {
server_info: None,
Expand Down Expand Up @@ -316,6 +261,62 @@ impl LanguageServer for Backend {
}])
.await;
}

let _blocker = self.root_setup.block();
let token = NumberOrString::String("odoo-lsp/postinit".to_string());
let mut progress = None;
if self
.client
.send_request::<WorkDoneProgressCreate>(WorkDoneProgressCreateParams { token: token.clone() })
.await
.is_ok()
{
_ = self
.client
.send_notification::<Progress>(ProgressParams {
token: token.clone(),
value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(WorkDoneProgressBegin {
title: "Indexing".to_string(),
..Default::default()
})),
})
.await;
progress = Some((&self.client, token.clone()));
}

self.ensure_nonoverlapping_roots();

for root in self.roots.iter() {
match self.index.add_root(&root, progress.clone(), false).await {
Ok(Some(results)) => {
info!(
target: "initialized",
"{} | {} modules | {} records | {} templates | {} models | {} components | {:.2}s",
root.display(),
results.module_count,
results.record_count,
results.template_count,
results.model_count,
results.component_count,
results.elapsed.as_secs_f64()
);
}
Err(err) => {
error!("could not add root {}:\n{err}", root.display());
}
_ => {}
}
}

if progress.is_some() {
_ = self
.client
.send_notification::<Progress>(ProgressParams {
token,
value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
})
.await;
}
}
#[instrument(skip_all, ret, fields(uri=params.text_document.uri.path()))]
async fn did_open(&self, params: DidOpenTextDocumentParams) {
Expand Down
38 changes: 31 additions & 7 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,20 @@ query! {

(class_definition
(block [
(function_definition (block) @SCOPE)
(decorated_definition (function_definition (block) @SCOPE)) ]))
(function_definition) @SCOPE
(decorated_definition
(decorator
(call
(attribute (identifier) @_api (#eq? @_api "api") (identifier) @_depends (#eq? @_depends "depends"))
(argument_list ((string) @MAPPED ","?)*)))
(function_definition) @SCOPE) ]))

(class_definition
(block
(decorated_definition
(decorator (_) @_)
(function_definition) @SCOPE)*)
(#not-match? @_ "^api.depends"))
}

/// (module (_)*)
Expand Down Expand Up @@ -1206,17 +1218,22 @@ baz = fields.Many2many(comodel_name='named')
class Foo(models.AbstractModel):
_name = 'foo'
_inherit = ['inherit_foo', 'inherit_bar']
foo = fields.Char(related='related')
@api.constrains('mapped', 'meh')
def foo(self):
what = self.sudo().mapped('ha.ha')
def bar(self):
pass
foo = fields.Foo()
@api.depends('mapped2')
@api.depends_context('uid')
@api.depends('mapped2', 'mapped3')
def another(self):
pass
def no_decorators(self):
pass
"#;
let ast = parser.parse(&contents[..], None).unwrap();
let query = PyCompletions::query();
Expand All @@ -1225,13 +1242,20 @@ class Foo(models.AbstractModel):
&["_name", "'foo'"],
&["_inherit", "'inherit_foo'", "'inherit_bar'"],
&["foo", "fields", "related", "'related'"],
// api.constrains('mapped', 'meh')
&["api", "constrains", "'mapped'"],
&["api", "constrains", "'meh'"],
&["<scope>"],
// scope detection with no .depends
// note that it goes later
&["self.sudo()", "mapped", "'ha.ha'"],
&["<scope>"],
&["api.constrains('mapped', 'meh')", "<scope>"],
&["foo", "fields"],
// scope detection with both .depends and non-.depends
// first, each of the original MAPPED rules are triggered
&["api", "depends", "'mapped2'"],
&["api", "depends", "'mapped3'"],
&["api", "depends", "'mapped2'", "'mapped3'", "<scope>"],
// no decorators
&["<scope>"],
];
let actual = cursor
Expand Down
2 changes: 1 addition & 1 deletion testing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
odoocmd = [lsp_devtools, "agent", "--", f"{__dirname}/../target/debug/odoo-lsp"]
else:
odoocmd = [f"{__dirname}/../target/debug/odoo-lsp"]
ODOO_ENV = {"RUST_LOG": "info,odoo_lsp=trace"}
ODOO_ENV = {"RUST_LOG": "info,odoo_lsp=trace", "ODOO_LSP_LOG": "1"}


@pytest.fixture
Expand Down
40 changes: 37 additions & 3 deletions testing/fixtures/basic/foo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@ class Foo(Model):
_name = "foo"

bar = fields.Char()
foo_m2o = fields.Many2one("foo")
# ^complete bar derived.bar foo foob
foo_o2m = fields.One2many("foo")
# ^complete bar derived.bar foo foob
foo_m2m = fields.Many2many("foo")
# ^complete bar derived.bar foo foob

def completions(self):
self.env["bar"]
# ^complete bar derived.bar foo foob
for foo in self:
foo.
# ^complete bar
# ^complete bar foo_m2m foo_m2o foo_o2m

def diagnostics(self):
self.foo
Expand All @@ -19,13 +25,41 @@ def diagnostics(self):
# ^diag Model `foo` has no field `foo`
self.env["fo"]
# ^diag `fo` is not a valid model name
self._context, self.pool, self.env


# TODO: More diagnostics for depends and compute
class Foob(Model):
_name = "foob"
_inherit = "bar"
# ^complete bar derived.bar foo foob

foo_id = fields.Many2one("foo")
barb = fields.Char(related='foo_id.bar')
# ^complete bar
# ^complete bar derived.bar foo foob
barb = fields.Char(related='foo_id.')
# ^complete bar foo_m2m foo_m2o foo_o2m
hoeh = fields.Char(compute="_non_existent_method")

@api.depends("foo_id")
# ^complete barb foo_id hoeh
@api.constrains("foo_id.wah")
# ^diag Dotted access is not supported in this context
@api.onchange("foo_id.wah")
# ^diag Dotted access is not supported in this context
def handler(self):
self.create({
"foo_id"
#^complete barb foo_id hoeh
})
self.create({
"foo_id": ...
#^complete barb foo_id hoeh
})

@api.depends("barb")
def missing_depends(self):
for record in self:
record.barb = bool(record.foo_id)

class NonModel:
__slots__ = ("foo", "bar")
Expand Down
20 changes: 16 additions & 4 deletions testing/fixtures/basic/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pytest_lsp import LanguageClient
from tree_sitter import Parser, Query, Tree
from deepdiff import DeepDiff # type: ignore
from deepdiff.diff import DiffLevel
from deepdiff.operator import BaseOperator
from tree_sitter import Language
import tree_sitter_python as tspython
from lsprotocol.types import (
Expand Down Expand Up @@ -39,6 +41,16 @@ def __init__(self):
self.diag = []
self.complete = []

class PositionOperator(BaseOperator):
def give_up_diffing(self, level: DiffLevel, diff_instance: DeepDiff) -> bool:
if isinstance(level.t1, Position) and isinstance(level.t2, Position):
if level.t1 != level.t2:
diff_instance.custom_report_result('values_changed', level) # type: ignore
return True
return False

def inc(position: Position):
return Position(position.line + 1, position.character + 1)

@pytest.mark.asyncio(scope="module")
async def test_python(client: LanguageClient, rootdir: str):
Expand Down Expand Up @@ -81,11 +93,11 @@ async def test_python(client: LanguageClient, rootdir: str):
)
await client.wait_for_notification("textDocument/publishDiagnostics")
actual_diagnostics = list(splay_diag(client.diagnostics[file.as_uri()]))
if diff := DeepDiff(expected[file].diag, actual_diagnostics):
if diff := DeepDiff(expected[file].diag, actual_diagnostics, custom_operators=[PositionOperator(types=[Position])], ignore_order=True):
for extra in diff.pop("iterable_item_added", {}).values(): # type: ignore
unexpected.append(f"diag: extra {extra}\n at {file}")
unexpected.append(f"diag: extra {extra}\n at {file}:{inc(extra[0])}") # type: ignore
for missing in diff.pop("iterable_item_removed", {}).values(): # type: ignore
unexpected.append(f"diag: missing {missing}\n at {file}")
unexpected.append(f"diag: missing {missing}\n at {file}:{inc(missing[0])}") # type: ignore
for mismatch in diff.pop("values_changed", {}).values(): # type: ignore
unexpected.append(
f"diag: expected={mismatch['old_value']!r} actual={mismatch['new_value']!r}\n at {file}"
Expand All @@ -112,7 +124,7 @@ async def test_python(client: LanguageClient, rootdir: str):
else:
node_text = ""
unexpected.append(
f"complete: actual={' '.join(actual)}\n at {file}:{pos}\n{' ' * node.start_point.column}{node_text}"
f"complete: actual={' '.join(actual)}\n at {file}:{inc(pos)}\n{' ' * node.start_point.column}{node_text}"
)
unexpected_len = len(unexpected)
assert not unexpected_len, "\n".join(unexpected)
Expand Down

0 comments on commit 3272e43

Please sign in to comment.