From e3145103198071c74dc9f6237721cf004e8e8b11 Mon Sep 17 00:00:00 2001 From: zhangyubo0722 <94225063+zhangyubo0722@users.noreply.github.com> Date: Fri, 3 Jan 2025 15:34:29 +0800 Subject: [PATCH] import encryption for aistudio & fix sync bn --- ppocr/utils/export_model.py | 19 +++++++++++++++ ppocr/utils/save_load.py | 48 +++++++++++++++++++++++++++++-------- 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/ppocr/utils/export_model.py b/ppocr/utils/export_model.py index 91c41dda3e..ccb4456eca 100644 --- a/ppocr/utils/export_model.py +++ b/ppocr/utils/export_model.py @@ -331,6 +331,12 @@ def export_single_model( model = dynamic_to_static(model, arch_config, logger, input_shape) if quanter is None: + try: + import encryption # Attempt to import the encryption module for AIStudio's encryption model + except ( + ModuleNotFoundError + ): # Encryption is not needed if the module cannot be imported + print("Skipping import of the encryption module") if config["Global"].get("export_with_pir", False): paddle_version = version.parse(paddle.__version__) assert ( @@ -349,6 +355,18 @@ def export_single_model( return +def convert_bn(model): + for n, m in model.named_children(): + if isinstance(m, nn.SyncBatchNorm): + bn = nn.BatchNorm2D( + m._num_features, m._momentum, m._epsilon, m._weight_attr, m._bias_attr + ) + bn.set_dict(m.state_dict()) + setattr(model, n, bn) + else: + convert_bn(m) + + def export(config, base_model=None, save_path=None): if paddle.distributed.get_rank() != 0: return @@ -424,6 +442,7 @@ def export(config, base_model=None, save_path=None): else: model = build_model(config["Architecture"]) load_model(config, model, model_type=config["Architecture"]["model_type"]) + convert_bn(model) model.eval() if not save_path: diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 865d451aa9..4d4b7ba03b 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -26,6 +26,14 @@ from ppocr.utils.logging import get_logger from ppocr.utils.network import maybe_download_params +try: + import encryption # Attempt to import the encryption module for AIStudio's encryption model + + encrypted = encryption.is_encryption_needed() +except ImportError: + get_logger().warning("Skipping import of the encryption module.") + encrypted = False # Encryption is not needed if the module cannot be imported + __all__ = ["load_model"] @@ -278,13 +286,11 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num= else: train_results = {} train_results["model_name"] = config["Global"]["pdx_model_name"] - label_dict_path = os.path.abspath( - config["Global"].get("character_dict_path", "") - ) + label_dict_path = config["Global"].get("character_dict_path", "") if label_dict_path != "": + label_dict_path = os.path.abspath(label_dict_path) if not os.path.exists(label_dict_path): label_dict_path = "" - label_dict_path = label_dict_path train_results["label_dict"] = label_dict_path train_results["train_log"] = "train.log" train_results["visualdl_log"] = "" @@ -305,9 +311,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num= raise ValueError("No metric score found.") train_results["models"]["best"]["score"] = metric_score for tag in save_model_tag: - train_results["models"]["best"][tag] = os.path.join( - prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states" - ) + if tag == "pdparams" and encrypted: + train_results["models"]["best"][tag] = os.path.join( + prefix, + ( + f"{prefix}.encrypted.{tag}" + if tag != "pdstates" + else f"{prefix}.states" + ), + ) + else: + train_results["models"]["best"][tag] = os.path.join( + prefix, + f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states", + ) for tag in save_inference_tag: train_results["models"]["best"][tag] = os.path.join( prefix, @@ -329,9 +346,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num= metric_score = 0 train_results["models"][f"last_{1}"]["score"] = metric_score for tag in save_model_tag: - train_results["models"][f"last_{1}"][tag] = os.path.join( - prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states" - ) + if tag == "pdparams" and encrypted: + train_results["models"][f"last_{1}"][tag] = os.path.join( + prefix, + ( + f"{prefix}.encrypted.{tag}" + if tag != "pdstates" + else f"{prefix}.states" + ), + ) + else: + train_results["models"][f"last_{1}"][tag] = os.path.join( + prefix, + f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states", + ) for tag in save_inference_tag: train_results["models"][f"last_{1}"][tag] = os.path.join( prefix,