Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Token system to avoid cross-module query dependencies #16275

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 51 additions & 24 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ mod tests {

let use_def = use_def_map(&db, scope);
let binding = use_def.first_public_binding(foo).unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Import(_)));
assert!(matches!(binding.kind(&db, file), DefinitionKind::Import(_)));
}

#[test]
Expand Down Expand Up @@ -533,7 +533,10 @@ mod tests {
.expect("symbol to exist"),
)
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::ImportFrom(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::ImportFrom(_)
));
}

#[test]
Expand All @@ -553,7 +556,10 @@ mod tests {
let binding = use_def
.first_public_binding(global_table.symbol_id_by_name("x").expect("symbol exists"))
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Assignment(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::Assignment(_)
));
}

#[test]
Expand All @@ -570,7 +576,7 @@ mod tests {
.unwrap();

assert!(matches!(
binding.kind(&db),
binding.kind(&db, file),
DefinitionKind::AugmentedAssignment(_)
));
}
Expand Down Expand Up @@ -606,7 +612,10 @@ y = 2
let binding = use_def
.first_public_binding(class_table.symbol_id_by_name("x").expect("symbol exists"))
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Assignment(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::Assignment(_)
));
}

#[test]
Expand Down Expand Up @@ -643,7 +652,10 @@ y = 2
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Assignment(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::Assignment(_)
));
}

#[test]
Expand Down Expand Up @@ -682,7 +694,10 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::Parameter(_)
));
}
let args_binding = use_def
.first_public_binding(
Expand All @@ -692,7 +707,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
args_binding.kind(&db, file),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
Expand All @@ -703,7 +718,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
kwargs_binding.kind(&db, file),
DefinitionKind::VariadicKeywordParameter(_)
));
}
Expand Down Expand Up @@ -735,7 +750,10 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let binding = use_def
.first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::Parameter(_)
));
}
let args_binding = use_def
.first_public_binding(
Expand All @@ -745,7 +763,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
args_binding.kind(&db, file),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
Expand All @@ -756,7 +774,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
kwargs_binding.kind(&db, file),
DefinitionKind::VariadicKeywordParameter(_)
));
}
Expand Down Expand Up @@ -803,7 +821,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
)
.unwrap();
assert!(matches!(
binding.kind(&db),
binding.kind(&db, file),
DefinitionKind::Comprehension(_)
));
}
Expand Down Expand Up @@ -843,7 +861,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file));

let binding = use_def.first_binding_at_use(element_use_id).unwrap();
let DefinitionKind::Comprehension(comprehension) = binding.kind(&db) else {
let DefinitionKind::Comprehension(comprehension) = binding.kind(&db, file) else {
panic!("expected generator definition")
};
let target = comprehension.target();
Expand Down Expand Up @@ -925,7 +943,10 @@ with item1 as x, item2 as y:
let binding = use_def
.first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists"))
.expect("Expected with item definition for {name}");
assert!(matches!(binding.kind(&db), DefinitionKind::WithItem(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::WithItem(_)
));
}
}

Expand All @@ -948,7 +969,10 @@ with context() as (x, y):
let binding = use_def
.first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists"))
.expect("Expected with item definition for {name}");
assert!(matches!(binding.kind(&db), DefinitionKind::WithItem(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::WithItem(_)
));
}
}

Expand Down Expand Up @@ -992,7 +1016,10 @@ def func():
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Function(_)));
assert!(matches!(
binding.kind(&db, file),
DefinitionKind::Function(_)
));
}

