diff --git a/bytetracker/byte_tracker.py b/bytetracker/byte_tracker.py index 5e0f8f0..f9c0d3c 100644 --- a/bytetracker/byte_tracker.py +++ b/bytetracker/byte_tracker.py @@ -5,8 +5,17 @@ from bytetracker.kalman_filter import KalmanFilter -def xywh2xyxy(x): - # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right +def xywh2xyxy(x: np.ndarray): + """ + Converts bounding boxes from [x, y, w, h] format to [x1, y1, x2, y2] format + + Parameters + ---------- + x: Array at [x, y, w, h] format + Returns + ------- + y: Array [x1, y1, x2, y2] format + """ y = np.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y @@ -15,8 +24,18 @@ def xywh2xyxy(x): return y -def xyxy2xywh(x): - # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right +def xyxy2xywh(x: np.ndarray): + """ + Converts bounding boxes from [x1, y1, x2, y2] format to [x, y, w, h] format + + Parameters + ---------- + x: Array at [x1, y1, x2, y2] format + + Returns + ------- + y: Array at [x, y, w, h] format + """ y = np.copy(x) y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center @@ -39,13 +58,25 @@ def __init__(self, tlwh, score, cls): self.cls = cls def predict(self): + """ + updates the mean and covariance using a Kalman filter prediction, with a condition + based on the state of the track. + """ mean_state = self.mean.copy() if self.state != TrackState.Tracked: mean_state[7] = 0 self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) @staticmethod - def multi_predict(stracks): + def multi_predict(stracks: list["STrack"]): + """ + takes a list of tracks, updates their mean and covariance values, and + performs a Kalman filter prediction step. + + Parameters + ---------- + stracks (list): list of STrack objects + """ if len(stracks) > 0: multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_covariance = np.asarray([st.covariance for st in stracks]) @@ -59,8 +90,16 @@ def multi_predict(stracks): stracks[i].mean = mean stracks[i].covariance = cov - def activate(self, kalman_filter, frame_id): - """Start a new tracklet""" + def activate(self, kalman_filter: KalmanFilter, frame_id: int): + """ + initializes a new tracklet with a Kalman filter and assigns a track ID and + state based on the frame ID. + + Parameters + ---------- + kalman_filter: Kalman filter object + frame_id (int): The `frame_id` parameter in the `activate` method. + """ self.kalman_filter = kalman_filter self.track_id = self.next_id() self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) @@ -72,7 +111,19 @@ def activate(self, kalman_filter, frame_id): self.frame_id = frame_id self.start_frame = frame_id - def re_activate(self, new_track, frame_id, new_id=False): + def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False): + """ + Updates a track using Kalman filtering + + Parameters + ---------- + new_track : STrack + The new track object to update. + frame_id : int + The frame ID. + new_id : bool + Whether to assign a new ID to the track, by default False. + """ self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh) ) @@ -84,13 +135,16 @@ def re_activate(self, new_track, frame_id, new_id=False): self.score = new_track.score self.cls = new_track.cls - def update(self, new_track, frame_id): + def update(self, new_track: "STrack", frame_id: int): """ - Update a matched track - :type new_track: STrack - :type frame_id: int - :type update_feature: bool - :return: + Update a matched track. + + Parameters + ---------- + new_track : STrack + The new track object to update. + frame_id : int + The frame ID. """ self.frame_id = frame_id self.cls = new_track.cls @@ -120,7 +174,8 @@ def tlwh(self): @property # @jit(nopython=True) def tlbr(self): - """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + """ + Convert bounding box to format `(min x, min y, max x, max y)`, i.e., `(top left, bottom right)`. """ ret = self.tlwh.copy() @@ -130,7 +185,8 @@ def tlbr(self): @staticmethod # @jit(nopython=True) def tlwh_to_xyah(tlwh): - """Convert bounding box to format `(center x, center y, aspect ratio, + """ + Convert bounding box to format `(center x, center y, aspect ratio, height)`, where the aspect ratio is `width / height`. """ ret = np.asarray(tlwh).copy() @@ -167,7 +223,23 @@ def reset(self): self.kalman_filter = KalmanFilter() BaseTrack._count = 0 - def update(self, dets, frame_id): + def update(self, dets: np.ndarray, frame_id: int): + """ + Performs object tracking by associating detections with existing tracks and updating their states accordingly. + + Parameters + ---------- + dets : np.ndarray + Detection boxes of objects in the format (n x 6), where each row contains (x1, y1, x2, y2, score, class). + frame_id : int + The ID of the current frame in the video. + + Returns + ------- + np.ndarray + An array of outputs containing bounding box coordinates, track ID, class label, and + score for each tracked object. + """ self.frame_id = frame_id activated_starcks = [] refind_stracks = [] @@ -318,7 +390,23 @@ def update(self, dets, frame_id): return outputs -def joint_stracks(tlista, tlistb): +def joint_stracks(tlista: list["STrack"], tlistb: list["STrack"]): + """ + Merges two lists of objects based on a specific attribute while + ensuring no duplicates are added. + + Parameters + ---------- + tlista : List[STrack] + list of STrack objects. + tlistb : List[STrack] + list of STrack objects. + + Returns + ------- + List[STrack] + A list containing all unique elements from both input lists. + """ exists = {} res = [] for t in tlista: @@ -332,7 +420,22 @@ def joint_stracks(tlista, tlistb): return res -def sub_stracks(tlista, tlistb): +def sub_stracks(tlista: list["STrack"], tlistb: list["STrack"]): + """ + Returns a list of STrack objects that are present in tlista but not in tlistb. + + Parameters + ---------- + tlista : List[STrack] + list of STrack objects. + tlistb : List[STrack] + list of STrack objects. + + Returns + ------- + List[STrack] + A list containing STrack objects present in tlista but not in tlistb. + """ stracks = {} for t in tlista: stracks[t.track_id] = t @@ -343,7 +446,22 @@ def sub_stracks(tlista, tlistb): return list(stracks.values()) -def remove_duplicate_stracks(stracksa, stracksb): +def remove_duplicate_stracks(stracksa: list["STrack"], stracksb: list["STrack"]): + """ + Removes duplicate STrack objects from the input lists based on their frame IDs. + + Parameters + ---------- + stracksa : List[STrack] + list of STrack objects. + stracksb : List[STrack] + list of STrack objects. + + Returns + ------- + Tuple[List[STrack], List[STrack]] + Two lists containing unique STrack objects after removing duplicates. + """ pdist = matching.iou_distance(stracksa, stracksb) pairs = np.where(pdist < 0.15) dupa, dupb = list(), list() diff --git a/bytetracker/kalman_filter.py b/bytetracker/kalman_filter.py index 2259754..06f920f 100644 --- a/bytetracker/kalman_filter.py +++ b/bytetracker/kalman_filter.py @@ -68,7 +68,7 @@ def initiate(self, measurement): covariance = np.diag(np.square(std)) return mean, covariance - def predict(self, mean, covariance): + def predict(self, mean: np.ndarray, covariance: np.ndarray): """Run Kalman filter prediction step. Parameters @@ -109,7 +109,7 @@ def predict(self, mean, covariance): return mean, covariance - def project(self, mean, covariance): + def project(self, mean: np.ndarray, covariance: np.ndarray): """Project state distribution to measurement space. Parameters @@ -138,7 +138,7 @@ def project(self, mean, covariance): covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) return mean, covariance + innovation_cov - def multi_predict(self, mean, covariance): + def multi_predict(self, mean: np.ndarray, covariance: np.ndarray): """Run Kalman filter prediction step (Vectorized version). Parameters ---------- @@ -179,7 +179,7 @@ def multi_predict(self, mean, covariance): return mean, covariance - def update(self, mean, covariance, measurement): + def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray): """Run Kalman filter correction step. Parameters diff --git a/bytetracker/matching.py b/bytetracker/matching.py index 2e198d0..753fbc9 100644 --- a/bytetracker/matching.py +++ b/bytetracker/matching.py @@ -2,7 +2,21 @@ import numpy as np -def linear_assignment(cost_matrix, thresh): +def linear_assignment(cost_matrix: np.ndarray, thresh: float): + """ + Assigns detections to existing tracks based on a cost matrix using linear assignment. + + Parameters + ---------- + cost_matrix : np.ndarray + The cost matrix representing the association cost between detections and tracks. + thresh : float + The threshold for cost matching. + + Returns + ------- + Tuple containing matches, unmatched detections, and unmatched tracks. + """ if cost_matrix.size == 0: return ( np.empty((0, 2), dtype=int), @@ -22,11 +36,19 @@ def linear_assignment(cost_matrix, thresh): def ious(atlbrs, btlbrs): """ - Compute cost based on IoU - :type atlbrs: list[tlbr] | np.ndarray - :type atlbrs: list[tlbr] | np.ndarray + Compute cost over Union (IoU) between bounding box pairs + + Parameters + ---------- + atlbrs : Union[list, np.ndarray] + The bounding boxes of the first set in (min x, min y, max x, max y) format. + btlbrs : Union[list, np.ndarray] + The bounding boxes of the second set in (min x, min y, max x, max y) format. - :rtype ious np.ndarray + Returns + ------- + np.ndarray + An array containing IoU values for each pair of bounding boxes. """ ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) if ious.size == 0: @@ -63,7 +85,22 @@ def iou_distance(atracks, btracks): return cost_matrix -def fuse_score(cost_matrix, detections): +def fuse_score(cost_matrix: np.ndarray, detections: np.ndarray): + """ + Fuse detection scores with similarity scores from a cost matrix. + + Parameters + ---------- + cost_matrix : np.ndarray + The cost matrix representing the dissimilarity between tracks and detections. + detections : np.ndarray + The array of detections, each containing a score. + + Returns + ------- + np.ndarray + The fused cost matrix, incorporating both similarity scores and detection scores. + """ if cost_matrix.size == 0: return cost_matrix iou_sim = 1 - cost_matrix @@ -74,7 +111,7 @@ def fuse_score(cost_matrix, detections): return fuse_cost -def bbox_ious(boxes, query_boxes): +def bbox_ious(boxes: np.ndarray, query_boxes: np.ndarray): """ Parameters ----------