forked from PaddlePaddle/PaddleOCR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrec_robustscanner_head.py
710 lines (589 loc) · 25.4 KB
/
rec_robustscanner_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/channel_reduction_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/robust_scanner_decoder.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
class BaseDecoder(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
def forward_train(self, feat, out_enc, targets, img_metas):
raise NotImplementedError
def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError
def forward(self,
feat,
out_enc,
label=None,
valid_ratios=None,
word_positions=None,
train_mode=True):
self.train_mode = train_mode
if train_mode:
return self.forward_train(feat, out_enc, label, valid_ratios, word_positions)
return self.forward_test(feat, out_enc, valid_ratios, word_positions)
class ChannelReductionEncoder(nn.Layer):
"""Change the channel number with a one by one convoluational layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self,
in_channels,
out_channels,
**kwargs):
super(ChannelReductionEncoder, self).__init__()
self.layer = nn.Conv2D(
in_channels, out_channels, kernel_size=1, stride=1, padding=0, weight_attr=nn.initializer.XavierNormal())
def forward(self, feat):
"""
Args:
feat (Tensor): Image features with the shape of
:math:`(N, C_{in}, H, W)`.
Returns:
Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
"""
return self.layer(feat)
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
class DotProductAttentionLayer(nn.Layer):
def __init__(self, dim_model=None):
super().__init__()
self.scale = dim_model**-0.5 if dim_model is not None else 1.
def forward(self, query, key, value, h, w, valid_ratios=None):
query = paddle.transpose(query, (0, 2, 1))
logits = paddle.matmul(query, key) * self.scale
n, c, t = logits.shape
# reshape to (n, c, h, w)
logits = paddle.reshape(logits, [n, c, h, w])
if valid_ratios is not None:
# cal mask of attention weight
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, int(w * valid_ratio + 0.5))
if valid_width < w:
logits[i, :, :, valid_width:] = float('-inf')
# reshape to (n, c, h, w)
logits = paddle.reshape(logits, [n, c, t])
weights = F.softmax(logits, axis=2)
value = paddle.transpose(value, (0, 2, 1))
glimpse = paddle.matmul(weights, value)
glimpse = paddle.transpose(glimpse, (0, 2, 1))
return glimpse
class SequenceAttentionDecoder(BaseDecoder):
"""Sequence attention decoder for RobustScanner.
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
Args:
num_classes (int): Number of output classes :math:`C`.
rnn_layers (int): Number of RNN layers.
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
same as encoder output vector ``out_enc``.
max_seq_len (int): Maximum output sequence length :math:`T`.
start_idx (int): The index of `<SOS>`.
mask (bool): Whether to mask input features according to
``img_meta['valid_ratio']``.
padding_idx (int): The index of `<PAD>`.
dropout (float): Dropout rate.
return_feature (bool): Return feature or logits as the result.
encode_value (bool): Whether to use the output of encoder ``out_enc``
as `value` of attention layer. If False, the original feature
``feat`` will be used.
Warning:
This decoder will not predict the final class which is assumed to be
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
is also ignored by loss as specified in
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
"""
def __init__(self,
num_classes=None,
rnn_layers=2,
dim_input=512,
dim_model=128,
max_seq_len=40,
start_idx=0,
mask=True,
padding_idx=None,
dropout=0,
return_feature=False,
encode_value=False):
super().__init__()
self.num_classes = num_classes
self.dim_input = dim_input
self.dim_model = dim_model
self.return_feature = return_feature
self.encode_value = encode_value
self.max_seq_len = max_seq_len
self.start_idx = start_idx
self.mask = mask
self.embedding = nn.Embedding(
self.num_classes, self.dim_model, padding_idx=padding_idx)
self.sequence_layer = nn.LSTM(
input_size=dim_model,
hidden_size=dim_model,
num_layers=rnn_layers,
time_major=False,
dropout=dropout)
self.attention_layer = DotProductAttentionLayer()
self.prediction = None
if not self.return_feature:
pred_num_classes = num_classes - 1
self.prediction = nn.Linear(
dim_model if encode_value else dim_input, pred_num_classes)
def forward_train(self, feat, out_enc, targets, valid_ratios):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
character.
valid_ratios (Tensor): valid length ratio of img.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
``return_feature=False``. Otherwise it would be the hidden feature
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
"""
tgt_embedding = self.embedding(targets)
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
assert c_feat == self.dim_input
_, len_q, c_q = tgt_embedding.shape
assert c_q == self.dim_model
assert len_q <= self.max_seq_len
query, _ = self.sequence_layer(tgt_embedding)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(out_enc, [n, c_enc, h * w])
if self.encode_value:
value = key
else:
value = paddle.reshape(feat, [n, c_feat, h * w])
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1))
if self.return_feature:
return attn_out
out = self.prediction(attn_out)
return out
def forward_test(self, feat, out_enc, valid_ratios):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
valid_ratios (Tensor): valid length ratio of img.
Returns:
Tensor: The output logit sequence tensor of shape
:math:`(N, T, C-1)`.
"""
seq_len = self.max_seq_len
batch_size = feat.shape[0]
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
outputs = []
for i in range(seq_len):
step_out = self.forward_test_step(feat, out_enc, decode_sequence,
i, valid_ratios)
outputs.append(step_out)
max_idx = paddle.argmax(step_out, axis=1, keepdim=False)
if i < seq_len - 1:
decode_sequence[:, i + 1] = max_idx
outputs = paddle.stack(outputs, 1)
return outputs
def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
valid_ratios):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
stores history decoding result.
current_step (int): Current decoding step.
valid_ratios (Tensor): valid length ratio of img
Returns:
Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
tokens at current time step.
"""
embed = self.embedding(decode_sequence)
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
assert c_feat == self.dim_input
_, _, c_q = embed.shape
assert c_q == self.dim_model
query, _ = self.sequence_layer(embed)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(out_enc, [n, c_enc, h * w])
if self.encode_value:
value = key
else:
value = paddle.reshape(feat, [n, c_feat, h * w])
# [n, c, l]
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
out = attn_out[:, :, current_step]
if self.return_feature:
return out
out = self.prediction(out)
out = F.softmax(out, dim=-1)
return out
class PositionAwareLayer(nn.Layer):
def __init__(self, dim_model, rnn_layers=2):
super().__init__()
self.dim_model = dim_model
self.rnn = nn.LSTM(
input_size=dim_model,
hidden_size=dim_model,
num_layers=rnn_layers,
time_major=False)
self.mixer = nn.Sequential(
nn.Conv2D(
dim_model, dim_model, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2D(
dim_model, dim_model, kernel_size=3, stride=1, padding=1))
def forward(self, img_feature):
n, c, h, w = img_feature.shape
rnn_input = paddle.transpose(img_feature, (0, 2, 3, 1))
rnn_input = paddle.reshape(rnn_input, (n * h, w, c))
rnn_output, _ = self.rnn(rnn_input)
rnn_output = paddle.reshape(rnn_output, (n, h, w, c))
rnn_output = paddle.transpose(rnn_output, (0, 3, 1, 2))
out = self.mixer(rnn_output)
return out
class PositionAttentionDecoder(BaseDecoder):
"""Position attention decoder for RobustScanner.
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
Args:
num_classes (int): Number of output classes :math:`C`.
rnn_layers (int): Number of RNN layers.
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
same as encoder output vector ``out_enc``.
max_seq_len (int): Maximum output sequence length :math:`T`.
mask (bool): Whether to mask input features according to
``img_meta['valid_ratio']``.
return_feature (bool): Return feature or logits as the result.
encode_value (bool): Whether to use the output of encoder ``out_enc``
as `value` of attention layer. If False, the original feature
``feat`` will be used.
Warning:
This decoder will not predict the final class which is assumed to be
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
is also ignored by loss
"""
def __init__(self,
num_classes=None,
rnn_layers=2,
dim_input=512,
dim_model=128,
max_seq_len=40,
mask=True,
return_feature=False,
encode_value=False):
super().__init__()
self.num_classes = num_classes
self.dim_input = dim_input
self.dim_model = dim_model
self.max_seq_len = max_seq_len
self.return_feature = return_feature
self.encode_value = encode_value
self.mask = mask
self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
self.position_aware_module = PositionAwareLayer(
self.dim_model, rnn_layers)
self.attention_layer = DotProductAttentionLayer()
self.prediction = None
if not self.return_feature:
pred_num_classes = num_classes - 1
self.prediction = nn.Linear(
dim_model if encode_value else dim_input, pred_num_classes)
def _get_position_index(self, length, batch_size):
position_index_list = []
for i in range(batch_size):
position_index = paddle.arange(0, end=length, step=1, dtype='int64')
position_index_list.append(position_index)
batch_position_index = paddle.stack(position_index_list, axis=0)
return batch_position_index
def forward_train(self, feat, out_enc, targets, valid_ratios, position_index):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
targets (dict): A dict with the key ``padded_targets``, a
tensor of shape :math:`(N, T)`. Each element is the index of a
character.
valid_ratios (Tensor): valid length ratio of img.
position_index (Tensor): The position of each word.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
``return_feature=False``. Otherwise it will be the hidden feature
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
"""
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
assert c_feat == self.dim_input
_, len_q = targets.shape
assert len_q <= self.max_seq_len
position_out_enc = self.position_aware_module(out_enc)
query = self.embedding(position_index)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
if self.encode_value:
value = paddle.reshape(out_enc,(n, c_enc, h * w))
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
if self.return_feature:
return attn_out
return self.prediction(attn_out)
def forward_test(self, feat, out_enc, valid_ratios, position_index):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
valid_ratios (Tensor): valid length ratio of img
position_index (Tensor): The position of each word.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
``return_feature=False``. Otherwise it would be the hidden feature
before the prediction projection layer, whose shape is
:math:`(N, T, D_m)`.
"""
n, c_enc, h, w = out_enc.shape
assert c_enc == self.dim_model
_, c_feat, _, _ = feat.shape
assert c_feat == self.dim_input
position_out_enc = self.position_aware_module(out_enc)
query = self.embedding(position_index)
query = paddle.transpose(query, (0, 2, 1))
key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
if self.encode_value:
value = paddle.reshape(out_enc,(n, c_enc, h * w))
else:
value = paddle.reshape(feat,(n, c_feat, h * w))
attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
if self.return_feature:
return attn_out
return self.prediction(attn_out)
class RobustScannerFusionLayer(nn.Layer):
def __init__(self, dim_model, dim=-1):
super(RobustScannerFusionLayer, self).__init__()
self.dim_model = dim_model
self.dim = dim
self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
def forward(self, x0, x1):
assert x0.shape == x1.shape
fusion_input = paddle.concat([x0, x1], self.dim)
output = self.linear_layer(fusion_input)
output = F.glu(output, self.dim)
return output
class RobustScannerDecoder(BaseDecoder):
"""Decoder for RobustScanner.
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
Args:
num_classes (int): Number of output classes :math:`C`.
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
same as encoder output vector ``out_enc``.
max_seq_len (int): Maximum output sequence length :math:`T`.
start_idx (int): The index of `<SOS>`.
mask (bool): Whether to mask input features according to
``img_meta['valid_ratio']``.
padding_idx (int): The index of `<PAD>`.
encode_value (bool): Whether to use the output of encoder ``out_enc``
as `value` of attention layer. If False, the original feature
``feat`` will be used.
Warning:
This decoder will not predict the final class which is assumed to be
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
is also ignored by loss as specified in
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
"""
def __init__(self,
num_classes=None,
dim_input=512,
dim_model=128,
hybrid_decoder_rnn_layers=2,
hybrid_decoder_dropout=0,
position_decoder_rnn_layers=2,
max_seq_len=40,
start_idx=0,
mask=True,
padding_idx=None,
encode_value=False):
super().__init__()
self.num_classes = num_classes
self.dim_input = dim_input
self.dim_model = dim_model
self.max_seq_len = max_seq_len
self.encode_value = encode_value
self.start_idx = start_idx
self.padding_idx = padding_idx
self.mask = mask
# init hybrid decoder
self.hybrid_decoder = SequenceAttentionDecoder(
num_classes=num_classes,
rnn_layers=hybrid_decoder_rnn_layers,
dim_input=dim_input,
dim_model=dim_model,
max_seq_len=max_seq_len,
start_idx=start_idx,
mask=mask,
padding_idx=padding_idx,
dropout=hybrid_decoder_dropout,
encode_value=encode_value,
return_feature=True
)
# init position decoder
self.position_decoder = PositionAttentionDecoder(
num_classes=num_classes,
rnn_layers=position_decoder_rnn_layers,
dim_input=dim_input,
dim_model=dim_model,
max_seq_len=max_seq_len,
mask=mask,
encode_value=encode_value,
return_feature=True
)
self.fusion_module = RobustScannerFusionLayer(
self.dim_model if encode_value else dim_input)
pred_num_classes = num_classes - 1
self.prediction = nn.Linear(dim_model if encode_value else dim_input,
pred_num_classes)
def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
target (dict): A dict with the key ``padded_targets``, a
tensor of shape :math:`(N, T)`. Each element is the index of a
character.
valid_ratios (Tensor):
word_positions (Tensor): The position of each word.
Returns:
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
"""
hybrid_glimpse = self.hybrid_decoder.forward_train(
feat, out_enc, target, valid_ratios)
position_glimpse = self.position_decoder.forward_train(
feat, out_enc, target, valid_ratios, word_positions)
fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
out = self.prediction(fusion_out)
return out
def forward_test(self, feat, out_enc, valid_ratios, word_positions):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
out_enc (Tensor): Encoder output of shape
:math:`(N, D_m, H, W)`.
valid_ratios (Tensor):
word_positions (Tensor): The position of each word.
Returns:
Tensor: The output logit sequence tensor of shape
:math:`(N, T, C-1)`.
"""
seq_len = self.max_seq_len
batch_size = feat.shape[0]
decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
position_glimpse = self.position_decoder.forward_test(
feat, out_enc, valid_ratios, word_positions)
outputs = []
for i in range(seq_len):
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
feat, out_enc, decode_sequence, i, valid_ratios)
fusion_out = self.fusion_module(hybrid_glimpse_step,
position_glimpse[:, i, :])
char_out = self.prediction(fusion_out)
char_out = F.softmax(char_out, -1)
outputs.append(char_out)
max_idx = paddle.argmax(char_out, axis=1, keepdim=False)
if i < seq_len - 1:
decode_sequence[:, i + 1] = max_idx
outputs = paddle.stack(outputs, 1)
return outputs
class RobustScannerHead(nn.Layer):
def __init__(self,
out_channels, # 90 + unknown + start + padding
in_channels,
enc_outchannles=128,
hybrid_dec_rnn_layers=2,
hybrid_dec_dropout=0,
position_dec_rnn_layers=2,
start_idx=0,
max_text_length=40,
mask=True,
padding_idx=None,
encode_value=False,
**kwargs):
super(RobustScannerHead, self).__init__()
# encoder module
self.encoder = ChannelReductionEncoder(
in_channels=in_channels, out_channels=enc_outchannles)
# decoder module
self.decoder =RobustScannerDecoder(
num_classes=out_channels,
dim_input=in_channels,
dim_model=enc_outchannles,
hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
hybrid_decoder_dropout=hybrid_dec_dropout,
position_decoder_rnn_layers=position_dec_rnn_layers,
max_seq_len=max_text_length,
start_idx=start_idx,
mask=mask,
padding_idx=padding_idx,
encode_value=encode_value)
def forward(self, inputs, targets=None):
'''
targets: [label, valid_ratio, word_positions]
'''
out_enc = self.encoder(inputs)
valid_ratios = None
word_positions = targets[-1]
if len(targets) > 1:
valid_ratios = targets[-2]
if self.training:
label = targets[0] # label
label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder(
inputs, out_enc, label, valid_ratios, word_positions)
if not self.training:
final_out = self.decoder(
inputs,
out_enc,
label=None,
valid_ratios=valid_ratios,
word_positions=word_positions,
train_mode=False)
return final_out