• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/x86/synonyms_avx2.h"
19 #include "aom_ports/mem.h"
20 
21 // SAD
get_sad_from_mm256_epi32(const __m256i * v)22 static inline unsigned int get_sad_from_mm256_epi32(const __m256i *v) {
23   // input 8 32-bit summation
24   __m128i lo128, hi128;
25   __m256i u = _mm256_srli_si256(*v, 8);
26   u = _mm256_add_epi32(u, *v);
27 
28   // 4 32-bit summation
29   hi128 = _mm256_extracti128_si256(u, 1);
30   lo128 = _mm256_castsi256_si128(u);
31   lo128 = _mm_add_epi32(hi128, lo128);
32 
33   // 2 32-bit summation
34   hi128 = _mm_srli_si128(lo128, 4);
35   lo128 = _mm_add_epi32(lo128, hi128);
36 
37   return (unsigned int)_mm_cvtsi128_si32(lo128);
38 }
39 
highbd_sad16x4_core_avx2(__m256i * s,__m256i * r,__m256i * sad_acc)40 static inline void highbd_sad16x4_core_avx2(__m256i *s, __m256i *r,
41                                             __m256i *sad_acc) {
42   const __m256i zero = _mm256_setzero_si256();
43   int i;
44   for (i = 0; i < 4; i++) {
45     s[i] = _mm256_sub_epi16(s[i], r[i]);
46     s[i] = _mm256_abs_epi16(s[i]);
47   }
48 
49   s[0] = _mm256_add_epi16(s[0], s[1]);
50   s[0] = _mm256_add_epi16(s[0], s[2]);
51   s[0] = _mm256_add_epi16(s[0], s[3]);
52 
53   r[0] = _mm256_unpacklo_epi16(s[0], zero);
54   r[1] = _mm256_unpackhi_epi16(s[0], zero);
55 
56   r[0] = _mm256_add_epi32(r[0], r[1]);
57   *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
58 }
59 
60 // If sec_ptr = 0, calculate regular SAD. Otherwise, calculate average SAD.
sad16x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)61 static inline void sad16x4(const uint16_t *src_ptr, int src_stride,
62                            const uint16_t *ref_ptr, int ref_stride,
63                            const uint16_t *sec_ptr, __m256i *sad_acc) {
64   __m256i s[4], r[4];
65   s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
66   s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
67   s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
68   s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
69 
70   r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
71   r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
72   r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
73   r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
74 
75   if (sec_ptr) {
76     r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
77     r[1] = _mm256_avg_epu16(
78         r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
79     r[2] = _mm256_avg_epu16(
80         r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
81     r[3] = _mm256_avg_epu16(
82         r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
83   }
84   highbd_sad16x4_core_avx2(s, r, sad_acc);
85 }
86 
aom_highbd_sad16xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)87 static AOM_FORCE_INLINE unsigned int aom_highbd_sad16xN_avx2(int N,
88                                                              const uint8_t *src,
89                                                              int src_stride,
90                                                              const uint8_t *ref,
91                                                              int ref_stride) {
92   const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
93   const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
94   int i;
95   __m256i sad = _mm256_setzero_si256();
96   for (i = 0; i < N; i += 4) {
97     sad16x4(src_ptr, src_stride, ref_ptr, ref_stride, NULL, &sad);
98     src_ptr += src_stride << 2;
99     ref_ptr += ref_stride << 2;
100   }
101   return (unsigned int)get_sad_from_mm256_epi32(&sad);
102 }
103 
sad32x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)104 static void sad32x4(const uint16_t *src_ptr, int src_stride,
105                     const uint16_t *ref_ptr, int ref_stride,
106                     const uint16_t *sec_ptr, __m256i *sad_acc) {
107   __m256i s[4], r[4];
108   int row_sections = 0;
109 
110   while (row_sections < 2) {
111     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
112     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
113     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
114     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
115 
116     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
117     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
118     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
119     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
120 
121     if (sec_ptr) {
122       r[0] =
123           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
124       r[1] = _mm256_avg_epu16(
125           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
126       r[2] = _mm256_avg_epu16(
127           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
128       r[3] = _mm256_avg_epu16(
129           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
130       sec_ptr += 32 << 1;
131     }
132     highbd_sad16x4_core_avx2(s, r, sad_acc);
133 
134     row_sections += 1;
135     src_ptr += src_stride << 1;
136     ref_ptr += ref_stride << 1;
137   }
138 }
139 
aom_highbd_sad32xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)140 static AOM_FORCE_INLINE unsigned int aom_highbd_sad32xN_avx2(int N,
141                                                              const uint8_t *src,
142                                                              int src_stride,
143                                                              const uint8_t *ref,
144                                                              int ref_stride) {
145   __m256i sad = _mm256_setzero_si256();
146   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
147   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
148   const int left_shift = 2;
149   int i;
150 
151   for (i = 0; i < N; i += 4) {
152     sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
153     srcp += src_stride << left_shift;
154     refp += ref_stride << left_shift;
155   }
156   return get_sad_from_mm256_epi32(&sad);
157 }
158 
sad64x2(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)159 static void sad64x2(const uint16_t *src_ptr, int src_stride,
160                     const uint16_t *ref_ptr, int ref_stride,
161                     const uint16_t *sec_ptr, __m256i *sad_acc) {
162   __m256i s[4], r[4];
163   int i;
164   for (i = 0; i < 2; i++) {
165     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
166     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
167     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
168     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
169 
170     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
171     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
172     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
173     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
174     if (sec_ptr) {
175       r[0] =
176           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
177       r[1] = _mm256_avg_epu16(
178           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
179       r[2] = _mm256_avg_epu16(
180           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
181       r[3] = _mm256_avg_epu16(
182           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
183       sec_ptr += 64;
184     }
185     highbd_sad16x4_core_avx2(s, r, sad_acc);
186     src_ptr += src_stride;
187     ref_ptr += ref_stride;
188   }
189 }
190 
aom_highbd_sad64xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)191 static AOM_FORCE_INLINE unsigned int aom_highbd_sad64xN_avx2(int N,
192                                                              const uint8_t *src,
193                                                              int src_stride,
194                                                              const uint8_t *ref,
195                                                              int ref_stride) {
196   __m256i sad = _mm256_setzero_si256();
197   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
198   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
199   const int left_shift = 1;
200   int i;
201   for (i = 0; i < N; i += 2) {
202     sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
203     srcp += src_stride << left_shift;
204     refp += ref_stride << left_shift;
205   }
206   return get_sad_from_mm256_epi32(&sad);
207 }
208 
sad128x1(const uint16_t * src_ptr,const uint16_t * ref_ptr,const uint16_t * sec_ptr,__m256i * sad_acc)209 static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
210                      const uint16_t *sec_ptr, __m256i *sad_acc) {
211   __m256i s[4], r[4];
212   int i;
213   for (i = 0; i < 2; i++) {
214     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
215     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
216     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
217     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
218     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
219     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
220     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
221     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
222     if (sec_ptr) {
223       r[0] =
224           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
225       r[1] = _mm256_avg_epu16(
226           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
227       r[2] = _mm256_avg_epu16(
228           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
229       r[3] = _mm256_avg_epu16(
230           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
231       sec_ptr += 64;
232     }
233     highbd_sad16x4_core_avx2(s, r, sad_acc);
234     src_ptr += 64;
235     ref_ptr += 64;
236   }
237 }
238 
aom_highbd_sad128xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)239 static AOM_FORCE_INLINE unsigned int aom_highbd_sad128xN_avx2(
240     int N, const uint8_t *src, int src_stride, const uint8_t *ref,
241     int ref_stride) {
242   __m256i sad = _mm256_setzero_si256();
243   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
244   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
245   int row = 0;
246   while (row < N) {
247     sad128x1(srcp, refp, NULL, &sad);
248     srcp += src_stride;
249     refp += ref_stride;
250     row++;
251   }
252   return get_sad_from_mm256_epi32(&sad);
253 }
254 
255 #define HIGHBD_SADMXN_AVX2(m, n)                                            \
256   unsigned int aom_highbd_sad##m##x##n##_avx2(                              \
257       const uint8_t *src, int src_stride, const uint8_t *ref,               \
258       int ref_stride) {                                                     \
259     return aom_highbd_sad##m##xN_avx2(n, src, src_stride, ref, ref_stride); \
260   }
261 
262 #define HIGHBD_SAD_SKIP_MXN_AVX2(m, n)                                       \
263   unsigned int aom_highbd_sad_skip_##m##x##n##_avx2(                         \
264       const uint8_t *src, int src_stride, const uint8_t *ref,                \
265       int ref_stride) {                                                      \
266     return 2 * aom_highbd_sad##m##xN_avx2((n / 2), src, 2 * src_stride, ref, \
267                                           2 * ref_stride);                   \
268   }
269 
270 HIGHBD_SADMXN_AVX2(16, 8)
271 HIGHBD_SADMXN_AVX2(16, 16)
272 HIGHBD_SADMXN_AVX2(16, 32)
273 
274 HIGHBD_SADMXN_AVX2(32, 16)
275 HIGHBD_SADMXN_AVX2(32, 32)
276 HIGHBD_SADMXN_AVX2(32, 64)
277 
278 HIGHBD_SADMXN_AVX2(64, 32)
279 HIGHBD_SADMXN_AVX2(64, 64)
280 HIGHBD_SADMXN_AVX2(64, 128)
281 
282 HIGHBD_SADMXN_AVX2(128, 64)
283 HIGHBD_SADMXN_AVX2(128, 128)
284 
285 #if !CONFIG_REALTIME_ONLY
286 HIGHBD_SADMXN_AVX2(16, 4)
287 HIGHBD_SADMXN_AVX2(16, 64)
288 HIGHBD_SADMXN_AVX2(32, 8)
289 HIGHBD_SADMXN_AVX2(64, 16)
290 #endif  // !CONFIG_REALTIME_ONLY
291 
292 HIGHBD_SAD_SKIP_MXN_AVX2(16, 16)
293 HIGHBD_SAD_SKIP_MXN_AVX2(16, 32)
294 
295 HIGHBD_SAD_SKIP_MXN_AVX2(32, 16)
296 HIGHBD_SAD_SKIP_MXN_AVX2(32, 32)
297 HIGHBD_SAD_SKIP_MXN_AVX2(32, 64)
298 
299 HIGHBD_SAD_SKIP_MXN_AVX2(64, 32)
300 HIGHBD_SAD_SKIP_MXN_AVX2(64, 64)
301 HIGHBD_SAD_SKIP_MXN_AVX2(64, 128)
302 
303 HIGHBD_SAD_SKIP_MXN_AVX2(128, 64)
304 HIGHBD_SAD_SKIP_MXN_AVX2(128, 128)
305 
306 #if !CONFIG_REALTIME_ONLY
307 HIGHBD_SAD_SKIP_MXN_AVX2(16, 64)
308 HIGHBD_SAD_SKIP_MXN_AVX2(64, 16)
309 #endif  // !CONFIG_REALTIME_ONLY
310 
aom_highbd_sad16x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)311 unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
312                                          const uint8_t *ref, int ref_stride,
313                                          const uint8_t *second_pred) {
314   __m256i sad = _mm256_setzero_si256();
315   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
316   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
317   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
318 
319   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
320 
321   // Next 4 rows
322   srcp += src_stride << 2;
323   refp += ref_stride << 2;
324   secp += 64;
325   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
326   return get_sad_from_mm256_epi32(&sad);
327 }
328 
aom_highbd_sad16x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)329 unsigned int aom_highbd_sad16x16_avg_avx2(const uint8_t *src, int src_stride,
330                                           const uint8_t *ref, int ref_stride,
331                                           const uint8_t *second_pred) {
332   const int left_shift = 3;
333   uint32_t sum = aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
334                                              second_pred);
335   src += src_stride << left_shift;
336   ref += ref_stride << left_shift;
337   second_pred += 16 << left_shift;
338   sum += aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
339                                      second_pred);
340   return sum;
341 }
342 
aom_highbd_sad16x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)343 unsigned int aom_highbd_sad16x32_avg_avx2(const uint8_t *src, int src_stride,
344                                           const uint8_t *ref, int ref_stride,
345                                           const uint8_t *second_pred) {
346   const int left_shift = 4;
347   uint32_t sum = aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
348                                               second_pred);
349   src += src_stride << left_shift;
350   ref += ref_stride << left_shift;
351   second_pred += 16 << left_shift;
352   sum += aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
353                                       second_pred);
354   return sum;
355 }
356 
357 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad16x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)358 unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride,
359                                           const uint8_t *ref, int ref_stride,
360                                           const uint8_t *second_pred) {
361   const int left_shift = 5;
362   uint32_t sum = aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
363                                               second_pred);
364   src += src_stride << left_shift;
365   ref += ref_stride << left_shift;
366   second_pred += 16 << left_shift;
367   sum += aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
368                                       second_pred);
369   return sum;
370 }
371 
aom_highbd_sad32x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)372 unsigned int aom_highbd_sad32x8_avg_avx2(const uint8_t *src, int src_stride,
373                                          const uint8_t *ref, int ref_stride,
374                                          const uint8_t *second_pred) {
375   __m256i sad = _mm256_setzero_si256();
376   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
377   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
378   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
379   const int left_shift = 2;
380   int row_section = 0;
381 
382   while (row_section < 2) {
383     sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
384     srcp += src_stride << left_shift;
385     refp += ref_stride << left_shift;
386     secp += 32 << left_shift;
387     row_section += 1;
388   }
389   return get_sad_from_mm256_epi32(&sad);
390 }
391 #endif  // !CONFIG_REALTIME_ONLY
392 
aom_highbd_sad32x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)393 unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
394                                           const uint8_t *ref, int ref_stride,
395                                           const uint8_t *second_pred) {
396   __m256i sad = _mm256_setzero_si256();
397   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
398   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
399   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
400   const int left_shift = 2;
401   int row_section = 0;
402 
403   while (row_section < 4) {
404     sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
405     srcp += src_stride << left_shift;
406     refp += ref_stride << left_shift;
407     secp += 32 << left_shift;
408     row_section += 1;
409   }
410   return get_sad_from_mm256_epi32(&sad);
411 }
412 
aom_highbd_sad32x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)413 unsigned int aom_highbd_sad32x32_avg_avx2(const uint8_t *src, int src_stride,
414                                           const uint8_t *ref, int ref_stride,
415                                           const uint8_t *second_pred) {
416   const int left_shift = 4;
417   uint32_t sum = aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
418                                               second_pred);
419   src += src_stride << left_shift;
420   ref += ref_stride << left_shift;
421   second_pred += 32 << left_shift;
422   sum += aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
423                                       second_pred);
424   return sum;
425 }
426 
aom_highbd_sad32x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)427 unsigned int aom_highbd_sad32x64_avg_avx2(const uint8_t *src, int src_stride,
428                                           const uint8_t *ref, int ref_stride,
429                                           const uint8_t *second_pred) {
430   const int left_shift = 5;
431   uint32_t sum = aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
432                                               second_pred);
433   src += src_stride << left_shift;
434   ref += ref_stride << left_shift;
435   second_pred += 32 << left_shift;
436   sum += aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
437                                       second_pred);
438   return sum;
439 }
440 
441 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad64x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)442 unsigned int aom_highbd_sad64x16_avg_avx2(const uint8_t *src, int src_stride,
443                                           const uint8_t *ref, int ref_stride,
444                                           const uint8_t *second_pred) {
445   __m256i sad = _mm256_setzero_si256();
446   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
447   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
448   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
449   const int left_shift = 1;
450   int row_section = 0;
451 
452   while (row_section < 8) {
453     sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
454     srcp += src_stride << left_shift;
455     refp += ref_stride << left_shift;
456     secp += 64 << left_shift;
457     row_section += 1;
458   }
459   return get_sad_from_mm256_epi32(&sad);
460 }
461 #endif  // !CONFIG_REALTIME_ONLY
462 
aom_highbd_sad64x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)463 unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
464                                           const uint8_t *ref, int ref_stride,
465                                           const uint8_t *second_pred) {
466   __m256i sad = _mm256_setzero_si256();
467   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
468   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
469   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
470   const int left_shift = 1;
471   int row_section = 0;
472 
473   while (row_section < 16) {
474     sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
475     srcp += src_stride << left_shift;
476     refp += ref_stride << left_shift;
477     secp += 64 << left_shift;
478     row_section += 1;
479   }
480   return get_sad_from_mm256_epi32(&sad);
481 }
482 
aom_highbd_sad64x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)483 unsigned int aom_highbd_sad64x64_avg_avx2(const uint8_t *src, int src_stride,
484                                           const uint8_t *ref, int ref_stride,
485                                           const uint8_t *second_pred) {
486   const int left_shift = 5;
487   uint32_t sum = aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
488                                               second_pred);
489   src += src_stride << left_shift;
490   ref += ref_stride << left_shift;
491   second_pred += 64 << left_shift;
492   sum += aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
493                                       second_pred);
494   return sum;
495 }
496 
aom_highbd_sad64x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)497 unsigned int aom_highbd_sad64x128_avg_avx2(const uint8_t *src, int src_stride,
498                                            const uint8_t *ref, int ref_stride,
499                                            const uint8_t *second_pred) {
500   const int left_shift = 6;
501   uint32_t sum = aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
502                                               second_pred);
503   src += src_stride << left_shift;
504   ref += ref_stride << left_shift;
505   second_pred += 64 << left_shift;
506   sum += aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
507                                       second_pred);
508   return sum;
509 }
510 
aom_highbd_sad128x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)511 unsigned int aom_highbd_sad128x64_avg_avx2(const uint8_t *src, int src_stride,
512                                            const uint8_t *ref, int ref_stride,
513                                            const uint8_t *second_pred) {
514   __m256i sad = _mm256_setzero_si256();
515   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
516   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
517   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
518   int row = 0;
519   while (row < 64) {
520     sad128x1(srcp, refp, secp, &sad);
521     srcp += src_stride;
522     refp += ref_stride;
523     secp += 16 << 3;
524     row += 1;
525   }
526   return get_sad_from_mm256_epi32(&sad);
527 }
528 
aom_highbd_sad128x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)529 unsigned int aom_highbd_sad128x128_avg_avx2(const uint8_t *src, int src_stride,
530                                             const uint8_t *ref, int ref_stride,
531                                             const uint8_t *second_pred) {
532   unsigned int sum;
533   const int left_shift = 6;
534 
535   sum = aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
536                                       second_pred);
537   src += src_stride << left_shift;
538   ref += ref_stride << left_shift;
539   second_pred += 128 << left_shift;
540   sum += aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
541                                        second_pred);
542   return sum;
543 }
544 
545 // SAD 4D
546 // Combine 4 __m256i input vectors  v to uint32_t result[4]
get_4d_sad_from_mm256_epi32(const __m256i * v,uint32_t * res)547 static inline void get_4d_sad_from_mm256_epi32(const __m256i *v,
548                                                uint32_t *res) {
549   __m256i u0, u1, u2, u3;
550   const __m256i mask = _mm256_set1_epi64x(~0u);
551   __m128i sad;
552 
553   // 8 32-bit summation
554   u0 = _mm256_srli_si256(v[0], 4);
555   u1 = _mm256_srli_si256(v[1], 4);
556   u2 = _mm256_srli_si256(v[2], 4);
557   u3 = _mm256_srli_si256(v[3], 4);
558 
559   u0 = _mm256_add_epi32(u0, v[0]);
560   u1 = _mm256_add_epi32(u1, v[1]);
561   u2 = _mm256_add_epi32(u2, v[2]);
562   u3 = _mm256_add_epi32(u3, v[3]);
563 
564   u0 = _mm256_and_si256(u0, mask);
565   u1 = _mm256_and_si256(u1, mask);
566   u2 = _mm256_and_si256(u2, mask);
567   u3 = _mm256_and_si256(u3, mask);
568   // 4 32-bit summation, evenly positioned
569 
570   u1 = _mm256_slli_si256(u1, 4);
571   u3 = _mm256_slli_si256(u3, 4);
572 
573   u0 = _mm256_or_si256(u0, u1);
574   u2 = _mm256_or_si256(u2, u3);
575   // 8 32-bit summation, interleaved
576 
577   u1 = _mm256_unpacklo_epi64(u0, u2);
578   u3 = _mm256_unpackhi_epi64(u0, u2);
579 
580   u0 = _mm256_add_epi32(u1, u3);
581   sad = _mm_add_epi32(_mm256_extractf128_si256(u0, 1),
582                       _mm256_castsi256_si128(u0));
583   _mm_storeu_si128((__m128i *)res, sad);
584 }
585 
convert_pointers(const uint8_t * const ref8[],const uint16_t * ref[])586 static void convert_pointers(const uint8_t *const ref8[],
587                              const uint16_t *ref[]) {
588   ref[0] = CONVERT_TO_SHORTPTR(ref8[0]);
589   ref[1] = CONVERT_TO_SHORTPTR(ref8[1]);
590   ref[2] = CONVERT_TO_SHORTPTR(ref8[2]);
591   ref[3] = CONVERT_TO_SHORTPTR(ref8[3]);
592 }
593 
init_sad(__m256i * s)594 static void init_sad(__m256i *s) {
595   s[0] = _mm256_setzero_si256();
596   s[1] = _mm256_setzero_si256();
597   s[2] = _mm256_setzero_si256();
598   s[3] = _mm256_setzero_si256();
599 }
600 
aom_highbd_sadMxNxD_avx2(int M,int N,int D,const uint8_t * src,int src_stride,const uint8_t * const ref_array[4],int ref_stride,uint32_t sad_array[4])601 static AOM_FORCE_INLINE void aom_highbd_sadMxNxD_avx2(
602     int M, int N, int D, const uint8_t *src, int src_stride,
603     const uint8_t *const ref_array[4], int ref_stride, uint32_t sad_array[4]) {
604   __m256i sad_vec[4];
605   const uint16_t *refp[4];
606   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
607   const uint16_t *srcp;
608   const int shift_for_rows = (M < 128) + (M < 64);
609   const int row_units = 1 << shift_for_rows;
610   int i, r;
611 
612   init_sad(sad_vec);
613   convert_pointers(ref_array, refp);
614 
615   for (i = 0; i < D; ++i) {
616     srcp = keep;
617     for (r = 0; r < N; r += row_units) {
618       if (M == 128) {
619         sad128x1(srcp, refp[i], NULL, &sad_vec[i]);
620       } else if (M == 64) {
621         sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
622       } else if (M == 32) {
623         sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
624       } else if (M == 16) {
625         sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
626       } else {
627         assert(0);
628       }
629       srcp += src_stride << shift_for_rows;
630       refp[i] += ref_stride << shift_for_rows;
631     }
632   }
633   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
634 }
635 
636 #define HIGHBD_SAD_MXNX4D_AVX2(m, n)                                          \
637   void aom_highbd_sad##m##x##n##x4d_avx2(                                     \
638       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4],  \
639       int ref_stride, uint32_t sad_array[4]) {                                \
640     aom_highbd_sadMxNxD_avx2(m, n, 4, src, src_stride, ref_array, ref_stride, \
641                              sad_array);                                      \
642   }
643 #define HIGHBD_SAD_SKIP_MXNX4D_AVX2(m, n)                                    \
644   void aom_highbd_sad_skip_##m##x##n##x4d_avx2(                              \
645       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
646       int ref_stride, uint32_t sad_array[4]) {                               \
647     aom_highbd_sadMxNxD_avx2(m, (n / 2), 4, src, 2 * src_stride, ref_array,  \
648                              2 * ref_stride, sad_array);                     \
649     sad_array[0] <<= 1;                                                      \
650     sad_array[1] <<= 1;                                                      \
651     sad_array[2] <<= 1;                                                      \
652     sad_array[3] <<= 1;                                                      \
653   }
654 #define HIGHBD_SAD_MXNX3D_AVX2(m, n)                                          \
655   void aom_highbd_sad##m##x##n##x3d_avx2(                                     \
656       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4],  \
657       int ref_stride, uint32_t sad_array[4]) {                                \
658     aom_highbd_sadMxNxD_avx2(m, n, 3, src, src_stride, ref_array, ref_stride, \
659                              sad_array);                                      \
660   }
661 
662 HIGHBD_SAD_MXNX4D_AVX2(16, 8)
663 HIGHBD_SAD_MXNX4D_AVX2(16, 16)
664 HIGHBD_SAD_MXNX4D_AVX2(16, 32)
665 
666 HIGHBD_SAD_MXNX4D_AVX2(32, 16)
667 HIGHBD_SAD_MXNX4D_AVX2(32, 32)
668 HIGHBD_SAD_MXNX4D_AVX2(32, 64)
669 
670 HIGHBD_SAD_MXNX4D_AVX2(64, 32)
671 HIGHBD_SAD_MXNX4D_AVX2(64, 64)
672 HIGHBD_SAD_MXNX4D_AVX2(64, 128)
673 
674 HIGHBD_SAD_MXNX4D_AVX2(128, 64)
675 HIGHBD_SAD_MXNX4D_AVX2(128, 128)
676 
677 #if !CONFIG_REALTIME_ONLY
678 HIGHBD_SAD_MXNX4D_AVX2(16, 4)
679 HIGHBD_SAD_MXNX4D_AVX2(16, 64)
680 HIGHBD_SAD_MXNX4D_AVX2(32, 8)
681 HIGHBD_SAD_MXNX4D_AVX2(64, 16)
682 #endif  // !CONFIG_REALTIME_ONLY
683 
684 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 16)
685 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 32)
686 
687 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 16)
688 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 32)
689 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 64)
690 
691 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 32)
692 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 64)
693 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 128)
694 
695 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 64)
696 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 128)
697 
698 #if !CONFIG_REALTIME_ONLY
699 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 64)
700 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 16)
701 #endif  // !CONFIG_REALTIME_ONLY
702 
703 HIGHBD_SAD_MXNX3D_AVX2(16, 8)
704 HIGHBD_SAD_MXNX3D_AVX2(16, 16)
705 HIGHBD_SAD_MXNX3D_AVX2(16, 32)
706 
707 HIGHBD_SAD_MXNX3D_AVX2(32, 16)
708 HIGHBD_SAD_MXNX3D_AVX2(32, 32)
709 HIGHBD_SAD_MXNX3D_AVX2(32, 64)
710 
711 HIGHBD_SAD_MXNX3D_AVX2(64, 32)
712 HIGHBD_SAD_MXNX3D_AVX2(64, 64)
713 HIGHBD_SAD_MXNX3D_AVX2(64, 128)
714 
715 HIGHBD_SAD_MXNX3D_AVX2(128, 64)
716 HIGHBD_SAD_MXNX3D_AVX2(128, 128)
717 
718 #if !CONFIG_REALTIME_ONLY
719 HIGHBD_SAD_MXNX3D_AVX2(16, 4)
720 HIGHBD_SAD_MXNX3D_AVX2(16, 64)
721 HIGHBD_SAD_MXNX3D_AVX2(32, 8)
722 HIGHBD_SAD_MXNX3D_AVX2(64, 16)
723 #endif  // !CONFIG_REALTIME_ONLY
724