Skip to content

Commit

Permalink
feat: add slanet plus table rec
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker1212 committed Oct 13, 2024
1 parent eb3d095 commit 2cb9cad
Show file tree
Hide file tree
Showing 13 changed files with 1,098 additions and 4 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ RapidTable库是专门用来文档类图像的表格结构还原,结合RapidOC

目前支持两种类别的表格识别模型:中文和英文表格识别模型,具体可参见下面表格:

| 模型类型 | 模型名称 | 模型大小 |
| :------: | :------------------------------------: | :------: |
| 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M |
| 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M |
slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升,但paddle2onnx暂时不支持转换

| 模型类型 | 模型名称 | 模型大小 |
|:--------------:|:--------------------------------------:| :------: |
| 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M |
| 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M |
| slanet_plus 中文 | `inference.pdmodel` | 7.4M |


模型来源:[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)
[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)

模型下载地址为:[百度网盘](https://pan.baidu.com/s/1PI9fksW6F6kQfJhwUkewWg?pwd=p29g) | [Google Drive](https://drive.google.com/drive/folders/1DAPWSN2zGQ-ED_Pz7RaJGTjfkN2-Mvsf?usp=sharing) |

Expand All @@ -47,6 +52,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
# 安装会引入paddlepaddle cpu 3.0.0b0
#pip install slanet_plus_table
```

### 使用方式
Expand All @@ -57,6 +64,7 @@ RapidTable类提供model_path参数,可以自行指定上述2个模型,默

```python
table_engine = RapidTable(model_path='ch_ppstructure_mobile_v2_SLANet.onnx')
#table_engine = SLANetPlus()
```

完整示例:
Expand All @@ -68,6 +76,7 @@ from rapid_table import RapidTable
from rapid_table import RapidTable, VisTable

table_engine = RapidTable()
#table_engine = SLANetPlus()
ocr_engine = RapidOCR()
viser = VisTable()

Expand Down Expand Up @@ -156,4 +165,7 @@ print(table_html_str)
- 去掉返回表格的html字符串中的`<thead></thead><tbody></tbody>`元素,便于后续统一。
- 采用Black工具优化代码

#### 2024.10.13 update
- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)

</details>
2 changes: 2 additions & 0 deletions slanet_plus_table/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .main import SLANetPlus
from .utils import VisTable
103 changes: 103 additions & 0 deletions slanet_plus_table/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import copy
import importlib
import time
from pathlib import Path
from typing import Optional, Union, List, Tuple

import cv2
import numpy as np

from slanet_plus_table.table_matcher import TableMatch
from slanet_plus_table.table_structure import TableStructurer
from slanet_plus_table.utils import LoadImage, VisTable

root_dir = Path(__file__).resolve().parent


class SLANetPlus:
def __init__(self, model_path: Optional[str] = None):
if model_path is None:
model_path = str(
root_dir / "models"
)

self.load_img = LoadImage()
self.table_structure = TableStructurer(model_path)
self.table_matcher = TableMatch()

try:
self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError:
self.ocr_engine = None

def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
ocr_result: List[Union[List[List[float]], str, str]] = None,
) -> Tuple[str, float]:
if self.ocr_engine is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
)

img = self.load_img(img_content)

s = time.time()
h, w = img.shape[:2]

if ocr_result is None:
ocr_result, _ = self.ocr_engine(img)
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)

pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img))
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)

elapse = time.time() - s
return pred_html, pred_bboxes, elapse

def get_boxes_recs(
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
) -> Tuple[np.ndarray, Tuple[str, str]]:
dt_boxes, rec_res, scores = list(zip(*ocr_result))
rec_res = list(zip(rec_res, scores))

r_boxes = []
for box in dt_boxes:
box = np.array(box)
x_min = max(0, box[:, 0].min() - 1)
x_max = min(w, box[:, 0].max() + 1)
y_min = max(0, box[:, 1].min() - 1)
y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, rec_res


if __name__ == '__main__':
slanet_table = SLANetPlus()
img_path = "D:\pythonProjects\TableStructureRec\outputs\\benchmark\\border_left_7267_OEJGHZF525Q011X2ZC34.jpg"
img = cv2.imread(img_path)
try:
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
) from exc
ocr_result, _ = ocr_engine(img)
table_html_str, table_cell_bboxes, elapse = slanet_table(img, ocr_result)

viser = VisTable()

img_path = Path(img_path)

save_dir = "outputs"
save_html_path = f"{save_dir}/{Path(img_path).stem}.html"
save_drawed_path = f"{save_dir}/vis_{Path(img_path).name}"
viser(
img_path,
table_html_str,
save_html_path,
table_cell_bboxes,
save_drawed_path,
)
106 changes: 106 additions & 0 deletions slanet_plus_table/models/inference.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
Hpi:
backend_config:
paddle_infer:
cpu_num_threads: 8
enable_log_info: false
selected_backends:
cpu: paddle_infer
gpu: paddle_infer
supported_backends:
cpu:
- paddle_infer
gpu:
- paddle_infer
Global:
model_name: SLANet_plus
PreProcess:
transform_ops:
- DecodeImage:
channel_first: false
img_mode: BGR
- TableLabelEncode:
learn_empty_box: false
loc_reg_num: 8
max_text_length: 500
merge_no_span_structure: true
replace_empty_cell_token: false
- TableBoxEncode:
in_box_format: xyxyxyxy
out_box_format: xyxyxyxy
- ResizeTableImage:
max_len: 488
- NormalizeImage:
mean:
- 0.485
- 0.456
- 0.406
order: hwc
scale: 1./255.
std:
- 0.229
- 0.224
- 0.225
- PaddingTableImage:
size:
- 488
- 488
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- structure
- bboxes
- bbox_masks
- shape
PostProcess:
name: TableLabelDecode
merge_no_span_structure: true
character_dict:
- <thead>
- </thead>
- <tbody>
- </tbody>
- <tr>
- </tr>
- <td>
- <td
- '>'
- </td>
- ' colspan="2"'
- ' colspan="3"'
- ' colspan="4"'
- ' colspan="5"'
- ' colspan="6"'
- ' colspan="7"'
- ' colspan="8"'
- ' colspan="9"'
- ' colspan="10"'
- ' colspan="11"'
- ' colspan="12"'
- ' colspan="13"'
- ' colspan="14"'
- ' colspan="15"'
- ' colspan="16"'
- ' colspan="17"'
- ' colspan="18"'
- ' colspan="19"'
- ' colspan="20"'
- ' rowspan="2"'
- ' rowspan="3"'
- ' rowspan="4"'
- ' rowspan="5"'
- ' rowspan="6"'
- ' rowspan="7"'
- ' rowspan="8"'
- ' rowspan="9"'
- ' rowspan="10"'
- ' rowspan="11"'
- ' rowspan="12"'
- ' rowspan="13"'
- ' rowspan="14"'
- ' rowspan="15"'
- ' rowspan="16"'
- ' rowspan="17"'
- ' rowspan="18"'
- ' rowspan="19"'
- ' rowspan="20"'
6 changes: 6 additions & 0 deletions slanet_plus_table/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--extra-index-url=https://www.paddlepaddle.org.cn/packages/stable/cpu/
opencv_python>=4.5.1.48
numpy>=1.21.6,<2
paddlepaddle==3.0.0b0
Pillow
requests
49 changes: 49 additions & 0 deletions slanet_plus_table/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import sys

import setuptools

from setuptools.command.install import install

MODULE_NAME = "slanet_plus_table"



setuptools.setup(
name=MODULE_NAME,
version="0.0.2",
platforms="Any",
long_description="simplify paddlex slanet plus table use",
long_description_content_type="text/markdown",
description="Tools for parsing table structures based paddlepaddle.",
author="jockerK",
author_email="[email protected]",
url="https://github.com/RapidAI/RapidTable",
license="Apache-2.0",
include_package_data=True,
install_requires=[
"paddlepaddle==3.0.0b0",
"PyYAML>=6.0",
"opencv_python>=4.5.1.48",
"numpy>=1.21.6",
"Pillow",
],
packages=[
MODULE_NAME,
f"{MODULE_NAME}.models",
f"{MODULE_NAME}.table_matcher",
f"{MODULE_NAME}.table_structure",
],
package_data={"": ["inference.pdiparams","inference.pdmodel"]},
keywords=["ppstructure,table,rapidocr,rapid_table"],
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
python_requires=">=3.8",
)
2 changes: 2 additions & 0 deletions slanet_plus_table/table_matcher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# -*- encoding: utf-8 -*-
from .matcher import TableMatch
Loading

0 comments on commit 2cb9cad

Please sign in to comment.