-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add DepthMerger host node. * Improvements. * Tests for DepthMerger. --------- Co-authored-by: klemen1999 <[email protected]>
- Loading branch information
1 parent
95d97be
commit 9edfea8
Showing
9 changed files
with
951 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .depth_merger import DepthMerger | ||
|
||
__all__ = ["DepthMerger"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from typing import Union | ||
|
||
import depthai as dai | ||
|
||
from depthai_nodes.ml.messages import ImgDetectionExtended, ImgDetectionsExtended | ||
|
||
from .host_spatials_calc import HostSpatialsCalc | ||
|
||
|
||
class DepthMerger(dai.node.HostNode): | ||
"""DepthMerger is a custom host node for merging 2D detections with depth | ||
information to produce spatial detections. | ||
Attributes | ||
---------- | ||
output : dai.Node.Output | ||
The output of the DepthMerger node containing dai.SpatialImgDetections. | ||
shrinking_factor : float | ||
The shrinking factor for the bounding box. 0 means no shrinking. The factor means the percentage of the bounding box to shrink from each side. | ||
Usage | ||
----- | ||
depth_merger = pipeline.create(DepthMerger).build( | ||
output_2d=nn.out, | ||
output_depth=stereo.depth | ||
) | ||
""" | ||
|
||
def __init__(self, shrinking_factor: float = 0) -> None: | ||
super().__init__() | ||
|
||
self.output = self.createOutput( | ||
possibleDatatypes=[ | ||
dai.Node.DatatypeHierarchy(dai.DatatypeEnum.SpatialImgDetections, True) | ||
] | ||
) | ||
|
||
self.shrinking_factor = shrinking_factor | ||
|
||
def build( | ||
self, | ||
output_2d: dai.Node.Output, | ||
output_depth: dai.Node.Output, | ||
calib_data: dai.CalibrationHandler, | ||
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A, | ||
shrinking_factor: float = 0, | ||
) -> "DepthMerger": | ||
self.link_args(output_2d, output_depth) | ||
self.shrinking_factor = shrinking_factor | ||
self.host_spatials_calc = HostSpatialsCalc(calib_data, depth_alignment_socket) | ||
return self | ||
|
||
def process(self, message_2d: dai.Buffer, depth: dai.ImgFrame) -> None: | ||
spatial_dets = self._transform(message_2d, depth) | ||
self.output.send(spatial_dets) | ||
|
||
def _transform( | ||
self, message_2d: dai.Buffer, depth: dai.ImgFrame | ||
) -> Union[dai.SpatialImgDetections, dai.SpatialImgDetection]: | ||
"""Transforms 2D detections into spatial detections based on the depth frame.""" | ||
if isinstance(message_2d, dai.ImgDetection): | ||
return self._detection_to_spatial(message_2d, depth) | ||
elif isinstance(message_2d, dai.ImgDetections): | ||
return self._detections_to_spatial(message_2d, depth) | ||
elif isinstance(message_2d, ImgDetectionExtended): | ||
return self._detection_to_spatial(message_2d, depth) | ||
elif isinstance(message_2d, ImgDetectionsExtended): | ||
return self._detections_to_spatial(message_2d, depth) | ||
else: | ||
raise ValueError(f"Unknown message type: {type(message_2d)}") | ||
|
||
def _detection_to_spatial( | ||
self, | ||
detection: Union[dai.ImgDetection, ImgDetectionExtended], | ||
depth: dai.ImgFrame, | ||
) -> dai.SpatialImgDetection: | ||
"""Converts a single 2D detection into a spatial detection using the depth | ||
frame.""" | ||
depth_frame = depth.getCvFrame() | ||
x_len = depth_frame.shape[1] | ||
y_len = depth_frame.shape[0] | ||
xmin = ( | ||
detection.rotated_rect.getOuterRect()[0] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.xmin | ||
) | ||
ymin = ( | ||
detection.rotated_rect.getOuterRect()[1] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.ymin | ||
) | ||
xmax = ( | ||
detection.rotated_rect.getOuterRect()[2] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.xmax | ||
) | ||
ymax = ( | ||
detection.rotated_rect.getOuterRect()[3] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.ymax | ||
) | ||
xmin += (xmax - xmin) * self.shrinking_factor | ||
ymin += (ymax - ymin) * self.shrinking_factor | ||
xmax -= (xmax - xmin) * self.shrinking_factor | ||
ymax -= (ymax - ymin) * self.shrinking_factor | ||
roi = [ | ||
self._get_index(xmin, x_len), | ||
self._get_index(ymin, y_len), | ||
self._get_index(xmax, x_len), | ||
self._get_index(ymax, y_len), | ||
] | ||
spatials = self.host_spatials_calc.calc_spatials(depth, roi) | ||
|
||
spatial_img_detection = dai.SpatialImgDetection() | ||
spatial_img_detection.xmin = xmin | ||
spatial_img_detection.ymin = ymin | ||
spatial_img_detection.xmax = xmax | ||
spatial_img_detection.ymax = ymax | ||
spatial_img_detection.spatialCoordinates = dai.Point3f( | ||
spatials["x"], spatials["y"], spatials["z"] | ||
) | ||
|
||
spatial_img_detection.confidence = detection.confidence | ||
spatial_img_detection.label = 0 if detection.label == -1 else detection.label | ||
return spatial_img_detection | ||
|
||
def _detections_to_spatial( | ||
self, | ||
detections: Union[dai.ImgDetections, ImgDetectionsExtended], | ||
depth: dai.ImgFrame, | ||
) -> dai.SpatialImgDetections: | ||
"""Converts multiple 2D detections into spatial detections using the depth | ||
frame.""" | ||
new_dets = dai.SpatialImgDetections() | ||
new_dets.detections = [ | ||
self._detection_to_spatial(d, depth) for d in detections.detections | ||
] | ||
new_dets.setSequenceNum(detections.getSequenceNum()) | ||
new_dets.setTimestamp(detections.getTimestamp()) | ||
return new_dets | ||
|
||
def _get_index(self, relative_coord: float, dimension_len: int) -> int: | ||
"""Converts a relative coordinate to an absolute index within the given | ||
dimension length.""" | ||
bounded_coord = min(1, relative_coord) | ||
return max(0, int(bounded_coord * dimension_len) - 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
from typing import Dict, List | ||
|
||
import depthai as dai | ||
import numpy as np | ||
|
||
|
||
class HostSpatialsCalc: | ||
"""HostSpatialsCalc is a helper class for calculating spatial coordinates from depth | ||
data. | ||
Attributes | ||
---------- | ||
calibData : dai.CalibrationHandler | ||
Calibration data handler for the device. | ||
depth_alignment_socket : dai.CameraBoardSocket | ||
The camera socket used for depth alignment. | ||
DELTA : int | ||
The delta value for ROI calculation. Default is 5 - means 10x10 depth pixels around point for depth averaging. | ||
THRESH_LOW : int | ||
The lower threshold for depth values. Default is 200 - means 20cm. | ||
THRESH_HIGH : int | ||
The upper threshold for depth values. Default is 30000 - means 30m. | ||
""" | ||
|
||
# We need device object to get calibration data | ||
def __init__( | ||
self, | ||
calib_data: dai.CalibrationHandler, | ||
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A, | ||
delta: int = 5, | ||
thresh_low: int = 200, | ||
thresh_high: int = 30000, | ||
): | ||
self.calibData = calib_data | ||
self.depth_alignment_socket = depth_alignment_socket | ||
|
||
self.delta = delta | ||
self.thresh_low = thresh_low | ||
self.thresh_high = thresh_high | ||
|
||
def setLowerThreshold(self, threshold_low: int) -> None: | ||
"""Sets the lower threshold for depth values. | ||
@param threshold_low: The lower threshold for depth values. | ||
@type threshold_low: int | ||
""" | ||
if not isinstance(threshold_low, int): | ||
if isinstance(threshold_low, float): | ||
threshold_low = int(threshold_low) | ||
else: | ||
raise TypeError( | ||
"Threshold has to be an integer or float! Got {}".format( | ||
type(threshold_low) | ||
) | ||
) | ||
self.thresh_low = threshold_low | ||
|
||
def setUpperThreshold(self, threshold_high: int) -> None: | ||
"""Sets the upper threshold for depth values. | ||
@param threshold_high: The upper threshold for depth values. | ||
@type threshold_high: int | ||
""" | ||
if not isinstance(threshold_high, int): | ||
if isinstance(threshold_high, float): | ||
threshold_high = int(threshold_high) | ||
else: | ||
raise TypeError( | ||
"Threshold has to be an integer or float! Got {}".format( | ||
type(threshold_high) | ||
) | ||
) | ||
self.thresh_high = threshold_high | ||
|
||
def setDeltaRoi(self, delta: int) -> None: | ||
"""Sets the delta value for ROI calculation. | ||
@param delta: The delta value for ROI calculation. | ||
@type delta: int | ||
""" | ||
if not isinstance(delta, int): | ||
if isinstance(delta, float): | ||
delta = int(delta) | ||
else: | ||
raise TypeError( | ||
"Delta has to be an integer or float! Got {}".format(type(delta)) | ||
) | ||
self.delta = delta | ||
|
||
def _check_input(self, roi: List[int], frame: np.ndarray) -> List[int]: | ||
"""Checks if the input is ROI or point and converts point to ROI if necessary. | ||
@param roi: The region of interest (ROI) or point. | ||
@type roi: List[int] | ||
@param frame: The depth frame. | ||
@type frame: np.ndarray | ||
@return: The region of interest (ROI). | ||
@rtype: List[int] | ||
""" | ||
if len(roi) == 4: | ||
return roi | ||
if len(roi) != 2: | ||
raise ValueError( | ||
"You have to pass either ROI (4 values) or point (2 values)!" | ||
) | ||
# Limit the point so ROI won't be outside the frame | ||
x = min(max(roi[0], self.delta), frame.shape[1] - self.delta) | ||
y = min(max(roi[1], self.delta), frame.shape[0] - self.delta) | ||
return (x - self.delta, y - self.delta, x + self.delta, y + self.delta) | ||
|
||
# roi has to be list of ints | ||
def calc_spatials( | ||
self, | ||
depthData: dai.ImgFrame, | ||
roi: List[int], | ||
averaging_method: callable = np.mean, | ||
) -> Dict[str, float]: | ||
"""Calculates spatial coordinates from depth data within the specified ROI. | ||
@param depthData: The depth data. | ||
@type depthData: dai.ImgFrame | ||
@param roi: The region of interest (ROI) or point. | ||
@type roi: List[int] | ||
@param averaging_method: The method for averaging the depth values. | ||
@type averaging_method: callable | ||
@return: The spatial coordinates. | ||
@rtype: Dict[str, float] | ||
""" | ||
depthFrame = depthData.getFrame() | ||
|
||
roi = self._check_input( | ||
roi, depthFrame | ||
) # If point was passed, convert it to ROI | ||
xmin, ymin, xmax, ymax = roi | ||
|
||
# Calculate the average depth in the ROI. | ||
depthROI = depthFrame[ymin:ymax, xmin:xmax] | ||
inRange = (self.thresh_low <= depthROI) & (depthROI <= self.thresh_high) | ||
|
||
averageDepth = averaging_method(depthROI[inRange]) | ||
|
||
centroid = np.array( # Get centroid of the ROI | ||
[ | ||
int((xmax + xmin) / 2), | ||
int((ymax + ymin) / 2), | ||
] | ||
) | ||
|
||
K = self.calibData.getCameraIntrinsics( | ||
cameraId=self.depth_alignment_socket, | ||
resizeWidth=depthFrame.shape[1], | ||
resizeHeight=depthFrame.shape[0], | ||
) | ||
K = np.array(K) | ||
K_inv = np.linalg.inv(K) | ||
homogenous_coords = np.array([centroid[0], centroid[1], 1]) | ||
spatial_coords = averageDepth * K_inv.dot(homogenous_coords) | ||
|
||
spatials = { | ||
"x": spatial_coords[0], | ||
"y": spatial_coords[1], | ||
"z": spatial_coords[2], | ||
} | ||
return spatials |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.