Skip to content

Commit

Permalink
feat first typing attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
nmathieufact committed Mar 28, 2024
1 parent d3f0030 commit c532eee
Showing 1 changed file with 56 additions and 39 deletions.
95 changes: 56 additions & 39 deletions bytetracker/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from bytetracker.kalman_filter import KalmanFilter


def xywh2xyxy(x):
def xywh2xyxy(x: np.ndarray) -> np.ndarray:
"""
converts bounding boxes from [x, y, w, h] format to [x1, y1, x2, y2] format
param x(n x 4 array): array at [x, y, w, h] format
return: y(n x 4 array): array [x1, y1, x2, y2] format
Converts bounding boxes from [x, y, w, h] format to [x1, y1, x2, y2] format
Parameters
----------
x: Array at [x, y, w, h] format
Returns
-------
y: Array [x1, y1, x2, y2] format
"""
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = np.copy(x)
Expand All @@ -22,13 +25,17 @@ def xywh2xyxy(x):
return y


def xyxy2xywh(x):
def xyxy2xywh(x: np.ndarray) -> np.ndarray:
"""
converts bounding boxes from [x1, y1, x2, y2] format to [x, y, w, h] format
Converts bounding boxes from [x1, y1, x2, y2] format to [x, y, w, h] format
param x(n x 4 array): array at [x1, y1, x2, y2] format
Parameters
----------
x: Array at [x1, y1, x2, y2] format
return: y(n x 4 array): array [x, y, w, h] format
Returns
-------
y: Array at [x, y, w, h] format
"""
y = np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
Expand Down Expand Up @@ -62,12 +69,13 @@ def predict(self):
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)

@staticmethod
def multi_predict(stracks):
def multi_predict(stracks: list["STrack"]) -> None:
"""
takes a list of tracks, updates their mean and covariance values, and
performs a Kalman filter prediction step.
:param
Parameters
----------
stracks (list): list of STrack objects
"""
if len(stracks) > 0:
Expand All @@ -83,15 +91,15 @@ def multi_predict(stracks):
stracks[i].mean = mean
stracks[i].covariance = cov

def activate(self, kalman_filter, frame_id):
def activate(self, kalman_filter: KalmanFilter, frame_id: int) -> None:
"""
initializes a new tracklet with a Kalman filter and assigns a track ID and
state based on the frame ID.
:param
Parameters
----------
kalman_filter: Kalman filter object
frame_id (int): The `frame_id` parameter in the `activate` method represents the identifier of the
frame in which the tracklet is being activated.
frame_id (int): The `frame_id` parameter in the `activate` method.
"""
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
Expand All @@ -105,15 +113,19 @@ def activate(self, kalman_filter, frame_id):
self.frame_id = frame_id
self.start_frame = frame_id

def re_activate(self, new_track, frame_id, new_id=False):
def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False) -> None:
"""
updates a track using Kalman filtering and sets various attributes based
Updates a track using Kalman filtering and sets various attributes based
on input parameters.
:param
kalman_filter: Kalman filter object
frame_id (int): The `frame_id` parameter in the `activate` method represents the identifier of the
frame in which the tracklet is being re_activated.
Parameters
----------
new_track : STrack
The new track object to update.
frame_id : int
The frame ID.
new_id : bool, optional
Whether to assign a new ID to the track, by default False.
"""
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
Expand All @@ -126,13 +138,16 @@ def re_activate(self, new_track, frame_id, new_id=False):
self.score = new_track.score
self.cls = new_track.cls

def update(self, new_track, frame_id):
def update(self, new_track: "STrack", frame_id: int) -> None:
"""
Update a matched track
:type new_track: STrack
:type frame_id: int
:type update_feature: bool
:return:
Update a matched track.
Parameters
----------
new_track : STrack
The new track object to update.
frame_id : int
The frame ID.
"""
self.frame_id = frame_id
self.cls = new_track.cls
Expand Down Expand Up @@ -371,20 +386,22 @@ def update(self, dets, frame_id):
return outputs


def joint_stracks(tlista, tlistb):
def joint_stracks(tlista: list["STrack"], tlistb: list["STrack"]) -> list["STrack"]:
"""
merges two lists of objects based on a specific attribute while
Merges two lists of objects based on a specific attribute while
ensuring no duplicates are added.
:param tlista: It seems like you were about to provide some information about the `tlista`
parameter, but the message got cut off. Could you please provide more details or complete the
information so that I can assist you further?
:param tlistb: It seems like you have not provided the `tlistb` parameter for the `joint_stracks`
function. Could you please provide the `tlistb` parameter so that I can assist you further with the
function?
:return: The function `joint_stracks` returns a list that contains all unique elements from both
input lists `tlista` and `tlistb`. Duplicate elements are removed based on the `track_id` attribute
of the elements.
Parameters
----------
tlista : List[STrack]
list of STrack objects.
tlistb : List[STrack]
list of STrack objects.
Returns
-------
List[STrack]
A list containing all unique elements from both input lists.
"""
exists = {}
res = []
Expand Down

0 comments on commit c532eee

Please sign in to comment.