Skip to content

Commit

Permalink
Merge #8945
Browse files Browse the repository at this point in the history
8945: fix: Make expected type work in more situations r=flodiebold a=flodiebold

Also makes call info show the correct types for generic methods.

![2021-05-23-182952_1134x616_scrot](https://user-images.githubusercontent.com/906069/119269023-dd5a5b00-bbf5-11eb-993a-b6e122c3b9a6.png)
![2021-05-23-183117_922x696_scrot](https://user-images.githubusercontent.com/906069/119269025-dfbcb500-bbf5-11eb-983c-fc415b8428e0.png)


Co-authored-by: Florian Diebold <[email protected]>
  • Loading branch information
bors[bot] and flodiebold authored May 23, 2021
2 parents a2ce091 + b826209 commit 495c958
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 69 deletions.
5 changes: 2 additions & 3 deletions crates/hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,8 @@ impl Field {
}

/// Returns the type as in the signature of the struct (i.e., with
/// placeholder types for type parameters). This is good for showing
/// signature help, but not so good to actually get the type of the field
/// when you actually have a variable of the struct.
/// placeholder types for type parameters). Only use this in the context of
/// the field definition.
pub fn ty(&self, db: &dyn HirDatabase) -> Type {
let var_id = self.parent.into();
let generic_def_id: GenericDefId = match self.parent {
Expand Down
17 changes: 9 additions & 8 deletions crates/hir/src/semantics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use hir_def::{
AsMacroCall, FunctionId, TraitId, VariantId,
};
use hir_expand::{name::AsName, ExpansionInfo};
use hir_ty::associated_type_shorthand_candidates;
use hir_ty::{associated_type_shorthand_candidates, Interner};
use itertools::Itertools;
use rustc_hash::{FxHashMap, FxHashSet};
use syntax::{
Expand Down Expand Up @@ -227,7 +227,7 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
pub fn resolve_record_field(
&self,
field: &ast::RecordExprField,
) -> Option<(Field, Option<Local>)> {
) -> Option<(Field, Option<Local>, Type)> {
self.imp.resolve_record_field(field)
}

Expand Down Expand Up @@ -501,14 +501,12 @@ impl<'db> SemanticsImpl<'db> {
}

fn resolve_method_call(&self, call: &ast::MethodCallExpr) -> Option<FunctionId> {
self.analyze(call.syntax()).resolve_method_call(self.db, call)
self.analyze(call.syntax()).resolve_method_call(self.db, call).map(|(id, _)| id)
}

fn resolve_method_call_as_callable(&self, call: &ast::MethodCallExpr) -> Option<Callable> {
// FIXME: this erases Substs, we should instead record the correct
// substitution during inference and use that
let func = self.resolve_method_call(call)?;
let ty = hir_ty::TyBuilder::value_ty(self.db, func.into()).fill_with_unknown().build();
let (func, subst) = self.analyze(call.syntax()).resolve_method_call(self.db, call)?;
let ty = self.db.value_ty(func.into()).substitute(&Interner, &subst);
let resolver = self.analyze(call.syntax()).resolver;
let ty = Type::new_with_resolver(self.db, &resolver, ty)?;
let mut res = ty.as_callable(self.db)?;
Expand All @@ -520,7 +518,10 @@ impl<'db> SemanticsImpl<'db> {
self.analyze(field.syntax()).resolve_field(self.db, field)
}

fn resolve_record_field(&self, field: &ast::RecordExprField) -> Option<(Field, Option<Local>)> {
fn resolve_record_field(
&self,
field: &ast::RecordExprField,
) -> Option<(Field, Option<Local>, Type)> {
self.analyze(field.syntax()).resolve_record_field(self.db, field)
}

Expand Down
9 changes: 6 additions & 3 deletions crates/hir/src/source_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl SourceAnalyzer {
&self,
db: &dyn HirDatabase,
call: &ast::MethodCallExpr,
) -> Option<FunctionId> {
) -> Option<(FunctionId, Substitution)> {
let expr_id = self.expr_id(db, &call.clone().into())?;
self.infer.as_ref()?.method_resolution(expr_id)
}
Expand All @@ -161,7 +161,7 @@ impl SourceAnalyzer {
&self,
db: &dyn HirDatabase,
field: &ast::RecordExprField,
) -> Option<(Field, Option<Local>)> {
) -> Option<(Field, Option<Local>, Type)> {
let record_expr = ast::RecordExpr::cast(field.syntax().parent().and_then(|p| p.parent())?)?;
let expr = ast::Expr::from(record_expr);
let expr_id = self.body_source_map.as_ref()?.node_expr(InFile::new(self.file_id, &expr))?;
Expand All @@ -178,10 +178,13 @@ impl SourceAnalyzer {
_ => None,
}
};
let (_, subst) = self.infer.as_ref()?.type_of_expr.get(expr_id)?.as_adt()?;
let variant = self.infer.as_ref()?.variant_resolution_for_expr(expr_id)?;
let variant_data = variant.variant_data(db.upcast());
let field = FieldId { parent: variant, local_id: variant_data.field(&local_name)? };
Some((field.into(), local))
let field_ty =
db.field_types(variant).get(field.local_id)?.clone().substitute(&Interner, subst);
Some((field.into(), local, Type::new_with_resolver(db, &self.resolver, field_ty)?))
}

pub(crate) fn resolve_record_pat_field(
Expand Down
8 changes: 8 additions & 0 deletions crates/hir_def/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,14 @@ impl VariantId {
VariantId::UnionId(it) => it.lookup(db).id.file_id(),
}
}

pub fn adt_id(self) -> AdtId {
match self {
VariantId::EnumVariantId(it) => it.parent.into(),
VariantId::StructId(it) => it.into(),
VariantId::UnionId(it) => it.into(),
}
}
}

trait Intern {
Expand Down
12 changes: 4 additions & 8 deletions crates/hir_ty/src/diagnostics/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
for (id, expr) in body.exprs.iter() {
if let Expr::MethodCall { receiver, .. } = expr {
let function_id = match self.infer.method_resolution(id) {
Some(id) => id,
Some((id, _)) => id,
None => continue,
};

Expand Down Expand Up @@ -239,15 +239,11 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
return;
}

// FIXME: note that we erase information about substs here. This
// is not right, but, luckily, doesn't matter as we care only
// about the number of params
let callee = match self.infer.method_resolution(call_id) {
Some(callee) => callee,
let (callee, subst) = match self.infer.method_resolution(call_id) {
Some(it) => it,
None => return,
};
let sig =
db.callable_item_signature(callee.into()).into_value_and_skipped_binders().0;
let sig = db.callable_item_signature(callee.into()).substitute(&Interner, &subst);

(sig, args)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/hir_ty/src/diagnostics/unsafe_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn walk_unsafe(
Expr::MethodCall { .. } => {
if infer
.method_resolution(current)
.map(|func| db.function_data(func).is_unsafe())
.map(|(func, _)| db.function_data(func).is_unsafe())
.unwrap_or(false)
{
unsafe_exprs.push(UnsafeExpr { expr: current, inside_unsafe_block });
Expand Down
25 changes: 14 additions & 11 deletions crates/hir_ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ use syntax::SmolStr;
use super::{DomainGoal, InEnvironment, ProjectionTy, TraitEnvironment, TraitRef, Ty};
use crate::{
db::HirDatabase, fold_tys, infer::diagnostics::InferenceDiagnostic,
lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Goal, Interner, TyBuilder,
TyExt, TyKind,
lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Goal, Interner, Substitution,
TyBuilder, TyExt, TyKind,
};

// This lint has a false positive here. See the link below for details.
Expand Down Expand Up @@ -132,7 +132,7 @@ impl Default for InternedStandardTypes {
#[derive(Clone, PartialEq, Eq, Debug, Default)]
pub struct InferenceResult {
/// For each method call expr, records the function it resolves to.
method_resolutions: FxHashMap<ExprId, FunctionId>,
method_resolutions: FxHashMap<ExprId, (FunctionId, Substitution)>,
/// For each field access expr, records the field it resolves to.
field_resolutions: FxHashMap<ExprId, FieldId>,
/// For each struct literal or pattern, records the variant it resolves to.
Expand All @@ -152,8 +152,8 @@ pub struct InferenceResult {
}

impl InferenceResult {
pub fn method_resolution(&self, expr: ExprId) -> Option<FunctionId> {
self.method_resolutions.get(&expr).copied()
pub fn method_resolution(&self, expr: ExprId) -> Option<(FunctionId, Substitution)> {
self.method_resolutions.get(&expr).cloned()
}
pub fn field_resolution(&self, expr: ExprId) -> Option<FieldId> {
self.field_resolutions.get(&expr).copied()
Expand Down Expand Up @@ -284,14 +284,17 @@ impl<'a> InferenceContext<'a> {
self.table.propagate_diverging_flag();
let mut result = std::mem::take(&mut self.result);
for ty in result.type_of_expr.values_mut() {
*ty = self.table.resolve_ty_completely(ty.clone());
*ty = self.table.resolve_completely(ty.clone());
}
for ty in result.type_of_pat.values_mut() {
*ty = self.table.resolve_ty_completely(ty.clone());
*ty = self.table.resolve_completely(ty.clone());
}
for mismatch in result.type_mismatches.values_mut() {
mismatch.expected = self.table.resolve_ty_completely(mismatch.expected.clone());
mismatch.actual = self.table.resolve_ty_completely(mismatch.actual.clone());
mismatch.expected = self.table.resolve_completely(mismatch.expected.clone());
mismatch.actual = self.table.resolve_completely(mismatch.actual.clone());
}
for (_, subst) in result.method_resolutions.values_mut() {
*subst = self.table.resolve_completely(subst.clone());
}
result
}
Expand All @@ -300,8 +303,8 @@ impl<'a> InferenceContext<'a> {
self.result.type_of_expr.insert(expr, ty);
}

fn write_method_resolution(&mut self, expr: ExprId, func: FunctionId) {
self.result.method_resolutions.insert(expr, func);
fn write_method_resolution(&mut self, expr: ExprId, func: FunctionId, subst: Substitution) {
self.result.method_resolutions.insert(expr, (func, subst));
}

fn write_field_resolution(&mut self, expr: ExprId, field: FieldId) {
Expand Down
34 changes: 18 additions & 16 deletions crates/hir_ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -891,17 +891,21 @@ impl<'a> InferenceContext<'a> {
method_name,
)
});
let (derefed_receiver_ty, method_ty, def_generics) = match resolved {
let (derefed_receiver_ty, method_ty, substs) = match resolved {
Some((ty, func)) => {
let ty = canonicalized_receiver.decanonicalize_ty(ty);
self.write_method_resolution(tgt_expr, func);
(ty, self.db.value_ty(func.into()), Some(generics(self.db.upcast(), func.into())))
let generics = generics(self.db.upcast(), func.into());
let substs = self.substs_for_method_call(generics, generic_args, &ty);
self.write_method_resolution(tgt_expr, func, substs.clone());
(ty, self.db.value_ty(func.into()), substs)
}
None => (receiver_ty, Binders::empty(&Interner, self.err_ty()), None),
None => (
receiver_ty,
Binders::empty(&Interner, self.err_ty()),
Substitution::empty(&Interner),
),
};
let substs = self.substs_for_method_call(def_generics, generic_args, &derefed_receiver_ty);
let method_ty = method_ty.substitute(&Interner, &substs);
let method_ty = self.insert_type_vars(method_ty);
self.register_obligations_for_call(&method_ty);
let (expected_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
Some(sig) => {
Expand Down Expand Up @@ -950,23 +954,21 @@ impl<'a> InferenceContext<'a> {

fn substs_for_method_call(
&mut self,
def_generics: Option<Generics>,
def_generics: Generics,
generic_args: Option<&GenericArgs>,
receiver_ty: &Ty,
) -> Substitution {
let (parent_params, self_params, type_params, impl_trait_params) =
def_generics.as_ref().map_or((0, 0, 0, 0), |g| g.provenance_split());
def_generics.provenance_split();
assert_eq!(self_params, 0); // method shouldn't have another Self param
let total_len = parent_params + type_params + impl_trait_params;
let mut substs = Vec::with_capacity(total_len);
// Parent arguments are unknown, except for the receiver type
if let Some(parent_generics) = def_generics.as_ref().map(|p| p.iter_parent()) {
for (_id, param) in parent_generics {
if param.provenance == hir_def::generics::TypeParamProvenance::TraitSelf {
substs.push(receiver_ty.clone());
} else {
substs.push(self.err_ty());
}
for (_id, param) in def_generics.iter_parent() {
if param.provenance == hir_def::generics::TypeParamProvenance::TraitSelf {
substs.push(receiver_ty.clone());
} else {
substs.push(self.table.new_type_var());
}
}
// handle provided type arguments
Expand All @@ -989,7 +991,7 @@ impl<'a> InferenceContext<'a> {
};
let supplied_params = substs.len();
for _ in supplied_params..total_len {
substs.push(self.err_ty());
substs.push(self.table.new_type_var());
}
assert_eq!(substs.len(), total_len);
Substitution::from_iter(&Interner, substs)
Expand Down
7 changes: 5 additions & 2 deletions crates/hir_ty/src/infer/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,11 @@ impl<'a> InferenceTable<'a> {
.expect("fold failed unexpectedly")
}

pub(crate) fn resolve_ty_completely(&mut self, ty: Ty) -> Ty {
self.resolve_with_fallback(ty, |_, _, d, _| d)
pub(crate) fn resolve_completely<T>(&mut self, t: T) -> T::Result
where
T: HasInterner<Interner = Interner> + Fold<Interner>,
{
self.resolve_with_fallback(t, |_, _, d, _| d)
}

/// Unify two types and register new trait goals that arise from that.
Expand Down
2 changes: 1 addition & 1 deletion crates/ide_assists/src/handlers/fix_visibility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext) -> O

fn add_vis_to_referenced_record_field(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
let record_field: ast::RecordExprField = ctx.find_node_at_offset()?;
let (record_field_def, _) = ctx.sema.resolve_record_field(&record_field)?;
let (record_field_def, _, _) = ctx.sema.resolve_record_field(&record_field)?;

let current_module = ctx.sema.scope(record_field.syntax()).module()?;
let visibility = record_field_def.visibility(ctx.db());
Expand Down
Loading

0 comments on commit 495c958

Please sign in to comment.