Skip to content

Commit

Permalink
Allows images to be callable
Browse files Browse the repository at this point in the history
  • Loading branch information
ianhi committed Oct 27, 2021
1 parent f3a2e2f commit da46f0e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pip install mpl-image-labeller
- Smart interactions with default Matplotlib keymap

![gif of usage for labelling images of cats and dogs](example.gif)

## Usage

```python
Expand All @@ -35,9 +36,15 @@ labeller = image_labeller(
plt.show()
```

**accessing the axis**
You can further modify the image (e.g. add masks over them) by using the plotting methods on
axis object accessible by `labeller.ax`.

**Lazy Loading Images**
If you want to lazy load your images you can provide a function to give the images. This function should take
the integer `idx` as an argument and return the image that corresponds to that index. If you do this then you
must also provide `N_images` in the constructor to let the object know how many images it should expect. See `examples/lazy_loading.py` for an example.

### Controls

- `<-` move one image back
Expand Down
19 changes: 19 additions & 0 deletions examples/lazy_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# You can lazy load images by providing a function instead of a list for *images*
# if you do this then you must also provide *N_images* in the labeller constructor


from mpl_image_labeller import image_labeller
import matplotlib.pyplot as plt

from numpy.random import default_rng


def lazy_image_generator(idx):
rng = default_rng(idx)
return rng.random((10, 10))


labeller = image_labeller(
lazy_image_generator, classes=["cool", "rad", "lame"], N_images=57
)
plt.show()
39 changes: 29 additions & 10 deletions mpl_image_labeller/_labeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
init_labels=None,
label_keymap: Union[List[str], str] = "1234",
labelling_advances_image: bool = True,
N_images=None,
fig: Figure = None,
):
"""
Expand All @@ -42,15 +43,34 @@ def __init__(
longer perform savefig.
labelling_advances_image : bool, default: True
Whether labelling an image should advance to the next image.
N_images : int or None
The number of images. Required if passing a Callable for images, otherwise
ignored.
fig : Figure
An empty figure to build the UI in. Use this to embed image_labeller into
a gui framework.
"""
self._images = images
if callable(images):
if not isinstance(N_images, int):
raise TypeError(
"If images is a callable then N_images must be provided"
)
self._N_images = N_images
def _get_image(i):
return self._images(i)
else:
self._N_images = len(images)
def _get_image(i):
return self._images[i]

self._get_image = _get_image

self._label_advances = labelling_advances_image

if init_labels is None:
self._labels = [None] * len(images)
elif len(init_labels) != len(images):
self._labels = [None] * self._N_images
elif len(init_labels) != self._N_images:
raise ValueError("init_labels must have the same length as images")
else:
self._labels = init_labels
Expand Down Expand Up @@ -91,7 +111,7 @@ def __init__(

self._image_index = 0
self._ax = self._fig.add_subplot(111)
self._im = self._ax.imshow(images[0])
self._im = self._ax.imshow(self._get_image(0))

# shift axis to make room for list of keybindings
box = self._ax.get_position()
Expand Down Expand Up @@ -144,7 +164,7 @@ def labels(self):

@labels.setter
def labels(self, value):
if len(value) != len(self._images):
if len(value) != self._N_images:
raise ValueError(
"Length of labels must be the same as the number of images"
)
Expand All @@ -156,15 +176,14 @@ def image_index(self):

@image_index.setter
def image_index(self, value):
N = len(self._images)
if value == self._image_index:
# quick return to avoid unnecessary draw
return
elif value >= N:
if self._image_index == N - 1:
elif value >= self._N_images:
if self._image_index == self._N_images - 1:
# quick return to avoid unnecessary draw
return
self._image_index = N - 1
self._image_index = self._N_images - 1
elif value < 0:
if self._image_index == 0:
# quick return to avoid unnecessary draw
Expand All @@ -180,7 +199,7 @@ def _update_title(self):
)

def _update_displayed(self):
self._im.set_data(self._images[self._image_index])
self._im.set_data(self._get_image(self._image_index))
self._update_title()
self._fig.canvas.draw_idle()

Expand All @@ -194,7 +213,7 @@ def _key_press(self, event):
self._label_keymap[event.key]
]
if self._label_advances:
if self.image_index == len(self._images) - 1:
if self.image_index == self._N_images - 1:
# make sure we update the title we are on the last image
self._update_title()
self._fig.canvas.draw_idle()
Expand Down

0 comments on commit da46f0e

Please sign in to comment.