• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <xnnpack/argmaxpool.h>
9 #include <xnnpack/math.h>
10 
11 
xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * accumulation_buffer,uint32_t * index_buffer,float * output,uint32_t * index,size_t input_increment,size_t output_increment,const union xnn_f32_output_params params[restrict static1])12 void xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(
13     size_t output_pixels,
14     size_t pooling_elements,
15     size_t channels,
16     const float** input,
17     size_t input_offset,
18     float* accumulation_buffer,
19     uint32_t* index_buffer,
20     float* output,
21     uint32_t* index,
22     size_t input_increment,
23     size_t output_increment,
24     const union xnn_f32_output_params params[restrict static 1])
25 {
26   assert(output_pixels != 0);
27   assert(pooling_elements != 0);
28   assert(pooling_elements > 9);
29   assert(channels != 0);
30 
31   const float voutput_max = params->scalar.max;
32   const float voutput_min = params->scalar.min;
33   do {
34     {
35       float* ab = accumulation_buffer;
36       uint32_t* ib = index_buffer;
37 
38       const float* i0 = *input++;
39       const float* i1 = *input++;
40       const float* i2 = *input++;
41       const float* i3 = *input++;
42       const float* i4 = *input++;
43       const float* i5 = *input++;
44       const float* i6 = *input++;
45       const float* i7 = *input++;
46       const float* i8 = *input++;
47       i0 = (const float*) ((uintptr_t) i0 + input_offset);
48       i1 = (const float*) ((uintptr_t) i1 + input_offset);
49       i2 = (const float*) ((uintptr_t) i2 + input_offset);
50       i3 = (const float*) ((uintptr_t) i3 + input_offset);
51       i4 = (const float*) ((uintptr_t) i4 + input_offset);
52       i5 = (const float*) ((uintptr_t) i5 + input_offset);
53       i6 = (const float*) ((uintptr_t) i6 + input_offset);
54       i7 = (const float*) ((uintptr_t) i7 + input_offset);
55       i8 = (const float*) ((uintptr_t) i8 + input_offset);
56 
57       size_t c = channels;
58       do {
59         const float vi0 = *i0++;
60         const float vi1 = *i1++;
61         const float vi2 = *i2++;
62         const float vi3 = *i3++;
63         const float vi4 = *i4++;
64         const float vi5 = *i5++;
65         const float vi6 = *i6++;
66         const float vi7 = *i7++;
67         const float vi8 = *i8++;
68 
69         float vmax = vi0;
70         uint32_t vidx = 0;
71 
72         if (vi1 > vmax) {
73           vmax = vi1;
74           vidx = 1;
75         }
76 
77         if (vi2 > vmax) {
78           vmax = vi2;
79           vidx = 2;
80         }
81 
82         if (vi3 > vmax) {
83           vmax = vi3;
84           vidx = 3;
85         }
86 
87         if (vi4 > vmax) {
88           vmax = vi4;
89           vidx = 4;
90         }
91 
92         if (vi5 > vmax) {
93           vmax = vi5;
94           vidx = 5;
95         }
96 
97         if (vi6 > vmax) {
98           vmax = vi6;
99           vidx = 6;
100         }
101 
102         if (vi7 > vmax) {
103           vmax = vi7;
104           vidx = 7;
105         }
106 
107         if (vi8 > vmax) {
108           vmax = vi8;
109           vidx = 8;
110         }
111 
112         *ab++ = vmax;
113         *ib++ = vidx;
114       } while (--c != 0);
115     }
116     uint32_t vidx0 = 9;
117     size_t k = pooling_elements;
118     for (k -= 9; k > 8; k -= 8) {
119       const float* i0 = *input++;
120       const float* i1 = *input++;
121       const float* i2 = *input++;
122       const float* i3 = *input++;
123       const float* i4 = *input++;
124       const float* i5 = *input++;
125       const float* i6 = *input++;
126       const float* i7 = *input++;
127       i0 = (const float*) ((uintptr_t) i0 + input_offset);
128       i1 = (const float*) ((uintptr_t) i1 + input_offset);
129       i2 = (const float*) ((uintptr_t) i2 + input_offset);
130       i3 = (const float*) ((uintptr_t) i3 + input_offset);
131       i4 = (const float*) ((uintptr_t) i4 + input_offset);
132       i5 = (const float*) ((uintptr_t) i5 + input_offset);
133       i6 = (const float*) ((uintptr_t) i6 + input_offset);
134       i7 = (const float*) ((uintptr_t) i7 + input_offset);
135 
136       float* ab = accumulation_buffer;
137       uint32_t* ib = index_buffer;
138 
139       size_t c = channels;
140       do {
141         const float vi0 = *i0++;
142         const float vi1 = *i1++;
143         const float vi2 = *i2++;
144         const float vi3 = *i3++;
145         const float vi4 = *i4++;
146         const float vi5 = *i5++;
147         const float vi6 = *i6++;
148         const float vi7 = *i7++;
149 
150         float vmax = *ab;
151         uint32_t vidx = *ib;
152 
153         if (vi0 > vmax) {
154           vmax = vi0;
155           vidx = vidx0;
156         }
157 
158         if (vi1 > vmax) {
159           vmax = vi1;
160           vidx = vidx0 + 1;
161         }
162 
163         if (vi2 > vmax) {
164           vmax = vi2;
165           vidx = vidx0 + 2;
166         }
167 
168         if (vi3 > vmax) {
169           vmax = vi3;
170           vidx = vidx0 + 3;
171         }
172 
173         if (vi4 > vmax) {
174           vmax = vi4;
175           vidx = vidx0 + 4;
176         }
177 
178         if (vi5 > vmax) {
179           vmax = vi5;
180           vidx = vidx0 + 5;
181         }
182 
183         if (vi6 > vmax) {
184           vmax = vi6;
185           vidx = vidx0 + 6;
186         }
187 
188         if (vi7 > vmax) {
189           vmax = vi7;
190           vidx = vidx0 + 7;
191         }
192 
193         *ab++ = vmax;
194         *ib++ = vidx;
195       } while (--c != 0);
196       vidx0 += 8;
197     }
198 
199     float* o = output;
200     uint32_t* i = index;
201     {
202       const float* i0 = input[0];
203       const float* i1 = input[1];
204       const float* i2 = input[2];
205       const float* i3 = input[3];
206       const float* i4 = input[4];
207       const float* i5 = input[5];
208       const float* i6 = input[6];
209       const float* i7 = input[7];
210       i0 = (const float*) ((uintptr_t) i0 + input_offset);
211       i1 = (const float*) ((uintptr_t) i1 + input_offset);
212       i2 = (const float*) ((uintptr_t) i2 + input_offset);
213       i3 = (const float*) ((uintptr_t) i3 + input_offset);
214       i4 = (const float*) ((uintptr_t) i4 + input_offset);
215       i5 = (const float*) ((uintptr_t) i5 + input_offset);
216       i6 = (const float*) ((uintptr_t) i6 + input_offset);
217       i7 = (const float*) ((uintptr_t) i7 + input_offset);
218       input = (const float**) ((uintptr_t) input + input_increment);
219       if (k < 2) {
220         i1 = i0;
221       }
222       if (k <= 2) {
223         i2 = i0;
224       }
225       if (k < 4) {
226         i3 = i0;
227       }
228       if (k <= 4) {
229         i4 = i0;
230       }
231       if (k < 6) {
232         i5 = i0;
233       }
234       if (k <= 6) {
235         i6 = i0;
236       }
237       if (k != 8) {
238         i7 = i0;
239       }
240 
241       size_t c = channels;
242       float* ab = accumulation_buffer;
243       uint32_t* ib = index_buffer;
244       do {
245         const float vi0 = *i0++;
246         const float vi1 = *i1++;
247         const float vi2 = *i2++;
248         const float vi3 = *i3++;
249         const float vi4 = *i4++;
250         const float vi5 = *i5++;
251         const float vi6 = *i6++;
252         const float vi7 = *i7++;
253 
254         float vmax = *ab++;
255         uint32_t vidx = *ib++;
256 
257         if (vi0 > vmax) {
258           vmax = vi0;
259           vidx = vidx0;
260         }
261 
262         if (vi1 > vmax) {
263           vmax = vi1;
264           vidx = vidx0 + 1;
265         }
266 
267         if (vi2 > vmax) {
268           vmax = vi2;
269           vidx = vidx0 + 2;
270         }
271 
272         if (vi3 > vmax) {
273           vmax = vi3;
274           vidx = vidx0 + 3;
275         }
276 
277         if (vi4 > vmax) {
278           vmax = vi4;
279           vidx = vidx0 + 4;
280         }
281 
282         if (vi5 > vmax) {
283           vmax = vi5;
284           vidx = vidx0 + 5;
285         }
286 
287         if (vi6 > vmax) {
288           vmax = vi6;
289           vidx = vidx0 + 6;
290         }
291 
292         if (vi7 > vmax) {
293           vmax = vi7;
294           vidx = vidx0 + 7;
295         }
296 
297         const float vout = math_max_f32(math_min_f32(vmax, voutput_max), voutput_min);
298 
299         *o++ = vout;
300         *i++ = vidx;
301       } while (--c != 0);
302     }
303 
304     output = (float*) ((uintptr_t) o + output_increment);
305     index = (uint32_t*) i;
306   } while (--output_pixels != 0);
307 }
308