Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add SHIFT reference evaluators #45

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/mmdet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def prepare_train_img(self, idx):
img = self.get_img(idx)
img_info = self.get_img_info(idx)
ann_info = self.get_ann_info(idx)
# Filter out images without annotations during training
if len(ann_info["bboxes"]) == 0:
return None
results = dict(img=img, img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
Expand Down
3 changes: 2 additions & 1 deletion examples/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def main():
Keys.segmentation_masks,
],
views_to_load=["front"],
backend=ZipBackend(), # also supports HDF5Backend(), FileBackend()
shift_type="discrete", # also supports "continuous/1x", "continuous/10x", "continuous/100x"
backend=ZipBackend(), # also supports HDF5Backend(), FileBackend()
verbose=True,
)

Expand Down
42 changes: 32 additions & 10 deletions shift_dev/dataloader/shift_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
data_file: str = "",
annotation_file: str = "",
view: str = "front",
framerate: str = "images",
shift_type: str = "discrete",
backend: DataBackend = HDF5Backend(),
verbose: bool = False,
num_workers: int = 1,
Expand All @@ -80,14 +82,22 @@ def __init__(
assert view in _SHIFTScalabelLabels.VIEWS, f"Invalid view '{view}'"

# Set attributes
annotation_path = os.path.join(
data_root, "discrete", "images", split, view, annotation_file
)
ext = _get_extension(backend)
data_path = os.path.join(
data_root, "discrete", "images", split, view, f"{data_file}{ext}"
)

if shift_type.startswith("continuous"):
shift_speed = shift_type.split("/")[-1]
annotation_path = os.path.join(
data_root, "continuous", framerate, shift_speed, split, view, annotation_file
)
data_path = os.path.join(
data_root, "continuous", framerate, shift_speed, split, view, f"{data_file}{ext}"
)
else:
annotation_path = os.path.join(
data_root, "discrete", framerate, split, view, annotation_file
)
data_path = os.path.join(
data_root, "discrete", framerate, split, view, f"{data_file}{ext}"
)
super().__init__(data_path, annotation_path, data_backend=backend, **kwargs)

def _generate_mapping(self) -> ScalabelData:
Expand Down Expand Up @@ -261,9 +271,17 @@ def __init__(
self.backend = backend
self.verbose = verbose
self.ext = _get_extension(backend)
self.annotation_base = os.path.join(
self.data_root, self.shift_type, self.framerate, self.split
)
if self.shift_type.startswith("continuous"):
shift_speed = self.shift_type.split("/")[-1]
self.annotation_base = os.path.join(
self.data_root, "continuous", self.framerate, shift_speed, self.split
)
else:
self.annotation_base = os.path.join(
self.data_root, self.shift_type, self.framerate, self.split
)
if self.verbose:
logger.info(f"Base: {self.annotation_base}. Backend: {self.backend}")

# Get the data groups' classes that need to be loaded
self._data_groups_to_load = self._get_data_groups(keys_to_load)
Expand All @@ -283,6 +301,8 @@ def __init__(
data_file="lidar",
annotation_file="det_3d.json",
view=view,
framerate=self.framerate,
shift_type=self.shift_type,
keys_to_load=(Keys.points3d, *self.DATA_GROUPS["det_3d"]),
backend=backend,
num_workers=num_workers,
Expand All @@ -304,6 +324,8 @@ def __init__(
data_file="img",
annotation_file=f"{group}.json",
view=view,
framerate=self.framerate,
shift_type=self.shift_type,
keys_to_load=keys_to_load,
backend=backend,
num_workers=num_workers,
Expand Down
6 changes: 6 additions & 0 deletions shift_dev/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .depth import DepthEvaluator
from .flow import OpticalFlowEvaluator
from .semseg import SemanticSegmentationEvaluator


__all__ = ["DepthEvaluator", "OpticalFlowEvaluator", "SemanticSegmentationEvaluator"]
32 changes: 32 additions & 0 deletions shift_dev/evaluator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""SHIFT base evaluation."""
from __future__ import annotations

from typing import Any

import numpy as np


class Evaluator:
"""Abstract evaluator class."""

METRICS: list[str] = []

def __init__(self) -> None:
"""Initialize evaluator."""
self.reset()

def reset(self) -> None:
"""Reset evaluator for new round of evaluation."""
self.metrics = {metric: [] for metric in self.METRICS}

def process(self, *args: Any) -> None: # type: ignore
"""Process a batch of data."""
raise NotImplementedError

def evaluate(self) -> dict[str, float]:
"""Evaluate all predictions according to given metric.

Returns:
dict[str, float]: Evaluation results.
"""
return {metric: np.nanmean(values) for metric, values in self.metrics.items()}
65 changes: 65 additions & 0 deletions shift_dev/evaluator/depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Depth estimation evaluator."""

from __future__ import annotations

from .base import Evaluator

import numpy as np


class DepthEvaluator(Evaluator):

METRICS = ["mae", "silog"]

def __init__(self, min_depth: float = 0.5, max_depth: float = 80.0) -> None:
"""Initialize the depth evaluator."""
self.min_depth = min_depth
self.max_depth = max_depth
super().__init__()

def mean_absolute_error(self, pred, target):
"""Compute the mean absolute error.

Args:
pred (np.array): Prediction depth map, in shape (H, W).
target (np.array): Target depth map, in shape (H, W).

Returns:
float: Mean absolute error.
"""
mask = (target > self.min_depth) & (target < self.max_depth)
return np.mean(np.abs(pred[mask] - target[mask]))

def silog(self, pred, target, eps=1e-6):
"""Compute the scale-invariant log error of KITTI.

Args:
pred (np.array): Prediction depth map, in shape (H, W).
target (np.array): Target depth map, in shape (H, W).
eps (float, optional): Epsilon. Defaults to 1e-6.

Returns:
float: Silog error.
"""
mask = (target > self.min_depth) & (target < self.max_depth)
log_diff = np.log(target[mask] + eps) - np.log(pred[mask] + eps)
return np.sqrt(np.mean(log_diff ** 2))

def process(self, prediction: np.array, target: np.array) -> None:
"""Process a batch of data.

Args:
prediction (np.array): Prediction depth map.
target (np.array): Target depth map.
"""
mae = self.mean_absolute_error(prediction, target)
silog = self.silog(prediction, target)
self.metrics.update({"mae": mae, "silog": silog})

def evaluate(self) -> dict[str, float]:
"""Evaluate all predictions according to given metric.

Returns:
dict[str, float]: Evaluation results.
"""
return {metric: np.nanmean(self.metrics[metric]) for metric in self.metrics}
113 changes: 113 additions & 0 deletions shift_dev/evaluator/detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Detection evaluator."""

from __future__ import annotations

from .base import Evaluator

import numpy as np
import pycocotools.mask as maskUtils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval


class DetectionEvaluator(Evaluator):

METRICS = ["mAP", "mAP_50", "mAP_75", "mAP_s", "mAP_m", "mAP_l"]

def __init__(
self,
num_classes: int = 5,
iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
bbox_format: str = "xywh",
iou_type: str = "bbox",
) -> None:
"""Initialize the detection evaluator."""
self.num_classes = num_classes
self.iou_thresholds = iou_thresholds
self.bbox_format = bbox_format
self.iou_type = iou_type
assert self.bbox_format in ["xywh", "xyxy"], "Invalid bbox format."
assert self.iou_type in ["bbox", "segm"], "Invalid iou type."
super().__init__()

def reset(self) -> None:
"""Reset evaluator for new round of evaluation."""
self._predictions = []
self._targets = []

def process(self, prediction: dict[str, np.array], target: dict[str, np.array]) -> None:
"""Process a batch of data.

Args:
prediction (np.array): Prediction data dictionary.
target (np.array): Target data dictionary.

Note:
The data dictionary should contain the following keys:
- bbox: Bounding box array of shape (4,) in the format [x, y, w, h].
- class_probs: Class logits array of shape (num_classes,).
- score: Prediction score.
- mask: Segmentation mask, in shape (H, W), with 0 for background and 1 for foreground.
"""
# convert to COCO format
if self.bbox_format == "xyxy":
# convert to [x, y, w, h] format
prediction["bbox"][:, 2] -= prediction["bbox"][:, 0]
prediction["bbox"][:, 3] -= prediction["bbox"][:, 1]
target["bbox"][:, 2] -= target["bbox"][:, 0]
target["bbox"][:, 3] -= target["bbox"][:, 1]

class_predicted = np.argmax(prediction["class_probs"], axis=-1)
class_score = np.max(prediction["class_probs"], axis=-1)
pred = {
"image_id": len(self._predictions),
"bbox": prediction["bbox"],
"category_id": class_predicted,
"score": class_score,
}
tgt = {
"image_id": len(self._targets),
"category_id": target["category_id"],
"bbox": target["bbox"],
}

if "segmentation" in prediction:
pred["segmentation"] = maskUtils.encode(
np.array(prediction["mask"], order="F", dtype=np.uint8)
)
tgt["segmentation"] = maskUtils.encode(
np.array(target["mask"], order="F", dtype=np.uint8)
)
self._predictions.append(pred)
self._targets.append(tgt)

def evaluate(self) -> dict[str, float]:
"""Evaluate all predictions according to given metric.

Returns:
dict[str, float]: Evaluation results.
"""
coco_pred = COCO()
coco_pred.dataset["images"] = [img for img in self._predictions]
coco_pred.dataset["categories"] = [{"id": i} for i in range(self.num_classes)]
coco_pred.createIndex()

coco_gt = COCO()
coco_gt.dataset["images"] = [img for img in self._targets]
coco_gt.dataset["categories"] = [{"id": i} for i in range(self.num_classes)]
coco_gt.createIndex()

coco_eval = COCOeval(coco_gt, coco_pred, self.iou_type)
coco_eval.params.iouThrs = self.iou_thresholds
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()

return {
"mAP": coco_eval.stats[0],
"mAP_50": coco_eval.stats[1],
"mAP_75": coco_eval.stats[2],
"mAP_s": coco_eval.stats[3],
"mAP_m": coco_eval.stats[4],
"mAP_l": coco_eval.stats[5],
}
41 changes: 41 additions & 0 deletions shift_dev/evaluator/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Optical flow estimation evaluator."""

from __future__ import annotations

from .base import Evaluator

import numpy as np


class OpticalFlowEvaluator(Evaluator):

METRICS = ["epe"]

def __init__(self, max_flow: float = 400.0) -> None:
"""Initialize the optical flow evaluator."""
self.max_flow = max_flow
super().__init__()

def end_point_error(self, pred, target):
"""Compute the end point error.

Args:
pred (np.array): Prediction optical flow, in shape (H, W, 2).
target (np.array): Target optical flow, in shape (H, W, 2).

Returns:
float: End point error.
"""
mask = np.sum(np.abs(target), axis=2) < self.max_flow
return np.mean(np.sqrt(np.sum((pred[mask] - target[mask]) ** 2, axis=1)))

def process(self, prediction: np.array, target: np.array) -> None:
"""Process a batch of data.

Args:
prediction (np.array): Prediction optical flow, in shape (H, W, 2).
target (np.array): Target optical flow, in shape (H, W, 2).
"""
epe = self.end_point_error(prediction, target)
self.metrics.update({"epe": epe})

Loading