1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/intrinsics-polyfill.h>
11 #include <xnnpack/pavgpool.h>
12
13
xnn_f16_pavgpool_minmax_ukernel_9p8x__avx2_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const void ** input,size_t input_offset,const void * zero,const void * multiplier,void * buffer,void * output,size_t input_increment,size_t output_increment,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f16_pavgpool_minmax_ukernel_9p8x__avx2_c8(
15 size_t output_pixels,
16 size_t kernel_elements,
17 size_t channels,
18 const void** input,
19 size_t input_offset,
20 const void* zero,
21 const void* multiplier,
22 void* buffer,
23 void* output,
24 size_t input_increment,
25 size_t output_increment,
26 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
27 {
28 assert(output_pixels != 0);
29 assert(kernel_elements > 9);
30 assert(channels != 0);
31
32 const __m256 voutput_min = _mm256_load_ps(params->avx.min);
33 const __m256 voutput_max = _mm256_load_ps(params->avx.max);
34
35 uint16_t* o = (uint16_t*) output;
36 do {
37 {
38 const uint16_t* i0 = (const uint16_t*) *input++;
39 assert(i0 != NULL);
40 if XNN_UNPREDICTABLE(i0 != zero) {
41 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
42 }
43 const uint16_t* i1 = (const uint16_t*) *input++;
44 assert(i1 != NULL);
45 if XNN_UNPREDICTABLE(i1 != zero) {
46 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
47 }
48 const uint16_t* i2 = (const uint16_t*) *input++;
49 assert(i2 != NULL);
50 if XNN_UNPREDICTABLE(i2 != zero) {
51 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
52 }
53 const uint16_t* i3 = (const uint16_t*) *input++;
54 assert(i3 != NULL);
55 if XNN_UNPREDICTABLE(i3 != zero) {
56 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
57 }
58 const uint16_t* i4 = (const uint16_t*) *input++;
59 assert(i4 != NULL);
60 if XNN_UNPREDICTABLE(i4 != zero) {
61 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
62 }
63 const uint16_t* i5 = (const uint16_t*) *input++;
64 assert(i5 != NULL);
65 if XNN_UNPREDICTABLE(i5 != zero) {
66 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
67 }
68 const uint16_t* i6 = (const uint16_t*) *input++;
69 assert(i6 != NULL);
70 if XNN_UNPREDICTABLE(i6 != zero) {
71 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
72 }
73 const uint16_t* i7 = (const uint16_t*) *input++;
74 assert(i7 != NULL);
75 if XNN_UNPREDICTABLE(i7 != zero) {
76 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
77 }
78 const uint16_t* i8 = (const uint16_t*) *input++;
79 assert(i8 != NULL);
80 if XNN_UNPREDICTABLE(i8 != zero) {
81 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
82 }
83
84 uint16_t* b = (uint16_t*) buffer;
85 for (size_t c = 0; c < channels; c += 8) {
86 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
87 i0 += 8;
88 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
89 i1 += 8;
90 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
91 i2 += 8;
92 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
93 i3 += 8;
94 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
95 i4 += 8;
96 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
97 i5 += 8;
98 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
99 i6 += 8;
100 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
101 i7 += 8;
102 const __m256 vi8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
103 i8 += 8;
104
105 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
106 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
107 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
108 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
109 const __m256 vsum018 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vi8), _MM_FROUND_NO_EXC));
110 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
111 const __m256 vsum01678 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum018, vsum67), _MM_FROUND_NO_EXC));
112 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum01678), _MM_FROUND_NO_EXC));
113
114 _mm_storeu_si128((__m128i*) b, _mm256_cvtps_ph(vsum, _MM_FROUND_NO_EXC));
115 b += 8;
116 }
117 }
118
119 size_t k = kernel_elements;
120 for (k -= 9; k > 8; k -= 8) {
121 const uint16_t* i0 = (const uint16_t*) *input++;
122 assert(i0 != NULL);
123 if XNN_UNPREDICTABLE(i0 != zero) {
124 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
125 }
126 const uint16_t* i1 = (const uint16_t*) *input++;
127 assert(i1 != NULL);
128 if XNN_UNPREDICTABLE(i1 != zero) {
129 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
130 }
131 const uint16_t* i2 = (const uint16_t*) *input++;
132 assert(i2 != NULL);
133 if XNN_UNPREDICTABLE(i2 != zero) {
134 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
135 }
136 const uint16_t* i3 = (const uint16_t*) *input++;
137 assert(i3 != NULL);
138 if XNN_UNPREDICTABLE(i3 != zero) {
139 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
140 }
141 const uint16_t* i4 = (const uint16_t*) *input++;
142 assert(i4 != NULL);
143 if XNN_UNPREDICTABLE(i4 != zero) {
144 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
145 }
146 const uint16_t* i5 = (const uint16_t*) *input++;
147 assert(i5 != NULL);
148 if XNN_UNPREDICTABLE(i5 != zero) {
149 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
150 }
151 const uint16_t* i6 = (const uint16_t*) *input++;
152 assert(i6 != NULL);
153 if XNN_UNPREDICTABLE(i6 != zero) {
154 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
155 }
156 const uint16_t* i7 = (const uint16_t*) *input++;
157 assert(i7 != NULL);
158 if XNN_UNPREDICTABLE(i7 != zero) {
159 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
160 }
161
162 uint16_t* b = (uint16_t*) buffer;
163 for (size_t c = 0; c < channels; c += 8) {
164 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
165 i0 += 8;
166 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
167 i1 += 8;
168 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
169 i2 += 8;
170 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
171 i3 += 8;
172 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
173 i4 += 8;
174 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
175 i5 += 8;
176 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
177 i6 += 8;
178 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
179 i7 += 8;
180 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
181
182 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
183 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
184 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
185 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
186 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
187 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
188 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
189 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
190
191 _mm_storeu_si128((__m128i*) b, _mm256_cvtps_ph(vsum, _MM_FROUND_NO_EXC));
192 b += 8;
193 }
194 }
195
196 {
197 const uint16_t* i0 = (const uint16_t*) input[0];
198 assert(i0 != NULL);
199 const uint16_t* i1 = (const uint16_t*) input[1];
200 const uint16_t* i2 = (const uint16_t*) input[2];
201 const uint16_t* i3 = (const uint16_t*) input[3];
202 const uint16_t* i4 = (const uint16_t*) input[4];
203 const uint16_t* i5 = (const uint16_t*) input[5];
204 const uint16_t* i6 = (const uint16_t*) input[6];
205 const uint16_t* i7 = (const uint16_t*) input[7];
206 input = (const void**) ((uintptr_t) input + input_increment);
207 if (k < 2) {
208 i1 = (const uint16_t*) zero;
209 }
210 assert(i1 != NULL);
211 if (k <= 2) {
212 i2 = (const uint16_t*) zero;
213 }
214 assert(i2 != NULL);
215 if (k < 4) {
216 i3 = (const uint16_t*) zero;
217 }
218 assert(i3 != NULL);
219 if (k <= 4) {
220 i4 = (const uint16_t*) zero;
221 }
222 assert(i4 != NULL);
223 if (k < 6) {
224 i5 = (const uint16_t*) zero;
225 }
226 assert(i5 != NULL);
227 if (k <= 6) {
228 i6 = (const uint16_t*) zero;
229 }
230 assert(i6 != NULL);
231 if (k < 8) {
232 i7 = (const uint16_t*) zero;
233 }
234 assert(i7 != NULL);
235 if XNN_UNPREDICTABLE(i0 != zero) {
236 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
237 }
238 if XNN_UNPREDICTABLE(i1 != zero) {
239 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
240 }
241 if XNN_UNPREDICTABLE(i2 != zero) {
242 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
243 }
244 if XNN_UNPREDICTABLE(i3 != zero) {
245 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
246 }
247 if XNN_UNPREDICTABLE(i4 != zero) {
248 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
249 }
250 if XNN_UNPREDICTABLE(i5 != zero) {
251 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
252 }
253 if XNN_UNPREDICTABLE(i6 != zero) {
254 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
255 }
256 if XNN_UNPREDICTABLE(i7 != zero) {
257 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
258 }
259
260 const __m256 vmultiplier = _mm256_cvtph_ps(_mm_set1_epi16((short) *((const uint16_t*) multiplier)));
261 multiplier = (const uint16_t*) multiplier + 1;
262
263 size_t c = channels;
264 const uint16_t* b = (const uint16_t*) buffer;
265 while (c >= 8) {
266 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
267 i0 += 8;
268 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
269 i1 += 8;
270 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
271 i2 += 8;
272 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
273 i3 += 8;
274 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
275 i4 += 8;
276 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
277 i5 += 8;
278 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
279 i6 += 8;
280 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
281 i7 += 8;
282 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
283 b += 8;
284
285 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
286 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
287 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
288 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
289 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
290 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
291 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
292 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
293
294 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vmultiplier), _MM_FROUND_NO_EXC));
295 vout = _mm256_max_ps(vout, voutput_min);
296 vout = _mm256_min_ps(vout, voutput_max);
297
298 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC));
299 o += 8;
300
301 c -= 8;
302 }
303 if (c != 0) {
304 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
305 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
306 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
307 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
308 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
309 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
310 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
311 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
312 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
313
314 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
315 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
316 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
317 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
318 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
319 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
320 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
321 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
322
323 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vmultiplier), _MM_FROUND_NO_EXC));
324 vout = _mm256_max_ps(vout, voutput_min);
325 vout = _mm256_min_ps(vout, voutput_max);
326
327 __m128i vh = _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC);
328 if (c & 4) {
329 _mm_storel_epi64((__m128i*) o, vh);
330 vh = _mm_unpackhi_epi64(vh, vh);
331 o += 4;
332 }
333 if (c & 2) {
334 _mm_storeu_si32(o, vh);
335 vh = _mm_srli_epi64(vh, 32);
336 o += 2;
337 }
338 if (c & 1) {
339 *o = (uint16_t) _mm_extract_epi16(vh, 0);
340 o += 1;
341 }
342 }
343 }
344 o = (uint16_t*) ((uintptr_t) o + output_increment);
345 } while (--output_pixels != 0);
346 }
347