diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index f15ab09e59..a1a258dae0 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -37,7 +37,9 @@ class TextRecognizer(object): - def __init__(self, args): + def __init__(self, args, logger=None): + if logger is None: + logger = get_logger() self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm @@ -157,7 +159,7 @@ def __init__(self, args): model_precision=args.precision, batch_size=args.rec_batch_num, data_shape="dynamic", - save_path=None, # args.save_log_path, + save_path=None, # not used if logger is not None inference_config=self.config, pids=pid, process_name=None, @@ -701,14 +703,25 @@ def __call__(self, img_list): def main(args): image_file_list = get_image_file_list(args.image_dir) - text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] + # logger + log_file = args.save_log_path + if os.path.is_dir(args.save_log_path) or ( + not os.path.exists(args.save_log_path) and args.save_log_path.endswith("/") + ): + log_file = os.path.join(log_file, "benchmark_recognition.log") + logger = get_logger(log_file=log_file) + + # create text recognizer + text_recognizer = TextRecognizer(args) + logger.info( "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320" ) + # warmup 2 times if args.warmup: img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)