forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cast_op.cu
134 lines (124 loc) · 3.68 KB
/
cast_op.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/cast_op.h"
#include "caffe2/utils/conversions.h"
namespace caffe2 {
template <typename DstType, typename SrcType>
__global__ void CastKernel(const int N, const SrcType* X, DstType* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
// Y[i] = static_cast<DstType>(X[i]);
Y[i] = convert::To<SrcType, DstType>(X[i]);
}
}
template <>
template <typename DstType, typename SrcType>
bool CastOp<CUDAContext>::DoRunWithType() {
auto& input = Input(0);
auto* output = Output(0, input.sizes(), at::dtype<DstType>());
const auto* data = input.template data<SrcType>();
auto* out = output->template mutable_data<DstType>();
DCHECK(input.numel() < INT_MAX);
int N = input.numel();
if (N == 0) {
// skip the rest of the computation if input is empty
return true;
}
CastKernel<DstType, SrcType>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(N, data, out);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
template <>
template <typename DstType>
bool CastOp<CUDAContext>::DoRunWithDstType() {
return DispatchHelper<
TensorTypes<
float,
int32_t,
bool,
uint8_t,
int8_t,
uint16_t,
int16_t,
int64_t,
double>,
DstType>::call(this, Input(0));
}
// specific version that allows for casting to fp16
template <>
template <>
bool CastOp<CUDAContext>::DoRunWithDstType<float>() {
return DispatchHelper<
TensorTypes<
float,
at::Half,
int32_t,
bool,
uint8_t,
int8_t,
uint16_t,
int16_t,
int64_t,
double>,
float /* DstType */>::call(this, Input(0));
}
// specific version for casting _from_ fp16
template <>
template <>
bool CastOp<CUDAContext>::DoRunWithDstType<at::Half>() {
return DispatchHelper<
TensorTypes<
float,
at::Half>,
at::Half /* DstType */>::call(this, Input(0));
}
template <>
void CastOp<CUDAContext>::SetBody(TensorProto_DataType to) {
switch (to) {
case TensorProto_DataType_FLOAT:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<float>;
break;
case TensorProto_DataType_INT32:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int>;
break;
case TensorProto_DataType_BYTE:
LOG(FATAL) << "BYTE is deprecated";
break;
case TensorProto_DataType_STRING:
CAFFE_THROW("Casting to and from strings is not supported yet");
// break;
case TensorProto_DataType_BOOL:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<bool>;
break;
case TensorProto_DataType_UINT8:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<uint8_t>;
break;
case TensorProto_DataType_INT8:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int8_t>;
break;
case TensorProto_DataType_UINT16:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<uint16_t>;
break;
case TensorProto_DataType_INT16:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int16_t>;
break;
case TensorProto_DataType_INT64:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int64_t>;
break;
case TensorProto_DataType_FLOAT16:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<at::Half>;
break;
case TensorProto_DataType_DOUBLE:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<double>;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("Cast op must have 'to' argument of type DataType");
// break;
default:
CAFFE_THROW("Unexpected 'to' argument value: ", to);
}
}
REGISTER_CUDA_OPERATOR(Cast, CastOp<CUDAContext>);
} // namespace caffe2