diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 2d0dd9b90..68f8a6d5b 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -521,7 +521,7 @@ def test_dataset_for_input( # TODO: Some method of configuring which vectorizer to use. from tcn_hpl.data.vectorize.locs_and_confs import LocsAndConfs - num_object_classes = len(dets_coco.cats) + num_object_classes = max(dets_coco.cats) + 1 vectorize = LocsAndConfs( top_k=1, num_classes=num_object_classes,