• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <arm_neon.h>
9 
10 #include <xnnpack/argmaxpool.h>
11 
12 
xnn_f32_argmaxpool_ukernel_9p8x__neon_c4(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * accumulation_buffer,uint32_t * index_buffer,float * output,uint32_t * index,size_t input_increment,size_t output_increment)13 void xnn_f32_argmaxpool_ukernel_9p8x__neon_c4(
14     size_t output_pixels,
15     size_t pooling_elements,
16     size_t channels,
17     const float** input,
18     size_t input_offset,
19     float* accumulation_buffer,
20     uint32_t* index_buffer,
21     float* output,
22     uint32_t* index,
23     size_t input_increment,
24     size_t output_increment) XNN_DISABLE_TSAN
25 {
26   assert(output_pixels != 0);
27   assert(pooling_elements != 0);
28   assert(pooling_elements > 9);
29   assert(channels != 0);
30 
31   do {
32     {
33       float* ab = accumulation_buffer;
34       uint32_t* ib = index_buffer;
35 
36       const float* i0 = *input++;
37       const float* i1 = *input++;
38       const float* i2 = *input++;
39       const float* i3 = *input++;
40       const float* i4 = *input++;
41       const float* i5 = *input++;
42       const float* i6 = *input++;
43       const float* i7 = *input++;
44       const float* i8 = *input++;
45       i0 = (const float*) ((uintptr_t) i0 + input_offset);
46       i1 = (const float*) ((uintptr_t) i1 + input_offset);
47       i2 = (const float*) ((uintptr_t) i2 + input_offset);
48       i3 = (const float*) ((uintptr_t) i3 + input_offset);
49       i4 = (const float*) ((uintptr_t) i4 + input_offset);
50       i5 = (const float*) ((uintptr_t) i5 + input_offset);
51       i6 = (const float*) ((uintptr_t) i6 + input_offset);
52       i7 = (const float*) ((uintptr_t) i7 + input_offset);
53       i8 = (const float*) ((uintptr_t) i8 + input_offset);
54 
55       for (size_t c = 0; c < channels; c += 4) {
56         const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
57         const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
58         const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
59         const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
60         const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
61         const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
62         const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
63         const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
64         const float32x4_t vi8 = vld1q_f32(i8); i8 += 4;
65 
66         float32x4_t vmax = vi0;
67         uint32x4_t vidx = vmovq_n_u32(0);
68 
69         const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
70         vmax = vbslq_f32(vm1, vi1, vmax);
71         vidx = vbslq_u32(vm1, vmovq_n_u32(1), vidx);
72 
73         const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
74         vmax = vbslq_f32(vm2, vi2, vmax);
75         vidx = vbslq_u32(vm2, vmovq_n_u32(2), vidx);
76 
77         const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
78         vmax = vbslq_f32(vm3, vi3, vmax);
79         vidx = vbslq_u32(vm3, vmovq_n_u32(3), vidx);
80 
81         const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
82         vmax = vbslq_f32(vm4, vi4, vmax);
83         vidx = vbslq_u32(vm4, vmovq_n_u32(4), vidx);
84 
85         const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
86         vmax = vbslq_f32(vm5, vi5, vmax);
87         vidx = vbslq_u32(vm5, vmovq_n_u32(5), vidx);
88 
89         const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
90         vmax = vbslq_f32(vm6, vi6, vmax);
91         vidx = vbslq_u32(vm6, vmovq_n_u32(6), vidx);
92 
93         const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
94         vmax = vbslq_f32(vm7, vi7, vmax);
95         vidx = vbslq_u32(vm7, vmovq_n_u32(7), vidx);
96 
97         const uint32x4_t vm8 = vcgtq_f32(vi8, vmax);
98         vmax = vbslq_f32(vm8, vi8, vmax);
99         vidx = vbslq_u32(vm8, vmovq_n_u32(8), vidx);
100 
101         vst1q_f32(ab, vmax); ab += 4;
102         vst1q_u32(ib, vidx); ib += 4;
103       }
104     }
105     const uint32x4_t v1 = vmovq_n_u32(1);
106     const uint32x4_t v8 = vmovq_n_u32(8);
107     uint32x4_t vidx0 = vaddq_u32(v1, v8);
108 
109     size_t k = pooling_elements;
110     for (k -= 9; k > 8; k -= 8) {
111       const float* i0 = *input++;
112       const float* i1 = *input++;
113       const float* i2 = *input++;
114       const float* i3 = *input++;
115       const float* i4 = *input++;
116       const float* i5 = *input++;
117       const float* i6 = *input++;
118       const float* i7 = *input++;
119       i0 = (const float*) ((uintptr_t) i0 + input_offset);
120       i1 = (const float*) ((uintptr_t) i1 + input_offset);
121       i2 = (const float*) ((uintptr_t) i2 + input_offset);
122       i3 = (const float*) ((uintptr_t) i3 + input_offset);
123       i4 = (const float*) ((uintptr_t) i4 + input_offset);
124       i5 = (const float*) ((uintptr_t) i5 + input_offset);
125       i6 = (const float*) ((uintptr_t) i6 + input_offset);
126       i7 = (const float*) ((uintptr_t) i7 + input_offset);
127 
128       float* ab = accumulation_buffer;
129       uint32_t* ib = index_buffer;
130 
131       for (size_t c = 0; c < channels; c += 4) {
132         const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
133         const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
134         const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
135         const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
136         const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
137         const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
138         const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
139         const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
140 
141         float32x4_t vmax = vld1q_f32(ab);
142         uint32x4_t vidx = vld1q_u32(ib);
143 
144         const uint32x4_t vm0 = vcgtq_f32(vi0, vmax);
145         vmax = vbslq_f32(vm0, vi0, vmax);
146         vidx = vbslq_u32(vm0, vidx0, vidx);
147 
148         const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
149         const uint32x4_t vidx1 = vaddq_u32(vidx0, v1);
150         vmax = vbslq_f32(vm1, vi1, vmax);
151         vidx = vbslq_u32(vm1, vidx1, vidx);
152 
153         const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
154         const uint32x4_t vidx2 = vaddq_u32(vidx1, v1);
155         vmax = vbslq_f32(vm2, vi2, vmax);
156         vidx = vbslq_u32(vm2, vidx2, vidx);
157 
158         const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
159         const uint32x4_t vidx3 = vaddq_u32(vidx2, v1);
160         vmax = vbslq_f32(vm3, vi3, vmax);
161         vidx = vbslq_u32(vm3, vidx3, vidx);
162 
163         const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
164         const uint32x4_t vidx4 = vaddq_u32(vidx3, v1);
165         vmax = vbslq_f32(vm4, vi4, vmax);
166         vidx = vbslq_u32(vm4, vidx4, vidx);
167 
168         const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
169         const uint32x4_t vidx5 = vaddq_u32(vidx4, v1);
170         vmax = vbslq_f32(vm5, vi5, vmax);
171         vidx = vbslq_u32(vm5, vidx5, vidx);
172 
173         const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
174         const uint32x4_t vidx6 = vaddq_u32(vidx5, v1);
175         vmax = vbslq_f32(vm6, vi6, vmax);
176         vidx = vbslq_u32(vm6, vidx6, vidx);
177 
178         const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
179         const uint32x4_t vidx7 = vaddq_u32(vidx6, v1);
180         vmax = vbslq_f32(vm7, vi7, vmax);
181         vidx = vbslq_u32(vm7, vidx7, vidx);
182 
183         vst1q_f32(ab, vmax); ab += 4;
184         vst1q_u32(ib, vidx); ib += 4;
185       }
186       vidx0 = vaddq_u32(vidx0, v8);
187     }
188 
189     float* o = output;
190     uint32_t* i = index;
191     {
192       const float* i0 = input[0];
193       const float* i1 = input[1];
194       const float* i2 = input[2];
195       const float* i3 = input[3];
196       const float* i4 = input[4];
197       const float* i5 = input[5];
198       const float* i6 = input[6];
199       const float* i7 = input[7];
200       i0 = (const float*) ((uintptr_t) i0 + input_offset);
201       i1 = (const float*) ((uintptr_t) i1 + input_offset);
202       i2 = (const float*) ((uintptr_t) i2 + input_offset);
203       i3 = (const float*) ((uintptr_t) i3 + input_offset);
204       i4 = (const float*) ((uintptr_t) i4 + input_offset);
205       i5 = (const float*) ((uintptr_t) i5 + input_offset);
206       i6 = (const float*) ((uintptr_t) i6 + input_offset);
207       i7 = (const float*) ((uintptr_t) i7 + input_offset);
208       input = (const float**) ((uintptr_t) input + input_increment);
209       if (k < 2) {
210         i1 = i0;
211       }
212       if (k <= 2) {
213         i2 = i0;
214       }
215       if (k < 4) {
216         i3 = i0;
217       }
218       if (k <= 4) {
219         i4 = i0;
220       }
221       if (k < 6) {
222         i5 = i0;
223       }
224       if (k <= 6) {
225         i6 = i0;
226       }
227       if (k != 8) {
228         i7 = i0;
229       }
230 
231       size_t c = channels;
232       float* ab = accumulation_buffer;
233       uint32_t* ib = index_buffer;
234       for (; c >= 4; c -= 4) {
235         const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
236         const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
237         const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
238         const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
239         const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
240         const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
241         const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
242         const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
243 
244         float32x4_t vmax = vld1q_f32(ab); ab += 4;
245         uint32x4_t vidx = vld1q_u32(ib); ib += 4;
246 
247         const uint32x4_t vm0 = vcgtq_f32(vi0, vmax);
248         vmax = vbslq_f32(vm0, vi0, vmax);
249         vidx = vbslq_u32(vm0, vidx0, vidx);
250 
251         const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
252         const uint32x4_t vidx1 = vaddq_u32(vidx0, v1);
253         vmax = vbslq_f32(vm1, vi1, vmax);
254         vidx = vbslq_u32(vm1, vidx1, vidx);
255 
256         const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
257         const uint32x4_t vidx2 = vaddq_u32(vidx1, v1);
258         vmax = vbslq_f32(vm2, vi2, vmax);
259         vidx = vbslq_u32(vm2, vidx2, vidx);
260 
261         const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
262         const uint32x4_t vidx3 = vaddq_u32(vidx2, v1);
263         vmax = vbslq_f32(vm3, vi3, vmax);
264         vidx = vbslq_u32(vm3, vidx3, vidx);
265 
266         const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
267         const uint32x4_t vidx4 = vaddq_u32(vidx3, v1);
268         vmax = vbslq_f32(vm4, vi4, vmax);
269         vidx = vbslq_u32(vm4, vidx4, vidx);
270 
271         const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
272         const uint32x4_t vidx5 = vaddq_u32(vidx4, v1);
273         vmax = vbslq_f32(vm5, vi5, vmax);
274         vidx = vbslq_u32(vm5, vidx5, vidx);
275 
276         const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
277         const uint32x4_t vidx6 = vaddq_u32(vidx5, v1);
278         vmax = vbslq_f32(vm6, vi6, vmax);
279         vidx = vbslq_u32(vm6, vidx6, vidx);
280 
281         const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
282         const uint32x4_t vidx7 = vaddq_u32(vidx6, v1);
283         vmax = vbslq_f32(vm7, vi7, vmax);
284         vidx = vbslq_u32(vm7, vidx7, vidx);
285 
286         vst1q_f32(o, vmax); o += 4;
287         vst1q_u32(i, vidx); i += 4;
288       }
289       if (c != 0) {
290         const float32x4_t vi0 = vld1q_f32(i0);
291         const float32x4_t vi1 = vld1q_f32(i1);
292         const float32x4_t vi2 = vld1q_f32(i2);
293         const float32x4_t vi3 = vld1q_f32(i3);
294         const float32x4_t vi4 = vld1q_f32(i4);
295         const float32x4_t vi5 = vld1q_f32(i5);
296         const float32x4_t vi6 = vld1q_f32(i6);
297         const float32x4_t vi7 = vld1q_f32(i7);
298 
299         float32x4_t vmax = vld1q_f32(ab);
300         uint32x4_t vidx = vld1q_u32(ib);
301 
302         const uint32x4_t vm0 = vcgtq_f32(vi0, vmax);
303         vmax = vbslq_f32(vm0, vi0, vmax);
304         vidx = vbslq_u32(vm0, vidx0, vidx);
305 
306         const uint32x4_t vm1 = vcgtq_f32(vi1, vmax);
307         const uint32x4_t vidx1 = vaddq_u32(vidx0, v1);
308         vmax = vbslq_f32(vm1, vi1, vmax);
309         vidx = vbslq_u32(vm1, vidx1, vidx);
310 
311         const uint32x4_t vm2 = vcgtq_f32(vi2, vmax);
312         const uint32x4_t vidx2 = vaddq_u32(vidx1, v1);
313         vmax = vbslq_f32(vm2, vi2, vmax);
314         vidx = vbslq_u32(vm2, vidx2, vidx);
315 
316         const uint32x4_t vm3 = vcgtq_f32(vi3, vmax);
317         const uint32x4_t vidx3 = vaddq_u32(vidx2, v1);
318         vmax = vbslq_f32(vm3, vi3, vmax);
319         vidx = vbslq_u32(vm3, vidx3, vidx);
320 
321         const uint32x4_t vm4 = vcgtq_f32(vi4, vmax);
322         const uint32x4_t vidx4 = vaddq_u32(vidx3, v1);
323         vmax = vbslq_f32(vm4, vi4, vmax);
324         vidx = vbslq_u32(vm4, vidx4, vidx);
325 
326         const uint32x4_t vm5 = vcgtq_f32(vi5, vmax);
327         const uint32x4_t vidx5 = vaddq_u32(vidx4, v1);
328         vmax = vbslq_f32(vm5, vi5, vmax);
329         vidx = vbslq_u32(vm5, vidx5, vidx);
330 
331         const uint32x4_t vm6 = vcgtq_f32(vi6, vmax);
332         const uint32x4_t vidx6 = vaddq_u32(vidx5, v1);
333         vmax = vbslq_f32(vm6, vi6, vmax);
334         vidx = vbslq_u32(vm6, vidx6, vidx);
335 
336         const uint32x4_t vm7 = vcgtq_f32(vi7, vmax);
337         const uint32x4_t vidx7 = vaddq_u32(vidx6, v1);
338         vmax = vbslq_f32(vm7, vi7, vmax);
339         vidx = vbslq_u32(vm7, vidx7, vidx);
340 
341         float32x2_t vmax_lo = vget_low_f32(vmax);
342         uint32x2_t vidx_lo = vget_low_u32(vidx);
343         if (c & 2) {
344           vst1_f32(o, vmax_lo); o += 2;
345           vst1_u32(i, vidx_lo); i += 2;
346           vmax_lo = vget_high_f32(vmax);
347           vidx_lo = vget_high_u32(vidx);
348         }
349         if (c & 1) {
350           vst1_lane_f32(o, vmax_lo, 0); o += 1;
351           vst1_lane_u32(i, vidx_lo, 0); i += 1;
352         }
353       }
354     }
355 
356     output = (float*) ((uintptr_t) o + output_increment);
357     index = (uint32_t*) i;
358   } while (--output_pixels != 0);
359 }
360