Skip to content

Commit

Permalink
[NPU] Add NaiveSyncBatchNorm for NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
will-jl944 committed Aug 8, 2024
1 parent ac43b3e commit 6353a4a
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 61 deletions.
115 changes: 56 additions & 59 deletions ppocr/modeling/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

from .base_model import BaseModel
from .distillation_model import DistillationModel
from .custom_device_layers import NaiveSyncBatchNorm

__all__ = ["build_model", "apply_to_static"]
__all__ = ["build_model", "apply_to_static", "NaiveSyncBatchNorm"]


def build_model(config):
Expand All @@ -38,81 +39,77 @@ def build_model(config):
def apply_to_static(model, config, logger):
if config["Global"].get("to_static", False) is not True:
return model
assert "d2s_train_image_shape" in config[
"Global"], "d2s_train_image_shape must be assigned for static training mode..."
supported_list = [
"DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet", "SVTR"
]
assert (
"d2s_train_image_shape" in config["Global"]
), "d2s_train_image_shape must be assigned for static training mode..."
supported_list = ["DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet", "SVTR"]
if config["Architecture"]["algorithm"] in ["Distillation"]:
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
else:
algo = config["Architecture"]["algorithm"]
assert algo in supported_list, f"algorithms that supports static training must in in {supported_list} but got {algo}"
assert (
algo in supported_list
), f"algorithms that supports static training must in in {supported_list} but got {algo}"

specs = [
InputSpec(
[None] + config["Global"]["d2s_train_image_shape"], dtype='float32')
InputSpec([None] + config["Global"]["d2s_train_image_shape"], dtype="float32")
]

if algo == "SVTR_LCNet":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"]],
dtype='int64'), InputSpec(
[None, config["Global"]["max_text_length"]], dtype='int64'),
InputSpec(
[None], dtype='int64'), InputSpec(
[None], dtype='float64')
])
specs.append(
[
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec([None], dtype="int64"),
InputSpec([None], dtype="float64"),
]
)
elif algo == "TableMaster":
specs.append(
[
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec(
[None, config["Global"]["max_text_length"]], dtype='int64'),
InputSpec(
[None, config["Global"]["max_text_length"], 4],
dtype='float32'),
[None, config["Global"]["max_text_length"], 4], dtype="float32"
),
InputSpec(
[None, config["Global"]["max_text_length"], 1],
dtype='float32'),
InputSpec(
[None, 6], dtype='float32'),
])
[None, config["Global"]["max_text_length"], 1], dtype="float32"
),
InputSpec([None, 6], dtype="float32"),
]
)
elif algo == "LayoutXLM":
specs = [[
InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
InputSpec(
shape=[None, 3, 224, 224], dtype="float32"), # image
InputSpec(
shape=[None, 512], dtype="int64"), # label
]]
specs = [
[
InputSpec(shape=[None, 512], dtype="int64"), # input_ids
InputSpec(shape=[None, 512, 4], dtype="int64"), # bbox
InputSpec(shape=[None, 512], dtype="int64"), # attention_mask
InputSpec(shape=[None, 512], dtype="int64"), # token_type_ids
InputSpec(shape=[None, 3, 224, 224], dtype="float32"), # image
InputSpec(shape=[None, 512], dtype="int64"), # label
]
]
elif algo == "SLANet":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"] + 2], dtype='int64'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 4],
dtype='float32'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 1],
dtype='float32'),
InputSpec(
[None, 6], dtype='float64'),
])
specs.append(
[
InputSpec(
[None, config["Global"]["max_text_length"] + 2], dtype="int64"
),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 4], dtype="float32"
),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 1], dtype="float32"
),
InputSpec([None, 6], dtype="float64"),
]
)
elif algo == "SVTR":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"]], dtype='int64'),
InputSpec(
[None], dtype='int64')
])
specs.append(
[
InputSpec([None, config["Global"]["max_text_length"]], dtype="int64"),
InputSpec([None], dtype="int64"),
]
)
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model
125 changes: 125 additions & 0 deletions ppocr/modeling/architectures/custom_device_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import paddle
import paddle.nn as nn
import paddle.distributed as dist

