diff --git a/README.md b/README.md
index 1acff842e..6043feacc 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,7 @@
[![Average time to resolve an issue](https://isitmaintained.com/badge/resolution/open-mmlab/mmocr.svg)](https://github.com/open-mmlab/mmocr/issues)
[![Percentage of issues still open](https://isitmaintained.com/badge/open/open-mmlab/mmocr.svg)](https://github.com/open-mmlab/mmocr/issues)
+[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_demo.svg)](https://openxlab.org.cn/apps?search=mmocr)
[📘Documentation](https://mmocr.readthedocs.io/en/dev-1.x/) |
[🛠️Installation](https://mmocr.readthedocs.io/en/dev-1.x/get_started/install.html) |
@@ -151,6 +152,7 @@ Supported algorithms:
- [x] [ABINet](configs/textrecog/abinet/README.md) (CVPR'2021)
- [x] [ASTER](configs/textrecog/aster/README.md) (TPAMI'2018)
- [x] [CRNN](configs/textrecog/crnn/README.md) (TPAMI'2016)
+- [x] [MAERec](configs/textrecog/maerec/README.md) (ICCV'2023)
- [x] [MASTER](configs/textrecog/master/README.md) (PR'2021)
- [x] [NRTR](configs/textrecog/nrtr/README.md) (ICDAR'2019)
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index c38839637..54357485f 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -26,6 +26,7 @@
[![Average time to resolve an issue](https://isitmaintained.com/badge/resolution/open-mmlab/mmocr.svg)](https://github.com/open-mmlab/mmocr/issues)
[![Percentage of issues still open](https://isitmaintained.com/badge/open/open-mmlab/mmocr.svg)](https://github.com/open-mmlab/mmocr/issues)
+[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_demo.svg)](https://openxlab.org.cn/apps?search=mmocr)
[📘文档](https://mmocr.readthedocs.io/zh_CN/dev-1.x/) |
[🛠️安装](https://mmocr.readthedocs.io/zh_CN/dev-1.x/get_started/install.html) |
@@ -150,6 +151,7 @@ mim install -e .
- [x] [ABINet](configs/textrecog/abinet/README.md) (CVPR'2021)
- [x] [ASTER](configs/textrecog/aster/README.md) (TPAMI'2018)
- [x] [CRNN](configs/textrecog/crnn/README.md) (TPAMI'2016)
+- [x] [MAERec](configs/textrecog/maerec/README.md) (ICCV'2023)
- [x] [MASTER](configs/textrecog/master/README.md) (PR'2021)
- [x] [NRTR](configs/textrecog/nrtr/README.md) (ICDAR'2019)
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
diff --git a/configs/textrecog/_base_/datasets/union14m_benchmark.py b/configs/textrecog/_base_/datasets/union14m_benchmark.py
new file mode 100644
index 000000000..007e4f878
--- /dev/null
+++ b/configs/textrecog/_base_/datasets/union14m_benchmark.py
@@ -0,0 +1,65 @@
+union14m_root = 'data/Union14M-L/'
+union14m_benchmark_root = 'data/Union14M-L/Union14M-Benchmarks'
+
+union14m_benchmark_artistic = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/artistic'),
+ ann_file=f'{union14m_benchmark_root}/artistic/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_contextless = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/contextless'),
+ ann_file=f'{union14m_benchmark_root}/contextless/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_curve = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/curve'),
+ ann_file=f'{union14m_benchmark_root}/curve/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_incomplete = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/incomplete'),
+ ann_file=f'{union14m_benchmark_root}/incomplete/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_incomplete_ori = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/incomplete_ori'),
+ ann_file=f'{union14m_benchmark_root}/incomplete_ori/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_multi_oriented = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/multi_oriented'),
+ ann_file=f'{union14m_benchmark_root}/multi_oriented/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_multi_words = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/multi_words'),
+ ann_file=f'{union14m_benchmark_root}/multi_words/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_salient = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/salient'),
+ ann_file=f'{union14m_benchmark_root}/salient/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_general = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_root}/'),
+ ann_file=f'{union14m_benchmark_root}/general/annotation.json',
+ test_mode=True,
+ pipeline=None)
diff --git a/configs/textrecog/_base_/datasets/union14m_train.py b/configs/textrecog/_base_/datasets/union14m_train.py
new file mode 100644
index 000000000..a91f2b104
--- /dev/null
+++ b/configs/textrecog/_base_/datasets/union14m_train.py
@@ -0,0 +1,38 @@
+union14m_data_root = 'data/Union14M-L/'
+
+union14m_challenging = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_challenging.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_hard = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_hard.json',
+ pipeline=None)
+
+union14m_medium = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_medium.json',
+ pipeline=None)
+
+union14m_normal = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_normal.json',
+ pipeline=None)
+
+union14m_easy = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_easy.json',
+ pipeline=None)
+
+union14m_val = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/val_annos.json',
+ pipeline=None)
diff --git a/configs/textrecog/_base_/schedules/schedule_adamw_cos_10e.py b/configs/textrecog/_base_/schedules/schedule_adamw_cos_10e.py
new file mode 100644
index 000000000..4f5c32a32
--- /dev/null
+++ b/configs/textrecog/_base_/schedules/schedule_adamw_cos_10e.py
@@ -0,0 +1,21 @@
+# optimizer
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=4e-4,
+ betas=(0.9, 0.999),
+ eps=1e-08,
+ weight_decay=0.01))
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='CosineAnnealingLR',
+ T_max=10,
+ eta_min=4e-6,
+ convert_to_iter_based=True)
+]
diff --git a/configs/textrecog/abinet/README.md b/configs/textrecog/abinet/README.md
index 6a7faadb3..727269fe4 100644
--- a/configs/textrecog/abinet/README.md
+++ b/configs/textrecog/abinet/README.md
@@ -47,6 +47,22 @@ Linguistic knowledge is of great benefit to scene text recognition. However, how
2. Facts about the pretrained model: MMOCR does not have a systematic pipeline to pretrain the language model (LM) yet, thus the weights of LM are converted from [the official pretrained model](https://github.com/FangShancheng/ABINet). The weights of ABINet-Vision are directly used as the vision model of ABINet.
```
+We also provide ABINet trained on [Union14M](https://github.com/Mountchicken/Union14M)
+
+- Evaluated on six common benchmarks
+
+ | methods | pretrained | | Regular Text | | | | Irregular Text | | download |
+ | :---------------------------------------------------------------: | :--------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :----------------------------------------------------------------- |
+ | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
+ | [ABINet-Vision](configs/textrecog/abinet/abinet-vision_10e_union14m.py) | - | 0.9730 | 0.9645 | 0.9552 | | 0.8536 | 0.8977 | 0.9479 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_union14m-cbf19742.pth) |
+
+- Evaluated on [Union14M-Benchmark](https://github.com/Mountchicken/Union14M)
+
+ | Methods | | Unsolved Challenges | | | | | Additional Challenges | | General | download |
+ | ------------------------------------------------------ | ----- | ------------------- | -------- | ----------- | --- | ------- | --------------------- | ---------- | ------- | ------------------------------------------------------- |
+ | | Curve | Multi-Oriented | Artistic | Contextless | | Salient | Multi-Words | Incomplete | General | |
+ | [ABINet-Vision](configs/textrecog/abinet/abinet-vision_10e_union14m.py) | 0.750 | 0.615 | 0.653 | 0.711 | | 0.729 | 0.591 | 0.026 | 0.794 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_union14m-cbf19742.pth) |
+
## Citation
```bibtex
diff --git a/configs/textrecog/abinet/abinet-vision_10e_union14m.py b/configs/textrecog/abinet/abinet-vision_10e_union14m.py
new file mode 100644
index 000000000..18961901e
--- /dev/null
+++ b/configs/textrecog/abinet/abinet-vision_10e_union14m.py
@@ -0,0 +1,118 @@
+_base_ = [
+ '../_base_/datasets/union14m_train.py',
+ '../_base_/datasets/union14m_benchmark.py',
+ '../_base_/datasets/cute80.py',
+ '../_base_/datasets/iiit5k.py',
+ '../_base_/datasets/svt.py',
+ '../_base_/datasets/svtp.py',
+ '../_base_/datasets/icdar2013.py',
+ '../_base_/datasets/icdar2015.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_adamw_cos_10e.py',
+ '_base_abinet.py',
+]
+
+load_from = 'https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth' # noqa
+
+_base_.pop('model')
+dictionary = dict(
+ type='Dictionary',
+ dict_file= # noqa
+ '{{ fileDirname }}/../../../dicts/english_digits_symbols_space.txt',
+ with_padding=True,
+ with_unknown=True,
+ same_start_end=True,
+ with_start=True,
+ with_end=True)
+
+model = dict(
+ type='ABINet',
+ backbone=dict(type='ResNetABI'),
+ encoder=dict(
+ type='ABIEncoder',
+ n_layers=3,
+ n_head=8,
+ d_model=512,
+ d_inner=2048,
+ dropout=0.1,
+ max_len=8 * 32,
+ ),
+ decoder=dict(
+ type='ABIFuser',
+ vision_decoder=dict(
+ type='ABIVisionDecoder',
+ in_channels=512,
+ num_channels=64,
+ attn_height=8,
+ attn_width=32,
+ attn_mode='nearest',
+ init_cfg=dict(type='Xavier', layer='Conv2d')),
+ module_loss=dict(type='ABIModuleLoss'),
+ postprocessor=dict(type='AttentionPostprocessor'),
+ dictionary=dictionary,
+ max_seq_len=26,
+ ),
+ data_preprocessor=dict(
+ type='TextRecogDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375]))
+
+# dataset settings
+train_list = [
+ _base_.union14m_challenging, _base_.union14m_hard, _base_.union14m_medium,
+ _base_.union14m_normal, _base_.union14m_easy
+]
+val_list = [
+ _base_.cute80_textrecog_test, _base_.iiit5k_textrecog_test,
+ _base_.svt_textrecog_test, _base_.svtp_textrecog_test,
+ _base_.icdar2013_textrecog_test, _base_.icdar2015_textrecog_test
+]
+test_list = [
+ _base_.union14m_benchmark_artistic,
+ _base_.union14m_benchmark_multi_oriented,
+ _base_.union14m_benchmark_contextless,
+ _base_.union14m_benchmark_curve,
+ _base_.union14m_benchmark_incomplete,
+ _base_.union14m_benchmark_incomplete_ori,
+ _base_.union14m_benchmark_multi_words,
+ _base_.union14m_benchmark_salient,
+ _base_.union14m_benchmark_general,
+]
+
+train_dataset = dict(
+ type='ConcatDataset', datasets=train_list, pipeline=_base_.train_pipeline)
+test_dataset = dict(
+ type='ConcatDataset', datasets=test_list, pipeline=_base_.test_pipeline)
+val_dataset = dict(
+ type='ConcatDataset', datasets=val_list, pipeline=_base_.test_pipeline)
+
+train_dataloader = dict(
+ batch_size=128,
+ num_workers=24,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=train_dataset)
+
+test_dataloader = dict(
+ batch_size=128,
+ num_workers=4,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=test_dataset)
+
+val_dataloader = dict(
+ batch_size=128,
+ num_workers=4,
+ persistent_workers=True,
+ pin_memory=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=val_dataset)
+
+val_evaluator = dict(
+ dataset_prefixes=['CUTE80', 'IIIT5K', 'SVT', 'SVTP', 'IC13', 'IC15'])
+test_evaluator = dict(dataset_prefixes=[
+ 'artistic', 'multi-oriented', 'contextless', 'curve', 'incomplete',
+ 'incomplete-ori', 'multi-words', 'salient', 'general'
+])
diff --git a/configs/textrecog/aster/README.md b/configs/textrecog/aster/README.md
index 0e795b7eb..ee4829aa6 100644
--- a/configs/textrecog/aster/README.md
+++ b/configs/textrecog/aster/README.md
@@ -40,6 +40,22 @@ A challenging aspect of scene text recognition is to handle text with distortion
| [ASTER](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9357 | 0.8949 | 0.9281 | | 0.7665 | 0.8062 | 0.8507 | [model](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/aster_resnet45_6e_st_mj-cc56eca4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/20221214_232605.log) |
| [ASTER-TTA](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9337 | 0.8949 | 0.9251 | | 0.7925 | 0.8109 | 0.8507 | |
+We also provide ASTER trained on [Union14M](https://github.com/Mountchicken/Union14M)
+
+- Evaluated on six common benchmarks
+
+ | Methods | pretrained | | Regular Text | | | | Irregular Text | | download |
+ | :------------------------------------------------------------: | :--------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------- |
+ | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
+ | [ASTER](configs/textrecog/aster/aster_resnet45_6e_union14m.py) | - | 0.9437 | 0.8903 | 0.9360 | | 0.7857 | 0.8093 | 0.9097 | [model](https://download.openmmlab.com/mmocr/textrecog/aster/aster_union14m/aster_union14m-230eb471.pth) |
+
+- Evaluated on [Union14M-Benchmark](https://github.com/Mountchicken/Union14M)
+
+ | Methods | | Unsolved Challenges | | | | | Additional Challenges | | General | download |
+ | ------------------------------------------------------ | ----- | ------------------- | -------- | ----------- | --- | ------- | --------------------- | ---------- | ------- | ------------------------------------------------------- |
+ | | Curve | Multi-Oriented | Artistic | Contextless | | Salient | Multi-Words | Incomplete | General | |
+ | [ASTER](configs/textrecog/aster/aster_resnet45_6e_union14m.py) | 0.384 | 0.130 | 0.418 | 0.529 | | 0.319 | 0.498 | 0.013 | 0.667 | [model](https://download.openmmlab.com/mmocr/textrecog/aster/aster_union14m/aster_union14m-230eb471.pth) |
+
## Citation
```bibtex
diff --git a/configs/textrecog/aster/aster_resnet45_6e_union14m.py b/configs/textrecog/aster/aster_resnet45_6e_union14m.py
new file mode 100644
index 000000000..fd702f348
--- /dev/null
+++ b/configs/textrecog/aster/aster_resnet45_6e_union14m.py
@@ -0,0 +1,91 @@
+# training schedule for 1x
+_base_ = [
+ '_base_aster.py',
+ '../_base_/datasets/union14m_train.py',
+ '../_base_/datasets/union14m_benchmark.py',
+ '../_base_/datasets/cute80.py',
+ '../_base_/datasets/iiit5k.py',
+ '../_base_/datasets/svt.py',
+ '../_base_/datasets/svtp.py',
+ '../_base_/datasets/icdar2013.py',
+ '../_base_/datasets/icdar2015.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_adamw_cos_6e.py',
+]
+
+dictionary = dict(
+ type='Dictionary',
+ dict_file= # noqa
+ '{{ fileDirname }}/../../../dicts/english_digits_symbols_space.txt',
+ with_padding=True,
+ with_unknown=True,
+ same_start_end=True,
+ with_start=True,
+ with_end=True)
+
+# dataset settings
+train_list = [
+ _base_.union14m_challenging, _base_.union14m_hard, _base_.union14m_medium,
+ _base_.union14m_normal, _base_.union14m_easy
+]
+val_list = [
+ _base_.cute80_textrecog_test, _base_.iiit5k_textrecog_test,
+ _base_.svt_textrecog_test, _base_.svtp_textrecog_test,
+ _base_.icdar2013_textrecog_test, _base_.icdar2015_textrecog_test
+]
+test_list = [
+ _base_.union14m_benchmark_artistic,
+ _base_.union14m_benchmark_multi_oriented,
+ _base_.union14m_benchmark_contextless,
+ _base_.union14m_benchmark_curve,
+ _base_.union14m_benchmark_incomplete,
+ _base_.union14m_benchmark_incomplete_ori,
+ _base_.union14m_benchmark_multi_words,
+ _base_.union14m_benchmark_salient,
+ _base_.union14m_benchmark_general,
+]
+
+default_hooks = dict(logger=dict(type='LoggerHook', interval=50))
+
+auto_scale_lr = dict(base_batch_size=512)
+
+train_dataset = dict(
+ type='ConcatDataset', datasets=train_list, pipeline=_base_.train_pipeline)
+test_dataset = dict(
+ type='ConcatDataset', datasets=test_list, pipeline=_base_.test_pipeline)
+val_dataset = dict(
+ type='ConcatDataset', datasets=val_list, pipeline=_base_.test_pipeline)
+
+train_dataloader = dict(
+ batch_size=512,
+ num_workers=12,
+ persistent_workers=True,
+ pin_memory=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=train_dataset)
+
+test_dataloader = dict(
+ batch_size=128,
+ num_workers=4,
+ persistent_workers=True,
+ pin_memory=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=test_dataset)
+
+val_dataloader = dict(
+ batch_size=128,
+ num_workers=4,
+ persistent_workers=True,
+ pin_memory=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=val_dataset)
+
+val_evaluator = dict(
+ dataset_prefixes=['CUTE80', 'IIIT5K', 'SVT', 'SVTP', 'IC13', 'IC15'])
+
+test_evaluator = dict(dataset_prefixes=[
+ 'artistic', 'multi-oriented', 'contextless', 'curve', 'incomplete',
+ 'incomplete-ori', 'multi-words', 'salient', 'general'
+])
diff --git a/configs/textrecog/maerec/README.md b/configs/textrecog/maerec/README.md
new file mode 100644
index 000000000..18b3b87c7
--- /dev/null
+++ b/configs/textrecog/maerec/README.md
@@ -0,0 +1,80 @@
+# MAERec
+
+> [Revisiting Scene Text Recognition: A Data Perspective](https://arxiv.org/abs/2307.08723)
+
+
+
+## Abstract
+
+This paper aims to re-assess scene text recognition (STR) from a data-oriented perspective. We begin by revisiting the six commonly used benchmarks in STR and observe a trend of performance saturation, whereby only 2.91% of the benchmark images cannot be accurately recognized by an ensemble of 13 representative models. While these results are impressive and suggest that STR could be considered solved, however, we argue that this is primarily due to the less challenging nature of the common benchmarks, thus concealing the underlying issues that STR faces. To this end, we consolidate a large-scale real STR dataset, namely Union14M, which comprises 4 million labeled images and 10 million unlabeled images, to assess the performance of STR models in more complex real-world scenarios. Our experiments demonstrate that the 13 models can only achieve an average accuracy of 66.53% on the 4 million labeled images, indicating that STR still faces numerous challenges in the real world. By analyzing the error patterns of the 13 models, we identify seven open challenges in STR and develop a challenge-driven benchmark consisting of eight distinct subsets to facilitate further progress in the field. Our exploration demonstrates that STR is far from being solved and leveraging data may be a promising solution. In this regard, we find that utilizing the 10 million unlabeled images through self-supervised pre-training can significantly improve the robustness of STR model in real-world scenarios and leads to state-of-the-art performance.
+
+
+
+
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :--------------------------------------------------------------: | :----------: | :--------: | :----: |
+| [Union14M](https://github.com/Mountchicken/Union14M#34-download) | 3230742 | 1 | real |
+
+### Test Dataset
+
+- On six common benchmarks
+
+ | testset | instance_num | type |
+ | :-----: | :----------: | :-------: |
+ | IIIT5K | 3000 | regular |
+ | SVT | 647 | regular |
+ | IC13 | 1015 | regular |
+ | IC15 | 2077 | irregular |
+ | SVTP | 645 | irregular |
+ | CT80 | 288 | irregular |
+
+- On Union14M-Benchmark
+
+ | testset | instance_num | type |
+ | :------------: | :----------: | :------------------: |
+ | Artistic | 900 | Unsolved Challenge |
+ | Curve | 2426 | Unsolved Challenge |
+ | Multi-Oriented | 1369 | Unsolved Challenge |
+ | Contextless | 779 | Additional Challenge |
+ | Multi-Words | 829 | Additional Challenge |
+ | Salient | 1585 | Additional Challenge |
+ | Incomplete | 1495 | Additional Challenge |
+ | General | 400,000 | - |
+
+## Results and Models
+
+- Evaluated on six common benchmarks
+
+ | Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
+ | :---------------------------------------------: | :----------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :--: | :----------------------------------------------: |
+ | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
+ | [MAERec-S](configs/textrecog/maerec/maerec_s_union14m.py) | [ViT-Small (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 98.0 | 97.6 | 96.8 | | 87.1 | 93.2 | 97.9 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_s_union14m-a9a157e5.pth) |
+ | [MAERec-B](configs/textrecog/maerec/maerec_b_union14m.py) | [ViT-Base (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 98.5 | 98.1 | 97.8 | | 89.5 | 94.4 | 98.6 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_b_union14m-4b98d1b4.pth) |
+
+- Evaluated on Union14M-Benchmark
+
+ | Methods | Backbone | | Unsolved Challenges | | | | | Additional Challenges | | General | download |
+ | ----------------------------------- | ------------------------------------- | ----- | ------------------- | -------- | ----------- | --- | ------- | --------------------- | ---------- | ------- | ------------------------------------- |
+ | | | Curve | Multi-Oriented | Artistic | Contextless | | Salient | Multi-Words | Incomplete | General | |
+ | [MAERec-S](configs/textrecog/maerec/maerec_s_union14m.py) | [ViT-Small (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 81.4 | 71.4 | 72.0 | 82.0 | | 78.5 | 82.4 | 2.7 | 82.5 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_s_union14m-a9a157e5.pth) |
+ | [MAERec-B](configs/textrecog/maerec/maerec_b_union14m.py) | [ViT-Base (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 88.8 | 83.9 | 80.0 | 85.5 | | 84.9 | 87.5 | 2.6 | 85.8 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_b_union14m-4b98d1b4.pth) |
+
+- **To train with MAERec, you need to download pretrained ViT weight and load it in the config file. Check [here](https://github.com/Mountchicken/Union14M/blob/main/docs/finetune.md) for instructions**
+
+## Citation
+
+```bibtex
+@misc{jiang2023revisiting,
+ title={Revisiting Scene Text Recognition: A Data Perspective},
+ author={Qing Jiang and Jiapeng Wang and Dezhi Peng and Chongyu Liu and Lianwen Jin},
+ year={2023},
+ eprint={2307.08723},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
diff --git a/configs/textrecog/maerec/_base_marec_vit_s.py b/configs/textrecog/maerec/_base_marec_vit_s.py
new file mode 100644
index 000000000..06febd088
--- /dev/null
+++ b/configs/textrecog/maerec/_base_marec_vit_s.py
@@ -0,0 +1,159 @@
+dictionary = dict(
+ type='Dictionary',
+ dict_file= # noqa
+ '{{ fileDirname }}/../../../dicts/english_digits_symbols_space.txt',
+ with_padding=True,
+ with_unknown=True,
+ same_start_end=True,
+ with_start=True,
+ with_end=True)
+
+model = dict(
+ type='MAERec',
+ backbone=dict(
+ type='VisionTransformer',
+ img_size=(32, 128),
+ patch_size=(4, 4),
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ pretrained=None),
+ decoder=dict(
+ type='MAERecDecoder',
+ n_layers=6,
+ d_embedding=384,
+ n_head=8,
+ d_model=384,
+ d_inner=384 * 4,
+ d_k=48,
+ d_v=48,
+ postprocessor=dict(type='AttentionPostprocessor'),
+ module_loss=dict(
+ type='CEModuleLoss', reduction='mean', ignore_first_char=True),
+ max_seq_len=48,
+ dictionary=dictionary),
+ data_preprocessor=dict(
+ type='TextRecogDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375]))
+
+train_pipeline = [
+ dict(type='LoadImageFromFile', ignore_empty=True, min_size=0),
+ dict(type='LoadOCRAnnotations', with_text=True),
+ dict(type='Resize', scale=(128, 32)),
+ dict(
+ type='RandomApply',
+ prob=0.5,
+ transforms=[
+ dict(
+ type='RandomChoice',
+ transforms=[
+ dict(
+ type='RandomRotate',
+ max_angle=15,
+ ),
+ dict(
+ type='TorchVisionWrapper',
+ op='RandomAffine',
+ degrees=15,
+ translate=(0.3, 0.3),
+ scale=(0.5, 2.),
+ shear=(-45, 45),
+ ),
+ dict(
+ type='TorchVisionWrapper',
+ op='RandomPerspective',
+ distortion_scale=0.5,
+ p=1,
+ ),
+ ])
+ ],
+ ),
+ dict(
+ type='RandomApply',
+ prob=0.25,
+ transforms=[
+ dict(type='PyramidRescale'),
+ dict(
+ type='mmdet.Albu',
+ transforms=[
+ dict(type='GaussNoise', var_limit=(20, 20), p=0.5),
+ dict(type='MotionBlur', blur_limit=7, p=0.5),
+ ]),
+ ]),
+ dict(
+ type='RandomApply',
+ prob=0.25,
+ transforms=[
+ dict(
+ type='TorchVisionWrapper',
+ op='ColorJitter',
+ brightness=0.5,
+ saturation=0.5,
+ contrast=0.5,
+ hue=0.1),
+ ]),
+ dict(
+ type='PackTextRecogInputs',
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(128, 32)),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadOCRAnnotations', with_text=True),
+ dict(
+ type='PackTextRecogInputs',
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
+]
+
+tta_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(
+ type='ConditionApply',
+ true_transforms=[
+ dict(
+ type='ImgAugWrapper',
+ args=[dict(cls='Rot90', k=0, keep_size=False)])
+ ],
+ condition="results['img_shape'][1]= digit_version(mmdet_minimum_version)
diff --git a/mmocr/models/textdet/postprocessors/pse_postprocessor.py b/mmocr/models/textdet/postprocessors/pse_postprocessor.py
index b0a1fb9f8..45461fbd5 100644
--- a/mmocr/models/textdet/postprocessors/pse_postprocessor.py
+++ b/mmocr/models/textdet/postprocessors/pse_postprocessor.py
@@ -93,7 +93,7 @@ def get_text_instances(self, pred_results: torch.Tensor,
area = points.shape[0]
score_instance = np.mean(score[labels == i])
if not (area >= self.min_text_area
- or score_instance > self.score_threshold):
+ and score_instance > self.score_threshold):
continue
polygon = self._points2boundary(points)
diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py
index 3201de388..feff05645 100644
--- a/mmocr/models/textrecog/backbones/__init__.py
+++ b/mmocr/models/textrecog/backbones/__init__.py
@@ -6,8 +6,9 @@
from .resnet31_ocr import ResNet31OCR
from .resnet_abi import ResNetABI
from .shallow_cnn import ShallowCNN
+from .vit import VisionTransformer
__all__ = [
'ResNet31OCR', 'MiniVGG', 'NRTRModalityTransform', 'ShallowCNN',
- 'ResNetABI', 'ResNet', 'MobileNetV2'
+ 'ResNetABI', 'ResNet', 'MobileNetV2', 'VisionTransformer'
]
diff --git a/mmocr/models/textrecog/backbones/vit.py b/mmocr/models/textrecog/backbones/vit.py
new file mode 100644
index 000000000..5787b72e2
--- /dev/null
+++ b/mmocr/models/textrecog/backbones/vit.py
@@ -0,0 +1,106 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code are migragated from MAE
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# DeiT: https://github.com/facebookresearch/deit
+
+from functools import partial
+from typing import Tuple
+
+import timm.models.vision_transformer
+import torch
+import torch.nn as nn
+
+from mmocr.registry import MODELS
+
+
+@MODELS.register_module()
+class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
+ """Vision Transformer migrated from timm.
+
+ Args:
+ global_pool (bool): If True, apply global pooling to the output
+ of the last stage. Default: False.
+ patch_size (int): Patch token size. Default: 8.
+ img_size (tuple[int]): Input image size. Default: (32, 128).
+ embed_dim (int): Number of linear projection output channels.
+ Default: 192.
+ depth (int): Number of blocks. Default: 12.
+ num_heads (int): Number of attention heads. Default: 3.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key,
+ value. Default: True.
+ norm_layer (nn.Module): Normalization layer. Default:
+ partial(nn.LayerNorm, eps=1e-6).
+ pretrained (str): Path to pre-trained checkpoint. Default: None.
+ """
+
+ def __init__(self,
+ global_pool: bool = False,
+ patch_size: int = 8,
+ img_size: Tuple[int, int] = (32, 128),
+ embed_dim: int = 192,
+ depth: int = 12,
+ num_heads: int = 3,
+ mlp_ratio: int = 4.,
+ qkv_bias: bool = True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ pretrained: bool = None,
+ **kwargs):
+ super(VisionTransformer, self).__init__(
+ patch_size=patch_size,
+ img_size=img_size,
+ embed_dim=embed_dim,
+ depth=depth,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ **kwargs)
+
+ self.global_pool = global_pool
+ if self.global_pool:
+ norm_layer = kwargs['norm_layer']
+ embed_dim = kwargs['embed_dim']
+ self.fc_norm = norm_layer(embed_dim)
+
+ del self.norm # remove the original norm
+ self.reset_classifier(0)
+
+ if pretrained:
+ checkpoint = torch.load(pretrained, map_location='cpu')
+
+ print('Load pre-trained checkpoint from: %s' % pretrained)
+ checkpoint_model = checkpoint['model']
+ state_dict = self.state_dict()
+ for k in ['head.weight', 'head.bias']:
+ if k in checkpoint_model and checkpoint_model[
+ k].shape != state_dict[k].shape:
+ print(f'Removing key {k} from pretrained checkpoint')
+ del checkpoint_model[k]
+ # remove key with decoder
+ for k in list(checkpoint_model.keys()):
+ if 'decoder' in k:
+ del checkpoint_model[k]
+ msg = self.load_state_dict(checkpoint_model, strict=False)
+ print(msg)
+
+ def forward_features(self, x: torch.Tensor):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ return self.forward_features(x)
diff --git a/mmocr/models/textrecog/decoders/__init__.py b/mmocr/models/textrecog/decoders/__init__.py
index 1ea5b881f..d68dcc224 100755
--- a/mmocr/models/textrecog/decoders/__init__.py
+++ b/mmocr/models/textrecog/decoders/__init__.py
@@ -5,6 +5,7 @@
from .aster_decoder import ASTERDecoder
from .base import BaseDecoder
from .crnn_decoder import CRNNDecoder
+from .maerec_decoder import MAERecDecoder
from .master_decoder import MasterDecoder
from .nrtr_decoder import NRTRDecoder
from .position_attention_decoder import PositionAttentionDecoder
@@ -19,5 +20,6 @@
'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder',
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
'ABILanguageDecoder', 'ABIVisionDecoder', 'MasterDecoder',
- 'RobustScannerFuser', 'ABIFuser', 'SVTRDecoder', 'ASTERDecoder'
+ 'RobustScannerFuser', 'ABIFuser', 'SVTRDecoder', 'ASTERDecoder',
+ 'MAERecDecoder'
]
diff --git a/mmocr/models/textrecog/decoders/maerec_decoder.py b/mmocr/models/textrecog/decoders/maerec_decoder.py
new file mode 100644
index 000000000..e772676d0
--- /dev/null
+++ b/mmocr/models/textrecog/decoders/maerec_decoder.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Dict, List, Optional, Sequence, Union
+
+import torch
+import torch.nn as nn
+from mmengine.model import ModuleList
+
+from mmocr.models.common import PositionalEncoding, TFDecoderLayer
+from mmocr.models.common.dictionary import Dictionary
+from mmocr.registry import MODELS
+from mmocr.structures import TextRecogDataSample
+from .base import BaseDecoder
+
+
+@MODELS.register_module()
+class MAERecDecoder(BaseDecoder):
+ """Transformer Decoder block with self attention mechanism.
+
+ Args:
+ n_layers (int): Number of attention layers. Defaults to 6.
+ d_embedding (int): Language embedding dimension. Defaults to 512.
+ n_head (int): Number of parallel attention heads. Defaults to 8.
+ d_k (int): Dimension of the key vector. Defaults to 64.
+ d_v (int): Dimension of the value vector. Defaults to 64
+ d_model (int): Dimension :math:`D_m` of the input from previous model.
+ Defaults to 512.
+ d_inner (int): Hidden dimension of feedforward layers. Defaults to 256.
+ n_position (int): Length of the positional encoding vector. Must be
+ greater than ``max_seq_len``. Defaults to 200.
+ dropout (float): Dropout rate for text embedding, MHSA, FFN. Defaults
+ to 0.1.
+ module_loss (dict, optional): Config to build module_loss. Defaults
+ to None.
+ postprocessor (dict, optional): Config to build postprocessor.
+ Defaults to None.
+ dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
+ the instance of `Dictionary`.
+ max_seq_len (int): Maximum output sequence length :math:`T`. Defaults
+ to 30.
+ init_cfg (dict or list[dict], optional): Initialization configs.
+ """
+
+ def __init__(self,
+ n_layers: int = 6,
+ d_embedding: int = 512,
+ n_head: int = 8,
+ d_k: int = 64,
+ d_v: int = 64,
+ d_model: int = 512,
+ d_inner: int = 256,
+ n_position: int = 200,
+ dropout: float = 0.1,
+ module_loss: Optional[Dict] = None,
+ postprocessor: Optional[Dict] = None,
+ dictionary: Optional[Union[Dict, Dictionary]] = None,
+ max_seq_len: int = 30,
+ init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
+ super().__init__(
+ module_loss=module_loss,
+ postprocessor=postprocessor,
+ dictionary=dictionary,
+ init_cfg=init_cfg,
+ max_seq_len=max_seq_len)
+
+ self.padding_idx = self.dictionary.padding_idx
+ self.start_idx = self.dictionary.start_idx
+ self.max_seq_len = max_seq_len
+
+ self.trg_word_emb = nn.Embedding(
+ self.dictionary.num_classes,
+ d_embedding,
+ padding_idx=self.padding_idx)
+
+ self.position_enc = PositionalEncoding(
+ d_embedding, n_position=n_position)
+ self.dropout = nn.Dropout(p=dropout)
+
+ self.layer_stack = ModuleList([
+ TFDecoderLayer(
+ d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ])
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ pred_num_class = self.dictionary.num_classes
+ self.classifier = nn.Linear(d_model, pred_num_class)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor:
+ """Generate mask for target sequence.
+
+ Args:
+ trg_seq (torch.Tensor): Input text sequence. Shape :math:`(N, T)`.
+
+ Returns:
+ Tensor: Target mask. Shape :math:`(N, T, T)`.
+ E.g.:
+ seq = torch.Tensor([[1, 2, 0, 0]]), pad_idx = 0, then
+ target_mask =
+ torch.Tensor([[[True, False, False, False],
+ [True, True, False, False],
+ [True, True, False, False],
+ [True, True, False, False]]])
+ """
+
+ pad_mask = (trg_seq != self.padding_idx).unsqueeze(-2)
+
+ len_s = trg_seq.size(1)
+ subsequent_mask = 1 - torch.triu(
+ torch.ones((len_s, len_s), device=trg_seq.device), diagonal=1)
+ subsequent_mask = subsequent_mask.unsqueeze(0).bool()
+
+ return pad_mask & subsequent_mask
+
+ def _get_source_mask(self, src_seq: torch.Tensor,
+ valid_ratios: Sequence[float]) -> torch.Tensor:
+ """Generate mask for source sequence.
+
+ Args:
+ src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
+ valid_ratios (list[float]): The valid ratio of input image. For
+ example, if the width of the original image is w1 and the width
+ after padding is w2, then valid_ratio = w1/w2. Source mask is
+ used to cover the area of the padding region.
+
+ Returns:
+ Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
+ padding area are False, and the rest are True.
+ """
+
+ N, T, _ = src_seq.size()
+ mask = None
+ if len(valid_ratios) > 0:
+ mask = src_seq.new_zeros((N, T), device=src_seq.device)
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(T, math.ceil(T * valid_ratio))
+ mask[i, :valid_width] = 1
+
+ return mask
+
+ def _attention(self,
+ trg_seq: torch.Tensor,
+ src: torch.Tensor,
+ src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """A wrapped process for transformer based decoder including text
+ embedding, position embedding, N x transformer decoder and a LayerNorm
+ operation.
+
+ Args:
+ trg_seq (Tensor): Target sequence in. Shape :math:`(N, T)`.
+ src (Tensor): Source sequence from encoder in shape
+ Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
+ src_mask (Tensor, Optional): Mask for source sequence.
+ Shape :math:`(N, T)`. Defaults to None.
+
+ Returns:
+ Tensor: Output sequence from transformer decoder.
+ Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
+ """
+
+ trg_embedding = self.trg_word_emb(trg_seq)
+ trg_pos_encoded = self.position_enc(trg_embedding)
+ trg_mask = self._get_target_mask(trg_seq)
+ tgt_seq = self.dropout(trg_pos_encoded)
+
+ output = tgt_seq
+ for dec_layer in self.layer_stack:
+ output = dec_layer(
+ output,
+ src,
+ self_attn_mask=trg_mask,
+ dec_enc_attn_mask=src_mask)
+ output = self.layer_norm(output)
+
+ return output
+
+ def forward_train(self,
+ feat: Optional[torch.Tensor] = None,
+ out_enc: torch.Tensor = None,
+ data_samples: Sequence[TextRecogDataSample] = None
+ ) -> torch.Tensor:
+ """Forward for training. Source mask will be used here.
+
+ Args:
+ feat (Tensor, optional): Unused.
+ out_enc (Tensor): Encoder output of shape : math:`(N, T, D_m)`
+ where :math:`D_m` is ``d_model``. Defaults to None.
+ data_samples (list[TextRecogDataSample]): Batch of
+ TextRecogDataSample, containing gt_text and valid_ratio
+ information. Defaults to None.
+
+ Returns:
+ Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
+ :math:`C` is ``num_classes``.
+ """
+ valid_ratios = []
+ for data_sample in data_samples:
+ valid_ratios.append(data_sample.get('valid_ratio'))
+ src_mask = self._get_source_mask(feat, valid_ratios)
+ trg_seq = []
+ for data_sample in data_samples:
+ trg_seq.append(data_sample.gt_text.padded_indexes.to(feat.device))
+ trg_seq = torch.stack(trg_seq, dim=0)
+ attn_output = self._attention(trg_seq, feat, src_mask=src_mask)
+ outputs = self.classifier(attn_output)
+
+ return outputs
+
+ def forward_test(self,
+ feat: Optional[torch.Tensor] = None,
+ out_enc: torch.Tensor = None,
+ data_samples: Sequence[TextRecogDataSample] = None
+ ) -> torch.Tensor:
+ """Forward for testing.
+
+ Args:
+ feat (Tensor, optional): Unused.
+ out_enc (Tensor): Encoder output of shape:
+ math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
+ Defaults to None.
+ data_samples (list[TextRecogDataSample]): Batch of
+ TextRecogDataSample, containing gt_text and valid_ratio
+ information. Defaults to None.
+
+ Returns:
+ Tensor: Character probabilities. of shape
+ :math:`(N, self.max_seq_len, C)` where :math:`C` is
+ ``num_classes``.
+ """
+ valid_ratios = []
+ for data_sample in data_samples:
+ valid_ratios.append(data_sample.get('valid_ratio'))
+ src_mask = self._get_source_mask(feat, valid_ratios)
+ N = feat.size(0)
+ init_target_seq = torch.full((N, self.max_seq_len + 1),
+ self.padding_idx,
+ device=feat.device,
+ dtype=torch.long)
+ # bsz * seq_len
+ init_target_seq[:, 0] = self.start_idx
+
+ outputs = []
+ for step in range(0, self.max_seq_len):
+ decoder_output = self._attention(
+ init_target_seq, feat, src_mask=src_mask)
+ # bsz * seq_len * C
+ step_result = self.classifier(decoder_output[:, step, :])
+ # bsz * num_classes
+ outputs.append(step_result)
+ _, step_max_index = torch.max(step_result, dim=-1)
+ init_target_seq[:, step + 1] = step_max_index
+
+ outputs = torch.stack(outputs, dim=1)
+
+ return self.softmax(outputs)
diff --git a/mmocr/models/textrecog/recognizers/__init__.py b/mmocr/models/textrecog/recognizers/__init__.py
index d9016492d..d517d6fbd 100644
--- a/mmocr/models/textrecog/recognizers/__init__.py
+++ b/mmocr/models/textrecog/recognizers/__init__.py
@@ -5,6 +5,7 @@
from .crnn import CRNN
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
from .encoder_decoder_recognizer_tta import EncoderDecoderRecognizerTTAModel
+from .maerec import MAERec
from .master import MASTER
from .nrtr import NRTR
from .robust_scanner import RobustScanner
@@ -15,5 +16,5 @@
__all__ = [
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER',
- 'EncoderDecoderRecognizerTTAModel'
+ 'EncoderDecoderRecognizerTTAModel', 'MAERec'
]
diff --git a/mmocr/models/textrecog/recognizers/maerec.py b/mmocr/models/textrecog/recognizers/maerec.py
new file mode 100644
index 000000000..788978f18
--- /dev/null
+++ b/mmocr/models/textrecog/recognizers/maerec.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmocr.registry import MODELS
+from .encoder_decoder_recognizer import EncoderDecoderRecognizer
+
+
+@MODELS.register_module()
+class MAERec(EncoderDecoderRecognizer):
+ """Implementation of MAERec."""
diff --git a/mmocr/version.py b/mmocr/version.py
index e83928324..17fdcd360 100644
--- a/mmocr/version.py
+++ b/mmocr/version.py
@@ -1,4 +1,4 @@
# Copyright (c) Open-MMLab. All rights reserved.
-__version__ = '1.0.0'
+__version__ = '1.0.1'
short_version = __version__
diff --git a/model-index.yml b/model-index.yml
index 563372c26..2a227cee0 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -10,6 +10,7 @@ Import:
- configs/textrecog/abinet/metafile.yml
- configs/textrecog/aster/metafile.yml
- configs/textrecog/crnn/metafile.yml
+ - configs/textrecog/maerec/metafile.yml
- configs/textrecog/master/metafile.yml
- configs/textrecog/nrtr/metafile.yml
- configs/textrecog/svtr/metafile.yml
diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt
index fe6b6d945..7f7953038 100644
--- a/requirements/mminstall.txt
+++ b/requirements/mminstall.txt
@@ -1,3 +1,3 @@
-mmcv>=2.0.0rc4,<2.1.0
-mmdet>=3.0.0rc5,<3.1.0
-mmengine>=0.7.0, <1.0.0
+mmcv>=2.0.0rc4,<2.2.0
+mmdet>=3.0.0rc5,<3.2.0
+mmengine>=0.7.0, <1.1.0
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index 52a9eec3c..e39d7328b 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -7,3 +7,4 @@ pyclipper
pycocotools
rapidfuzz>=2.0.0
scikit-image
+timm==0.9.2