• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/x86/synonyms_avx2.h"
19 #include "aom_ports/mem.h"
20 
21 // SAD
get_sad_from_mm256_epi32(const __m256i * v)22 static INLINE unsigned int get_sad_from_mm256_epi32(const __m256i *v) {
23   // input 8 32-bit summation
24   __m128i lo128, hi128;
25   __m256i u = _mm256_srli_si256(*v, 8);
26   u = _mm256_add_epi32(u, *v);
27 
28   // 4 32-bit summation
29   hi128 = _mm256_extracti128_si256(u, 1);
30   lo128 = _mm256_castsi256_si128(u);
31   lo128 = _mm_add_epi32(hi128, lo128);
32 
33   // 2 32-bit summation
34   hi128 = _mm_srli_si128(lo128, 4);
35   lo128 = _mm_add_epi32(lo128, hi128);
36 
37   return (unsigned int)_mm_cvtsi128_si32(lo128);
38 }
39 
highbd_sad16x4_core_avx2(__m256i * s,__m256i * r,__m256i * sad_acc)40 static INLINE void highbd_sad16x4_core_avx2(__m256i *s, __m256i *r,
41                                             __m256i *sad_acc) {
42   const __m256i zero = _mm256_setzero_si256();
43   int i;
44   for (i = 0; i < 4; i++) {
45     s[i] = _mm256_sub_epi16(s[i], r[i]);
46     s[i] = _mm256_abs_epi16(s[i]);
47   }
48 
49   s[0] = _mm256_add_epi16(s[0], s[1]);
50   s[0] = _mm256_add_epi16(s[0], s[2]);
51   s[0] = _mm256_add_epi16(s[0], s[3]);
52 
53   r[0] = _mm256_unpacklo_epi16(s[0], zero);
54   r[1] = _mm256_unpackhi_epi16(s[0], zero);
55 
56   r[0] = _mm256_add_epi32(r[0], r[1]);
57   *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
58 }
59 
60 // If sec_ptr = 0, calculate regular SAD. Otherwise, calculate average SAD.
sad16x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)61 static INLINE void sad16x4(const uint16_t *src_ptr, int src_stride,
62                            const uint16_t *ref_ptr, int ref_stride,
63                            const uint16_t *sec_ptr, __m256i *sad_acc) {
64   __m256i s[4], r[4];
65   s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
66   s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
67   s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
68   s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
69 
70   r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
71   r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
72   r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
73   r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
74 
75   if (sec_ptr) {
76     r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
77     r[1] = _mm256_avg_epu16(
78         r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
79     r[2] = _mm256_avg_epu16(
80         r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
81     r[3] = _mm256_avg_epu16(
82         r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
83   }
84   highbd_sad16x4_core_avx2(s, r, sad_acc);
85 }
86 
aom_highbd_sad16xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)87 static AOM_FORCE_INLINE unsigned int aom_highbd_sad16xN_avx2(int N,
88                                                              const uint8_t *src,
89                                                              int src_stride,
90                                                              const uint8_t *ref,
91                                                              int ref_stride) {
92   const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
93   const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
94   int i;
95   __m256i sad = _mm256_setzero_si256();
96   for (i = 0; i < N; i += 4) {
97     sad16x4(src_ptr, src_stride, ref_ptr, ref_stride, NULL, &sad);
98     src_ptr += src_stride << 2;
99     ref_ptr += ref_stride << 2;
100   }
101   return (unsigned int)get_sad_from_mm256_epi32(&sad);
102 }
103 
sad32x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)104 static void sad32x4(const uint16_t *src_ptr, int src_stride,
105                     const uint16_t *ref_ptr, int ref_stride,
106                     const uint16_t *sec_ptr, __m256i *sad_acc) {
107   __m256i s[4], r[4];
108   int row_sections = 0;
109 
110   while (row_sections < 2) {
111     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
112     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
113     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
114     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
115 
116     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
117     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
118     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
119     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
120 
121     if (sec_ptr) {
122       r[0] =
123           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
124       r[1] = _mm256_avg_epu16(
125           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
126       r[2] = _mm256_avg_epu16(
127           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
128       r[3] = _mm256_avg_epu16(
129           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
130       sec_ptr += 32 << 1;
131     }
132     highbd_sad16x4_core_avx2(s, r, sad_acc);
133 
134     row_sections += 1;
135     src_ptr += src_stride << 1;
136     ref_ptr += ref_stride << 1;
137   }
138 }
139 
aom_highbd_sad32xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)140 static AOM_FORCE_INLINE unsigned int aom_highbd_sad32xN_avx2(int N,
141                                                              const uint8_t *src,
142                                                              int src_stride,
143                                                              const uint8_t *ref,
144                                                              int ref_stride) {
145   __m256i sad = _mm256_setzero_si256();
146   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
147   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
148   const int left_shift = 2;
149   int i;
150 
151   for (i = 0; i < N; i += 4) {
152     sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
153     srcp += src_stride << left_shift;
154     refp += ref_stride << left_shift;
155   }
156   return get_sad_from_mm256_epi32(&sad);
157 }
158 
sad64x2(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)159 static void sad64x2(const uint16_t *src_ptr, int src_stride,
160                     const uint16_t *ref_ptr, int ref_stride,
161                     const uint16_t *sec_ptr, __m256i *sad_acc) {
162   __m256i s[4], r[4];
163   int i;
164   for (i = 0; i < 2; i++) {
165     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
166     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
167     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
168     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
169 
170     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
171     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
172     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
173     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
174     if (sec_ptr) {
175       r[0] =
176           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
177       r[1] = _mm256_avg_epu16(
178           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
179       r[2] = _mm256_avg_epu16(
180           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
181       r[3] = _mm256_avg_epu16(
182           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
183       sec_ptr += 64;
184     }
185     highbd_sad16x4_core_avx2(s, r, sad_acc);
186     src_ptr += src_stride;
187     ref_ptr += ref_stride;
188   }
189 }
190 
aom_highbd_sad64xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)191 static AOM_FORCE_INLINE unsigned int aom_highbd_sad64xN_avx2(int N,
192                                                              const uint8_t *src,
193                                                              int src_stride,
194                                                              const uint8_t *ref,
195                                                              int ref_stride) {
196   __m256i sad = _mm256_setzero_si256();
197   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
198   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
199   const int left_shift = 1;
200   int i;
201   for (i = 0; i < N; i += 2) {
202     sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
203     srcp += src_stride << left_shift;
204     refp += ref_stride << left_shift;
205   }
206   return get_sad_from_mm256_epi32(&sad);
207 }
208 
sad128x1(const uint16_t * src_ptr,const uint16_t * ref_ptr,const uint16_t * sec_ptr,__m256i * sad_acc)209 static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
210                      const uint16_t *sec_ptr, __m256i *sad_acc) {
211   __m256i s[4], r[4];
212   int i;
213   for (i = 0; i < 2; i++) {
214     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
215     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
216     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
217     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
218     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
219     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
220     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
221     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
222     if (sec_ptr) {
223       r[0] =
224           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
225       r[1] = _mm256_avg_epu16(
226           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
227       r[2] = _mm256_avg_epu16(
228           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
229       r[3] = _mm256_avg_epu16(
230           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
231       sec_ptr += 64;
232     }
233     highbd_sad16x4_core_avx2(s, r, sad_acc);
234     src_ptr += 64;
235     ref_ptr += 64;
236   }
237 }
238 
aom_highbd_sad128xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)239 static AOM_FORCE_INLINE unsigned int aom_highbd_sad128xN_avx2(
240     int N, const uint8_t *src, int src_stride, const uint8_t *ref,
241     int ref_stride) {
242   __m256i sad = _mm256_setzero_si256();
243   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
244   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
245   int row = 0;
246   while (row < N) {
247     sad128x1(srcp, refp, NULL, &sad);
248     srcp += src_stride;
249     refp += ref_stride;
250     row++;
251   }
252   return get_sad_from_mm256_epi32(&sad);
253 }
254 
255 #define HIGHBD_SADMXN_AVX2(m, n)                                            \
256   unsigned int aom_highbd_sad##m##x##n##_avx2(                              \
257       const uint8_t *src, int src_stride, const uint8_t *ref,               \
258       int ref_stride) {                                                     \
259     return aom_highbd_sad##m##xN_avx2(n, src, src_stride, ref, ref_stride); \
260   }
261 
262 #define HIGHBD_SAD_SKIP_MXN_AVX2(m, n)                                       \
263   unsigned int aom_highbd_sad_skip_##m##x##n##_avx2(                         \
264       const uint8_t *src, int src_stride, const uint8_t *ref,                \
265       int ref_stride) {                                                      \
266     return 2 * aom_highbd_sad##m##xN_avx2((n / 2), src, 2 * src_stride, ref, \
267                                           2 * ref_stride);                   \
268   }
269 
270 HIGHBD_SADMXN_AVX2(16, 4)
271 HIGHBD_SADMXN_AVX2(16, 8)
272 HIGHBD_SADMXN_AVX2(16, 16)
273 HIGHBD_SADMXN_AVX2(16, 32)
274 HIGHBD_SADMXN_AVX2(16, 64)
275 
276 HIGHBD_SADMXN_AVX2(32, 8)
277 HIGHBD_SADMXN_AVX2(32, 16)
278 HIGHBD_SADMXN_AVX2(32, 32)
279 HIGHBD_SADMXN_AVX2(32, 64)
280 
281 HIGHBD_SADMXN_AVX2(64, 16)
282 HIGHBD_SADMXN_AVX2(64, 32)
283 HIGHBD_SADMXN_AVX2(64, 64)
284 HIGHBD_SADMXN_AVX2(64, 128)
285 
286 HIGHBD_SADMXN_AVX2(128, 64)
287 HIGHBD_SADMXN_AVX2(128, 128)
288 
289 HIGHBD_SAD_SKIP_MXN_AVX2(16, 8)
290 HIGHBD_SAD_SKIP_MXN_AVX2(16, 16)
291 HIGHBD_SAD_SKIP_MXN_AVX2(16, 32)
292 HIGHBD_SAD_SKIP_MXN_AVX2(16, 64)
293 
294 HIGHBD_SAD_SKIP_MXN_AVX2(32, 8)
295 HIGHBD_SAD_SKIP_MXN_AVX2(32, 16)
296 HIGHBD_SAD_SKIP_MXN_AVX2(32, 32)
297 HIGHBD_SAD_SKIP_MXN_AVX2(32, 64)
298 
299 HIGHBD_SAD_SKIP_MXN_AVX2(64, 16)
300 HIGHBD_SAD_SKIP_MXN_AVX2(64, 32)
301 HIGHBD_SAD_SKIP_MXN_AVX2(64, 64)
302 HIGHBD_SAD_SKIP_MXN_AVX2(64, 128)
303 
304 HIGHBD_SAD_SKIP_MXN_AVX2(128, 64)
305 HIGHBD_SAD_SKIP_MXN_AVX2(128, 128)
306 
aom_highbd_sad16x4_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)307 unsigned int aom_highbd_sad16x4_avg_avx2(const uint8_t *src, int src_stride,
308                                          const uint8_t *ref, int ref_stride,
309                                          const uint8_t *second_pred) {
310   __m256i sad = _mm256_setzero_si256();
311   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
312   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
313   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
314   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
315 
316   return get_sad_from_mm256_epi32(&sad);
317 }
318 
aom_highbd_sad16x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)319 unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
320                                          const uint8_t *ref, int ref_stride,
321                                          const uint8_t *second_pred) {
322   __m256i sad = _mm256_setzero_si256();
323   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
324   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
325   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
326 
327   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
328 
329   // Next 4 rows
330   srcp += src_stride << 2;
331   refp += ref_stride << 2;
332   secp += 64;
333   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
334   return get_sad_from_mm256_epi32(&sad);
335 }
336 
aom_highbd_sad16x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)337 unsigned int aom_highbd_sad16x16_avg_avx2(const uint8_t *src, int src_stride,
338                                           const uint8_t *ref, int ref_stride,
339                                           const uint8_t *second_pred) {
340   const int left_shift = 3;
341   uint32_t sum = aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
342                                              second_pred);
343   src += src_stride << left_shift;
344   ref += ref_stride << left_shift;
345   second_pred += 16 << left_shift;
346   sum += aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
347                                      second_pred);
348   return sum;
349 }
350 
aom_highbd_sad16x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)351 unsigned int aom_highbd_sad16x32_avg_avx2(const uint8_t *src, int src_stride,
352                                           const uint8_t *ref, int ref_stride,
353                                           const uint8_t *second_pred) {
354   const int left_shift = 4;
355   uint32_t sum = aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
356                                               second_pred);
357   src += src_stride << left_shift;
358   ref += ref_stride << left_shift;
359   second_pred += 16 << left_shift;
360   sum += aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
361                                       second_pred);
362   return sum;
363 }
364 
aom_highbd_sad16x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)365 unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride,
366                                           const uint8_t *ref, int ref_stride,
367                                           const uint8_t *second_pred) {
368   const int left_shift = 5;
369   uint32_t sum = aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
370                                               second_pred);
371   src += src_stride << left_shift;
372   ref += ref_stride << left_shift;
373   second_pred += 16 << left_shift;
374   sum += aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
375                                       second_pred);
376   return sum;
377 }
378 
aom_highbd_sad32x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)379 unsigned int aom_highbd_sad32x8_avg_avx2(const uint8_t *src, int src_stride,
380                                          const uint8_t *ref, int ref_stride,
381                                          const uint8_t *second_pred) {
382   __m256i sad = _mm256_setzero_si256();
383   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
384   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
385   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
386   const int left_shift = 2;
387   int row_section = 0;
388 
389   while (row_section < 2) {
390     sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
391     srcp += src_stride << left_shift;
392     refp += ref_stride << left_shift;
393     secp += 32 << left_shift;
394     row_section += 1;
395   }
396   return get_sad_from_mm256_epi32(&sad);
397 }
398 
aom_highbd_sad32x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)399 unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
400                                           const uint8_t *ref, int ref_stride,
401                                           const uint8_t *second_pred) {
402   __m256i sad = _mm256_setzero_si256();
403   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
404   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
405   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
406   const int left_shift = 2;
407   int row_section = 0;
408 
409   while (row_section < 4) {
410     sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
411     srcp += src_stride << left_shift;
412     refp += ref_stride << left_shift;
413     secp += 32 << left_shift;
414     row_section += 1;
415   }
416   return get_sad_from_mm256_epi32(&sad);
417 }
418 
aom_highbd_sad32x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)419 unsigned int aom_highbd_sad32x32_avg_avx2(const uint8_t *src, int src_stride,
420                                           const uint8_t *ref, int ref_stride,
421                                           const uint8_t *second_pred) {
422   const int left_shift = 4;
423   uint32_t sum = aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
424                                               second_pred);
425   src += src_stride << left_shift;
426   ref += ref_stride << left_shift;
427   second_pred += 32 << left_shift;
428   sum += aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
429                                       second_pred);
430   return sum;
431 }
432 
aom_highbd_sad32x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)433 unsigned int aom_highbd_sad32x64_avg_avx2(const uint8_t *src, int src_stride,
434                                           const uint8_t *ref, int ref_stride,
435                                           const uint8_t *second_pred) {
436   const int left_shift = 5;
437   uint32_t sum = aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
438                                               second_pred);
439   src += src_stride << left_shift;
440   ref += ref_stride << left_shift;
441   second_pred += 32 << left_shift;
442   sum += aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
443                                       second_pred);
444   return sum;
445 }
446 
aom_highbd_sad64x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)447 unsigned int aom_highbd_sad64x16_avg_avx2(const uint8_t *src, int src_stride,
448                                           const uint8_t *ref, int ref_stride,
449                                           const uint8_t *second_pred) {
450   __m256i sad = _mm256_setzero_si256();
451   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
452   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
453   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
454   const int left_shift = 1;
455   int row_section = 0;
456 
457   while (row_section < 8) {
458     sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
459     srcp += src_stride << left_shift;
460     refp += ref_stride << left_shift;
461     secp += 64 << left_shift;
462     row_section += 1;
463   }
464   return get_sad_from_mm256_epi32(&sad);
465 }
466 
aom_highbd_sad64x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)467 unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
468                                           const uint8_t *ref, int ref_stride,
469                                           const uint8_t *second_pred) {
470   __m256i sad = _mm256_setzero_si256();
471   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
472   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
473   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
474   const int left_shift = 1;
475   int row_section = 0;
476 
477   while (row_section < 16) {
478     sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
479     srcp += src_stride << left_shift;
480     refp += ref_stride << left_shift;
481     secp += 64 << left_shift;
482     row_section += 1;
483   }
484   return get_sad_from_mm256_epi32(&sad);
485 }
486 
aom_highbd_sad64x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)487 unsigned int aom_highbd_sad64x64_avg_avx2(const uint8_t *src, int src_stride,
488                                           const uint8_t *ref, int ref_stride,
489                                           const uint8_t *second_pred) {
490   const int left_shift = 5;
491   uint32_t sum = aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
492                                               second_pred);
493   src += src_stride << left_shift;
494   ref += ref_stride << left_shift;
495   second_pred += 64 << left_shift;
496   sum += aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
497                                       second_pred);
498   return sum;
499 }
500 
aom_highbd_sad64x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)501 unsigned int aom_highbd_sad64x128_avg_avx2(const uint8_t *src, int src_stride,
502                                            const uint8_t *ref, int ref_stride,
503                                            const uint8_t *second_pred) {
504   const int left_shift = 6;
505   uint32_t sum = aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
506                                               second_pred);
507   src += src_stride << left_shift;
508   ref += ref_stride << left_shift;
509   second_pred += 64 << left_shift;
510   sum += aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
511                                       second_pred);
512   return sum;
513 }
514 
aom_highbd_sad128x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)515 unsigned int aom_highbd_sad128x64_avg_avx2(const uint8_t *src, int src_stride,
516                                            const uint8_t *ref, int ref_stride,
517                                            const uint8_t *second_pred) {
518   __m256i sad = _mm256_setzero_si256();
519   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
520   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
521   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
522   int row = 0;
523   while (row < 64) {
524     sad128x1(srcp, refp, secp, &sad);
525     srcp += src_stride;
526     refp += ref_stride;
527     secp += 16 << 3;
528     row += 1;
529   }
530   return get_sad_from_mm256_epi32(&sad);
531 }
532 
aom_highbd_sad128x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)533 unsigned int aom_highbd_sad128x128_avg_avx2(const uint8_t *src, int src_stride,
534                                             const uint8_t *ref, int ref_stride,
535                                             const uint8_t *second_pred) {
536   unsigned int sum;
537   const int left_shift = 6;
538 
539   sum = aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
540                                       second_pred);
541   src += src_stride << left_shift;
542   ref += ref_stride << left_shift;
543   second_pred += 128 << left_shift;
544   sum += aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
545                                        second_pred);
546   return sum;
547 }
548 
549 // SAD 4D
550 // Combine 4 __m256i input vectors  v to uint32_t result[4]
get_4d_sad_from_mm256_epi32(const __m256i * v,uint32_t * res)551 static INLINE void get_4d_sad_from_mm256_epi32(const __m256i *v,
552                                                uint32_t *res) {
553   __m256i u0, u1, u2, u3;
554   const __m256i mask = yy_set1_64_from_32i(~0);
555   __m128i sad;
556 
557   // 8 32-bit summation
558   u0 = _mm256_srli_si256(v[0], 4);
559   u1 = _mm256_srli_si256(v[1], 4);
560   u2 = _mm256_srli_si256(v[2], 4);
561   u3 = _mm256_srli_si256(v[3], 4);
562 
563   u0 = _mm256_add_epi32(u0, v[0]);
564   u1 = _mm256_add_epi32(u1, v[1]);
565   u2 = _mm256_add_epi32(u2, v[2]);
566   u3 = _mm256_add_epi32(u3, v[3]);
567 
568   u0 = _mm256_and_si256(u0, mask);
569   u1 = _mm256_and_si256(u1, mask);
570   u2 = _mm256_and_si256(u2, mask);
571   u3 = _mm256_and_si256(u3, mask);
572   // 4 32-bit summation, evenly positioned
573 
574   u1 = _mm256_slli_si256(u1, 4);
575   u3 = _mm256_slli_si256(u3, 4);
576 
577   u0 = _mm256_or_si256(u0, u1);
578   u2 = _mm256_or_si256(u2, u3);
579   // 8 32-bit summation, interleaved
580 
581   u1 = _mm256_unpacklo_epi64(u0, u2);
582   u3 = _mm256_unpackhi_epi64(u0, u2);
583 
584   u0 = _mm256_add_epi32(u1, u3);
585   sad = _mm_add_epi32(_mm256_extractf128_si256(u0, 1),
586                       _mm256_castsi256_si128(u0));
587   _mm_storeu_si128((__m128i *)res, sad);
588 }
589 
convert_pointers(const uint8_t * const ref8[],const uint16_t * ref[])590 static void convert_pointers(const uint8_t *const ref8[],
591                              const uint16_t *ref[]) {
592   ref[0] = CONVERT_TO_SHORTPTR(ref8[0]);
593   ref[1] = CONVERT_TO_SHORTPTR(ref8[1]);
594   ref[2] = CONVERT_TO_SHORTPTR(ref8[2]);
595   ref[3] = CONVERT_TO_SHORTPTR(ref8[3]);
596 }
597 
init_sad(__m256i * s)598 static void init_sad(__m256i *s) {
599   s[0] = _mm256_setzero_si256();
600   s[1] = _mm256_setzero_si256();
601   s[2] = _mm256_setzero_si256();
602   s[3] = _mm256_setzero_si256();
603 }
604 
aom_highbd_sadMxNxD_avx2(int M,int N,int D,const uint8_t * src,int src_stride,const uint8_t * const ref_array[4],int ref_stride,uint32_t sad_array[4])605 static AOM_FORCE_INLINE void aom_highbd_sadMxNxD_avx2(
606     int M, int N, int D, const uint8_t *src, int src_stride,
607     const uint8_t *const ref_array[4], int ref_stride, uint32_t sad_array[4]) {
608   __m256i sad_vec[4];
609   const uint16_t *refp[4];
610   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
611   const uint16_t *srcp;
612   const int shift_for_rows = (M < 128) + (M < 64);
613   const int row_units = 1 << shift_for_rows;
614   int i, r;
615 
616   init_sad(sad_vec);
617   convert_pointers(ref_array, refp);
618 
619   for (i = 0; i < D; ++i) {
620     srcp = keep;
621     for (r = 0; r < N; r += row_units) {
622       if (M == 128) {
623         sad128x1(srcp, refp[i], NULL, &sad_vec[i]);
624       } else if (M == 64) {
625         sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
626       } else if (M == 32) {
627         sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
628       } else if (M == 16) {
629         sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
630       } else {
631         assert(0);
632       }
633       srcp += src_stride << shift_for_rows;
634       refp[i] += ref_stride << shift_for_rows;
635     }
636   }
637   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
638 }
639 
640 #define HIGHBD_SAD_MXNX4D_AVX2(m, n)                                          \
641   void aom_highbd_sad##m##x##n##x4d_avx2(                                     \
642       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4],  \
643       int ref_stride, uint32_t sad_array[4]) {                                \
644     aom_highbd_sadMxNxD_avx2(m, n, 4, src, src_stride, ref_array, ref_stride, \
645                              sad_array);                                      \
646   }
647 #define HIGHBD_SAD_SKIP_MXNX4D_AVX2(m, n)                                    \
648   void aom_highbd_sad_skip_##m##x##n##x4d_avx2(                              \
649       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
650       int ref_stride, uint32_t sad_array[4]) {                               \
651     aom_highbd_sadMxNxD_avx2(m, (n / 2), 4, src, 2 * src_stride, ref_array,  \
652                              2 * ref_stride, sad_array);                     \
653     sad_array[0] <<= 1;                                                      \
654     sad_array[1] <<= 1;                                                      \
655     sad_array[2] <<= 1;                                                      \
656     sad_array[3] <<= 1;                                                      \
657   }
658 #define HIGHBD_SAD_MXNX3D_AVX2(m, n)                                          \
659   void aom_highbd_sad##m##x##n##x3d_avx2(                                     \
660       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4],  \
661       int ref_stride, uint32_t sad_array[4]) {                                \
662     aom_highbd_sadMxNxD_avx2(m, n, 3, src, src_stride, ref_array, ref_stride, \
663                              sad_array);                                      \
664   }
665 
666 HIGHBD_SAD_MXNX4D_AVX2(16, 4)
667 HIGHBD_SAD_MXNX4D_AVX2(16, 8)
668 HIGHBD_SAD_MXNX4D_AVX2(16, 16)
669 HIGHBD_SAD_MXNX4D_AVX2(16, 32)
670 HIGHBD_SAD_MXNX4D_AVX2(16, 64)
671 
672 HIGHBD_SAD_MXNX4D_AVX2(32, 8)
673 HIGHBD_SAD_MXNX4D_AVX2(32, 16)
674 HIGHBD_SAD_MXNX4D_AVX2(32, 32)
675 HIGHBD_SAD_MXNX4D_AVX2(32, 64)
676 
677 HIGHBD_SAD_MXNX4D_AVX2(64, 16)
678 HIGHBD_SAD_MXNX4D_AVX2(64, 32)
679 HIGHBD_SAD_MXNX4D_AVX2(64, 64)
680 HIGHBD_SAD_MXNX4D_AVX2(64, 128)
681 
682 HIGHBD_SAD_MXNX4D_AVX2(128, 64)
683 HIGHBD_SAD_MXNX4D_AVX2(128, 128)
684 
685 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 8)
686 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 16)
687 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 32)
688 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 64)
689 
690 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 8)
691 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 16)
692 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 32)
693 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 64)
694 
695 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 16)
696 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 32)
697 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 64)
698 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 128)
699 
700 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 64)
701 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 128)
702 
703 HIGHBD_SAD_MXNX3D_AVX2(16, 4)
704 HIGHBD_SAD_MXNX3D_AVX2(16, 8)
705 HIGHBD_SAD_MXNX3D_AVX2(16, 16)
706 HIGHBD_SAD_MXNX3D_AVX2(16, 32)
707 HIGHBD_SAD_MXNX3D_AVX2(16, 64)
708 
709 HIGHBD_SAD_MXNX3D_AVX2(32, 8)
710 HIGHBD_SAD_MXNX3D_AVX2(32, 16)
711 HIGHBD_SAD_MXNX3D_AVX2(32, 32)
712 HIGHBD_SAD_MXNX3D_AVX2(32, 64)
713 
714 HIGHBD_SAD_MXNX3D_AVX2(64, 16)
715 HIGHBD_SAD_MXNX3D_AVX2(64, 32)
716 HIGHBD_SAD_MXNX3D_AVX2(64, 64)
717 HIGHBD_SAD_MXNX3D_AVX2(64, 128)
718 
719 HIGHBD_SAD_MXNX3D_AVX2(128, 64)
720 HIGHBD_SAD_MXNX3D_AVX2(128, 128)
721