__all__ = ["NaiveSyncBatchNorm"]


class _AllReduce(paddle.autograd.PyLayer):

@staticmethod
def forward(ctx, input):
input_list = [paddle.zeros_like(input) for k in range(dist.get_world_size())]
# Use allgather instead of allreduce since I don't trust in-place operations ..
dist.all_gather(input_list, input, sync_op=True)
inputs = paddle.stack(input_list, axis=0)
return paddle.sum(inputs, axis=0)

@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, sync_op=True)
return grad_output


def differentiable_all_reduce(input):
"""
Differentiable counterpart of `dist.all_reduce`.
"""
if (
not dist.is_available()
or not dist.is_initialized()
or dist.get_world_size() == 1
):
return input
return _AllReduce.apply(input)


class NaiveSyncBatchNorm(nn.BatchNorm2D):

def __init__(self, *args, stats_mode="", **kwargs):
super().__init__(*args, **kwargs)
assert stats_mode in ["", "N"]
self._stats_mode = stats_mode

def forward(self, input):
if dist.get_world_size() == 1 or not self.training:
return super().forward(input)

B, C = input.shape[0], input.shape[1]

mean = paddle.mean(input, axis=[0, 2, 3])
meansqr = paddle.mean(input * input, axis=[0, 2, 3])

if self._stats_mode == "":
assert (
B > 0
), 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
vec = paddle.concat([mean, meansqr], axis=0)
vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
mean, meansqr = paddle.split(vec, [C, C])
momentum = (
1 - self._momentum
) # NOTE: paddle has reverse momentum defination
else:
if B == 0:
vec = paddle.zeros([2 * C + 1], dtype=mean.dtype)
vec = vec + input.sum() # make sure there is gradient w.r.t input
else:
vec = paddle.concat(
[
mean,
meansqr,
paddle.ones([1], dtype=mean.dtype),
],
axis=0,
)
vec = differentiable_all_reduce(vec * B)

total_batch = vec[-1].detach()
momentum = total_batch.clip(max=1) * (
1 - self._momentum
) # no update if total_batch is 0
mean, meansqr, _ = paddle.split(
vec / total_batch.clip(min=1), [C, C, int(vec.shape[0] - 2 * C)]
) # avoid div-by-zero

var = meansqr - mean * mean
invstd = paddle.rsqrt(var + self._epsilon)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape([1, -1, 1, 1])
bias = bias.reshape([1, -1, 1, 1])

tmp_mean = self._mean + momentum * (mean.detach() - self._mean)
self._mean.set_value(tmp_mean)
tmp_variance = self._variance + (momentum * (var.detach() - self._variance))
self._variance.set_value(tmp_variance)
ret = input * scale + bias
return ret

@classmethod
def convert_sync_batchnorm(cls, layer):
layer_output = layer
if isinstance(layer, nn.BatchNorm2D):

layer_output = NaiveSyncBatchNorm(
layer._num_features,
layer._momentum,
layer._epsilon,
layer._weight_attr,
layer._bias_attr,
layer._data_format,
layer._name,
)

if layer._weight_attr is not False and layer._bias_attr is not False:
with paddle.no_grad():
layer_output.weight = layer.weight
layer_output.bias = layer.bias
layer_output._mean = layer._mean
layer_output._variance = layer._variance

for name, sublayer in layer.named_children():
layer_output.add_sublayer(name, cls.convert_sync_batchnorm(sublayer))
del layer
return layer_output
9 changes: 7 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
from ppocr.modeling.architectures import apply_to_static
from ppocr.modeling.architectures import NaiveSyncBatchNorm
import tools.program as program

dist.get_world_size()
Expand Down Expand Up @@ -138,8 +139,12 @@ def main(config, device, logger, vdl_writer):

use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm")
if "npu" in paddle.get_device() and dist.ParallelEnv().nranks > 1:
model = NaiveSyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm for NPU")
else:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm")

model = apply_to_static(model, config, logger)

Expand Down

0 comments on commit 6353a4a

Please sign in to comment.