forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathblitBuffer.cl
129 lines (110 loc) · 4.31 KB
/
blitBuffer.cl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void blitBuffer(
const __global FLOAT *input,
__global FLOAT *output,
int4 inputOffset,
int4 outputOffset,
int4 region,
int4 inputStride,
int4 outputStride,
int2 wh
) {
int w = wh.x;
int h = wh.y;
int2 xy = (int2)(get_global_id(0), get_global_id(1));
//N, C, H, W
int4 pos = (int4)(xy.y/h, xy.x/w, xy.y%h, xy.x%w);
if (pos.x < region.x && pos.y < region.y) {
int4 posInput = inputOffset + pos;
int4 posOutput = outputOffset + pos;
int outputPos = posOutput.x * outputStride.x
+ posOutput.y * outputStride.y
+ posOutput.z * outputStride.z
+ posOutput.w * outputStride.w;
int inputPos = posInput.x * inputStride.x
+ posInput.y * inputStride.y
+ posInput.z * inputStride.z
+ posInput.w * inputStride.w;
output[outputPos] = input[inputPos];
}
}
__kernel void blitImageToBuffer(
__read_only image2d_t input,
__global FLOAT *output,
int4 inputOffset,
int4 outputOffset,
int4 region,
int2 inputWH,
int4 outputStride,
int4 outputSize/*nhwc*/
) {
int w = outputSize.z;
int h = outputSize.y;
int c = outputSize.w;
int n = outputSize.x;
int2 xy = (int2)(get_global_id(0), get_global_id(1));
//N, C, H, W
int4 pos = (int4)(xy.y/h, xy.x/w, xy.y%h, xy.x%w);
int4 bufferPos = pos * (int4)(1, 4, 1, 1);
if (pos.x < region.x && pos.y < region.y) {
int4 posInput = inputOffset + pos;
int4 posOutput = outputOffset + bufferPos;
int2 inputPos = (int2)(posInput.w + posInput.y*inputWH.x, posInput.x*inputWH.y + posInput.z);
FLOAT4 color = RI_F(input, SAMPLER, inputPos);
int outputPosBasic = posOutput.x*outputStride.x
+ posOutput.y*outputStride.y
+ posOutput.z*outputStride.z
+ posOutput.w*outputStride.w;
int outputPos0 = outputPosBasic + 0*outputStride.y;
output[outputPos0] = color.x;
if (posOutput.y + 1 < c) {
int outputPos1 = outputPosBasic + 1*outputStride.y;
output[outputPos1] = color.y;
}
if (posOutput.y + 2 < c) {
int outputPos1 = outputPosBasic + 2*outputStride.y;
output[outputPos1] = color.z;
}
if (posOutput.y + 3 < c) {
int outputPos1 = outputPosBasic + 3*outputStride.y;
output[outputPos1] = color.w;
}
}
}
__kernel void blitBufferToImage(
__global FLOAT *input,
__write_only image2d_t output,
int4 inputOffset,
int4 outputOffset,
int4 region,
int4 inputStride,
int2 outputWH,
int2 wh
) {
int w = wh.x;
int h = wh.y;
int2 xy = (int2)(get_global_id(0), get_global_id(1));
//N, C, H, W
int4 pos = (int4)(xy.y/h, xy.x/w, xy.y%h, xy.x%w);
int4 bufferPos = pos * (int4)(1, 4, 1, 1);
if (pos.x < region.x && pos.y < region.y) {
int4 posInput = inputOffset + bufferPos;
int4 posOutput = outputOffset + pos;
int2 outputPos = (int2)(posOutput.w + posOutput.y*outputWH.x, posOutput.x*outputWH.y + posOutput.z);
int inputPosBasic = posInput.x*inputStride.x
+posInput.y*inputStride.y
+posInput.z*inputStride.z
+posInput.w*inputStride.w;
int inputPos0 = inputPosBasic + 0*inputStride.y;
int inputPos1 = inputPosBasic + 1*inputStride.y;
int inputPos2 = inputPosBasic + 2*inputStride.y;
int inputPos3 = inputPosBasic + 3*inputStride.y;
FLOAT4 color;
color.x = input[inputPos0];
color.y = input[inputPos1];
color.z = input[inputPos2];
color.w = input[inputPos3];
WI_F(output, outputPos, color);
}
}