1 /*
2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <stdio.h>
13 #include <tmmintrin.h>
14
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17
18 #include "aom_dsp/blend.h"
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/x86/synonyms.h"
21
22 #include "aom_dsp/x86/masked_sad_intrin_ssse3.h"
23
24 // For width a multiple of 16
25 static INLINE unsigned int masked_sad_ssse3(const uint8_t *src_ptr,
26 int src_stride,
27 const uint8_t *a_ptr, int a_stride,
28 const uint8_t *b_ptr, int b_stride,
29 const uint8_t *m_ptr, int m_stride,
30 int width, int height);
31
32 #define MASKSADMXN_SSSE3(m, n) \
33 unsigned int aom_masked_sad##m##x##n##_ssse3( \
34 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
35 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
36 int invert_mask) { \
37 if (!invert_mask) \
38 return masked_sad_ssse3(src, src_stride, ref, ref_stride, second_pred, \
39 m, msk, msk_stride, m, n); \
40 else \
41 return masked_sad_ssse3(src, src_stride, second_pred, m, ref, \
42 ref_stride, msk, msk_stride, m, n); \
43 }
44
45 #define MASKSAD8XN_SSSE3(n) \
46 unsigned int aom_masked_sad8x##n##_ssse3( \
47 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
48 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
49 int invert_mask) { \
50 if (!invert_mask) \
51 return aom_masked_sad8xh_ssse3(src, src_stride, ref, ref_stride, \
52 second_pred, 8, msk, msk_stride, n); \
53 else \
54 return aom_masked_sad8xh_ssse3(src, src_stride, second_pred, 8, ref, \
55 ref_stride, msk, msk_stride, n); \
56 }
57
58 #define MASKSAD4XN_SSSE3(n) \
59 unsigned int aom_masked_sad4x##n##_ssse3( \
60 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
61 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
62 int invert_mask) { \
63 if (!invert_mask) \
64 return aom_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride, \
65 second_pred, 4, msk, msk_stride, n); \
66 else \
67 return aom_masked_sad4xh_ssse3(src, src_stride, second_pred, 4, ref, \
68 ref_stride, msk, msk_stride, n); \
69 }
70
71 MASKSADMXN_SSSE3(128, 128)
72 MASKSADMXN_SSSE3(128, 64)
73 MASKSADMXN_SSSE3(64, 128)
74 MASKSADMXN_SSSE3(64, 64)
75 MASKSADMXN_SSSE3(64, 32)
76 MASKSADMXN_SSSE3(32, 64)
77 MASKSADMXN_SSSE3(32, 32)
78 MASKSADMXN_SSSE3(32, 16)
79 MASKSADMXN_SSSE3(16, 32)
80 MASKSADMXN_SSSE3(16, 16)
81 MASKSADMXN_SSSE3(16, 8)
82 MASKSAD8XN_SSSE3(16)
83 MASKSAD8XN_SSSE3(8)
84 MASKSAD8XN_SSSE3(4)
85 MASKSAD4XN_SSSE3(8)
86 MASKSAD4XN_SSSE3(4)
87 MASKSAD4XN_SSSE3(16)
88 MASKSADMXN_SSSE3(16, 4)
89 MASKSAD8XN_SSSE3(32)
90 MASKSADMXN_SSSE3(32, 8)
91 MASKSADMXN_SSSE3(16, 64)
92 MASKSADMXN_SSSE3(64, 16)
93
masked_sad_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)94 static INLINE unsigned int masked_sad_ssse3(const uint8_t *src_ptr,
95 int src_stride,
96 const uint8_t *a_ptr, int a_stride,
97 const uint8_t *b_ptr, int b_stride,
98 const uint8_t *m_ptr, int m_stride,
99 int width, int height) {
100 int x, y;
101 __m128i res = _mm_setzero_si128();
102 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
103
104 for (y = 0; y < height; y++) {
105 for (x = 0; x < width; x += 16) {
106 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
107 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
108 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
109 const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
110 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
111
112 // Calculate 16 predicted pixels.
113 // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
114 // is 64 * 255, so we have plenty of space to add rounding constants.
115 const __m128i data_l = _mm_unpacklo_epi8(a, b);
116 const __m128i mask_l = _mm_unpacklo_epi8(m, m_inv);
117 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
118 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
119
120 const __m128i data_r = _mm_unpackhi_epi8(a, b);
121 const __m128i mask_r = _mm_unpackhi_epi8(m, m_inv);
122 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
123 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
124
125 const __m128i pred = _mm_packus_epi16(pred_l, pred_r);
126 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
127 }
128
129 src_ptr += src_stride;
130 a_ptr += a_stride;
131 b_ptr += b_stride;
132 m_ptr += m_stride;
133 }
134 // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
135 unsigned int sad = (unsigned int)(_mm_cvtsi128_si32(res) +
136 _mm_cvtsi128_si32(_mm_srli_si128(res, 8)));
137 return sad;
138 }
139
aom_masked_sad8xh_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)140 unsigned int aom_masked_sad8xh_ssse3(const uint8_t *src_ptr, int src_stride,
141 const uint8_t *a_ptr, int a_stride,
142 const uint8_t *b_ptr, int b_stride,
143 const uint8_t *m_ptr, int m_stride,
144 int height) {
145 int y;
146 __m128i res = _mm_setzero_si128();
147 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
148
149 for (y = 0; y < height; y += 2) {
150 const __m128i src = _mm_unpacklo_epi64(
151 _mm_loadl_epi64((const __m128i *)src_ptr),
152 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
153 const __m128i a0 = _mm_loadl_epi64((const __m128i *)a_ptr);
154 const __m128i a1 = _mm_loadl_epi64((const __m128i *)&a_ptr[a_stride]);
155 const __m128i b0 = _mm_loadl_epi64((const __m128i *)b_ptr);
156 const __m128i b1 = _mm_loadl_epi64((const __m128i *)&b_ptr[b_stride]);
157 const __m128i m =
158 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
159 _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
160 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
161
162 const __m128i data_l = _mm_unpacklo_epi8(a0, b0);
163 const __m128i mask_l = _mm_unpacklo_epi8(m, m_inv);
164 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
165 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
166
167 const __m128i data_r = _mm_unpacklo_epi8(a1, b1);
168 const __m128i mask_r = _mm_unpackhi_epi8(m, m_inv);
169 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
170 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
171
172 const __m128i pred = _mm_packus_epi16(pred_l, pred_r);
173 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
174
175 src_ptr += src_stride * 2;
176 a_ptr += a_stride * 2;
177 b_ptr += b_stride * 2;
178 m_ptr += m_stride * 2;
179 }
180 unsigned int sad = (unsigned int)(_mm_cvtsi128_si32(res) +
181 _mm_cvtsi128_si32(_mm_srli_si128(res, 8)));
182 return sad;
183 }
184
aom_masked_sad4xh_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)185 unsigned int aom_masked_sad4xh_ssse3(const uint8_t *src_ptr, int src_stride,
186 const uint8_t *a_ptr, int a_stride,
187 const uint8_t *b_ptr, int b_stride,
188 const uint8_t *m_ptr, int m_stride,
189 int height) {
190 int y;
191 __m128i res = _mm_setzero_si128();
192 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
193
194 for (y = 0; y < height; y += 2) {
195 // Load two rows at a time, this seems to be a bit faster
196 // than four rows at a time in this case.
197 const __m128i src =
198 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)src_ptr),
199 _mm_cvtsi32_si128(*(int *)&src_ptr[src_stride]));
200 const __m128i a =
201 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)a_ptr),
202 _mm_cvtsi32_si128(*(int *)&a_ptr[a_stride]));
203 const __m128i b =
204 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)b_ptr),
205 _mm_cvtsi32_si128(*(int *)&b_ptr[b_stride]));
206 const __m128i m =
207 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)m_ptr),
208 _mm_cvtsi32_si128(*(int *)&m_ptr[m_stride]));
209 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
210
211 const __m128i data = _mm_unpacklo_epi8(a, b);
212 const __m128i mask = _mm_unpacklo_epi8(m, m_inv);
213 __m128i pred_16bit = _mm_maddubs_epi16(data, mask);
214 pred_16bit = xx_roundn_epu16(pred_16bit, AOM_BLEND_A64_ROUND_BITS);
215
216 const __m128i pred = _mm_packus_epi16(pred_16bit, _mm_setzero_si128());
217 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
218
219 src_ptr += src_stride * 2;
220 a_ptr += a_stride * 2;
221 b_ptr += b_stride * 2;
222 m_ptr += m_stride * 2;
223 }
224 // At this point, the SAD is stored in lane 0 of 'res'
225 return (unsigned int)_mm_cvtsi128_si32(res);
226 }
227
228 // For width a multiple of 8
229 static INLINE unsigned int highbd_masked_sad_ssse3(
230 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
231 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
232 int width, int height);
233
234 #define HIGHBD_MASKSADMXN_SSSE3(m, n) \
235 unsigned int aom_highbd_masked_sad##m##x##n##_ssse3( \
236 const uint8_t *src8, int src_stride, const uint8_t *ref8, \
237 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \
238 int msk_stride, int invert_mask) { \
239 if (!invert_mask) \
240 return highbd_masked_sad_ssse3(src8, src_stride, ref8, ref_stride, \
241 second_pred8, m, msk, msk_stride, m, n); \
242 else \
243 return highbd_masked_sad_ssse3(src8, src_stride, second_pred8, m, ref8, \
244 ref_stride, msk, msk_stride, m, n); \
245 }
246
247 #define HIGHBD_MASKSAD4XN_SSSE3(n) \
248 unsigned int aom_highbd_masked_sad4x##n##_ssse3( \
249 const uint8_t *src8, int src_stride, const uint8_t *ref8, \
250 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \
251 int msk_stride, int invert_mask) { \
252 if (!invert_mask) \
253 return aom_highbd_masked_sad4xh_ssse3(src8, src_stride, ref8, \
254 ref_stride, second_pred8, 4, msk, \
255 msk_stride, n); \
256 else \
257 return aom_highbd_masked_sad4xh_ssse3(src8, src_stride, second_pred8, 4, \
258 ref8, ref_stride, msk, msk_stride, \
259 n); \
260 }
261
262 HIGHBD_MASKSADMXN_SSSE3(128, 128)
263 HIGHBD_MASKSADMXN_SSSE3(128, 64)
264 HIGHBD_MASKSADMXN_SSSE3(64, 128)
265 HIGHBD_MASKSADMXN_SSSE3(64, 64)
266 HIGHBD_MASKSADMXN_SSSE3(64, 32)
267 HIGHBD_MASKSADMXN_SSSE3(32, 64)
268 HIGHBD_MASKSADMXN_SSSE3(32, 32)
269 HIGHBD_MASKSADMXN_SSSE3(32, 16)
270 HIGHBD_MASKSADMXN_SSSE3(16, 32)
271 HIGHBD_MASKSADMXN_SSSE3(16, 16)
272 HIGHBD_MASKSADMXN_SSSE3(16, 8)
273 HIGHBD_MASKSADMXN_SSSE3(8, 16)
274 HIGHBD_MASKSADMXN_SSSE3(8, 8)
275 HIGHBD_MASKSADMXN_SSSE3(8, 4)
276 HIGHBD_MASKSAD4XN_SSSE3(8)
277 HIGHBD_MASKSAD4XN_SSSE3(4)
278 HIGHBD_MASKSAD4XN_SSSE3(16)
279 HIGHBD_MASKSADMXN_SSSE3(16, 4)
280 HIGHBD_MASKSADMXN_SSSE3(8, 32)
281 HIGHBD_MASKSADMXN_SSSE3(32, 8)
282 HIGHBD_MASKSADMXN_SSSE3(16, 64)
283 HIGHBD_MASKSADMXN_SSSE3(64, 16)
284
highbd_masked_sad_ssse3(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)285 static INLINE unsigned int highbd_masked_sad_ssse3(
286 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
287 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
288 int width, int height) {
289 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
290 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
291 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
292 int x, y;
293 __m128i res = _mm_setzero_si128();
294 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
295 const __m128i round_const =
296 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
297 const __m128i one = _mm_set1_epi16(1);
298
299 for (y = 0; y < height; y++) {
300 for (x = 0; x < width; x += 8) {
301 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
302 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
303 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
304 // Zero-extend mask to 16 bits
305 const __m128i m = _mm_unpacklo_epi8(
306 _mm_loadl_epi64((const __m128i *)&m_ptr[x]), _mm_setzero_si128());
307 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
308
309 const __m128i data_l = _mm_unpacklo_epi16(a, b);
310 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
311 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
312 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
313 AOM_BLEND_A64_ROUND_BITS);
314
315 const __m128i data_r = _mm_unpackhi_epi16(a, b);
316 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
317 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
318 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
319 AOM_BLEND_A64_ROUND_BITS);
320
321 // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
322 // so it is safe to do signed saturation here.
323 const __m128i pred = _mm_packs_epi32(pred_l, pred_r);
324 // There is no 16-bit SAD instruction, so we have to synthesize
325 // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
326 // and accumulating them at the end
327 const __m128i diff = _mm_abs_epi16(_mm_sub_epi16(pred, src));
328 res = _mm_add_epi32(res, _mm_madd_epi16(diff, one));
329 }
330
331 src_ptr += src_stride;
332 a_ptr += a_stride;
333 b_ptr += b_stride;
334 m_ptr += m_stride;
335 }
336 // At this point, we have four 32-bit partial SADs stored in 'res'.
337 res = _mm_hadd_epi32(res, res);
338 res = _mm_hadd_epi32(res, res);
339 int sad = _mm_cvtsi128_si32(res);
340 return sad;
341 }
342
aom_highbd_masked_sad4xh_ssse3(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int height)343 unsigned int aom_highbd_masked_sad4xh_ssse3(const uint8_t *src8, int src_stride,
344 const uint8_t *a8, int a_stride,
345 const uint8_t *b8, int b_stride,
346 const uint8_t *m_ptr, int m_stride,
347 int height) {
348 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
349 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
350 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
351 int y;
352 __m128i res = _mm_setzero_si128();
353 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
354 const __m128i round_const =
355 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
356 const __m128i one = _mm_set1_epi16(1);
357
358 for (y = 0; y < height; y += 2) {
359 const __m128i src = _mm_unpacklo_epi64(
360 _mm_loadl_epi64((const __m128i *)src_ptr),
361 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
362 const __m128i a =
363 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)a_ptr),
364 _mm_loadl_epi64((const __m128i *)&a_ptr[a_stride]));
365 const __m128i b =
366 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)b_ptr),
367 _mm_loadl_epi64((const __m128i *)&b_ptr[b_stride]));
368 // Zero-extend mask to 16 bits
369 const __m128i m = _mm_unpacklo_epi8(
370 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(const int *)m_ptr),
371 _mm_cvtsi32_si128(*(const int *)&m_ptr[m_stride])),
372 _mm_setzero_si128());
373 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
374
375 const __m128i data_l = _mm_unpacklo_epi16(a, b);
376 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
377 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
378 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
379 AOM_BLEND_A64_ROUND_BITS);
380
381 const __m128i data_r = _mm_unpackhi_epi16(a, b);
382 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
383 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
384 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
385 AOM_BLEND_A64_ROUND_BITS);
386
387 const __m128i pred = _mm_packs_epi32(pred_l, pred_r);
388 const __m128i diff = _mm_abs_epi16(_mm_sub_epi16(pred, src));
389 res = _mm_add_epi32(res, _mm_madd_epi16(diff, one));
390
391 src_ptr += src_stride * 2;
392 a_ptr += a_stride * 2;
393 b_ptr += b_stride * 2;
394 m_ptr += m_stride * 2;
395 }
396 res = _mm_hadd_epi32(res, res);
397 res = _mm_hadd_epi32(res, res);
398 int sad = _mm_cvtsi128_si32(res);
399 return sad;
400 }
401