1 /*
2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <tmmintrin.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom_dsp/blend.h"
18 #include "aom/aom_integer.h"
19 #include "aom_dsp/x86/synonyms.h"
20 #include "aom_dsp/x86/masked_sad_intrin_ssse3.h"
21
masked_sad32xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)22 static INLINE unsigned int masked_sad32xh_avx2(
23 const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride,
24 const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride,
25 int width, int height) {
26 int x, y;
27 __m256i res = _mm256_setzero_si256();
28 const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
29 const __m256i round_scale =
30 _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
31 for (y = 0; y < height; y++) {
32 for (x = 0; x < width; x += 32) {
33 const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]);
34 const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]);
35 const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]);
36 const __m256i m = _mm256_lddqu_si256((const __m256i *)&m_ptr[x]);
37 const __m256i m_inv = _mm256_sub_epi8(mask_max, m);
38
39 // Calculate 16 predicted pixels.
40 // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
41 // is 64 * 255, so we have plenty of space to add rounding constants.
42 const __m256i data_l = _mm256_unpacklo_epi8(a, b);
43 const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv);
44 __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l);
45 pred_l = _mm256_mulhrs_epi16(pred_l, round_scale);
46
47 const __m256i data_r = _mm256_unpackhi_epi8(a, b);
48 const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv);
49 __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r);
50 pred_r = _mm256_mulhrs_epi16(pred_r, round_scale);
51
52 const __m256i pred = _mm256_packus_epi16(pred_l, pred_r);
53 res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src));
54 }
55
56 src_ptr += src_stride;
57 a_ptr += a_stride;
58 b_ptr += b_stride;
59 m_ptr += m_stride;
60 }
61 // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
62 res = _mm256_shuffle_epi32(res, 0xd8);
63 res = _mm256_permute4x64_epi64(res, 0xd8);
64 res = _mm256_hadd_epi32(res, res);
65 res = _mm256_hadd_epi32(res, res);
66 int32_t sad = _mm256_extract_epi32(res, 0);
67 return sad;
68 }
69
xx_loadu2_m128i(const void * hi,const void * lo)70 static INLINE __m256i xx_loadu2_m128i(const void *hi, const void *lo) {
71 __m128i a0 = _mm_lddqu_si128((const __m128i *)(lo));
72 __m128i a1 = _mm_lddqu_si128((const __m128i *)(hi));
73 __m256i a = _mm256_castsi128_si256(a0);
74 return _mm256_inserti128_si256(a, a1, 1);
75 }
76
masked_sad16xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)77 static INLINE unsigned int masked_sad16xh_avx2(
78 const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride,
79 const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride,
80 int height) {
81 int y;
82 __m256i res = _mm256_setzero_si256();
83 const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
84 const __m256i round_scale =
85 _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
86 for (y = 0; y < height; y += 2) {
87 const __m256i src = xx_loadu2_m128i(src_ptr + src_stride, src_ptr);
88 const __m256i a = xx_loadu2_m128i(a_ptr + a_stride, a_ptr);
89 const __m256i b = xx_loadu2_m128i(b_ptr + b_stride, b_ptr);
90 const __m256i m = xx_loadu2_m128i(m_ptr + m_stride, m_ptr);
91 const __m256i m_inv = _mm256_sub_epi8(mask_max, m);
92
93 // Calculate 16 predicted pixels.
94 // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
95 // is 64 * 255, so we have plenty of space to add rounding constants.
96 const __m256i data_l = _mm256_unpacklo_epi8(a, b);
97 const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv);
98 __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l);
99 pred_l = _mm256_mulhrs_epi16(pred_l, round_scale);
100
101 const __m256i data_r = _mm256_unpackhi_epi8(a, b);
102 const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv);
103 __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r);
104 pred_r = _mm256_mulhrs_epi16(pred_r, round_scale);
105
106 const __m256i pred = _mm256_packus_epi16(pred_l, pred_r);
107 res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src));
108
109 src_ptr += src_stride << 1;
110 a_ptr += a_stride << 1;
111 b_ptr += b_stride << 1;
112 m_ptr += m_stride << 1;
113 }
114 // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
115 res = _mm256_shuffle_epi32(res, 0xd8);
116 res = _mm256_permute4x64_epi64(res, 0xd8);
117 res = _mm256_hadd_epi32(res, res);
118 res = _mm256_hadd_epi32(res, res);
119 int32_t sad = _mm256_extract_epi32(res, 0);
120 return sad;
121 }
122
aom_masked_sad_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred,const uint8_t * msk,int msk_stride,int invert_mask,int m,int n)123 static INLINE unsigned int aom_masked_sad_avx2(
124 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
125 const uint8_t *second_pred, const uint8_t *msk, int msk_stride,
126 int invert_mask, int m, int n) {
127 unsigned int sad;
128 if (!invert_mask) {
129 switch (m) {
130 case 4:
131 sad = aom_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride,
132 second_pred, m, msk, msk_stride, n);
133 break;
134 case 8:
135 sad = aom_masked_sad8xh_ssse3(src, src_stride, ref, ref_stride,
136 second_pred, m, msk, msk_stride, n);
137 break;
138 case 16:
139 sad = masked_sad16xh_avx2(src, src_stride, ref, ref_stride, second_pred,
140 m, msk, msk_stride, n);
141 break;
142 default:
143 sad = masked_sad32xh_avx2(src, src_stride, ref, ref_stride, second_pred,
144 m, msk, msk_stride, m, n);
145 break;
146 }
147 } else {
148 switch (m) {
149 case 4:
150 sad = aom_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref,
151 ref_stride, msk, msk_stride, n);
152 break;
153 case 8:
154 sad = aom_masked_sad8xh_ssse3(src, src_stride, second_pred, m, ref,
155 ref_stride, msk, msk_stride, n);
156 break;
157 case 16:
158 sad = masked_sad16xh_avx2(src, src_stride, second_pred, m, ref,
159 ref_stride, msk, msk_stride, n);
160 break;
161 default:
162 sad = masked_sad32xh_avx2(src, src_stride, second_pred, m, ref,
163 ref_stride, msk, msk_stride, m, n);
164 break;
165 }
166 }
167 return sad;
168 }
169
170 #define MASKSADMXN_AVX2(m, n) \
171 unsigned int aom_masked_sad##m##x##n##_avx2( \
172 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
173 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
174 int invert_mask) { \
175 return aom_masked_sad_avx2(src, src_stride, ref, ref_stride, second_pred, \
176 msk, msk_stride, invert_mask, m, n); \
177 }
178
179 MASKSADMXN_AVX2(4, 4)
180 MASKSADMXN_AVX2(4, 8)
181 MASKSADMXN_AVX2(8, 4)
182 MASKSADMXN_AVX2(8, 8)
183 MASKSADMXN_AVX2(8, 16)
184 MASKSADMXN_AVX2(16, 8)
185 MASKSADMXN_AVX2(16, 16)
186 MASKSADMXN_AVX2(16, 32)
187 MASKSADMXN_AVX2(32, 16)
188 MASKSADMXN_AVX2(32, 32)
189 MASKSADMXN_AVX2(32, 64)
190 MASKSADMXN_AVX2(64, 32)
191 MASKSADMXN_AVX2(64, 64)
192 MASKSADMXN_AVX2(64, 128)
193 MASKSADMXN_AVX2(128, 64)
194 MASKSADMXN_AVX2(128, 128)
195 MASKSADMXN_AVX2(4, 16)
196 MASKSADMXN_AVX2(16, 4)
197 MASKSADMXN_AVX2(8, 32)
198 MASKSADMXN_AVX2(32, 8)
199 MASKSADMXN_AVX2(16, 64)
200 MASKSADMXN_AVX2(64, 16)
201
highbd_masked_sad8xh_avx2(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int height)202 static INLINE unsigned int highbd_masked_sad8xh_avx2(
203 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
204 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
205 int height) {
206 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
207 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
208 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
209 int y;
210 __m256i res = _mm256_setzero_si256();
211 const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
212 const __m256i round_const =
213 _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
214 const __m256i one = _mm256_set1_epi16(1);
215
216 for (y = 0; y < height; y += 2) {
217 const __m256i src = xx_loadu2_m128i(src_ptr + src_stride, src_ptr);
218 const __m256i a = xx_loadu2_m128i(a_ptr + a_stride, a_ptr);
219 const __m256i b = xx_loadu2_m128i(b_ptr + b_stride, b_ptr);
220 // Zero-extend mask to 16 bits
221 const __m256i m = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(
222 _mm_loadl_epi64((const __m128i *)(m_ptr)),
223 _mm_loadl_epi64((const __m128i *)(m_ptr + m_stride))));
224 const __m256i m_inv = _mm256_sub_epi16(mask_max, m);
225
226 const __m256i data_l = _mm256_unpacklo_epi16(a, b);
227 const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv);
228 __m256i pred_l = _mm256_madd_epi16(data_l, mask_l);
229 pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const),
230 AOM_BLEND_A64_ROUND_BITS);
231
232 const __m256i data_r = _mm256_unpackhi_epi16(a, b);
233 const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv);
234 __m256i pred_r = _mm256_madd_epi16(data_r, mask_r);
235 pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const),
236 AOM_BLEND_A64_ROUND_BITS);
237
238 // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
239 // so it is safe to do signed saturation here.
240 const __m256i pred = _mm256_packs_epi32(pred_l, pred_r);
241 // There is no 16-bit SAD instruction, so we have to synthesize
242 // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
243 // and accumulating them at the end
244 const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src));
245 res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one));
246
247 src_ptr += src_stride << 1;
248 a_ptr += a_stride << 1;
249 b_ptr += b_stride << 1;
250 m_ptr += m_stride << 1;
251 }
252 // At this point, we have four 32-bit partial SADs stored in 'res'.
253 res = _mm256_hadd_epi32(res, res);
254 res = _mm256_hadd_epi32(res, res);
255 int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4);
256 return sad;
257 }
258
highbd_masked_sad16xh_avx2(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)259 static INLINE unsigned int highbd_masked_sad16xh_avx2(
260 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
261 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
262 int width, int height) {
263 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
264 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
265 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
266 int x, y;
267 __m256i res = _mm256_setzero_si256();
268 const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
269 const __m256i round_const =
270 _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
271 const __m256i one = _mm256_set1_epi16(1);
272
273 for (y = 0; y < height; y++) {
274 for (x = 0; x < width; x += 16) {
275 const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]);
276 const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]);
277 const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]);
278 // Zero-extend mask to 16 bits
279 const __m256i m =
280 _mm256_cvtepu8_epi16(_mm_lddqu_si128((const __m128i *)&m_ptr[x]));
281 const __m256i m_inv = _mm256_sub_epi16(mask_max, m);
282
283 const __m256i data_l = _mm256_unpacklo_epi16(a, b);
284 const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv);
285 __m256i pred_l = _mm256_madd_epi16(data_l, mask_l);
286 pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const),
287 AOM_BLEND_A64_ROUND_BITS);
288
289 const __m256i data_r = _mm256_unpackhi_epi16(a, b);
290 const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv);
291 __m256i pred_r = _mm256_madd_epi16(data_r, mask_r);
292 pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const),
293 AOM_BLEND_A64_ROUND_BITS);
294
295 // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
296 // so it is safe to do signed saturation here.
297 const __m256i pred = _mm256_packs_epi32(pred_l, pred_r);
298 // There is no 16-bit SAD instruction, so we have to synthesize
299 // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
300 // and accumulating them at the end
301 const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src));
302 res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one));
303 }
304
305 src_ptr += src_stride;
306 a_ptr += a_stride;
307 b_ptr += b_stride;
308 m_ptr += m_stride;
309 }
310 // At this point, we have four 32-bit partial SADs stored in 'res'.
311 res = _mm256_hadd_epi32(res, res);
312 res = _mm256_hadd_epi32(res, res);
313 int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4);
314 return sad;
315 }
316
aom_highbd_masked_sad_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred,const uint8_t * msk,int msk_stride,int invert_mask,int m,int n)317 static INLINE unsigned int aom_highbd_masked_sad_avx2(
318 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
319 const uint8_t *second_pred, const uint8_t *msk, int msk_stride,
320 int invert_mask, int m, int n) {
321 unsigned int sad;
322 if (!invert_mask) {
323 switch (m) {
324 case 4:
325 sad =
326 aom_highbd_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride,
327 second_pred, m, msk, msk_stride, n);
328 break;
329 case 8:
330 sad = highbd_masked_sad8xh_avx2(src, src_stride, ref, ref_stride,
331 second_pred, m, msk, msk_stride, n);
332 break;
333 default:
334 sad = highbd_masked_sad16xh_avx2(src, src_stride, ref, ref_stride,
335 second_pred, m, msk, msk_stride, m, n);
336 break;
337 }
338 } else {
339 switch (m) {
340 case 4:
341 sad =
342 aom_highbd_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref,
343 ref_stride, msk, msk_stride, n);
344 break;
345 case 8:
346 sad = highbd_masked_sad8xh_avx2(src, src_stride, second_pred, m, ref,
347 ref_stride, msk, msk_stride, n);
348 break;
349 default:
350 sad = highbd_masked_sad16xh_avx2(src, src_stride, second_pred, m, ref,
351 ref_stride, msk, msk_stride, m, n);
352 break;
353 }
354 }
355 return sad;
356 }
357
358 #define HIGHBD_MASKSADMXN_AVX2(m, n) \
359 unsigned int aom_highbd_masked_sad##m##x##n##_avx2( \
360 const uint8_t *src8, int src_stride, const uint8_t *ref8, \
361 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \
362 int msk_stride, int invert_mask) { \
363 return aom_highbd_masked_sad_avx2(src8, src_stride, ref8, ref_stride, \
364 second_pred8, msk, msk_stride, \
365 invert_mask, m, n); \
366 }
367
368 HIGHBD_MASKSADMXN_AVX2(4, 4)
369 HIGHBD_MASKSADMXN_AVX2(4, 8)
370 HIGHBD_MASKSADMXN_AVX2(8, 4)
371 HIGHBD_MASKSADMXN_AVX2(8, 8)
372 HIGHBD_MASKSADMXN_AVX2(8, 16)
373 HIGHBD_MASKSADMXN_AVX2(16, 8)
374 HIGHBD_MASKSADMXN_AVX2(16, 16)
375 HIGHBD_MASKSADMXN_AVX2(16, 32)
376 HIGHBD_MASKSADMXN_AVX2(32, 16)
377 HIGHBD_MASKSADMXN_AVX2(32, 32)
378 HIGHBD_MASKSADMXN_AVX2(32, 64)
379 HIGHBD_MASKSADMXN_AVX2(64, 32)
380 HIGHBD_MASKSADMXN_AVX2(64, 64)
381 HIGHBD_MASKSADMXN_AVX2(64, 128)
382 HIGHBD_MASKSADMXN_AVX2(128, 64)
383 HIGHBD_MASKSADMXN_AVX2(128, 128)
384 HIGHBD_MASKSADMXN_AVX2(4, 16)
385 HIGHBD_MASKSADMXN_AVX2(16, 4)
386 HIGHBD_MASKSADMXN_AVX2(8, 32)
387 HIGHBD_MASKSADMXN_AVX2(32, 8)
388 HIGHBD_MASKSADMXN_AVX2(16, 64)
389 HIGHBD_MASKSADMXN_AVX2(64, 16)
390