forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
channel_stats_op.cu
122 lines (111 loc) · 2.92 KB
/
channel_stats_op.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include "caffe2/operators/channel_stats_op.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math/reduce.cuh"
namespace caffe2 {
namespace {
template <typename T, int kBlockDimX, int kBlockDimY>
__global__ void ChannelStatsNCHWCUDAKernel(
const int N,
const int C,
const int HxW,
const T* X,
T* sum,
T* sumsq) {
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
__shared__
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
const int c = blockIdx.x;
T m_val = 0;
T v_val = 0;
for (int n = threadIdx.x; n < N; n += blockDim.x) {
for (int hw = threadIdx.y; hw < HxW; hw += blockDim.y) {
const int index = (n * C + c) * HxW + hw;
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
m_val += __ldg(X + index);
v_val += __ldg(X + index) * __ldg(X + index);
#else
m_val += X[index];
v_val += X[index] * X[index];
#endif
}
}
m_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(m_val);
v_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(v_val);
if (threadIdx.x == 0 && threadIdx.y == 0) {
sum[c] = m_val;
sumsq[c] = v_val;
}
}
template <typename T>
__global__ void ChannelStatsNHWCCUDAKernel(
const int N,
const int C,
const int HxW,
const T* X,
T* sum,
T* sumsq) {
__shared__ typename BlockReduce<T>::TempStorage m_storage;
__shared__ typename BlockReduce<T>::TempStorage v_storage;
const int inner_size = N * HxW;
const int c = blockIdx.x;
T m_val = 0;
T v_val = 0;
for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
const int index = i * C + c;
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
m_val += __ldg(X + index);
v_val += __ldg(X + index) * __ldg(X + index);
#else
m_val += X[index];
v_val += X[index] * X[index];
#endif
}
m_val = BlockReduce<T>(m_storage).Sum(m_val);
v_val = BlockReduce<T>(v_storage).Sum(v_val);
if (threadIdx.x == 0) {
sum[c] = m_val;
sumsq[c] = v_val;
}
}
} // namespace
template <>
template <>
bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNCHW<float>(
const int N,
const int C,
const int HxW,
const float* X,
float* sum,
float* sumsq) {
DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
HxW,
ChannelStatsNCHWCUDAKernel,
float,
C,
context_.cuda_stream(),
N,
C,
HxW,
X,
sum,
sumsq);
return true;
}
template <>
template <>
bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNHWC<float>(
const int N,
const int C,
const int HxW,
const float* X,
float* sum,
float* sumsq) {
ChannelStatsNHWCCUDAKernel<float>
<<<C, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, C, HxW, X, sum, sumsq);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
REGISTER_CUDA_OPERATOR(ChannelStats, ChannelStatsOp<CUDAContext>);
} // namespace caffe2