Skip to content

Commit

Permalink
Merge pull request #290 from roboflow/feature/yolov9
Browse files Browse the repository at this point in the history
Feature/yolov9
  • Loading branch information
paulguerrie authored Feb 27, 2024
2 parents 851f538 + b899024 commit 3efd23d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
YOLOv8KeypointsDetection,
YOLOv8ObjectDetection,
)
from inference.models.yolov9 import YOLOv9ObjectDetection
2 changes: 2 additions & 0 deletions inference/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
YOLOv8Classification,
YOLOv8InstanceSegmentation,
YOLOv8ObjectDetection,
YOLOv9ObjectDetection,
)
from inference.models.yolov8.yolov8_keypoints_detection import YOLOv8KeypointsDetection

Expand All @@ -37,6 +38,7 @@
("object-detection", "yolov5v6m"): YOLOv5ObjectDetection,
("object-detection", "yolov5v6l"): YOLOv5ObjectDetection,
("object-detection", "yolov5v6x"): YOLOv5ObjectDetection,
("object-detection", "yolov9"): YOLOv9ObjectDetection,
("object-detection", "yolov8"): YOLOv8ObjectDetection,
("object-detection", "yolov8s"): YOLOv8ObjectDetection,
("object-detection", "yolov8n"): YOLOv8ObjectDetection,
Expand Down
1 change: 1 addition & 0 deletions inference/models/yolov9/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from inference.models.yolov9.yolov9_object_detection import YOLOv9ObjectDetection
45 changes: 45 additions & 0 deletions inference/models/yolov9/yolov9_object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Tuple

import numpy as np

from inference.core.models.object_detection_base import (
ObjectDetectionBaseOnnxRoboflowInferenceModel,
)


class YOLOv9ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel):
"""Roboflow ONNX Object detection model (Implements an object detection specific infer method).
This class is responsible for performing object detection using the YOLOv9 model
with ONNX runtime.
Attributes:
weights_file (str): Path to the ONNX weights file.
"""

@property
def weights_file(self) -> str:
"""Gets the weights file for the YOLOv9 model.
Returns:
str: Path to the ONNX weights file.
"""
return "weights.onnx"

def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
"""Performs object detection on the given image using the ONNX session.
Args:
img_in (np.ndarray): Input image as a NumPy array.
Returns:
Tuple[np.ndarray]: NumPy array representing the predictions.
"""
# (b x 8 x 8000)
predictions = self.onnx_session.run(None, {self.input_name: img_in})[0]
predictions = predictions.transpose(0, 2, 1)
boxes = predictions[:, :, :4]
class_confs = predictions[:, :, 4:]
confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
predictions = np.concatenate([boxes, confs, class_confs], axis=2)
return (predictions,)

0 comments on commit 3efd23d

Please sign in to comment.