• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <xnnpack/argmaxpool.h>
9 #include <xnnpack/math.h>
10 
11 
xnn_f32_argmaxpool_ukernel_4x__scalar_c1(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * output,uint32_t * index,size_t input_increment,size_t output_increment)12 void xnn_f32_argmaxpool_ukernel_4x__scalar_c1(
13     size_t output_pixels,
14     size_t pooling_elements,
15     size_t channels,
16     const float** input,
17     size_t input_offset,
18     float* output,
19     uint32_t* index,
20     size_t input_increment,
21     size_t output_increment)
22 {
23   assert(output_pixels != 0);
24   assert(pooling_elements != 0);
25   assert(pooling_elements <= 4);
26   assert(channels != 0);
27 
28   do {
29     const float* i0 = input[0];
30     const float* i1 = input[1];
31     const float* i2 = input[2];
32     const float* i3 = input[3];
33     i0 = (const float*) ((uintptr_t) i0 + input_offset);
34     i1 = (const float*) ((uintptr_t) i1 + input_offset);
35     i2 = (const float*) ((uintptr_t) i2 + input_offset);
36     i3 = (const float*) ((uintptr_t) i3 + input_offset);
37     if (pooling_elements < 2) {
38       i1 = i0;
39     }
40     if (pooling_elements <= 2) {
41       i2 = i0;
42     }
43     if (pooling_elements != 4) {
44       i3 = i0;
45     }
46 
47     size_t c = channels;
48     do {
49       const float vi0 = *i0++;
50       const float vi1 = *i1++;
51       const float vi2 = *i2++;
52       const float vi3 = *i3++;
53 
54       float vmax = vi0;
55       uint32_t vidx = 0;
56 
57       if (vi1 > vmax) {
58         vmax = vi1;
59         vidx = 1;
60       }
61 
62       if (vi2 > vmax) {
63         vmax = vi2;
64         vidx = 2;
65       }
66 
67       if (vi3 > vmax) {
68         vmax = vi3;
69         vidx = 3;
70       }
71 
72       *output++ = vmax;
73       *index++ = vidx;
74     } while (--c != 0);
75     input = (const float**) ((uintptr_t) input + input_increment);
76     output = (float*) ((uintptr_t) output + output_increment);
77   } while (--output_pixels != 0);
78 }
79