1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <xnnpack/argmaxpool.h>
9 #include <xnnpack/math.h>
10
11
xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * accumulation_buffer,uint32_t * index_buffer,float * output,uint32_t * index,size_t input_increment,size_t output_increment,const union xnn_f32_output_params params[restrict static1])12 void xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(
13 size_t output_pixels,
14 size_t pooling_elements,
15 size_t channels,
16 const float** input,
17 size_t input_offset,
18 float* accumulation_buffer,
19 uint32_t* index_buffer,
20 float* output,
21 uint32_t* index,
22 size_t input_increment,
23 size_t output_increment,
24 const union xnn_f32_output_params params[restrict static 1])
25 {
26 assert(output_pixels != 0);
27 assert(pooling_elements != 0);
28 assert(pooling_elements > 9);
29 assert(channels != 0);
30
31 const float voutput_max = params->scalar.max;
32 const float voutput_min = params->scalar.min;
33 do {
34 {
35 float* ab = accumulation_buffer;
36 uint32_t* ib = index_buffer;
37
38 const float* i0 = *input++;
39 const float* i1 = *input++;
40 const float* i2 = *input++;
41 const float* i3 = *input++;
42 const float* i4 = *input++;
43 const float* i5 = *input++;
44 const float* i6 = *input++;
45 const float* i7 = *input++;
46 const float* i8 = *input++;
47 i0 = (const float*) ((uintptr_t) i0 + input_offset);
48 i1 = (const float*) ((uintptr_t) i1 + input_offset);
49 i2 = (const float*) ((uintptr_t) i2 + input_offset);
50 i3 = (const float*) ((uintptr_t) i3 + input_offset);
51 i4 = (const float*) ((uintptr_t) i4 + input_offset);
52 i5 = (const float*) ((uintptr_t) i5 + input_offset);
53 i6 = (const float*) ((uintptr_t) i6 + input_offset);
54 i7 = (const float*) ((uintptr_t) i7 + input_offset);
55 i8 = (const float*) ((uintptr_t) i8 + input_offset);
56
57 size_t c = channels;
58 do {
59 const float vi0 = *i0++;
60 const float vi1 = *i1++;
61 const float vi2 = *i2++;
62 const float vi3 = *i3++;
63 const float vi4 = *i4++;
64 const float vi5 = *i5++;
65 const float vi6 = *i6++;
66 const float vi7 = *i7++;
67 const float vi8 = *i8++;
68
69 float vmax = vi0;
70 uint32_t vidx = 0;
71
72 if (vi1 > vmax) {
73 vmax = vi1;
74 vidx = 1;
75 }
76
77 if (vi2 > vmax) {
78 vmax = vi2;
79 vidx = 2;
80 }
81
82 if (vi3 > vmax) {
83 vmax = vi3;
84 vidx = 3;
85 }
86
87 if (vi4 > vmax) {
88 vmax = vi4;
89 vidx = 4;
90 }
91
92 if (vi5 > vmax) {
93 vmax = vi5;
94 vidx = 5;
95 }
96
97 if (vi6 > vmax) {
98 vmax = vi6;
99 vidx = 6;
100 }
101
102 if (vi7 > vmax) {
103 vmax = vi7;
104 vidx = 7;
105 }
106
107 if (vi8 > vmax) {
108 vmax = vi8;
109 vidx = 8;
110 }
111
112 *ab++ = vmax;
113 *ib++ = vidx;
114 } while (--c != 0);
115 }
116 uint32_t vidx0 = 9;
117 size_t k = pooling_elements;
118 for (k -= 9; k > 8; k -= 8) {
119 const float* i0 = *input++;
120 const float* i1 = *input++;
121 const float* i2 = *input++;
122 const float* i3 = *input++;
123 const float* i4 = *input++;
124 const float* i5 = *input++;
125 const float* i6 = *input++;
126 const float* i7 = *input++;
127 i0 = (const float*) ((uintptr_t) i0 + input_offset);
128 i1 = (const float*) ((uintptr_t) i1 + input_offset);
129 i2 = (const float*) ((uintptr_t) i2 + input_offset);
130 i3 = (const float*) ((uintptr_t) i3 + input_offset);
131 i4 = (const float*) ((uintptr_t) i4 + input_offset);
132 i5 = (const float*) ((uintptr_t) i5 + input_offset);
133 i6 = (const float*) ((uintptr_t) i6 + input_offset);
134 i7 = (const float*) ((uintptr_t) i7 + input_offset);
135
136 float* ab = accumulation_buffer;
137 uint32_t* ib = index_buffer;
138
139 size_t c = channels;
140 do {
141 const float vi0 = *i0++;
142 const float vi1 = *i1++;
143 const float vi2 = *i2++;
144 const float vi3 = *i3++;
145 const float vi4 = *i4++;
146 const float vi5 = *i5++;
147 const float vi6 = *i6++;
148 const float vi7 = *i7++;
149
150 float vmax = *ab;
151 uint32_t vidx = *ib;
152
153 if (vi0 > vmax) {
154 vmax = vi0;
155 vidx = vidx0;
156 }
157
158 if (vi1 > vmax) {
159 vmax = vi1;
160 vidx = vidx0 + 1;
161 }
162
163 if (vi2 > vmax) {
164 vmax = vi2;
165 vidx = vidx0 + 2;
166 }
167
168 if (vi3 > vmax) {
169 vmax = vi3;
170 vidx = vidx0 + 3;
171 }
172
173 if (vi4 > vmax) {
174 vmax = vi4;
175 vidx = vidx0 + 4;
176 }
177
178 if (vi5 > vmax) {
179 vmax = vi5;
180 vidx = vidx0 + 5;
181 }
182
183 if (vi6 > vmax) {
184 vmax = vi6;
185 vidx = vidx0 + 6;
186 }
187
188 if (vi7 > vmax) {
189 vmax = vi7;
190 vidx = vidx0 + 7;
191 }
192
193 *ab++ = vmax;
194 *ib++ = vidx;
195 } while (--c != 0);
196 vidx0 += 8;
197 }
198
199 float* o = output;
200 uint32_t* i = index;
201 {
202 const float* i0 = input[0];
203 const float* i1 = input[1];
204 const float* i2 = input[2];
205 const float* i3 = input[3];
206 const float* i4 = input[4];
207 const float* i5 = input[5];
208 const float* i6 = input[6];
209 const float* i7 = input[7];
210 i0 = (const float*) ((uintptr_t) i0 + input_offset);
211 i1 = (const float*) ((uintptr_t) i1 + input_offset);
212 i2 = (const float*) ((uintptr_t) i2 + input_offset);
213 i3 = (const float*) ((uintptr_t) i3 + input_offset);
214 i4 = (const float*) ((uintptr_t) i4 + input_offset);
215 i5 = (const float*) ((uintptr_t) i5 + input_offset);
216 i6 = (const float*) ((uintptr_t) i6 + input_offset);
217 i7 = (const float*) ((uintptr_t) i7 + input_offset);
218 input = (const float**) ((uintptr_t) input + input_increment);
219 if (k < 2) {
220 i1 = i0;
221 }
222 if (k <= 2) {
223 i2 = i0;
224 }
225 if (k < 4) {
226 i3 = i0;
227 }
228 if (k <= 4) {
229 i4 = i0;
230 }
231 if (k < 6) {
232 i5 = i0;
233 }
234 if (k <= 6) {
235 i6 = i0;
236 }
237 if (k != 8) {
238 i7 = i0;
239 }
240
241 size_t c = channels;
242 float* ab = accumulation_buffer;
243 uint32_t* ib = index_buffer;
244 do {
245 const float vi0 = *i0++;
246 const float vi1 = *i1++;
247 const float vi2 = *i2++;
248 const float vi3 = *i3++;
249 const float vi4 = *i4++;
250 const float vi5 = *i5++;
251 const float vi6 = *i6++;
252 const float vi7 = *i7++;
253
254 float vmax = *ab++;
255 uint32_t vidx = *ib++;
256
257 if (vi0 > vmax) {
258 vmax = vi0;
259 vidx = vidx0;
260 }
261
262 if (vi1 > vmax) {
263 vmax = vi1;
264 vidx = vidx0 + 1;
265 }
266
267 if (vi2 > vmax) {
268 vmax = vi2;
269 vidx = vidx0 + 2;
270 }
271
272 if (vi3 > vmax) {
273 vmax = vi3;
274 vidx = vidx0 + 3;
275 }
276
277 if (vi4 > vmax) {
278 vmax = vi4;
279 vidx = vidx0 + 4;
280 }
281
282 if (vi5 > vmax) {
283 vmax = vi5;
284 vidx = vidx0 + 5;
285 }
286
287 if (vi6 > vmax) {
288 vmax = vi6;
289 vidx = vidx0 + 6;
290 }
291
292 if (vi7 > vmax) {
293 vmax = vi7;
294 vidx = vidx0 + 7;
295 }
296
297 const float vout = math_max_f32(math_min_f32(vmax, voutput_max), voutput_min);
298
299 *o++ = vout;
300 *i++ = vidx;
301 } while (--c != 0);
302 }
303
304 output = (float*) ((uintptr_t) o + output_increment);
305 index = (uint32_t*) i;
306 } while (--output_pixels != 0);
307 }
308