Skip to content

Commit

Permalink
[Fix] Support loading space character from dict file (#854)
Browse files Browse the repository at this point in the history
* [Feature] Support loading dict file with space character

* fix tests

* clean up \\r in tests

* add DICT37 and DICT91

* update docstr
  • Loading branch information
gaotongxiao authored Mar 22, 2022
1 parent 33c5e41 commit 37833ad
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 25 deletions.
42 changes: 27 additions & 15 deletions mmocr/models/textrecog/convertors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ class BaseConvertor:
"""Convert between text, index and tensor for text recognize pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_type (str): Type of dict, options are 'DICT36', 'DICT37', 'DICT90'
and 'DICT91'.
dict_file (None|str): Character dict file path. If not none,
the dict_file is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
Expand All @@ -18,32 +19,43 @@ class BaseConvertor:
unknown_idx = None
lower = False

DICT36 = tuple('0123456789abcdefghijklmnopqrstuvwxyz')
DICT90 = tuple('0123456789abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
'*+,-./:;<=>?@[\\]_`~')
dicts = dict(
DICT36=tuple('0123456789abcdefghijklmnopqrstuvwxyz'),
DICT90=tuple('0123456789abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
'*+,-./:;<=>?@[\\]_`~'),
# With space character
DICT37=tuple('0123456789abcdefghijklmnopqrstuvwxyz '),
DICT91=tuple('0123456789abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
'*+,-./:;<=>?@[\\]_`~ '))

def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
assert dict_type in ('DICT36', 'DICT90')
assert dict_file is None or isinstance(dict_file, str)
assert dict_list is None or isinstance(dict_list, list)
self.idx2char = []
if dict_file is not None:
for line in list_from_file(dict_file):
line = line.strip()
for line_num, line in enumerate(list_from_file(dict_file)):
line = line.strip('\r\n')
if len(line) > 1:
raise ValueError('Expect each line has 0 or 1 character, '
f'got {len(line)} characters '
f'at line {line_num + 1}')
if line != '':
self.idx2char.append(line)
elif dict_list is not None:
self.idx2char = dict_list
self.idx2char = list(dict_list)
else:
if dict_type == 'DICT36':
self.idx2char = list(self.DICT36)
if dict_type in self.dicts:
self.idx2char = list(self.dicts[dict_type])
else:
self.idx2char = list(self.DICT90)
raise NotImplementedError(f'Dict type {dict_type} is not '
'supported')

self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
assert len(set(self.idx2char)) == len(self.idx2char), \
'Invalid dictionary: Has duplicated characters.'

self.char2idx = {char: idx for idx, char in enumerate(self.idx2char)}

def num_classes(self):
"""Number of output classes."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_attn_label_convertor():
_create_dummy_dict_file(dict_file)

# test invalid arguments
with pytest.raises(AssertionError):
with pytest.raises(NotImplementedError):
AttnConvertor(5)
with pytest.raises(AssertionError):
AttnConvertor('DICT90', dict_file, '1')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import pytest

from mmocr.models.textrecog.convertors import BaseConvertor


def test_base_label_convertor():
with pytest.raises(NotImplementedError):
label_convertor = BaseConvertor()
label_convertor.str2tensor(None)
label_convertor.tensor2idx(None)

tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')

# Test loading a dictionary from file

# Test the capability of handling different line separator style
# Set newline='' to preserve the line separators as given in the test file
# *nix style line separator
with open(dict_file, 'w', newline='') as fw:
fw.write('a\nb\n\n \n\n')
label_convertor = BaseConvertor(dict_file=dict_file)
assert label_convertor.idx2char == ['a', 'b', ' ']
# Windows style line separator
with open(dict_file, 'w', newline='') as fw:
fw.write('a\r\nb\r\n\r\n \r\n\r\n')
label_convertor = BaseConvertor(dict_file=dict_file)
assert label_convertor.idx2char == ['a', 'b', ' ']

# Ensure it won't parse line separator as a space character
with open(dict_file, 'w') as fw:
fw.write('a\nb\n\n\nc\n\n')
label_convertor = BaseConvertor(dict_file=dict_file)
assert label_convertor.idx2char == ['a', 'b', 'c']

# Test loading an illegal dictionary
# Duplciated characters
with open(dict_file, 'w') as fw:
fw.write('a\nb\n\n \n\na')
with pytest.raises(AssertionError):
label_convertor = BaseConvertor(dict_file=dict_file)

# Too many characters per line
with open(dict_file, 'w') as fw:
fw.write('a\nb\nc \n')
with pytest.raises(
ValueError,
match='Expect each line has 0 or 1 character, got 2'
' characters at line 3'):
label_convertor = BaseConvertor(dict_file=dict_file)
with open(dict_file, 'w') as fw:
fw.write(' \n')
with pytest.raises(
ValueError,
match='Expect each line has 0 or 1 character, got 3'
' characters at line 1'):
label_convertor = BaseConvertor(dict_file=dict_file)

# Test creating a dictionary from dict_type
label_convertor = BaseConvertor(dict_type='DICT37')
assert len(label_convertor.idx2char) == 37
with pytest.raises(
NotImplementedError, match='Dict type DICT100 is not supported'):
label_convertor = BaseConvertor(dict_type='DICT100')

# Test creating a dictionary from dict_list
label_convertor = BaseConvertor(dict_list=['a', 'b', 'c', 'd', ' '])
assert label_convertor.idx2char == ['a', 'b', 'c', 'd', ' ']

tmp_dir.cleanup()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from mmocr.models.textrecog.convertors import BaseConvertor, CTCConvertor
from mmocr.models.textrecog.convertors import CTCConvertor


def _create_dummy_dict_file(dict_file):
Expand All @@ -23,7 +23,7 @@ def test_ctc_label_convertor():
_create_dummy_dict_file(dict_file)

# test invalid arguments
with pytest.raises(AssertionError):
with pytest.raises(NotImplementedError):
CTCConvertor(5)

label_convertor = CTCConvertor(dict_file=dict_file, with_unknown=False)
Expand Down Expand Up @@ -71,10 +71,3 @@ def test_ctc_label_convertor():
assert output_strings[0] == 'hell'

tmp_dir.cleanup()


def test_base_label_convertor():
with pytest.raises(NotImplementedError):
label_convertor = BaseConvertor()
label_convertor.str2tensor(None)
label_convertor.tensor2idx(None)

0 comments on commit 37833ad

Please sign in to comment.