Skip to content

Commit

Permalink
fix latex_ocr inference (#14498)
Browse files Browse the repository at this point in the history
* add

* update

* add

* add
  • Loading branch information
vivienfanghuagood authored Jan 7, 2025
1 parent ed6fe28 commit 359ab6c
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions ppocr/modeling/backbones/rec_resnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def __init__(
self.export = is_export
self.eps = eps

self.running_mean = paddle.zeros([self._out_channels], dtype="float32")
self.running_variance = paddle.ones([self._out_channels], dtype="float32")
orin_shape = self.weight.shape
new_weight = F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]),
self.running_mean,
self.running_variance,
momentum=0.0,
epsilon=self.eps,
use_global_stats=False,
).reshape(orin_shape)
self.weight.set_value(new_weight.numpy())

def forward(self, x):
if not self.training:
self.export = True
Expand All @@ -96,30 +109,14 @@ def forward(self, x):
x = pad_same_export(x, self._kernel_size, self._stride, self._dilation)
else:
x = pad_same(x, self._kernel_size, self._stride, self._dilation)
running_mean = paddle.to_tensor([0] * self._out_channels, dtype="float32")
running_variance = paddle.to_tensor([1] * self._out_channels, dtype="float32")
if self.export:
weight = paddle.reshape(
F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]).cast(
paddle.float32
),
running_mean,
running_variance,
momentum=0.0,
epsilon=self.eps,
use_global_stats=False,
),
self.weight.shape,
)
weight = self.weight
else:
weight = paddle.reshape(
F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]).cast(
paddle.float32
),
running_mean,
running_variance,
self.weight.reshape([1, self._out_channels, -1]),
self.running_mean,
self.running_variance,
training=True,
momentum=0.0,
epsilon=self.eps,
Expand Down

0 comments on commit 359ab6c

Please sign in to comment.