diff --git a/mpl_image_labeller/_util.py b/mpl_image_labeller/_util.py index 1541926..19b9e44 100644 --- a/mpl_image_labeller/_util.py +++ b/mpl_image_labeller/_util.py @@ -52,7 +52,9 @@ def list_to_onehot(labels, classes): arr = np.zeros((len(labels), len(classes)), dtype=bool) for i, l in enumerate(labels): - if isinstance(l, str) or not isinstance(l, Iterable): + if l is None: + continue + elif isinstance(l, str) or not isinstance(l, Iterable): # str, or number, or something like that arr[i, lookup[l]] = True else: