diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 0cdbdcb815..7b1d24270b 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -69,6 +69,9 @@ def __init__(self, data, affine=None, axes=None, title=None): self._title = title self._closed = False self._cross = True + self._overlay = None + self._threshold = None + self._alpha = 1 data = np.asanyarray(data) if data.ndim < 3: @@ -286,6 +289,111 @@ def clim(self, clim): self._clim = tuple(clim) self.draw() + @property + def overlay(self): + """The current overlay """ + return self._overlay + + @property + def threshold(self): + """The current data display threshold """ + return self._threshold + + @threshold.setter + def threshold(self, threshold): + # mask data array + if threshold is not None: + self._data = np.ma.masked_array(np.asarray(self._data), + np.asarray(self._data) <= threshold) + self._threshold = float(threshold) + else: + self._data = np.asarray(self._data) + self._threshold = threshold + + # update current volume data w/masked array and re-draw everything + if self._data.ndim > 3: + self._current_vol_data = self._data[..., self._data_idx[3]] + else: + self._current_vol_data = self._data + self._set_position(None, None, None, notify=False) + + @property + def alpha(self): + """ The current alpha (transparency) value """ + return self._alpha + + @alpha.setter + def alpha(self, alpha): + alpha = float(alpha) + if alpha > 1 or alpha < 0: + raise ValueError('alpha must be between 0 and 1') + for im in self._ims: + im.set_alpha(alpha) + self._alpha = alpha + self.draw() + + def set_overlay(self, data, affine=None, threshold=None, cmap='viridis'): + if affine is None: + try: # did we get an image? + affine = data.affine + data = data.dataobj + except AttributeError: + pass + + # check that we have sufficient information to match the overlays + if affine is None and data.shape[:3] != self._data.shape[:3]: + raise ValueError('Provided `data` do not match shape of ' + 'underlay and no `affine` matrix was ' + 'provided. Please provide an `affine` matrix ' + 'or resample first three dims of `data` to {}' + .format(self._data.shape[:3])) + + # we need to resample the provided data to the already-plotted data + if not np.allclose(affine, self._affine): + from .processing import resample_from_to + from .nifti1 import Nifti1Image + target_shape = self._data.shape[:3] + data.shape[3:] + # we can't just use SpatialImage because we need an image type + # where the spatial axes are _always_ first + data = resample_from_to(Nifti1Image(data, affine), + (target_shape, self._affine)).dataobj + affine = self._affine + + if self._overlay is not None: + # remove all images + cross hair lines + for nn, im in enumerate(self._overlay._ims): + im.remove() + for line in self._overlay._crosshairs[nn].values(): + line.remove() + # remove the fourth axis, if it was created for the overlay + if (self._overlay.n_volumes > 1 and len(self._overlay._axes) > 3 + and self.n_volumes == 1): + a = self._axes.pop(-1) + a.remove() + + # create an axis if we have a 4D overlay (vs a 3D underlay) + axes = self._axes + o_n_volumes = int(np.prod(data.shape[3:])) + if o_n_volumes > self.n_volumes: + axes += [axes[0].figure.add_subplot(224)] + elif o_n_volumes < self.n_volumes: + axes = axes[:-1] + + # mask array for provided threshold + self._overlay = self.__class__(data, affine=affine, axes=axes) + self._overlay.threshold = threshold + + # set transparency and new cmap + self._overlay.cmap = cmap + for im in self._overlay._ims: + im.set_alpha(0.7) + + # no double cross-hairs (they get confused when we have linked orthos) + for cross in self._overlay._crosshairs: + cross['horiz'].set_visible(False) + cross['vert'].set_visible(False) + self._overlay._draw() + def link_to(self, other): """Link positional changes between two canvases @@ -413,7 +521,7 @@ def _set_position(self, x, y, z, notify=True): idx = [slice(None)] * len(self._axes) for ii in range(3): idx[self._order[ii]] = self._data_idx[ii] - vdata = self._data[tuple(idx)].ravel() + vdata = np.asarray(self._data[tuple(idx)].ravel()) vdata = np.concatenate((vdata, [vdata[-1]])) self._volume_ax_objs['patch'].set_x(self._data_idx[3] - 0.5) self._volume_ax_objs['step'].set_ydata(vdata)