diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index a2c766cccc570..dc203c9343ba1 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -492,7 +492,7 @@ mod tests { let use_def = use_def_map(&db, scope); let binding = use_def.first_public_binding(foo).unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::Import(_))); + assert!(matches!(binding.kind(&db, file), DefinitionKind::Import(_))); } #[test] @@ -533,7 +533,10 @@ mod tests { .expect("symbol to exist"), ) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::ImportFrom(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::ImportFrom(_) + )); } #[test] @@ -553,7 +556,10 @@ mod tests { let binding = use_def .first_public_binding(global_table.symbol_id_by_name("x").expect("symbol exists")) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::Assignment(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::Assignment(_) + )); } #[test] @@ -570,7 +576,7 @@ mod tests { .unwrap(); assert!(matches!( - binding.kind(&db), + binding.kind(&db, file), DefinitionKind::AugmentedAssignment(_) )); } @@ -606,7 +612,10 @@ y = 2 let binding = use_def .first_public_binding(class_table.symbol_id_by_name("x").expect("symbol exists")) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::Assignment(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::Assignment(_) + )); } #[test] @@ -643,7 +652,10 @@ y = 2 .expect("symbol exists"), ) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::Assignment(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::Assignment(_) + )); } #[test] @@ -682,7 +694,10 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): .expect("symbol exists"), ) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::Parameter(_) + )); } let args_binding = use_def .first_public_binding( @@ -692,7 +707,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): ) .unwrap(); assert!(matches!( - args_binding.kind(&db), + args_binding.kind(&db, file), DefinitionKind::VariadicPositionalParameter(_) )); let kwargs_binding = use_def @@ -703,7 +718,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): ) .unwrap(); assert!(matches!( - kwargs_binding.kind(&db), + kwargs_binding.kind(&db, file), DefinitionKind::VariadicKeywordParameter(_) )); } @@ -735,7 +750,10 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let binding = use_def .first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists")) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::Parameter(_) + )); } let args_binding = use_def .first_public_binding( @@ -745,7 +763,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): ) .unwrap(); assert!(matches!( - args_binding.kind(&db), + args_binding.kind(&db, file), DefinitionKind::VariadicPositionalParameter(_) )); let kwargs_binding = use_def @@ -756,7 +774,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): ) .unwrap(); assert!(matches!( - kwargs_binding.kind(&db), + kwargs_binding.kind(&db, file), DefinitionKind::VariadicKeywordParameter(_) )); } @@ -803,7 +821,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): ) .unwrap(); assert!(matches!( - binding.kind(&db), + binding.kind(&db, file), DefinitionKind::Comprehension(_) )); } @@ -843,7 +861,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file)); let binding = use_def.first_binding_at_use(element_use_id).unwrap(); - let DefinitionKind::Comprehension(comprehension) = binding.kind(&db) else { + let DefinitionKind::Comprehension(comprehension) = binding.kind(&db, file) else { panic!("expected generator definition") }; let target = comprehension.target(); @@ -925,7 +943,10 @@ with item1 as x, item2 as y: let binding = use_def .first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists")) .expect("Expected with item definition for {name}"); - assert!(matches!(binding.kind(&db), DefinitionKind::WithItem(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::WithItem(_) + )); } } @@ -948,7 +969,10 @@ with context() as (x, y): let binding = use_def .first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists")) .expect("Expected with item definition for {name}"); - assert!(matches!(binding.kind(&db), DefinitionKind::WithItem(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::WithItem(_) + )); } } @@ -992,7 +1016,10 @@ def func(): .expect("symbol exists"), ) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::Function(_))); + assert!(matches!( + binding.kind(&db, file), + DefinitionKind::Function(_) + )); } #[test] @@ -1093,7 +1120,7 @@ class C[T]: let x_use_id = x_use_expr_name.scoped_use_id(&db, scope); let use_def = use_def_map(&db, scope); let binding = use_def.first_binding_at_use(x_use_id).unwrap(); - let DefinitionKind::Assignment(assignment) = binding.kind(&db) else { + let DefinitionKind::Assignment(assignment) = binding.kind(&db, file) else { panic!("should be an assignment definition") }; let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { @@ -1226,7 +1253,7 @@ match subject: let binding = use_def .first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists")) .expect("Expected with item definition for {name}"); - if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db) { + if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db, file) { assert_eq!(pattern.index(), expected_index); } else { panic!("Expected match pattern definition for {name}"); @@ -1256,7 +1283,7 @@ match 1: let binding = use_def .first_public_binding(global_table.symbol_id_by_name(name).expect("symbol exists")) .expect("Expected with item definition for {name}"); - if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db) { + if let DefinitionKind::MatchPattern(pattern) = binding.kind(&db, file) { assert_eq!(pattern.index(), expected_index); } else { panic!("Expected match pattern definition for {name}"); @@ -1277,7 +1304,7 @@ match 1: .first_public_binding(global_table.symbol_id_by_name("x").unwrap()) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::For(_))); + assert!(matches!(binding.kind(&db, file), DefinitionKind::For(_))); } #[test] @@ -1296,8 +1323,8 @@ match 1: .first_public_binding(global_table.symbol_id_by_name("y").unwrap()) .unwrap(); - assert!(matches!(x_binding.kind(&db), DefinitionKind::For(_))); - assert!(matches!(y_binding.kind(&db), DefinitionKind::For(_))); + assert!(matches!(x_binding.kind(&db, file), DefinitionKind::For(_))); + assert!(matches!(y_binding.kind(&db, file), DefinitionKind::For(_))); } #[test] @@ -1313,6 +1340,6 @@ match 1: .first_public_binding(global_table.symbol_id_by_name("a").unwrap()) .unwrap(); - assert!(matches!(binding.kind(&db), DefinitionKind::For(_))); + assert!(matches!(binding.kind(&db, file), DefinitionKind::For(_))); } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 19e49f244e493..d81cb17b9bdcf 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -38,7 +38,7 @@ pub struct Definition<'db> { #[no_eq] #[return_ref] #[tracked] - pub(crate) kind: DefinitionKind<'db>, + kind_inner: DefinitionKind<'db>, /// This is a dedicated field to avoid accessing `kind` to compute this value. pub(crate) is_reexported: bool, @@ -50,6 +50,18 @@ impl<'db> Definition<'db> { pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { self.file_scope(db).to_scope_id(db, self.file(db)) } + + /// Returns the definition's kind which gives access to its AST. + /// + /// `query_file` is the file for which the current query performs type inference. + /// It acts as a token of prove that we aren't accessing an AST node from a different file + /// than in which the current enclosing Salsa query (which would lead to cross-file dependencies). + #[inline] + pub(crate) fn kind(self, db: &'db dyn Db, query_file: File) -> &'db DefinitionKind<'db> { + debug_assert_eq!(query_file, self.scope(db).file(db)); + + self.kind_inner(db) + } } #[derive(Copy, Clone, Debug)] diff --git a/crates/red_knot_python_semantic/src/semantic_index/expression.rs b/crates/red_knot_python_semantic/src/semantic_index/expression.rs index 6189c4cb68c56..3218e36486c03 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/expression.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/expression.rs @@ -42,7 +42,7 @@ pub(crate) struct Expression<'db> { #[no_eq] #[tracked] #[return_ref] - pub(crate) node_ref: AstNodeRef, + node_ref_inner: AstNodeRef, /// Should this expression be inferred as a normal expression or a type expression? pub(crate) kind: ExpressionKind, @@ -51,6 +51,17 @@ pub(crate) struct Expression<'db> { } impl<'db> Expression<'db> { + /// Returns a reference to the expression's AST node. + /// + /// `query_file` is the file for which the current query performs type inference. + /// It acts as a token of prove that we aren't accessing an AST node from a different file + /// than in which the current enclosing Salsa query (which would lead to cross-file dependencies). + #[inline] + pub(crate) fn node_ref(self, db: &'db dyn Db, query_file: File) -> &'db AstNodeRef { + debug_assert_eq!(self.file(db), query_file); + self.node_ref_inner(db) + } + pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { self.file_scope(db).to_scope_id(db, self.file(db)) } diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index 5e678c2526703..8c691591d94ee 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -110,21 +110,34 @@ pub struct ScopeId<'db> { } impl<'db> ScopeId<'db> { - pub(crate) fn is_function_like(self, db: &'db dyn Db) -> bool { - self.node(db).scope_kind().is_function_like() + pub(crate) fn is_function_like(self, db: &'db dyn Db, query_file: File) -> bool { + self.node(db, query_file).scope_kind().is_function_like() } - pub(crate) fn node(self, db: &dyn Db) -> &NodeWithScopeKind { + #[inline] + pub(crate) fn node(self, db: &dyn Db, query_file: File) -> &NodeWithScopeKind { + debug_assert_eq!(self.file(db), query_file); + self.node_unchecked(db) + } + + /// Returns the scope's node without checking if the query's file matches the scope's file + /// (which is desired to avoid cross-module query dependencies). + /// + /// Use this method when in situations where it's okay to add a cross-module dependency. + /// For example, when emitting diagnostics. + #[inline] + pub(crate) fn node_unchecked(self, db: &dyn Db) -> &NodeWithScopeKind { self.scope(db).node() } - pub(crate) fn scope(self, db: &dyn Db) -> &Scope { + fn scope(self, db: &dyn Db) -> &Scope { semantic_index(db, self.file(db)).scope(self.file_scope_id(db)) } #[cfg(test)] pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { - match self.node(db) { + // Use `self.node` if this ever becomes a non-testing function. + match self.node_unchecked(db) { NodeWithScopeKind::Module => "", NodeWithScopeKind::Class(class) | NodeWithScopeKind::ClassTypeParameters(class) => { class.name.as_str() diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 625a48689f4dc..9d4faa8fd070b 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -3144,7 +3144,8 @@ impl<'db> FunctionType<'db> { /// would depend on the function's AST and rerun for every change in that file. #[salsa::tracked(return_ref)] pub fn signature(self, db: &'db dyn Db) -> Signature<'db> { - let function_stmt_node = self.body_scope(db).node(db).expect_function(); + let body_scope = self.body_scope(db); + let function_stmt_node = body_scope.node(db, body_scope.file(db)).expect_function(); let internal_signature = self.internal_signature(db); if function_stmt_node.decorator_list.is_empty() { return internal_signature; @@ -3165,8 +3166,10 @@ impl<'db> FunctionType<'db> { /// scope. fn internal_signature(self, db: &'db dyn Db) -> Signature<'db> { let scope = self.body_scope(db); - let function_stmt_node = scope.node(db).expect_function(); - let definition = semantic_index(db, scope.file(db)).definition(function_stmt_node); + let file = scope.file(db); + + let function_stmt_node = scope.node(db, file).expect_function(); + let definition = semantic_index(db, file).definition(function_stmt_node); Signature::from_function(db, definition, function_stmt_node) } @@ -3490,9 +3493,10 @@ impl<'db> Class<'db> { #[salsa::tracked(return_ref)] fn explicit_bases_query(self, db: &'db dyn Db) -> Box<[Type<'db>]> { - let class_stmt = self.node(db); + let file = self.file(db); + let class_stmt = self.node(db, file); - let class_definition = semantic_index(db, self.file(db)).definition(class_stmt); + let class_definition = semantic_index(db, file).definition(class_stmt); class_stmt .bases() @@ -3510,18 +3514,20 @@ impl<'db> Class<'db> { /// ## Note /// Only call this function from queries in the same file or your /// query depends on the AST of another file (bad!). - fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef { - self.body_scope(db).node(db).expect_class() + #[inline] + fn node(self, db: &'db dyn Db, query_file: File) -> &'db ast::StmtClassDef { + self.body_scope(db).node(db, query_file).expect_class() } /// Return the types of the decorators on this class #[salsa::tracked(return_ref)] fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> { - let class_stmt = self.node(db); + let file = self.file(db); + let class_stmt = self.node(db, file); if class_stmt.decorator_list.is_empty() { return Box::new([]); } - let class_definition = semantic_index(db, self.file(db)).definition(class_stmt); + let class_definition = semantic_index(db, file).definition(class_stmt); class_stmt .decorator_list .iter() @@ -3577,14 +3583,17 @@ impl<'db> Class<'db> { /// ## Note /// Only call this function from queries in the same file or your /// query depends on the AST of another file (bad!). - fn explicit_metaclass(self, db: &'db dyn Db) -> Option> { - let class_stmt = self.node(db); + fn explicit_metaclass(self, db: &'db dyn Db, query_file: File) -> Option> { + let file = self.file(db); + debug_assert_eq!(query_file, file); + + let class_stmt = self.node(db, file); let metaclass_node = &class_stmt .arguments .as_ref()? .find_keyword("metaclass")? .value; - let class_definition = semantic_index(db, self.file(db)).definition(class_stmt); + let class_definition = semantic_index(db, file).definition(class_stmt); let metaclass_ty = definition_expression_type(db, class_definition, metaclass_node); Some(metaclass_ty) } @@ -3608,7 +3617,7 @@ impl<'db> Class<'db> { return Ok(SubclassOfType::subclass_of_unknown()); } - let explicit_metaclass = self.explicit_metaclass(db); + let explicit_metaclass = self.explicit_metaclass(db, self.file(db)); let (metaclass, class_metaclass_was_from) = if let Some(metaclass) = explicit_metaclass { (metaclass, self) } else if let Some(base_class) = base_classes.next() { @@ -4015,9 +4024,10 @@ impl<'db> TypeAliasType<'db> { #[salsa::tracked] pub fn value_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.rhs_scope(db); + let file = scope.file(db); - let type_alias_stmt_node = scope.node(db).expect_type_alias(); - let definition = semantic_index(db, scope.file(db)).definition(type_alias_stmt_node); + let type_alias_stmt_node = scope.node(db, file).expect_type_alias(); + let definition = semantic_index(db, file).definition(type_alias_stmt_node); definition_expression_type(db, definition, &type_alias_stmt_node.value) } diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 44c0d298c2682..77720378a6920 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -315,7 +315,7 @@ impl<'db> CallBindingError<'db> { if let Some(func_lit) = callable_ty.into_function_literal() { let func_lit_scope = func_lit.body_scope(context.db()); let mut span = Span::from(func_lit_scope.file(context.db())); - let node = func_lit_scope.node(context.db()); + let node = func_lit_scope.node_unchecked(context.db()); if let Some(func_def) = node.as_function() { let range = func_def .parameters diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 02b269bb2a3cb..17cf12c4b75d3 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -117,8 +117,9 @@ fn infer_definition_types_cycle_recovery<'db>( input: Definition<'db>, ) -> TypeInference<'db> { tracing::trace!("infer_definition_types_cycle_recovery"); - let mut inference = TypeInference::empty(input.scope(db)); - let category = input.kind(db).category(); + let scope = input.scope(db); + let mut inference = TypeInference::empty(scope); + let category = input.kind(db, scope.file(db)).category(); if category.is_declaration() { inference .declarations @@ -142,7 +143,7 @@ pub(crate) fn infer_definition_types<'db>( let file = definition.file(db); let _span = tracing::trace_span!( "infer_definition_types", - range = ?definition.kind(db).target_range(), + range = ?definition.kind(db, file).target_range(), file = %file.path(db) ) .entered(); @@ -165,7 +166,7 @@ pub(crate) fn infer_deferred_types<'db>( let _span = tracing::trace_span!( "infer_deferred_types", definition = ?definition.as_id(), - range = ?definition.kind(db).target_range(), + range = ?definition.kind(db, file).target_range(), file = %file.path(db) ) .entered(); @@ -188,7 +189,7 @@ pub(crate) fn infer_expression_types<'db>( let _span = tracing::trace_span!( "infer_expression_types", expression = ?expression.as_id(), - range = ?expression.node_ref(db).range(), + range = ?expression.node_ref(db, file).range(), file = %file.path(db) ) .entered(); @@ -203,13 +204,22 @@ pub(crate) fn infer_expression_types<'db>( /// This is a small helper around [`infer_expression_types()`] to reduce the boilerplate. /// Use [`infer_expression_type()`] if it isn't guaranteed that `expression` is in the same file to /// avoid cross-file query dependencies. +/// +/// `query_file` is the file for which the current query performs type inference. +/// It acts as a token of prove that we aren't accessing an AST node from a different file +/// than in which the current enclosing Salsa query (which would lead to cross-file dependencies). pub(super) fn infer_same_file_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, + query_file: File, ) -> Type<'db> { let inference = infer_expression_types(db, expression); let scope = expression.scope(db); - inference.expression_type(expression.node_ref(db).scoped_expression_id(db, scope)) + inference.expression_type( + expression + .node_ref(db, query_file) + .scoped_expression_id(db, scope), + ) } /// Infers the type of an expression where the expression might come from another file. @@ -225,7 +235,7 @@ pub(crate) fn infer_expression_type<'db>( expression: Expression<'db>, ) -> Type<'db> { // It's okay to call the "same file" version here because we're inside a salsa query. - infer_same_file_expression_type(db, expression) + infer_same_file_expression_type(db, expression, expression.scope(db).file(db)) } /// Infer the types for an [`Unpack`] operation. @@ -238,11 +248,11 @@ pub(crate) fn infer_expression_type<'db>( pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { let file = unpack.file(db); let _span = - tracing::trace_span!("infer_unpack_types", range=?unpack.range(db), file=%file.path(db)) + tracing::trace_span!("infer_unpack_types", range=?unpack.range(db, file), file=%file.path(db)) .entered(); let mut unpacker = Unpacker::new(db, unpack.scope(db)); - unpacker.unpack(unpack.target(db), unpack.value(db)); + unpacker.unpack(unpack.target(db, file), unpack.value(db)); unpacker.finish() } @@ -537,7 +547,7 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_region_scope(&mut self, scope: ScopeId<'db>) { - let node = scope.node(self.db()); + let node = scope.node(self.db(), self.file()); match node { NodeWithScopeKind::Module => { let parsed = parsed_module(self.db().upcast(), self.file()); @@ -603,7 +613,7 @@ impl<'db> TypeInferenceBuilder<'db> { .iter() .filter_map(|(definition, ty)| { // Filter out class literals that result from imports - if let DefinitionKind::Class(class) = definition.kind(self.db()) { + if let DefinitionKind::Class(class) = definition.kind(self.db(), self.file()) { ty.inner_type() .into_class_literal() .map(|ty| (ty.class, class.node())) @@ -759,7 +769,7 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_region_definition(&mut self, definition: Definition<'db>) { - match definition.kind(self.db()) { + match definition.kind(self.db(), self.file()) { DefinitionKind::Function(function) => { self.infer_function_definition(function.node(), definition); } @@ -850,7 +860,7 @@ impl<'db> TypeInferenceBuilder<'db> { // to use end-of-scope semantics. This would require custom and possibly a complex // implementation to allow this "split" to happen. - match definition.kind(self.db()) { + match definition.kind(self.db(), self.file()) { DefinitionKind::Function(function) => self.infer_function_deferred(function.node()), DefinitionKind::Class(class) => self.infer_class_deferred(class.node()), _ => {} @@ -860,10 +870,10 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_region_expression(&mut self, expression: Expression<'db>) { match expression.kind(self.db()) { ExpressionKind::Normal => { - self.infer_expression_impl(expression.node_ref(self.db())); + self.infer_expression_impl(expression.node_ref(self.db(), self.file())); } ExpressionKind::TypeExpression => { - self.infer_type_expression(expression.node_ref(self.db())); + self.infer_type_expression(expression.node_ref(self.db(), self.file())); } } } @@ -900,7 +910,7 @@ impl<'db> TypeInferenceBuilder<'db> { } fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) { - debug_assert!(binding.kind(self.db()).category().is_binding()); + debug_assert!(binding.kind(self.db(), self.file()).category().is_binding()); let use_def = self.index.use_def_map(binding.file_scope(self.db())); let declarations = use_def.declarations_at_binding(binding); let mut bound_ty = ty; @@ -935,7 +945,10 @@ impl<'db> TypeInferenceBuilder<'db> { declaration: Definition<'db>, ty: TypeAndQualifiers<'db>, ) { - debug_assert!(declaration.kind(self.db()).category().is_declaration()); + debug_assert!(declaration + .kind(self.db(), self.file()) + .category() + .is_declaration()); let use_def = self.index.use_def_map(declaration.file_scope(self.db())); let prior_bindings = use_def.bindings_at_declaration(declaration); // unbound_ty is Never because for this check we don't care about unbound @@ -965,8 +978,14 @@ impl<'db> TypeInferenceBuilder<'db> { definition: Definition<'db>, declared_and_inferred_ty: &DeclaredAndInferredType<'db>, ) { - debug_assert!(definition.kind(self.db()).category().is_binding()); - debug_assert!(definition.kind(self.db()).category().is_declaration()); + debug_assert!(definition + .kind(self.db(), self.file()) + .category() + .is_binding()); + debug_assert!(definition + .kind(self.db(), self.file()) + .category() + .is_declaration()); let (declared_ty, inferred_ty) = match *declared_and_inferred_ty { DeclaredAndInferredType::AreTheSame(ty) => (ty.into(), ty), @@ -3520,7 +3539,7 @@ impl<'db> TypeInferenceBuilder<'db> { // a local variable or not in function-like scopes. If a variable has any bindings in a // function-like scope, it is considered a local variable; it never references another // scope. (At runtime, it would use the `LOAD_FAST` opcode.) - if has_bindings_in_this_scope && scope.is_function_like(db) { + if has_bindings_in_this_scope && scope.is_function_like(db, self.file()) { return Symbol::Unbound; } @@ -3535,7 +3554,7 @@ impl<'db> TypeInferenceBuilder<'db> { // scope differently (because an unbound name there falls back to builtins), so // check only function-like scopes. let enclosing_scope_id = enclosing_scope_file_id.to_scope_id(db, current_file); - if !enclosing_scope_id.is_function_like(db) { + if !enclosing_scope_id.is_function_like(db, self.file()) { continue; } diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 8822a3a18f190..92935bbc51aa1 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -199,7 +199,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let expression_node = expression.node_ref(self.db).node(); + let expression_node = expression + .node_ref(self.db, self.scope().file(self.db)) + .node(); self.evaluate_expression_node_constraint(expression_node, expression, is_positive) } @@ -246,6 +248,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { symbol_table(self.db, self.scope()) } + /// Returns the `constraint`'s scope. + /// + /// This is also the scope of the enclosing salsa query. fn scope(&self) -> ScopeId<'db> { match self.constraint { ConstraintNode::Expression(expression) => expression.scope(self.db), @@ -473,7 +478,10 @@ impl<'db> NarrowingConstraintsBuilder<'db> { subject: Expression<'db>, singleton: ast::Singleton, ) -> Option> { - if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() { + if let Some(ast::ExprName { id, .. }) = subject + .node_ref(self.db, self.scope().file(self.db)) + .as_name_expr() + { // SAFETY: we should always have a symbol for every Name node. let symbol = self.symbols().symbol_id_by_name(id).unwrap(); @@ -495,10 +503,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { subject: Expression<'db>, cls: Expression<'db>, ) -> Option> { - if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() { + if let Some(ast::ExprName { id, .. }) = subject + .node_ref(self.db, self.scope().file(self.db)) + .as_name_expr() + { // SAFETY: we should always have a symbol for every Name node. let symbol = self.symbols().symbol_id_by_name(id).unwrap(); - let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db); + let ty = infer_same_file_expression_type(self.db, cls, self.scope().file(self.db)) + .to_instance(self.db); let mut constraints = NarrowingConstraints::default(); constraints.insert(symbol, ty); diff --git a/crates/red_knot_python_semantic/src/types/unpacker.rs b/crates/red_knot_python_semantic/src/types/unpacker.rs index 3173a9dc28539..96bf426d01fbc 100644 --- a/crates/red_knot_python_semantic/src/types/unpacker.rs +++ b/crates/red_knot_python_semantic/src/types/unpacker.rs @@ -49,7 +49,7 @@ impl<'db> Unpacker<'db> { && self.context.in_stub() && value .expression() - .node_ref(self.db()) + .node_ref(self.db(), self.scope.file(self.db())) .is_ellipsis_literal_expr() { value_ty = Type::unknown(); @@ -57,12 +57,17 @@ impl<'db> Unpacker<'db> { if value.is_iterable() { // If the value is an iterable, then the type that needs to be unpacked is the iterator // type. - value_ty = value_ty - .iterate(self.db()) - .unwrap_with_diagnostic(&self.context, value.as_any_node_ref(self.db())); + value_ty = value_ty.iterate(self.db()).unwrap_with_diagnostic( + &self.context, + value.as_any_node_ref(self.db(), self.scope.file(self.db())), + ); } - self.unpack_inner(target, value.as_any_node_ref(self.db()), value_ty); + self.unpack_inner( + target, + value.as_any_node_ref(self.db(), self.scope.file(self.db())), + value_ty, + ); } fn unpack_inner( diff --git a/crates/red_knot_python_semantic/src/unpack.rs b/crates/red_knot_python_semantic/src/unpack.rs index ad6eac413d01c..a753380369aa9 100644 --- a/crates/red_knot_python_semantic/src/unpack.rs +++ b/crates/red_knot_python_semantic/src/unpack.rs @@ -32,12 +32,10 @@ pub(crate) struct Unpack<'db> { pub(crate) file_scope: FileScopeId, - /// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target - /// expression is `(a, b)`. #[no_eq] #[return_ref] #[tracked] - pub(crate) target: AstNodeRef, + target_inner: AstNodeRef, /// The ingredient representing the value expression of the unpacking. For example, in /// `(a, b) = (1, 2)`, the value expression is `(1, 2)`. @@ -54,8 +52,25 @@ impl<'db> Unpack<'db> { } /// Returns the range of the unpack target expression. - pub(crate) fn range(self, db: &'db dyn Db) -> TextRange { - self.target(db).range() + /// + /// `query_file` is the file for which the current query performs type inference. + /// It acts as a token of prove that we aren't accessing an AST node from a different file + /// than in which the current enclosing Salsa query (which would lead to cross-file dependencies). + #[inline] + pub(crate) fn range(self, db: &'db dyn Db, query_file: File) -> TextRange { + self.target(db, query_file).range() + } + + /// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target + /// expression is `(a, b)`. + /// + /// `query_file` is the file for which the current query performs type inference. + /// It acts as a token of prove that we aren't accessing an AST node from a different file + /// than in which the current enclosing Salsa query (which would lead to cross-file dependencies). + #[inline] + pub(crate) fn target(self, db: &'db dyn Db, query_file: File) -> &'db AstNodeRef { + debug_assert_eq!(self.file(db), query_file); + self.target_inner(db) } } @@ -93,12 +108,12 @@ impl<'db> UnpackValue<'db> { scope: ScopeId<'db>, ) -> ScopedExpressionId { self.expression() - .node_ref(db) + .node_ref(db, scope.file(db)) .scoped_expression_id(db, scope) } /// Returns the expression as an [`AnyNodeRef`]. - pub(crate) fn as_any_node_ref(self, db: &'db dyn Db) -> AnyNodeRef<'db> { - self.expression().node_ref(db).node().into() + pub(crate) fn as_any_node_ref(self, db: &'db dyn Db, file: File) -> AnyNodeRef<'db> { + self.expression().node_ref(db, file).node().into() } }