1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <arm_neon.h>
9
10 #include <xnnpack/argmaxpool.h>
11
12
xnn_f32_argmaxpool_ukernel_9p8x__neon_c4(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * accumulation_buffer,uint32_t * index_buffer,float * output,uint32_t * index,size_t input_increment,size_t output_increment)13 void xnn_f32_argmaxpool_ukernel_9p8x__neon_c4(
14 size_t output_pixels,
15 size_t pooling_elements,
16 size_t channels,
17 const float** input,
18 size_t input_offset,
19 float* accumulation_buffer,
20 uint32_t* index_buffer,
21 float* output,
22 uint32_t* index,
23 size_t input_increment,
24 size_t output_increment) XNN_DISABLE_TSAN
25 {
26 assert(output_pixels != 0);
27 assert(pooling_elements != 0);
28 assert(pooling_elements > 9);
29 assert(channels != 0);
30
31 do {
32 {
33 float* ab = accumulation_buffer;
34 uint32_t* ib = index_buffer;
35
36 const float* i0 = *input++;
37 const float* i1 = *input++;
38 const float* i2 = *input++;
39 const float* i3 = *input++;
40 const float* i4 = *input++;
41 const float* i5 = *input++;
42 const float* i6 = *input++;
43 const float* i7 = *input++;
44 const float* i8 = *input++;
45 i0 = (const float*) ((uintptr_t) i0 + input_offset);
46 i1 = (const float*) ((uintptr_t) i1 + input_offset);
47 i2 = (const float*) ((uintptr_t) i2 + input_offset);
48 i3 = (const float*) ((uintptr_t) i3 + input_offset);
49 i4 = (const float*) ((uintptr_t) i4 + input_offset);
50 i5 = (const float*) ((uintptr_t) i5 + input_offset);
51 i6 = (const float*) ((uintptr_t) i6 + input_offset);
52 i7 = (const float*) ((uintptr_t) i7 + input_offset);
53 i8 = (const float*) ((uintptr_t) i8 + input_offset);
54
55 for (size_t c = 0; c < channels; c += 4) {
56 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
57 const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
58 const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
59 const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
60 const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
61 const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
62 const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
63 const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
64 const float32x4_t vi8 = vld1q_f32(i8); i8 += 4;
65
66 float32x4_t vmax = vi0;
67 uint32x4_t vidx = vmovq_n_u32(0);
68
69 const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
70 vmax = vbslq_f32(vm1, vi1, vmax);
71 vidx = vbslq_u32(vm1, vmovq_n_u32(1), vidx);
72
73 const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
74 vmax = vbslq_f32(vm2, vi2, vmax);
75 vidx = vbslq_u32(vm2, vmovq_n_u32(2), vidx);
76
77 const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
78 vmax = vbslq_f32(vm3, vi3, vmax);
79 vidx = vbslq_u32(vm3, vmovq_n_u32(3), vidx);
80
81 const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
82 vmax = vbslq_f32(vm4, vi4, vmax);
83 vidx = vbslq_u32(vm4, vmovq_n_u32(4), vidx);
84
85 const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
86 vmax = vbslq_f32(vm5, vi5, vmax);
87 vidx = vbslq_u32(vm5, vmovq_n_u32(5), vidx);
88
89 const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
90 vmax = vbslq_f32(vm6, vi6, vmax);
91 vidx = vbslq_u32(vm6, vmovq_n_u32(6), vidx);
92
93 const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
94 vmax = vbslq_f32(vm7, vi7, vmax);
95 vidx = vbslq_u32(vm7, vmovq_n_u32(7), vidx);
96
97 const uint32x4_t vm8 = vcgtq_f32(vi8, vmax);
98 vmax = vbslq_f32(vm8, vi8, vmax);
99 vidx = vbslq_u32(vm8, vmovq_n_u32(8), vidx);
100
101 vst1q_f32(ab, vmax); ab += 4;
102 vst1q_u32(ib, vidx); ib += 4;
103 }
104 }
105 const uint32x4_t v1 = vmovq_n_u32(1);
106 const uint32x4_t v8 = vmovq_n_u32(8);
107 uint32x4_t vidx0 = vaddq_u32(v1, v8);
108
109 size_t k = pooling_elements;
110 for (k -= 9; k > 8; k -= 8) {
111 const float* i0 = *input++;
112 const float* i1 = *input++;
113 const float* i2 = *input++;
114 const float* i3 = *input++;
115 const float* i4 = *input++;
116 const float* i5 = *input++;
117 const float* i6 = *input++;
118 const float* i7 = *input++;
119 i0 = (const float*) ((uintptr_t) i0 + input_offset);
120 i1 = (const float*) ((uintptr_t) i1 + input_offset);
121 i2 = (const float*) ((uintptr_t) i2 + input_offset);
122 i3 = (const float*) ((uintptr_t) i3 + input_offset);
123 i4 = (const float*) ((uintptr_t) i4 + input_offset);
124 i5 = (const float*) ((uintptr_t) i5 + input_offset);
125 i6 = (const float*) ((uintptr_t) i6 + input_offset);
126 i7 = (const float*) ((uintptr_t) i7 + input_offset);
127
128 float* ab = accumulation_buffer;
129 uint32_t* ib = index_buffer;
130
131 for (size_t c = 0; c < channels; c += 4) {
132 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
133 const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
134 const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
135 const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
136 const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
137 const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
138 const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
139 const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
140
141 float32x4_t vmax = vld1q_f32(ab);
142 uint32x4_t vidx = vld1q_u32(ib);
143
144 const uint32x4_t vm0 = vcgtq_f32(vi0, vmax);
145 vmax = vbslq_f32(vm0, vi0, vmax);
146 vidx = vbslq_u32(vm0, vidx0, vidx);
147
148 const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
149 const uint32x4_t vidx1 = vaddq_u32(vidx0, v1);
150 vmax = vbslq_f32(vm1, vi1, vmax);
151 vidx = vbslq_u32(vm1, vidx1, vidx);
152
153 const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
154 const uint32x4_t vidx2 = vaddq_u32(vidx1, v1);
155 vmax = vbslq_f32(vm2, vi2, vmax);
156 vidx = vbslq_u32(vm2, vidx2, vidx);
157
158 const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
159 const uint32x4_t vidx3 = vaddq_u32(vidx2, v1);
160 vmax = vbslq_f32(vm3, vi3, vmax);
161 vidx = vbslq_u32(vm3, vidx3, vidx);
162
163 const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
164 const uint32x4_t vidx4 = vaddq_u32(vidx3, v1);
165 vmax = vbslq_f32(vm4, vi4, vmax);
166 vidx = vbslq_u32(vm4, vidx4, vidx);
167
168 const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
169 const uint32x4_t vidx5 = vaddq_u32(vidx4, v1);
170 vmax = vbslq_f32(vm5, vi5, vmax);
171 vidx = vbslq_u32(vm5, vidx5, vidx);
172
173 const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
174 const uint32x4_t vidx6 = vaddq_u32(vidx5, v1);
175 vmax = vbslq_f32(vm6, vi6, vmax);
176 vidx = vbslq_u32(vm6, vidx6, vidx);
177
178 const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
179 const uint32x4_t vidx7 = vaddq_u32(vidx6, v1);
180 vmax = vbslq_f32(vm7, vi7, vmax);
181 vidx = vbslq_u32(vm7, vidx7, vidx);
182
183 vst1q_f32(ab, vmax); ab += 4;
184 vst1q_u32(ib, vidx); ib += 4;
185 }
186 vidx0 = vaddq_u32(vidx0, v8);
187 }
188
189 float* o = output;
190 uint32_t* i = index;
191 {
192 const float* i0 = input[0];
193 const float* i1 = input[1];
194 const float* i2 = input[2];
195 const float* i3 = input[3];
196 const float* i4 = input[4];
197 const float* i5 = input[5];
198 const float* i6 = input[6];
199 const float* i7 = input[7];
200 i0 = (const float*) ((uintptr_t) i0 + input_offset);
201 i1 = (const float*) ((uintptr_t) i1 + input_offset);
202 i2 = (const float*) ((uintptr_t) i2 + input_offset);
203 i3 = (const float*) ((uintptr_t) i3 + input_offset);
204 i4 = (const float*) ((uintptr_t) i4 + input_offset);
205 i5 = (const float*) ((uintptr_t) i5 + input_offset);
206 i6 = (const float*) ((uintptr_t) i6 + input_offset);
207 i7 = (const float*) ((uintptr_t) i7 + input_offset);
208 input = (const float**) ((uintptr_t) input + input_increment);
209 if (k < 2) {
210 i1 = i0;
211 }
212 if (k <= 2) {
213 i2 = i0;
214 }
215 if (k < 4) {
216 i3 = i0;
217 }
218 if (k <= 4) {
219 i4 = i0;
220 }
221 if (k < 6) {
222 i5 = i0;
223 }
224 if (k <= 6) {
225 i6 = i0;
226 }
227 if (k != 8) {
228 i7 = i0;
229 }
230
231 size_t c = channels;
232 float* ab = accumulation_buffer;
233 uint32_t* ib = index_buffer;
234 for (; c >= 4; c -= 4) {
235 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
236 const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
237 const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
238 const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
239 const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
240 const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
241 const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
242 const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
243
244 float32x4_t vmax = vld1q_f32(ab); ab += 4;
245 uint32x4_t vidx = vld1q_u32(ib); ib += 4;
246
247 const uint32x4_t vm0 = vcgtq_f32(vi0, vmax);
248 vmax = vbslq_f32(vm0, vi0, vmax);
249 vidx = vbslq_u32(vm0, vidx0, vidx);
250
251 const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
252 const uint32x4_t vidx1 = vaddq_u32(vidx0, v1);
253 vmax = vbslq_f32(vm1, vi1, vmax);
254 vidx = vbslq_u32(vm1, vidx1, vidx);
255
256 const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
257 const uint32x4_t vidx2 = vaddq_u32(vidx1, v1);
258 vmax = vbslq_f32(vm2, vi2, vmax);
259 vidx = vbslq_u32(vm2, vidx2, vidx);
260
261 const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
262 const uint32x4_t vidx3 = vaddq_u32(vidx2, v1);
263 vmax = vbslq_f32(vm3, vi3, vmax);
264 vidx = vbslq_u32(vm3, vidx3, vidx);
265
266 const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
267 const uint32x4_t vidx4 = vaddq_u32(vidx3, v1);
268 vmax = vbslq_f32(vm4, vi4, vmax);
269 vidx = vbslq_u32(vm4, vidx4, vidx);
270
271 const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
272 const uint32x4_t vidx5 = vaddq_u32(vidx4, v1);
273 vmax = vbslq_f32(vm5, vi5, vmax);
274 vidx = vbslq_u32(vm5, vidx5, vidx);
275
276 const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
277 const uint32x4_t vidx6 = vaddq_u32(vidx5, v1);
278 vmax = vbslq_f32(vm6, vi6, vmax);
279 vidx = vbslq_u32(vm6, vidx6, vidx);
280
281 const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
282 const uint32x4_t vidx7 = vaddq_u32(vidx6, v1);
283 vmax = vbslq_f32(vm7, vi7, vmax);
284 vidx = vbslq_u32(vm7, vidx7, vidx);
285
286 vst1q_f32(o, vmax); o += 4;
287 vst1q_u32(i, vidx); i += 4;
288 }
289 if (c != 0) {
290 const float32x4_t vi0 = vld1q_f32(i0);
291 const float32x4_t vi1 = vld1q_f32(i1);
292 const float32x4_t vi2 = vld1q_f32(i2);
293 const float32x4_t vi3 = vld1q_f32(i3);
294 const float32x4_t vi4 = vld1q_f32(i4);
295 const float32x4_t vi5 = vld1q_f32(i5);
296 const float32x4_t vi6 = vld1q_f32(i6);
297 const float32x4_t vi7 = vld1q_f32(i7);
298
299 float32x4_t vmax = vld1q_f32(ab);
300 uint32x4_t vidx = vld1q_u32(ib);
301
302 const uint32x4_t vm0 = vcgtq_f32(vi0, vmax);
303 vmax = vbslq_f32(vm0, vi0, vmax);
304 vidx = vbslq_u32(vm0, vidx0, vidx);
305
306 const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
307 const uint32x4_t vidx1 = vaddq_u32(vidx0, v1);
308 vmax = vbslq_f32(vm1, vi1, vmax);
309 vidx = vbslq_u32(vm1, vidx1, vidx);
310
311 const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
312 const uint32x4_t vidx2 = vaddq_u32(vidx1, v1);
313 vmax = vbslq_f32(vm2, vi2, vmax);
314 vidx = vbslq_u32(vm2, vidx2, vidx);
315
316 const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
317 const uint32x4_t vidx3 = vaddq_u32(vidx2, v1);
318 vmax = vbslq_f32(vm3, vi3, vmax);
319 vidx = vbslq_u32(vm3, vidx3, vidx);
320
321 const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
322 const uint32x4_t vidx4 = vaddq_u32(vidx3, v1);
323 vmax = vbslq_f32(vm4, vi4, vmax);
324 vidx = vbslq_u32(vm4, vidx4, vidx);
325
326 const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
327 const uint32x4_t vidx5 = vaddq_u32(vidx4, v1);
328 vmax = vbslq_f32(vm5, vi5, vmax);
329 vidx = vbslq_u32(vm5, vidx5, vidx);
330
331 const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
332 const uint32x4_t vidx6 = vaddq_u32(vidx5, v1);
333 vmax = vbslq_f32(vm6, vi6, vmax);
334 vidx = vbslq_u32(vm6, vidx6, vidx);
335
336 const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
337 const uint32x4_t vidx7 = vaddq_u32(vidx6, v1);
338 vmax = vbslq_f32(vm7, vi7, vmax);
339 vidx = vbslq_u32(vm7, vidx7, vidx);
340
341 float32x2_t vmax_lo = vget_low_f32(vmax);
342 uint32x2_t vidx_lo = vget_low_u32(vidx);
343 if (c & 2) {
344 vst1_f32(o, vmax_lo); o += 2;
345 vst1_u32(i, vidx_lo); i += 2;
346 vmax_lo = vget_high_f32(vmax);
347 vidx_lo = vget_high_u32(vidx);
348 }
349 if (c & 1) {
350 vst1_lane_f32(o, vmax_lo, 0); o += 1;
351 vst1_lane_u32(i, vidx_lo, 0); i += 1;
352 }
353 }
354 }
355
356 output = (float*) ((uintptr_t) o + output_increment);
357 index = (uint32_t*) i;
358 } while (--output_pixels != 0);
359 }
360