Skip to content

Commit

Permalink
support export with pir and no pir (#14379)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Dec 19, 2024
1 parent 04c989b commit 0697d24
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
62 changes: 52 additions & 10 deletions ppocr/utils/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import yaml
import json
import copy
import shutil
import paddle
import paddle.nn as nn
from paddle.jit import to_static

from collections import OrderedDict
from packaging import version
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
Expand All @@ -39,21 +41,23 @@ def setup_orderdict():
def dump_infer_config(config, path, logger):
setup_orderdict()
infer_cfg = OrderedDict()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
if config["Global"].get("uniform_output_enabled", None):
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
common_dynamic_shapes = {
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
"x": [[1, 3, 24, 160], [1, 3, 48, 320], [8, 3, 96, 640]]
}
elif arch_config["model_type"] == "det":
common_dynamic_shapes = {
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
}
elif arch_config["algorithm"] == "SLANet":
common_dynamic_shapes = {
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 488, 488]]
}
elif arch_config["algorithm"] == "LaTeXOCR":
common_dynamic_shapes = {
Expand Down Expand Up @@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger):
logger.info("Export inference config file to {}".format(os.path.join(path)))


def export_single_model(
model, arch_config, save_path, logger, input_shape=None, quanter=None
):
def dynamic_to_static(model, arch_config, logger, input_shape=None):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
Expand Down Expand Up @@ -262,9 +264,46 @@ def export_single_model(
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
layer.rep()
return model


def export_single_model(
model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None
):

model = dynamic_to_static(model, arch_config, logger, input_shape)

if quanter is None:
paddle.jit.save(model, save_path)
paddle_version = version.parse(paddle.__version__)
if (
paddle_version >= version.parse("3.0.0b2")
or paddle_version == version.parse("0.0.0")
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]:
save_path = os.path.dirname(save_path)
for enable_pir in [True, False]:
if not enable_pir:
save_path_no_pir = os.path.join(save_path, "inference")
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(
model, arch_config, logger, input_shape
)
paddle.jit.save(model, save_path_no_pir)
else:
save_path_pir = os.path.join(
os.path.dirname(save_path),
f"{os.path.basename(save_path)}_pir",
"inference",
)
paddle.jit.save(model, save_path_pir)
shutil.copy(
yaml_path,
os.path.join(
os.path.dirname(save_path_pir), os.path.basename(yaml_path)
),
)
else:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
logger.info("inference model is saved to {}".format(save_path))
Expand Down Expand Up @@ -362,19 +401,22 @@ def export(config, base_model=None, save_path=None):
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
else:
input_shape = None

dump_infer_config(config, yaml_path, logger)
if arch_config["algorithm"] in [
"Distillation",
]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(
model.model_list[idx], archs[idx], sub_model_save_path, logger
model.model_list[idx],
archs[idx],
sub_model_save_path,
logger,
yaml_path,
)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(
model, arch_config, save_path, logger, input_shape=input_shape
model, arch_config, save_path, logger, yaml_path, input_shape=input_shape
)
dump_infer_config(config, yaml_path, logger)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ requests
albumentations==1.4.10
# to be compatible with albumentations
albucore==0.0.13
packaging

0 comments on commit 0697d24

Please sign in to comment.