diff --git a/mmdetection/mmdet/engine/hooks/submission_hook.py b/mmdetection/mmdet/engine/hooks/submission_hook.py index 2a65bad..abcbef3 100644 --- a/mmdetection/mmdet/engine/hooks/submission_hook.py +++ b/mmdetection/mmdet/engine/hooks/submission_hook.py @@ -9,7 +9,7 @@ from mmdet.registry import HOOKS from mmdet.structures import DetDataSample import pandas as pd - +import re @HOOKS.register_module() @@ -28,16 +28,20 @@ class SubmissionHook(Hook): test_out_dir (str) : 저장할 경로 """ - def __init__(self, test_out_dir='submit'): - self.prediction_strings = [] - self.file_names = [] + def __init__(self, test_out_dir="submit", mode="test", out_file="submission"): + self.test_outputs_data = [] self.test_out_dir = test_out_dir - - - def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, - outputs: Sequence[DetDataSample]) -> None: - """ - Run after every testing iterations. + self.mode = mode + self.out_file = out_file + + def after_test_iter( + self, + runner: Runner, + batch_idx: int, + data_batch: dict, + outputs: Sequence[DetDataSample], + ) -> None: + """Run after every testing iterations. Args: runner (:obj:`Runner`): The runner of the testing process. @@ -46,18 +50,37 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples that contain annotations and predictions. """ - assert len(outputs) == 1, \ - 'only batch_size=1 is supported while testing.' for output in outputs: - prediction_string = '' - for label, score, bbox in zip(output.pred_instances.labels, output.pred_instances.scores, output.pred_instances.bboxes): + prediction_string = "" + for label, score, bbox in zip( + output.pred_instances.labels, + output.pred_instances.scores, + output.pred_instances.bboxes, + ): bbox = bbox.cpu().numpy() # 이미 xyxy로 되어있음 - prediction_string += str(int(label.cpu())) + ' ' + str(float(score.cpu())) + ' ' + str(bbox[0]) + ' ' + str(bbox[1]) + ' ' + str(bbox[2]) + ' ' + str(bbox[3]) + ' ' - self.prediction_strings.append(prediction_string) - self.file_names.append(output.img_path[-13:]) - + prediction_string += ( + str(int(label.cpu())) + + " " + + str(float(score.cpu())) + + " " + + str(bbox[0]) + + " " + + str(bbox[1]) + + " " + + str(bbox[2]) + + " " + + str(bbox[3]) + + " " + ) + match = re.search(rf"{self.mode}/(\d+\.jpg)", output.img_path) + if match: + self.test_outputs_data.append( + [int(match.group(1)[:4]), prediction_string, match.group(0)] + ) + else: + assert "File dir have Problem -- in Submission Hook" def after_test(self, runner: Runner): """ @@ -67,13 +90,23 @@ def after_test(self, runner: Runner): runner (:obj:`Runner`): The runner of the testing process. """ if self.test_out_dir is not None: - self.test_out_dir = osp.join(runner.work_dir, runner.timestamp, - self.test_out_dir) + self.test_out_dir = osp.join(runner.work_dir, self.test_out_dir) mkdir_or_exist(self.test_out_dir) + self.test_outputs_data.sort(key=lambda x: x[0]) + + prediction_strings = [] + file_names = [] + for _, predict, file_name in self.test_outputs_data: + prediction_strings.append(predict) + file_names.append(file_name) submission = pd.DataFrame() - submission['PredictionString'] = self.prediction_strings - submission['image_id'] = self.file_names - submission.to_csv(osp.join(self.test_out_dir, 'submission.csv'), index=None) - print('submission saved to {}'.format(osp.join(self.test_out_dir, 'submission.csv'))) - \ No newline at end of file + submission["PredictionString"] = prediction_strings + submission["image_id"] = file_names + print(submission.head()) + submission.to_csv(osp.join(self.test_out_dir, f"{self.out_file}.csv"), index=None) + print( + "submission saved to {}".format( + osp.join(self.test_out_dir, "submission.csv") + ) + ) diff --git a/mmdetection/tools/multi_model_csv_test.py b/mmdetection/tools/multi_model_csv_test.py new file mode 100644 index 0000000..e87d94d --- /dev/null +++ b/mmdetection/tools/multi_model_csv_test.py @@ -0,0 +1,234 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Support for multi-model fusion, and currently only the Weighted Box Fusion +(WBF) fusion method is supported. + +References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion + +Example: + + python demo/demo_multi_model.py demo/demo.jpg \ + ./configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_1x_coco.py \ + ./configs/retinanet/retinanet_r50-caffe_fpn_1x_coco.py \ + --checkpoints \ + https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth \ # noqa + https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth \ + --weights 1 2 +""" + +import argparse +import os.path as osp +import os +from datetime import datetime + +import mmcv +from mmengine.structures import InstanceData + +from mmdet.apis import DetInferencer + +# from mmdet.models.utils import weighted_boxes_fusion + +import torch +from pycocotools.coco import COCO +from torchmetrics.detection import MeanAveragePrecision +from tqdm import tqdm + +from ensemble_boxes import * +import pandas as pd +import numpy as np + + +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMDetection multi-model inference demo" + ) + parser.add_argument( + "type", + type=str, + default="wbf", # avg, max, box_and_model_avg, absent_model_aware_avg + help="ensemble type", + ) + parser.add_argument( + "csv", + type=str, + nargs="*", + help="CSV file(s), support receive multiple files", + ) + parser.add_argument( + "--conf-type", + type=str, + default="avg", # avg, max, box_and_model_avg, absent_model_aware_avg + help="how to calculate confidence in weighted boxes in wbf", + ) + parser.add_argument( + "--weights", + type=float, + nargs="*", + default=None, + help="weights for each model, remember to " "correspond to the above config", + ) + parser.add_argument( + "--fusion-iou-thr", + type=float, + default=0.55, + help="IoU value for boxes to be a match in wbf", + ) + parser.add_argument( + "--skip-box-thr", + type=float, + default=0.0, + help="exclude boxes with score lower than this variable in wbf", + ) + parser.add_argument( + "--out-dir", + type=str, + default="outputs", + help="Output directory of images or prediction results.", + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + + print(f"Ensemble Type : {args.type}") + print(f"Ensemble weight : {args.weights}") + print(f"Ensemble iou thr : {args.fusion_iou_thr}") + print(f"Ensemble skip box thr : {args.skip_box_thr}") + print(f"Ensemble wbf type : {args.conf_type}") + + now = datetime.now() + out_path = osp.join(args.out_dir, "result", now.strftime("%Y_%m_%d_%H_%M_%S")) + os.makedirs(out_path, exist_ok=True) + + image_size = 1024 + results = [] + + submission_df = [pd.read_csv(file) for file in args.csv] + image_ids = submission_df[0]["image_id"].tolist() + + prediction_strings = [] + file_names = [] + + for i, image_id in tqdm(enumerate(image_ids), total=len(image_ids)): + prediction_string = "" + boxes_list = [] + scores_list = [] + labels_list = [] + + for df in submission_df: + predict_string = df[df["image_id"] == image_id][ + "PredictionString" + ].tolist()[0] + predict_list = str(predict_string).split() + + if len(predict_list) == 0 or len(predict_list) == 1: + continue + + predict_list = np.reshape(predict_list, (-1, 6)) + box_list = [] + + for bbox in predict_list[:, 2:6].tolist(): + bbox[0] = float(bbox[0]) / image_size + bbox[1] = float(bbox[1]) / image_size + bbox[2] = float(bbox[2]) / image_size + bbox[3] = float(bbox[3]) / image_size + box_list.append(bbox) + + boxes_list.append(box_list) + scores_list.append(list(map(float, predict_list[:, 1].tolist()))) + labels_list.append(list(map(int, predict_list[:, 0].tolist()))) + results.append([boxes_list, scores_list, labels_list]) + + if len(boxes_list): + if args.type == "nms": + bboxes, scores, labels = nms( + boxes_list, + scores_list, + labels_list, + weights=args.weights, + iou_thr=args.fusion_iou_thr, + ) + if args.type == "soft_nms": + bboxes, scores, labels = soft_nms( + boxes_list, + scores_list, + labels_list, + weights=args.weights, + iou_thr=args.fusion_iou_thr, + mode=2, + sigma=0.5, + ) + elif args.type == "nmw": + bboxes, scores, labels = non_maximum_weighted( + boxes_list, + scores_list, + labels_list, + weights=args.weights, + iou_thr=args.fusion_iou_thr, + skip_box_thr=args.skip_box_thr, + ) + elif args.type == "wbf": + bboxes, scores, labels = weighted_boxes_fusion( + boxes_list, + scores_list, + labels_list, + weights=args.weights, + iou_thr=args.fusion_iou_thr, + skip_box_thr=args.skip_box_thr, + conf_type=args.conf_type, + ) + + for bbox, score, label in zip(bboxes, scores, labels): + prediction_string += ( + str(int(label)) + + " " + + str(score) + + " " + + str(bbox[0] * image_size) + + " " + + str(bbox[1] * image_size) + + " " + + str(bbox[2] * image_size) + + " " + + str(bbox[3] * image_size) + + " " + ) + prediction_strings.append(prediction_string) + file_names.append(image_id) + + submission = pd.DataFrame() + submission["PredictionString"] = prediction_strings + submission["image_id"] = file_names + submission.to_csv(osp.join(out_path, "ouput.csv"), index=False) + + with open(os.path.join(out_path, "log.txt"), "w") as f: + f.write("Config & Weight List\n") + for csv_name in args.csv: + slice_config_name = csv_name.split("/")[-1] + f.write(f"{slice_config_name}\n") + f.write("Ensemble Config\n") + f.write(f"Ensemble Type : {args.type}\n") + f.write(f"Ensemble weight : {args.weights}\n") + f.write(f"Ensemble iou thr : {args.fusion_iou_thr}\n") + f.write(f"Ensemble skip box thr : {args.skip_box_thr}\n") + f.write(f"Ensemble wbf type : {args.conf_type}\n") + + +if __name__ == "__main__": + main() diff --git a/mmdetection/tools/multi_model_csv_val.py b/mmdetection/tools/multi_model_csv_val.py new file mode 100644 index 0000000..9f28bdb --- /dev/null +++ b/mmdetection/tools/multi_model_csv_val.py @@ -0,0 +1,286 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Support for multi-model fusion, and currently only the Weighted Box Fusion +(WBF) fusion method is supported. + +References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion + +Example: + + python demo/demo_multi_model.py demo/demo.jpg \ + ./configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_1x_coco.py \ + ./configs/retinanet/retinanet_r50-caffe_fpn_1x_coco.py \ + --checkpoints \ + https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth \ # noqa + https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth \ + --weights 1 2 +""" + +import argparse +import os.path as osp +import os +from datetime import datetime + +import mmcv +from mmengine.structures import InstanceData + +from mmdet.apis import DetInferencer + +# from mmdet.models.utils import weighted_boxes_fusion + +import torch +from pycocotools.coco import COCO +from torchmetrics.detection import MeanAveragePrecision +from tqdm import tqdm + +from ensemble_boxes import * +import pandas as pd +import numpy as np + + +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMDetection multi-model inference demo" + ) + parser.add_argument( + "type", + type=str, + default="wbf", # avg, max, box_and_model_avg, absent_model_aware_avg + help="ensemble type", + ) + parser.add_argument( + "csv", + type=str, + nargs="*", + help="CSV file(s), support receive multiple files", + ) + parser.add_argument( + "--conf-type", + type=str, + default="avg", # avg, max, box_and_model_avg, absent_model_aware_avg + help="how to calculate confidence in weighted boxes in wbf", + ) + parser.add_argument( + "--weights", + type=float, + nargs="*", + default=None, + help="weights for each model, remember to " "correspond to the above config", + ) + parser.add_argument( + "--fusion-iou-thr", + type=float, + default=0.55, + help="IoU value for boxes to be a match in wbf", + ) + parser.add_argument( + "--skip-box-thr", + type=float, + default=0.0, + help="exclude boxes with score lower than this variable in wbf", + ) + parser.add_argument( + "--out-dir", + type=str, + default="outputs", + help="Output directory of images or prediction results.", + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + + print(f"Ensemble Type : {args.type}") + print(f"Ensemble weight : {args.weights}") + print(f"Ensemble iou thr : {args.fusion_iou_thr}") + print(f"Ensemble skip box thr : {args.skip_box_thr}") + print(f"Ensemble wbf type : {args.conf_type}") + + now = datetime.now() + out_path = osp.join(args.out_dir, "result", now.strftime("%Y_%m_%d_%H_%M_%S")) + os.makedirs(out_path, exist_ok=True) + + image_size = 1024 + val_json = "/home/hojun/Documents/code/boostcamp/project2/version1/dataset/val_eye_eda.json" + results = [] + pred_base_list = [] + + submission_df = [pd.read_csv(file) for file in args.csv] + image_ids = submission_df[0]["image_id"].tolist() + + for i, image_id in tqdm(enumerate(image_ids), total=len(image_ids)): + prediction_string = "" + boxes_list = [] + scores_list = [] + labels_list = [] + + for df in submission_df: + predict_string = df[df["image_id"] == image_id][ + "PredictionString" + ].tolist()[0] + predict_list = str(predict_string).split() + + if len(predict_list) == 0 or len(predict_list) == 1: + continue + + predict_list = np.reshape(predict_list, (-1, 6)) + box_list = [] + + for box in predict_list[:, 2:6].tolist(): + box[0] = float(box[0]) / image_size + box[1] = float(box[1]) / image_size + box[2] = float(box[2]) / image_size + box[3] = float(box[3]) / image_size + box_list.append(box) + + boxes_list.append(box_list) + scores_list.append(list(map(float, predict_list[:, 1].tolist()))) + labels_list.append(list(map(int, predict_list[:, 0].tolist()))) + results.append([boxes_list, scores_list, labels_list]) + + for i in tqdm(range(len(results))): + if args.type == "nms": + bboxes, scores, labels = nms( + results[i][0], + results[i][1], + results[i][2], + weights=args.weights, + iou_thr=args.fusion_iou_thr, + # skip_box_thr=args.skip_box_thr, + # conf_type=args.conf_type, + ) + if args.type == "soft_nms": + bboxes, scores, labels = soft_nms( + results[i][0], + results[i][1], + results[i][2], + weights=args.weights, + iou_thr=args.fusion_iou_thr, + mode=2, + sigma=0.5 + # skip_box_thr=args.skip_box_thr, + # conf_type=args.conf_type, + ) + elif args.type == "nmw": + bboxes, scores, labels = non_maximum_weighted( + results[i][0], + results[i][1], + results[i][2], + weights=args.weights, + iou_thr=args.fusion_iou_thr, + skip_box_thr=args.skip_box_thr, + ) + elif args.type == "wbf": + bboxes, scores, labels = weighted_boxes_fusion( + results[i][0], + results[i][1], + results[i][2], + weights=args.weights, + iou_thr=args.fusion_iou_thr, + skip_box_thr=args.skip_box_thr, + conf_type=args.conf_type, + ) + + pred_instances = InstanceData() + pred_instances.bboxes = bboxes + pred_instances.scores = scores + pred_instances.labels = labels + + pred_instances_dict = {"boxes": [], "scores": [], "labels": []} + pred_instances_dict["boxes"] = torch.tensor(bboxes).to(args.device) + pred_instances_dict["scores"] = torch.tensor(scores).to(args.device) + pred_instances_dict["labels"] = torch.tensor(labels).to(args.device) + + pred_base_list.append(pred_instances_dict) + + coco = COCO(val_json) + + gt_base_list = [] + for imgs in tqdm(coco.imgs.values()): + gt_instances = {"boxes": [], "labels": []} + file_name = imgs["file_name"] + + image_info = coco.loadImgs( + coco.getImgIds(imgIds=[int(file_name.split("/")[-1].split(".")[0])]) + )[0] + annotation_ids = coco.getAnnIds(imgIds=image_info["id"]) + annotations = coco.loadAnns(annotation_ids) + for annotation in annotations: + bbox = annotation["bbox"] + bbox[2] = bbox[0] + bbox[2] + bbox[3] = bbox[1] + bbox[3] + class_id = annotation["category_id"] + gt_instances["boxes"].append(bbox) + gt_instances["labels"].append(class_id) + + gt_instances["boxes"] = torch.tensor(gt_instances["boxes"]).to(args.device) + gt_instances["labels"] = torch.tensor(gt_instances["labels"]).to(args.device) + gt_base_list.append(gt_instances) + + base_metric = MeanAveragePrecision(iou_type="bbox", class_metrics=True) + base_metric50 = MeanAveragePrecision( + iou_type="bbox", class_metrics=True, iou_thresholds=[0.5] + ) + for idx in tqdm(range(len(gt_base_list))): + pred_base_list[idx]["boxes"] = torch.round( + pred_base_list[idx]["boxes"] * image_size, decimals=1 + ) + pred_base_list[idx]["labels"] = pred_base_list[idx]["labels"].type(torch.int) + base_metric.update([pred_base_list[idx]], [gt_base_list[idx]]) + base_metric50.update([pred_base_list[idx]], [gt_base_list[idx]]) + + base_metric_score = base_metric.compute() + base_metric50_score = base_metric50.compute() + + score_dict = {} + base_score_names = ["mAP", "mAP50", "mAP75", "mAR_1", "mAR_10", "mAR_100"] + base_socre_indexs = ["map", "map_50", "map_75", "mar_1", "mar_10", "mar_100"] + labels = [ + "General trash", + "Paper", + "Paper pack", + "Metal", + "Glass", + "Plastic", + "Styrofoam", + "Plastic bag", + "Battery", + "Clothing", + ] + for score_name, score_index in zip(base_score_names, base_socre_indexs): + score_dict[f"val_{score_name}"] = base_metric_score[score_index] + for index, label in enumerate(labels): + score_dict[f"val_{label}_mAP50"] = base_metric50_score["map_per_class"][index] + score_dict[f"val_{label}_mAR100"] = base_metric_score["mar_100_per_class"][ + index + ] + + with open(os.path.join(out_path, "log.txt"), "w") as f: + f.write("Config & Weight List\n") + for csv_name in args.csv: + slice_config_name = csv_name.split("/")[-1] + f.write(f"{slice_config_name}\n") + f.write("Scores\n") + for key, value in score_dict.items(): + f.write(f"{key} : {value}\n") + print(score_dict) + + +if __name__ == "__main__": + main() diff --git a/mmdetection/tools/multi_model_val.py b/mmdetection/tools/multi_model_val.py new file mode 100644 index 0000000..ab950e8 --- /dev/null +++ b/mmdetection/tools/multi_model_val.py @@ -0,0 +1,274 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Support for multi-model fusion, and currently only the Weighted Box Fusion +(WBF) fusion method is supported. + +References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion + +Example: + + python demo/demo_multi_model.py demo/demo.jpg \ + ./configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_1x_coco.py \ + ./configs/retinanet/retinanet_r50-caffe_fpn_1x_coco.py \ + --checkpoints \ + https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth \ # noqa + https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth \ + --weights 1 2 +""" + +import argparse +import os.path as osp +import os +from datetime import datetime + +import mmcv +from mmengine.structures import InstanceData + +from mmdet.apis import DetInferencer + +# from mmdet.models.utils import weighted_boxes_fusion + +import torch +from pycocotools.coco import COCO +from torchmetrics.detection import MeanAveragePrecision +from tqdm import tqdm + +from ensemble_boxes import * + + +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMDetection multi-model inference demo" + ) + parser.add_argument("inputs", type=str, help="Input image file or folder path.") + parser.add_argument( + "config", + type=str, + nargs="*", + help="Config file(s), support receive multiple files", + ) + parser.add_argument( + "--checkpoints", + type=str, + nargs="*", + help="Checkpoint file(s), support receive multiple files, " + "remember to correspond to the above config", + ) + parser.add_argument( + "--weights", + type=float, + nargs="*", + default=None, + help="weights for each model, remember to " "correspond to the above config", + ) + parser.add_argument( + "--fusion-iou-thr", + type=float, + default=0.55, + help="IoU value for boxes to be a match in wbf", + ) + parser.add_argument( + "--skip-box-thr", + type=float, + default=0.0, + help="exclude boxes with score lower than this variable in wbf", + ) + parser.add_argument( + "--conf-type", + type=str, + default="avg", # avg, max, box_and_model_avg, absent_model_aware_avg + help="how to calculate confidence in weighted boxes in wbf", + ) + parser.add_argument( + "--out-dir", + type=str, + default="outputs", + help="Output directory of images or prediction results.", + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + parser.add_argument( + "--pred-score-thr", type=float, default=0.3, help="bbox score threshold" + ) + parser.add_argument( + "--batch-size", type=int, default=1, help="Inference batch size." + ) + parser.add_argument( + "--show", action="store_true", help="Display the image in a popup window." + ) + parser.add_argument( + "--no-save-vis", action="store_true", help="Do not save detection vis results" + ) + parser.add_argument( + "--no-save-pred", action="store_true", help="Do not save detection json results" + ) + parser.add_argument( + "--palette", + default="none", + choices=["coco", "voc", "citys", "random", "none"], + help="Color palette used for visualization", + ) + + args = parser.parse_args() + + if args.no_save_vis and args.no_save_pred: + args.out_dir = "" + + return args + + +def main(): + args = parse_args() + + results = [] + + inputs = [] + filename_list = [] + + coco = COCO(args.inputs) + + gt_base_list = [] + for imgs in coco.imgs.values(): + gt_instances = {"boxes": [], "labels": []} + file_name = imgs["file_name"] + image_path = "/".join(args.inputs.split("/")[:-1]) + "/" + file_name + img = mmcv.imread(image_path) + inputs.append(img) + + filename_list.append(file_name) + image_info = coco.loadImgs( + coco.getImgIds(imgIds=[int(file_name.split("/")[-1].split(".")[0])]) + )[0] + annotation_ids = coco.getAnnIds(imgIds=image_info["id"]) + annotations = coco.loadAnns(annotation_ids) + for annotation in annotations: + bbox = annotation["bbox"] + class_id = annotation["category_id"] + gt_instances["boxes"].append(bbox) + gt_instances["labels"].append(class_id) + + gt_instances["boxes"] = torch.tensor(gt_instances["boxes"]).to(args.device) + gt_instances["labels"] = torch.tensor(gt_instances["labels"]).to(args.device) + gt_base_list.append(gt_instances) + + pred_base_list = [] + for i, (config, checkpoint) in enumerate(tqdm(zip(args.config, args.checkpoints))): + inferencer = DetInferencer( + config, checkpoint, device=args.device, palette=args.palette + ) + + result_raw = inferencer( + inputs=inputs, + batch_size=args.batch_size, + no_save_vis=True, + pred_score_thr=args.pred_score_thr, + ) + + if i == 0: + results = [ + {"bboxes_list": [], "scores_list": [], "labels_list": []} + for _ in range(len(result_raw["predictions"])) + ] + + for res, raw in zip(results, result_raw["predictions"]): + res["bboxes_list"].append(raw["bboxes"]) + res["scores_list"].append(raw["scores"]) + res["labels_list"].append(raw["labels"]) + + # visualizer = VISUALIZERS.build(cfg_visualizer) + # visualizer.dataset_meta = dataset_meta + + for i in range(len(results)): + # bboxes, scores, labels = weighted_boxes_fusion( + # results[i]["bboxes_list"], + # results[i]["scores_list"], + # results[i]["labels_list"], + # weights=args.weights, + # iou_thr=args.fusion_iou_thr, + # skip_box_thr=args.skip_box_thr, + # conf_type=args.conf_type, + # ) + bboxes, scores, labels = weighted_boxes_fusion( + results[i]["bboxes_list"], + results[i]["scores_list"], + results[i]["labels_list"], + # weights=args.weights, + # iou_thr=args.fusion_iou_thr, + # skip_box_thr=args.skip_box_thr, + # conf_type=args.conf_type, + ) + + pred_instances = InstanceData() + pred_instances.bboxes = bboxes + pred_instances.scores = scores + pred_instances.labels = labels + + pred_instances_dict = {"boxes": [], "scores": [], "labels": []} + pred_instances_dict["boxes"] = torch.tensor(bboxes).to(args.device) + pred_instances_dict["scores"] = torch.tensor(scores).to(args.device) + pred_instances_dict["labels"] = torch.tensor(labels).to(args.device) + + pred_base_list.append(pred_instances_dict) + + base_metric = MeanAveragePrecision(iou_type="bbox", class_metrics=True) + base_metric50 = MeanAveragePrecision( + iou_type="bbox", class_metrics=True, iou_thresholds=[0.5] + ) + for pred, gt in zip(pred_base_list, gt_base_list): + base_metric.update([pred], [gt]) + base_metric50.update([pred], [gt]) + + base_metric_score = base_metric.compute() + base_metric50_score = base_metric50.compute() + + score_dict = {} + base_score_names = ["mAP", "mAP50", "mAP75", "mAR_1", "mAR_10", "mAR_100"] + base_socre_indexs = ["map", "map_50", "map_75", "mar_1", "mar_10", "mar_100"] + labels = [ + "General trash", + "Paper", + "Paper pack", + "Metal", + "Glass", + "Plastic", + "Styrofoam", + "Plastic bag", + "Battery", + "Clothing", + ] + for score_name, score_index in zip(base_score_names, base_socre_indexs): + score_dict[f"val_{score_name}"] = base_metric_score[score_index] + for index, label in enumerate(labels): + score_dict[f"val_{label}_mAP50"] = base_metric50_score["map_per_class"][index] + score_dict[f"val_{label}_mAR100"] = base_metric_score["mar_100_per_class"][ + index + ] + + now = datetime.now() + out_path = osp.join(args.out_dir, "result", now.strftime("%Y_%m_%d_%H_%M_%S")) + os.makedirs(out_path, exist_ok=True) + with open(os.path.join(out_path, "log.txt"), "w") as f: + f.write("Config & Weight List\n") + for config_name, weight_name in zip(args.config, args.checkpoints): + slice_config_name = config_name.split("/")[-1] + slice_weight_name = weight_name.split("/")[-1] + f.write(f"{slice_config_name} {slice_weight_name}\n") + f.write("Scores") + for key, value in score_dict.items(): + f.write(f"{key} : {value}\n") + print(score_dict) + + +if __name__ == "__main__": + main() diff --git a/mmdetection/tools/multi_model_val_another.py b/mmdetection/tools/multi_model_val_another.py new file mode 100644 index 0000000..d9c3e2c --- /dev/null +++ b/mmdetection/tools/multi_model_val_another.py @@ -0,0 +1,280 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Support for multi-model fusion, and currently only the Weighted Box Fusion +(WBF) fusion method is supported. + +References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion + +Example: + + python demo/demo_multi_model.py demo/demo.jpg \ + ./configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_1x_coco.py \ + ./configs/retinanet/retinanet_r50-caffe_fpn_1x_coco.py \ + --checkpoints \ + https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth \ # noqa + https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth \ + --weights 1 2 +""" + +import argparse +import os.path as osp +import os +from datetime import datetime + +import mmcv +from mmengine.structures import InstanceData + +from mmdet.apis import DetInferencer + +# from mmdet.models.utils import weighted_boxes_fusion + +import torch +from pycocotools.coco import COCO +from torchmetrics.detection import MeanAveragePrecision +from tqdm import tqdm + +from ensemble_boxes import * + + +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMDetection multi-model inference demo" + ) + parser.add_argument("inputs", type=str, help="Input image file or folder path.") + parser.add_argument( + "config", + type=str, + nargs="*", + help="Config file(s), support receive multiple files", + ) + parser.add_argument( + "--checkpoints", + type=str, + nargs="*", + help="Checkpoint file(s), support receive multiple files, " + "remember to correspond to the above config", + ) + parser.add_argument( + "--weights", + type=float, + nargs="*", + default=None, + help="weights for each model, remember to " "correspond to the above config", + ) + parser.add_argument( + "--fusion-iou-thr", + type=float, + default=0.55, + help="IoU value for boxes to be a match in wbf", + ) + parser.add_argument( + "--skip-box-thr", + type=float, + default=0.0, + help="exclude boxes with score lower than this variable in wbf", + ) + parser.add_argument( + "--conf-type", + type=str, + default="avg", # avg, max, box_and_model_avg, absent_model_aware_avg + help="how to calculate confidence in weighted boxes in wbf", + ) + parser.add_argument( + "--out-dir", + type=str, + default="outputs", + help="Output directory of images or prediction results.", + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + parser.add_argument( + "--pred-score-thr", type=float, default=0.3, help="bbox score threshold" + ) + parser.add_argument( + "--batch-size", type=int, default=1, help="Inference batch size." + ) + parser.add_argument( + "--show", action="store_true", help="Display the image in a popup window." + ) + parser.add_argument( + "--no-save-vis", action="store_true", help="Do not save detection vis results" + ) + parser.add_argument( + "--no-save-pred", action="store_true", help="Do not save detection json results" + ) + parser.add_argument( + "--palette", + default="none", + choices=["coco", "voc", "citys", "random", "none"], + help="Color palette used for visualization", + ) + + args = parser.parse_args() + + if args.no_save_vis and args.no_save_pred: + args.out_dir = "" + + return args + + +def main(): + args = parse_args() + + results = [] + + inputs = [] + filename_list = [] + + coco = COCO(args.inputs) + + gt_base_list = [] + for imgs in coco.imgs.values(): + gt_instances = {"boxes": [], "labels": []} + file_name = imgs["file_name"] + image_path = "/".join(args.inputs.split("/")[:-1]) + "/" + file_name + img = mmcv.imread(image_path) + inputs.append(img) + + filename_list.append(file_name) + image_info = coco.loadImgs( + coco.getImgIds(imgIds=[int(file_name.split("/")[-1].split(".")[0])]) + )[0] + annotation_ids = coco.getAnnIds(imgIds=image_info["id"]) + annotations = coco.loadAnns(annotation_ids) + for annotation in annotations: + bbox = annotation["bbox"] + class_id = annotation["category_id"] + gt_instances["boxes"].append(bbox) + gt_instances["labels"].append(class_id) + + gt_instances["boxes"] = torch.tensor(gt_instances["boxes"]).to(args.device) + gt_instances["labels"] = torch.tensor(gt_instances["labels"]).to(args.device) + gt_base_list.append(gt_instances) + + pred_base_list = [] + for i, (config, checkpoint) in enumerate(tqdm(zip(args.config, args.checkpoints))): + inferencer = DetInferencer( + config, checkpoint, device=args.device, palette=args.palette + ) + + result_raw = inferencer( + inputs=inputs, + batch_size=args.batch_size, + no_save_vis=True, + pred_score_thr=args.pred_score_thr, + ) + + if i == 0: + results = [ + {"bboxes_list": [], "scores_list": [], "labels_list": []} + for _ in range(len(result_raw["predictions"])) + ] + + for res, raw in zip(results, result_raw["predictions"]): + for idx in range(len(raw["bboxes"])): + raw["bboxes"][idx] = list( + map(lambda x: float(x / 1024), raw["bboxes"][idx]) + ) + res["bboxes_list"].append(raw["bboxes"]) + res["scores_list"].append(raw["scores"]) + res["labels_list"].append(raw["labels"]) + + # visualizer = VISUALIZERS.build(cfg_visualizer) + # visualizer.dataset_meta = dataset_meta + + for i in range(len(results)): + # bboxes, scores, labels = weighted_boxes_fusion( + # results[i]["bboxes_list"], + # results[i]["scores_list"], + # results[i]["labels_list"], + # weights=args.weights, + # iou_thr=args.fusion_iou_thr, + # skip_box_thr=args.skip_box_thr, + # conf_type=args.conf_type, + # ) + bboxes, scores, labels = weighted_boxes_fusion( + results[i]["bboxes_list"], + results[i]["scores_list"], + results[i]["labels_list"], + # weights=args.weights, + # iou_thr=args.fusion_iou_thr, + # skip_box_thr=args.skip_box_thr, + # conf_type=args.conf_type, + ) + + pred_instances = InstanceData() + pred_instances.bboxes = bboxes + pred_instances.scores = scores + pred_instances.labels = labels + + pred_instances_dict = {"boxes": [], "scores": [], "labels": []} + pred_instances_dict["boxes"] = torch.tensor(bboxes).to(args.device) + pred_instances_dict["scores"] = torch.tensor(scores).to(args.device) + pred_instances_dict["labels"] = torch.tensor(labels).to(args.device) + + pred_base_list.append(pred_instances_dict) + + base_metric = MeanAveragePrecision(iou_type="bbox", class_metrics=True) + base_metric50 = MeanAveragePrecision( + iou_type="bbox", class_metrics=True, iou_thresholds=[0.5] + ) + for pred, gt in zip(pred_base_list, gt_base_list): + pred["boxes"] = pred["boxes"] * 1024 + pred["labels"] = pred["labels"].type(torch.int) + base_metric.update([pred], [gt]) + base_metric50.update([pred], [gt]) + + base_metric_score = base_metric.compute() + base_metric50_score = base_metric50.compute() + + score_dict = {} + base_score_names = ["mAP", "mAP50", "mAP75", "mAR_1", "mAR_10", "mAR_100"] + base_socre_indexs = ["map", "map_50", "map_75", "mar_1", "mar_10", "mar_100"] + labels = [ + "General trash", + "Paper", + "Paper pack", + "Metal", + "Glass", + "Plastic", + "Styrofoam", + "Plastic bag", + "Battery", + "Clothing", + ] + for score_name, score_index in zip(base_score_names, base_socre_indexs): + score_dict[f"val_{score_name}"] = base_metric_score[score_index] + for index, label in enumerate(labels): + score_dict[f"val_{label}_mAP50"] = base_metric50_score["map_per_class"][index] + score_dict[f"val_{label}_mAR100"] = base_metric_score["mar_100_per_class"][ + index + ] + + now = datetime.now() + out_path = osp.join(args.out_dir, "result", now.strftime("%Y_%m_%d_%H_%M_%S")) + os.makedirs(out_path, exist_ok=True) + with open(os.path.join(out_path, "log.txt"), "w") as f: + f.write("Config & Weight List\n") + for config_name, weight_name in zip(args.config, args.checkpoints): + slice_config_name = config_name.split("/")[-1] + slice_weight_name = weight_name.split("/")[-1] + f.write(f"{slice_config_name} {slice_weight_name}\n") + f.write("Scores") + for key, value in score_dict.items(): + f.write(f"{key} : {value}\n") + print(score_dict) + + +if __name__ == "__main__": + main() diff --git a/mmdetection/tools/test_tracking.py b/mmdetection/tools/test_tracking.py deleted file mode 100644 index 8b928c0..0000000 --- a/mmdetection/tools/test_tracking.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse -import os -import os.path as osp - -from mmengine.config import Config, DictAction -from mmengine.model import is_model_wrapper -from mmengine.registry import RUNNERS -from mmengine.runner import Runner -from mmengine.runner.checkpoint import load_checkpoint - -from mmdet.utils import register_all_modules - - -# TODO: support fuse_conv_bn, visualization, and format_only -def parse_args(): - parser = argparse.ArgumentParser( - description='MMTrack test (and eval) a model') - parser.add_argument('config', help='test config file path') - parser.add_argument('--checkpoint', help='checkpoint file') - parser.add_argument('--detector', help='detection checkpoint file') - parser.add_argument('--reid', help='reid checkpoint file') - parser.add_argument( - '--work-dir', - help='the directory to save the file containing evaluation metrics') - parser.add_argument( - '--cfg-options', - nargs='+', - action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' - 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' - 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') - parser.add_argument('--local-rank', type=int, default=0) - args = parser.parse_args() - if 'LOCAL_RANK' not in os.environ: - os.environ['LOCAL_RANK'] = str(args.local_rank) - return args - - -def main(): - args = parse_args() - - # register all modules in mmtrack into the registries - # do not init the default scope here because it will be init in the runner - register_all_modules(init_default_scope=False) - - # load config - cfg = Config.fromfile(args.config) - cfg.launcher = args.launcher - if args.cfg_options is not None: - cfg.merge_from_dict(args.cfg_options) - - # work_dir is determined in this priority: CLI > segment in file > filename - if args.work_dir is not None: - # update configs according to CLI args if args.work_dir is not None - cfg.work_dir = args.work_dir - elif cfg.get('work_dir', None) is None: - # use config filename as default work_dir if cfg.work_dir is None - cfg.work_dir = osp.join('./work_dirs', - osp.splitext(osp.basename(args.config))[0]) - - cfg.load_from = args.checkpoint - - # build the runner from config - if 'runner_type' not in cfg: - # build the default runner - runner = Runner.from_cfg(cfg) - else: - # build customized runner from the registry - # if 'runner_type' is set in the cfg - runner = RUNNERS.build(cfg) - - if is_model_wrapper(runner.model): - model = runner.model.module - else: - model = runner.model - - if args.detector: - assert not (args.checkpoint and args.detector), \ - 'Error: checkpoint and detector checkpoint cannot both exist' - load_checkpoint(model.detector, args.detector) - - if args.reid: - assert not (args.checkpoint and args.reid), \ - 'Error: checkpoint and reid checkpoint cannot both exist' - load_checkpoint(model.reid, args.reid) - - # start testing - runner.test() - - -if __name__ == '__main__': - main() diff --git a/operating_configs/base_config.py b/operating_configs/base_config.py index 52b565f..d40c3e3 100644 --- a/operating_configs/base_config.py +++ b/operating_configs/base_config.py @@ -18,7 +18,10 @@ ) custom_hooks = [ - dict(type="SubmissionHook"), + # test_out_dir : 저장할 폴더 이름 + # mode : val일 경우 train, test일 경우 test + # out_file : 저장할 csv파일 이름 + dict(type="SubmissionHook", test_out_dir="submit", mode="test", out_file="submission"), dict(type="MetricHook"), ]