From 37833ada4bbd951c55702cbb52b192358c705dff Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Tue, 22 Mar 2022 17:44:32 +0800 Subject: [PATCH] [Fix] Support loading space character from dict file (#854) * [Feature] Support loading dict file with space character * fix tests * clean up \\r in tests * add DICT37 and DICT91 * update docstr --- mmocr/models/textrecog/convertors/base.py | 42 +++++++---- .../test_attn_label_convertor.py | 2 +- .../test_base_label_convertor.py | 74 +++++++++++++++++++ .../test_ctc_label_convertor.py | 11 +-- 4 files changed, 104 insertions(+), 25 deletions(-) create mode 100644 tests/test_models/test_label_convertor/test_base_label_convertor.py diff --git a/mmocr/models/textrecog/convertors/base.py b/mmocr/models/textrecog/convertors/base.py index 976299d99..83b1ab760 100644 --- a/mmocr/models/textrecog/convertors/base.py +++ b/mmocr/models/textrecog/convertors/base.py @@ -8,7 +8,8 @@ class BaseConvertor: """Convert between text, index and tensor for text recognize pipeline. Args: - dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. + dict_type (str): Type of dict, options are 'DICT36', 'DICT37', 'DICT90' + and 'DICT91'. dict_file (None|str): Character dict file path. If not none, the dict_file is of higher priority than dict_type. dict_list (None|list[str]): Character list. If not none, the list @@ -18,32 +19,43 @@ class BaseConvertor: unknown_idx = None lower = False - DICT36 = tuple('0123456789abcdefghijklmnopqrstuvwxyz') - DICT90 = tuple('0123456789abcdefghijklmnopqrstuvwxyz' - 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()' - '*+,-./:;<=>?@[\\]_`~') + dicts = dict( + DICT36=tuple('0123456789abcdefghijklmnopqrstuvwxyz'), + DICT90=tuple('0123456789abcdefghijklmnopqrstuvwxyz' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()' + '*+,-./:;<=>?@[\\]_`~'), + # With space character + DICT37=tuple('0123456789abcdefghijklmnopqrstuvwxyz '), + DICT91=tuple('0123456789abcdefghijklmnopqrstuvwxyz' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()' + '*+,-./:;<=>?@[\\]_`~ ')) def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None): - assert dict_type in ('DICT36', 'DICT90') assert dict_file is None or isinstance(dict_file, str) assert dict_list is None or isinstance(dict_list, list) self.idx2char = [] if dict_file is not None: - for line in list_from_file(dict_file): - line = line.strip() + for line_num, line in enumerate(list_from_file(dict_file)): + line = line.strip('\r\n') + if len(line) > 1: + raise ValueError('Expect each line has 0 or 1 character, ' + f'got {len(line)} characters ' + f'at line {line_num + 1}') if line != '': self.idx2char.append(line) elif dict_list is not None: - self.idx2char = dict_list + self.idx2char = list(dict_list) else: - if dict_type == 'DICT36': - self.idx2char = list(self.DICT36) + if dict_type in self.dicts: + self.idx2char = list(self.dicts[dict_type]) else: - self.idx2char = list(self.DICT90) + raise NotImplementedError(f'Dict type {dict_type} is not ' + 'supported') - self.char2idx = {} - for idx, char in enumerate(self.idx2char): - self.char2idx[char] = idx + assert len(set(self.idx2char)) == len(self.idx2char), \ + 'Invalid dictionary: Has duplicated characters.' + + self.char2idx = {char: idx for idx, char in enumerate(self.idx2char)} def num_classes(self): """Number of output classes.""" diff --git a/tests/test_models/test_label_convertor/test_attn_label_convertor.py b/tests/test_models/test_label_convertor/test_attn_label_convertor.py index 62c53466a..96b7c86d0 100644 --- a/tests/test_models/test_label_convertor/test_attn_label_convertor.py +++ b/tests/test_models/test_label_convertor/test_attn_label_convertor.py @@ -23,7 +23,7 @@ def test_attn_label_convertor(): _create_dummy_dict_file(dict_file) # test invalid arguments - with pytest.raises(AssertionError): + with pytest.raises(NotImplementedError): AttnConvertor(5) with pytest.raises(AssertionError): AttnConvertor('DICT90', dict_file, '1') diff --git a/tests/test_models/test_label_convertor/test_base_label_convertor.py b/tests/test_models/test_label_convertor/test_base_label_convertor.py new file mode 100644 index 000000000..7b1c7fecf --- /dev/null +++ b/tests/test_models/test_label_convertor/test_base_label_convertor.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile + +import pytest + +from mmocr.models.textrecog.convertors import BaseConvertor + + +def test_base_label_convertor(): + with pytest.raises(NotImplementedError): + label_convertor = BaseConvertor() + label_convertor.str2tensor(None) + label_convertor.tensor2idx(None) + + tmp_dir = tempfile.TemporaryDirectory() + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + + # Test loading a dictionary from file + + # Test the capability of handling different line separator style + # Set newline='' to preserve the line separators as given in the test file + # *nix style line separator + with open(dict_file, 'w', newline='') as fw: + fw.write('a\nb\n\n \n\n') + label_convertor = BaseConvertor(dict_file=dict_file) + assert label_convertor.idx2char == ['a', 'b', ' '] + # Windows style line separator + with open(dict_file, 'w', newline='') as fw: + fw.write('a\r\nb\r\n\r\n \r\n\r\n') + label_convertor = BaseConvertor(dict_file=dict_file) + assert label_convertor.idx2char == ['a', 'b', ' '] + + # Ensure it won't parse line separator as a space character + with open(dict_file, 'w') as fw: + fw.write('a\nb\n\n\nc\n\n') + label_convertor = BaseConvertor(dict_file=dict_file) + assert label_convertor.idx2char == ['a', 'b', 'c'] + + # Test loading an illegal dictionary + # Duplciated characters + with open(dict_file, 'w') as fw: + fw.write('a\nb\n\n \n\na') + with pytest.raises(AssertionError): + label_convertor = BaseConvertor(dict_file=dict_file) + + # Too many characters per line + with open(dict_file, 'w') as fw: + fw.write('a\nb\nc \n') + with pytest.raises( + ValueError, + match='Expect each line has 0 or 1 character, got 2' + ' characters at line 3'): + label_convertor = BaseConvertor(dict_file=dict_file) + with open(dict_file, 'w') as fw: + fw.write(' \n') + with pytest.raises( + ValueError, + match='Expect each line has 0 or 1 character, got 3' + ' characters at line 1'): + label_convertor = BaseConvertor(dict_file=dict_file) + + # Test creating a dictionary from dict_type + label_convertor = BaseConvertor(dict_type='DICT37') + assert len(label_convertor.idx2char) == 37 + with pytest.raises( + NotImplementedError, match='Dict type DICT100 is not supported'): + label_convertor = BaseConvertor(dict_type='DICT100') + + # Test creating a dictionary from dict_list + label_convertor = BaseConvertor(dict_list=['a', 'b', 'c', 'd', ' ']) + assert label_convertor.idx2char == ['a', 'b', 'c', 'd', ' '] + + tmp_dir.cleanup() diff --git a/tests/test_models/test_label_convertor/test_ctc_label_convertor.py b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py index df677e688..d26d99082 100644 --- a/tests/test_models/test_label_convertor/test_ctc_label_convertor.py +++ b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py @@ -6,7 +6,7 @@ import pytest import torch -from mmocr.models.textrecog.convertors import BaseConvertor, CTCConvertor +from mmocr.models.textrecog.convertors import CTCConvertor def _create_dummy_dict_file(dict_file): @@ -23,7 +23,7 @@ def test_ctc_label_convertor(): _create_dummy_dict_file(dict_file) # test invalid arguments - with pytest.raises(AssertionError): + with pytest.raises(NotImplementedError): CTCConvertor(5) label_convertor = CTCConvertor(dict_file=dict_file, with_unknown=False) @@ -71,10 +71,3 @@ def test_ctc_label_convertor(): assert output_strings[0] == 'hell' tmp_dir.cleanup() - - -def test_base_label_convertor(): - with pytest.raises(NotImplementedError): - label_convertor = BaseConvertor() - label_convertor.str2tensor(None) - label_convertor.tensor2idx(None)