Skip to content

Commit

Permalink
add label based indexing to pass sequence (#154)
Browse files Browse the repository at this point in the history
* add label based indexing to pass sequence

* add ipython completion support
  • Loading branch information
axtimhaus authored Jan 12, 2024
1 parent bf911dc commit 1384c53
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 4 deletions.
27 changes: 23 additions & 4 deletions pyroll/core/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,31 @@ def __iter__(self):
return self._subunits.__iter__()

@overload
def __getitem__(self, index: int) -> Unit:
def __getitem__(self, key: int) -> Unit:
"""Gets unit item by index."""
...

@overload
def __getitem__(self, index: slice) -> list[Unit]:
def __getitem__(self, key: str) -> Unit:
"""Gets unit item by label."""
...

def __getitem__(self, index: int) -> Unit:
return self._subunits.__getitem__(index)
@overload
def __getitem__(self, key: slice) -> list[Unit]:
"""Gets a slice of units."""
...

def __getitem__(self, key):
if isinstance(key, str):
try:
return next(u for u in self._subunits if u.label == key)
except StopIteration:
raise KeyError(f"No unit with label '{key}' found.")

if isinstance(key, int) or isinstance(key, slice):
return self._subunits.__getitem__(key)

raise TypeError("Key must be int, slice or str")

@property
def units(self) -> List[Unit]:
Expand All @@ -88,3 +104,6 @@ def __attrs__(self):
return super().__attrs__ | {
"units": self.units
}

def _ipython_key_completions_(self):
return [u.label for u in self._subunits]
63 changes: 63 additions & 0 deletions tests/test_pass_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pyroll.core import PassSequence, RollPass, CircularOvalGroove, Transport, RoundGroove, Roll
import pytest


# noinspection DuplicatedCode
def test_pass_sequence_indexing():
sequence = PassSequence([
RollPass(
label="Oval I",
roll=Roll(
groove=CircularOvalGroove(
depth=8e-3,
r1=6e-3,
r2=40e-3
),
nominal_radius=160e-3,
rotational_frequency=1,
neutral_point=-20e-3
),
gap=2e-3,

),
Transport(
label="I => II",
duration=1,
),
RollPass(
label="Round II",
roll=Roll(
groove=RoundGroove(
r1=1e-3,
r2=12.5e-3,
depth=11.5e-3
),
nominal_radius=160e-3,
rotational_frequency=1
),
gap=2e-3,
),
Transport(
label="II => III",
duration=1
),
RollPass(
label="Oval III",
roll=Roll(
groove=CircularOvalGroove(
depth=6e-3,
r1=6e-3,
r2=35e-3
),
nominal_radius=160e-3,
rotational_frequency=1
),
gap=2e-3,
),
])

assert sequence["Oval III"] == sequence[4]
assert sequence["I => II"] == sequence[1]

with pytest.raises(KeyError):
_ = sequence["not present"]
1 change: 1 addition & 0 deletions tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def flow_stress(self: RollPass.Profile):
return 50e6 * (1 + self.strain) ** 0.2 * self.roll_pass.strain_rate ** 0.1


# noinspection DuplicatedCode
def test_solve(tmp_path: Path, caplog):
caplog.set_level(logging.DEBUG, logger="pyroll")

Expand Down

0 comments on commit 1384c53

Please sign in to comment.