diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 04579a376a..66ddec5b5f 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -62,43 +62,43 @@ def build_post_process(config, global_config=None): - support_dict = [ - "DBPostProcess", - "EASTPostProcess", - "SASTPostProcess", - "FCEPostProcess", - "CTCLabelDecode", - "AttnLabelDecode", - "ClsPostProcess", - "SRNLabelDecode", - "PGPostProcess", - "DistillationCTCLabelDecode", - "TableLabelDecode", - "DistillationDBPostProcess", - "NRTRLabelDecode", - "SARLabelDecode", - "SEEDLabelDecode", - "VQASerTokenLayoutLMPostProcess", - "VQAReTokenLayoutLMPostProcess", - "PRENLabelDecode", - "DistillationSARLabelDecode", - "ViTSTRLabelDecode", - "ABINetLabelDecode", - "TableMasterLabelDecode", - "SPINLabelDecode", - "DistillationSerPostProcess", - "DistillationRePostProcess", - "VLLabelDecode", - "PicoDetPostProcess", - "CTPostProcess", - "RFLLabelDecode", - "DRRGPostprocess", - "CANLabelDecode", - "SATRNLabelDecode", - "ParseQLabelDecode", - "CPPDLabelDecode", - "LaTeXOCRDecode", - ] + support_dict = { + "DBPostProcess": DBPostProcess, + "EASTPostProcess": EASTPostProcess, + "SASTPostProcess": SASTPostProcess, + "FCEPostProcess": FCEPostProcess, + "CTCLabelDecode": CTCLabelDecode, + "AttnLabelDecode": AttnLabelDecode, + "ClsPostProcess": ClsPostProcess, + "SRNLabelDecode": SRNLabelDecode, + "PGPostProcess": PGPostProcess, + "DistillationCTCLabelDecode": DistillationCTCLabelDecode, + "TableLabelDecode": TableLabelDecode, + "DistillationDBPostProcess": DistillationDBPostProcess, + "NRTRLabelDecode": NRTRLabelDecode, + "SARLabelDecode": SARLabelDecode, + "SEEDLabelDecode": SEEDLabelDecode, + "VQASerTokenLayoutLMPostProcess": VQASerTokenLayoutLMPostProcess, + "VQAReTokenLayoutLMPostProcess": VQAReTokenLayoutLMPostProcess, + "PRENLabelDecode": PRENLabelDecode, + "DistillationSARLabelDecode": DistillationSARLabelDecode, + "ViTSTRLabelDecode": ViTSTRLabelDecode, + "ABINetLabelDecode": ABINetLabelDecode, + "TableMasterLabelDecode": TableMasterLabelDecode, + "SPINLabelDecode": SPINLabelDecode, + "DistillationSerPostProcess": DistillationSerPostProcess, + "DistillationRePostProcess": DistillationRePostProcess, + "VLLabelDecode": VLLabelDecode, + "PicoDetPostProcess": PicoDetPostProcess, + "CTPostProcess": CTPostProcess, + "RFLLabelDecode": RFLLabelDecode, + "DRRGPostprocess": DRRGPostprocess, + "CANLabelDecode": CANLabelDecode, + "SATRNLabelDecode": SATRNLabelDecode, + "ParseQLabelDecode": ParseQLabelDecode, + "CPPDLabelDecode": CPPDLabelDecode, + "LaTeXOCRDecode": LaTeXOCRDecode, + } if config["name"] == "PSEPostProcess": from .pse_postprocess import PSEPostProcess @@ -112,7 +112,7 @@ def build_post_process(config, global_config=None): if global_config is not None: config.update(global_config) assert module_name in support_dict, Exception( - "post process only support {}".format(support_dict) + "post process only support {}".format(list(support_dict.keys())) ) - module_class = eval(module_name)(**config) + module_class = support_dict[module_name](**config) return module_class