Skip to content

Commit

Permalink
Merge pull request #9 from sony/feature/20170808-absolute_error
Browse files Browse the repository at this point in the history
Implement absolute error of cuda version.
  • Loading branch information
TakuyaNarihira authored Aug 21, 2017
2 parents dc5af19 + d262ae6 commit 6a4ef0a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/implements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ Function Implements
- x
-
* - SquaredError
- x
-
* - AbsoluteError
- x
-
* - KLMultinomial
Expand Down
31 changes: 31 additions & 0 deletions include/nbla/cuda/function/absolute_error.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/** AbsoluteError
*/
#ifndef __NBLA_CUDA_FUNCTION_ABSOLUTEERROR_HPP__
#define __NBLA_CUDA_FUNCTION_ABSOLUTEERROR_HPP__

#include <nbla/cuda/common.hpp>
#include <nbla/cuda/cuda.hpp>
#include <nbla/function/absolute_error.hpp>

#include <nbla/cuda/function/utils/base_transform_binary.hpp>

namespace nbla {
/** @copydoc AbsoluteError
*/
NBLA_DECLARE_TRANSFORM_BINARY_CUDA(AbsoluteError);
}
#endif
30 changes: 30 additions & 0 deletions src/nbla/cuda/function/absolute_error.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <nbla/array.hpp>
#include <nbla/cuda/common.hpp>
#include <nbla/cuda/function/absolute_error.hpp>
#include <nbla/cuda/function/utils/base_transform_binary.cuh>
#include <nbla/cuda/math.hpp>
#include <nbla/variable.hpp>

namespace nbla {

NBLA_DEFINE_TRANSFORM_BINARY_CUDA(AbsoluteError, abs((x0 - x1)),
((x0 - x1) > (T)0) ? dy : -dy,
((x0 - x1) > (T)0) ? -dy : dy);

// template instantiation
template class AbsoluteErrorCuda<float>;
}

0 comments on commit 6a4ef0a

Please sign in to comment.