From 2c2dc7b1f0a586fabc5cc62a08621641e2de5be5 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Sat, 7 Dec 2024 14:22:24 +0100 Subject: [PATCH 01/12] yolov11n and yolov11m added to model selection --- README.md | 2 ++ benchmarks/Evaluate-Results.ipynb | 2 +- deepface/DeepFace.py | 16 +++++------ deepface/commons/weight_utils.py | 17 +++++++++--- deepface/models/face_detection/Yolo.py | 38 ++++++++++++++++++++++---- deepface/modules/demography.py | 2 +- deepface/modules/detection.py | 2 +- deepface/modules/modeling.py | 6 ++-- deepface/modules/recognition.py | 4 +-- deepface/modules/representation.py | 2 +- deepface/modules/streaming.py | 8 +++--- deepface/modules/verification.py | 2 +- tests/visual-test.py | 2 ++ 13 files changed, 72 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index c51652a9f..cbd1dc1d3 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,8 @@ backends = [ 'retinaface', 'mediapipe', 'yolov8', + 'yolov11n', + 'yolov11m', 'yunet', 'centerface', ] diff --git a/benchmarks/Evaluate-Results.ipynb b/benchmarks/Evaluate-Results.ipynb index e2a7172d7..16d29dce5 100644 --- a/benchmarks/Evaluate-Results.ipynb +++ b/benchmarks/Evaluate-Results.ipynb @@ -30,7 +30,7 @@ "source": [ "alignment = [False, True]\n", "models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\"]\n", - "detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n", + "detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yolov11n\", \"yolov11m\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n", "distance_metrics = [\"euclidean\", \"euclidean_l2\", \"cosine\"]" ] }, diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index af5245fd9..a95bcc534 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -56,7 +56,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet, + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing task (str): facial_recognition, facial_attribute, face_detector, spoofing @@ -96,7 +96,7 @@ def verify( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -187,7 +187,7 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -298,7 +298,7 @@ def find( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -396,7 +396,7 @@ def represent( (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -462,7 +462,7 @@ def stream( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -517,7 +517,7 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. @@ -601,7 +601,7 @@ def detectFace( added to resize the image (default is (224, 224)). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. diff --git a/deepface/commons/weight_utils.py b/deepface/commons/weight_utils.py index d6770c08e..421d85704 100644 --- a/deepface/commons/weight_utils.py +++ b/deepface/commons/weight_utils.py @@ -128,8 +128,9 @@ def download_all_models_in_one_shot() -> None: WEIGHTS_URL as SSD_WEIGHTS, ) from deepface.models.face_detection.Yolo import ( - WEIGHT_URL as YOLOV8_WEIGHTS, - WEIGHT_NAME as YOLOV8_WEIGHT_NAME, + WEIGHT_URLS as YOLO_WEIGHTS, + WEIGHT_NAMES as YOLO_WEIGHT_NAMES, + YoloModel ) from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS @@ -162,8 +163,16 @@ def download_all_models_in_one_shot() -> None: SSD_MODEL, SSD_WEIGHTS, { - "filename": YOLOV8_WEIGHT_NAME, - "url": YOLOV8_WEIGHTS, + "filename": YOLO_WEIGHT_NAMES[YoloModel.V8N.value], + "url": YOLO_WEIGHTS[YoloModel.V8N.value], + }, + { + "filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value], + "url": YOLO_WEIGHTS[YoloModel.V11N.value], + }, + { + "filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value], + "url": YOLO_WEIGHTS[YoloModel.V11M.value], }, YUNET_WEIGHTS, DLIB_FD_WEIGHTS, diff --git a/deepface/models/face_detection/Yolo.py b/deepface/models/face_detection/Yolo.py index 77dd09b32..f0826e140 100644 --- a/deepface/models/face_detection/Yolo.py +++ b/deepface/models/face_detection/Yolo.py @@ -1,6 +1,7 @@ # built-in dependencies import os from typing import Any, List +from enum import Enum # 3rd party dependencies import numpy as np @@ -13,17 +14,27 @@ logger = Logger() # Model's weights paths -WEIGHT_NAME = "yolov8n-face.pt" +WEIGHT_NAMES = ["yolov8n-face.pt", + "yolov11n-face.pt", + "yolov11m-face.pt"] # Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB -WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb" +WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] + + +class YoloModel(Enum): + V8N = 0 + V11N = 1 + V11M = 2 class YoloClient(Detector): - def __init__(self): - self.model = self.build_model() + def __init__(self, model: YoloModel): + self.model = self.build_model(model) - def build_model(self) -> Any: + def build_model(self, model: YoloModel) -> Any: """ Build a yolo detector model Returns: @@ -40,7 +51,7 @@ def build_model(self) -> Any: ) from e weight_file = weight_utils.download_weights_if_necessary( - file_name=WEIGHT_NAME, source_url=WEIGHT_URL + file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] ) # Return face_detector @@ -98,3 +109,18 @@ def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: resp.append(facial_area) return resp + + +class YoloClientV8n(YoloClient): + def __init__(self): + super().__init__(YoloModel.V8N) + + +class YoloClientV11n(YoloClient): + def __init__(self): + super().__init__(YoloModel.V11N) + + +class YoloClientV11m(YoloClient): + def __init__(self): + super().__init__(YoloModel.V11M) diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index b68314b9c..0f29cd9f7 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -35,7 +35,7 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', diff --git a/deepface/modules/detection.py b/deepface/modules/detection.py index 17bd5d906..bce658605 100644 --- a/deepface/modules/detection.py +++ b/deepface/modules/detection.py @@ -38,7 +38,7 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv) enforce_detection (boolean): If no face is detected in an image, raise an exception. diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index c097c923e..ba65383e9 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -38,7 +38,7 @@ def build_model(task: str, model_name: str) -> Any: - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yunet, + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing Returns: @@ -77,7 +77,9 @@ def build_model(task: str, model_name: str) -> Any: "dlib": DlibDetector.DlibClient, "retinaface": RetinaFace.RetinaFaceClient, "mediapipe": MediaPipe.MediaPipeClient, - "yolov8": Yolo.YoloClient, + "yolov8": Yolo.YoloClientV8n, + "yolov11n": Yolo.YoloClientV11n, + "yolov11m": Yolo.YoloClientV11m, "yunet": YuNet.YuNetClient, "fastmtcnn": FastMtCnn.FastMtCnnClient, "centerface": CenterFace.CenterFaceClient, diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index df7068dc4..96313f059 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -54,7 +54,7 @@ def find( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n','yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. @@ -483,7 +483,7 @@ def find_batched( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index a1476405f..fbe952e69 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -36,7 +36,7 @@ def represent( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. diff --git a/deepface/modules/streaming.py b/deepface/modules/streaming.py index c1a0363e1..64ebe80aa 100644 --- a/deepface/modules/streaming.py +++ b/deepface/modules/streaming.py @@ -45,7 +45,7 @@ def analysis( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -192,7 +192,7 @@ def search_identity( model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -374,7 +374,7 @@ def grab_facial_areas( Args: img (np.ndarray): image itself detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). threshold (int): threshold for facial area, discard smaller ones Returns @@ -443,7 +443,7 @@ def perform_facial_recognition( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). diff --git a/deepface/modules/verification.py b/deepface/modules/verification.py index 540b63bf7..6e05eb469 100644 --- a/deepface/modules/verification.py +++ b/deepface/modules/verification.py @@ -47,7 +47,7 @@ def verify( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' (default is opencv) distance_metric (string): Metric for measuring similarity. Options: 'cosine', diff --git a/tests/visual-test.py b/tests/visual-test.py index 9149bc5eb..2b1ff22cd 100644 --- a/tests/visual-test.py +++ b/tests/visual-test.py @@ -34,6 +34,8 @@ "retinaface", "yunet", "yolov8", + "yolov11n", + "yolov11m", "centerface", ] From 38261e07e5fe4f6e987b5a56c7f82920229caf9f Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Sat, 7 Dec 2024 23:00:16 +0100 Subject: [PATCH 02/12] yolov11n and yolov11m added to model selection --- benchmarks/Evaluate-Results.ipynb | 4 +- deepface/DeepFace.py | 16 ++--- deepface/commons/weight_utils.py | 6 +- deepface/models/YoloClientBase.py | 37 ++++++++++ deepface/models/YoloModel.py | 21 ++++++ deepface/models/face_detection/Yolo.py | 79 +++++++--------------- deepface/models/facial_recognition/Yolo.py | 44 ++++++++++++ deepface/modules/demography.py | 2 +- deepface/modules/detection.py | 2 +- deepface/modules/modeling.py | 16 +++-- deepface/modules/recognition.py | 4 +- deepface/modules/representation.py | 13 ++-- deepface/modules/streaming.py | 8 +-- deepface/modules/verification.py | 2 +- tests/visual-test.py | 5 ++ 15 files changed, 173 insertions(+), 86 deletions(-) create mode 100644 deepface/models/YoloClientBase.py create mode 100644 deepface/models/YoloModel.py create mode 100644 deepface/models/facial_recognition/Yolo.py diff --git a/benchmarks/Evaluate-Results.ipynb b/benchmarks/Evaluate-Results.ipynb index 16d29dce5..72d74cc6a 100644 --- a/benchmarks/Evaluate-Results.ipynb +++ b/benchmarks/Evaluate-Results.ipynb @@ -29,8 +29,8 @@ "outputs": [], "source": [ "alignment = [False, True]\n", - "models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\"]\n", - "detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yolov11n\", \"yolov11m\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n", + "models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\", \"yolov8\", \"yolov11n\", \"yolov11s\", \"yolov11m\"]\n", + "detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yolov11n\", \"yolov11s\", \"yolov11m\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n", "distance_metrics = [\"euclidean\", \"euclidean_l2\", \"cosine\"]" ] }, diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index a95bcc534..e4e5411e8 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -56,7 +56,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet, + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s','yolov11m', yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing task (str): facial_recognition, facial_attribute, face_detector, spoofing @@ -96,7 +96,7 @@ def verify( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -187,7 +187,7 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -298,7 +298,7 @@ def find( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -396,7 +396,7 @@ def represent( (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -462,7 +462,7 @@ def stream( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -517,7 +517,7 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. @@ -601,7 +601,7 @@ def detectFace( added to resize the image (default is (224, 224)). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. diff --git a/deepface/commons/weight_utils.py b/deepface/commons/weight_utils.py index 421d85704..d3207022a 100644 --- a/deepface/commons/weight_utils.py +++ b/deepface/commons/weight_utils.py @@ -127,7 +127,7 @@ def download_all_models_in_one_shot() -> None: MODEL_URL as SSD_MODEL, WEIGHTS_URL as SSD_WEIGHTS, ) - from deepface.models.face_detection.Yolo import ( + from deepface.models.YoloModel import ( WEIGHT_URLS as YOLO_WEIGHTS, WEIGHT_NAMES as YOLO_WEIGHT_NAMES, YoloModel @@ -170,6 +170,10 @@ def download_all_models_in_one_shot() -> None: "filename": YOLO_WEIGHT_NAMES[YoloModel.V11N.value], "url": YOLO_WEIGHTS[YoloModel.V11N.value], }, + { + "filename": YOLO_WEIGHT_NAMES[YoloModel.V11S.value], + "url": YOLO_WEIGHTS[YoloModel.V11S.value], + }, { "filename": YOLO_WEIGHT_NAMES[YoloModel.V11M.value], "url": YOLO_WEIGHTS[YoloModel.V11M.value], diff --git a/deepface/models/YoloClientBase.py b/deepface/models/YoloClientBase.py new file mode 100644 index 000000000..83ad32413 --- /dev/null +++ b/deepface/models/YoloClientBase.py @@ -0,0 +1,37 @@ +# built-in dependencies +from typing import Any + +# project dependencies +from deepface.models.YoloModel import YoloModel, WEIGHT_URLS, WEIGHT_NAMES +from deepface.commons import weight_utils +from deepface.commons.logger import Logger + +logger = Logger() + + +class YoloClientBase: + def __init__(self, model: YoloModel): + self.model = self.build_model(model) + + def build_model(self, model: YoloModel) -> Any: + """ + Build a yolo detector model + Returns: + model (Any) + """ + + # Import the optional Ultralytics YOLO model + try: + from ultralytics import YOLO + except ModuleNotFoundError as e: + raise ImportError( + "Yolo is an optional detector, ensure the library is installed. " + "Please install using 'pip install ultralytics'" + ) from e + + weight_file = weight_utils.download_weights_if_necessary( + file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] + ) + + # Return face_detector + return YOLO(weight_file) diff --git a/deepface/models/YoloModel.py b/deepface/models/YoloModel.py new file mode 100644 index 000000000..93f2a74fb --- /dev/null +++ b/deepface/models/YoloModel.py @@ -0,0 +1,21 @@ +from enum import Enum + + +class YoloModel(Enum): + V8N = 0 + V11N = 1 + V11S = 2 + V11M = 3 + + +# Model's weights paths +WEIGHT_NAMES = ["yolov8n-face.pt", + "yolov11n-face.pt", + "yolov11s-face.pt", + "yolov11m-face.pt"] + +# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB +WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] diff --git a/deepface/models/face_detection/Yolo.py b/deepface/models/face_detection/Yolo.py index f0826e140..1f29781b9 100644 --- a/deepface/models/face_detection/Yolo.py +++ b/deepface/models/face_detection/Yolo.py @@ -1,61 +1,22 @@ # built-in dependencies import os -from typing import Any, List -from enum import Enum +from typing import List # 3rd party dependencies import numpy as np # project dependencies +from deepface.models.YoloClientBase import YoloClientBase +from deepface.models.YoloModel import YoloModel from deepface.models.Detector import Detector, FacialAreaRegion -from deepface.commons import weight_utils from deepface.commons.logger import Logger logger = Logger() -# Model's weights paths -WEIGHT_NAMES = ["yolov8n-face.pt", - "yolov11n-face.pt", - "yolov11m-face.pt"] -# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB -WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] - - -class YoloModel(Enum): - V8N = 0 - V11N = 1 - V11M = 2 - - -class YoloClient(Detector): +class YoloDetectorClient(YoloClientBase, Detector): def __init__(self, model: YoloModel): - self.model = self.build_model(model) - - def build_model(self, model: YoloModel) -> Any: - """ - Build a yolo detector model - Returns: - model (Any) - """ - - # Import the optional Ultralytics YOLO model - try: - from ultralytics import YOLO - except ModuleNotFoundError as e: - raise ImportError( - "Yolo is an optional detector, ensure the library is installed. " - "Please install using 'pip install ultralytics'" - ) from e - - weight_file = weight_utils.download_weights_if_necessary( - file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] - ) - - # Return face_detector - return YOLO(weight_file) + super().__init__(model) def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: """ @@ -80,21 +41,24 @@ def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: # For each face, extract the bounding box, the landmarks and confidence for result in results: - if result.boxes is None or result.keypoints is None: + if result.boxes is None: continue # Extract the bounding box and the confidence x, y, w, h = result.boxes.xywh.tolist()[0] confidence = result.boxes.conf.tolist()[0] - # right_eye_conf = result.keypoints.conf[0][0] - # left_eye_conf = result.keypoints.conf[0][1] - right_eye = result.keypoints.xy[0][0].tolist() - left_eye = result.keypoints.xy[0][1].tolist() + right_eye = None + left_eye = None + if result.keypoints is not None: + # right_eye_conf = result.keypoints.conf[0][0] + # left_eye_conf = result.keypoints.conf[0][1] + right_eye = result.keypoints.xy[0][0].tolist() + left_eye = result.keypoints.xy[0][1].tolist() - # eyes are list of float, need to cast them tuple of int - left_eye = tuple(int(i) for i in left_eye) - right_eye = tuple(int(i) for i in right_eye) + # eyes are list of float, need to cast them tuple of int + left_eye = tuple(int(i) for i in left_eye) + right_eye = tuple(int(i) for i in right_eye) x, y, w, h = int(x - w / 2), int(y - h / 2), int(w), int(h) facial_area = FacialAreaRegion( @@ -111,16 +75,21 @@ def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: return resp -class YoloClientV8n(YoloClient): +class YoloDetectorClientV8n(YoloDetectorClient): def __init__(self): super().__init__(YoloModel.V8N) -class YoloClientV11n(YoloClient): +class YoloDetectorClientV11n(YoloDetectorClient): def __init__(self): super().__init__(YoloModel.V11N) -class YoloClientV11m(YoloClient): +class YoloDetectorClientV11s(YoloDetectorClient): + def __init__(self): + super().__init__(YoloModel.V11S) + + +class YoloDetectorClientV11m(YoloDetectorClient): def __init__(self): super().__init__(YoloModel.V11M) diff --git a/deepface/models/facial_recognition/Yolo.py b/deepface/models/facial_recognition/Yolo.py new file mode 100644 index 000000000..e8f6d5904 --- /dev/null +++ b/deepface/models/facial_recognition/Yolo.py @@ -0,0 +1,44 @@ +# built-in dependencies +from typing import List + +# 3rd party dependencies +import numpy as np + +# project dependencies +from deepface.models.YoloClientBase import YoloClientBase +from deepface.models.YoloModel import YoloModel +from deepface.models.FacialRecognition import FacialRecognition +from deepface.commons.logger import Logger + +logger = Logger() + + +class YoloFacialRecognitionClient(YoloClientBase, FacialRecognition): + def __init__(self, model: YoloModel): + super().__init__(model) + self.model_name = "Yolo" + self.input_shape = None + self.output_shape = 512 + + def forward(self, img: np.ndarray) -> List[float]: + return self.model.embed(img)[0].tolist() + + +class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient): + def __init__(self): + super().__init__(YoloModel.V8N) + + +class YoloFacialRecognitionClientV11n(YoloFacialRecognitionClient): + def __init__(self): + super().__init__(YoloModel.V11N) + + +class YoloFacialRecognitionClientV11s(YoloFacialRecognitionClient): + def __init__(self): + super().__init__(YoloModel.V11S) + + +class YoloFacialRecognitionClientV11m(YoloFacialRecognitionClient): + def __init__(self): + super().__init__(YoloModel.V11M) diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index 0f29cd9f7..cc5112e2c 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -35,7 +35,7 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', diff --git a/deepface/modules/detection.py b/deepface/modules/detection.py index bce658605..4ed907636 100644 --- a/deepface/modules/detection.py +++ b/deepface/modules/detection.py @@ -38,7 +38,7 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv) enforce_detection (boolean): If no face is detected in an image, raise an exception. diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index ba65383e9..fa884ac6b 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -12,6 +12,7 @@ Dlib, Facenet, GhostFaceNet, + Yolo as YoloFacialRecognition, ) from deepface.models.face_detection import ( FastMtCnn, @@ -21,7 +22,7 @@ Dlib as DlibDetector, RetinaFace, Ssd, - Yolo, + Yolo as YoloFaceDetector, YuNet, CenterFace, ) @@ -38,7 +39,7 @@ def build_model(task: str, model_name: str) -> Any: - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11m', yunet, + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing Returns: @@ -60,6 +61,10 @@ def build_model(task: str, model_name: str) -> Any: "ArcFace": ArcFace.ArcFaceClient, "SFace": SFace.SFaceClient, "GhostFaceNet": GhostFaceNet.GhostFaceNetClient, + "yolov8": YoloFacialRecognition.YoloFacialRecognitionClientV8n, + "yolov11n": YoloFacialRecognition.YoloFacialRecognitionClientV11n, + "yolov11s": YoloFacialRecognition.YoloFacialRecognitionClientV11s, + "yolov11m": YoloFacialRecognition.YoloFacialRecognitionClientV11m }, "spoofing": { "Fasnet": FasNet.Fasnet, @@ -77,9 +82,10 @@ def build_model(task: str, model_name: str) -> Any: "dlib": DlibDetector.DlibClient, "retinaface": RetinaFace.RetinaFaceClient, "mediapipe": MediaPipe.MediaPipeClient, - "yolov8": Yolo.YoloClientV8n, - "yolov11n": Yolo.YoloClientV11n, - "yolov11m": Yolo.YoloClientV11m, + "yolov8": YoloFaceDetector.YoloDetectorClientV8n, + "yolov11n": YoloFaceDetector.YoloDetectorClientV11n, + "yolov11s": YoloFaceDetector.YoloDetectorClientV11s, + "yolov11m": YoloFaceDetector.YoloDetectorClientV11m, "yunet": YuNet.YuNetClient, "fastmtcnn": FastMtCnn.FastMtCnnClient, "centerface": CenterFace.CenterFaceClient, diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index 96313f059..d254d66e3 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -54,7 +54,7 @@ def find( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n','yolov11m', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. @@ -483,7 +483,7 @@ def find_batched( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index fbe952e69..f9e751d99 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -36,7 +36,7 @@ def represent( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. @@ -122,11 +122,12 @@ def represent( confidence = img_obj["confidence"] # resize to expected shape of ml model - img = preprocessing.resize_image( - img=img, - # thanks to DeepId (!) - target_size=(target_size[1], target_size[0]), - ) + if target_size is not None: + img = preprocessing.resize_image( + img=img, + # thanks to DeepId (!) + target_size=(target_size[1], target_size[0]), + ) # custom normalization img = preprocessing.normalize_input(img=img, normalization=normalization) diff --git a/deepface/modules/streaming.py b/deepface/modules/streaming.py index 64ebe80aa..ca51989cb 100644 --- a/deepface/modules/streaming.py +++ b/deepface/modules/streaming.py @@ -45,7 +45,7 @@ def analysis( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', @@ -192,7 +192,7 @@ def search_identity( model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -374,7 +374,7 @@ def grab_facial_areas( Args: img (np.ndarray): image itself detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). threshold (int): threshold for facial area, discard smaller ones Returns @@ -443,7 +443,7 @@ def perform_facial_recognition( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). diff --git a/deepface/modules/verification.py b/deepface/modules/verification.py index 6e05eb469..1c03e5c6f 100644 --- a/deepface/modules/verification.py +++ b/deepface/modules/verification.py @@ -47,7 +47,7 @@ def verify( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11m', 'centerface' or 'skip' + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv) distance_metric (string): Metric for measuring similarity. Options: 'cosine', diff --git a/tests/visual-test.py b/tests/visual-test.py index 2b1ff22cd..111e9ed11 100644 --- a/tests/visual-test.py +++ b/tests/visual-test.py @@ -22,6 +22,10 @@ "ArcFace", "SFace", "GhostFaceNet", + "yolov8", + "yolov11n", + "yolov11s", + "yolov11m" ] detector_backends = [ @@ -35,6 +39,7 @@ "yunet", "yolov8", "yolov11n", + "yolov11s", "yolov11m", "centerface", ] From 7156e6f5512d6dc096f053343ace655de3f8a712 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Mon, 9 Dec 2024 10:20:59 +0100 Subject: [PATCH 03/12] fix: method documentation --- README.md | 1 + deepface/DeepFace.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index cbd1dc1d3..eec06930c 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,7 @@ backends = [ 'retinaface', 'mediapipe', 'yolov8', + 'yolov11s', 'yolov11n', 'yolov11m', 'yunet', diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index e4e5411e8..43df4a16c 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -54,7 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: Args: model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace, GhostFaceNet for face recognition + ArcFace, SFace, GhostFaceNet, Yolo-Face for face recognition - Age, Gender, Emotion, Race for facial attributes - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s','yolov11m', yunet, fastmtcnn or centerface for face detectors @@ -93,7 +93,7 @@ def verify( or pre-calculated embeddings. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' @@ -289,7 +289,7 @@ def find( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face (default is VGG-Face). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -388,7 +388,7 @@ def represent( include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face (default is VGG-Face.). enforce_detection (boolean): If no face is detected in an image, raise an exception. @@ -459,7 +459,7 @@ def stream( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' From aa57d879b62026d035e14836aad75847e3d44135 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Mon, 9 Dec 2024 10:25:53 +0100 Subject: [PATCH 04/12] fix: method documentation --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index eec06930c..5addd788c 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introdu **Face recognition models** - [`Demo`](https://youtu.be/eKOZawGR3y0) -DeepFace is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`. The default configuration uses VGG-Face model. +DeepFace is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace`, `GhostFaceNet` and `Yolo-Face`. The default configuration uses VGG-Face model. ```python models = [ @@ -122,6 +122,10 @@ models = [ "Dlib", "SFace", "GhostFaceNet", + "yolov8", + "yolov11n", + "yolov11s", + "yolov11m" ] #face verification From 01bf48dff81b712fbf4fc9a40774418b7769755d Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Mon, 9 Dec 2024 11:51:56 +0100 Subject: [PATCH 05/12] fix:YoloModel removal --- deepface/commons/weight_utils.py | 2 +- deepface/models/YoloClientBase.py | 37 --------------- deepface/models/YoloModel.py | 21 --------- deepface/models/face_detection/Yolo.py | 54 ++++++++++++++++++++-- deepface/models/facial_recognition/Yolo.py | 51 ++++++++++++++++++-- 5 files changed, 97 insertions(+), 68 deletions(-) delete mode 100644 deepface/models/YoloClientBase.py delete mode 100644 deepface/models/YoloModel.py diff --git a/deepface/commons/weight_utils.py b/deepface/commons/weight_utils.py index d3207022a..dfac1faa1 100644 --- a/deepface/commons/weight_utils.py +++ b/deepface/commons/weight_utils.py @@ -127,7 +127,7 @@ def download_all_models_in_one_shot() -> None: MODEL_URL as SSD_MODEL, WEIGHTS_URL as SSD_WEIGHTS, ) - from deepface.models.YoloModel import ( + from deepface.models.face_detection.Yolo import ( WEIGHT_URLS as YOLO_WEIGHTS, WEIGHT_NAMES as YOLO_WEIGHT_NAMES, YoloModel diff --git a/deepface/models/YoloClientBase.py b/deepface/models/YoloClientBase.py deleted file mode 100644 index 83ad32413..000000000 --- a/deepface/models/YoloClientBase.py +++ /dev/null @@ -1,37 +0,0 @@ -# built-in dependencies -from typing import Any - -# project dependencies -from deepface.models.YoloModel import YoloModel, WEIGHT_URLS, WEIGHT_NAMES -from deepface.commons import weight_utils -from deepface.commons.logger import Logger - -logger = Logger() - - -class YoloClientBase: - def __init__(self, model: YoloModel): - self.model = self.build_model(model) - - def build_model(self, model: YoloModel) -> Any: - """ - Build a yolo detector model - Returns: - model (Any) - """ - - # Import the optional Ultralytics YOLO model - try: - from ultralytics import YOLO - except ModuleNotFoundError as e: - raise ImportError( - "Yolo is an optional detector, ensure the library is installed. " - "Please install using 'pip install ultralytics'" - ) from e - - weight_file = weight_utils.download_weights_if_necessary( - file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] - ) - - # Return face_detector - return YOLO(weight_file) diff --git a/deepface/models/YoloModel.py b/deepface/models/YoloModel.py deleted file mode 100644 index 93f2a74fb..000000000 --- a/deepface/models/YoloModel.py +++ /dev/null @@ -1,21 +0,0 @@ -from enum import Enum - - -class YoloModel(Enum): - V8N = 0 - V11N = 1 - V11S = 2 - V11M = 3 - - -# Model's weights paths -WEIGHT_NAMES = ["yolov8n-face.pt", - "yolov11n-face.pt", - "yolov11s-face.pt", - "yolov11m-face.pt"] - -# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB -WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] diff --git a/deepface/models/face_detection/Yolo.py b/deepface/models/face_detection/Yolo.py index 1f29781b9..548e4f796 100644 --- a/deepface/models/face_detection/Yolo.py +++ b/deepface/models/face_detection/Yolo.py @@ -1,22 +1,66 @@ # built-in dependencies import os -from typing import List +from typing import List, Any +from enum import Enum # 3rd party dependencies import numpy as np # project dependencies -from deepface.models.YoloClientBase import YoloClientBase -from deepface.models.YoloModel import YoloModel from deepface.models.Detector import Detector, FacialAreaRegion from deepface.commons.logger import Logger +from deepface.commons import weight_utils logger = Logger() -class YoloDetectorClient(YoloClientBase, Detector): +class YoloModel(Enum): + V8N = 0 + V11N = 1 + V11S = 2 + V11M = 3 + + +# Model's weights paths +WEIGHT_NAMES = ["yolov8n-face.pt", + "yolov11n-face.pt", + "yolov11s-face.pt", + "yolov11m-face.pt"] + +# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB +WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] + + +class YoloDetectorClient(Detector): def __init__(self, model: YoloModel): - super().__init__(model) + super().__init__() + self.model = self.build_model(model) + + def build_model(self, model: YoloModel) -> Any: + """ + Build a yolo detector model + Returns: + model (Any) + """ + + # Import the optional Ultralytics YOLO model + try: + from ultralytics import YOLO + except ModuleNotFoundError as e: + raise ImportError( + "Yolo is an optional detector, ensure the library is installed. " + "Please install using 'pip install ultralytics'" + ) from e + + weight_file = weight_utils.download_weights_if_necessary( + file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] + ) + + # Return face_detector + return YOLO(weight_file) def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: """ diff --git a/deepface/models/facial_recognition/Yolo.py b/deepface/models/facial_recognition/Yolo.py index e8f6d5904..fb8054d87 100644 --- a/deepface/models/facial_recognition/Yolo.py +++ b/deepface/models/facial_recognition/Yolo.py @@ -1,25 +1,68 @@ # built-in dependencies -from typing import List +from typing import List, Any +from enum import Enum # 3rd party dependencies import numpy as np # project dependencies -from deepface.models.YoloClientBase import YoloClientBase -from deepface.models.YoloModel import YoloModel from deepface.models.FacialRecognition import FacialRecognition from deepface.commons.logger import Logger +from deepface.commons import weight_utils logger = Logger() -class YoloFacialRecognitionClient(YoloClientBase, FacialRecognition): +class YoloModel(Enum): + V8N = 0 + V11N = 1 + V11S = 2 + V11M = 3 + + +# Model's weights paths +WEIGHT_NAMES = ["yolov8n-face.pt", + "yolov11n-face.pt", + "yolov11s-face.pt", + "yolov11m-face.pt"] + +# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB +WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt", + "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] + + +class YoloFacialRecognitionClient(FacialRecognition): def __init__(self, model: YoloModel): super().__init__(model) self.model_name = "Yolo" self.input_shape = None self.output_shape = 512 + def build_model(self, model: YoloModel) -> Any: + """ + Build a yolo detector model + Returns: + model (Any) + """ + + # Import the optional Ultralytics YOLO model + try: + from ultralytics import YOLO + except ModuleNotFoundError as e: + raise ImportError( + "Yolo is an optional detector, ensure the library is installed. " + "Please install using 'pip install ultralytics'" + ) from e + + weight_file = weight_utils.download_weights_if_necessary( + file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] + ) + + # Return face_detector + return YOLO(weight_file) + def forward(self, img: np.ndarray) -> List[float]: return self.model.embed(img)[0].tolist() From 79dedc08c1cf6961d426997a54a07d7482f574af Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Mon, 9 Dec 2024 14:45:34 +0100 Subject: [PATCH 06/12] fix: added input_shape to YoloFacialRecognitionClient --- deepface/models/facial_recognition/Yolo.py | 7 ++++--- deepface/modules/representation.py | 11 +++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/deepface/models/facial_recognition/Yolo.py b/deepface/models/facial_recognition/Yolo.py index fb8054d87..1e08218eb 100644 --- a/deepface/models/facial_recognition/Yolo.py +++ b/deepface/models/facial_recognition/Yolo.py @@ -35,10 +35,11 @@ class YoloModel(Enum): class YoloFacialRecognitionClient(FacialRecognition): def __init__(self, model: YoloModel): - super().__init__(model) + super().__init__() self.model_name = "Yolo" - self.input_shape = None + self.input_shape = (224, 224) self.output_shape = 512 + self.model = self.build_model(model) def build_model(self, model: YoloModel) -> Any: """ @@ -64,7 +65,7 @@ def build_model(self, model: YoloModel) -> Any: return YOLO(weight_file) def forward(self, img: np.ndarray) -> List[float]: - return self.model.embed(img)[0].tolist() + return self.model.embed(np.squeeze(img, axis=0))[0].tolist() class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient): diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index f9e751d99..c1e2a5fe4 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -122,12 +122,11 @@ def represent( confidence = img_obj["confidence"] # resize to expected shape of ml model - if target_size is not None: - img = preprocessing.resize_image( - img=img, - # thanks to DeepId (!) - target_size=(target_size[1], target_size[0]), - ) + img = preprocessing.resize_image( + img=img, + # thanks to DeepId (!) + target_size=(target_size[1], target_size[0]), + ) # custom normalization img = preprocessing.normalize_input(img=img, normalization=normalization) From d3a3f2b65fd711efb23439c1cedd6e55e51063c5 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Mon, 9 Dec 2024 16:00:27 +0100 Subject: [PATCH 07/12] fix: documentation --- deepface/DeepFace.py | 176 +++++++++++++------------ deepface/models/face_detection/Yolo.py | 2 + deepface/modules/demography.py | 4 +- deepface/modules/detection.py | 4 +- deepface/modules/modeling.py | 10 +- deepface/modules/recognition.py | 15 ++- deepface/modules/representation.py | 2 +- deepface/modules/streaming.py | 22 ++-- deepface/modules/verification.py | 7 +- 9 files changed, 128 insertions(+), 114 deletions(-) diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 43df4a16c..cf2c827f5 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -54,10 +54,11 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: Args: model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace, GhostFaceNet, Yolo-Face for face recognition + ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, Yolov11s and + Yolov11m for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s','yolov11m', yunet, - fastmtcnn or centerface for face detectors + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n, + yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing task (str): facial_recognition, facial_attribute, face_detector, spoofing default is facial_recognition @@ -68,18 +69,18 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: def verify( - img1_path: Union[str, np.ndarray, List[float]], - img2_path: Union[str, np.ndarray, List[float]], - model_name: str = "VGG-Face", - detector_backend: str = "opencv", - distance_metric: str = "cosine", - enforce_detection: bool = True, - align: bool = True, - expand_percentage: int = 0, - normalization: str = "base", - silent: bool = False, - threshold: Optional[float] = None, - anti_spoofing: bool = False, + img1_path: Union[str, np.ndarray, List[float]], + img2_path: Union[str, np.ndarray, List[float]], + model_name: str = "VGG-Face", + detector_backend: str = "opencv", + distance_metric: str = "cosine", + enforce_detection: bool = True, + align: bool = True, + expand_percentage: int = 0, + normalization: str = "base", + silent: bool = False, + threshold: Optional[float] = None, + anti_spoofing: bool = False, ) -> Dict[str, Any]: """ Verify if an image pair represents the same person or different persons. @@ -93,7 +94,8 @@ def verify( or pre-calculated embeddings. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' @@ -164,14 +166,14 @@ def verify( def analyze( - img_path: Union[str, np.ndarray], - actions: Union[tuple, list] = ("emotion", "age", "gender", "race"), - enforce_detection: bool = True, - detector_backend: str = "opencv", - align: bool = True, - expand_percentage: int = 0, - silent: bool = False, - anti_spoofing: bool = False, + img_path: Union[str, np.ndarray], + actions: Union[tuple, list] = ("emotion", "age", "gender", "race"), + enforce_detection: bool = True, + detector_backend: str = "opencv", + align: bool = True, + expand_percentage: int = 0, + silent: bool = False, + anti_spoofing: bool = False, ) -> List[Dict[str, Any]]: """ Analyze facial attributes such as age, gender, emotion, and race in the provided image. @@ -187,8 +189,8 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -263,20 +265,20 @@ def analyze( def find( - img_path: Union[str, np.ndarray], - db_path: str, - model_name: str = "VGG-Face", - distance_metric: str = "cosine", - enforce_detection: bool = True, - detector_backend: str = "opencv", - align: bool = True, - expand_percentage: int = 0, - threshold: Optional[float] = None, - normalization: str = "base", - silent: bool = False, - refresh_database: bool = True, - anti_spoofing: bool = False, - batched: bool = False, + img_path: Union[str, np.ndarray], + db_path: str, + model_name: str = "VGG-Face", + distance_metric: str = "cosine", + enforce_detection: bool = True, + detector_backend: str = "opencv", + align: bool = True, + expand_percentage: int = 0, + threshold: Optional[float] = None, + normalization: str = "base", + silent: bool = False, + refresh_database: bool = True, + anti_spoofing: bool = False, + batched: bool = False, ) -> Union[List[pd.DataFrame], List[List[Dict[str, Any]]]]: """ Identify individuals in a database @@ -289,7 +291,8 @@ def find( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -298,8 +301,8 @@ def find( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -369,15 +372,15 @@ def find( def represent( - img_path: Union[str, np.ndarray], - model_name: str = "VGG-Face", - enforce_detection: bool = True, - detector_backend: str = "opencv", - align: bool = True, - expand_percentage: int = 0, - normalization: str = "base", - anti_spoofing: bool = False, - max_faces: Optional[int] = None, + img_path: Union[str, np.ndarray], + model_name: str = "VGG-Face", + enforce_detection: bool = True, + detector_backend: str = "opencv", + align: bool = True, + expand_percentage: int = 0, + normalization: str = "base", + anti_spoofing: bool = False, + max_faces: Optional[int] = None, ) -> List[Dict[str, Any]]: """ Represent facial images as multi-dimensional vector embeddings. @@ -388,16 +391,16 @@ def represent( include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face - (default is VGG-Face.). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face.). enforce_detection (boolean): If no face is detected in an image, raise an exception. Default is True. Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). align (boolean): Perform alignment based on the eye positions (default is True). @@ -441,15 +444,15 @@ def represent( def stream( - db_path: str = "", - model_name: str = "VGG-Face", - detector_backend: str = "opencv", - distance_metric: str = "cosine", - enable_face_analysis: bool = True, - source: Any = 0, - time_threshold: int = 5, - frame_threshold: int = 5, - anti_spoofing: bool = False, + db_path: str = "", + model_name: str = "VGG-Face", + detector_backend: str = "opencv", + distance_metric: str = "cosine", + enable_face_analysis: bool = True, + source: Any = 0, + time_threshold: int = 5, + frame_threshold: int = 5, + anti_spoofing: bool = False, ) -> None: """ Run real time face recognition and facial attribute analysis @@ -459,11 +462,12 @@ def stream( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet and Yolo-Face (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -499,15 +503,15 @@ def stream( def extract_faces( - img_path: Union[str, np.ndarray], - detector_backend: str = "opencv", - enforce_detection: bool = True, - align: bool = True, - expand_percentage: int = 0, - grayscale: bool = False, - color_face: str = "rgb", - normalize_face: bool = True, - anti_spoofing: bool = False, + img_path: Union[str, np.ndarray], + detector_backend: str = "opencv", + enforce_detection: bool = True, + align: bool = True, + expand_percentage: int = 0, + grayscale: bool = False, + color_face: str = "rgb", + normalize_face: bool = True, + anti_spoofing: bool = False, ) -> List[Dict[str, Any]]: """ Extract faces from a given image @@ -517,8 +521,8 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. Set to False to avoid the exception for low-resolution images (default is True). @@ -584,11 +588,11 @@ def cli() -> None: def detectFace( - img_path: Union[str, np.ndarray], - target_size: tuple = (224, 224), - detector_backend: str = "opencv", - enforce_detection: bool = True, - align: bool = True, + img_path: Union[str, np.ndarray], + target_size: tuple = (224, 224), + detector_backend: str = "opencv", + enforce_detection: bool = True, + align: bool = True, ) -> Union[np.ndarray, None]: """ Deprecated face detection function. Use extract_faces for same functionality. @@ -601,8 +605,8 @@ def detectFace( added to resize the image (default is (224, 224)). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). enforce_detection (boolean): If no face is detected in an image, raise an exception. Set to False to avoid the exception for low-resolution images (default is True). diff --git a/deepface/models/face_detection/Yolo.py b/deepface/models/face_detection/Yolo.py index 548e4f796..ee756c56c 100644 --- a/deepface/models/face_detection/Yolo.py +++ b/deepface/models/face_detection/Yolo.py @@ -94,6 +94,8 @@ def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: right_eye = None left_eye = None + + # yolo-facev8 is detecting eyes through keypoints, while for v11 keypoints are always None if result.keypoints is not None: # right_eye_conf = result.keypoints.conf[0][0] # left_eye_conf = result.keypoints.conf[0][1] diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index cc5112e2c..2258c1efe 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -35,8 +35,8 @@ def analyze( Set to False to avoid the exception for low-resolution images (default is True). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). diff --git a/deepface/modules/detection.py b/deepface/modules/detection.py index 4ed907636..eecca857c 100644 --- a/deepface/modules/detection.py +++ b/deepface/modules/detection.py @@ -38,8 +38,8 @@ def extract_faces( as a string, numpy array (BGR), or base64 encoded images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv) + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv) enforce_detection (boolean): If no face is detected in an image, raise an exception. Default is True. Set to False to avoid the exception for low-resolution images. diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index fa884ac6b..57a5e76c2 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -37,7 +37,7 @@ def build_model(task: str, model_name: str) -> Any: task (str): facial_recognition, facial_attribute, face_detector, spoofing model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace, GhostFaceNet for face recognition + ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, Yolov11s and Yolov11m for face recognition - Age, Gender, Emotion, Race for facial attributes - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors @@ -61,10 +61,10 @@ def build_model(task: str, model_name: str) -> Any: "ArcFace": ArcFace.ArcFaceClient, "SFace": SFace.SFaceClient, "GhostFaceNet": GhostFaceNet.GhostFaceNetClient, - "yolov8": YoloFacialRecognition.YoloFacialRecognitionClientV8n, - "yolov11n": YoloFacialRecognition.YoloFacialRecognitionClientV11n, - "yolov11s": YoloFacialRecognition.YoloFacialRecognitionClientV11s, - "yolov11m": YoloFacialRecognition.YoloFacialRecognitionClientV11m + "Yolov8": YoloFacialRecognition.YoloFacialRecognitionClientV8n, + "Yolov11n": YoloFacialRecognition.YoloFacialRecognitionClientV11n, + "Yolov11s": YoloFacialRecognition.YoloFacialRecognitionClientV11s, + "Yolov11m": YoloFacialRecognition.YoloFacialRecognitionClientV11m }, "spoofing": { "Fasnet": FasNet.Fasnet, diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index d254d66e3..25b36457d 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -45,7 +45,8 @@ def find( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2'. @@ -54,7 +55,8 @@ def find( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8','yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. @@ -359,7 +361,8 @@ def __find_bulk_embeddings( employees (list): list of exact image paths model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). detector_backend (str): face detector model name @@ -474,7 +477,8 @@ def find_batched( (used for anti-spoofing). model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2'. @@ -483,7 +487,8 @@ def find_batched( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions. diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index c1e2a5fe4..b8fcd29fb 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -30,7 +30,7 @@ def represent( include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, Yolov11s and Yolov11m enforce_detection (boolean): If no face is detected in an image, raise an exception. Default is True. Set to False to avoid the exception for low-resolution images. diff --git a/deepface/modules/streaming.py b/deepface/modules/streaming.py index ca51989cb..21cc6e7c3 100644 --- a/deepface/modules/streaming.py +++ b/deepface/modules/streaming.py @@ -42,11 +42,12 @@ def analysis( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -190,10 +191,11 @@ def search_identity( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). Returns: @@ -374,8 +376,8 @@ def grab_facial_areas( Args: img (np.ndarray): image itself detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). threshold (int): threshold for facial area, discard smaller ones Returns result (list): list of tuple with x, y, w and h coordinates @@ -443,8 +445,8 @@ def perform_facial_recognition( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, diff --git a/deepface/modules/verification.py b/deepface/modules/verification.py index 1c03e5c6f..83b6f9829 100644 --- a/deepface/modules/verification.py +++ b/deepface/modules/verification.py @@ -44,11 +44,12 @@ def verify( or pre-calculated embeddings. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, + Yolov11s and Yolov11m (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv) + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv) distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). From c25f14af104140ac70fb0af35f544813cd52d47c Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Mon, 9 Dec 2024 16:04:08 +0100 Subject: [PATCH 08/12] fix: yolo input_shape --- deepface/models/facial_recognition/Yolo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepface/models/facial_recognition/Yolo.py b/deepface/models/facial_recognition/Yolo.py index 1e08218eb..b87f8fc0e 100644 --- a/deepface/models/facial_recognition/Yolo.py +++ b/deepface/models/facial_recognition/Yolo.py @@ -37,7 +37,7 @@ class YoloFacialRecognitionClient(FacialRecognition): def __init__(self, model: YoloModel): super().__init__() self.model_name = "Yolo" - self.input_shape = (224, 224) + self.input_shape = (640, 640) self.output_shape = 512 self.model = self.build_model(model) From f808126101186df9f7bebc84752ca299075fdf98 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Tue, 10 Dec 2024 16:00:43 +0100 Subject: [PATCH 09/12] fix: rechecked with ultralitycs, embeddings generation is a bit misleading in terms of documentaion. In the end not suitable for recognition at all. Removed YoloFacialRecognition --- deepface/DeepFace.py | 15 ++-- deepface/models/facial_recognition/Yolo.py | 88 ---------------------- deepface/modules/modeling.py | 11 +-- deepface/modules/recognition.py | 9 +-- deepface/modules/representation.py | 2 +- deepface/modules/streaming.py | 6 +- deepface/modules/verification.py | 3 +- 7 files changed, 15 insertions(+), 119 deletions(-) delete mode 100644 deepface/models/facial_recognition/Yolo.py diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index cf2c827f5..b192ce0bd 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -54,8 +54,7 @@ def build_model(model_name: str, task: str = "facial_recognition") -> Any: Args: model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, Yolov11s and - Yolov11m for face recognition + ArcFace, SFace and GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, yolov11n, yolov11s, yolov11m, yunet, fastmtcnn or centerface for face detectors @@ -94,8 +93,7 @@ def verify( or pre-calculated embeddings. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' @@ -291,8 +289,7 @@ def find( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -391,8 +388,7 @@ def represent( include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face.). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face.). enforce_detection (boolean): If no face is detected in an image, raise an exception. Default is True. Set to False to avoid the exception for low-resolution images @@ -462,8 +458,7 @@ def stream( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', diff --git a/deepface/models/facial_recognition/Yolo.py b/deepface/models/facial_recognition/Yolo.py deleted file mode 100644 index b87f8fc0e..000000000 --- a/deepface/models/facial_recognition/Yolo.py +++ /dev/null @@ -1,88 +0,0 @@ -# built-in dependencies -from typing import List, Any -from enum import Enum - -# 3rd party dependencies -import numpy as np - -# project dependencies -from deepface.models.FacialRecognition import FacialRecognition -from deepface.commons.logger import Logger -from deepface.commons import weight_utils - -logger = Logger() - - -class YoloModel(Enum): - V8N = 0 - V11N = 1 - V11S = 2 - V11M = 3 - - -# Model's weights paths -WEIGHT_NAMES = ["yolov8n-face.pt", - "yolov11n-face.pt", - "yolov11s-face.pt", - "yolov11m-face.pt"] - -# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB -WEIGHT_URLS = ["https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11n-face.pt", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11s-face.pt", - "https://github.com/akanametov/yolo-face/releases/download/v0.0.0/yolov11m-face.pt"] - - -class YoloFacialRecognitionClient(FacialRecognition): - def __init__(self, model: YoloModel): - super().__init__() - self.model_name = "Yolo" - self.input_shape = (640, 640) - self.output_shape = 512 - self.model = self.build_model(model) - - def build_model(self, model: YoloModel) -> Any: - """ - Build a yolo detector model - Returns: - model (Any) - """ - - # Import the optional Ultralytics YOLO model - try: - from ultralytics import YOLO - except ModuleNotFoundError as e: - raise ImportError( - "Yolo is an optional detector, ensure the library is installed. " - "Please install using 'pip install ultralytics'" - ) from e - - weight_file = weight_utils.download_weights_if_necessary( - file_name=WEIGHT_NAMES[model.value], source_url=WEIGHT_URLS[model.value] - ) - - # Return face_detector - return YOLO(weight_file) - - def forward(self, img: np.ndarray) -> List[float]: - return self.model.embed(np.squeeze(img, axis=0))[0].tolist() - - -class YoloFacialRecognitionClientV8n(YoloFacialRecognitionClient): - def __init__(self): - super().__init__(YoloModel.V8N) - - -class YoloFacialRecognitionClientV11n(YoloFacialRecognitionClient): - def __init__(self): - super().__init__(YoloModel.V11N) - - -class YoloFacialRecognitionClientV11s(YoloFacialRecognitionClient): - def __init__(self): - super().__init__(YoloModel.V11S) - - -class YoloFacialRecognitionClientV11m(YoloFacialRecognitionClient): - def __init__(self): - super().__init__(YoloModel.V11M) diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index 57a5e76c2..e12804ef8 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -11,8 +11,7 @@ SFace, Dlib, Facenet, - GhostFaceNet, - Yolo as YoloFacialRecognition, + GhostFaceNet ) from deepface.models.face_detection import ( FastMtCnn, @@ -37,7 +36,7 @@ def build_model(task: str, model_name: str) -> Any: task (str): facial_recognition, facial_attribute, face_detector, spoofing model_name (str): model identifier - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, - ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, Yolov11s and Yolov11m for face recognition + ArcFace, SFace and GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors @@ -60,11 +59,7 @@ def build_model(task: str, model_name: str) -> Any: "Dlib": Dlib.DlibClient, "ArcFace": ArcFace.ArcFaceClient, "SFace": SFace.SFaceClient, - "GhostFaceNet": GhostFaceNet.GhostFaceNetClient, - "Yolov8": YoloFacialRecognition.YoloFacialRecognitionClientV8n, - "Yolov11n": YoloFacialRecognition.YoloFacialRecognitionClientV11n, - "Yolov11s": YoloFacialRecognition.YoloFacialRecognitionClientV11s, - "Yolov11m": YoloFacialRecognition.YoloFacialRecognitionClientV11m + "GhostFaceNet": GhostFaceNet.GhostFaceNetClient }, "spoofing": { "Fasnet": FasNet.Fasnet, diff --git a/deepface/modules/recognition.py b/deepface/modules/recognition.py index 25b36457d..1edb43025 100644 --- a/deepface/modules/recognition.py +++ b/deepface/modules/recognition.py @@ -45,8 +45,7 @@ def find( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2'. @@ -361,8 +360,7 @@ def __find_bulk_embeddings( employees (list): list of exact image paths model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (str): face detector model name @@ -477,8 +475,7 @@ def find_batched( (used for anti-spoofing). model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2'. diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index b8fcd29fb..c1e2a5fe4 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -30,7 +30,7 @@ def represent( include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, Yolov11s and Yolov11m + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet enforce_detection (boolean): If no face is detected in an image, raise an exception. Default is True. Set to False to avoid the exception for low-resolution images. diff --git a/deepface/modules/streaming.py b/deepface/modules/streaming.py index 21cc6e7c3..cc447830a 100644 --- a/deepface/modules/streaming.py +++ b/deepface/modules/streaming.py @@ -42,8 +42,7 @@ def analysis( in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face) detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', @@ -191,8 +190,7 @@ def search_identity( db_path (string): Path to the folder containing image files. All detected faces in the database will be considered in the decision-making process. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' (default is opencv). diff --git a/deepface/modules/verification.py b/deepface/modules/verification.py index 83b6f9829..43c3ba9e2 100644 --- a/deepface/modules/verification.py +++ b/deepface/modules/verification.py @@ -44,8 +44,7 @@ def verify( or pre-calculated embeddings. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace, GhostFaceNet, Yolov8, Yolov11n, - Yolov11s and Yolov11m (default is VGG-Face). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', From fcc093240e3cd0c37c3a33d41de45855c742cf54 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Tue, 10 Dec 2024 17:25:16 +0100 Subject: [PATCH 10/12] reverted to original --- benchmarks/Evaluate-Results.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/Evaluate-Results.ipynb b/benchmarks/Evaluate-Results.ipynb index 72d74cc6a..e2a7172d7 100644 --- a/benchmarks/Evaluate-Results.ipynb +++ b/benchmarks/Evaluate-Results.ipynb @@ -29,8 +29,8 @@ "outputs": [], "source": [ "alignment = [False, True]\n", - "models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\", \"yolov8\", \"yolov11n\", \"yolov11s\", \"yolov11m\"]\n", - "detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yolov11n\", \"yolov11s\", \"yolov11m\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n", + "models = [\"Facenet512\", \"Facenet\", \"VGG-Face\", \"ArcFace\", \"Dlib\", \"GhostFaceNet\", \"SFace\", \"OpenFace\", \"DeepFace\", \"DeepID\"]\n", + "detectors = [\"retinaface\", \"mtcnn\", \"fastmtcnn\", \"dlib\", \"yolov8\", \"yunet\", \"centerface\", \"mediapipe\", \"ssd\", \"opencv\", \"skip\"]\n", "distance_metrics = [\"euclidean\", \"euclidean_l2\", \"cosine\"]" ] }, From 29e41e11e37cc3f19eb4264ad01b78f84a7d6a12 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Tue, 10 Dec 2024 17:25:49 +0100 Subject: [PATCH 11/12] fix: yolo mention removed from recognition models --- README.md | 8 ++------ tests/visual-test.py | 6 +----- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 5addd788c..822f298eb 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introdu **Face recognition models** - [`Demo`](https://youtu.be/eKOZawGR3y0) -DeepFace is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace`, `GhostFaceNet` and `Yolo-Face`. The default configuration uses VGG-Face model. +DeepFace is a **hybrid** face recognition package. It currently wraps many **state-of-the-art** face recognition models: [`VGG-Face`](https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/) , [`FaceNet`](https://sefiks.com/2018/09/03/face-recognition-with-facenet-in-keras/), [`OpenFace`](https://sefiks.com/2019/07/21/face-recognition-with-openface-in-keras/), [`DeepFace`](https://sefiks.com/2020/02/17/face-recognition-with-facebook-deepface-in-keras/), [`DeepID`](https://sefiks.com/2020/06/16/face-recognition-with-deepid-in-keras/), [`ArcFace`](https://sefiks.com/2020/12/14/deep-face-recognition-with-arcface-in-keras-and-python/), [`Dlib`](https://sefiks.com/2020/07/11/face-recognition-with-dlib-in-python/), `SFace` and `GhostFaceNet`. The default configuration uses VGG-Face model. ```python models = [ @@ -121,11 +121,7 @@ models = [ "ArcFace", "Dlib", "SFace", - "GhostFaceNet", - "yolov8", - "yolov11n", - "yolov11s", - "yolov11m" + "GhostFaceNet" ] #face verification diff --git a/tests/visual-test.py b/tests/visual-test.py index 111e9ed11..9dd89863c 100644 --- a/tests/visual-test.py +++ b/tests/visual-test.py @@ -21,11 +21,7 @@ "Dlib", "ArcFace", "SFace", - "GhostFaceNet", - "yolov8", - "yolov11n", - "yolov11s", - "yolov11m" + "GhostFaceNet" ] detector_backends = [ From 9abcbd3d9c2e109cc5c8d8397ae36ae7ad53da24 Mon Sep 17 00:00:00 2001 From: roberto-corno-nttdata Date: Wed, 11 Dec 2024 12:07:52 +0100 Subject: [PATCH 12/12] fix: linting issues --- deepface/DeepFace.py | 7 ++++--- deepface/models/face_detection/Yolo.py | 3 ++- deepface/modules/modeling.py | 4 ++-- deepface/modules/representation.py | 3 ++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index b192ce0bd..6eb31ac68 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -96,8 +96,8 @@ def verify( OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face). detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip' - (default is opencv). + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', + 'centerface' or 'skip' (default is opencv). distance_metric (string): Metric for measuring similarity. Options: 'cosine', 'euclidean', 'euclidean_l2' (default is cosine). @@ -388,7 +388,8 @@ def represent( include information for each detected face. model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512, - OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet (default is VGG-Face.). + OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet + (default is VGG-Face.). enforce_detection (boolean): If no face is detected in an image, raise an exception. Default is True. Set to False to avoid the exception for low-resolution images diff --git a/deepface/models/face_detection/Yolo.py b/deepface/models/face_detection/Yolo.py index ee756c56c..233f08851 100644 --- a/deepface/models/face_detection/Yolo.py +++ b/deepface/models/face_detection/Yolo.py @@ -95,7 +95,8 @@ def detect_faces(self, img: np.ndarray) -> List[FacialAreaRegion]: right_eye = None left_eye = None - # yolo-facev8 is detecting eyes through keypoints, while for v11 keypoints are always None + # yolo-facev8 is detecting eyes through keypoints, + # while for v11 keypoints are always None if result.keypoints is not None: # right_eye_conf = result.keypoints.conf[0][0] # left_eye_conf = result.keypoints.conf[0][1] diff --git a/deepface/modules/modeling.py b/deepface/modules/modeling.py index e12804ef8..176d9e74b 100644 --- a/deepface/modules/modeling.py +++ b/deepface/modules/modeling.py @@ -38,8 +38,8 @@ def build_model(task: str, model_name: str) -> Any: - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet for face recognition - Age, Gender, Emotion, Race for facial attributes - - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', 'yolov11s', 'yolov11m', yunet, - fastmtcnn or centerface for face detectors + - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n', + 'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors - Fasnet for spoofing Returns: built model class diff --git a/deepface/modules/representation.py b/deepface/modules/representation.py index c1e2a5fe4..bec0b1aed 100644 --- a/deepface/modules/representation.py +++ b/deepface/modules/representation.py @@ -36,7 +36,8 @@ def represent( Default is True. Set to False to avoid the exception for low-resolution images. detector_backend (string): face detector backend. Options: 'opencv', 'retinaface', - 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', 'yolov11m', 'centerface' or 'skip'. + 'mtcnn', 'ssd', 'dlib', 'mediapipe', 'yolov8', 'yolov11n', 'yolov11s', + 'yolov11m', 'centerface' or 'skip'. align (boolean): Perform alignment based on the eye positions.