Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
- do alignment first and detection second optionally
- expand face area with percentage instead of pixels
  • Loading branch information
serengil committed Jan 2, 2024
1 parent af524f5 commit 018a5d2
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

[![PyPI Downloads](https://static.pepy.tech/personalized-badge/retina-face?period=total&units=international_system&left_color=grey&right_color=blue&left_text=pip%20downloads)](https://pepy.tech/project/retina-face)
[![Conda Downloads](https://img.shields.io/conda/dn/conda-forge/retina-face?color=green&label=conda%20downloads)](https://anaconda.org/conda-forge/retina-face)
[![Stars](https://img.shields.io/github/stars/serengil/retinaface?color=yellow)](https://github.com/serengil/retinaface)
[![Stars](https://img.shields.io/github/stars/serengil/retinaface?color=yellow&style=flat)](https://github.com/serengil/retinaface/stargazers)
[![License](http://img.shields.io/:license-MIT-green.svg?style=flat)](https://github.com/serengil/retinaface/blob/master/LICENSE)
[![Tests](https://github.com/serengil/retinaface/actions/workflows/tests.yml/badge.svg)](https://github.com/serengil/retinaface/actions/workflows/tests.yml)

Expand Down
47 changes: 39 additions & 8 deletions retinaface/RetinaFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from retinaface.model import retinaface_model
from retinaface.commons import preprocess, postprocess
from retinaface.commons.logger import Logger

logger = Logger(module="retinaface/RetinaFace.py")

# pylint: disable=global-variable-undefined, no-name-in-module, unused-import, too-many-locals, redefined-outer-name, too-many-statements, too-many-arguments

Expand Down Expand Up @@ -208,6 +211,7 @@ def extract_faces(
align: bool = True,
allow_upscaling: bool = True,
expand_face_area: int = 0,
align_first: bool = False,
) -> list:
"""
Extract detected and aligned faces
Expand All @@ -216,8 +220,10 @@ def extract_faces(
threshold (float): detection threshold
model (Model): pre-trained model can be passed to the function
align (bool): enable or disable alignment
allow_upscaling (bool)
expand_face_area (int): set this to something to expand facial area with given pixels
allow_upscaling (bool): allowing up-scaling
expand_face_area (int): expand detected facial area with a percentage
align_first (bool): set this True to align first and detect second
this can be applied only if input image has just one face
"""
resp = []

Expand All @@ -231,16 +237,31 @@ def extract_faces(
img_path=img, threshold=threshold, model=model, allow_upscaling=allow_upscaling
)

if align_first is True and len(obj) > 1:
logger.warn(
f"Even though align_first is set to True, there are {len(obj)} faces in input image."
"Align first functionality can be applied only if there is single face in the input"
)

if isinstance(obj, dict):
for _, identity in obj.items():
facial_area = identity["facial_area"]

x = facial_area[0]
y = facial_area[1]
w = facial_area[2]
h = facial_area[3]

# expand the facial area to be extracted and stay within img.shape limits
x1 = max(0, facial_area[0] - expand_face_area) # expand left
y1 = max(0, facial_area[1] - expand_face_area) # expand top
x2 = min(img.shape[1], facial_area[2] + expand_face_area) # expand right
y2 = min(img.shape[0], facial_area[3] + expand_face_area) # expand bottom
facial_img = img[y1:y2, x1:x2]
x1 = max(0, x - int((w * expand_face_area) / 100)) # expand left
y1 = max(0, y - int((h * expand_face_area) / 100)) # expand top
x2 = min(img.shape[1], w + int((w * expand_face_area) / 100)) # expand right
y2 = min(img.shape[0], h + int((h * expand_face_area) / 100)) # expand bottom

if align_first is False or (align_first is True and len(obj) > 1):
facial_img = img[y1:y2, x1:x2]
else:
facial_img = img.copy()

if align is True:
landmarks = identity["landmarks"]
Expand All @@ -249,9 +270,19 @@ def extract_faces(
nose = landmarks["nose"]
# mouth_right = landmarks["mouth_right"]
# mouth_left = landmarks["mouth_left"]

facial_img = postprocess.alignment_procedure(facial_img, right_eye, left_eye, nose)

if align_first is True and len(obj) == 1:
facial_img = extract_faces(
img_path=facial_img,
threshold=threshold,
model=model,
allow_upscaling=allow_upscaling,
expand_face_area=expand_face_area,
align=False,
align_first=False,
)[0][:, :, ::-1]

resp.append(facial_img[:, :, ::-1])

return resp
13 changes: 10 additions & 3 deletions retinaface/commons/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Union
import numpy as np
from PIL import Image

Expand All @@ -7,16 +8,22 @@


def findEuclideanDistance(
source_representation: np.ndarray, test_representation: np.ndarray
source_representation: Union[np.ndarray, list], test_representation: Union[np.ndarray, list]
) -> float:
"""
Find euclidean distance between 2 vectors
Args:
source_representation (numpy array)
test_representation (numpy array)
source_representation (numpy array or list)
test_representation (numpy array or list)
Returns
distance
"""
if isinstance(source_representation, list):
source_representation = np.array(source_representation)

if isinstance(test_representation, list):
test_representation = np.array(test_representation)

euclidean_distance = source_representation - test_representation
euclidean_distance = np.sum(np.multiply(euclidean_distance, euclidean_distance))
euclidean_distance = np.sqrt(euclidean_distance)
Expand Down
Binary file added tests/dataset/couple.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 55 additions & 0 deletions tests/test_align_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
from retinaface import RetinaFace
from retinaface.commons.logger import Logger

logger = Logger("tests/test_actions.py")

THRESHOLD = 1000


def test_detect_first():
"""
Test the default behavior. Detect first and align second causes
so many black pixels
"""
faces = RetinaFace.extract_faces(img_path="tests/dataset/img11.jpg")
num_black_pixels = np.sum(np.all(faces[0] == 0, axis=2))
assert num_black_pixels > THRESHOLD
logger.info("✅ Disabled align_first test for single face photo done")


def test_align_first():
"""
Test align first behavior. Align first and detect second do not cause
so many black pixels in contrast to default behavior
"""
faces = RetinaFace.extract_faces(img_path="tests/dataset/img11.jpg", align_first=True)
num_black_pixels = np.sum(np.all(faces[0] == 0, axis=2))
assert num_black_pixels < THRESHOLD
logger.info("✅ Enabled align_first test for single face photo done")


def test_align_first_for_group_photo():
"""
Align first will not work if the given image has many faces and
it will cause so many black pixels
"""
faces = RetinaFace.extract_faces(img_path="tests/dataset/couple.jpg", align_first=True)
for face in faces:
num_black_pixels = np.sum(np.all(face == 0, axis=2))
assert num_black_pixels > THRESHOLD

logger.info("✅ Enabled align_first test for group photo done")


def test_default_behavior_for_group_photo():
"""
Align first will not work in the default behaviour and
it will cause so many black pixels
"""
faces = RetinaFace.extract_faces(img_path="tests/dataset/couple.jpg")
for face in faces:
num_black_pixels = np.sum(np.all(face == 0, axis=2))
assert num_black_pixels > THRESHOLD

logger.info("✅ Disabled align_first test for group photo done")
34 changes: 34 additions & 0 deletions tests/test_expand_face_area.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import cv2
from retinaface import RetinaFace
from retinaface.commons import postprocess
from retinaface.commons.logger import Logger

logger = Logger("tests/test_expand_face_area.py")


def test_expand_face_area():
img_path = "tests/dataset/img11.jpg"
default_faces = RetinaFace.extract_faces(img_path=img_path, expand_face_area=10)

img1 = default_faces[0]
img1 = cv2.resize(img1, (500, 500))

obj1 = RetinaFace.detect_faces(img1, threshold=0.1)

expanded_faces = RetinaFace.extract_faces(img_path=img_path, expand_face_area=50)

img2 = expanded_faces[0]
img2 = cv2.resize(img2, (500, 500))

obj2 = RetinaFace.detect_faces(img2, threshold=0.1)

landmarks1 = obj1["face_1"]["landmarks"]
landmarks2 = obj2["face_1"]["landmarks"]

distance1 = postprocess.findEuclideanDistance(landmarks1["right_eye"], landmarks1["left_eye"])
distance2 = postprocess.findEuclideanDistance(landmarks2["right_eye"], landmarks2["left_eye"])

# 2nd one's expand ratio is higher. so, it should be smaller.
assert distance2 < distance1

logger.info("✅ Test expand face area is done")

0 comments on commit 018a5d2

Please sign in to comment.