Skip to content

Commit

Permalink
Add `UnionBuilder::extend
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser committed Feb 18, 2025
1 parent 91e214c commit fe308d6
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 43 deletions.
22 changes: 12 additions & 10 deletions crates/red_knot_python_semantic/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl<'db> LookupError<'db> {
fallback: Symbol<'db>,
) -> LookupResult<'db> {
let fallback = fallback.into_lookup_result();
match (&self, &fallback) {
match (self, fallback) {
(LookupError::Unbound, _) => fallback,
(LookupError::PossiblyUnbound { .. }, Err(LookupError::Unbound)) => Err(self),
(LookupError::PossiblyUnbound(ty), Ok(ty2)) => {
Expand Down Expand Up @@ -611,15 +611,17 @@ fn symbol_from_declarations_impl<'db>(
let ty_first = first.inner_type();
let mut qualifiers = first.qualifiers();

let mut builder = UnionBuilder::new(db).add(ty_first);
for other in std::iter::once(second).chain(types) {
let other_ty = other.inner_type();
if !ty_first.is_equivalent_to(db, other_ty) {
conflicting.push(other_ty);
}
builder = builder.add(other_ty);
qualifiers = qualifiers.union(other.qualifiers());
}
let mut builder = UnionBuilder::new(db).add(ty_first).extend(
std::iter::once(second).chain(types).map(|other| {
let other_ty = other.inner_type();
if !ty_first.is_equivalent_to(db, other_ty) {
conflicting.push(other_ty);
}

qualifiers = qualifiers.union(other.qualifiers());
other_ty
}),
);
TypeAndQualifiers::new(builder.build(), qualifiers)
} else {
first
Expand Down
32 changes: 15 additions & 17 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,7 @@ impl<'db> Type<'db> {
fn iterate(self, db: &'db dyn Db) -> IterationOutcome<'db> {
if let Type::Tuple(tuple_type) = self {
return IterationOutcome::Iterable {
element_ty: UnionType::from_elements(db, tuple_type.elements(db)),
element_ty: UnionType::from_elements(db, tuple_type.elements(db).iter().copied()),
};
}

Expand Down Expand Up @@ -3700,12 +3700,13 @@ impl<'db> Class<'db> {
(Some(builder), Some(ty)) => Some(builder.add(ty)),
}
})
.map(|mut builder| {
for binding in bindings {
builder = builder.add(binding.return_type());
}

builder.build()
.map(|builder| {
builder
.extend(
IntoIterator::into_iter(bindings)
.map(|binding| binding.return_type()),
)
.build()
});

if partly_not_callable {
Expand Down Expand Up @@ -4172,27 +4173,24 @@ impl<'db> UnionType<'db> {

/// Create a union from a list of elements
/// (which may be eagerly simplified into a different variant of [`Type`] altogether).
pub fn from_elements<I, T>(db: &'db dyn Db, elements: I) -> Type<'db>
pub fn from_elements<I>(db: &'db dyn Db, elements: I) -> Type<'db>
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
I: IntoIterator<Item = Type<'db>>,
{
let mut elements = elements.into_iter();
let Some(first) = elements.next() else {
return Type::Never;
};

let Some(second) = elements.next() else {
return first.into();
return first;
};

let mut builder = UnionBuilder::new(db).add(first.into()).add(second.into());

for element in elements {
builder = builder.add(element.into());
}
let (lower, _) = elements.size_hint();
let mut builder = UnionBuilder::new(db);
builder.reserve(lower + 2);

builder.build()
builder.add(first).add(second).extend(elements).build()
}

/// Apply a transformation function to all elements of the union,
Expand Down
18 changes: 18 additions & 0 deletions crates/red_knot_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,31 @@ impl<'db> UnionBuilder<'db> {
}
}

pub(crate) fn reserve(&mut self, additional: usize) {
self.elements.reserve(additional);
}

/// Collapse the union to a single type: `object`.
fn collapse_to_object(mut self) -> Self {
self.elements.clear();
self.elements.push(Type::object(self.db));
self
}

pub(crate) fn extend(mut self, elements: impl IntoIterator<Item = Type<'db>>) -> Self {
let elements = elements.into_iter();
let (lower, _) = elements.size_hint();

// Assume that most types will be unique
self.reserve(lower);

for element in elements {
self = self.add(element);
}

self
}

/// Adds a type to this union.
pub(crate) fn add(mut self, ty: Type<'db>) -> Self {
match ty {
Expand Down
27 changes: 12 additions & 15 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1707,18 +1707,17 @@ impl<'db> TypeInferenceBuilder<'db> {
let symbol_ty = if let Type::Tuple(tuple) = node_ty {
let type_base_exception = KnownClass::BaseException.to_subclass_of(self.db());
let mut builder = UnionBuilder::new(self.db());
for element in tuple.elements(self.db()).iter().copied() {
builder = builder.add(
if element.is_assignable_to(self.db(), type_base_exception) {
element.to_instance(self.db())
} else {
if let Some(node) = node {
report_invalid_exception_caught(&self.context, node, element);
}
Type::unknown()
},
);
}
builder = builder.extend(tuple.elements(self.db()).iter().map(|element| {
if element.is_assignable_to(self.db(), type_base_exception) {
element.to_instance(self.db())
} else {
if let Some(node) = node {
report_invalid_exception_caught(&self.context, node, *element);
}
Type::unknown()
}
}));

builder.build()
} else if node_ty.is_subtype_of(self.db(), KnownClass::Tuple.to_instance(self.db())) {
todo_type!("Homogeneous tuple in exception handler")
Expand Down Expand Up @@ -3363,9 +3362,7 @@ impl<'db> TypeInferenceBuilder<'db> {
bindings: _,
errors,
} => {
// TODO: Remove the `Vec::from` call once we use the Rust 2024 edition
// which adds `Box<[T]>::into_iter`
if let Some(first) = Vec::from(errors).into_iter().next() {
if let Some(first) = IntoIterator::into_iter(errors).next() {
report_call_error(context, first, call_expression);
} else {
debug_assert!(
Expand Down
2 changes: 1 addition & 1 deletion crates/red_knot_python_semantic/src/types/unpacker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl<'db> Unpacker<'db> {
// SAFETY: `target_types` is initialized with the same length as `elts`.
let element_ty = match target_types[index].as_slice() {
[] => Type::unknown(),
types => UnionType::from_elements(self.db(), types),
types => UnionType::from_elements(self.db(), types.iter().copied()),
};
self.unpack_inner(element, value_expr, element_ty);
}
Expand Down

0 comments on commit fe308d6

Please sign in to comment.