Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Jan 8, 2025
1 parent 5737916 commit 8558e58
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 36 deletions.
5 changes: 2 additions & 3 deletions oneflow/core/autograd/gradient_funcs/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
*in_grads = *JUST(functional::FuseLayerNormAffineGrad(
*in_grads = *JUST(functional::FuseLayerNormGrad(
dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon));
} else {
*in_grads = *JUST(functional::FuseLayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis,
begin_params_axis, ctx->epsilon));
UNIMPLEMENTED();
}
}
} else {
Expand Down
6 changes: 1 addition & 5 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1559,11 +1559,7 @@
bind_python: False

- name: "fuse_layer_norm_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad"
bind_python: False

- name: "fuse_layer_norm_affine_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormAffineGrad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad"
bind_python: False

- name: "layer_norm_param_grad"
Expand Down
29 changes: 1 addition & 28 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,34 +986,6 @@ class LayerNormAffineGradFunctor {
class FuseLayerNormGradFunctor {
public:
FuseLayerNormGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad")
.Input("dy")
.Input("x")
.Input("mean")
.Input("inv_variance")
.Output("dx")
.Output("gamma_diff")
.Output("beta_diff")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& mean,
const std::shared_ptr<one::Tensor>& inv_variance,
const int64_t& begin_norm_axis, const int64_t& begin_params_axis,
const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class FuseLayerNormAffineGradFunctor {
public:
FuseLayerNormAffineGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad")
.Input("dy")
.Input("x")
Expand Down Expand Up @@ -1765,6 +1737,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::LayerNormGradFunctor>("LayerNormGrad");
m.add_functor<impl::LayerNormAffineGradFunctor>("LayerNormAffineGrad");
m.add_functor<impl::LayerNormParamGradFunctor>("LayerNormParamGrad");
m.add_functor<impl::FuseLayerNormGradFunctor>("FuseLayerNormGrad");
m.add_functor<impl::GroupNormGradFunctor>("GroupNormGrad");
m.add_functor<impl::GroupNormParamGradFunctor>("GroupNormParamGrad");
m.add_functor<impl::BroadcastMatmulGradBFunctor>("BroadcastMatmulGradB");
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/job_rewriter/auto_mixed_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "inv_variance", 0)

} // namespace

Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"layer_norm",
"layer_norm_param_grad",
"layer_norm_grad",
"fuse_layer_norm_grad",
"skip_layer_norm",
"rms_norm",
"rms_norm_grad",
Expand Down
20 changes: 20 additions & 0 deletions oneflow/user/kernels/layer_norm_cpu_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ class LayerNormGradCpuKernel final : public user_op::OpKernel {
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(float)
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(double)

template<typename T>
class FuseLayerNormGradCpuKernel final : public user_op::OpKernel {
public:
FuseLayerNormGradCpuKernel() = default;
~FuseLayerNormGradCpuKernel() = default;

private:
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };
};

#define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("fuse_layer_norm_grad") \
.SetCreateFn<LayerNormGradCpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value));

REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float)
REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double)

template<typename T>
class LayerNormParamGradCpuKernel final : public user_op::OpKernel {
public:
Expand Down

0 comments on commit 8558e58

Please sign in to comment.