Skip to content

Commit

Permalink
fix: XmlPart._rel_ref_count
Browse files Browse the repository at this point in the history
`.rel_ref_count()` as implemented was only applicable to `XmlPart` where
references to a related part could be present in the XML. Longer term it
probably makes sense to override `Part.drop_rel()` in `XmlPart` and not
have a `_rel_ref_count()` method in `part` at all, but this works and is
less potentially disruptive.
  • Loading branch information
scanny committed May 1, 2024
1 parent 3f56b7d commit e493474
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
19 changes: 13 additions & 6 deletions src/docx/opc/part.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Dict, Type
from typing import TYPE_CHECKING, Callable, Dict, Type, cast

from docx.opc.oxml import serialize_part_xml
from docx.opc.packuri import PackURI
Expand Down Expand Up @@ -149,11 +149,12 @@ def target_ref(self, rId: str) -> str:
rel = self.rels[rId]
return rel.target_ref

def _rel_ref_count(self, rId):
"""Return the count of references in this part's XML to the relationship
identified by `rId`."""
rIds = self._element.xpath("//@r:id")
return len([_rId for _rId in rIds if _rId == rId])
def _rel_ref_count(self, rId: str) -> int:
"""Return the count of references in this part to the relationship identified by `rId`.
Only an XML part can contain references, so this is 0 for `Part`.
"""
return 0


class PartFactory:
Expand Down Expand Up @@ -231,3 +232,9 @@ def part(self):
That chain of delegation ends here for child objects.
"""
return self

def _rel_ref_count(self, rId: str) -> int:
"""Return the count of references in this part's XML to the relationship
identified by `rId`."""
rIds = cast("list[str]", self._element.xpath("//@r:id"))
return len([_rId for _rId in rIds if _rId == rId])
39 changes: 25 additions & 14 deletions tests/opc/test_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,13 @@ def it_can_establish_an_external_relationship(self, rels_prop_: Mock, rels_: Moc
rels_.get_or_add_ext_rel.assert_called_once_with("http://rel/type", "https://hyper/link")
assert rId == "rId27"

@pytest.mark.parametrize(
("part_cxml", "rel_should_be_dropped"),
[
("w:p", True),
("w:p/r:a{r:id=rId42}", True),
("w:p/r:a{r:id=rId42}/r:b{r:id=rId42}", False),
],
)
def it_can_drop_a_relationship(
self, part_cxml: str, rel_should_be_dropped: bool, rels_prop_: Mock
):
def it_can_drop_a_relationship(self, rels_prop_: Mock):
rels_prop_.return_value = {"rId42": None}
part = Part("partname", "content_type")
part._element = element(part_cxml) # pyright: ignore[reportAttributeAccessIssue]
part = Part(PackURI("/partname"), "content_type")

part.drop_rel("rId42")

assert ("rId42" not in part.rels) is rel_should_be_dropped
assert "rId42" not in part.rels

def it_can_find_a_related_part_by_reltype(
self, rels_prop_: Mock, rels_: Mock, other_part_: Mock
Expand Down Expand Up @@ -411,6 +400,24 @@ def it_knows_its_the_part_for_its_child_objects(self, part_fixture):
xml_part = part_fixture
assert xml_part.part is xml_part

@pytest.mark.parametrize(
("part_cxml", "rel_should_be_dropped"),
[
("w:p", True),
("w:p/r:a{r:id=rId42}", True),
("w:p/r:a{r:id=rId42}/r:b{r:id=rId42}", False),
],
)
def it_only_drops_a_relationship_with_zero_reference_count(
self, part_cxml: str, rel_should_be_dropped: bool, rels_prop_: Mock, package_: Mock
):
rels_prop_.return_value = {"rId42": None}
part = XmlPart(PackURI("/partname"), "content_type", element(part_cxml), package_)

part.drop_rel("rId42")

assert ("rId42" not in part.rels) is rel_should_be_dropped

# fixtures -------------------------------------------------------

@pytest.fixture
Expand Down Expand Up @@ -452,6 +459,10 @@ def parse_xml_(self, request, element_):
def partname_(self, request):
return instance_mock(request, PackURI)

@pytest.fixture
def rels_prop_(self, request: FixtureRequest):
return property_mock(request, XmlPart, "rels")

@pytest.fixture
def serialize_part_xml_(self, request):
return function_mock(request, "docx.opc.part.serialize_part_xml")

0 comments on commit e493474

Please sign in to comment.