From 93ce1f49d48d3c21a34c3b935720c4a5783fc76b Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 20 Oct 2021 02:51:26 +0000 Subject: [PATCH] update the code for the download_file and document --- docs/model_zoo/taskflow.md | 2 +- paddlenlp/taskflow/text_correction.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index 6b981c33d7a7..118d6dcc3832 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -134,7 +134,7 @@ ddp("百度是一家高科技公司") >>> [{'word': ['百度', '是', '一家', '高科技', '公司'], 'postag': ['ORG', 'v', 'm', 'n', 'n'], 'head': ['2', '0', '5', '5', '2'], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB'], 'prob': [1.0, 1.0, 1.0, 1.0, 1.0]}] # 使用ddparser-ernie-1.0进行预测 -ddp = Taskflow("dependency_parsing",model="ddparser-ernie-1.0") +ddp = Taskflow("dependency_parsing", model="ddparser-ernie-1.0") ddp("百度是一家高科技公司") >>> [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': ['2', '0', '5', '5', '2'], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']}] ``` diff --git a/paddlenlp/taskflow/text_correction.py b/paddlenlp/taskflow/text_correction.py index b453f0b56ff0..be0ad94e9cc9 100644 --- a/paddlenlp/taskflow/text_correction.py +++ b/paddlenlp/taskflow/text_correction.py @@ -139,7 +139,7 @@ def _construct_model(self, model): pad_pinyin_id=self._pinyin_vocab[self._pinyin_vocab.pad_token]) # Load the model parameter for the predict model_path = download_file(self._task_path, model + ".pdparams", - URLS[model][0], URLS[model][1], model) + URLS[model][0], URLS[model][1]) state_dict = paddle.load(model_path) model_instance.set_state_dict(state_dict) model_instance.eval()