• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <xnnpack/avgpool.h>
14 
15 
xnn_q8_avgpool_ukernel_mp9p8q__sse2(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,int32_t * buffer,uint8_t * output,size_t input_increment,size_t output_increment,const union xnn_q8_avgpool_params params[restrict static1])16 void xnn_q8_avgpool_ukernel_mp9p8q__sse2(
17     size_t n,
18     size_t ks,
19     size_t kc,
20     const uint8_t** input,
21     const uint8_t* zero,
22     int32_t* buffer,
23     uint8_t* output,
24     size_t input_increment,
25     size_t output_increment,
26     const union xnn_q8_avgpool_params params[restrict static 1])
27 {
28   assert(n != 0);
29   assert(ks > 9);
30   assert(kc != 0);
31 
32   const __m128i vbias = _mm_load_si128((const __m128i*) &params->sse2.bias);
33   const __m128i vzero = _mm_setzero_si128();
34   const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
35   const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
36   const __m128i vright_shift = _mm_loadl_epi64((const __m128i*) params->sse2.right_shift);
37 
38   do {
39     {
40       const uint8_t* i0 = *input++;
41       const uint8_t* i1 = *input++;
42       const uint8_t* i2 = *input++;
43       const uint8_t* i3 = *input++;
44       const uint8_t* i4 = *input++;
45       const uint8_t* i5 = *input++;
46       const uint8_t* i6 = *input++;
47       const uint8_t* i7 = *input++;
48       const uint8_t* i8 = *input++;
49 
50       int32_t* acc = buffer;
51       for (size_t k = 0; k < kc; k += 8) {
52         const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
53         const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
54         const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
55         const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
56         const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
57         const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
58         const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
59         const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
60         const __m128i vi8 = _mm_loadl_epi64((const __m128i*) i8); i8 += 8;
61 
62         const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
63         const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
64         const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
65         const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
66         const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
67         const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
68         const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
69         const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
70         const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
71 
72         const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
73         const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
74         const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
75         const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
76 
77         const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
78         const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
79         const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
80 
81         const __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
82         const __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
83 
84         _mm_store_si128((__m128i*) acc, vacc_lo);
85         _mm_store_si128((__m128i*) acc + 1, vacc_hi);
86         acc += 8;
87       }
88     }
89 
90     size_t m = ks;
91     for (m -= 9; m > 8; m -= 8) {
92       const uint8_t* i0 = *input++;
93       const uint8_t* i1 = *input++;
94       const uint8_t* i2 = *input++;
95       const uint8_t* i3 = *input++;
96       const uint8_t* i4 = *input++;
97       const uint8_t* i5 = *input++;
98       const uint8_t* i6 = *input++;
99       const uint8_t* i7 = *input++;
100 
101       int32_t* acc = buffer;
102       for (size_t k = 0; k < kc; k += 8) {
103         const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
104         const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
105         const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
106         const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
107         const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
108         const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
109         const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
110         const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
111         __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
112         __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
113 
114         const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
115         const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
116         const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
117         const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
118         const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
119         const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
120         const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
121         const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
122 
123         const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
124         const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
125         const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
126         const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
127 
128         const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23);
129         const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67);
130         const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567);
131 
132         vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
133         vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
134 
135         _mm_store_si128((__m128i*) acc, vacc_lo);
136         _mm_store_si128((__m128i*) acc + 1, vacc_hi);
137         acc += 8;
138       }
139     }
140 
141     {
142       const uint8_t* i0 = input[0];
143       const uint8_t* i1 = input[1];
144       const uint8_t* i2 = input[2];
145       const uint8_t* i3 = input[3];
146       const uint8_t* i4 = input[4];
147       const uint8_t* i5 = input[5];
148       const uint8_t* i6 = input[6];
149       const uint8_t* i7 = input[7];
150       input = (const uint8_t**) ((uintptr_t) input + input_increment);
151       if (m < 2) {
152         i1 = zero;
153       }
154       if (m <= 2) {
155         i2 = zero;
156       }
157       if (m < 4) {
158         i3 = zero;
159       }
160       if (m <= 4) {
161         i4 = zero;
162       }
163       if (m < 6) {
164         i5 = zero;
165       }
166       if (m <= 6) {
167         i6 = zero;
168       }
169       if (m != 8) {
170         i7 = zero;
171       }
172 
173       size_t k = kc;
174       int32_t* acc = buffer;
175       while (k >= 8) {
176         const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
177         const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
178         const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
179         const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
180         const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
181         const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
182         const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
183         const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
184         __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
185         __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
186         acc += 8;
187 
188         const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
189         const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
190         const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
191         const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
192         const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
193         const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
194         const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
195         const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
196 
197         const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
198         const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
199         const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
200         const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
201 
202         const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23);
203         const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67);
204         const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567);
205 
206         vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
207         vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
208 
209         const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
210         const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
211 
212         const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
213         const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
214 
215         const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
216         const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
217 
218         const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
219         const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
220 
221         const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
222         const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
223 
224         const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
225         const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
226         const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
227         const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
228 
229         const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
230             _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
231         const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
232             _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
233 
234         const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
235         const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
236 
237         const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
238         const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
239 
240         __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
241         vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) &params->sse2.output_zero_point));
242         vout = _mm_packus_epi16(vout, vout);
243         vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) &params->sse2.output_max));
244         vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) &params->sse2.output_min));
245 
246         _mm_storel_epi64((__m128i*) output, vout);
247         output += 8;
248 
249         k -= 8;
250       }
251       if (k != 0) {
252         const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0);
253         const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1);
254         const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2);
255         const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3);
256         const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4);
257         const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5);
258         const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6);
259         const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7);
260         __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
261         __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
262 
263         const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
264         const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
265         const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
266         const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
267         const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
268         const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
269         const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
270         const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
271 
272         const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
273         const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
274         const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
275         const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
276 
277         const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23);
278         const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67);
279         const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567);
280 
281         vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
282         vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
283 
284         const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
285         const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
286 
287         const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
288         const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
289 
290         const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
291         const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
292 
293         const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
294         const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
295 
296         const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
297         const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
298 
299         const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
300         const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
301         const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
302         const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
303 
304         const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
305             _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
306         const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
307             _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
308 
309         const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
310         const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
311 
312         const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
313         const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
314 
315         __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
316         vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) &params->sse2.output_zero_point));
317         vout = _mm_packus_epi16(vout, vout);
318         vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) &params->sse2.output_max));
319         vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) &params->sse2.output_min));
320 
321         if (k & 4) {
322           *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout);
323           output += 4;
324           vout = _mm_srli_epi64(vout, 32);
325         }
326         if (k & 2) {
327           *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout, 0);
328           output += 2;
329           vout = _mm_srli_epi32(vout, 16);
330         }
331         if (k & 1) {
332           *((uint8_t*) output) = (uint8_t) _mm_cvtsi128_si32(vout);
333           output += 1;
334         }
335       }
336     }
337     output = (uint8_t*) ((uintptr_t) output + output_increment);
338   } while (--n != 0);
339 }
340