From f7a1c02a40c784980c735eb27a52c79fb6425c2b Mon Sep 17 00:00:00 2001 From: Wangxueqing <870530576@qq.com> Date: Mon, 21 Oct 2024 15:04:07 +0000 Subject: [PATCH] init commit with face recgnition pipeline --- paddlex/configs/face_detection/BlazeFace.yaml | 39 +++ .../configs/face_recognition/IResNet50.yaml | 41 +++ .../face_recognition/MobileFaceNet.yaml | 41 +++ paddlex/inference/components/__init__.py | 1 + .../components/retrieval/__init__.py | 15 + .../inference/components/retrieval/faiss.py | 256 ++++++++++++++++++ .../inference/components/task_related/clas.py | 8 +- paddlex/inference/models/__init__.py | 1 + paddlex/inference/models/face_recognition.py | 99 +++++++ paddlex/inference/models/object_detection.py | 2 +- paddlex/inference/pipelines/__init__.py | 1 + .../pipelines/face_recognition/__init__.py | 15 + .../face_recognition/face_recognition.py | 125 +++++++++ paddlex/inference/results/__init__.py | 1 + paddlex/inference/results/face_rec.py | 35 +++ paddlex/inference/utils/official_models.py | 3 + paddlex/modules/__init__.py | 7 + paddlex/modules/face_recognition/__init__.py | 18 ++ .../dataset_checker/__init__.py | 71 +++++ .../dataset_checker/dataset_src/__init__.py | 16 ++ .../dataset_src/check_dataset.py | 156 +++++++++++ .../dataset_src/utils/__init__.py | 13 + .../dataset_src/utils/visualizer.py | 156 +++++++++++ paddlex/modules/face_recognition/evaluator.py | 51 ++++ paddlex/modules/face_recognition/exportor.py | 22 ++ .../modules/face_recognition/model_list.py | 18 ++ paddlex/modules/face_recognition/trainer.py | 73 +++++ .../modules/object_detection/model_list.py | 1 + paddlex/pipelines/face_recognition.yaml | 13 + .../repo_apis/PaddleClas_api/cls/register.py | 22 ++ .../PaddleClas_api/configs/IResNet50.yaml | 123 +++++++++ .../PaddleClas_api/configs/MobileFaceNet.yaml | 126 +++++++++ .../configs/BlazeFace.yaml | 144 ++++++++++ .../object_det/register.py | 31 +++ 34 files changed, 1739 insertions(+), 5 deletions(-) create mode 100644 paddlex/configs/face_detection/BlazeFace.yaml create mode 100644 paddlex/configs/face_recognition/IResNet50.yaml create mode 100644 paddlex/configs/face_recognition/MobileFaceNet.yaml create mode 100644 paddlex/inference/components/retrieval/__init__.py create mode 100644 paddlex/inference/components/retrieval/faiss.py create mode 100644 paddlex/inference/models/face_recognition.py create mode 100644 paddlex/inference/pipelines/face_recognition/__init__.py create mode 100644 paddlex/inference/pipelines/face_recognition/face_recognition.py create mode 100644 paddlex/inference/results/face_rec.py create mode 100644 paddlex/modules/face_recognition/__init__.py create mode 100644 paddlex/modules/face_recognition/dataset_checker/__init__.py create mode 100644 paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py create mode 100644 paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py create mode 100644 paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py create mode 100644 paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py create mode 100644 paddlex/modules/face_recognition/evaluator.py create mode 100644 paddlex/modules/face_recognition/exportor.py create mode 100644 paddlex/modules/face_recognition/model_list.py create mode 100644 paddlex/modules/face_recognition/trainer.py create mode 100644 paddlex/pipelines/face_recognition.yaml create mode 100644 paddlex/repo_apis/PaddleClas_api/configs/IResNet50.yaml create mode 100644 paddlex/repo_apis/PaddleClas_api/configs/MobileFaceNet.yaml create mode 100644 paddlex/repo_apis/PaddleDetection_api/configs/BlazeFace.yaml diff --git a/paddlex/configs/face_detection/BlazeFace.yaml b/paddlex/configs/face_detection/BlazeFace.yaml new file mode 100644 index 000000000..a3c505c91 --- /dev/null +++ b/paddlex/configs/face_detection/BlazeFace.yaml @@ -0,0 +1,39 @@ +Global: + model: BlazeFace + mode: check_dataset # check_dataset/train/evaluate/predict + dataset_dir: "/paddle/dataset/paddlex/det/widerface_coco_examples" + device: gpu:0,1,2,3 + output: "output" + +CheckDataset: + convert: + enable: False + src_dataset_type: null + split: + enable: False + train_percent: null + val_percent: null + +Train: + epochs_iters: 1000 + batch_size: 4 + learning_rate: 0.001 + pretrain_weight_path: null + warmup_steps: 500 + resume_path: null + log_interval: 10 + eval_interval: 10 + +Evaluate: + weight_path: "output/best_model/best_model.pdparams" + log_interval: 10 + +Export: + weight_path: https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams + +Predict: + batch_size: 1 + model_dir: "output/blazeface" + input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/face_detection.png" + kernel_option: + run_mode: paddle diff --git a/paddlex/configs/face_recognition/IResNet50.yaml b/paddlex/configs/face_recognition/IResNet50.yaml new file mode 100644 index 000000000..a7215b92d --- /dev/null +++ b/paddlex/configs/face_recognition/IResNet50.yaml @@ -0,0 +1,41 @@ +Global: + model: IResNet50 + mode: check_dataset # check_dataset/train/evaluate/predict + dataset_dir: "/paddle/dataset/paddlex/cls/face_train_examples" + device: gpu:0,1,2,3 + output: "output" + +CheckDataset: + convert: + enable: False + src_dataset_type: null + split: + enable: False + train_percent: null + val_percent: null + +Train: + num_classes: 995 + epochs_iters: 25 + batch_size: 128 + learning_rate: 0.002 + pretrain_weight_path: null + warmup_steps: 1 + resume_path: null + log_interval: 1 + eval_interval: 1 + save_interval: 1 + +Evaluate: + weight_path: "output/best_model/best_model.pdparams" + log_interval: 1 + +Export: + weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/iresnet50.pdparams + +Predict: + batch_size: 1 + model_dir: "output/best_model/inference" + input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/face_classification_001.jpg" + kernel_option: + run_mode: paddle diff --git a/paddlex/configs/face_recognition/MobileFaceNet.yaml b/paddlex/configs/face_recognition/MobileFaceNet.yaml new file mode 100644 index 000000000..541370234 --- /dev/null +++ b/paddlex/configs/face_recognition/MobileFaceNet.yaml @@ -0,0 +1,41 @@ +Global: + model: MobileFaceNet + mode: check_dataset # check_dataset/train/evaluate/predict + dataset_dir: "/paddle/dataset/paddlex/cls/face_train_examples" + device: gpu:0,1,2,3 + output: "output" + +CheckDataset: + convert: + enable: False + src_dataset_ype: null + split: + enable: False + train_percent: null + val_percent: null + +Train: + num_classes: 995 + epochs_iters: 25 + batch_size: 128 + learning_rate: 0.002 + pretrain_weight_path: null + warmup_steps: 1 + resume_path: null + log_interval: 1 + eval_interval: 1 + save_interval: 1 + +Evaluate: + weight_path: "output/best_model/best_model.pdparams" + log_interval: 1 + +Export: + weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/mobilefacenet.pdparams + +Predict: + batch_size: 1 + model_dir: "output/best_model/inference" + input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/face_classification_001.jpg" + kernel_option: + run_mode: paddle diff --git a/paddlex/inference/components/__init__.py b/paddlex/inference/components/__init__.py index b8bfb17d5..072d09e60 100644 --- a/paddlex/inference/components/__init__.py +++ b/paddlex/inference/components/__init__.py @@ -15,3 +15,4 @@ from .transforms import * from .paddle_predictor import * from .task_related import * +from .retrieval import * diff --git a/paddlex/inference/components/retrieval/__init__.py b/paddlex/inference/components/retrieval/__init__.py new file mode 100644 index 000000000..7cfcb5767 --- /dev/null +++ b/paddlex/inference/components/retrieval/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .faiss import FaissIndexer diff --git a/paddlex/inference/components/retrieval/faiss.py b/paddlex/inference/components/retrieval/faiss.py new file mode 100644 index 000000000..ce30ed515 --- /dev/null +++ b/paddlex/inference/components/retrieval/faiss.py @@ -0,0 +1,256 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +from pathlib import Path +import faiss +import numpy as np + +from ....utils import logging +from ..base import BaseComponent + + +class FaissIndexer(BaseComponent): + + INPUT_KEYS = "feature" + OUTPUT_KEYS = ["label", "score"] + DEAULT_INPUTS = {"feature": "feature"} + DEAULT_OUTPUTS = {"label": "label", "score": "score", "unique_id": "unique_id"} + + ENABLE_BATCH = True + + def __init__( + self, + index_dir, + metric_type="IP", + return_k=1, + score_thres=None, + hamming_radius=None, + ): + super().__init__() + index_dir = Path(index_dir) + vector_path = (index_dir / "vector.index").as_posix() + id_map_path = (index_dir / "id_map.pkl").as_posix() + + if metric_type == "hamming": + self._indexer = faiss.read_index_binary(vector_path) + self.hamming_radius = hamming_radius + else: + self._indexer = faiss.read_index(vector_path) + self.score_thres = score_thres + with open(id_map_path, "rb") as fd: + self.id_map = pickle.load(fd) + self.metric_type = metric_type + self.return_k = return_k + self.unique_id_map = {k: v+1 for v, k in enumerate(sorted(set(self.id_map.values())))} + + def apply(self, feature): + """apply""" + scores_list, ids_list = self._indexer.search(np.array(feature), self.return_k) + preds = [] + for scores, ids in zip(scores_list, ids_list): + labels = [] + for id in ids: + if id > 0: + labels.append(self.id_map[id]) + preds.append({"score": scores, + "label": labels, + "unique_id": [self.unique_id_map[l] for l in labels]}) + + if self.metric_type == "hamming": + idxs = np.where(scores_list[:, 0] > self.hamming_radius)[0] + else: + idxs = np.where(scores_list[:, 0] < self.score_thres)[0] + for idx in idxs: + preds[idx] = {"score": None, "label": None, "unique_id": None} + return preds + + +class FaissBuilder: + + SUPPORT_MODE = ("new", "remove", "append") + SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2") + SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32") + BINARY_METRIC_TYPE = ("hamming", "jaccard") + BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash") + + def __init__(self, predict, mode="new", index_type="HNSW32", metric_type="IP"): + super().__init__() + assert mode in self.SUPPORT_MODE, f"Supported modes only: {self.SUPPORT_MODE}!" + assert ( + metric_type in self.SUPPORT_METRIC_TYPE + ), f"Supported metric types only: {self.SUPPORT_METRIC_TYPE}!" + assert ( + index_type in self.SUPPORT_INDEX_TYPE + ), f"Supported index types only: {self.SUPPORT_INDEX_TYPE}!" + + self._predict = predict + self._mode = mode + self._metric_type = metric_type + self._index_type = index_type + + def _get_index_type(self, num=None): + if self._metric_type in self.BINARY_METRIC_TYPE: + assert ( + self._index_type in self.BINARY_SUPPORT_INDEX_TYPE + ), f"The metric type({self._metric_type}) only support {self.BINARY_SUPPORT_INDEX_TYPE} index types!" + + # if IVF method, cal ivf number automaticlly + if self._index_type == "IVF": + index_type = self._index_type + str(min(int(num // 8), 65536)) + if self._metric_type in self.BINARY_METRIC_TYPE: + index_type += ",BFlat" + else: + index_type += ",Flat" + + # for binary index, add B at head of index_type + if self._metric_type in self.BINARY_METRIC_TYPE: + return "B" + index_type + + if self._index_type == "HNSW32": + index_type = self._index_type + logging.warning("The HNSW32 method dose not support 'remove' operation") + return index_type + + def _get_metric_type(self): + if self._metric_type == "hamming": + return faiss.METRIC_Hamming + elif self._metric_type == "jaccard": + return faiss.METRIC_Jaccard + elif self._metric_type == "IP": + return faiss.METRIC_INNER_PRODUCT + elif self._metric_type == "L2": + return faiss.METRIC_L2 + + def build( + self, + label_file, + image_root, + index_dir, + ): + file_list, gallery_docs = get_file_list(label_file, image_root) + if self._mode != "remove": + features = [res["feature"] for res in self._predict(file_list)] + dtype = ( + np.uint8 if self._metric_type in self.BINARY_METRIC_TYPE else np.float32 + ) + features = np.array(features).astype(dtype) + vector_num, vector_dim = features.shape + + if self._mode in ["remove", "append"]: + # if remove or append, load vector.index and id_map.pkl + index, ids = self._load_index(index_dir) + else: + # build index + if self._metric_type in self.BINARY_METRIC_TYPE: + index = faiss.index_binary_factory( + vector_dim, + self._get_index_type(vector_num), + self._get_metric_type(), + ) + else: + index = faiss.index_factory( + vector_dim, + self._get_index_type(vector_num), + self._get_metric_type(), + ) + index = faiss.IndexIDMap2(index) + ids = {} + + if self._mode != "remove": + # calculate id for new data + index, ids = self._add_gallery(index, ids, features, gallery_docs) + else: + if self._index_type == "HNSW32": + raise RuntimeError( + "The index_type: HNSW32 dose not support 'remove' operation" + ) + # remove ids in id_map, remove index data in faiss index + index, ids = self._rm_id_in_galllery(index, ids, gallery_docs) + + # store faiss index file and id_map file + self._save_gallery(index, ids, index_dir) + + def _load_index(self, index_dir): + assert os.path.join( + index_dir, "vector.index" + ), "The vector.index dose not exist in {} when 'index_operation' is not None".format( + index_dir + ) + assert os.path.join( + index_dir, "id_map.pkl" + ), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format( + index_dir + ) + index = faiss.read_index(os.path.join(index_dir, "vector.index")) + with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: + ids = pickle.load(fd) + assert index.ntotal == len( + ids.keys() + ), "data number in index is not equal in in id_map" + return index, ids + + def _add_gallery(self, index, ids, gallery_features, gallery_docs): + start_id = max(ids.keys()) + 1 if ids else 0 + ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64) + + # only train when new index file + if self._mode == "new": + if self._metric_type in self.BINARY_METRIC_TYPE: + index.add(gallery_features) + else: + index.train(gallery_features) + + if not self._metric_type in self.BINARY_METRIC_TYPE: + index.add_with_ids(gallery_features, ids_now) + + for i, d in zip(list(ids_now), gallery_docs): + ids[i] = d + return index, ids + + def _rm_id_in_galllery(self, index, ids, gallery_docs): + remove_ids = list(filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) + remove_ids = np.asarray(remove_ids) + index.remove_ids(remove_ids) + for k in remove_ids: + del ids[k] + + return index, ids + + def _save_gallery(self, index, ids, index_dir): + Path(index_dir).mkdir(parents=True, exist_ok=True) + if self._metric_type in self.BINARY_METRIC_TYPE: + faiss.write_index_binary(index, os.path.join(index_dir, "vector.index")) + else: + faiss.write_index(index, os.path.join(index_dir, "vector.index")) + + with open(os.path.join(index_dir, "id_map.pkl"), "wb") as fd: + pickle.dump(ids, fd) + + +def get_file_list(data_file, root_dir, delimiter="\t"): + root_dir = Path(root_dir) + files = [] + labels = [] + lines = [] + with open(data_file, "r", encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + path, label = line.strip().split(delimiter) + file_path = root_dir / path + files.append(file_path.as_posix()) + labels.append(label) + + return files, labels \ No newline at end of file diff --git a/paddlex/inference/components/task_related/clas.py b/paddlex/inference/components/task_related/clas.py index 1ecf1d2e4..f1e1356e8 100644 --- a/paddlex/inference/components/task_related/clas.py +++ b/paddlex/inference/components/task_related/clas.py @@ -113,12 +113,12 @@ class NormalizeFeatures(BaseComponent): """Normalize Features Transform""" INPUT_KEYS = ["pred"] - OUTPUT_KEYS = ["rec_feature"] + OUTPUT_KEYS = ["feature"] DEAULT_INPUTS = {"pred": "pred"} - DEAULT_OUTPUTS = {"rec_feature": "rec_feature"} + DEAULT_OUTPUTS = {"feature": "feature"} def apply(self, pred): """apply""" feas_norm = np.sqrt(np.sum(np.square(pred[0]), axis=0, keepdims=True)) - rec_feature = np.divide(pred[0], feas_norm) - return {"rec_feature": rec_feature} + feature = np.divide(pred[0], feas_norm) + return {"feature": feature} diff --git a/paddlex/inference/models/__init__.py b/paddlex/inference/models/__init__.py index 49143fa4a..07feeb58f 100644 --- a/paddlex/inference/models/__init__.py +++ b/paddlex/inference/models/__init__.py @@ -34,6 +34,7 @@ from .multilabel_classification import MLClasPredictor from .anomaly_detection import UadPredictor from .formula_recognition import LaTeXOCRPredictor +from .face_recognition import FaceRecPredictor def _create_hp_predictor( diff --git a/paddlex/inference/models/face_recognition.py b/paddlex/inference/models/face_recognition.py new file mode 100644 index 000000000..b165639e8 --- /dev/null +++ b/paddlex/inference/models/face_recognition.py @@ -0,0 +1,99 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from paddlex.utils.func_register import FuncRegister +from paddlex.modules.face_recognition.model_list import MODELS +from ..components import * +from ..results import BaseResult +from .base import BasicPredictor + + +class FaceRecPredictor(BasicPredictor): + + entities = MODELS + + _FUNC_MAP = {} + register = FuncRegister(_FUNC_MAP) + + def _build_components(self): + self._add_component(ReadImage(format="RGB")) + for cfg in self.config["PreProcess"]["transform_ops"]: + tf_key = list(cfg.keys())[0] + func = self._FUNC_MAP[tf_key] + args = cfg.get(tf_key, {}) + op = func(self, **args) if args else func(self) + self._add_component(op) + + predictor = ImagePredictor( + model_dir=self.model_dir, + model_prefix=self.MODEL_FILE_PREFIX, + option=self.pp_option, + ) + self._add_component(predictor) + + post_processes = self.config["PostProcess"] + for key in post_processes: + func = self._FUNC_MAP.get(key) + args = post_processes.get(key, {}) + op = func(self, **args) if args else func(self) + self._add_component(op) + + @register("ResizeImage") + # TODO(gaotingquan): backend & interpolation + def build_resize( + self, + resize_short=None, + size=None, + backend="cv2", + interpolation="LINEAR", + return_numpy=False, + ): + assert resize_short or size + if resize_short: + op = ResizeByShort( + target_short_edge=resize_short, size_divisor=None, interp="LINEAR" + ) + else: + op = Resize(target_size=size) + return op + + @register("CropImage") + def build_crop(self, size=224): + return Crop(crop_size=size) + + @register("NormalizeImage") + def build_normalize( + self, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + scale=1 / 255, + order="", + channel_num=3, + ): + assert channel_num == 3 + return Normalize(mean=mean, std=std) + + @register("ToCHWImage") + def build_to_chw(self): + return ToCHWImage() + + @register("NormalizeFeatures") + def build_normalize_features(self): + return NormalizeFeatures() + + def _pack_res(self, data): + keys = ["input_path", "feature"] + return BaseResult({key: data[key] for key in keys}) \ No newline at end of file diff --git a/paddlex/inference/models/object_detection.py b/paddlex/inference/models/object_detection.py index 49d8cb904..bf010442e 100644 --- a/paddlex/inference/models/object_detection.py +++ b/paddlex/inference/models/object_detection.py @@ -53,7 +53,7 @@ def _build_components(self): } ) - if self.model_name == "Blazeface": + if self.model_name == "BlazeFace": predictor.set_inputs( { "img": "img", diff --git a/paddlex/inference/pipelines/__init__.py b/paddlex/inference/pipelines/__init__.py index 8f1da24d4..e2aee8822 100644 --- a/paddlex/inference/pipelines/__init__.py +++ b/paddlex/inference/pipelines/__init__.py @@ -34,6 +34,7 @@ from .ocr import OCRPipeline from .formula_recognition import FormulaRecognitionPipeline from .table_recognition import TableRecPipeline +from .face_recognition import FaceRecPipeline from .seal_recognition import SealOCRPipeline from .ppchatocrv3 import PPChatOCRPipeline from .layout_parsing import LayoutParsingPipeline diff --git a/paddlex/inference/pipelines/face_recognition/__init__.py b/paddlex/inference/pipelines/face_recognition/__init__.py new file mode 100644 index 000000000..d4e1ff6dc --- /dev/null +++ b/paddlex/inference/pipelines/face_recognition/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .face_recognition import FaceRecPipeline diff --git a/paddlex/inference/pipelines/face_recognition/face_recognition.py b/paddlex/inference/pipelines/face_recognition/face_recognition.py new file mode 100644 index 000000000..7fb9bfeca --- /dev/null +++ b/paddlex/inference/pipelines/face_recognition/face_recognition.py @@ -0,0 +1,125 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from paddlex.inference.components import CropByBoxes, FaissIndexer +from paddlex.inference.components.retrieval.faiss import FaissBuilder +from paddlex.inference.results import FaceRecResult +from paddlex.inference.utils.io import ImageReader +from ..base import BasePipeline + + +class FaceRecPipeline(BasePipeline): + """Face Recognition Pipeline""" + + entities = "face_recognition" + + def __init__( + self, + det_model, + rec_model, + det_batch_size=1, + rec_batch_size=1, + index_dir=None, + metric_type="IP", + score_thres=None, + hamming_radius=None, + return_k=5, + device=None, + predictor_kwargs=None, + ): + super().__init__(device, predictor_kwargs) + self._build_predictor(det_model, rec_model) + self.set_predictor(det_batch_size, rec_batch_size, device) + self._indexer = ( + FaissIndexer(index_dir, metric_type, return_k, score_thres, hamming_radius) + if index_dir + else None + ) + + def _build_predictor(self, det_model, rec_model): + self.det_model = self._create(model=det_model) + self.rec_model = self._create(model=rec_model) + self._crop_by_boxes = CropByBoxes() + self._img_reader = ImageReader(backend="opencv") + + def set_predictor(self, det_batch_size=None, rec_batch_size=None, device=None): + if det_batch_size: + self.det_model.set_predictor(batch_size=det_batch_size) + if rec_batch_size: + self.rec_model.set_predictor(batch_size=rec_batch_size) + if device: + self.det_model.set_predictor(device=device) + self.rec_model.set_predictor(device=device) + + def predict(self, input, **kwargs): + assert self._indexer + self.set_predictor(**kwargs) + for det_res in self.det_model(input): + rec_res = self.get_rec_result(det_res) + yield self.get_final_result(det_res, rec_res) + + def get_rec_result(self, det_res): + full_img = self._img_reader.read(det_res["input_path"]) + w, h = full_img.shape[:2] + # det_res["boxes"].append( + # {"cls_id": 0, "label": "full_img", "score": 0, "coordinate": [0, 0, h, w]} + # ) + subs_of_img = list(self._crop_by_boxes(det_res)) + img_list = [img["img"] for img in subs_of_img] + all_rec_res = list(self.rec_model(img_list)) + all_rec_res = next(self._indexer(all_rec_res)) + output = {"label": [], "score": [], "unique_id": []} + for res in all_rec_res: + output["label"].append(res["label"]) + output["score"].append(res["score"]) + output["unique_id"].append(res["unique_id"]) + return output + + def get_final_result(self, det_res, rec_res): + single_img_res = {"input_path": det_res["input_path"], "boxes": []} + for i, obj in enumerate(det_res["boxes"]): + rec_scores = rec_res["score"][i] + labels = rec_res["label"][i] + rec_ids = rec_res["unique_id"][i] + single_img_res["boxes"].append( + { + "labels": labels, + "rec_scores": rec_scores, + "rec_ids": rec_ids, + "det_score": obj["score"], + "coordinate": obj["coordinate"], + } + ) + return FaceRecResult(single_img_res) + + def build_index( + self, + label_file, + image_root, + index_dir, + mode="new", + metric_type="IP", + index_type="HNSW32", + **kwargs, + ): + self.set_predictor(**kwargs) + builder = FaissBuilder( + self.rec_model.predict, + mode=mode, + metric_type=metric_type, + index_type=index_type, + ) + builder.build(label_file, image_root, index_dir) + return diff --git a/paddlex/inference/results/__init__.py b/paddlex/inference/results/__init__.py index 0438063f4..63fe66ce4 100644 --- a/paddlex/inference/results/__init__.py +++ b/paddlex/inference/results/__init__.py @@ -26,3 +26,4 @@ from .ts import TSFcResult, TSAdResult, TSClsResult from .warp import DocTrResult from .chat_ocr import * +from .face_rec import FaceRecResult diff --git a/paddlex/inference/results/face_rec.py b/paddlex/inference/results/face_rec.py new file mode 100644 index 000000000..e86311bf5 --- /dev/null +++ b/paddlex/inference/results/face_rec.py @@ -0,0 +1,35 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from .base import CVResult +from .det import draw_box + + +class FaceRecResult(CVResult): + + def _to_img(self): + """apply""" + image = self._img_reader.read(self["input_path"]) + boxes = [ + { + "coordinate": box["coordinate"], + "label": box["labels"][0] if box["labels"] is not None else "Unknown", + "score": box["det_score"], + "cls_id": box["rec_ids"][0] if box["rec_ids"] is not None else 0 # rec ids start from 1 + } + for box in self["boxes"] + ] + image = draw_box(image, boxes) + return image diff --git a/paddlex/inference/utils/official_models.py b/paddlex/inference/utils/official_models.py index e4a39c8c3..419d6434e 100644 --- a/paddlex/inference/utils/official_models.py +++ b/paddlex/inference/utils/official_models.py @@ -258,6 +258,9 @@ "RT-DETR-H_layout_3cls": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/RT-DETR-H_layout_3cls_infer.tar", "RT-DETR-H_layout_17cls": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/RT-DETR-H_layout_17cls_infer.tar", "PicoDet_LCNet_x2_5_face": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/PicoDet_LCNet_x2_5_face_infer.tar", + "BlazeFace": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/BlazeFace_infer.tar", + "MobileFaceNet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MobileFaceNet_infer.tar", + "IResNet50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/IResNet50_infer.tar" } diff --git a/paddlex/modules/__init__.py b/paddlex/modules/__init__.py index 0369f2d58..27c23eeec 100644 --- a/paddlex/modules/__init__.py +++ b/paddlex/modules/__init__.py @@ -95,4 +95,11 @@ TSCLSExportor, ) +from .face_recognition import ( + FaceRecDatasetChecker, + FaceRecTrainer, + FaceRecEvaluator, + FaceRecExportor, +) + from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator diff --git a/paddlex/modules/face_recognition/__init__.py b/paddlex/modules/face_recognition/__init__.py new file mode 100644 index 000000000..c9092df7d --- /dev/null +++ b/paddlex/modules/face_recognition/__init__.py @@ -0,0 +1,18 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainer import FaceRecTrainer +from .dataset_checker import FaceRecDatasetChecker +from .evaluator import FaceRecEvaluator +from .exportor import FaceRecExportor diff --git a/paddlex/modules/face_recognition/dataset_checker/__init__.py b/paddlex/modules/face_recognition/dataset_checker/__init__.py new file mode 100644 index 000000000..a7f2fac8a --- /dev/null +++ b/paddlex/modules/face_recognition/dataset_checker/__init__.py @@ -0,0 +1,71 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +from ...base import BaseDatasetChecker +from .dataset_src import check_train, check_val +from ..model_list import MODELS + + +class FaceRecDatasetChecker(BaseDatasetChecker): + """Dataset Checker for Image Classification Model""" + + entities = MODELS + sample_num = 10 + + def get_dataset_root(self, dataset_dir: str) -> str: + """find the dataset root dir + + Args: + dataset_dir (str): the directory that contain dataset. + + Returns: + str: the root directory of dataset. + """ + anno_dirs = list(Path(dataset_dir).glob("**/images")) + assert len(anno_dirs) == 2 + dataset_dir = anno_dirs[0].parent.parent.as_posix() + return dataset_dir + + def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict: + """check if the dataset meets the specifications and get dataset summary + + Args: + dataset_dir (str): the root directory of dataset. + sample_num (int): the number to be sampled. + Returns: + dict: dataset summary. + """ + train_attr = check_train(os.path.join(dataset_dir, "train"), self.output) + val_attr = check_val(os.path.join(dataset_dir, "val"), self.output) + train_attr.update(val_attr) + return train_attr + + def get_show_type(self) -> str: + """get the show type of dataset + + Returns: + str: show type + """ + return "image" + + def get_dataset_type(self) -> str: + """return the dataset type + + Returns: + str: dataset type + """ + return "ClsDataset" diff --git a/paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py b/paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py new file mode 100644 index 000000000..f807cfecb --- /dev/null +++ b/paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .check_dataset import check_train, check_val diff --git a/paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py b/paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py new file mode 100644 index 000000000..939a51113 --- /dev/null +++ b/paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py @@ -0,0 +1,156 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +import random +import pickle +from PIL import Image, ImageOps +from collections import defaultdict +from tqdm import tqdm + +from .....utils.errors import DatasetFileNotFoundError, CheckFailedError +from .utils.visualizer import draw_label + + +def check_train(dataset_dir, output, sample_num=10): + """check dataset""" + dataset_dir = osp.abspath(dataset_dir) + # Custom dataset + if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir): + raise DatasetFileNotFoundError(file_path=dataset_dir) + + delim = " " + valid_num_parts = 2 + + label_map_dict = dict() + sample_paths = [] + labels = [] + + label_file = osp.join(dataset_dir, "label.txt") + if not osp.exists(label_file): + raise DatasetFileNotFoundError( + file_path=label_file, + solution=f"Ensure that `label.txt` exist in {dataset_dir}", + ) + with open(label_file, "r", encoding="utf-8") as f: + all_lines = f.readlines() + random.seed(123) + random.shuffle(all_lines) + sample_cnts = len(all_lines) + for line in all_lines: + substr = line.strip("\n").split(delim) + if len(substr) != valid_num_parts: + raise CheckFailedError( + f"The number of delimiter-separated items in each row in {label_file} \ + should be {valid_num_parts} (current delimiter is '{delim}')." + ) + file_name = substr[0] + label = substr[1] + + img_path = osp.join(dataset_dir, file_name) + + if not osp.exists(img_path): + raise DatasetFileNotFoundError(file_path=img_path) + + vis_save_dir = osp.join(output, "demo_img") + if not osp.exists(vis_save_dir): + os.makedirs(vis_save_dir) + + try: + label = int(label) + label_map_dict[label] = str(label) + except (ValueError, TypeError) as e: + raise CheckFailedError( + f"Ensure that the second number in each line in {label_file} should be int." + ) from e + + if len(sample_paths) < sample_num: + img = Image.open(img_path) + img = ImageOps.exif_transpose(img) + vis_im = draw_label(img, label, label_map_dict) + vis_path = osp.join(vis_save_dir, osp.basename(file_name)) + vis_im.save(vis_path) + sample_path = osp.join( + "check_dataset", os.path.relpath(vis_path, output) + ) + sample_paths.append(sample_path) + labels.append(label) + if min(labels) != 0: + raise CheckFailedError( + f"Ensure that the index starts from 0 in `{label_file}`." + ) + num_classes = max(labels) + 1 + attrs = {} + attrs["train_label_file"] = osp.relpath(label_file, output) + attrs["train_num_classes"] = num_classes + attrs["train_samples"] = sample_cnts + attrs["train_sample_paths"] = sample_paths + return attrs + +def check_val(dataset_dir, output, sample_num=10): + """check dataset""" + dataset_dir = osp.abspath(dataset_dir) + # Custom dataset + if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir): + raise DatasetFileNotFoundError(file_path=dataset_dir) + + delim = " " + valid_num_parts = 3 + + labels = [] + label_file = osp.join(dataset_dir, "pair_label.txt") + if not osp.exists(label_file): + raise DatasetFileNotFoundError( + file_path=label_file, + solution=f"Ensure that `label.txt` exist in {dataset_dir}", + ) + with open(label_file, "r", encoding="utf-8") as f: + all_lines = f.readlines() + random.seed(123) + random.shuffle(all_lines) + sample_cnts = len(all_lines) + for line in all_lines: + substr = line.strip("\n").split(delim) + if len(substr) != valid_num_parts: + raise CheckFailedError( + f"The number of delimiter-separated items in each row in {label_file} \ + should be {valid_num_parts} (current delimiter is '{delim}')." + ) + left_file_name = substr[0] + right_file_name = substr[1] + label = substr[2] + + left_img_path = osp.join(dataset_dir, left_file_name) + if not osp.exists(left_img_path): + raise DatasetFileNotFoundError(file_path=left_img_path) + + right_img_path = osp.join(dataset_dir, right_file_name) + if not osp.exists(right_img_path): + raise DatasetFileNotFoundError(file_path=right_img_path) + + try: + label = int(label) + assert label in [0, 1], "Face eval dataset only support two classes" + except (ValueError, TypeError) as e: + raise CheckFailedError( + f"Ensure that the second number in each line in {label_file} should be int." + ) from e + labels.append(label) + num_classes = max(labels) + 1 + attrs = {} + attrs["val_label_file"] = osp.relpath(label_file, output) + attrs["val_num_classes"] = num_classes + attrs["val_samples"] = sample_cnts + return attrs diff --git a/paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py b/paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py new file mode 100644 index 000000000..59372f937 --- /dev/null +++ b/paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py b/paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py new file mode 100644 index 000000000..110e0ec6d --- /dev/null +++ b/paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py @@ -0,0 +1,156 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import json +from pathlib import Path +import PIL +from PIL import Image, ImageDraw, ImageFont + +from ......utils.fonts import PINGFANG_FONT_FILE_PATH + + +def colormap(rgb=False): + """ + Get colormap + """ + color_list = np.array( + [ + 0xFF, + 0x00, + 0x00, + 0xCC, + 0xFF, + 0x00, + 0x00, + 0xFF, + 0x66, + 0x00, + 0x66, + 0xFF, + 0xCC, + 0x00, + 0xFF, + 0xFF, + 0x4D, + 0x00, + 0x80, + 0xFF, + 0x00, + 0x00, + 0xFF, + 0xB2, + 0x00, + 0x1A, + 0xFF, + 0xFF, + 0x00, + 0xE5, + 0xFF, + 0x99, + 0x00, + 0x33, + 0xFF, + 0x00, + 0x00, + 0xFF, + 0xFF, + 0x33, + 0x00, + 0xFF, + 0xFF, + 0x00, + 0x99, + 0xFF, + 0xE5, + 0x00, + 0x00, + 0xFF, + 0x1A, + 0x00, + 0xB2, + 0xFF, + 0x80, + 0x00, + 0xFF, + 0xFF, + 0x00, + 0x4D, + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) + if not rgb: + color_list = color_list[:, ::-1] + return color_list.astype("int32") + + +def font_colormap(color_index): + """ + Get font colormap + """ + dark = np.array([0x14, 0x0E, 0x35]) + light = np.array([0xFF, 0xFF, 0xFF]) + light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19] + if color_index in light_indexs: + return light.astype("int32") + else: + return dark.astype("int32") + + +def draw_label(image, label, label_map_dict): + """Draw label on image""" + image = image.convert("RGB") + image_size = image.size + draw = ImageDraw.Draw(image) + min_font_size = int(image_size[0] * 0.02) + max_font_size = int(image_size[0] * 0.05) + for font_size in range(max_font_size, min_font_size - 1, -1): + font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8") + if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0): + text_width_tmp, text_height_tmp = draw.textsize( + label_map_dict[int(label)], font + ) + else: + left, top, right, bottom = draw.textbbox( + (0, 0), label_map_dict[int(label)], font + ) + text_width_tmp, text_height_tmp = right - left, bottom - top + if text_width_tmp <= image_size[0]: + break + else: + font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, min_font_size) + color_list = colormap(rgb=True) + color = tuple(color_list[0]) + font_color = tuple(font_colormap(3)) + if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0): + text_width, text_height = draw.textsize(label_map_dict[int(label)], font) + else: + left, top, right, bottom = draw.textbbox( + (0, 0), label_map_dict[int(label)], font + ) + text_width, text_height = right - left, bottom - top + + rect_left = 3 + rect_top = 3 + rect_right = rect_left + text_width + 3 + rect_bottom = rect_top + text_height + 6 + + draw.rectangle([(rect_left, rect_top), (rect_right, rect_bottom)], fill=color) + + text_x = rect_left + 3 + text_y = rect_top + draw.text((text_x, text_y), label_map_dict[int(label)], fill=font_color, font=font) + + return image diff --git a/paddlex/modules/face_recognition/evaluator.py b/paddlex/modules/face_recognition/evaluator.py new file mode 100644 index 000000000..67875f2b9 --- /dev/null +++ b/paddlex/modules/face_recognition/evaluator.py @@ -0,0 +1,51 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from paddlex.utils.misc import abspath +from ..base import BaseEvaluator +from .model_list import MODELS + + +class FaceRecEvaluator(BaseEvaluator): + """Image Classification Model Evaluator""" + + entities = MODELS + + def update_config(self): + """update evalution config""" + if self.eval_config.log_interval: + self.pdx_config.update_log_interval(self.eval_config.log_interval) + self.update_dataset_cfg() + self.pdx_config.update_pretrained_weights(self.eval_config.weight_path) + def update_dataset_cfg(self): + val_dataset_dir = abspath(os.path.join(self.global_config.dataset_dir, "val")) + val_list_path = abspath(os.path.join(val_dataset_dir, "pair_label.txt")) + ds_cfg = [ + f"DataLoader.Eval.dataset.name=FaceEvalDataset", + f"DataLoader.Eval.dataset.dataset_root={val_dataset_dir}", + f"DataLoader.Eval.dataset.pair_label_path={val_list_path}", + ] + self.pdx_config.update(ds_cfg) + + def get_eval_kwargs(self) -> dict: + """get key-value arguments of model evalution function + + Returns: + dict: the arguments of evaluation function. + """ + return { + "weight_path": self.eval_config.weight_path, + "device": self.get_device(using_device_number=1), + } diff --git a/paddlex/modules/face_recognition/exportor.py b/paddlex/modules/face_recognition/exportor.py new file mode 100644 index 000000000..6f73a84c9 --- /dev/null +++ b/paddlex/modules/face_recognition/exportor.py @@ -0,0 +1,22 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..base import BaseExportor +from .model_list import MODELS + + +class FaceRecExportor(BaseExportor): + """Image Classification Model Exportor""" + + entities = MODELS diff --git a/paddlex/modules/face_recognition/model_list.py b/paddlex/modules/face_recognition/model_list.py new file mode 100644 index 000000000..ffc12075a --- /dev/null +++ b/paddlex/modules/face_recognition/model_list.py @@ -0,0 +1,18 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MODELS = [ + "MobileFaceNet", + "IResnet50" +] diff --git a/paddlex/modules/face_recognition/trainer.py b/paddlex/modules/face_recognition/trainer.py new file mode 100644 index 000000000..bd4ed5a45 --- /dev/null +++ b/paddlex/modules/face_recognition/trainer.py @@ -0,0 +1,73 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +from paddlex.utils.misc import abspath +from ..image_classification import ClsTrainer +from .model_list import MODELS + + +class FaceRecTrainer(ClsTrainer): + """Image Classification Model Trainer""" + + entities = MODELS + + def update_config(self): + """update training config""" + if self.train_config.log_interval: + self.pdx_config.update_log_interval(self.train_config.log_interval) + if self.train_config.eval_interval: + self.pdx_config.update_eval_interval(self.train_config.eval_interval) + if self.train_config.save_interval: + self.pdx_config.update_save_interval(self.train_config.save_interval) + + self.update_dataset_cfg() + if self.train_config.num_classes is not None: + self.pdx_config.update_num_classes(self.train_config.num_classes) + if self.train_config.pretrain_weight_path != "": + self.pdx_config.update_pretrained_weights( + self.train_config.pretrain_weight_path + ) + + label_dict_path = Path(self.global_config.dataset_dir).joinpath("label.txt") + if label_dict_path.exists(): + self.dump_label_dict(label_dict_path) + if self.train_config.batch_size is not None: + self.pdx_config.update_batch_size(self.train_config.batch_size) + if self.train_config.learning_rate is not None: + self.pdx_config.update_learning_rate(self.train_config.learning_rate) + if self.train_config.epochs_iters is not None: + self.pdx_config._update_epochs(self.train_config.epochs_iters) + if self.train_config.warmup_steps is not None: + self.pdx_config.update_warmup_epochs(self.train_config.warmup_steps) + if self.global_config.output is not None: + self.pdx_config._update_output_dir(self.global_config.output) + + def update_dataset_cfg(self): + train_dataset_dir = abspath(os.path.join(self.global_config.dataset_dir, "train")) + val_dataset_dir = abspath(os.path.join(self.global_config.dataset_dir, "val")) + train_list_path = abspath(os.path.join(train_dataset_dir, "label.txt")) + val_list_path = abspath(os.path.join(val_dataset_dir, "pair_label.txt")) + + ds_cfg = [ + f"DataLoader.Train.dataset.name=ClsDataset", + f"DataLoader.Train.dataset.image_root={train_dataset_dir}", + f"DataLoader.Train.dataset.cls_label_path={train_list_path}", + f"DataLoader.Eval.dataset.name=FaceEvalDataset", + f"DataLoader.Eval.dataset.dataset_root={val_dataset_dir}", + f"DataLoader.Eval.dataset.pair_label_path={val_list_path}", + ] + self.pdx_config.update(ds_cfg) diff --git a/paddlex/modules/object_detection/model_list.py b/paddlex/modules/object_detection/model_list.py index 1f3172857..0b6ebdf24 100644 --- a/paddlex/modules/object_detection/model_list.py +++ b/paddlex/modules/object_detection/model_list.py @@ -64,4 +64,5 @@ "CenterNet-DLA-34", "CenterNet-ResNet50", "PicoDet_LCNet_x2_5_face", + "BlazeFace", ] diff --git a/paddlex/pipelines/face_recognition.yaml b/paddlex/pipelines/face_recognition.yaml new file mode 100644 index 000000000..d0aa85fbc --- /dev/null +++ b/paddlex/pipelines/face_recognition.yaml @@ -0,0 +1,13 @@ +Global: + pipeline_name: face_recognition + input: ./drink_dataset_v2.0/test_images/100.jpeg + +Pipeline: + det_model: "BlazeFace" + rec_model: "MobileFaceNet" + det_batch_size: 1 + rec_batch_size: 1 + device: gpu + index_dir: "face_index" + score_thres: 0.4 + return_k: 5 diff --git a/paddlex/repo_apis/PaddleClas_api/cls/register.py b/paddlex/repo_apis/PaddleClas_api/cls/register.py index 98f321164..6c3e01187 100644 --- a/paddlex/repo_apis/PaddleClas_api/cls/register.py +++ b/paddlex/repo_apis/PaddleClas_api/cls/register.py @@ -947,3 +947,25 @@ "hpi_config_path": None, } ) + +register_model_info( + { + "model_name": "MobileFaceNet", + "suite": "Cls", + "config_path": osp.join(PDX_CONFIG_DIR, "MobileFaceNet.yaml"), + "supported_apis": ["train", "evaluate", "predict", "export", "infer"], + "infer_config": "deploy/configs/inference_cls.yaml", + "hpi_config_path": None, + } +) + +register_model_info( + { + "model_name": "IResnet50", + "suite": "Cls", + "config_path": osp.join(PDX_CONFIG_DIR, "IResnet50.yaml"), + "supported_apis": ["train", "evaluate", "predict", "export", "infer"], + "infer_config": "deploy/configs/inference_cls.yaml", + "hpi_config_path": None, + } +) diff --git a/paddlex/repo_apis/PaddleClas_api/configs/IResNet50.yaml b/paddlex/repo_apis/PaddleClas_api/configs/IResNet50.yaml new file mode 100644 index 000000000..fcd78adda --- /dev/null +++ b/paddlex/repo_apis/PaddleClas_api/configs/IResNet50.yaml @@ -0,0 +1,123 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: output/face_arcface_ir50 + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 25 + print_batch_step: 20 + use_visualdl: False + eval_mode: face_recognition + retrieval_feature_from: backbone + flip_test: True + feature_normalize: False + re_ranking: False + use_dali: False + # used for static mode and model export + image_shape: [3, 112, 112] + save_inference_dir: ./inference + +AMP: + scale_loss: 27648.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + +# model architecture +Arch: + name: RecModel + infer_output_key: features + infer_add_softmax: False + + Backbone: + name: FresResNet50 + Head: + name: ArcMargin + embedding_size: 512 + class_num: 93431 + margin: 0.5 + scale: 64 +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Piecewise + decay_epochs: [10, 16, 22] + values: [0.02, 0.002, 0.0002, 0.00002] + by_epoch: True + regularizer: + name: L2 + coeff: 0.0005 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: dataset/MS1M_v3/ + cls_label_path: dataset/MS1M_v3/label.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + backend: cv2 + - RandFlipImage: + flip_code: 1 + - ResizeImage: + size: [112, 112] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: FiveEvalDataset + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + backend: cv2 + - ResizeImage: + size: [112, 112] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - BestAccuracy: {} diff --git a/paddlex/repo_apis/PaddleClas_api/configs/MobileFaceNet.yaml b/paddlex/repo_apis/PaddleClas_api/configs/MobileFaceNet.yaml new file mode 100644 index 000000000..c33eba3bb --- /dev/null +++ b/paddlex/repo_apis/PaddleClas_api/configs/MobileFaceNet.yaml @@ -0,0 +1,126 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 25 + print_batch_step: 20 + use_visualdl: False + eval_mode: face_recognition + retrieval_feature_from: backbone + flip_test: True + feature_normalize: False + re_ranking: False + use_dali: False + # used for static mode and model export + image_shape: [3, 112, 112] + save_inference_dir: ./inference + +AMP: + scale_loss: 27648 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + +# model architecture +Arch: + name: RecModel + infer_output_key: features + infer_add_softmax: False + + Backbone: + name: MobileFaceNet + Head: + name: ArcMargin + embedding_size: 128 + class_num: 93431 + margin: 0.5 + scale: 64 +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 4e-3 # lr 4e-3 for total_batch_size 1024 + eta_min: 1e-6 + warmup_epoch: 1 + warmup_start_lr: 0 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: dataset/MS1M_v3/ + cls_label_path: dataset/MS1M_v3/label.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + backend: cv2 + - RandFlipImage: + flip_code: 1 + - ResizeImage: + size: [112, 112] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: FiveEvalDataset + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + backend: cv2 + - ResizeImage: + size: [112, 112] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - BestAccuracy: {} diff --git a/paddlex/repo_apis/PaddleDetection_api/configs/BlazeFace.yaml b/paddlex/repo_apis/PaddleDetection_api/configs/BlazeFace.yaml new file mode 100644 index 000000000..f7d95d3d8 --- /dev/null +++ b/paddlex/repo_apis/PaddleDetection_api/configs/BlazeFace.yaml @@ -0,0 +1,144 @@ +# Runtime +use_gpu: true +use_xpu: false +use_mlu: false +use_npu: false +log_iter: 20 +save_dir: output +print_flops: false +print_params: false +weights: output/blazeface_1000e/model_final +snapshot_epoch: 10 + +# Model +architecture: BlazeFace +BlazeFace: + backbone: BlazeNet + neck: BlazeNeck + blaze_head: FaceHead + post_process: BBoxPostProcess +BlazeNet: + blaze_filters: [[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]] + double_blaze_filters: [[48, 24, 96, 2], [96, 24, 96], [96, 24, 96], + [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]] + act: relu +BlazeNeck: + neck_type : None + in_channel: [96,96] +FaceHead: + in_channels: [96,96] + anchor_generator: AnchorGeneratorSSD + loss: SSDLoss +SSDLoss: + overlap_threshold: 0.35 +AnchorGeneratorSSD: + steps: [8., 16.] + aspect_ratios: [[1.], [1.]] + min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]] + max_sizes: [[], []] + offset: 0.5 + flip: False + min_max_aspect_ratios_order: false +BBoxPostProcess: + decode: + name: SSDBox + nms: + name: MultiClassNMS + keep_top_k: 750 + score_threshold: 0.01 + nms_threshold: 0.3 + nms_top_k: 5000 + nms_eta: 1.0 + +# Optimizer +epoch: 1000 +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 333 + - 800 + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 +OptimizerBuilder: + optimizer: + momentum: 0.0 + type: RMSProp + regularizer: + factor: 0.0005 + type: L2 + +# Dataset +metric: WiderFace +num_classes: 1 +TrainDataset: + !WIDERFaceDataSet + dataset_dir: dataset/wider_face + anno_path: wider_face_split/wider_face_train_bbx_gt.txt + image_dir: WIDER_train/images + data_fields: ['image', 'gt_bbox', 'gt_class'] +EvalDataset: + !WIDERFaceValDataset + dataset_dir: dataset/wider_face + image_dir: WIDER_val/images + anno_path: wider_face_split/wider_face_val_bbx_gt.txt + gt_mat_path: WIDER_val/ground_truth + data_fields: ['image', 'gt_bbox', 'gt_class', 'ori_gt_bbox'] +TestDataset: + !ImageFolder + use_default_label: true + +# Reader +worker_num: 8 +TrainReader: + inputs_def: + num_max_boxes: 90 + sample_transforms: + - Decode: {} + - RandomDistort: {brightness: [0.5, 1.125, 0.875], random_apply: False} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomFlip: {} + - CropWithDataAchorSampling: { + anchor_sampler: [[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]], + batch_sampler: [ + [1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + ], + target_size: 640} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 1} + - NormalizeBox: {} + - PadBox: {num_max_boxes: 90} + batch_transforms: + - NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false} + - Permute: {} + batch_size: 16 + shuffle: true + drop_last: true +EvalReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false} + - Permute: {} + batch_size: 1 + collate_samples: false + shuffle: false + drop_last: false +TestReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false} + - Permute: {} + batch_size: 1 + +# Exporting the model +export: + post_process: True # Whether post-processing is included in the network when export model. + nms: True # Whether NMS is included in the network when export model. + benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported. + fuse_conv_bn: False diff --git a/paddlex/repo_apis/PaddleDetection_api/object_det/register.py b/paddlex/repo_apis/PaddleDetection_api/object_det/register.py index 20cb4ea58..d01a00517 100644 --- a/paddlex/repo_apis/PaddleDetection_api/object_det/register.py +++ b/paddlex/repo_apis/PaddleDetection_api/object_det/register.py @@ -837,3 +837,34 @@ }, } ) + + +register_model_info( + { + "model_name": "BlazeFace", + "suite": "Det", + "config_path": osp.join(PDX_CONFIG_DIR, "BlazeFace.yaml"), + "supported_apis": ["train", "evaluate", "predict", "export", "infer"], + "supported_dataset_types": ["WIDERFaceDataset"], + "supported_train_opts": { + "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"], + "dy2st": False, + "amp": ["OFF"], + }, + } +) + +register_model_info( + { + "model_name": "BlazeFace_FPN_SSH", + "suite": "Det", + "config_path": osp.join(PDX_CONFIG_DIR, "BlazeFace-FPN-SSH.yaml"), + "supported_apis": ["train", "evaluate", "predict", "export", "infer"], + "supported_dataset_types": ["WIDERFaceDataset"], + "supported_train_opts": { + "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"], + "dy2st": False, + "amp": ["OFF"], + }, + } +) \ No newline at end of file