-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathScatterNDPlugin.cu
executable file
·312 lines (250 loc) · 8.49 KB
/
ScatterNDPlugin.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
/**
* For the usage of those member function, please refer to the
* offical api doc.
* https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_plugin_v2_ext.html
*/
#include "ScatterNDPlugin.h"
#include <cassert>
#include <iostream>
#include <string.h>
#include "cuda_runtime.h"
#include "cuda_fp16.h"
// Use fp16 mode for inference
#define DATA_TYPE nvinfer1::DataType::kHALF
#define THREAD_NUM 1024
// Helper function for deserializing plugin
template <typename T>
T readFromBuffer(const char*& buffer)
{
T val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
return val;
}
// Helper function for serializing plugin
template <typename T>
void writeToBuffer(char*& buffer, const T& val)
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
using namespace nvinfer1;
using nvinfer1::plugin::ScatterNDPlugin;
using nvinfer1::plugin::ScatterNDSamplePluginCreator;
static const char* SCATTERND_PLUGIN_VERSION{"1"};
static const char* SCATTERND_PLUGIN_NAME{"ScatterND"};
PluginFieldCollection ScatterNDSamplePluginCreator::mFC{};
std::vector<PluginField> ScatterNDSamplePluginCreator::mPluginAttributes;
ScatterNDPlugin::ScatterNDPlugin(const std::string name, const size_t outputShapeArray[],
const size_t indexShapeArray[], const DataType type) : mLayerName(name), mDataType(type)
{
mOutputSize[0] = outputShapeArray[0];
mOutputSize[1] = outputShapeArray[1];
mInputIndexSize[0] = indexShapeArray[0];
mInputIndexSize[1] = indexShapeArray[1];
}
ScatterNDPlugin::ScatterNDPlugin(const std::string name, const void* data, size_t length)
: mLayerName(name)
{
const char *d = reinterpret_cast<const char *>(data);
const char *a = d;
mDataType = readFromBuffer<DataType>(d);
mOutputSize[0] = readFromBuffer<size_t>(d);
mOutputSize[1] = readFromBuffer<size_t>(d);
mInputIndexSize[0] = readFromBuffer<size_t>(d);
mInputIndexSize[1] = readFromBuffer<size_t>(d);
assert(d == a + length);
}
int ScatterNDPlugin::getNbOutputs() const
{
return 1;
}
Dims ScatterNDPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
{
// scatterND data input
return Dims2(inputs[0].d[0],inputs[0].d[1]);
}
int ScatterNDPlugin::initialize()
{
return 0;
}
size_t ScatterNDPlugin::getWorkspaceSize(int) const
{
return 0;
}
DataType ScatterNDPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
{
return inputTypes[2];
}
template <typename Dtype>
__global__ void _ScatterNDKernel(const Dtype *updata_input, const int *indicesInputPtr , Dtype* output,
int channel_num, int max_index_num) {
int idx_num = blockDim.x * blockIdx.x + threadIdx.x;
if (idx_num >= max_index_num) return;
int idx_output = indicesInputPtr[idx_num*2+1];
if (idx_output < 0) return;
for(int idx=0; idx < channel_num; idx++){
output[idx_output*channel_num+idx] = updata_input[idx_num*channel_num+idx];
}
}
int ScatterNDPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void*, cudaStream_t stream)
{
int channel_num = mOutputSize[1];
int max_index_num = mInputIndexSize[0];
int totalElems = mOutputSize[0]*channel_num;
dim3 blockSize(THREAD_NUM);
dim3 gridsize(max_index_num/blockSize.x+1);
// if you want to inference use fp32, change the DATA_TYPE
switch (mDataType)
{
case nvinfer1::DataType::kFLOAT:
cudaMemset(outputs[0], 0, totalElems * sizeof(float));
_ScatterNDKernel<<<gridsize, blockSize,0,stream>>>(static_cast<float const*> (inputs[2]), static_cast<int32_t const*> (inputs[1]),
static_cast<float *> (outputs[0]), channel_num, max_index_num);
break;
case nvinfer1::DataType::kHALF:
cudaMemset(outputs[0], 0, totalElems * sizeof(float)/2);
_ScatterNDKernel<<<gridsize, blockSize,0,stream>>>(static_cast<int16_t const*> (inputs[2]), static_cast<int32_t const*> (inputs[1]),
static_cast<int16_t *> (outputs[0]), channel_num, max_index_num);
break;
default:
std::cout << "[ERROR]: mDataType dones't support" << std::endl;
}
return 0;
}
void ScatterNDPlugin::serialize(void* buffer) const
{
char* d = static_cast<char*>(buffer);
char *a = d;
writeToBuffer<DataType>(d, mDataType);
writeToBuffer<size_t>(d, mOutputSize[0]);
writeToBuffer<size_t>(d, mOutputSize[1]);
writeToBuffer<size_t>(d, mInputIndexSize[0]);
writeToBuffer<size_t>(d, mInputIndexSize[1]);
assert(d == a + getSerializationSize());
}
void ScatterNDPlugin::terminate() {
}
size_t ScatterNDPlugin::getSerializationSize() const
{
return sizeof(DataType)+ 4*sizeof(size_t);
}
bool ScatterNDPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
{
return false;
}
bool ScatterNDPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
{
return false;
}
void ScatterNDPlugin::configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs,
const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast,
const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize)
{
mOutputSize[0] = outputDims[0].d[0];
mOutputSize[1] = outputDims[0].d[1];
mInputIndexSize[0] = inputDims[1].d[0];
mInputIndexSize[1] = inputDims[1].d[1];
}
bool ScatterNDPlugin::supportsFormat(DataType type, PluginFormat format) const
{
switch (type)
{
case nvinfer1::DataType::kINT32: return true;
case nvinfer1::DataType::kFLOAT: return true;
case nvinfer1::DataType::kHALF: return true;
}
return false;
}
/**
* NO NEED TO MODIFY
*/
const char* ScatterNDPlugin::getPluginType() const
{
return SCATTERND_PLUGIN_NAME;
}
/**
* NO NEED TO MODIFY
*/
const char* ScatterNDPlugin::getPluginVersion() const
{
return SCATTERND_PLUGIN_VERSION;
}
void ScatterNDPlugin::destroy()
{
delete this;
}
IPluginV2Ext* ScatterNDPlugin::clone() const
{
auto* plugin = new ScatterNDPlugin(mLayerName, mOutputSize, mInputIndexSize, mDataType);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
/**
* NO NEED TO MODIFY
*/
void ScatterNDPlugin::setPluginNamespace(const char* libNamespace)
{
mNamespace = libNamespace;
}
/**
* NO NEED TO MODIFY
*/
const char* ScatterNDPlugin::getPluginNamespace() const
{
return mNamespace.c_str();
}
ScatterNDSamplePluginCreator::ScatterNDSamplePluginCreator()
{
mPluginAttributes.emplace_back(PluginField("output_shape", nullptr, PluginFieldType::kINT32, 3));
mPluginAttributes.emplace_back(PluginField("index_shape", nullptr, PluginFieldType::kINT32, 3));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
/**
* NO NEED TO MODIFY
*/
const char* ScatterNDSamplePluginCreator::getPluginName() const
{
return SCATTERND_PLUGIN_NAME;
}
/**
* NO NEED TO MODIFY
*/
const char* ScatterNDSamplePluginCreator::getPluginVersion() const
{
return SCATTERND_PLUGIN_VERSION;
}
/**
* NO NEED TO MODIFY
*/
const PluginFieldCollection* ScatterNDSamplePluginCreator::getFieldNames()
{
return &mFC;
}
IPluginV2Ext* ScatterNDSamplePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
const nvinfer1::PluginField* fields = fc->fields;
mDataType = DATA_TYPE;
size_t indexShapeArray[2] = {0};
size_t outputShapeArray[2] = {0};
for (int i=0; i<fc->nbFields; i++) {
if(!strcmp(fields[i].name, "output_shape")){
const auto *outputShapeAttr = static_cast<const int32_t*>(fields[i].data);
outputShapeArray[0] = outputShapeAttr[1];
outputShapeArray[1] = outputShapeAttr[2];
}
if(!strcmp(fields[i].name, "index_shape")){
const auto * indexShapeAttr = static_cast<const int32_t*>(fields[i].data);
indexShapeArray[0] = indexShapeAttr[1];
indexShapeArray[1] = indexShapeAttr[2];
}
}
auto* plugin = new ScatterNDPlugin(name, outputShapeArray, indexShapeArray, mDataType);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
IPluginV2Ext* ScatterNDSamplePluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
{
return new ScatterNDPlugin(name, serialData, serialLength);
}
REGISTER_TENSORRT_PLUGIN(ScatterNDSamplePluginCreator);