• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/bitdepth_conversion_avx2.h"
17 #include "aom_ports/mem.h"
18 
sign_extend_16bit_to_32bit_avx2(__m256i in,__m256i zero,__m256i * out_lo,__m256i * out_hi)19 static INLINE void sign_extend_16bit_to_32bit_avx2(__m256i in, __m256i zero,
20                                                    __m256i *out_lo,
21                                                    __m256i *out_hi) {
22   const __m256i sign_bits = _mm256_cmpgt_epi16(zero, in);
23   *out_lo = _mm256_unpacklo_epi16(in, sign_bits);
24   *out_hi = _mm256_unpackhi_epi16(in, sign_bits);
25 }
26 
hadamard_col8x2_avx2(__m256i * in,int iter)27 static void hadamard_col8x2_avx2(__m256i *in, int iter) {
28   __m256i a0 = in[0];
29   __m256i a1 = in[1];
30   __m256i a2 = in[2];
31   __m256i a3 = in[3];
32   __m256i a4 = in[4];
33   __m256i a5 = in[5];
34   __m256i a6 = in[6];
35   __m256i a7 = in[7];
36 
37   __m256i b0 = _mm256_add_epi16(a0, a1);
38   __m256i b1 = _mm256_sub_epi16(a0, a1);
39   __m256i b2 = _mm256_add_epi16(a2, a3);
40   __m256i b3 = _mm256_sub_epi16(a2, a3);
41   __m256i b4 = _mm256_add_epi16(a4, a5);
42   __m256i b5 = _mm256_sub_epi16(a4, a5);
43   __m256i b6 = _mm256_add_epi16(a6, a7);
44   __m256i b7 = _mm256_sub_epi16(a6, a7);
45 
46   a0 = _mm256_add_epi16(b0, b2);
47   a1 = _mm256_add_epi16(b1, b3);
48   a2 = _mm256_sub_epi16(b0, b2);
49   a3 = _mm256_sub_epi16(b1, b3);
50   a4 = _mm256_add_epi16(b4, b6);
51   a5 = _mm256_add_epi16(b5, b7);
52   a6 = _mm256_sub_epi16(b4, b6);
53   a7 = _mm256_sub_epi16(b5, b7);
54 
55   if (iter == 0) {
56     b0 = _mm256_add_epi16(a0, a4);
57     b7 = _mm256_add_epi16(a1, a5);
58     b3 = _mm256_add_epi16(a2, a6);
59     b4 = _mm256_add_epi16(a3, a7);
60     b2 = _mm256_sub_epi16(a0, a4);
61     b6 = _mm256_sub_epi16(a1, a5);
62     b1 = _mm256_sub_epi16(a2, a6);
63     b5 = _mm256_sub_epi16(a3, a7);
64 
65     a0 = _mm256_unpacklo_epi16(b0, b1);
66     a1 = _mm256_unpacklo_epi16(b2, b3);
67     a2 = _mm256_unpackhi_epi16(b0, b1);
68     a3 = _mm256_unpackhi_epi16(b2, b3);
69     a4 = _mm256_unpacklo_epi16(b4, b5);
70     a5 = _mm256_unpacklo_epi16(b6, b7);
71     a6 = _mm256_unpackhi_epi16(b4, b5);
72     a7 = _mm256_unpackhi_epi16(b6, b7);
73 
74     b0 = _mm256_unpacklo_epi32(a0, a1);
75     b1 = _mm256_unpacklo_epi32(a4, a5);
76     b2 = _mm256_unpackhi_epi32(a0, a1);
77     b3 = _mm256_unpackhi_epi32(a4, a5);
78     b4 = _mm256_unpacklo_epi32(a2, a3);
79     b5 = _mm256_unpacklo_epi32(a6, a7);
80     b6 = _mm256_unpackhi_epi32(a2, a3);
81     b7 = _mm256_unpackhi_epi32(a6, a7);
82 
83     in[0] = _mm256_unpacklo_epi64(b0, b1);
84     in[1] = _mm256_unpackhi_epi64(b0, b1);
85     in[2] = _mm256_unpacklo_epi64(b2, b3);
86     in[3] = _mm256_unpackhi_epi64(b2, b3);
87     in[4] = _mm256_unpacklo_epi64(b4, b5);
88     in[5] = _mm256_unpackhi_epi64(b4, b5);
89     in[6] = _mm256_unpacklo_epi64(b6, b7);
90     in[7] = _mm256_unpackhi_epi64(b6, b7);
91   } else {
92     in[0] = _mm256_add_epi16(a0, a4);
93     in[7] = _mm256_add_epi16(a1, a5);
94     in[3] = _mm256_add_epi16(a2, a6);
95     in[4] = _mm256_add_epi16(a3, a7);
96     in[2] = _mm256_sub_epi16(a0, a4);
97     in[6] = _mm256_sub_epi16(a1, a5);
98     in[1] = _mm256_sub_epi16(a2, a6);
99     in[5] = _mm256_sub_epi16(a3, a7);
100   }
101 }
102 
aom_hadamard_lp_8x8_dual_avx2(const int16_t * src_diff,ptrdiff_t src_stride,int16_t * coeff)103 void aom_hadamard_lp_8x8_dual_avx2(const int16_t *src_diff,
104                                    ptrdiff_t src_stride, int16_t *coeff) {
105   __m256i src[8];
106   src[0] = _mm256_loadu_si256((const __m256i *)src_diff);
107   src[1] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
108   src[2] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
109   src[3] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
110   src[4] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
111   src[5] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
112   src[6] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
113   src[7] = _mm256_loadu_si256((const __m256i *)(src_diff + src_stride));
114 
115   hadamard_col8x2_avx2(src, 0);
116   hadamard_col8x2_avx2(src, 1);
117 
118   _mm256_storeu_si256((__m256i *)coeff,
119                       _mm256_permute2x128_si256(src[0], src[1], 0x20));
120   coeff += 16;
121   _mm256_storeu_si256((__m256i *)coeff,
122                       _mm256_permute2x128_si256(src[2], src[3], 0x20));
123   coeff += 16;
124   _mm256_storeu_si256((__m256i *)coeff,
125                       _mm256_permute2x128_si256(src[4], src[5], 0x20));
126   coeff += 16;
127   _mm256_storeu_si256((__m256i *)coeff,
128                       _mm256_permute2x128_si256(src[6], src[7], 0x20));
129   coeff += 16;
130   _mm256_storeu_si256((__m256i *)coeff,
131                       _mm256_permute2x128_si256(src[0], src[1], 0x31));
132   coeff += 16;
133   _mm256_storeu_si256((__m256i *)coeff,
134                       _mm256_permute2x128_si256(src[2], src[3], 0x31));
135   coeff += 16;
136   _mm256_storeu_si256((__m256i *)coeff,
137                       _mm256_permute2x128_si256(src[4], src[5], 0x31));
138   coeff += 16;
139   _mm256_storeu_si256((__m256i *)coeff,
140                       _mm256_permute2x128_si256(src[6], src[7], 0x31));
141 }
142 
hadamard_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff,int is_final)143 static INLINE void hadamard_16x16_avx2(const int16_t *src_diff,
144                                        ptrdiff_t src_stride, tran_low_t *coeff,
145                                        int is_final) {
146   DECLARE_ALIGNED(32, int16_t, temp_coeff[16 * 16]);
147   int16_t *t_coeff = temp_coeff;
148   int16_t *coeff16 = (int16_t *)coeff;
149   int idx;
150   for (idx = 0; idx < 2; ++idx) {
151     const int16_t *src_ptr = src_diff + idx * 8 * src_stride;
152     aom_hadamard_lp_8x8_dual_avx2(src_ptr, src_stride,
153                                   t_coeff + (idx * 64 * 2));
154   }
155 
156   for (idx = 0; idx < 64; idx += 16) {
157     const __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
158     const __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 64));
159     const __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 128));
160     const __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 192));
161 
162     __m256i b0 = _mm256_add_epi16(coeff0, coeff1);
163     __m256i b1 = _mm256_sub_epi16(coeff0, coeff1);
164     __m256i b2 = _mm256_add_epi16(coeff2, coeff3);
165     __m256i b3 = _mm256_sub_epi16(coeff2, coeff3);
166 
167     b0 = _mm256_srai_epi16(b0, 1);
168     b1 = _mm256_srai_epi16(b1, 1);
169     b2 = _mm256_srai_epi16(b2, 1);
170     b3 = _mm256_srai_epi16(b3, 1);
171     if (is_final) {
172       store_tran_low(_mm256_add_epi16(b0, b2), coeff);
173       store_tran_low(_mm256_add_epi16(b1, b3), coeff + 64);
174       store_tran_low(_mm256_sub_epi16(b0, b2), coeff + 128);
175       store_tran_low(_mm256_sub_epi16(b1, b3), coeff + 192);
176       coeff += 16;
177     } else {
178       _mm256_storeu_si256((__m256i *)coeff16, _mm256_add_epi16(b0, b2));
179       _mm256_storeu_si256((__m256i *)(coeff16 + 64), _mm256_add_epi16(b1, b3));
180       _mm256_storeu_si256((__m256i *)(coeff16 + 128), _mm256_sub_epi16(b0, b2));
181       _mm256_storeu_si256((__m256i *)(coeff16 + 192), _mm256_sub_epi16(b1, b3));
182       coeff16 += 16;
183     }
184     t_coeff += 16;
185   }
186 }
187 
aom_hadamard_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)188 void aom_hadamard_16x16_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
189                              tran_low_t *coeff) {
190   hadamard_16x16_avx2(src_diff, src_stride, coeff, 1);
191 }
192 
aom_hadamard_lp_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,int16_t * coeff)193 void aom_hadamard_lp_16x16_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
194                                 int16_t *coeff) {
195   int16_t *t_coeff = coeff;
196   for (int idx = 0; idx < 2; ++idx) {
197     const int16_t *src_ptr = src_diff + idx * 8 * src_stride;
198     aom_hadamard_lp_8x8_dual_avx2(src_ptr, src_stride,
199                                   t_coeff + (idx * 64 * 2));
200   }
201 
202   for (int idx = 0; idx < 64; idx += 16) {
203     const __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
204     const __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 64));
205     const __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 128));
206     const __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 192));
207 
208     __m256i b0 = _mm256_add_epi16(coeff0, coeff1);
209     __m256i b1 = _mm256_sub_epi16(coeff0, coeff1);
210     __m256i b2 = _mm256_add_epi16(coeff2, coeff3);
211     __m256i b3 = _mm256_sub_epi16(coeff2, coeff3);
212 
213     b0 = _mm256_srai_epi16(b0, 1);
214     b1 = _mm256_srai_epi16(b1, 1);
215     b2 = _mm256_srai_epi16(b2, 1);
216     b3 = _mm256_srai_epi16(b3, 1);
217     _mm256_storeu_si256((__m256i *)coeff, _mm256_add_epi16(b0, b2));
218     _mm256_storeu_si256((__m256i *)(coeff + 64), _mm256_add_epi16(b1, b3));
219     _mm256_storeu_si256((__m256i *)(coeff + 128), _mm256_sub_epi16(b0, b2));
220     _mm256_storeu_si256((__m256i *)(coeff + 192), _mm256_sub_epi16(b1, b3));
221     coeff += 16;
222     t_coeff += 16;
223   }
224 }
225 
aom_hadamard_32x32_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)226 void aom_hadamard_32x32_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
227                              tran_low_t *coeff) {
228   // For high bitdepths, it is unnecessary to store_tran_low
229   // (mult/unpack/store), then load_tran_low (load/pack) the same memory in the
230   // next stage.  Output to an intermediate buffer first, then store_tran_low()
231   // in the final stage.
232   DECLARE_ALIGNED(32, int16_t, temp_coeff[32 * 32]);
233   int16_t *t_coeff = temp_coeff;
234   int idx;
235   __m256i coeff0_lo, coeff1_lo, coeff2_lo, coeff3_lo, b0_lo, b1_lo, b2_lo,
236       b3_lo;
237   __m256i coeff0_hi, coeff1_hi, coeff2_hi, coeff3_hi, b0_hi, b1_hi, b2_hi,
238       b3_hi;
239   __m256i b0, b1, b2, b3;
240   const __m256i zero = _mm256_setzero_si256();
241   for (idx = 0; idx < 4; ++idx) {
242     // src_diff: 9 bit, dynamic range [-255, 255]
243     const int16_t *src_ptr =
244         src_diff + (idx >> 1) * 16 * src_stride + (idx & 0x01) * 16;
245     hadamard_16x16_avx2(src_ptr, src_stride,
246                         (tran_low_t *)(t_coeff + idx * 256), 0);
247   }
248 
249   for (idx = 0; idx < 256; idx += 16) {
250     const __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
251     const __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 256));
252     const __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 512));
253     const __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 768));
254 
255     // Sign extend 16 bit to 32 bit.
256     sign_extend_16bit_to_32bit_avx2(coeff0, zero, &coeff0_lo, &coeff0_hi);
257     sign_extend_16bit_to_32bit_avx2(coeff1, zero, &coeff1_lo, &coeff1_hi);
258     sign_extend_16bit_to_32bit_avx2(coeff2, zero, &coeff2_lo, &coeff2_hi);
259     sign_extend_16bit_to_32bit_avx2(coeff3, zero, &coeff3_lo, &coeff3_hi);
260 
261     b0_lo = _mm256_add_epi32(coeff0_lo, coeff1_lo);
262     b0_hi = _mm256_add_epi32(coeff0_hi, coeff1_hi);
263 
264     b1_lo = _mm256_sub_epi32(coeff0_lo, coeff1_lo);
265     b1_hi = _mm256_sub_epi32(coeff0_hi, coeff1_hi);
266 
267     b2_lo = _mm256_add_epi32(coeff2_lo, coeff3_lo);
268     b2_hi = _mm256_add_epi32(coeff2_hi, coeff3_hi);
269 
270     b3_lo = _mm256_sub_epi32(coeff2_lo, coeff3_lo);
271     b3_hi = _mm256_sub_epi32(coeff2_hi, coeff3_hi);
272 
273     b0_lo = _mm256_srai_epi32(b0_lo, 2);
274     b1_lo = _mm256_srai_epi32(b1_lo, 2);
275     b2_lo = _mm256_srai_epi32(b2_lo, 2);
276     b3_lo = _mm256_srai_epi32(b3_lo, 2);
277 
278     b0_hi = _mm256_srai_epi32(b0_hi, 2);
279     b1_hi = _mm256_srai_epi32(b1_hi, 2);
280     b2_hi = _mm256_srai_epi32(b2_hi, 2);
281     b3_hi = _mm256_srai_epi32(b3_hi, 2);
282 
283     b0 = _mm256_packs_epi32(b0_lo, b0_hi);
284     b1 = _mm256_packs_epi32(b1_lo, b1_hi);
285     b2 = _mm256_packs_epi32(b2_lo, b2_hi);
286     b3 = _mm256_packs_epi32(b3_lo, b3_hi);
287 
288     store_tran_low(_mm256_add_epi16(b0, b2), coeff);
289     store_tran_low(_mm256_add_epi16(b1, b3), coeff + 256);
290     store_tran_low(_mm256_sub_epi16(b0, b2), coeff + 512);
291     store_tran_low(_mm256_sub_epi16(b1, b3), coeff + 768);
292 
293     coeff += 16;
294     t_coeff += 16;
295   }
296 }
297 
298 #if CONFIG_AV1_HIGHBITDEPTH
highbd_hadamard_col8_avx2(__m256i * in,int iter)299 static void highbd_hadamard_col8_avx2(__m256i *in, int iter) {
300   __m256i a0 = in[0];
301   __m256i a1 = in[1];
302   __m256i a2 = in[2];
303   __m256i a3 = in[3];
304   __m256i a4 = in[4];
305   __m256i a5 = in[5];
306   __m256i a6 = in[6];
307   __m256i a7 = in[7];
308 
309   __m256i b0 = _mm256_add_epi32(a0, a1);
310   __m256i b1 = _mm256_sub_epi32(a0, a1);
311   __m256i b2 = _mm256_add_epi32(a2, a3);
312   __m256i b3 = _mm256_sub_epi32(a2, a3);
313   __m256i b4 = _mm256_add_epi32(a4, a5);
314   __m256i b5 = _mm256_sub_epi32(a4, a5);
315   __m256i b6 = _mm256_add_epi32(a6, a7);
316   __m256i b7 = _mm256_sub_epi32(a6, a7);
317 
318   a0 = _mm256_add_epi32(b0, b2);
319   a1 = _mm256_add_epi32(b1, b3);
320   a2 = _mm256_sub_epi32(b0, b2);
321   a3 = _mm256_sub_epi32(b1, b3);
322   a4 = _mm256_add_epi32(b4, b6);
323   a5 = _mm256_add_epi32(b5, b7);
324   a6 = _mm256_sub_epi32(b4, b6);
325   a7 = _mm256_sub_epi32(b5, b7);
326 
327   if (iter == 0) {
328     b0 = _mm256_add_epi32(a0, a4);
329     b7 = _mm256_add_epi32(a1, a5);
330     b3 = _mm256_add_epi32(a2, a6);
331     b4 = _mm256_add_epi32(a3, a7);
332     b2 = _mm256_sub_epi32(a0, a4);
333     b6 = _mm256_sub_epi32(a1, a5);
334     b1 = _mm256_sub_epi32(a2, a6);
335     b5 = _mm256_sub_epi32(a3, a7);
336 
337     a0 = _mm256_unpacklo_epi32(b0, b1);
338     a1 = _mm256_unpacklo_epi32(b2, b3);
339     a2 = _mm256_unpackhi_epi32(b0, b1);
340     a3 = _mm256_unpackhi_epi32(b2, b3);
341     a4 = _mm256_unpacklo_epi32(b4, b5);
342     a5 = _mm256_unpacklo_epi32(b6, b7);
343     a6 = _mm256_unpackhi_epi32(b4, b5);
344     a7 = _mm256_unpackhi_epi32(b6, b7);
345 
346     b0 = _mm256_unpacklo_epi64(a0, a1);
347     b1 = _mm256_unpacklo_epi64(a4, a5);
348     b2 = _mm256_unpackhi_epi64(a0, a1);
349     b3 = _mm256_unpackhi_epi64(a4, a5);
350     b4 = _mm256_unpacklo_epi64(a2, a3);
351     b5 = _mm256_unpacklo_epi64(a6, a7);
352     b6 = _mm256_unpackhi_epi64(a2, a3);
353     b7 = _mm256_unpackhi_epi64(a6, a7);
354 
355     in[0] = _mm256_permute2x128_si256(b0, b1, 0x20);
356     in[1] = _mm256_permute2x128_si256(b0, b1, 0x31);
357     in[2] = _mm256_permute2x128_si256(b2, b3, 0x20);
358     in[3] = _mm256_permute2x128_si256(b2, b3, 0x31);
359     in[4] = _mm256_permute2x128_si256(b4, b5, 0x20);
360     in[5] = _mm256_permute2x128_si256(b4, b5, 0x31);
361     in[6] = _mm256_permute2x128_si256(b6, b7, 0x20);
362     in[7] = _mm256_permute2x128_si256(b6, b7, 0x31);
363   } else {
364     in[0] = _mm256_add_epi32(a0, a4);
365     in[7] = _mm256_add_epi32(a1, a5);
366     in[3] = _mm256_add_epi32(a2, a6);
367     in[4] = _mm256_add_epi32(a3, a7);
368     in[2] = _mm256_sub_epi32(a0, a4);
369     in[6] = _mm256_sub_epi32(a1, a5);
370     in[1] = _mm256_sub_epi32(a2, a6);
371     in[5] = _mm256_sub_epi32(a3, a7);
372   }
373 }
374 
aom_highbd_hadamard_8x8_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)375 void aom_highbd_hadamard_8x8_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
376                                   tran_low_t *coeff) {
377   __m128i src16[8];
378   __m256i src32[8];
379 
380   src16[0] = _mm_loadu_si128((const __m128i *)src_diff);
381   src16[1] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
382   src16[2] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
383   src16[3] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
384   src16[4] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
385   src16[5] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
386   src16[6] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
387   src16[7] = _mm_loadu_si128((const __m128i *)(src_diff + src_stride));
388 
389   src32[0] = _mm256_cvtepi16_epi32(src16[0]);
390   src32[1] = _mm256_cvtepi16_epi32(src16[1]);
391   src32[2] = _mm256_cvtepi16_epi32(src16[2]);
392   src32[3] = _mm256_cvtepi16_epi32(src16[3]);
393   src32[4] = _mm256_cvtepi16_epi32(src16[4]);
394   src32[5] = _mm256_cvtepi16_epi32(src16[5]);
395   src32[6] = _mm256_cvtepi16_epi32(src16[6]);
396   src32[7] = _mm256_cvtepi16_epi32(src16[7]);
397 
398   highbd_hadamard_col8_avx2(src32, 0);
399   highbd_hadamard_col8_avx2(src32, 1);
400 
401   _mm256_storeu_si256((__m256i *)coeff, src32[0]);
402   coeff += 8;
403   _mm256_storeu_si256((__m256i *)coeff, src32[1]);
404   coeff += 8;
405   _mm256_storeu_si256((__m256i *)coeff, src32[2]);
406   coeff += 8;
407   _mm256_storeu_si256((__m256i *)coeff, src32[3]);
408   coeff += 8;
409   _mm256_storeu_si256((__m256i *)coeff, src32[4]);
410   coeff += 8;
411   _mm256_storeu_si256((__m256i *)coeff, src32[5]);
412   coeff += 8;
413   _mm256_storeu_si256((__m256i *)coeff, src32[6]);
414   coeff += 8;
415   _mm256_storeu_si256((__m256i *)coeff, src32[7]);
416 }
417 
aom_highbd_hadamard_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)418 void aom_highbd_hadamard_16x16_avx2(const int16_t *src_diff,
419                                     ptrdiff_t src_stride, tran_low_t *coeff) {
420   int idx;
421   tran_low_t *t_coeff = coeff;
422   for (idx = 0; idx < 4; ++idx) {
423     const int16_t *src_ptr =
424         src_diff + (idx >> 1) * 8 * src_stride + (idx & 0x01) * 8;
425     aom_highbd_hadamard_8x8_avx2(src_ptr, src_stride, t_coeff + idx * 64);
426   }
427 
428   for (idx = 0; idx < 64; idx += 8) {
429     __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
430     __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 64));
431     __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 128));
432     __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 192));
433 
434     __m256i b0 = _mm256_add_epi32(coeff0, coeff1);
435     __m256i b1 = _mm256_sub_epi32(coeff0, coeff1);
436     __m256i b2 = _mm256_add_epi32(coeff2, coeff3);
437     __m256i b3 = _mm256_sub_epi32(coeff2, coeff3);
438 
439     b0 = _mm256_srai_epi32(b0, 1);
440     b1 = _mm256_srai_epi32(b1, 1);
441     b2 = _mm256_srai_epi32(b2, 1);
442     b3 = _mm256_srai_epi32(b3, 1);
443 
444     coeff0 = _mm256_add_epi32(b0, b2);
445     coeff1 = _mm256_add_epi32(b1, b3);
446     coeff2 = _mm256_sub_epi32(b0, b2);
447     coeff3 = _mm256_sub_epi32(b1, b3);
448 
449     _mm256_storeu_si256((__m256i *)coeff, coeff0);
450     _mm256_storeu_si256((__m256i *)(coeff + 64), coeff1);
451     _mm256_storeu_si256((__m256i *)(coeff + 128), coeff2);
452     _mm256_storeu_si256((__m256i *)(coeff + 192), coeff3);
453 
454     coeff += 8;
455     t_coeff += 8;
456   }
457 }
458 
aom_highbd_hadamard_32x32_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)459 void aom_highbd_hadamard_32x32_avx2(const int16_t *src_diff,
460                                     ptrdiff_t src_stride, tran_low_t *coeff) {
461   int idx;
462   tran_low_t *t_coeff = coeff;
463   for (idx = 0; idx < 4; ++idx) {
464     const int16_t *src_ptr =
465         src_diff + (idx >> 1) * 16 * src_stride + (idx & 0x01) * 16;
466     aom_highbd_hadamard_16x16_avx2(src_ptr, src_stride, t_coeff + idx * 256);
467   }
468 
469   for (idx = 0; idx < 256; idx += 8) {
470     __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
471     __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 256));
472     __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 512));
473     __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 768));
474 
475     __m256i b0 = _mm256_add_epi32(coeff0, coeff1);
476     __m256i b1 = _mm256_sub_epi32(coeff0, coeff1);
477     __m256i b2 = _mm256_add_epi32(coeff2, coeff3);
478     __m256i b3 = _mm256_sub_epi32(coeff2, coeff3);
479 
480     b0 = _mm256_srai_epi32(b0, 2);
481     b1 = _mm256_srai_epi32(b1, 2);
482     b2 = _mm256_srai_epi32(b2, 2);
483     b3 = _mm256_srai_epi32(b3, 2);
484 
485     coeff0 = _mm256_add_epi32(b0, b2);
486     coeff1 = _mm256_add_epi32(b1, b3);
487     coeff2 = _mm256_sub_epi32(b0, b2);
488     coeff3 = _mm256_sub_epi32(b1, b3);
489 
490     _mm256_storeu_si256((__m256i *)coeff, coeff0);
491     _mm256_storeu_si256((__m256i *)(coeff + 256), coeff1);
492     _mm256_storeu_si256((__m256i *)(coeff + 512), coeff2);
493     _mm256_storeu_si256((__m256i *)(coeff + 768), coeff3);
494 
495     coeff += 8;
496     t_coeff += 8;
497   }
498 }
499 #endif  // CONFIG_AV1_HIGHBITDEPTH
500 
aom_satd_avx2(const tran_low_t * coeff,int length)501 int aom_satd_avx2(const tran_low_t *coeff, int length) {
502   __m256i accum = _mm256_setzero_si256();
503   int i;
504 
505   for (i = 0; i < length; i += 8, coeff += 8) {
506     const __m256i src_line = _mm256_loadu_si256((const __m256i *)coeff);
507     const __m256i abs = _mm256_abs_epi32(src_line);
508     accum = _mm256_add_epi32(accum, abs);
509   }
510 
511   {  // 32 bit horizontal add
512     const __m256i a = _mm256_srli_si256(accum, 8);
513     const __m256i b = _mm256_add_epi32(accum, a);
514     const __m256i c = _mm256_srli_epi64(b, 32);
515     const __m256i d = _mm256_add_epi32(b, c);
516     const __m128i accum_128 = _mm_add_epi32(_mm256_castsi256_si128(d),
517                                             _mm256_extractf128_si256(d, 1));
518     return _mm_cvtsi128_si32(accum_128);
519   }
520 }
521 
aom_satd_lp_avx2(const int16_t * coeff,int length)522 int aom_satd_lp_avx2(const int16_t *coeff, int length) {
523   const __m256i one = _mm256_set1_epi16(1);
524   __m256i accum = _mm256_setzero_si256();
525 
526   for (int i = 0; i < length; i += 16) {
527     const __m256i src_line = _mm256_loadu_si256((const __m256i *)coeff);
528     const __m256i abs = _mm256_abs_epi16(src_line);
529     const __m256i sum = _mm256_madd_epi16(abs, one);
530     accum = _mm256_add_epi32(accum, sum);
531     coeff += 16;
532   }
533 
534   {  // 32 bit horizontal add
535     const __m256i a = _mm256_srli_si256(accum, 8);
536     const __m256i b = _mm256_add_epi32(accum, a);
537     const __m256i c = _mm256_srli_epi64(b, 32);
538     const __m256i d = _mm256_add_epi32(b, c);
539     const __m128i accum_128 = _mm_add_epi32(_mm256_castsi256_si128(d),
540                                             _mm256_extractf128_si256(d, 1));
541     return _mm_cvtsi128_si32(accum_128);
542   }
543 }
544 
xx_loadu2_mi128(const void * hi,const void * lo)545 static INLINE __m256i xx_loadu2_mi128(const void *hi, const void *lo) {
546   __m256i a = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(lo)));
547   a = _mm256_inserti128_si256(a, _mm_loadu_si128((const __m128i *)(hi)), 1);
548   return a;
549 }
550 
aom_avg_8x8_quad_avx2(const uint8_t * s,int p,int x16_idx,int y16_idx,int * avg)551 void aom_avg_8x8_quad_avx2(const uint8_t *s, int p, int x16_idx, int y16_idx,
552                            int *avg) {
553   const uint8_t *s_y0 = s + y16_idx * p + x16_idx;
554   const uint8_t *s_y1 = s_y0 + 8 * p;
555   __m256i sum0, sum1, s0, s1, s2, s3, u0;
556   u0 = _mm256_setzero_si256();
557   s0 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1, s_y0), u0);
558   s1 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1 + p, s_y0 + p), u0);
559   s2 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1 + 2 * p, s_y0 + 2 * p), u0);
560   s3 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1 + 3 * p, s_y0 + 3 * p), u0);
561   sum0 = _mm256_add_epi16(s0, s1);
562   sum1 = _mm256_add_epi16(s2, s3);
563   s0 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1 + 4 * p, s_y0 + 4 * p), u0);
564   s1 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1 + 5 * p, s_y0 + 5 * p), u0);
565   s2 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1 + 6 * p, s_y0 + 6 * p), u0);
566   s3 = _mm256_sad_epu8(xx_loadu2_mi128(s_y1 + 7 * p, s_y0 + 7 * p), u0);
567   sum0 = _mm256_add_epi16(sum0, _mm256_add_epi16(s0, s1));
568   sum1 = _mm256_add_epi16(sum1, _mm256_add_epi16(s2, s3));
569   sum0 = _mm256_add_epi16(sum0, sum1);
570 
571   // (avg + 32) >> 6
572   __m256i rounding = _mm256_set1_epi32(32);
573   sum0 = _mm256_add_epi32(sum0, rounding);
574   sum0 = _mm256_srli_epi32(sum0, 6);
575   __m128i lo = _mm256_castsi256_si128(sum0);
576   __m128i hi = _mm256_extracti128_si256(sum0, 1);
577   avg[0] = _mm_cvtsi128_si32(lo);
578   avg[1] = _mm_extract_epi32(lo, 2);
579   avg[2] = _mm_cvtsi128_si32(hi);
580   avg[3] = _mm_extract_epi32(hi, 2);
581 }
582 
aom_int_pro_row_avx2(int16_t * hbuf,const uint8_t * ref,const int ref_stride,const int width,const int height,int norm_factor)583 void aom_int_pro_row_avx2(int16_t *hbuf, const uint8_t *ref,
584                           const int ref_stride, const int width,
585                           const int height, int norm_factor) {
586   // SIMD implementation assumes width and height to be multiple of 16 and 2
587   // respectively. For any odd width or height, SIMD support needs to be added.
588   assert(width % 16 == 0 && height % 2 == 0);
589 
590   if (width % 32 == 0) {
591     const __m256i zero = _mm256_setzero_si256();
592     for (int wd = 0; wd < width; wd += 32) {
593       const uint8_t *ref_tmp = ref + wd;
594       int16_t *hbuf_tmp = hbuf + wd;
595       __m256i s0 = zero;
596       __m256i s1 = zero;
597       int idx = 0;
598       do {
599         __m256i src_line = _mm256_loadu_si256((const __m256i *)ref_tmp);
600         __m256i t0 = _mm256_unpacklo_epi8(src_line, zero);
601         __m256i t1 = _mm256_unpackhi_epi8(src_line, zero);
602         s0 = _mm256_add_epi16(s0, t0);
603         s1 = _mm256_add_epi16(s1, t1);
604         ref_tmp += ref_stride;
605 
606         src_line = _mm256_loadu_si256((const __m256i *)ref_tmp);
607         t0 = _mm256_unpacklo_epi8(src_line, zero);
608         t1 = _mm256_unpackhi_epi8(src_line, zero);
609         s0 = _mm256_add_epi16(s0, t0);
610         s1 = _mm256_add_epi16(s1, t1);
611         ref_tmp += ref_stride;
612         idx += 2;
613       } while (idx < height);
614       s0 = _mm256_srai_epi16(s0, norm_factor);
615       s1 = _mm256_srai_epi16(s1, norm_factor);
616       _mm_storeu_si128((__m128i *)(hbuf_tmp), _mm256_castsi256_si128(s0));
617       _mm_storeu_si128((__m128i *)(hbuf_tmp + 8), _mm256_castsi256_si128(s1));
618       _mm_storeu_si128((__m128i *)(hbuf_tmp + 16),
619                        _mm256_extractf128_si256(s0, 1));
620       _mm_storeu_si128((__m128i *)(hbuf_tmp + 24),
621                        _mm256_extractf128_si256(s1, 1));
622     }
623   } else if (width % 16 == 0) {
624     aom_int_pro_row_sse2(hbuf, ref, ref_stride, width, height, norm_factor);
625   }
626 }
627 
load_from_src_buf(const uint8_t * ref1,__m256i * src,const int stride)628 static INLINE void load_from_src_buf(const uint8_t *ref1, __m256i *src,
629                                      const int stride) {
630   src[0] = _mm256_loadu_si256((const __m256i *)ref1);
631   src[1] = _mm256_loadu_si256((const __m256i *)(ref1 + stride));
632   src[2] = _mm256_loadu_si256((const __m256i *)(ref1 + (2 * stride)));
633   src[3] = _mm256_loadu_si256((const __m256i *)(ref1 + (3 * stride)));
634 }
635 
636 #define CALC_TOT_SAD_AND_STORE                                                \
637   /* r00 r10 x x r01 r11 x x | r02 r12 x x r03 r13 x x */                     \
638   const __m256i r01 = _mm256_add_epi16(_mm256_slli_si256(r1, 2), r0);         \
639   /* r00 r10 r20 x r01 r11 r21 x | r02 r12 r22 x r03 r13 r23 x */             \
640   const __m256i r012 = _mm256_add_epi16(_mm256_slli_si256(r2, 4), r01);       \
641   /* r00 r10 r20 r30 r01 r11 r21 r31 | r02 r12 r22 r32 r03 r13 r23 r33 */     \
642   const __m256i result0 = _mm256_add_epi16(_mm256_slli_si256(r3, 6), r012);   \
643                                                                               \
644   const __m128i results0 = _mm_add_epi16(                                     \
645       _mm256_castsi256_si128(result0), _mm256_extractf128_si256(result0, 1)); \
646   const __m128i results1 =                                                    \
647       _mm_add_epi16(results0, _mm_srli_si128(results0, 8));                   \
648   _mm_storel_epi64((__m128i *)vbuf, _mm_srli_epi16(results1, norm_factor));
649 
aom_int_pro_col_16wd_avx2(int16_t * vbuf,const uint8_t * ref,const int ref_stride,const int height,int norm_factor)650 static INLINE void aom_int_pro_col_16wd_avx2(int16_t *vbuf, const uint8_t *ref,
651                                              const int ref_stride,
652                                              const int height,
653                                              int norm_factor) {
654   const __m256i zero = _mm256_setzero_si256();
655   int ht = 0;
656   // Post sad operation, the data is present in lower 16-bit of each 64-bit lane
657   // and higher 16-bits are Zero. Here, we are processing 8 rows at a time to
658   // utilize the higher 16-bits efficiently.
659   do {
660     __m256i src_00 =
661         _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(ref)));
662     src_00 = _mm256_inserti128_si256(
663         src_00, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 4)), 1);
664     __m256i src_01 = _mm256_castsi128_si256(
665         _mm_loadu_si128((const __m128i *)(ref + ref_stride * 1)));
666     src_01 = _mm256_inserti128_si256(
667         src_01, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 5)), 1);
668     __m256i src_10 = _mm256_castsi128_si256(
669         _mm_loadu_si128((const __m128i *)(ref + ref_stride * 2)));
670     src_10 = _mm256_inserti128_si256(
671         src_10, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 6)), 1);
672     __m256i src_11 = _mm256_castsi128_si256(
673         _mm_loadu_si128((const __m128i *)(ref + ref_stride * 3)));
674     src_11 = _mm256_inserti128_si256(
675         src_11, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 7)), 1);
676 
677     // s00 x x x s01 x x x | s40 x x x s41 x x x
678     const __m256i s0 = _mm256_sad_epu8(src_00, zero);
679     // s10 x x x s11 x x x | s50 x x x s51 x x x
680     const __m256i s1 = _mm256_sad_epu8(src_01, zero);
681     // s20 x x x s21 x x x | s60 x x x s61 x x x
682     const __m256i s2 = _mm256_sad_epu8(src_10, zero);
683     // s30 x x x s31 x x x | s70 x x x s71 x x x
684     const __m256i s3 = _mm256_sad_epu8(src_11, zero);
685 
686     // s00 s10 x x x x x x | s40 s50 x x x x x x
687     const __m256i s0_lo = _mm256_unpacklo_epi16(s0, s1);
688     // s01 s11 x x x x x x | s41 s51 x x x x x x
689     const __m256i s0_hi = _mm256_unpackhi_epi16(s0, s1);
690     // s20 s30 x x x x x x | s60 s70 x x x x x x
691     const __m256i s1_lo = _mm256_unpacklo_epi16(s2, s3);
692     // s21 s31 x x x x x x | s61 s71 x x x x x x
693     const __m256i s1_hi = _mm256_unpackhi_epi16(s2, s3);
694 
695     // s0 s1 x x x x x x | s4 s5 x x x x x x
696     const __m256i s0_add = _mm256_add_epi16(s0_lo, s0_hi);
697     // s2 s3 x x x x x x | s6 s7 x x x x x x
698     const __m256i s1_add = _mm256_add_epi16(s1_lo, s1_hi);
699 
700     // s1 s1 s2 s3 s4 s5 s6 s7
701     const __m128i results = _mm256_castsi256_si128(
702         _mm256_permute4x64_epi64(_mm256_unpacklo_epi32(s0_add, s1_add), 0x08));
703     _mm_storeu_si128((__m128i *)vbuf, _mm_srli_epi16(results, norm_factor));
704     vbuf += 8;
705     ref += (ref_stride << 3);
706     ht += 8;
707   } while (ht < height);
708 }
709 
aom_int_pro_col_avx2(int16_t * vbuf,const uint8_t * ref,const int ref_stride,const int width,const int height,int norm_factor)710 void aom_int_pro_col_avx2(int16_t *vbuf, const uint8_t *ref,
711                           const int ref_stride, const int width,
712                           const int height, int norm_factor) {
713   assert(width % 16 == 0);
714   if (width == 128) {
715     const __m256i zero = _mm256_setzero_si256();
716     for (int ht = 0; ht < height; ht += 4) {
717       __m256i src[16];
718       // Load source data.
719       load_from_src_buf(ref, &src[0], ref_stride);
720       load_from_src_buf(ref + 32, &src[4], ref_stride);
721       load_from_src_buf(ref + 64, &src[8], ref_stride);
722       load_from_src_buf(ref + 96, &src[12], ref_stride);
723 
724       // Row0 output: r00 x x x r01 x x x | r02 x x x r03 x x x
725       const __m256i s0 = _mm256_add_epi16(_mm256_sad_epu8(src[0], zero),
726                                           _mm256_sad_epu8(src[4], zero));
727       const __m256i s1 = _mm256_add_epi16(_mm256_sad_epu8(src[8], zero),
728                                           _mm256_sad_epu8(src[12], zero));
729       const __m256i r0 = _mm256_add_epi16(s0, s1);
730       // Row1 output: r10 x x x r11 x x x | r12 x x x r13 x x x
731       const __m256i s2 = _mm256_add_epi16(_mm256_sad_epu8(src[1], zero),
732                                           _mm256_sad_epu8(src[5], zero));
733       const __m256i s3 = _mm256_add_epi16(_mm256_sad_epu8(src[9], zero),
734                                           _mm256_sad_epu8(src[13], zero));
735       const __m256i r1 = _mm256_add_epi16(s2, s3);
736       // Row2 output: r20 x x x r21 x x x | r22 x x x r23 x x x
737       const __m256i s4 = _mm256_add_epi16(_mm256_sad_epu8(src[2], zero),
738                                           _mm256_sad_epu8(src[6], zero));
739       const __m256i s5 = _mm256_add_epi16(_mm256_sad_epu8(src[10], zero),
740                                           _mm256_sad_epu8(src[14], zero));
741       const __m256i r2 = _mm256_add_epi16(s4, s5);
742       // Row3 output: r30 x x x r31 x x x | r32 x x x r33 x x x
743       const __m256i s6 = _mm256_add_epi16(_mm256_sad_epu8(src[3], zero),
744                                           _mm256_sad_epu8(src[7], zero));
745       const __m256i s7 = _mm256_add_epi16(_mm256_sad_epu8(src[11], zero),
746                                           _mm256_sad_epu8(src[15], zero));
747       const __m256i r3 = _mm256_add_epi16(s6, s7);
748 
749       CALC_TOT_SAD_AND_STORE
750       vbuf += 4;
751       ref += ref_stride << 2;
752     }
753   } else if (width == 64) {
754     const __m256i zero = _mm256_setzero_si256();
755     for (int ht = 0; ht < height; ht += 4) {
756       __m256i src[8];
757       // Load source data.
758       load_from_src_buf(ref, &src[0], ref_stride);
759       load_from_src_buf(ref + 32, &src[4], ref_stride);
760 
761       // Row0 output: r00 x x x r01 x x x | r02 x x x r03 x x x
762       const __m256i s0 = _mm256_sad_epu8(src[0], zero);
763       const __m256i s1 = _mm256_sad_epu8(src[4], zero);
764       const __m256i r0 = _mm256_add_epi16(s0, s1);
765       // Row1 output: r10 x x x r11 x x x | r12 x x x r13 x x x
766       const __m256i s2 = _mm256_sad_epu8(src[1], zero);
767       const __m256i s3 = _mm256_sad_epu8(src[5], zero);
768       const __m256i r1 = _mm256_add_epi16(s2, s3);
769       // Row2 output: r20 x x x r21 x x x | r22 x x x r23 x x x
770       const __m256i s4 = _mm256_sad_epu8(src[2], zero);
771       const __m256i s5 = _mm256_sad_epu8(src[6], zero);
772       const __m256i r2 = _mm256_add_epi16(s4, s5);
773       // Row3 output: r30 x x x r31 x x x | r32 x x x r33 x x x
774       const __m256i s6 = _mm256_sad_epu8(src[3], zero);
775       const __m256i s7 = _mm256_sad_epu8(src[7], zero);
776       const __m256i r3 = _mm256_add_epi16(s6, s7);
777 
778       CALC_TOT_SAD_AND_STORE
779       vbuf += 4;
780       ref += ref_stride << 2;
781     }
782   } else if (width == 32) {
783     assert(height % 2 == 0);
784     const __m256i zero = _mm256_setzero_si256();
785     for (int ht = 0; ht < height; ht += 4) {
786       __m256i src[4];
787       // Load source data.
788       load_from_src_buf(ref, &src[0], ref_stride);
789 
790       // s00 x x x s01 x x x s02 x x x s03 x x x
791       const __m256i r0 = _mm256_sad_epu8(src[0], zero);
792       // s10 x x x s11 x x x s12 x x x s13 x x x
793       const __m256i r1 = _mm256_sad_epu8(src[1], zero);
794       // s20 x x x s21 x x x s22 x x x s23 x x x
795       const __m256i r2 = _mm256_sad_epu8(src[2], zero);
796       // s30 x x x s31 x x x s32 x x x s33 x x x
797       const __m256i r3 = _mm256_sad_epu8(src[3], zero);
798 
799       CALC_TOT_SAD_AND_STORE
800       vbuf += 4;
801       ref += ref_stride << 2;
802     }
803   } else if (width == 16) {
804     aom_int_pro_col_16wd_avx2(vbuf, ref, ref_stride, height, norm_factor);
805   }
806 }
807 
calc_vector_mean_sse_64wd(const int16_t * ref,const int16_t * src,__m256i * mean,__m256i * sse)808 static inline void calc_vector_mean_sse_64wd(const int16_t *ref,
809                                              const int16_t *src, __m256i *mean,
810                                              __m256i *sse) {
811   const __m256i src_line0 = _mm256_loadu_si256((const __m256i *)src);
812   const __m256i src_line1 = _mm256_loadu_si256((const __m256i *)(src + 16));
813   const __m256i src_line2 = _mm256_loadu_si256((const __m256i *)(src + 32));
814   const __m256i src_line3 = _mm256_loadu_si256((const __m256i *)(src + 48));
815   const __m256i ref_line0 = _mm256_loadu_si256((const __m256i *)ref);
816   const __m256i ref_line1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
817   const __m256i ref_line2 = _mm256_loadu_si256((const __m256i *)(ref + 32));
818   const __m256i ref_line3 = _mm256_loadu_si256((const __m256i *)(ref + 48));
819 
820   const __m256i diff0 = _mm256_sub_epi16(ref_line0, src_line0);
821   const __m256i diff1 = _mm256_sub_epi16(ref_line1, src_line1);
822   const __m256i diff2 = _mm256_sub_epi16(ref_line2, src_line2);
823   const __m256i diff3 = _mm256_sub_epi16(ref_line3, src_line3);
824   const __m256i diff_sqr0 = _mm256_madd_epi16(diff0, diff0);
825   const __m256i diff_sqr1 = _mm256_madd_epi16(diff1, diff1);
826   const __m256i diff_sqr2 = _mm256_madd_epi16(diff2, diff2);
827   const __m256i diff_sqr3 = _mm256_madd_epi16(diff3, diff3);
828 
829   *mean = _mm256_add_epi16(*mean, _mm256_add_epi16(diff0, diff1));
830   *mean = _mm256_add_epi16(*mean, diff2);
831   *mean = _mm256_add_epi16(*mean, diff3);
832   *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(diff_sqr0, diff_sqr1));
833   *sse = _mm256_add_epi32(*sse, diff_sqr2);
834   *sse = _mm256_add_epi32(*sse, diff_sqr3);
835 }
836 
837 #define CALC_VAR_FROM_MEAN_SSE(mean, sse)                                    \
838   {                                                                          \
839     mean = _mm256_madd_epi16(mean, _mm256_set1_epi16(1));                    \
840     mean = _mm256_hadd_epi32(mean, sse);                                     \
841     mean = _mm256_add_epi32(mean, _mm256_bsrli_epi128(mean, 4));             \
842     const __m128i result = _mm_add_epi32(_mm256_castsi256_si128(mean),       \
843                                          _mm256_extractf128_si256(mean, 1)); \
844     /*(mean * mean): dynamic range 31 bits.*/                                \
845     const int mean_int = _mm_extract_epi32(result, 0);                       \
846     const int sse_int = _mm_extract_epi32(result, 2);                        \
847     const unsigned int mean_abs = abs(mean_int);                             \
848     var = sse_int - ((mean_abs * mean_abs) >> (bwl + 2));                    \
849   }
850 
851 // ref: [0 - 510]
852 // src: [0 - 510]
853 // bwl: {2, 3, 4, 5}
aom_vector_var_avx2(const int16_t * ref,const int16_t * src,int bwl)854 int aom_vector_var_avx2(const int16_t *ref, const int16_t *src, int bwl) {
855   const int width = 4 << bwl;
856   assert(width % 16 == 0 && width <= 128);
857   int var = 0;
858 
859   // Instead of having a loop over width 16, considered loop unrolling to avoid
860   // some addition operations.
861   if (width == 128) {
862     __m256i mean = _mm256_setzero_si256();
863     __m256i sse = _mm256_setzero_si256();
864 
865     calc_vector_mean_sse_64wd(src, ref, &mean, &sse);
866     calc_vector_mean_sse_64wd(src + 64, ref + 64, &mean, &sse);
867     CALC_VAR_FROM_MEAN_SSE(mean, sse)
868   } else if (width == 64) {
869     __m256i mean = _mm256_setzero_si256();
870     __m256i sse = _mm256_setzero_si256();
871 
872     calc_vector_mean_sse_64wd(src, ref, &mean, &sse);
873     CALC_VAR_FROM_MEAN_SSE(mean, sse)
874   } else if (width == 32) {
875     const __m256i src_line0 = _mm256_loadu_si256((const __m256i *)src);
876     const __m256i ref_line0 = _mm256_loadu_si256((const __m256i *)ref);
877     const __m256i src_line1 = _mm256_loadu_si256((const __m256i *)(src + 16));
878     const __m256i ref_line1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
879 
880     const __m256i diff0 = _mm256_sub_epi16(ref_line0, src_line0);
881     const __m256i diff1 = _mm256_sub_epi16(ref_line1, src_line1);
882     const __m256i diff_sqr0 = _mm256_madd_epi16(diff0, diff0);
883     const __m256i diff_sqr1 = _mm256_madd_epi16(diff1, diff1);
884     const __m256i sse = _mm256_add_epi32(diff_sqr0, diff_sqr1);
885     __m256i mean = _mm256_add_epi16(diff0, diff1);
886 
887     CALC_VAR_FROM_MEAN_SSE(mean, sse)
888   } else if (width == 16) {
889     const __m256i src_line = _mm256_loadu_si256((const __m256i *)src);
890     const __m256i ref_line = _mm256_loadu_si256((const __m256i *)ref);
891     __m256i mean = _mm256_sub_epi16(ref_line, src_line);
892     const __m256i sse = _mm256_madd_epi16(mean, mean);
893 
894     CALC_VAR_FROM_MEAN_SSE(mean, sse)
895   }
896   return var;
897 }
898