From 2cb9cade6c70748f4a842c3ca34f7c5574244534 Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sun, 13 Oct 2024 23:25:54 +0800 Subject: [PATCH] feat: add slanet plus table rec --- README.md | 20 +- slanet_plus_table/__init__.py | 2 + slanet_plus_table/main.py | 103 ++++++ slanet_plus_table/models/inference.yml | 106 ++++++ slanet_plus_table/requirements.txt | 6 + slanet_plus_table/setup.py | 49 +++ slanet_plus_table/table_matcher/__init__.py | 2 + slanet_plus_table/table_matcher/matcher.py | 125 +++++++ slanet_plus_table/table_matcher/utils.py | 36 ++ slanet_plus_table/table_structure/__init__.py | 2 + .../table_structure/table_structure.py | 178 ++++++++++ slanet_plus_table/table_structure/utils.py | 331 ++++++++++++++++++ slanet_plus_table/utils.py | 142 ++++++++ 13 files changed, 1098 insertions(+), 4 deletions(-) create mode 100644 slanet_plus_table/__init__.py create mode 100644 slanet_plus_table/main.py create mode 100644 slanet_plus_table/models/inference.yml create mode 100644 slanet_plus_table/requirements.txt create mode 100644 slanet_plus_table/setup.py create mode 100644 slanet_plus_table/table_matcher/__init__.py create mode 100644 slanet_plus_table/table_matcher/matcher.py create mode 100644 slanet_plus_table/table_matcher/utils.py create mode 100644 slanet_plus_table/table_structure/__init__.py create mode 100644 slanet_plus_table/table_structure/table_structure.py create mode 100644 slanet_plus_table/table_structure/utils.py create mode 100644 slanet_plus_table/utils.py diff --git a/README.md b/README.md index bde4f09..0a627b9 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,17 @@ RapidTable库是专门用来文档类图像的表格结构还原,结合RapidOC 目前支持两种类别的表格识别模型:中文和英文表格识别模型,具体可参见下面表格: - | 模型类型 | 模型名称 | 模型大小 | - | :------: | :------------------------------------: | :------: | - | 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M | - | 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M | +slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升,但paddle2onnx暂时不支持转换 + + | 模型类型 | 模型名称 | 模型大小 | + |:--------------:|:--------------------------------------:| :------: | + | 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M | + | 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M | + | slanet_plus 中文 | `inference.pdmodel` | 7.4M | + 模型来源:[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md) +[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md) 模型下载地址为:[百度网盘](https://pan.baidu.com/s/1PI9fksW6F6kQfJhwUkewWg?pwd=p29g) | [Google Drive](https://drive.google.com/drive/folders/1DAPWSN2zGQ-ED_Pz7RaJGTjfkN2-Mvsf?usp=sharing) | @@ -47,6 +52,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu ```bash pip install rapidocr_onnxruntime pip install rapid_table +# 安装会引入paddlepaddle cpu 3.0.0b0 +#pip install slanet_plus_table ``` ### 使用方式 @@ -57,6 +64,7 @@ RapidTable类提供model_path参数,可以自行指定上述2个模型,默 ```python table_engine = RapidTable(model_path='ch_ppstructure_mobile_v2_SLANet.onnx') +#table_engine = SLANetPlus() ``` 完整示例: @@ -68,6 +76,7 @@ from rapid_table import RapidTable from rapid_table import RapidTable, VisTable table_engine = RapidTable() +#table_engine = SLANetPlus() ocr_engine = RapidOCR() viser = VisTable() @@ -156,4 +165,7 @@ print(table_html_str) - 去掉返回表格的html字符串中的``元素,便于后续统一。 - 采用Black工具优化代码 +#### 2024.10.13 update +- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) + diff --git a/slanet_plus_table/__init__.py b/slanet_plus_table/__init__.py new file mode 100644 index 0000000..418f3c1 --- /dev/null +++ b/slanet_plus_table/__init__.py @@ -0,0 +1,2 @@ +from .main import SLANetPlus +from .utils import VisTable diff --git a/slanet_plus_table/main.py b/slanet_plus_table/main.py new file mode 100644 index 0000000..14eeb62 --- /dev/null +++ b/slanet_plus_table/main.py @@ -0,0 +1,103 @@ +import copy +import importlib +import time +from pathlib import Path +from typing import Optional, Union, List, Tuple + +import cv2 +import numpy as np + +from slanet_plus_table.table_matcher import TableMatch +from slanet_plus_table.table_structure import TableStructurer +from slanet_plus_table.utils import LoadImage, VisTable + +root_dir = Path(__file__).resolve().parent + + +class SLANetPlus: + def __init__(self, model_path: Optional[str] = None): + if model_path is None: + model_path = str( + root_dir / "models" + ) + + self.load_img = LoadImage() + self.table_structure = TableStructurer(model_path) + self.table_matcher = TableMatch() + + try: + self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() + except ModuleNotFoundError: + self.ocr_engine = None + + def __call__( + self, + img_content: Union[str, np.ndarray, bytes, Path], + ocr_result: List[Union[List[List[float]], str, str]] = None, + ) -> Tuple[str, float]: + if self.ocr_engine is None and ocr_result is None: + raise ValueError( + "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." + ) + + img = self.load_img(img_content) + + s = time.time() + h, w = img.shape[:2] + + if ocr_result is None: + ocr_result, _ = self.ocr_engine(img) + dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w) + + pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img)) + pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) + + elapse = time.time() - s + return pred_html, pred_bboxes, elapse + + def get_boxes_recs( + self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int + ) -> Tuple[np.ndarray, Tuple[str, str]]: + dt_boxes, rec_res, scores = list(zip(*ocr_result)) + rec_res = list(zip(rec_res, scores)) + + r_boxes = [] + for box in dt_boxes: + box = np.array(box) + x_min = max(0, box[:, 0].min() - 1) + x_max = min(w, box[:, 0].max() + 1) + y_min = max(0, box[:, 1].min() - 1) + y_max = min(h, box[:, 1].max() + 1) + box = [x_min, y_min, x_max, y_max] + r_boxes.append(box) + dt_boxes = np.array(r_boxes) + return dt_boxes, rec_res + + +if __name__ == '__main__': + slanet_table = SLANetPlus() + img_path = "D:\pythonProjects\TableStructureRec\outputs\\benchmark\\border_left_7267_OEJGHZF525Q011X2ZC34.jpg" + img = cv2.imread(img_path) + try: + ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." + ) from exc + ocr_result, _ = ocr_engine(img) + table_html_str, table_cell_bboxes, elapse = slanet_table(img, ocr_result) + + viser = VisTable() + + img_path = Path(img_path) + + save_dir = "outputs" + save_html_path = f"{save_dir}/{Path(img_path).stem}.html" + save_drawed_path = f"{save_dir}/vis_{Path(img_path).name}" + viser( + img_path, + table_html_str, + save_html_path, + table_cell_bboxes, + save_drawed_path, + ) diff --git a/slanet_plus_table/models/inference.yml b/slanet_plus_table/models/inference.yml new file mode 100644 index 0000000..dd6db0d --- /dev/null +++ b/slanet_plus_table/models/inference.yml @@ -0,0 +1,106 @@ +Hpi: + backend_config: + paddle_infer: + cpu_num_threads: 8 + enable_log_info: false + selected_backends: + cpu: paddle_infer + gpu: paddle_infer + supported_backends: + cpu: + - paddle_infer + gpu: + - paddle_infer +Global: + model_name: SLANet_plus +PreProcess: + transform_ops: + - DecodeImage: + channel_first: false + img_mode: BGR + - TableLabelEncode: + learn_empty_box: false + loc_reg_num: 8 + max_text_length: 500 + merge_no_span_structure: true + replace_empty_cell_token: false + - TableBoxEncode: + in_box_format: xyxyxyxy + out_box_format: xyxyxyxy + - ResizeTableImage: + max_len: 488 + - NormalizeImage: + mean: + - 0.485 + - 0.456 + - 0.406 + order: hwc + scale: 1./255. + std: + - 0.229 + - 0.224 + - 0.225 + - PaddingTableImage: + size: + - 488 + - 488 + - ToCHWImage: null + - KeepKeys: + keep_keys: + - image + - structure + - bboxes + - bbox_masks + - shape +PostProcess: + name: TableLabelDecode + merge_no_span_structure: true + character_dict: + - + - + - + - + - + - + - + - ' + - + - ' colspan="2"' + - ' colspan="3"' + - ' colspan="4"' + - ' colspan="5"' + - ' colspan="6"' + - ' colspan="7"' + - ' colspan="8"' + - ' colspan="9"' + - ' colspan="10"' + - ' colspan="11"' + - ' colspan="12"' + - ' colspan="13"' + - ' colspan="14"' + - ' colspan="15"' + - ' colspan="16"' + - ' colspan="17"' + - ' colspan="18"' + - ' colspan="19"' + - ' colspan="20"' + - ' rowspan="2"' + - ' rowspan="3"' + - ' rowspan="4"' + - ' rowspan="5"' + - ' rowspan="6"' + - ' rowspan="7"' + - ' rowspan="8"' + - ' rowspan="9"' + - ' rowspan="10"' + - ' rowspan="11"' + - ' rowspan="12"' + - ' rowspan="13"' + - ' rowspan="14"' + - ' rowspan="15"' + - ' rowspan="16"' + - ' rowspan="17"' + - ' rowspan="18"' + - ' rowspan="19"' + - ' rowspan="20"' diff --git a/slanet_plus_table/requirements.txt b/slanet_plus_table/requirements.txt new file mode 100644 index 0000000..cb794ff --- /dev/null +++ b/slanet_plus_table/requirements.txt @@ -0,0 +1,6 @@ +--extra-index-url=https://www.paddlepaddle.org.cn/packages/stable/cpu/ +opencv_python>=4.5.1.48 +numpy>=1.21.6,<2 +paddlepaddle==3.0.0b0 +Pillow +requests diff --git a/slanet_plus_table/setup.py b/slanet_plus_table/setup.py new file mode 100644 index 0000000..8f20759 --- /dev/null +++ b/slanet_plus_table/setup.py @@ -0,0 +1,49 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import sys + +import setuptools + +from setuptools.command.install import install + +MODULE_NAME = "slanet_plus_table" + + + +setuptools.setup( + name=MODULE_NAME, + version="0.0.2", + platforms="Any", + long_description="simplify paddlex slanet plus table use", + long_description_content_type="text/markdown", + description="Tools for parsing table structures based paddlepaddle.", + author="jockerK", + author_email="xinyijianggo@gmail.com", + url="https://github.com/RapidAI/RapidTable", + license="Apache-2.0", + include_package_data=True, + install_requires=[ + "paddlepaddle==3.0.0b0", + "PyYAML>=6.0", + "opencv_python>=4.5.1.48", + "numpy>=1.21.6", + "Pillow", + ], + packages=[ + MODULE_NAME, + f"{MODULE_NAME}.models", + f"{MODULE_NAME}.table_matcher", + f"{MODULE_NAME}.table_structure", + ], + package_data={"": ["inference.pdiparams","inference.pdmodel"]}, + keywords=["ppstructure,table,rapidocr,rapid_table"], + classifiers=[ + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + python_requires=">=3.8", +) diff --git a/slanet_plus_table/table_matcher/__init__.py b/slanet_plus_table/table_matcher/__init__.py new file mode 100644 index 0000000..3cebc9e --- /dev/null +++ b/slanet_plus_table/table_matcher/__init__.py @@ -0,0 +1,2 @@ +# -*- encoding: utf-8 -*- +from .matcher import TableMatch diff --git a/slanet_plus_table/table_matcher/matcher.py b/slanet_plus_table/table_matcher/matcher.py new file mode 100644 index 0000000..b930c70 --- /dev/null +++ b/slanet_plus_table/table_matcher/matcher.py @@ -0,0 +1,125 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- encoding: utf-8 -*- +import numpy as np + +from .utils import compute_iou, distance + + +class TableMatch: + def __init__(self, filter_ocr_result=True, use_master=False): + self.filter_ocr_result = filter_ocr_result + self.use_master = use_master + + def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): + if self.filter_ocr_result: + dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res) + matched_index = self.match_result(dt_boxes, pred_bboxes) + pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) + return pred_html + + def match_result(self, dt_boxes, pred_bboxes): + matched = {} + for i, gt_box in enumerate(dt_boxes): + distances = [] + for j, pred_box in enumerate(pred_bboxes): + if len(pred_box) == 8: + pred_box = [ + np.min(pred_box[0::2]), + np.min(pred_box[1::2]), + np.max(pred_box[0::2]), + np.max(pred_box[1::2]), + ] + distances.append( + (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box)) + ) # compute iou and l1 distance + sorted_distances = distances.copy() + # select det box by iou and l1 distance + sorted_distances = sorted( + sorted_distances, key=lambda item: (item[1], item[0]) + ) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if "" not in tag: + end_html.append(tag) + continue + + if "" == tag: + end_html.extend("") + + if td_index in matched_index.keys(): + b_with = False + if ( + "" in ocr_contents[matched_index[td_index][0]] + and len(matched_index[td_index]) > 1 + ): + b_with = True + end_html.extend("") + + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + + if content[0] == " ": + content = content[1:] + + if "" in content: + content = content[3:] + + if "" in content: + content = content[:-4] + + if len(content) == 0: + continue + + if i != len(matched_index[td_index]) - 1 and " " != content[-1]: + content += " " + end_html.extend(content) + + if b_with: + end_html.extend("") + + if "" == tag: + end_html.append("") + else: + end_html.append(tag) + + td_index += 1 + + # Filter elements + filter_elements = ["", "", "", ""] + end_html = [v for v in end_html if v not in filter_elements] + return "".join(end_html), end_html + + def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): + y1 = pred_bboxes[:, 1::2].min() + new_dt_boxes = [] + new_rec_res = [] + + for box, rec in zip(dt_boxes, rec_res): + if np.max(box[1::2]) < y1: + continue + new_dt_boxes.append(box) + new_rec_res.append(rec) + return new_dt_boxes, new_rec_res diff --git a/slanet_plus_table/table_matcher/utils.py b/slanet_plus_table/table_matcher/utils.py new file mode 100644 index 0000000..de55fe1 --- /dev/null +++ b/slanet_plus_table/table_matcher/utils.py @@ -0,0 +1,36 @@ +def distance(box_1, box_2): + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4 - x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + + +def compute_iou(rec1, rec2): + """ + computing IoU + :param rec1: (y0, x0, y1, x1), which reflects + (top, left, bottom, right) + :param rec2: (y0, x0, y1, x1) + :return: scala value of IoU + """ + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + left_line = max(rec1[1], rec2[1]) + right_line = min(rec1[3], rec2[3]) + top_line = max(rec1[0], rec2[0]) + bottom_line = min(rec1[2], rec2[2]) + + # judge if there is an intersect + if left_line >= right_line or top_line >= bottom_line: + return 0.0 + else: + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect)) * 1.0 diff --git a/slanet_plus_table/table_structure/__init__.py b/slanet_plus_table/table_structure/__init__.py new file mode 100644 index 0000000..9da248f --- /dev/null +++ b/slanet_plus_table/table_structure/__init__.py @@ -0,0 +1,2 @@ +# -*- encoding: utf-8 -*- +from .table_structure import TableStructurer diff --git a/slanet_plus_table/table_structure/table_structure.py b/slanet_plus_table/table_structure/table_structure.py new file mode 100644 index 0000000..cd7af51 --- /dev/null +++ b/slanet_plus_table/table_structure/table_structure.py @@ -0,0 +1,178 @@ +import time +import numpy as np +from .utils import TablePredictor, TablePreprocess, TableLabelDecode + + +# class SLANetPlus: +# def __init__(self, model_dir, model_prefix="inference"): +# self.preprocess_op = TablePreprocess() +# +# self.mean=[0.485, 0.456, 0.406] +# self.std=[0.229, 0.224, 0.225] +# self.target_img_size = [488, 488] +# self.scale=1 / 255 +# self.order="hwc" +# self.img_loader = LoadImage() +# self.target_size = 488 +# self.pad_color = 0 +# self.predictor = TablePredictor(model_dir, model_prefix) +# dict_character=['sos', '', '', '', '', '', '', '', '', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', '', 'eos'] +# self.beg_str = "sos" +# self.end_str = "eos" +# self.dict = {} +# self.table_matcher = TableMatch() +# for i, char in enumerate(dict_character): +# self.dict[char] = i +# self.character = dict_character +# self.td_token = ["", ""] +# +# def call(self, img): +# starttime = time.time() +# data = {"image": img} +# data = self.preprocess_op(data) +# img = data[0] +# if img is None: +# return None, 0 +# img = np.expand_dims(img, axis=0) +# img = img.copy() +# def __call__(self, img, ocr_result): +# img = self.img_loader(img) +# h, w = img.shape[:2] +# n_img, h_resize, w_resize = self.resize(img) +# n_img = self.normalize(n_img) +# n_img = self.pad(n_img) +# n_img = n_img.transpose((2, 0, 1)) +# n_img = np.expand_dims(n_img, axis=0) +# start = time.time() +# batch_output = self.predictor(n_img) +# elapse_time = time.time() - start +# ori_img_size = [[w, h]] +# output = self.decode(batch_output, ori_img_size)[0] +# corners = np.stack(output['bbox'], axis=0) +# dt_boxes, rec_res = get_boxes_recs(ocr_result, h, w) +# pred_html = self.table_matcher(output['structure'], convert_corners_to_bounding_boxes(corners), dt_boxes, rec_res) +# return pred_html,output['bbox'], elapse_time +# def resize(self, img): +# h, w = img.shape[:2] +# scale = self.target_size / max(h, w) +# h_resize = round(h * scale) +# w_resize = round(w * scale) +# resized_img = cv2.resize(img, (w_resize, h_resize), interpolation=cv2.INTER_LINEAR) +# return resized_img, h_resize, w_resize +# def pad(self, img): +# h, w = img.shape[:2] +# tw, th = self.target_img_size +# ph = th - h +# pw = tw - w +# pad = (0, ph, 0, pw) +# chns = 1 if img.ndim == 2 else img.shape[2] +# im = cv2.copyMakeBorder(img, *pad, cv2.BORDER_CONSTANT, value=(self.pad_color,) * chns) +# return im +# def normalize(self, img): +# img = img.astype("float32", copy=False) +# img *= self.scale +# img -= self.mean +# img /= self.std +# return img +# +# +# def decode(self, pred, ori_img_size): +# bbox_preds, structure_probs = [], [] +# for bbox_pred, stru_prob in pred: +# bbox_preds.append(bbox_pred) +# structure_probs.append(stru_prob) +# bbox_preds = np.array(bbox_preds) +# structure_probs = np.array(structure_probs) +# +# bbox_list, structure_str_list, structure_score = self.decode_single( +# structure_probs, bbox_preds, [self.target_img_size], ori_img_size +# ) +# structure_str_list = [ +# ( +# ["", "", ""] +# + structure +# + ["
", "", ""] +# ) +# for structure in structure_str_list +# ] +# return [ +# {"bbox": bbox, "structure": structure, "structure_score": structure_score} +# for bbox, structure in zip(bbox_list, structure_str_list) +# ] +# +# +# def decode_single(self, structure_probs, bbox_preds, padding_size, ori_img_size): +# """convert text-label into text-index.""" +# ignored_tokens = [self.beg_str, self.end_str] +# end_idx = self.dict[self.end_str] +# +# structure_idx = structure_probs.argmax(axis=2) +# structure_probs = structure_probs.max(axis=2) +# +# structure_batch_list = [] +# bbox_batch_list = [] +# batch_size = len(structure_idx) +# for batch_idx in range(batch_size): +# structure_list = [] +# bbox_list = [] +# score_list = [] +# for idx in range(len(structure_idx[batch_idx])): +# char_idx = int(structure_idx[batch_idx][idx]) +# if idx > 0 and char_idx == end_idx: +# break +# if char_idx in ignored_tokens: +# continue +# text = self.character[char_idx] +# if text in self.td_token: +# bbox = bbox_preds[batch_idx, idx] +# bbox = self._bbox_decode( +# bbox, padding_size[batch_idx], ori_img_size[batch_idx] +# ) +# bbox_list.append(bbox.astype(int)) +# structure_list.append(text) +# score_list.append(structure_probs[batch_idx, idx]) +# structure_batch_list.append(structure_list) +# structure_score = np.mean(score_list) +# bbox_batch_list.append(bbox_list) +# +# return bbox_batch_list, structure_batch_list, structure_score +# +# def _bbox_decode(self, bbox, padding_shape, ori_shape): +# +# pad_w, pad_h = padding_shape +# w, h = ori_shape +# ratio_w = pad_w / w +# ratio_h = pad_h / h +# ratio = min(ratio_w, ratio_h) +# +# bbox[0::2] *= pad_w +# bbox[1::2] *= pad_h +# bbox[0::2] /= ratio +# bbox[1::2] /= ratio +# +# return bbox + + +class TableStructurer: + def __init__(self, model_path: str): + self.preprocess_op = TablePreprocess() + self.predictor = TablePredictor(model_path) + self.character = ['', '', '', '', '', '', '', '', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', ''] + self.postprocess_op = TableLabelDecode(self.character) + + def __call__(self, img): + start_time = time.time() + data = {"image": img} + h, w = img.shape[:2] + ori_img_size = [[w, h]] + data = self.preprocess_op(data) + img = data[0] + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + img = img.copy() + cur_img_size = [[488, 488]] + outputs = self.predictor(img) + output = self.postprocess_op(outputs, cur_img_size, ori_img_size)[0] + elapse = time.time() - start_time + return output["structure"], np.stack(output["bbox"]), elapse diff --git a/slanet_plus_table/table_structure/utils.py b/slanet_plus_table/table_structure/utils.py new file mode 100644 index 0000000..f1cf7a0 --- /dev/null +++ b/slanet_plus_table/table_structure/utils.py @@ -0,0 +1,331 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- encoding: utf-8 -*- +# @Author: Jocker1212 +# @Contact: xinyijianggo@gmail.com +from pathlib import Path + +import cv2 +import numpy as np +from paddle.inference import Config, create_predictor + +class TablePredictor: + def __init__(self, model_dir, model_prefix="inference"): + model_file = f"{model_dir}/{model_prefix}.pdmodel" + params_file = f"{model_dir}/{model_prefix}.pdiparams" + config = Config(model_file, params_file) + config.disable_gpu() + config.disable_glog_info() + config.enable_new_ir(True) + config.enable_new_executor(True) + config.enable_memory_optim() + config.switch_ir_optim(True) + # Disable feed, fetch OP, needed by zero_copy_run + config.switch_use_feed_fetch_ops(False) + predictor = create_predictor(config) + self.config = config + self.predictor = predictor + # Get input and output handlers + input_names = predictor.get_input_names() + self.input_names = input_names.sort() + self.input_handlers = [] + self.output_handlers = [] + for input_name in input_names: + input_handler = predictor.get_input_handle(input_name) + self.input_handlers.append(input_handler) + self.output_names = predictor.get_output_names() + for output_name in self.output_names: + output_handler = predictor.get_output_handle(output_name) + self.output_handlers.append(output_handler) + + def __call__(self, batch_imgs): + self.input_handlers[0].reshape(batch_imgs.shape) + self.input_handlers[0].copy_from_cpu(batch_imgs) + self.predictor.run() + output = [] + for out_tensor in self.output_handlers: + batch = out_tensor.copy_to_cpu() + output.append(batch) + return self.format_output(output) + + def format_output(self, pred): + return [res for res in zip(*pred)] + +class TableLabelDecode: + """decode the table model outputs(probs) to character str""" + def __init__(self, dict_character=[], merge_no_span_structure=True, **kwargs): + + if merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + self.td_token = ["", ""] + + def add_special_char(self, dict_character): + """add_special_char""" + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def get_ignored_tokens(self): + """get_ignored_tokens""" + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + """get_beg_end_flag_idx""" + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def __call__(self, pred, img_size, ori_img_size): + """apply""" + bbox_preds, structure_probs = [], [] + for bbox_pred, stru_prob in pred: + bbox_preds.append(bbox_pred) + structure_probs.append(stru_prob) + bbox_preds = np.array(bbox_preds) + structure_probs = np.array(structure_probs) + + bbox_list, structure_str_list, structure_score = self.decode( + structure_probs, bbox_preds, img_size, ori_img_size + ) + structure_str_list = [ + ( + ["", "", ""] + + structure + + ["
", "", ""] + ) + for structure in structure_str_list + ] + return [ + {"bbox": bbox, "structure": structure, "structure_score": structure_score} + for bbox, structure in zip(bbox_list, structure_str_list) + ] + + def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size): + """convert text-label into text-index.""" + ignored_tokens = self.get_ignored_tokens() + end_idx = self.dict[self.end_str] + + structure_idx = structure_probs.argmax(axis=2) + structure_probs = structure_probs.max(axis=2) + + structure_batch_list = [] + bbox_batch_list = [] + batch_size = len(structure_idx) + for batch_idx in range(batch_size): + structure_list = [] + bbox_list = [] + score_list = [] + for idx in range(len(structure_idx[batch_idx])): + char_idx = int(structure_idx[batch_idx][idx]) + if idx > 0 and char_idx == end_idx: + break + if char_idx in ignored_tokens: + continue + text = self.character[char_idx] + if text in self.td_token: + bbox = bbox_preds[batch_idx, idx] + bbox = self._bbox_decode( + bbox, padding_size[batch_idx], ori_img_size[batch_idx] + ) + bbox_list.append(bbox.astype(int)) + structure_list.append(text) + score_list.append(structure_probs[batch_idx, idx]) + structure_batch_list.append(structure_list) + structure_score = np.mean(score_list) + bbox_batch_list.append(bbox_list) + + return bbox_batch_list, structure_batch_list, structure_score + + def _bbox_decode(self, bbox, padding_shape, ori_shape): + + pad_w, pad_h = padding_shape + w, h = ori_shape + ratio_w = pad_w / w + ratio_h = pad_h / h + ratio = min(ratio_w, ratio_h) + + bbox[0::2] *= pad_w + bbox[1::2] *= pad_h + bbox[0::2] /= ratio + bbox[1::2] /= ratio + + return bbox + + +class TablePreprocess: + def __init__(self): + self.table_max_len = 488 + self.build_pre_process_list() + self.ops = self.create_operators() + + def __call__(self, data): + """transform""" + if self.ops is None: + self.ops = [] + + for op in self.ops: + data = op(data) + if data is None: + return None + return data + + def create_operators( + self, + ): + """ + create operators based on the config + + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance( + self.pre_process_list, list + ), "operator config should be a list" + ops = [] + for operator in self.pre_process_list: + assert ( + isinstance(operator, dict) and len(operator) == 1 + ), "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + op = eval(op_name)(**param) + ops.append(op) + return ops + + def build_pre_process_list(self): + resize_op = { + "ResizeTableImage": { + "max_len": self.table_max_len, + } + } + pad_op = { + "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]} + } + normalize_op = { + "NormalizeImage": { + "std": [0.229, 0.224, 0.225], + "mean": [0.485, 0.456, 0.406], + "scale": "1./255.", + "order": "hwc", + } + } + to_chw_op = {"ToCHWImage": None} + keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}} + self.pre_process_list = [ + resize_op, + normalize_op, + pad_op, + to_chw_op, + keep_keys_op, + ] + + +class ResizeTableImage: + def __init__(self, max_len, resize_bboxes=False, infer_mode=False): + super(ResizeTableImage, self).__init__() + self.max_len = max_len + self.resize_bboxes = resize_bboxes + self.infer_mode = infer_mode + + def __call__(self, data): + img = data["image"] + height, width = img.shape[0:2] + ratio = self.max_len / (max(height, width) * 1.0) + resize_h = int(height * ratio) + resize_w = int(width * ratio) + resize_img = cv2.resize(img, (resize_w, resize_h)) + if self.resize_bboxes and not self.infer_mode: + data["bboxes"] = data["bboxes"] * ratio + data["image"] = resize_img + data["src_img"] = img + data["shape"] = np.array([height, width, ratio, ratio]) + data["max_len"] = self.max_len + return data + + +class PaddingTableImage: + def __init__(self, size, **kwargs): + super(PaddingTableImage, self).__init__() + self.size = size + + def __call__(self, data): + img = data["image"] + pad_h, pad_w = self.size + padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) + height, width = img.shape[0:2] + padding_img[0:height, 0:width, :] = img.copy() + data["image"] = padding_img + shape = data["shape"].tolist() + shape.extend([pad_h, pad_w]) + data["shape"] = np.array(shape) + return data + + +class NormalizeImage: + """normalize image such as substract mean, divide std""" + + def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") + + def __call__(self, data): + img = np.array(data["image"]) + assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" + data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std + return data + + +class ToCHWImage: + """convert hwc image to chw image""" + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = np.array(data["image"]) + data["image"] = img.transpose((2, 0, 1)) + return data + +class KeepKeys: + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list diff --git a/slanet_plus_table/utils.py b/slanet_plus_table/utils.py new file mode 100644 index 0000000..bb0736c --- /dev/null +++ b/slanet_plus_table/utils.py @@ -0,0 +1,142 @@ +from io import BytesIO +from pathlib import Path +from typing import Union, Optional + +import cv2 +import numpy as np +from PIL import Image, UnidentifiedImageError + +InputType = Union[str, np.ndarray, bytes, Path] + +class LoadImage: + def __init__( + self, + ): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + img = self.load_img(img) + + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3 and img.shape[2] == 4: + return self.cvt_four_to_three(img) + + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = np.array(Image.open(img)) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = np.array(Image.open(BytesIO(img))) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if isinstance(img, np.ndarray): + return img + + raise LoadImageError(f"{type(img)} is not supported!") + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → RGB""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + +class LoadImageError(Exception): + pass + +class VisTable: + def __init__( + self, + ): + self.load_img = LoadImage() + + def __call__( + self, + img_path: Union[str, Path], + table_html_str: str, + save_html_path: Optional[str] = None, + table_cell_bboxes: Optional[np.ndarray] = None, + save_drawed_path: Optional[str] = None, + ) -> None: + if save_html_path: + html_with_border = self.insert_border_style(table_html_str) + self.save_html(save_html_path, html_with_border) + + if table_cell_bboxes is None: + return None + + img = self.load_img(img_path) + + dims_bboxes = table_cell_bboxes.shape[1] + if dims_bboxes == 4: + drawed_img = self.draw_rectangle(img, table_cell_bboxes) + elif dims_bboxes == 8: + drawed_img = self.draw_polylines(img, table_cell_bboxes) + else: + raise ValueError("Shape of table bounding boxes is not between in 4 or 8.") + + if save_drawed_path: + self.save_img(save_drawed_path, drawed_img) + + return drawed_img + + def insert_border_style(self, table_html_str: str): + style_res = """""" + prefix_table, suffix_table = table_html_str.split("") + html_with_border = f"{prefix_table}{style_res}{suffix_table}" + return html_with_border + + @staticmethod + def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: + img_copy = img.copy() + for box in boxes.astype(int): + x1, y1, x2, y2 = box + cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2) + return img_copy + + @staticmethod + def draw_polylines(img: np.ndarray, points) -> np.ndarray: + img_copy = img.copy() + for point in points.astype(int): + point = point.reshape(4, 2) + cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2) + return img_copy + + @staticmethod + def save_img(save_path: Union[str, Path], img: np.ndarray): + cv2.imwrite(str(save_path), img) + + @staticmethod + def save_html(save_path: Union[str, Path], html: str): + with open(save_path, "w", encoding="utf-8") as f: + f.write(html)