1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <xnnpack/argmaxpool.h>
9 #include <xnnpack/math.h>
10
11
xnn_f32_argmaxpool_ukernel_4x__scalar_c1(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * output,uint32_t * index,size_t input_increment,size_t output_increment,const union xnn_f32_output_params params[restrict static1])12 void xnn_f32_argmaxpool_ukernel_4x__scalar_c1(
13 size_t output_pixels,
14 size_t pooling_elements,
15 size_t channels,
16 const float** input,
17 size_t input_offset,
18 float* output,
19 uint32_t* index,
20 size_t input_increment,
21 size_t output_increment,
22 const union xnn_f32_output_params params[restrict static 1])
23 {
24 assert(output_pixels != 0);
25 assert(pooling_elements != 0);
26 assert(pooling_elements <= 4);
27 assert(channels != 0);
28
29 const float voutput_max = params->scalar.max;
30 const float voutput_min = params->scalar.min;
31 do {
32 const float* i0 = input[0];
33 const float* i1 = input[1];
34 const float* i2 = input[2];
35 const float* i3 = input[3];
36 i0 = (const float*) ((uintptr_t) i0 + input_offset);
37 i1 = (const float*) ((uintptr_t) i1 + input_offset);
38 i2 = (const float*) ((uintptr_t) i2 + input_offset);
39 i3 = (const float*) ((uintptr_t) i3 + input_offset);
40 if (pooling_elements < 2) {
41 i1 = i0;
42 }
43 if (pooling_elements <= 2) {
44 i2 = i0;
45 }
46 if (pooling_elements != 4) {
47 i3 = i0;
48 }
49
50 size_t c = channels;
51 do {
52 const float vi0 = *i0++;
53 const float vi1 = *i1++;
54 const float vi2 = *i2++;
55 const float vi3 = *i3++;
56
57 float vmax = vi0;
58 uint32_t vidx = 0;
59
60 if (vi1 > vmax) {
61 vmax = vi1;
62 vidx = 1;
63 }
64
65 if (vi2 > vmax) {
66 vmax = vi2;
67 vidx = 2;
68 }
69
70 if (vi3 > vmax) {
71 vmax = vi3;
72 vidx = 3;
73 }
74
75 const float vout = math_max_f32(math_min_f32(vmax, voutput_max), voutput_min);
76
77 *output++ = vout;
78 *index++ = vidx;
79 } while (--c != 0);
80 input = (const float**) ((uintptr_t) input + input_increment);
81 output = (float*) ((uintptr_t) output + output_increment);
82 } while (--output_pixels != 0);
83 }
84