• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <xnnpack/argmaxpool.h>
9 #include <xnnpack/math.h>
10 
11 
xnn_f32_argmaxpool_ukernel_4x__scalar_c1(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * output,uint32_t * index,size_t input_increment,size_t output_increment,const union xnn_f32_output_params params[restrict static1])12 void xnn_f32_argmaxpool_ukernel_4x__scalar_c1(
13     size_t output_pixels,
14     size_t pooling_elements,
15     size_t channels,
16     const float** input,
17     size_t input_offset,
18     float* output,
19     uint32_t* index,
20     size_t input_increment,
21     size_t output_increment,
22     const union xnn_f32_output_params params[restrict static 1])
23 {
24   assert(output_pixels != 0);
25   assert(pooling_elements != 0);
26   assert(pooling_elements <= 4);
27   assert(channels != 0);
28 
29   const float voutput_max = params->scalar.max;
30   const float voutput_min = params->scalar.min;
31   do {
32     const float* i0 = input[0];
33     const float* i1 = input[1];
34     const float* i2 = input[2];
35     const float* i3 = input[3];
36     i0 = (const float*) ((uintptr_t) i0 + input_offset);
37     i1 = (const float*) ((uintptr_t) i1 + input_offset);
38     i2 = (const float*) ((uintptr_t) i2 + input_offset);
39     i3 = (const float*) ((uintptr_t) i3 + input_offset);
40     if (pooling_elements < 2) {
41       i1 = i0;
42     }
43     if (pooling_elements <= 2) {
44       i2 = i0;
45     }
46     if (pooling_elements != 4) {
47       i3 = i0;
48     }
49 
50     size_t c = channels;
51     do {
52       const float vi0 = *i0++;
53       const float vi1 = *i1++;
54       const float vi2 = *i2++;
55       const float vi3 = *i3++;
56 
57       float vmax = vi0;
58       uint32_t vidx = 0;
59 
60       if (vi1 > vmax) {
61         vmax = vi1;
62         vidx = 1;
63       }
64 
65       if (vi2 > vmax) {
66         vmax = vi2;
67         vidx = 2;
68       }
69 
70       if (vi3 > vmax) {
71         vmax = vi3;
72         vidx = 3;
73       }
74 
75       const float vout = math_max_f32(math_min_f32(vmax, voutput_max), voutput_min);
76 
77       *output++ = vout;
78       *index++ = vidx;
79     } while (--c != 0);
80     input = (const float**) ((uintptr_t) input + input_increment);
81     output = (float*) ((uintptr_t) output + output_increment);
82   } while (--output_pixels != 0);
83 }
84