From 2cb9cade6c70748f4a842c3ca34f7c5574244534 Mon Sep 17 00:00:00 2001
From: Jokcer <519548295@qq.com>
Date: Sun, 13 Oct 2024 23:25:54 +0800
Subject: [PATCH] feat: add slanet plus table rec
---
README.md | 20 +-
slanet_plus_table/__init__.py | 2 +
slanet_plus_table/main.py | 103 ++++++
slanet_plus_table/models/inference.yml | 106 ++++++
slanet_plus_table/requirements.txt | 6 +
slanet_plus_table/setup.py | 49 +++
slanet_plus_table/table_matcher/__init__.py | 2 +
slanet_plus_table/table_matcher/matcher.py | 125 +++++++
slanet_plus_table/table_matcher/utils.py | 36 ++
slanet_plus_table/table_structure/__init__.py | 2 +
.../table_structure/table_structure.py | 178 ++++++++++
slanet_plus_table/table_structure/utils.py | 331 ++++++++++++++++++
slanet_plus_table/utils.py | 142 ++++++++
13 files changed, 1098 insertions(+), 4 deletions(-)
create mode 100644 slanet_plus_table/__init__.py
create mode 100644 slanet_plus_table/main.py
create mode 100644 slanet_plus_table/models/inference.yml
create mode 100644 slanet_plus_table/requirements.txt
create mode 100644 slanet_plus_table/setup.py
create mode 100644 slanet_plus_table/table_matcher/__init__.py
create mode 100644 slanet_plus_table/table_matcher/matcher.py
create mode 100644 slanet_plus_table/table_matcher/utils.py
create mode 100644 slanet_plus_table/table_structure/__init__.py
create mode 100644 slanet_plus_table/table_structure/table_structure.py
create mode 100644 slanet_plus_table/table_structure/utils.py
create mode 100644 slanet_plus_table/utils.py
diff --git a/README.md b/README.md
index bde4f09..0a627b9 100644
--- a/README.md
+++ b/README.md
@@ -19,12 +19,17 @@ RapidTable库是专门用来文档类图像的表格结构还原,结合RapidOC
目前支持两种类别的表格识别模型:中文和英文表格识别模型,具体可参见下面表格:
- | 模型类型 | 模型名称 | 模型大小 |
- | :------: | :------------------------------------: | :------: |
- | 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M |
- | 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M |
+slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升,但paddle2onnx暂时不支持转换
+
+ | 模型类型 | 模型名称 | 模型大小 |
+ |:--------------:|:--------------------------------------:| :------: |
+ | 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M |
+ | 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M |
+ | slanet_plus 中文 | `inference.pdmodel` | 7.4M |
+
模型来源:[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)
+[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)
模型下载地址为:[百度网盘](https://pan.baidu.com/s/1PI9fksW6F6kQfJhwUkewWg?pwd=p29g) | [Google Drive](https://drive.google.com/drive/folders/1DAPWSN2zGQ-ED_Pz7RaJGTjfkN2-Mvsf?usp=sharing) |
@@ -47,6 +52,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
+# 安装会引入paddlepaddle cpu 3.0.0b0
+#pip install slanet_plus_table
```
### 使用方式
@@ -57,6 +64,7 @@ RapidTable类提供model_path参数,可以自行指定上述2个模型,默
```python
table_engine = RapidTable(model_path='ch_ppstructure_mobile_v2_SLANet.onnx')
+#table_engine = SLANetPlus()
```
完整示例:
@@ -68,6 +76,7 @@ from rapid_table import RapidTable
from rapid_table import RapidTable, VisTable
table_engine = RapidTable()
+#table_engine = SLANetPlus()
ocr_engine = RapidOCR()
viser = VisTable()
@@ -156,4 +165,7 @@ print(table_html_str)
- 去掉返回表格的html字符串中的`
`元素,便于后续统一。
- 采用Black工具优化代码
+#### 2024.10.13 update
+- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
+
diff --git a/slanet_plus_table/__init__.py b/slanet_plus_table/__init__.py
new file mode 100644
index 0000000..418f3c1
--- /dev/null
+++ b/slanet_plus_table/__init__.py
@@ -0,0 +1,2 @@
+from .main import SLANetPlus
+from .utils import VisTable
diff --git a/slanet_plus_table/main.py b/slanet_plus_table/main.py
new file mode 100644
index 0000000..14eeb62
--- /dev/null
+++ b/slanet_plus_table/main.py
@@ -0,0 +1,103 @@
+import copy
+import importlib
+import time
+from pathlib import Path
+from typing import Optional, Union, List, Tuple
+
+import cv2
+import numpy as np
+
+from slanet_plus_table.table_matcher import TableMatch
+from slanet_plus_table.table_structure import TableStructurer
+from slanet_plus_table.utils import LoadImage, VisTable
+
+root_dir = Path(__file__).resolve().parent
+
+
+class SLANetPlus:
+ def __init__(self, model_path: Optional[str] = None):
+ if model_path is None:
+ model_path = str(
+ root_dir / "models"
+ )
+
+ self.load_img = LoadImage()
+ self.table_structure = TableStructurer(model_path)
+ self.table_matcher = TableMatch()
+
+ try:
+ self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
+ except ModuleNotFoundError:
+ self.ocr_engine = None
+
+ def __call__(
+ self,
+ img_content: Union[str, np.ndarray, bytes, Path],
+ ocr_result: List[Union[List[List[float]], str, str]] = None,
+ ) -> Tuple[str, float]:
+ if self.ocr_engine is None and ocr_result is None:
+ raise ValueError(
+ "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
+ )
+
+ img = self.load_img(img_content)
+
+ s = time.time()
+ h, w = img.shape[:2]
+
+ if ocr_result is None:
+ ocr_result, _ = self.ocr_engine(img)
+ dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
+
+ pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img))
+ pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
+
+ elapse = time.time() - s
+ return pred_html, pred_bboxes, elapse
+
+ def get_boxes_recs(
+ self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
+ ) -> Tuple[np.ndarray, Tuple[str, str]]:
+ dt_boxes, rec_res, scores = list(zip(*ocr_result))
+ rec_res = list(zip(rec_res, scores))
+
+ r_boxes = []
+ for box in dt_boxes:
+ box = np.array(box)
+ x_min = max(0, box[:, 0].min() - 1)
+ x_max = min(w, box[:, 0].max() + 1)
+ y_min = max(0, box[:, 1].min() - 1)
+ y_max = min(h, box[:, 1].max() + 1)
+ box = [x_min, y_min, x_max, y_max]
+ r_boxes.append(box)
+ dt_boxes = np.array(r_boxes)
+ return dt_boxes, rec_res
+
+
+if __name__ == '__main__':
+ slanet_table = SLANetPlus()
+ img_path = "D:\pythonProjects\TableStructureRec\outputs\\benchmark\\border_left_7267_OEJGHZF525Q011X2ZC34.jpg"
+ img = cv2.imread(img_path)
+ try:
+ ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
+ except ModuleNotFoundError as exc:
+ raise ModuleNotFoundError(
+ "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
+ ) from exc
+ ocr_result, _ = ocr_engine(img)
+ table_html_str, table_cell_bboxes, elapse = slanet_table(img, ocr_result)
+
+ viser = VisTable()
+
+ img_path = Path(img_path)
+
+ save_dir = "outputs"
+ save_html_path = f"{save_dir}/{Path(img_path).stem}.html"
+ save_drawed_path = f"{save_dir}/vis_{Path(img_path).name}"
+ viser(
+ img_path,
+ table_html_str,
+ save_html_path,
+ table_cell_bboxes,
+ save_drawed_path,
+ )
diff --git a/slanet_plus_table/models/inference.yml b/slanet_plus_table/models/inference.yml
new file mode 100644
index 0000000..dd6db0d
--- /dev/null
+++ b/slanet_plus_table/models/inference.yml
@@ -0,0 +1,106 @@
+Hpi:
+ backend_config:
+ paddle_infer:
+ cpu_num_threads: 8
+ enable_log_info: false
+ selected_backends:
+ cpu: paddle_infer
+ gpu: paddle_infer
+ supported_backends:
+ cpu:
+ - paddle_infer
+ gpu:
+ - paddle_infer
+Global:
+ model_name: SLANet_plus
+PreProcess:
+ transform_ops:
+ - DecodeImage:
+ channel_first: false
+ img_mode: BGR
+ - TableLabelEncode:
+ learn_empty_box: false
+ loc_reg_num: 8
+ max_text_length: 500
+ merge_no_span_structure: true
+ replace_empty_cell_token: false
+ - TableBoxEncode:
+ in_box_format: xyxyxyxy
+ out_box_format: xyxyxyxy
+ - ResizeTableImage:
+ max_len: 488
+ - NormalizeImage:
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ order: hwc
+ scale: 1./255.
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ - PaddingTableImage:
+ size:
+ - 488
+ - 488
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys:
+ - image
+ - structure
+ - bboxes
+ - bbox_masks
+ - shape
+PostProcess:
+ name: TableLabelDecode
+ merge_no_span_structure: true
+ character_dict:
+ -
+ -
+ -
+ -
+ -
+ -
+ -
+ - | '
+ - |
+ - ' colspan="2"'
+ - ' colspan="3"'
+ - ' colspan="4"'
+ - ' colspan="5"'
+ - ' colspan="6"'
+ - ' colspan="7"'
+ - ' colspan="8"'
+ - ' colspan="9"'
+ - ' colspan="10"'
+ - ' colspan="11"'
+ - ' colspan="12"'
+ - ' colspan="13"'
+ - ' colspan="14"'
+ - ' colspan="15"'
+ - ' colspan="16"'
+ - ' colspan="17"'
+ - ' colspan="18"'
+ - ' colspan="19"'
+ - ' colspan="20"'
+ - ' rowspan="2"'
+ - ' rowspan="3"'
+ - ' rowspan="4"'
+ - ' rowspan="5"'
+ - ' rowspan="6"'
+ - ' rowspan="7"'
+ - ' rowspan="8"'
+ - ' rowspan="9"'
+ - ' rowspan="10"'
+ - ' rowspan="11"'
+ - ' rowspan="12"'
+ - ' rowspan="13"'
+ - ' rowspan="14"'
+ - ' rowspan="15"'
+ - ' rowspan="16"'
+ - ' rowspan="17"'
+ - ' rowspan="18"'
+ - ' rowspan="19"'
+ - ' rowspan="20"'
diff --git a/slanet_plus_table/requirements.txt b/slanet_plus_table/requirements.txt
new file mode 100644
index 0000000..cb794ff
--- /dev/null
+++ b/slanet_plus_table/requirements.txt
@@ -0,0 +1,6 @@
+--extra-index-url=https://www.paddlepaddle.org.cn/packages/stable/cpu/
+opencv_python>=4.5.1.48
+numpy>=1.21.6,<2
+paddlepaddle==3.0.0b0
+Pillow
+requests
diff --git a/slanet_plus_table/setup.py b/slanet_plus_table/setup.py
new file mode 100644
index 0000000..8f20759
--- /dev/null
+++ b/slanet_plus_table/setup.py
@@ -0,0 +1,49 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import sys
+
+import setuptools
+
+from setuptools.command.install import install
+
+MODULE_NAME = "slanet_plus_table"
+
+
+
+setuptools.setup(
+ name=MODULE_NAME,
+ version="0.0.2",
+ platforms="Any",
+ long_description="simplify paddlex slanet plus table use",
+ long_description_content_type="text/markdown",
+ description="Tools for parsing table structures based paddlepaddle.",
+ author="jockerK",
+ author_email="xinyijianggo@gmail.com",
+ url="https://github.com/RapidAI/RapidTable",
+ license="Apache-2.0",
+ include_package_data=True,
+ install_requires=[
+ "paddlepaddle==3.0.0b0",
+ "PyYAML>=6.0",
+ "opencv_python>=4.5.1.48",
+ "numpy>=1.21.6",
+ "Pillow",
+ ],
+ packages=[
+ MODULE_NAME,
+ f"{MODULE_NAME}.models",
+ f"{MODULE_NAME}.table_matcher",
+ f"{MODULE_NAME}.table_structure",
+ ],
+ package_data={"": ["inference.pdiparams","inference.pdmodel"]},
+ keywords=["ppstructure,table,rapidocr,rapid_table"],
+ classifiers=[
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ ],
+ python_requires=">=3.8",
+)
diff --git a/slanet_plus_table/table_matcher/__init__.py b/slanet_plus_table/table_matcher/__init__.py
new file mode 100644
index 0000000..3cebc9e
--- /dev/null
+++ b/slanet_plus_table/table_matcher/__init__.py
@@ -0,0 +1,2 @@
+# -*- encoding: utf-8 -*-
+from .matcher import TableMatch
diff --git a/slanet_plus_table/table_matcher/matcher.py b/slanet_plus_table/table_matcher/matcher.py
new file mode 100644
index 0000000..b930c70
--- /dev/null
+++ b/slanet_plus_table/table_matcher/matcher.py
@@ -0,0 +1,125 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+import numpy as np
+
+from .utils import compute_iou, distance
+
+
+class TableMatch:
+ def __init__(self, filter_ocr_result=True, use_master=False):
+ self.filter_ocr_result = filter_ocr_result
+ self.use_master = use_master
+
+ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
+ if self.filter_ocr_result:
+ dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res)
+ matched_index = self.match_result(dt_boxes, pred_bboxes)
+ pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
+ return pred_html
+
+ def match_result(self, dt_boxes, pred_bboxes):
+ matched = {}
+ for i, gt_box in enumerate(dt_boxes):
+ distances = []
+ for j, pred_box in enumerate(pred_bboxes):
+ if len(pred_box) == 8:
+ pred_box = [
+ np.min(pred_box[0::2]),
+ np.min(pred_box[1::2]),
+ np.max(pred_box[0::2]),
+ np.max(pred_box[1::2]),
+ ]
+ distances.append(
+ (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
+ ) # compute iou and l1 distance
+ sorted_distances = distances.copy()
+ # select det box by iou and l1 distance
+ sorted_distances = sorted(
+ sorted_distances, key=lambda item: (item[1], item[0])
+ )
+ if distances.index(sorted_distances[0]) not in matched.keys():
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched
+
+ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+ end_html = []
+ td_index = 0
+ for tag in pred_structures:
+ if "" not in tag:
+ end_html.append(tag)
+ continue
+
+ if " | " == tag:
+ end_html.extend("")
+
+ if td_index in matched_index.keys():
+ b_with = False
+ if (
+ "" in ocr_contents[matched_index[td_index][0]]
+ and len(matched_index[td_index]) > 1
+ ):
+ b_with = True
+ end_html.extend("")
+
+ for i, td_index_index in enumerate(matched_index[td_index]):
+ content = ocr_contents[td_index_index][0]
+ if len(matched_index[td_index]) > 1:
+ if len(content) == 0:
+ continue
+
+ if content[0] == " ":
+ content = content[1:]
+
+ if "" in content:
+ content = content[3:]
+
+ if "" in content:
+ content = content[:-4]
+
+ if len(content) == 0:
+ continue
+
+ if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+ content += " "
+ end_html.extend(content)
+
+ if b_with:
+ end_html.extend("")
+
+ if " | | " == tag:
+ end_html.append("")
+ else:
+ end_html.append(tag)
+
+ td_index += 1
+
+ # Filter elements
+ filter_elements = ["", "", "", ""]
+ end_html = [v for v in end_html if v not in filter_elements]
+ return "".join(end_html), end_html
+
+ def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
+ y1 = pred_bboxes[:, 1::2].min()
+ new_dt_boxes = []
+ new_rec_res = []
+
+ for box, rec in zip(dt_boxes, rec_res):
+ if np.max(box[1::2]) < y1:
+ continue
+ new_dt_boxes.append(box)
+ new_rec_res.append(rec)
+ return new_dt_boxes, new_rec_res
diff --git a/slanet_plus_table/table_matcher/utils.py b/slanet_plus_table/table_matcher/utils.py
new file mode 100644
index 0000000..de55fe1
--- /dev/null
+++ b/slanet_plus_table/table_matcher/utils.py
@@ -0,0 +1,36 @@
+def distance(box_1, box_2):
+ x1, y1, x2, y2 = box_1
+ x3, y3, x4, y4 = box_2
+ dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+ dis_2 = abs(x3 - x1) + abs(y3 - y1)
+ dis_3 = abs(x4 - x2) + abs(y4 - y2)
+ return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1, rec2):
+ """
+ computing IoU
+ :param rec1: (y0, x0, y1, x1), which reflects
+ (top, left, bottom, right)
+ :param rec2: (y0, x0, y1, x1)
+ :return: scala value of IoU
+ """
+ # computing area of each rectangles
+ S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+ S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+ # computing the sum_area
+ sum_area = S_rec1 + S_rec2
+
+ # find the each edge of intersect rectangle
+ left_line = max(rec1[1], rec2[1])
+ right_line = min(rec1[3], rec2[3])
+ top_line = max(rec1[0], rec2[0])
+ bottom_line = min(rec1[2], rec2[2])
+
+ # judge if there is an intersect
+ if left_line >= right_line or top_line >= bottom_line:
+ return 0.0
+ else:
+ intersect = (right_line - left_line) * (bottom_line - top_line)
+ return (intersect / (sum_area - intersect)) * 1.0
diff --git a/slanet_plus_table/table_structure/__init__.py b/slanet_plus_table/table_structure/__init__.py
new file mode 100644
index 0000000..9da248f
--- /dev/null
+++ b/slanet_plus_table/table_structure/__init__.py
@@ -0,0 +1,2 @@
+# -*- encoding: utf-8 -*-
+from .table_structure import TableStructurer
diff --git a/slanet_plus_table/table_structure/table_structure.py b/slanet_plus_table/table_structure/table_structure.py
new file mode 100644
index 0000000..cd7af51
--- /dev/null
+++ b/slanet_plus_table/table_structure/table_structure.py
@@ -0,0 +1,178 @@
+import time
+import numpy as np
+from .utils import TablePredictor, TablePreprocess, TableLabelDecode
+
+
+# class SLANetPlus:
+# def __init__(self, model_dir, model_prefix="inference"):
+# self.preprocess_op = TablePreprocess()
+#
+# self.mean=[0.485, 0.456, 0.406]
+# self.std=[0.229, 0.224, 0.225]
+# self.target_img_size = [488, 488]
+# self.scale=1 / 255
+# self.order="hwc"
+# self.img_loader = LoadImage()
+# self.target_size = 488
+# self.pad_color = 0
+# self.predictor = TablePredictor(model_dir, model_prefix)
+# dict_character=['sos', '', '', '', '', '', '
', '', ' | ', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', ' | ', 'eos']
+# self.beg_str = "sos"
+# self.end_str = "eos"
+# self.dict = {}
+# self.table_matcher = TableMatch()
+# for i, char in enumerate(dict_character):
+# self.dict[char] = i
+# self.character = dict_character
+# self.td_token = ["", " | | "]
+#
+# def call(self, img):
+# starttime = time.time()
+# data = {"image": img}
+# data = self.preprocess_op(data)
+# img = data[0]
+# if img is None:
+# return None, 0
+# img = np.expand_dims(img, axis=0)
+# img = img.copy()
+# def __call__(self, img, ocr_result):
+# img = self.img_loader(img)
+# h, w = img.shape[:2]
+# n_img, h_resize, w_resize = self.resize(img)
+# n_img = self.normalize(n_img)
+# n_img = self.pad(n_img)
+# n_img = n_img.transpose((2, 0, 1))
+# n_img = np.expand_dims(n_img, axis=0)
+# start = time.time()
+# batch_output = self.predictor(n_img)
+# elapse_time = time.time() - start
+# ori_img_size = [[w, h]]
+# output = self.decode(batch_output, ori_img_size)[0]
+# corners = np.stack(output['bbox'], axis=0)
+# dt_boxes, rec_res = get_boxes_recs(ocr_result, h, w)
+# pred_html = self.table_matcher(output['structure'], convert_corners_to_bounding_boxes(corners), dt_boxes, rec_res)
+# return pred_html,output['bbox'], elapse_time
+# def resize(self, img):
+# h, w = img.shape[:2]
+# scale = self.target_size / max(h, w)
+# h_resize = round(h * scale)
+# w_resize = round(w * scale)
+# resized_img = cv2.resize(img, (w_resize, h_resize), interpolation=cv2.INTER_LINEAR)
+# return resized_img, h_resize, w_resize
+# def pad(self, img):
+# h, w = img.shape[:2]
+# tw, th = self.target_img_size
+# ph = th - h
+# pw = tw - w
+# pad = (0, ph, 0, pw)
+# chns = 1 if img.ndim == 2 else img.shape[2]
+# im = cv2.copyMakeBorder(img, *pad, cv2.BORDER_CONSTANT, value=(self.pad_color,) * chns)
+# return im
+# def normalize(self, img):
+# img = img.astype("float32", copy=False)
+# img *= self.scale
+# img -= self.mean
+# img /= self.std
+# return img
+#
+#
+# def decode(self, pred, ori_img_size):
+# bbox_preds, structure_probs = [], []
+# for bbox_pred, stru_prob in pred:
+# bbox_preds.append(bbox_pred)
+# structure_probs.append(stru_prob)
+# bbox_preds = np.array(bbox_preds)
+# structure_probs = np.array(structure_probs)
+#
+# bbox_list, structure_str_list, structure_score = self.decode_single(
+# structure_probs, bbox_preds, [self.target_img_size], ori_img_size
+# )
+# structure_str_list = [
+# (
+# ["", "", ""]
+# + structure
+# + ["
", "", ""]
+# )
+# for structure in structure_str_list
+# ]
+# return [
+# {"bbox": bbox, "structure": structure, "structure_score": structure_score}
+# for bbox, structure in zip(bbox_list, structure_str_list)
+# ]
+#
+#
+# def decode_single(self, structure_probs, bbox_preds, padding_size, ori_img_size):
+# """convert text-label into text-index."""
+# ignored_tokens = [self.beg_str, self.end_str]
+# end_idx = self.dict[self.end_str]
+#
+# structure_idx = structure_probs.argmax(axis=2)
+# structure_probs = structure_probs.max(axis=2)
+#
+# structure_batch_list = []
+# bbox_batch_list = []
+# batch_size = len(structure_idx)
+# for batch_idx in range(batch_size):
+# structure_list = []
+# bbox_list = []
+# score_list = []
+# for idx in range(len(structure_idx[batch_idx])):
+# char_idx = int(structure_idx[batch_idx][idx])
+# if idx > 0 and char_idx == end_idx:
+# break
+# if char_idx in ignored_tokens:
+# continue
+# text = self.character[char_idx]
+# if text in self.td_token:
+# bbox = bbox_preds[batch_idx, idx]
+# bbox = self._bbox_decode(
+# bbox, padding_size[batch_idx], ori_img_size[batch_idx]
+# )
+# bbox_list.append(bbox.astype(int))
+# structure_list.append(text)
+# score_list.append(structure_probs[batch_idx, idx])
+# structure_batch_list.append(structure_list)
+# structure_score = np.mean(score_list)
+# bbox_batch_list.append(bbox_list)
+#
+# return bbox_batch_list, structure_batch_list, structure_score
+#
+# def _bbox_decode(self, bbox, padding_shape, ori_shape):
+#
+# pad_w, pad_h = padding_shape
+# w, h = ori_shape
+# ratio_w = pad_w / w
+# ratio_h = pad_h / h
+# ratio = min(ratio_w, ratio_h)
+#
+# bbox[0::2] *= pad_w
+# bbox[1::2] *= pad_h
+# bbox[0::2] /= ratio
+# bbox[1::2] /= ratio
+#
+# return bbox
+
+
+class TableStructurer:
+ def __init__(self, model_path: str):
+ self.preprocess_op = TablePreprocess()
+ self.predictor = TablePredictor(model_path)
+ self.character = ['', '', '', '', '', '
', '', ' | ', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', ' | ']
+ self.postprocess_op = TableLabelDecode(self.character)
+
+ def __call__(self, img):
+ start_time = time.time()
+ data = {"image": img}
+ h, w = img.shape[:2]
+ ori_img_size = [[w, h]]
+ data = self.preprocess_op(data)
+ img = data[0]
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+ cur_img_size = [[488, 488]]
+ outputs = self.predictor(img)
+ output = self.postprocess_op(outputs, cur_img_size, ori_img_size)[0]
+ elapse = time.time() - start_time
+ return output["structure"], np.stack(output["bbox"]), elapse
diff --git a/slanet_plus_table/table_structure/utils.py b/slanet_plus_table/table_structure/utils.py
new file mode 100644
index 0000000..f1cf7a0
--- /dev/null
+++ b/slanet_plus_table/table_structure/utils.py
@@ -0,0 +1,331 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+# @Author: Jocker1212
+# @Contact: xinyijianggo@gmail.com
+from pathlib import Path
+
+import cv2
+import numpy as np
+from paddle.inference import Config, create_predictor
+
+class TablePredictor:
+ def __init__(self, model_dir, model_prefix="inference"):
+ model_file = f"{model_dir}/{model_prefix}.pdmodel"
+ params_file = f"{model_dir}/{model_prefix}.pdiparams"
+ config = Config(model_file, params_file)
+ config.disable_gpu()
+ config.disable_glog_info()
+ config.enable_new_ir(True)
+ config.enable_new_executor(True)
+ config.enable_memory_optim()
+ config.switch_ir_optim(True)
+ # Disable feed, fetch OP, needed by zero_copy_run
+ config.switch_use_feed_fetch_ops(False)
+ predictor = create_predictor(config)
+ self.config = config
+ self.predictor = predictor
+ # Get input and output handlers
+ input_names = predictor.get_input_names()
+ self.input_names = input_names.sort()
+ self.input_handlers = []
+ self.output_handlers = []
+ for input_name in input_names:
+ input_handler = predictor.get_input_handle(input_name)
+ self.input_handlers.append(input_handler)
+ self.output_names = predictor.get_output_names()
+ for output_name in self.output_names:
+ output_handler = predictor.get_output_handle(output_name)
+ self.output_handlers.append(output_handler)
+
+ def __call__(self, batch_imgs):
+ self.input_handlers[0].reshape(batch_imgs.shape)
+ self.input_handlers[0].copy_from_cpu(batch_imgs)
+ self.predictor.run()
+ output = []
+ for out_tensor in self.output_handlers:
+ batch = out_tensor.copy_to_cpu()
+ output.append(batch)
+ return self.format_output(output)
+
+ def format_output(self, pred):
+ return [res for res in zip(*pred)]
+
+class TableLabelDecode:
+ """decode the table model outputs(probs) to character str"""
+ def __init__(self, dict_character=[], merge_no_span_structure=True, **kwargs):
+
+ if merge_no_span_structure:
+ if " | " not in dict_character:
+ dict_character.append(" | ")
+ if "" in dict_character:
+ dict_character.remove(" | ")
+
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
+ self.td_token = [" | ", " | | "]
+
+ def add_special_char(self, dict_character):
+ """add_special_char"""
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ """get_ignored_tokens"""
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ """get_beg_end_flag_idx"""
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
+ return idx
+
+ def __call__(self, pred, img_size, ori_img_size):
+ """apply"""
+ bbox_preds, structure_probs = [], []
+ for bbox_pred, stru_prob in pred:
+ bbox_preds.append(bbox_pred)
+ structure_probs.append(stru_prob)
+ bbox_preds = np.array(bbox_preds)
+ structure_probs = np.array(structure_probs)
+
+ bbox_list, structure_str_list, structure_score = self.decode(
+ structure_probs, bbox_preds, img_size, ori_img_size
+ )
+ structure_str_list = [
+ (
+ ["", "", "", "", ""]
+ )
+ for structure in structure_str_list
+ ]
+ return [
+ {"bbox": bbox, "structure": structure, "structure_score": structure_score}
+ for bbox, structure in zip(bbox_list, structure_str_list)
+ ]
+
+ def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size):
+ """convert text-label into text-index."""
+ ignored_tokens = self.get_ignored_tokens()
+ end_idx = self.dict[self.end_str]
+
+ structure_idx = structure_probs.argmax(axis=2)
+ structure_probs = structure_probs.max(axis=2)
+
+ structure_batch_list = []
+ bbox_batch_list = []
+ batch_size = len(structure_idx)
+ for batch_idx in range(batch_size):
+ structure_list = []
+ bbox_list = []
+ score_list = []
+ for idx in range(len(structure_idx[batch_idx])):
+ char_idx = int(structure_idx[batch_idx][idx])
+ if idx > 0 and char_idx == end_idx:
+ break
+ if char_idx in ignored_tokens:
+ continue
+ text = self.character[char_idx]
+ if text in self.td_token:
+ bbox = bbox_preds[batch_idx, idx]
+ bbox = self._bbox_decode(
+ bbox, padding_size[batch_idx], ori_img_size[batch_idx]
+ )
+ bbox_list.append(bbox.astype(int))
+ structure_list.append(text)
+ score_list.append(structure_probs[batch_idx, idx])
+ structure_batch_list.append(structure_list)
+ structure_score = np.mean(score_list)
+ bbox_batch_list.append(bbox_list)
+
+ return bbox_batch_list, structure_batch_list, structure_score
+
+ def _bbox_decode(self, bbox, padding_shape, ori_shape):
+
+ pad_w, pad_h = padding_shape
+ w, h = ori_shape
+ ratio_w = pad_w / w
+ ratio_h = pad_h / h
+ ratio = min(ratio_w, ratio_h)
+
+ bbox[0::2] *= pad_w
+ bbox[1::2] *= pad_h
+ bbox[0::2] /= ratio
+ bbox[1::2] /= ratio
+
+ return bbox
+
+
+class TablePreprocess:
+ def __init__(self):
+ self.table_max_len = 488
+ self.build_pre_process_list()
+ self.ops = self.create_operators()
+
+ def __call__(self, data):
+ """transform"""
+ if self.ops is None:
+ self.ops = []
+
+ for op in self.ops:
+ data = op(data)
+ if data is None:
+ return None
+ return data
+
+ def create_operators(
+ self,
+ ):
+ """
+ create operators based on the config
+
+ Args:
+ params(list): a dict list, used to create some operators
+ """
+ assert isinstance(
+ self.pre_process_list, list
+ ), "operator config should be a list"
+ ops = []
+ for operator in self.pre_process_list:
+ assert (
+ isinstance(operator, dict) and len(operator) == 1
+ ), "yaml format error"
+ op_name = list(operator)[0]
+ param = {} if operator[op_name] is None else operator[op_name]
+ op = eval(op_name)(**param)
+ ops.append(op)
+ return ops
+
+ def build_pre_process_list(self):
+ resize_op = {
+ "ResizeTableImage": {
+ "max_len": self.table_max_len,
+ }
+ }
+ pad_op = {
+ "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
+ }
+ normalize_op = {
+ "NormalizeImage": {
+ "std": [0.229, 0.224, 0.225],
+ "mean": [0.485, 0.456, 0.406],
+ "scale": "1./255.",
+ "order": "hwc",
+ }
+ }
+ to_chw_op = {"ToCHWImage": None}
+ keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
+ self.pre_process_list = [
+ resize_op,
+ normalize_op,
+ pad_op,
+ to_chw_op,
+ keep_keys_op,
+ ]
+
+
+class ResizeTableImage:
+ def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
+ super(ResizeTableImage, self).__init__()
+ self.max_len = max_len
+ self.resize_bboxes = resize_bboxes
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ img = data["image"]
+ height, width = img.shape[0:2]
+ ratio = self.max_len / (max(height, width) * 1.0)
+ resize_h = int(height * ratio)
+ resize_w = int(width * ratio)
+ resize_img = cv2.resize(img, (resize_w, resize_h))
+ if self.resize_bboxes and not self.infer_mode:
+ data["bboxes"] = data["bboxes"] * ratio
+ data["image"] = resize_img
+ data["src_img"] = img
+ data["shape"] = np.array([height, width, ratio, ratio])
+ data["max_len"] = self.max_len
+ return data
+
+
+class PaddingTableImage:
+ def __init__(self, size, **kwargs):
+ super(PaddingTableImage, self).__init__()
+ self.size = size
+
+ def __call__(self, data):
+ img = data["image"]
+ pad_h, pad_w = self.size
+ padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
+ height, width = img.shape[0:2]
+ padding_img[0:height, 0:width, :] = img.copy()
+ data["image"] = padding_img
+ shape = data["shape"].tolist()
+ shape.extend([pad_h, pad_w])
+ data["shape"] = np.array(shape)
+ return data
+
+
+class NormalizeImage:
+ """normalize image such as substract mean, divide std"""
+
+ def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
+
+ def __call__(self, data):
+ img = np.array(data["image"])
+ assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+ data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
+ return data
+
+
+class ToCHWImage:
+ """convert hwc image to chw image"""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ img = np.array(data["image"])
+ data["image"] = img.transpose((2, 0, 1))
+ return data
+
+class KeepKeys:
+ def __init__(self, keep_keys, **kwargs):
+ self.keep_keys = keep_keys
+
+ def __call__(self, data):
+ data_list = []
+ for key in self.keep_keys:
+ data_list.append(data[key])
+ return data_list
diff --git a/slanet_plus_table/utils.py b/slanet_plus_table/utils.py
new file mode 100644
index 0000000..bb0736c
--- /dev/null
+++ b/slanet_plus_table/utils.py
@@ -0,0 +1,142 @@
+from io import BytesIO
+from pathlib import Path
+from typing import Union, Optional
+
+import cv2
+import numpy as np
+from PIL import Image, UnidentifiedImageError
+
+InputType = Union[str, np.ndarray, bytes, Path]
+
+class LoadImage:
+ def __init__(
+ self,
+ ):
+ pass
+
+ def __call__(self, img: InputType) -> np.ndarray:
+ if not isinstance(img, InputType.__args__):
+ raise LoadImageError(
+ f"The img type {type(img)} does not in {InputType.__args__}"
+ )
+
+ img = self.load_img(img)
+
+ if img.ndim == 2:
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ if img.ndim == 3 and img.shape[2] == 4:
+ return self.cvt_four_to_three(img)
+
+ return img
+
+ def load_img(self, img: InputType) -> np.ndarray:
+ if isinstance(img, (str, Path)):
+ self.verify_exist(img)
+ try:
+ img = np.array(Image.open(img))
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ except UnidentifiedImageError as e:
+ raise LoadImageError(f"cannot identify image file {img}") from e
+ return img
+
+ if isinstance(img, bytes):
+ img = np.array(Image.open(BytesIO(img)))
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+ if isinstance(img, np.ndarray):
+ return img
+
+ raise LoadImageError(f"{type(img)} is not supported!")
+
+ @staticmethod
+ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
+ """RGBA → RGB"""
+ r, g, b, a = cv2.split(img)
+ new_img = cv2.merge((b, g, r))
+
+ not_a = cv2.bitwise_not(a)
+ not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
+
+ new_img = cv2.bitwise_and(new_img, new_img, mask=a)
+ new_img = cv2.add(new_img, not_a)
+ return new_img
+
+ @staticmethod
+ def verify_exist(file_path: Union[str, Path]):
+ if not Path(file_path).exists():
+ raise LoadImageError(f"{file_path} does not exist.")
+
+class LoadImageError(Exception):
+ pass
+
+class VisTable:
+ def __init__(
+ self,
+ ):
+ self.load_img = LoadImage()
+
+ def __call__(
+ self,
+ img_path: Union[str, Path],
+ table_html_str: str,
+ save_html_path: Optional[str] = None,
+ table_cell_bboxes: Optional[np.ndarray] = None,
+ save_drawed_path: Optional[str] = None,
+ ) -> None:
+ if save_html_path:
+ html_with_border = self.insert_border_style(table_html_str)
+ self.save_html(save_html_path, html_with_border)
+
+ if table_cell_bboxes is None:
+ return None
+
+ img = self.load_img(img_path)
+
+ dims_bboxes = table_cell_bboxes.shape[1]
+ if dims_bboxes == 4:
+ drawed_img = self.draw_rectangle(img, table_cell_bboxes)
+ elif dims_bboxes == 8:
+ drawed_img = self.draw_polylines(img, table_cell_bboxes)
+ else:
+ raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
+
+ if save_drawed_path:
+ self.save_img(save_drawed_path, drawed_img)
+
+ return drawed_img
+
+ def insert_border_style(self, table_html_str: str):
+ style_res = """"""
+ prefix_table, suffix_table = table_html_str.split("")
+ html_with_border = f"{prefix_table}{style_res}{suffix_table}"
+ return html_with_border
+
+ @staticmethod
+ def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
+ img_copy = img.copy()
+ for box in boxes.astype(int):
+ x1, y1, x2, y2 = box
+ cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
+ return img_copy
+
+ @staticmethod
+ def draw_polylines(img: np.ndarray, points) -> np.ndarray:
+ img_copy = img.copy()
+ for point in points.astype(int):
+ point = point.reshape(4, 2)
+ cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
+ return img_copy
+
+ @staticmethod
+ def save_img(save_path: Union[str, Path], img: np.ndarray):
+ cv2.imwrite(str(save_path), img)
+
+ @staticmethod
+ def save_html(save_path: Union[str, Path], html: str):
+ with open(save_path, "w", encoding="utf-8") as f:
+ f.write(html)