diff --git a/python/test/function/test_clip_grad_by_norm.py b/python/test/function/test_clip_grad_by_norm.py index ace0ad9a6..42b32299e 100644 --- a/python/test/function/test_clip_grad_by_norm.py +++ b/python/test/function/test_clip_grad_by_norm.py @@ -27,9 +27,9 @@ def ref_clip_grad_by_norm(x, clip_norm, axes): def ref_grad_clip_by_norm(x, dy, clip_norm, axes): dx = np.copy(dy) - dx = clip_norm * dy / \ - np.broadcast_to( - np.sqrt(np.sum(dy**2, axis=axes, keepdims=True)), dy.shape) + norm = np.sqrt(np.sum(dy**2, axis=axes, keepdims=True)) + norm[norm < clip_norm] = clip_norm + dx = clip_norm * dy / np.broadcast_to(norm, dy.shape) return dx.flatten() diff --git a/src/nbla/function/generic/clip_grad_by_norm.cpp b/src/nbla/function/generic/clip_grad_by_norm.cpp index 8c0590381..8d85adb85 100644 --- a/src/nbla/function/generic/clip_grad_by_norm.cpp +++ b/src/nbla/function/generic/clip_grad_by_norm.cpp @@ -56,7 +56,7 @@ template void clip_grad_by_norm_backward_cpu(int size, T clip_norm_grad, T *dx, const T *dy, const T *m) { for (int s = 0; s < size; ++s) { - T _dx = clip_norm_grad * dy[s] / std::sqrt(m[s]); + T _dx = clip_norm_grad * dy[s] / std::max(std::sqrt(m[s]), clip_norm_grad); accum ? dx[s] += _dx : dx[s] = _dx; } }