Skip to content

Commit

Permalink
Fix pylint-dev#2628 by adding ignore_duplicate parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksMat committed Oct 28, 2024
1 parent e380fd1 commit 77e60bb
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 8 deletions.
4 changes: 2 additions & 2 deletions astroid/brain/brain_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _find_arguments_from_base_classes(
# See TODO down below
# all_have_defaults = True

for base in reversed(node.mro()):
for base in reversed(node.mro(ignore_duplicates=True)):
if not base.is_dataclass:
continue
try:
Expand Down Expand Up @@ -221,7 +221,7 @@ def _parse_arguments_into_strings(

def _get_previous_field_default(node: nodes.ClassDef, name: str) -> nodes.NodeNG | None:
"""Get the default value of a previously defined field."""
for base in reversed(node.mro()):
for base in reversed(node.mro(ignore_duplicates=True)):
if not base.is_dataclass:
continue
if name in base.locals:
Expand Down
21 changes: 15 additions & 6 deletions astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ def clean_duplicates_mro(
sequences: list[list[ClassDef]],
cls: ClassDef,
context: InferenceContext | None,
ignore_duplicates: bool,
) -> list[list[ClassDef]]:
for sequence in sequences:
seen = set()
for node in sequence:
lineno_and_qname = (node.lineno, node.qname())
if lineno_and_qname in seen:
if lineno_and_qname in seen and not ignore_duplicates:
raise DuplicateBasesError(
message="Duplicates found in MROs {mros} for {cls!r}.",
mros=sequences,
Expand Down Expand Up @@ -2834,7 +2835,9 @@ def _inferred_bases(self, context: InferenceContext | None = None):
else:
yield from baseobj.bases

def _compute_mro(self, context: InferenceContext | None = None):
def _compute_mro(
self, context: InferenceContext | None = None, ignore_duplicates: bool = False
):
if self.qname() == "builtins.object":
return [self]

Expand All @@ -2844,23 +2847,29 @@ def _compute_mro(self, context: InferenceContext | None = None):
if base is self:
continue

mro = base._compute_mro(context=context)
mro = base._compute_mro(
context=context, ignore_duplicates=ignore_duplicates
)
bases_mro.append(mro)

unmerged_mro: list[list[ClassDef]] = [[self], *bases_mro, inferred_bases]
unmerged_mro = clean_duplicates_mro(unmerged_mro, self, context)
unmerged_mro = clean_duplicates_mro(
unmerged_mro, self, context, ignore_duplicates=ignore_duplicates
)
clean_typing_generic_mro(unmerged_mro)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context: InferenceContext | None = None) -> list[ClassDef]:
def mro(
self, context: InferenceContext | None = None, ignore_duplicates: bool = False
) -> list[ClassDef]:
"""Get the method resolution order, using C3 linearization.
:returns: The list of ancestors, sorted by the mro.
:rtype: list(NodeNG)
:raises DuplicateBasesError: Duplicate bases in the same class base
:raises InconsistentMroError: A class' MRO is inconsistent
"""
return self._compute_mro(context=context)
return self._compute_mro(context=context, ignore_duplicates=ignore_duplicates)

def bool_value(self, context: InferenceContext | None = None) -> Literal[True]:
"""Determine the boolean value of this node.
Expand Down
30 changes: 30 additions & 0 deletions tests/brain/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,33 @@ def attr(self, value: int) -> None:
fourth_init: bases.UnboundMethod = next(fourth.infer())
assert [a.name for a in fourth_init.args.args] == ["self", "other_attr", "attr"]
assert [a.name for a in fourth_init.args.defaults] == ["Uninferable"]


@parametrize_module
def test_dataclass_inherited_from_multiple_protocol_bases(module: str):
code = astroid.extract_node(
f"""
from {module} import dataclass
from typing import TypeVar, Protocol
BaseT = TypeVar("BaseT")
T = TypeVar("T", bound=BaseT)
class A(Protocol[BaseT]):
pass
class B(A[T], Protocol[T]):
pass
@dataclass
class Dataclass(B[T]):
pass
"""
)
inferred = code.inferred()
assert len(inferred) == 1
assert isinstance(inferred[0], nodes.ClassDef)
assert inferred[0].is_dataclass
3 changes: 3 additions & 0 deletions tests/test_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,7 @@ class A(Generic[T1], Generic[T2]): ...
assert isinstance(cls, nodes.ClassDef)
with self.assertRaises(DuplicateBasesError):
cls.mro()
assert len(cls.mro(ignore_duplicates=True)) == 3

def test_mro_generic_error_2(self):
cls = builder.extract_node(
Expand All @@ -1951,6 +1952,8 @@ class B(A[T], A[T]): ...
assert isinstance(cls, nodes.ClassDef)
with self.assertRaises(DuplicateBasesError):
cls.mro()
with self.assertRaises(InconsistentMroError):
cls.mro(ignore_duplicates=True)

def test_mro_typing_extensions(self):
"""Regression test for mro() inference on typing_extensions.
Expand Down

0 comments on commit 77e60bb

Please sign in to comment.