#[test]
Expand Down Expand Up @@ -1093,7 +1120,7 @@ class C[T]:
let x_use_id = x_use_expr_name.scoped_use_id(&db, scope);
let use_def = use_def_map(&db, scope);
let binding = use_def.first_binding_at_use(x_use_id).unwrap();
let DefinitionKind::Assignment(assignment) = binding.kind(&db) else {
let DefinitionKind::Assignment(assignment) = binding.kind(&db, file) else {
panic!("should be an assignment definition")
};
let ast::Expr::NumberLiteral(ast::ExprNumberLiteral {
Expand Down Expand Up @@ -1226,7 +1253,7 @@ match subject:
let binding = use_def
.first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists"))
.expect("Expected with item definition for {name}");
if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db) {
if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db, file) {
assert_eq!(pattern.index(), expected_index);
} else {
panic!("Expected match pattern definition for {name}");
Expand Down Expand Up @@ -1256,7 +1283,7 @@ match 1:
let binding = use_def
.first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists"))
.expect("Expected with item definition for {name}");
if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db) {
if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db, file) {
assert_eq!(pattern.index(), expected_index);
} else {
panic!("Expected match pattern definition for {name}");
Expand All @@ -1277,7 +1304,7 @@ match 1:
.first_public_binding(global_table.symbol_id_by_name("x").unwrap())
.unwrap();

assert!(matches!(binding.kind(&db), DefinitionKind::For(_)));
assert!(matches!(binding.kind(&db, file), DefinitionKind::For(_)));
}

#[test]
Expand All @@ -1296,8 +1323,8 @@ match 1:
.first_public_binding(global_table.symbol_id_by_name("y").unwrap())
.unwrap();

assert!(matches!(x_binding.kind(&db), DefinitionKind::For(_)));
assert!(matches!(y_binding.kind(&db), DefinitionKind::For(_)));
assert!(matches!(x_binding.kind(&db, file), DefinitionKind::For(_)));
assert!(matches!(y_binding.kind(&db, file), DefinitionKind::For(_)));
}

#[test]
Expand All @@ -1313,6 +1340,6 @@ match 1:
.first_public_binding(global_table.symbol_id_by_name("a").unwrap())
.unwrap();

assert!(matches!(binding.kind(&db), DefinitionKind::For(_)));
assert!(matches!(binding.kind(&db, file), DefinitionKind::For(_)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub struct Definition<'db> {
#[no_eq]
#[return_ref]
#[tracked]
pub(crate) kind: DefinitionKind<'db>,
kind_inner: DefinitionKind<'db>,

/// This is a dedicated field to avoid accessing `kind` to compute this value.
pub(crate) is_reexported: bool,
Expand All @@ -50,6 +50,18 @@ impl<'db> Definition<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}

/// Returns the definition's kind which gives access to its AST.
///
/// `query_file` is the file for which the current query performs type inference.
/// It acts as a token of prove that we aren't accessing an AST node from a different file
/// than in which the current enclosing Salsa query (which would lead to cross-file dependencies).
Comment on lines +57 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// It acts as a token of prove that we aren't accessing an AST node from a different file
/// than in which the current enclosing Salsa query (which would lead to cross-file dependencies).
/// It acts as a token of proof that we aren't accessing an AST node from a different file
/// to the one from which the current enclosing Salsa query is being called.
/// Doing so would lead to cross-file dependencies, hurting incremental computation.

#[inline]
pub(crate) fn kind(self, db: &'db dyn Db, query_file: File) -> &'db DefinitionKind<'db> {
debug_assert_eq!(query_file, self.scope(db).file(db));

self.kind_inner(db)
}
}

#[derive(Copy, Clone, Debug)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(crate) struct Expression<'db> {
#[no_eq]
#[tracked]
#[return_ref]
pub(crate) node_ref: AstNodeRef<ast::Expr>,
node_ref_inner: AstNodeRef<ast::Expr>,

/// Should this expression be inferred as a normal expression or a type expression?
pub(crate) kind: ExpressionKind,
Expand All @@ -51,6 +51,17 @@ pub(crate) struct Expression<'db> {
}

impl<'db> Expression<'db> {
/// Returns a reference to the expression's AST node.
///
/// `query_file` is the file for which the current query performs type inference.
/// It acts as a token of prove that we aren't accessing an AST node from a different file
/// than in which the current enclosing Salsa query (which would lead to cross-file dependencies).
Comment on lines +57 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// It acts as a token of prove that we aren't accessing an AST node from a different file
/// than in which the current enclosing Salsa query (which would lead to cross-file dependencies).
/// It acts as a token of proof that we aren't accessing an AST node from a different file
/// to the one from which the current enclosing Salsa query is being called.
/// Doing so would lead to cross-file dependencies, hurting incremental computation.

#[inline]
pub(crate) fn node_ref(self, db: &'db dyn Db, query_file: File) -> &'db AstNodeRef<ast::Expr> {
debug_assert_eq!(self.file(db), query_file);
self.node_ref_inner(db)
}

pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
Expand Down
23 changes: 18 additions & 5 deletions crates/red_knot_python_semantic/src/semantic_index/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,34 @@ pub struct ScopeId<'db> {
}

impl<'db> ScopeId<'db> {
pub(crate) fn is_function_like(self, db: &'db dyn Db) -> bool {
self.node(db).scope_kind().is_function_like()
pub(crate) fn is_function_like(self, db: &'db dyn Db, query_file: File) -> bool {
self.node(db, query_file).scope_kind().is_function_like()
}

pub(crate) fn node(self, db: &dyn Db) -> &NodeWithScopeKind {
#[inline]
pub(crate) fn node(self, db: &dyn Db, query_file: File) -> &NodeWithScopeKind {
debug_assert_eq!(self.file(db), query_file);
self.node_unchecked(db)
}

/// Returns the scope's node without checking if the query's file matches the scope's file
/// (which is desired to avoid cross-module query dependencies).
///
/// Use this method when in situations where it's okay to add a cross-module dependency.
/// For example, when emitting diagnostics.
#[inline]
pub(crate) fn node_unchecked(self, db: &dyn Db) -> &NodeWithScopeKind {
self.scope(db).node()
}

pub(crate) fn scope(self, db: &dyn Db) -> &Scope {
fn scope(self, db: &dyn Db) -> &Scope {
semantic_index(db, self.file(db)).scope(self.file_scope_id(db))
}

#[cfg(test)]
pub(crate) fn name(self, db: &'db dyn Db) -> &'db str {
match self.node(db) {
// Use `self.node` if this ever becomes a non-testing function.
match self.node_unchecked(db) {
NodeWithScopeKind::Module => "<module>",
NodeWithScopeKind::Class(class) | NodeWithScopeKind::ClassTypeParameters(class) => {
class.name.as_str()
Expand Down
Loading
Loading