Skip to content

Commit

Permalink
optimize hardsigmoid grad decmop (PaddlePaddle#70083)
Browse files Browse the repository at this point in the history
* optimize hardsigmoid grad decmop

* fix bug
  • Loading branch information
phlrain authored Dec 10, 2024
1 parent fe9c296 commit 15e4cb7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
CUSTOM_VJP = [
'bce_loss_grad',
'gelu_grad',
'hardsigmoid_grad',
'hardswish_grad',
'leaky_relu_grad',
'mean_grad',
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/decomp_vjp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
'dropout_grad',
'gelu_grad',
'group_norm_grad',
'hardsigmoid_grad',
'hardswish_grad',
'instance_norm_grad',
'layer_norm_grad',
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,24 @@ void tile_grad(const Tensor& x,
}
}

template <typename T>
void hardsigmoid_grad(const Tensor& out,
const Tensor& out_grad,
float slope,
float offset,
Tensor* x_grad) {
if (x_grad) {
Tensor zeros = full_scalar<T>(0.0, out.dtype());
Tensor one = full_scalar<T>(1.0, out.dtype());
auto mask_gt = greater_than<T>(out, zeros);
auto mask_lt = less_than<T>(out, one);
auto mask = bitwise_and<T>(mask_gt, mask_lt);
Tensor slope_tensor = full_scalar<T>(slope, out.dtype());
auto res = cast<T>(mask, out.dtype()) * slope_tensor * out_grad;
set_output<T>(res, x_grad);
}
}

template <typename T>
void hardswish_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
Expand Down

0 comments on commit 15e4cb7

Please sign in to comment.