Skip to content

Commit

Permalink
feat: typing functions and adding some docstring (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
nmathieufact authored Apr 17, 2024
1 parent 4d8a20d commit 4fb1460
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 31 deletions.
158 changes: 138 additions & 20 deletions bytetracker/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@
from bytetracker.kalman_filter import KalmanFilter


def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
def xywh2xyxy(x: np.ndarray):
"""
Converts bounding boxes from [x, y, w, h] format to [x1, y1, x2, y2] format
Parameters
----------
x: Array at [x, y, w, h] format
Returns
-------
y: Array [x1, y1, x2, y2] format
"""
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
Expand All @@ -15,8 +24,18 @@ def xywh2xyxy(x):
return y


def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
def xyxy2xywh(x: np.ndarray):
"""
Converts bounding boxes from [x1, y1, x2, y2] format to [x, y, w, h] format
Parameters
----------
x: Array at [x1, y1, x2, y2] format
Returns
-------
y: Array at [x, y, w, h] format
"""
y = np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
Expand All @@ -39,13 +58,25 @@ def __init__(self, tlwh, score, cls):
self.cls = cls

def predict(self):
"""
updates the mean and covariance using a Kalman filter prediction, with a condition
based on the state of the track.
"""
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)

@staticmethod
def multi_predict(stracks):
def multi_predict(stracks: list["STrack"]):
"""
takes a list of tracks, updates their mean and covariance values, and
performs a Kalman filter prediction step.
Parameters
----------
stracks (list): list of STrack objects
"""
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
Expand All @@ -59,8 +90,16 @@ def multi_predict(stracks):
stracks[i].mean = mean
stracks[i].covariance = cov

def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
def activate(self, kalman_filter: KalmanFilter, frame_id: int):
"""
initializes a new tracklet with a Kalman filter and assigns a track ID and
state based on the frame ID.
Parameters
----------
kalman_filter: Kalman filter object
frame_id (int): The `frame_id` parameter in the `activate` method.
"""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
Expand All @@ -72,7 +111,19 @@ def activate(self, kalman_filter, frame_id):
self.frame_id = frame_id
self.start_frame = frame_id

def re_activate(self, new_track, frame_id, new_id=False):
def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
"""
Updates a track using Kalman filtering
Parameters
----------
new_track : STrack
The new track object to update.
frame_id : int
The frame ID.
new_id : bool
Whether to assign a new ID to the track, by default False.
"""
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
)
Expand All @@ -84,13 +135,16 @@ def re_activate(self, new_track, frame_id, new_id=False):
self.score = new_track.score
self.cls = new_track.cls

def update(self, new_track, frame_id):
def update(self, new_track: "STrack", frame_id: int):
"""
Update a matched track
:type new_track: STrack
:type frame_id: int
:type update_feature: bool
:return:
Update a matched track.
Parameters
----------
new_track : STrack
The new track object to update.
frame_id : int
The frame ID.
"""
self.frame_id = frame_id
self.cls = new_track.cls
Expand Down Expand Up @@ -120,7 +174,8 @@ def tlwh(self):
@property
# @jit(nopython=True)
def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
Expand All @@ -130,7 +185,8 @@ def tlbr(self):
@staticmethod
# @jit(nopython=True)
def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio,
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
Expand Down Expand Up @@ -167,7 +223,23 @@ def reset(self):
self.kalman_filter = KalmanFilter()
BaseTrack._count = 0

def update(self, dets, frame_id):
def update(self, dets: np.ndarray, frame_id: int):
"""
Performs object tracking by associating detections with existing tracks and updating their states accordingly.
Parameters
----------
dets : np.ndarray
Detection boxes of objects in the format (n x 6), where each row contains (x1, y1, x2, y2, score, class).
frame_id : int
The ID of the current frame in the video.
Returns
-------
np.ndarray
An array of outputs containing bounding box coordinates, track ID, class label, and
score for each tracked object.
"""
self.frame_id = frame_id
activated_starcks = []
refind_stracks = []
Expand Down Expand Up @@ -318,7 +390,23 @@ def update(self, dets, frame_id):
return outputs


def joint_stracks(tlista, tlistb):
def joint_stracks(tlista: list["STrack"], tlistb: list["STrack"]):
"""
Merges two lists of objects based on a specific attribute while
ensuring no duplicates are added.
Parameters
----------
tlista : List[STrack]
list of STrack objects.
tlistb : List[STrack]
list of STrack objects.
Returns
-------
List[STrack]
A list containing all unique elements from both input lists.
"""
exists = {}
res = []
for t in tlista:
Expand All @@ -332,7 +420,22 @@ def joint_stracks(tlista, tlistb):
return res


