• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <psimd.h>
9 
10 #include <xnnpack/gavgpool.h>
11 #include <xnnpack/math.h>
12 
13 
xnn_f32_gavgpool_ukernel_mp7p7q__psimd(size_t m,size_t n,const float * input,size_t input_stride,const float * zero,float * buffer,float * output,const union xnn_f32_avgpool_params params[restrict static1])14 void xnn_f32_gavgpool_ukernel_mp7p7q__psimd(
15     size_t m,
16     size_t n,
17     const float* input,
18     size_t input_stride,
19     const float* zero,
20     float* buffer,
21     float* output,
22     const union xnn_f32_avgpool_params params[restrict static 1])
23 {
24   assert(m > 7);
25   assert(n != 0);
26 
27   const float* i0 = input;
28   const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
29   const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
30   const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
31   const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
32   const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
33   const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
34   const size_t packed_n = round_up_po2(n, 4);
35   const size_t input_increment = 7 * input_stride - packed_n * sizeof(float);
36 
37   float* b = buffer;
38   for (size_t k = 0; k < n; k += 4) {
39     const psimd_f32 vi0 = psimd_load_f32(i0);
40     i0 += 4;
41     const psimd_f32 vi1 = psimd_load_f32(i1);
42     i1 += 4;
43     const psimd_f32 vi2 = psimd_load_f32(i2);
44     i2 += 4;
45     const psimd_f32 vi3 = psimd_load_f32(i3);
46     i3 += 4;
47     const psimd_f32 vi4 = psimd_load_f32(i4);
48     i4 += 4;
49     const psimd_f32 vi5 = psimd_load_f32(i5);
50     i5 += 4;
51     const psimd_f32 vi6 = psimd_load_f32(i6);
52     i6 += 4;
53 
54     const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
55     const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
56     const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
57 
58     const psimd_f32 vsum016 = psimd_add_f32(vsum01, vi6);
59     const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
60 
61     const psimd_f32 vsum = psimd_add_f32(vsum016, vsum2345);
62 
63     psimd_store_f32(b, vsum); b += 4;
64   }
65   for (m -= 7; m > 7; m -= 7) {
66     b = buffer;
67 
68     i0 = (const float*) ((uintptr_t) i0 + input_increment);
69     i1 = (const float*) ((uintptr_t) i1 + input_increment);
70     i2 = (const float*) ((uintptr_t) i2 + input_increment);
71     i3 = (const float*) ((uintptr_t) i3 + input_increment);
72     i4 = (const float*) ((uintptr_t) i4 + input_increment);
73     i5 = (const float*) ((uintptr_t) i5 + input_increment);
74     i6 = (const float*) ((uintptr_t) i6 + input_increment);
75 
76     for (size_t k = 0; k < n; k += 4) {
77       const psimd_f32 vi0 = psimd_load_f32(i0);
78       i0 += 4;
79       const psimd_f32 vi1 = psimd_load_f32(i1);
80       i1 += 4;
81       const psimd_f32 vi2 = psimd_load_f32(i2);
82       i2 += 4;
83       const psimd_f32 vi3 = psimd_load_f32(i3);
84       i3 += 4;
85       const psimd_f32 vi4 = psimd_load_f32(i4);
86       i4 += 4;
87       const psimd_f32 vi5 = psimd_load_f32(i5);
88       i5 += 4;
89       const psimd_f32 vi6 = psimd_load_f32(i6);
90       i6 += 4;
91       const psimd_f32 vacc = psimd_load_f32(b);
92 
93       const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
94       const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
95       const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
96       const psimd_f32 vsum6a = psimd_add_f32(vi6, vacc);
97 
98       const psimd_f32 vsum0123 = psimd_add_f32(vsum01, vsum23);
99       const psimd_f32 vsum456a = psimd_add_f32(vsum45, vsum6a);
100 
101       const psimd_f32 vsum = psimd_add_f32(vsum0123, vsum456a);
102 
103       psimd_store_f32(b, vsum); b += 4;
104     }
105   }
106 
107   i0 = (const float*) ((uintptr_t) i0 + input_increment);
108   i1 = (const float*) ((uintptr_t) i1 + input_increment);
109   if (m < 2) {
110     i1 = zero;
111   }
112   i2 = (const float*) ((uintptr_t) i2 + input_increment);
113   if (m <= 2) {
114     i2 = zero;
115   }
116   i3 = (const float*) ((uintptr_t) i3 + input_increment);
117   if (m < 4) {
118     i3 = zero;
119   }
120   i4 = (const float*) ((uintptr_t) i4 + input_increment);
121   if (m <= 4) {
122     i4 = zero;
123   }
124   i5 = (const float*) ((uintptr_t) i5 + input_increment);
125   if (m < 6) {
126     i5 = zero;
127   }
128   i6 = (const float*) ((uintptr_t) i6 + input_increment);
129   if (m <= 6) {
130     i6 = zero;
131   }
132   const psimd_f32 vmultiplier = psimd_load_splat_f32(&params->scalar.multiplier);
133   const psimd_f32 voutput_min = psimd_load_splat_f32(&params->scalar.output_min);
134   const psimd_f32 voutput_max = psimd_load_splat_f32(&params->scalar.output_max);
135 
136   b = buffer;
137   while (n >= 4) {
138     const psimd_f32 vi0 = psimd_load_f32(i0);
139     i0 += 4;
140     const psimd_f32 vi1 = psimd_load_f32(i1);
141     i1 += 4;
142     const psimd_f32 vi2 = psimd_load_f32(i2);
143     i2 += 4;
144     const psimd_f32 vi3 = psimd_load_f32(i3);
145     i3 += 4;
146     const psimd_f32 vi4 = psimd_load_f32(i4);
147     i4 += 4;
148     const psimd_f32 vi5 = psimd_load_f32(i5);
149     i5 += 4;
150     const psimd_f32 vi6 = psimd_load_f32(i6);
151     i6 += 4;
152     const psimd_f32 vacc = psimd_load_f32(b);
153     b += 4;
154 
155     const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
156     const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
157     const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
158     const psimd_f32 vsum6a = psimd_add_f32(vi6, vacc);
159 
160     const psimd_f32 vsum0123 = psimd_add_f32(vsum01, vsum23);
161     const psimd_f32 vsum456a = psimd_add_f32(vsum45, vsum6a);
162 
163     const psimd_f32 vsum = psimd_add_f32(vsum0123, vsum456a);
164 
165     psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
166     vout = psimd_max_f32(vout, voutput_min);
167     vout = psimd_min_f32(vout, voutput_max);
168 
169     psimd_store_f32(output, vout);
170     output += 4;
171 
172     n -= 4;
173   }
174   if (n != 0) {
175     const psimd_f32 vi0 = psimd_load_f32(i0);
176     const psimd_f32 vi1 = psimd_load_f32(i1);
177     const psimd_f32 vi2 = psimd_load_f32(i2);
178     const psimd_f32 vi3 = psimd_load_f32(i3);
179     const psimd_f32 vi4 = psimd_load_f32(i4);
180     const psimd_f32 vi5 = psimd_load_f32(i5);
181     const psimd_f32 vi6 = psimd_load_f32(i6);
182     const psimd_f32 vacc = psimd_load_f32(b);
183 
184     const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
185     const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
186     const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
187     const psimd_f32 vsum6a = psimd_add_f32(vi6, vacc);
188 
189     const psimd_f32 vsum0123 = psimd_add_f32(vsum01, vsum23);
190     const psimd_f32 vsum456a = psimd_add_f32(vsum45, vsum6a);
191 
192     const psimd_f32 vsum = psimd_add_f32(vsum0123, vsum456a);
193 
194     psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
195     vout = psimd_max_f32(vout, voutput_min);
196     vout = psimd_min_f32(vout, voutput_max);
197 
198     if (n & 2) {
199       psimd_store2_f32(output, vout);
200       output += 2;
201       vout = psimd_concat_hi_f32(vout, vout);
202     }
203     if (n & 1) {
204       psimd_store1_f32(output, vout);
205     }
206   }
207 }
208