From e4934749b8c94bec743467f7c0e26384eacbd9a4 Mon Sep 17 00:00:00 2001 From: Steve Canny Date: Tue, 30 Apr 2024 23:13:57 -0700 Subject: [PATCH] fix: XmlPart._rel_ref_count `.rel_ref_count()` as implemented was only applicable to `XmlPart` where references to a related part could be present in the XML. Longer term it probably makes sense to override `Part.drop_rel()` in `XmlPart` and not have a `_rel_ref_count()` method in `part` at all, but this works and is less potentially disruptive. --- src/docx/opc/part.py | 19 +++++++++++++------ tests/opc/test_part.py | 39 +++++++++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/docx/opc/part.py b/src/docx/opc/part.py index 142f49dd1..1353bb850 100644 --- a/src/docx/opc/part.py +++ b/src/docx/opc/part.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, Type +from typing import TYPE_CHECKING, Callable, Dict, Type, cast from docx.opc.oxml import serialize_part_xml from docx.opc.packuri import PackURI @@ -149,11 +149,12 @@ def target_ref(self, rId: str) -> str: rel = self.rels[rId] return rel.target_ref - def _rel_ref_count(self, rId): - """Return the count of references in this part's XML to the relationship - identified by `rId`.""" - rIds = self._element.xpath("//@r:id") - return len([_rId for _rId in rIds if _rId == rId]) + def _rel_ref_count(self, rId: str) -> int: + """Return the count of references in this part to the relationship identified by `rId`. + + Only an XML part can contain references, so this is 0 for `Part`. + """ + return 0 class PartFactory: @@ -231,3 +232,9 @@ def part(self): That chain of delegation ends here for child objects. """ return self + + def _rel_ref_count(self, rId: str) -> int: + """Return the count of references in this part's XML to the relationship + identified by `rId`.""" + rIds = cast("list[str]", self._element.xpath("//@r:id")) + return len([_rId for _rId in rIds if _rId == rId]) diff --git a/tests/opc/test_part.py b/tests/opc/test_part.py index 03eacd361..b156a63f8 100644 --- a/tests/opc/test_part.py +++ b/tests/opc/test_part.py @@ -169,24 +169,13 @@ def it_can_establish_an_external_relationship(self, rels_prop_: Mock, rels_: Moc rels_.get_or_add_ext_rel.assert_called_once_with("http://rel/type", "https://hyper/link") assert rId == "rId27" - @pytest.mark.parametrize( - ("part_cxml", "rel_should_be_dropped"), - [ - ("w:p", True), - ("w:p/r:a{r:id=rId42}", True), - ("w:p/r:a{r:id=rId42}/r:b{r:id=rId42}", False), - ], - ) - def it_can_drop_a_relationship( - self, part_cxml: str, rel_should_be_dropped: bool, rels_prop_: Mock - ): + def it_can_drop_a_relationship(self, rels_prop_: Mock): rels_prop_.return_value = {"rId42": None} - part = Part("partname", "content_type") - part._element = element(part_cxml) # pyright: ignore[reportAttributeAccessIssue] + part = Part(PackURI("/partname"), "content_type") part.drop_rel("rId42") - assert ("rId42" not in part.rels) is rel_should_be_dropped + assert "rId42" not in part.rels def it_can_find_a_related_part_by_reltype( self, rels_prop_: Mock, rels_: Mock, other_part_: Mock @@ -411,6 +400,24 @@ def it_knows_its_the_part_for_its_child_objects(self, part_fixture): xml_part = part_fixture assert xml_part.part is xml_part + @pytest.mark.parametrize( + ("part_cxml", "rel_should_be_dropped"), + [ + ("w:p", True), + ("w:p/r:a{r:id=rId42}", True), + ("w:p/r:a{r:id=rId42}/r:b{r:id=rId42}", False), + ], + ) + def it_only_drops_a_relationship_with_zero_reference_count( + self, part_cxml: str, rel_should_be_dropped: bool, rels_prop_: Mock, package_: Mock + ): + rels_prop_.return_value = {"rId42": None} + part = XmlPart(PackURI("/partname"), "content_type", element(part_cxml), package_) + + part.drop_rel("rId42") + + assert ("rId42" not in part.rels) is rel_should_be_dropped + # fixtures ------------------------------------------------------- @pytest.fixture @@ -452,6 +459,10 @@ def parse_xml_(self, request, element_): def partname_(self, request): return instance_mock(request, PackURI) + @pytest.fixture + def rels_prop_(self, request: FixtureRequest): + return property_mock(request, XmlPart, "rels") + @pytest.fixture def serialize_part_xml_(self, request): return function_mock(request, "docx.opc.part.serialize_part_xml")