def sub_stracks(tlista, tlistb):
def sub_stracks(tlista: list["STrack"], tlistb: list["STrack"]):
"""
Returns a list of STrack objects that are present in tlista but not in tlistb.
Parameters
----------
tlista : List[STrack]
list of STrack objects.
tlistb : List[STrack]
list of STrack objects.
Returns
-------
List[STrack]
A list containing STrack objects present in tlista but not in tlistb.
"""
stracks = {}
for t in tlista:
stracks[t.track_id] = t
Expand All @@ -343,7 +446,22 @@ def sub_stracks(tlista, tlistb):
return list(stracks.values())


def remove_duplicate_stracks(stracksa, stracksb):
def remove_duplicate_stracks(stracksa: list["STrack"], stracksb: list["STrack"]):
"""
Removes duplicate STrack objects from the input lists based on their frame IDs.
Parameters
----------
stracksa : List[STrack]
list of STrack objects.
stracksb : List[STrack]
list of STrack objects.
Returns
-------
Tuple[List[STrack], List[STrack]]
Two lists containing unique STrack objects after removing duplicates.
"""
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = list(), list()
Expand Down
8 changes: 4 additions & 4 deletions bytetracker/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def initiate(self, measurement):
covariance = np.diag(np.square(std))
return mean, covariance

def predict(self, mean, covariance):
def predict(self, mean: np.ndarray, covariance: np.ndarray):
"""Run Kalman filter prediction step.
Parameters
Expand Down Expand Up @@ -109,7 +109,7 @@ def predict(self, mean, covariance):

return mean, covariance

def project(self, mean, covariance):
def project(self, mean: np.ndarray, covariance: np.ndarray):
"""Project state distribution to measurement space.
Parameters
Expand Down Expand Up @@ -138,7 +138,7 @@ def project(self, mean, covariance):
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov

def multi_predict(self, mean, covariance):
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
"""Run Kalman filter prediction step (Vectorized version).
Parameters
----------
Expand Down Expand Up @@ -179,7 +179,7 @@ def multi_predict(self, mean, covariance):

return mean, covariance

def update(self, mean, covariance, measurement):
def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
"""Run Kalman filter correction step.
Parameters
Expand Down
51 changes: 44 additions & 7 deletions bytetracker/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@
import numpy as np


def linear_assignment(cost_matrix, thresh):
def linear_assignment(cost_matrix: np.ndarray, thresh: float):
"""
Assigns detections to existing tracks based on a cost matrix using linear assignment.
Parameters
----------
cost_matrix : np.ndarray
The cost matrix representing the association cost between detections and tracks.
thresh : float
The threshold for cost matching.
Returns
-------
Tuple containing matches, unmatched detections, and unmatched tracks.
"""
if cost_matrix.size == 0:
return (
np.empty((0, 2), dtype=int),
Expand All @@ -22,11 +36,19 @@ def linear_assignment(cost_matrix, thresh):

def ious(atlbrs, btlbrs):
"""
Compute cost based on IoU
:type atlbrs: list[tlbr] | np.ndarray
:type atlbrs: list[tlbr] | np.ndarray
Compute cost over Union (IoU) between bounding box pairs
Parameters
----------
atlbrs : Union[list, np.ndarray]
The bounding boxes of the first set in (min x, min y, max x, max y) format.
btlbrs : Union[list, np.ndarray]
The bounding boxes of the second set in (min x, min y, max x, max y) format.
:rtype ious np.ndarray
Returns
-------
np.ndarray
An array containing IoU values for each pair of bounding boxes.
"""
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if ious.size == 0:
Expand Down Expand Up @@ -63,7 +85,22 @@ def iou_distance(atracks, btracks):
return cost_matrix


def fuse_score(cost_matrix, detections):
def fuse_score(cost_matrix: np.ndarray, detections: np.ndarray):
"""
Fuse detection scores with similarity scores from a cost matrix.
Parameters
----------
cost_matrix : np.ndarray
The cost matrix representing the dissimilarity between tracks and detections.
detections : np.ndarray
The array of detections, each containing a score.
Returns
-------
np.ndarray
The fused cost matrix, incorporating both similarity scores and detection scores.
"""
if cost_matrix.size == 0:
return cost_matrix
iou_sim = 1 - cost_matrix
Expand All @@ -74,7 +111,7 @@ def fuse_score(cost_matrix, detections):
return fuse_cost


def bbox_ious(boxes, query_boxes):
def bbox_ious(boxes: np.ndarray, query_boxes: np.ndarray):
"""
Parameters
----------
Expand Down

0 comments on commit 4fb1460

Please sign in to comment.