1 /*
2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <stdlib.h>
13 #include <string.h>
14 #include <tmmintrin.h>
15
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/aom_filter.h"
21 #include "aom_dsp/blend.h"
22 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
23 #include "aom_dsp/x86/synonyms.h"
24 #include "aom_ports/mem.h"
25
26 // For width a multiple of 16
27 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
28 int yoffset, uint8_t *dst, int w, int h);
29
30 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
31 int yoffset, uint8_t *dst, int h);
32
33 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
34 int yoffset, uint8_t *dst, int h);
35
36 // For width a multiple of 16
37 static void masked_variance(const uint8_t *src_ptr, int src_stride,
38 const uint8_t *a_ptr, int a_stride,
39 const uint8_t *b_ptr, int b_stride,
40 const uint8_t *m_ptr, int m_stride, int width,
41 int height, unsigned int *sse, int *sum_);
42
43 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
44 const uint8_t *a_ptr, const uint8_t *b_ptr,
45 const uint8_t *m_ptr, int m_stride, int height,
46 unsigned int *sse, int *sum_);
47
48 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
49 const uint8_t *a_ptr, const uint8_t *b_ptr,
50 const uint8_t *m_ptr, int m_stride, int height,
51 unsigned int *sse, int *sum_);
52
53 #define MASK_SUBPIX_VAR_SSSE3(W, H) \
54 unsigned int aom_masked_sub_pixel_variance##W##x##H##_ssse3( \
55 const uint8_t *src, int src_stride, int xoffset, int yoffset, \
56 const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
57 const uint8_t *msk, int msk_stride, int invert_mask, \
58 unsigned int *sse) { \
59 int sum; \
60 uint8_t temp[(H + 1) * W]; \
61 \
62 bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
63 \
64 if (!invert_mask) \
65 masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
66 msk_stride, W, H, sse, &sum); \
67 else \
68 masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
69 msk_stride, W, H, sse, &sum); \
70 return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H)); \
71 }
72
73 #define MASK_SUBPIX_VAR8XH_SSSE3(H) \
74 unsigned int aom_masked_sub_pixel_variance8x##H##_ssse3( \
75 const uint8_t *src, int src_stride, int xoffset, int yoffset, \
76 const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
77 const uint8_t *msk, int msk_stride, int invert_mask, \
78 unsigned int *sse) { \
79 int sum; \
80 uint8_t temp[(H + 1) * 8]; \
81 \
82 bilinear_filter8xh(src, src_stride, xoffset, yoffset, temp, H); \
83 \
84 if (!invert_mask) \
85 masked_variance8xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
86 H, sse, &sum); \
87 else \
88 masked_variance8xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
89 H, sse, &sum); \
90 return *sse - (uint32_t)(((int64_t)sum * sum) / (8 * H)); \
91 }
92
93 #define MASK_SUBPIX_VAR4XH_SSSE3(H) \
94 unsigned int aom_masked_sub_pixel_variance4x##H##_ssse3( \
95 const uint8_t *src, int src_stride, int xoffset, int yoffset, \
96 const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
97 const uint8_t *msk, int msk_stride, int invert_mask, \
98 unsigned int *sse) { \
99 int sum; \
100 uint8_t temp[(H + 1) * 4]; \
101 \
102 bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
103 \
104 if (!invert_mask) \
105 masked_variance4xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
106 H, sse, &sum); \
107 else \
108 masked_variance4xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
109 H, sse, &sum); \
110 return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H)); \
111 }
112
113 MASK_SUBPIX_VAR_SSSE3(128, 128)
114 MASK_SUBPIX_VAR_SSSE3(128, 64)
115 MASK_SUBPIX_VAR_SSSE3(64, 128)
116 MASK_SUBPIX_VAR_SSSE3(64, 64)
117 MASK_SUBPIX_VAR_SSSE3(64, 32)
118 MASK_SUBPIX_VAR_SSSE3(32, 64)
119 MASK_SUBPIX_VAR_SSSE3(32, 32)
120 MASK_SUBPIX_VAR_SSSE3(32, 16)
121 MASK_SUBPIX_VAR_SSSE3(16, 32)
122 MASK_SUBPIX_VAR_SSSE3(16, 16)
123 MASK_SUBPIX_VAR_SSSE3(16, 8)
124 MASK_SUBPIX_VAR8XH_SSSE3(16)
125 MASK_SUBPIX_VAR8XH_SSSE3(8)
126 MASK_SUBPIX_VAR8XH_SSSE3(4)
127 MASK_SUBPIX_VAR4XH_SSSE3(8)
128 MASK_SUBPIX_VAR4XH_SSSE3(4)
129 MASK_SUBPIX_VAR4XH_SSSE3(16)
130 MASK_SUBPIX_VAR_SSSE3(16, 4)
131 MASK_SUBPIX_VAR8XH_SSSE3(32)
132 MASK_SUBPIX_VAR_SSSE3(32, 8)
133 MASK_SUBPIX_VAR_SSSE3(64, 16)
134 MASK_SUBPIX_VAR_SSSE3(16, 64)
135
filter_block(const __m128i a,const __m128i b,const __m128i filter)136 static INLINE __m128i filter_block(const __m128i a, const __m128i b,
137 const __m128i filter) {
138 __m128i v0 = _mm_unpacklo_epi8(a, b);
139 v0 = _mm_maddubs_epi16(v0, filter);
140 v0 = xx_roundn_epu16(v0, FILTER_BITS);
141
142 __m128i v1 = _mm_unpackhi_epi8(a, b);
143 v1 = _mm_maddubs_epi16(v1, filter);
144 v1 = xx_roundn_epu16(v1, FILTER_BITS);
145
146 return _mm_packus_epi16(v0, v1);
147 }
148
bilinear_filter(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int w,int h)149 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
150 int yoffset, uint8_t *dst, int w, int h) {
151 int i, j;
152 // Horizontal filter
153 if (xoffset == 0) {
154 uint8_t *b = dst;
155 for (i = 0; i < h + 1; ++i) {
156 for (j = 0; j < w; j += 16) {
157 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
158 _mm_storeu_si128((__m128i *)&b[j], x);
159 }
160 src += src_stride;
161 b += w;
162 }
163 } else if (xoffset == 4) {
164 uint8_t *b = dst;
165 for (i = 0; i < h + 1; ++i) {
166 for (j = 0; j < w; j += 16) {
167 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
168 __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
169 __m128i z = _mm_alignr_epi8(y, x, 1);
170 _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu8(x, z));
171 }
172 src += src_stride;
173 b += w;
174 }
175 } else {
176 uint8_t *b = dst;
177 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
178 const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
179 for (i = 0; i < h + 1; ++i) {
180 for (j = 0; j < w; j += 16) {
181 const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
182 const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
183 const __m128i z = _mm_alignr_epi8(y, x, 1);
184 const __m128i res = filter_block(x, z, hfilter_vec);
185 _mm_storeu_si128((__m128i *)&b[j], res);
186 }
187
188 src += src_stride;
189 b += w;
190 }
191 }
192
193 // Vertical filter
194 if (yoffset == 0) {
195 // The data is already in 'dst', so no need to filter
196 } else if (yoffset == 4) {
197 for (i = 0; i < h; ++i) {
198 for (j = 0; j < w; j += 16) {
199 __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
200 __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
201 _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu8(x, y));
202 }
203 dst += w;
204 }
205 } else {
206 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
207 const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
208 for (i = 0; i < h; ++i) {
209 for (j = 0; j < w; j += 16) {
210 const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
211 const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
212 const __m128i res = filter_block(x, y, vfilter_vec);
213 _mm_storeu_si128((__m128i *)&dst[j], res);
214 }
215
216 dst += w;
217 }
218 }
219 }
220
filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)221 static INLINE __m128i filter_block_2rows(const __m128i *a0, const __m128i *b0,
222 const __m128i *a1, const __m128i *b1,
223 const __m128i *filter) {
224 __m128i v0 = _mm_unpacklo_epi8(*a0, *b0);
225 v0 = _mm_maddubs_epi16(v0, *filter);
226 v0 = xx_roundn_epu16(v0, FILTER_BITS);
227
228 __m128i v1 = _mm_unpacklo_epi8(*a1, *b1);
229 v1 = _mm_maddubs_epi16(v1, *filter);
230 v1 = xx_roundn_epu16(v1, FILTER_BITS);
231
232 return _mm_packus_epi16(v0, v1);
233 }
234
bilinear_filter8xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)235 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
236 int yoffset, uint8_t *dst, int h) {
237 int i;
238 // Horizontal filter
239 if (xoffset == 0) {
240 uint8_t *b = dst;
241 for (i = 0; i < h + 1; ++i) {
242 __m128i x = _mm_loadl_epi64((__m128i *)src);
243 _mm_storel_epi64((__m128i *)b, x);
244 src += src_stride;
245 b += 8;
246 }
247 } else if (xoffset == 4) {
248 uint8_t *b = dst;
249 for (i = 0; i < h + 1; ++i) {
250 __m128i x = _mm_loadu_si128((__m128i *)src);
251 __m128i z = _mm_srli_si128(x, 1);
252 _mm_storel_epi64((__m128i *)b, _mm_avg_epu8(x, z));
253 src += src_stride;
254 b += 8;
255 }
256 } else {
257 uint8_t *b = dst;
258 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
259 const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
260 for (i = 0; i < h; i += 2) {
261 const __m128i x0 = _mm_loadu_si128((__m128i *)src);
262 const __m128i z0 = _mm_srli_si128(x0, 1);
263 const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
264 const __m128i z1 = _mm_srli_si128(x1, 1);
265 const __m128i res = filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
266 _mm_storeu_si128((__m128i *)b, res);
267
268 src += src_stride * 2;
269 b += 16;
270 }
271 // Handle i = h separately
272 const __m128i x0 = _mm_loadu_si128((__m128i *)src);
273 const __m128i z0 = _mm_srli_si128(x0, 1);
274
275 __m128i v0 = _mm_unpacklo_epi8(x0, z0);
276 v0 = _mm_maddubs_epi16(v0, hfilter_vec);
277 v0 = xx_roundn_epu16(v0, FILTER_BITS);
278
279 _mm_storel_epi64((__m128i *)b, _mm_packus_epi16(v0, v0));
280 }
281
282 // Vertical filter
283 if (yoffset == 0) {
284 // The data is already in 'dst', so no need to filter
285 } else if (yoffset == 4) {
286 for (i = 0; i < h; ++i) {
287 __m128i x = _mm_loadl_epi64((__m128i *)dst);
288 __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
289 _mm_storel_epi64((__m128i *)dst, _mm_avg_epu8(x, y));
290 dst += 8;
291 }
292 } else {
293 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
294 const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
295 for (i = 0; i < h; i += 2) {
296 const __m128i x = _mm_loadl_epi64((__m128i *)dst);
297 const __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
298 const __m128i z = _mm_loadl_epi64((__m128i *)&dst[16]);
299 const __m128i res = filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
300 _mm_storeu_si128((__m128i *)dst, res);
301
302 dst += 16;
303 }
304 }
305 }
306
bilinear_filter4xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)307 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
308 int yoffset, uint8_t *dst, int h) {
309 int i;
310 // Horizontal filter
311 if (xoffset == 0) {
312 uint8_t *b = dst;
313 for (i = 0; i < h + 1; ++i) {
314 __m128i x = xx_loadl_32((__m128i *)src);
315 xx_storel_32(b, x);
316 src += src_stride;
317 b += 4;
318 }
319 } else if (xoffset == 4) {
320 uint8_t *b = dst;
321 for (i = 0; i < h + 1; ++i) {
322 __m128i x = _mm_loadl_epi64((__m128i *)src);
323 __m128i z = _mm_srli_si128(x, 1);
324 xx_storel_32(b, _mm_avg_epu8(x, z));
325 src += src_stride;
326 b += 4;
327 }
328 } else {
329 uint8_t *b = dst;
330 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
331 const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
332 for (i = 0; i < h; i += 4) {
333 const __m128i x0 = _mm_loadl_epi64((__m128i *)src);
334 const __m128i z0 = _mm_srli_si128(x0, 1);
335 const __m128i x1 = _mm_loadl_epi64((__m128i *)&src[src_stride]);
336 const __m128i z1 = _mm_srli_si128(x1, 1);
337 const __m128i x2 = _mm_loadl_epi64((__m128i *)&src[src_stride * 2]);
338 const __m128i z2 = _mm_srli_si128(x2, 1);
339 const __m128i x3 = _mm_loadl_epi64((__m128i *)&src[src_stride * 3]);
340 const __m128i z3 = _mm_srli_si128(x3, 1);
341
342 const __m128i a0 = _mm_unpacklo_epi32(x0, x1);
343 const __m128i b0 = _mm_unpacklo_epi32(z0, z1);
344 const __m128i a1 = _mm_unpacklo_epi32(x2, x3);
345 const __m128i b1 = _mm_unpacklo_epi32(z2, z3);
346 const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &hfilter_vec);
347 _mm_storeu_si128((__m128i *)b, res);
348
349 src += src_stride * 4;
350 b += 16;
351 }
352 // Handle i = h separately
353 const __m128i x = _mm_loadl_epi64((__m128i *)src);
354 const __m128i z = _mm_srli_si128(x, 1);
355
356 __m128i v0 = _mm_unpacklo_epi8(x, z);
357 v0 = _mm_maddubs_epi16(v0, hfilter_vec);
358 v0 = xx_roundn_epu16(v0, FILTER_BITS);
359
360 xx_storel_32(b, _mm_packus_epi16(v0, v0));
361 }
362
363 // Vertical filter
364 if (yoffset == 0) {
365 // The data is already in 'dst', so no need to filter
366 } else if (yoffset == 4) {
367 for (i = 0; i < h; ++i) {
368 __m128i x = xx_loadl_32((__m128i *)dst);
369 __m128i y = xx_loadl_32((__m128i *)&dst[4]);
370 xx_storel_32(dst, _mm_avg_epu8(x, y));
371 dst += 4;
372 }
373 } else {
374 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
375 const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
376 for (i = 0; i < h; i += 4) {
377 const __m128i a = xx_loadl_32((__m128i *)dst);
378 const __m128i b = xx_loadl_32((__m128i *)&dst[4]);
379 const __m128i c = xx_loadl_32((__m128i *)&dst[8]);
380 const __m128i d = xx_loadl_32((__m128i *)&dst[12]);
381 const __m128i e = xx_loadl_32((__m128i *)&dst[16]);
382
383 const __m128i a0 = _mm_unpacklo_epi32(a, b);
384 const __m128i b0 = _mm_unpacklo_epi32(b, c);
385 const __m128i a1 = _mm_unpacklo_epi32(c, d);
386 const __m128i b1 = _mm_unpacklo_epi32(d, e);
387 const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &vfilter_vec);
388 _mm_storeu_si128((__m128i *)dst, res);
389
390 dst += 16;
391 }
392 }
393 }
394
accumulate_block(const __m128i * src,const __m128i * a,const __m128i * b,const __m128i * m,__m128i * sum,__m128i * sum_sq)395 static INLINE void accumulate_block(const __m128i *src, const __m128i *a,
396 const __m128i *b, const __m128i *m,
397 __m128i *sum, __m128i *sum_sq) {
398 const __m128i zero = _mm_setzero_si128();
399 const __m128i one = _mm_set1_epi16(1);
400 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
401 const __m128i m_inv = _mm_sub_epi8(mask_max, *m);
402
403 // Calculate 16 predicted pixels.
404 // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
405 // is 64 * 255, so we have plenty of space to add rounding constants.
406 const __m128i data_l = _mm_unpacklo_epi8(*a, *b);
407 const __m128i mask_l = _mm_unpacklo_epi8(*m, m_inv);
408 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
409 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
410
411 const __m128i data_r = _mm_unpackhi_epi8(*a, *b);
412 const __m128i mask_r = _mm_unpackhi_epi8(*m, m_inv);
413 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
414 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
415
416 const __m128i src_l = _mm_unpacklo_epi8(*src, zero);
417 const __m128i src_r = _mm_unpackhi_epi8(*src, zero);
418 const __m128i diff_l = _mm_sub_epi16(pred_l, src_l);
419 const __m128i diff_r = _mm_sub_epi16(pred_r, src_r);
420
421 // Update partial sums and partial sums of squares
422 *sum =
423 _mm_add_epi32(*sum, _mm_madd_epi16(_mm_add_epi16(diff_l, diff_r), one));
424 *sum_sq =
425 _mm_add_epi32(*sum_sq, _mm_add_epi32(_mm_madd_epi16(diff_l, diff_l),
426 _mm_madd_epi16(diff_r, diff_r)));
427 }
428
masked_variance(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,unsigned int * sse,int * sum_)429 static void masked_variance(const uint8_t *src_ptr, int src_stride,
430 const uint8_t *a_ptr, int a_stride,
431 const uint8_t *b_ptr, int b_stride,
432 const uint8_t *m_ptr, int m_stride, int width,
433 int height, unsigned int *sse, int *sum_) {
434 int x, y;
435 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
436
437 for (y = 0; y < height; y++) {
438 for (x = 0; x < width; x += 16) {
439 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
440 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
441 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
442 const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
443 accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
444 }
445
446 src_ptr += src_stride;
447 a_ptr += a_stride;
448 b_ptr += b_stride;
449 m_ptr += m_stride;
450 }
451 // Reduce down to a single sum and sum of squares
452 sum = _mm_hadd_epi32(sum, sum_sq);
453 sum = _mm_hadd_epi32(sum, sum);
454 *sum_ = _mm_cvtsi128_si32(sum);
455 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
456 }
457
masked_variance8xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)458 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
459 const uint8_t *a_ptr, const uint8_t *b_ptr,
460 const uint8_t *m_ptr, int m_stride, int height,
461 unsigned int *sse, int *sum_) {
462 int y;
463 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
464
465 for (y = 0; y < height; y += 2) {
466 __m128i src = _mm_unpacklo_epi64(
467 _mm_loadl_epi64((const __m128i *)src_ptr),
468 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
469 const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
470 const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
471 const __m128i m =
472 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
473 _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
474 accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
475
476 src_ptr += src_stride * 2;
477 a_ptr += 16;
478 b_ptr += 16;
479 m_ptr += m_stride * 2;
480 }
481 // Reduce down to a single sum and sum of squares
482 sum = _mm_hadd_epi32(sum, sum_sq);
483 sum = _mm_hadd_epi32(sum, sum);
484 *sum_ = _mm_cvtsi128_si32(sum);
485 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
486 }
487
masked_variance4xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)488 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
489 const uint8_t *a_ptr, const uint8_t *b_ptr,
490 const uint8_t *m_ptr, int m_stride, int height,
491 unsigned int *sse, int *sum_) {
492 int y;
493 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
494
495 for (y = 0; y < height; y += 4) {
496 // Load four rows at a time
497 __m128i src = _mm_setr_epi32(*(int *)src_ptr, *(int *)&src_ptr[src_stride],
498 *(int *)&src_ptr[src_stride * 2],
499 *(int *)&src_ptr[src_stride * 3]);
500 const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
501 const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
502 const __m128i m = _mm_setr_epi32(*(int *)m_ptr, *(int *)&m_ptr[m_stride],
503 *(int *)&m_ptr[m_stride * 2],
504 *(int *)&m_ptr[m_stride * 3]);
505 accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
506
507 src_ptr += src_stride * 4;
508 a_ptr += 16;
509 b_ptr += 16;
510 m_ptr += m_stride * 4;
511 }
512 // Reduce down to a single sum and sum of squares
513 sum = _mm_hadd_epi32(sum, sum_sq);
514 sum = _mm_hadd_epi32(sum, sum);
515 *sum_ = _mm_cvtsi128_si32(sum);
516 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
517 }
518
519 #if CONFIG_AV1_HIGHBITDEPTH
520 // For width a multiple of 8
521 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
522 int xoffset, int yoffset, uint16_t *dst,
523 int w, int h);
524
525 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
526 int xoffset, int yoffset, uint16_t *dst,
527 int h);
528
529 // For width a multiple of 8
530 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
531 const uint16_t *a_ptr, int a_stride,
532 const uint16_t *b_ptr, int b_stride,
533 const uint8_t *m_ptr, int m_stride,
534 int width, int height, uint64_t *sse,
535 int *sum_);
536
537 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
538 const uint16_t *a_ptr,
539 const uint16_t *b_ptr,
540 const uint8_t *m_ptr, int m_stride,
541 int height, int *sse, int *sum_);
542
543 #define HIGHBD_MASK_SUBPIX_VAR_SSSE3(W, H) \
544 unsigned int aom_highbd_8_masked_sub_pixel_variance##W##x##H##_ssse3( \
545 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
546 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
547 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
548 uint64_t sse64; \
549 int sum; \
550 uint16_t temp[(H + 1) * W]; \
551 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
552 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
553 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
554 \
555 highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
556 \
557 if (!invert_mask) \
558 highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
559 msk_stride, W, H, &sse64, &sum); \
560 else \
561 highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
562 msk_stride, W, H, &sse64, &sum); \
563 *sse = (uint32_t)sse64; \
564 return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H)); \
565 } \
566 unsigned int aom_highbd_10_masked_sub_pixel_variance##W##x##H##_ssse3( \
567 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
568 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
569 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
570 uint64_t sse64; \
571 int sum; \
572 int64_t var; \
573 uint16_t temp[(H + 1) * W]; \
574 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
575 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
576 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
577 \
578 highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
579 \
580 if (!invert_mask) \
581 highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
582 msk_stride, W, H, &sse64, &sum); \
583 else \
584 highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
585 msk_stride, W, H, &sse64, &sum); \
586 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 4); \
587 sum = ROUND_POWER_OF_TWO(sum, 2); \
588 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H)); \
589 return (var >= 0) ? (uint32_t)var : 0; \
590 } \
591 unsigned int aom_highbd_12_masked_sub_pixel_variance##W##x##H##_ssse3( \
592 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
593 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
594 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
595 uint64_t sse64; \
596 int sum; \
597 int64_t var; \
598 uint16_t temp[(H + 1) * W]; \
599 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
600 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
601 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
602 \
603 highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
604 \
605 if (!invert_mask) \
606 highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
607 msk_stride, W, H, &sse64, &sum); \
608 else \
609 highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
610 msk_stride, W, H, &sse64, &sum); \
611 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 8); \
612 sum = ROUND_POWER_OF_TWO(sum, 4); \
613 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H)); \
614 return (var >= 0) ? (uint32_t)var : 0; \
615 }
616
617 #define HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(H) \
618 unsigned int aom_highbd_8_masked_sub_pixel_variance4x##H##_ssse3( \
619 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
620 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
621 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
622 int sse_; \
623 int sum; \
624 uint16_t temp[(H + 1) * 4]; \
625 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
626 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
627 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
628 \
629 highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
630 \
631 if (!invert_mask) \
632 highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk, \
633 msk_stride, H, &sse_, &sum); \
634 else \
635 highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk, \
636 msk_stride, H, &sse_, &sum); \
637 *sse = (uint32_t)sse_; \
638 return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H)); \
639 } \
640 unsigned int aom_highbd_10_masked_sub_pixel_variance4x##H##_ssse3( \
641 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
642 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
643 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
644 int sse_; \
645 int sum; \
646 int64_t var; \
647 uint16_t temp[(H + 1) * 4]; \
648 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
649 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
650 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
651 \
652 highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
653 \
654 if (!invert_mask) \
655 highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk, \
656 msk_stride, H, &sse_, &sum); \
657 else \
658 highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk, \
659 msk_stride, H, &sse_, &sum); \
660 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 4); \
661 sum = ROUND_POWER_OF_TWO(sum, 2); \
662 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H)); \
663 return (var >= 0) ? (uint32_t)var : 0; \
664 } \
665 unsigned int aom_highbd_12_masked_sub_pixel_variance4x##H##_ssse3( \
666 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
667 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
668 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
669 int sse_; \
670 int sum; \
671 int64_t var; \
672 uint16_t temp[(H + 1) * 4]; \
673 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
674 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
675 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
676 \
677 highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
678 \
679 if (!invert_mask) \
680 highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk, \
681 msk_stride, H, &sse_, &sum); \
682 else \
683 highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk, \
684 msk_stride, H, &sse_, &sum); \
685 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 8); \
686 sum = ROUND_POWER_OF_TWO(sum, 4); \
687 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H)); \
688 return (var >= 0) ? (uint32_t)var : 0; \
689 }
690
691 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 128)
692 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 64)
693 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 128)
694 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 64)
695 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 32)
696 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 64)
697 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 32)
698 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 16)
699 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 32)
700 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 16)
701 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 8)
702 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 16)
703 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 8)
704 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 4)
705 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(8)
706 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(4)
707 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(16)
708 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 4)
709 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 32)
710 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 8)
711 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 64)
712 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 16)
713
highbd_filter_block(const __m128i a,const __m128i b,const __m128i filter)714 static INLINE __m128i highbd_filter_block(const __m128i a, const __m128i b,
715 const __m128i filter) {
716 __m128i v0 = _mm_unpacklo_epi16(a, b);
717 v0 = _mm_madd_epi16(v0, filter);
718 v0 = xx_roundn_epu32(v0, FILTER_BITS);
719
720 __m128i v1 = _mm_unpackhi_epi16(a, b);
721 v1 = _mm_madd_epi16(v1, filter);
722 v1 = xx_roundn_epu32(v1, FILTER_BITS);
723
724 return _mm_packs_epi32(v0, v1);
725 }
726
highbd_bilinear_filter(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int w,int h)727 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
728 int xoffset, int yoffset, uint16_t *dst,
729 int w, int h) {
730 int i, j;
731 // Horizontal filter
732 if (xoffset == 0) {
733 uint16_t *b = dst;
734 for (i = 0; i < h + 1; ++i) {
735 for (j = 0; j < w; j += 8) {
736 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
737 _mm_storeu_si128((__m128i *)&b[j], x);
738 }
739 src += src_stride;
740 b += w;
741 }
742 } else if (xoffset == 4) {
743 uint16_t *b = dst;
744 for (i = 0; i < h + 1; ++i) {
745 for (j = 0; j < w; j += 8) {
746 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
747 __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
748 __m128i z = _mm_alignr_epi8(y, x, 2);
749 _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu16(x, z));
750 }
751 src += src_stride;
752 b += w;
753 }
754 } else {
755 uint16_t *b = dst;
756 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
757 const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
758 for (i = 0; i < h + 1; ++i) {
759 for (j = 0; j < w; j += 8) {
760 const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
761 const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
762 const __m128i z = _mm_alignr_epi8(y, x, 2);
763 const __m128i res = highbd_filter_block(x, z, hfilter_vec);
764 _mm_storeu_si128((__m128i *)&b[j], res);
765 }
766
767 src += src_stride;
768 b += w;
769 }
770 }
771
772 // Vertical filter
773 if (yoffset == 0) {
774 // The data is already in 'dst', so no need to filter
775 } else if (yoffset == 4) {
776 for (i = 0; i < h; ++i) {
777 for (j = 0; j < w; j += 8) {
778 __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
779 __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
780 _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu16(x, y));
781 }
782 dst += w;
783 }
784 } else {
785 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
786 const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
787 for (i = 0; i < h; ++i) {
788 for (j = 0; j < w; j += 8) {
789 const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
790 const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
791 const __m128i res = highbd_filter_block(x, y, vfilter_vec);
792 _mm_storeu_si128((__m128i *)&dst[j], res);
793 }
794
795 dst += w;
796 }
797 }
798 }
799
highbd_filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)800 static INLINE __m128i highbd_filter_block_2rows(const __m128i *a0,
801 const __m128i *b0,
802 const __m128i *a1,
803 const __m128i *b1,
804 const __m128i *filter) {
805 __m128i v0 = _mm_unpacklo_epi16(*a0, *b0);
806 v0 = _mm_madd_epi16(v0, *filter);
807 v0 = xx_roundn_epu32(v0, FILTER_BITS);
808
809 __m128i v1 = _mm_unpacklo_epi16(*a1, *b1);
810 v1 = _mm_madd_epi16(v1, *filter);
811 v1 = xx_roundn_epu32(v1, FILTER_BITS);
812
813 return _mm_packs_epi32(v0, v1);
814 }
815
highbd_bilinear_filter4xh(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int h)816 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
817 int xoffset, int yoffset, uint16_t *dst,
818 int h) {
819 int i;
820 // Horizontal filter
821 if (xoffset == 0) {
822 uint16_t *b = dst;
823 for (i = 0; i < h + 1; ++i) {
824 __m128i x = _mm_loadl_epi64((__m128i *)src);
825 _mm_storel_epi64((__m128i *)b, x);
826 src += src_stride;
827 b += 4;
828 }
829 } else if (xoffset == 4) {
830 uint16_t *b = dst;
831 for (i = 0; i < h + 1; ++i) {
832 __m128i x = _mm_loadu_si128((__m128i *)src);
833 __m128i z = _mm_srli_si128(x, 2);
834 _mm_storel_epi64((__m128i *)b, _mm_avg_epu16(x, z));
835 src += src_stride;
836 b += 4;
837 }
838 } else {
839 uint16_t *b = dst;
840 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
841 const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
842 for (i = 0; i < h; i += 2) {
843 const __m128i x0 = _mm_loadu_si128((__m128i *)src);
844 const __m128i z0 = _mm_srli_si128(x0, 2);
845 const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
846 const __m128i z1 = _mm_srli_si128(x1, 2);
847 const __m128i res =
848 highbd_filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
849 _mm_storeu_si128((__m128i *)b, res);
850
851 src += src_stride * 2;
852 b += 8;
853 }
854 // Process i = h separately
855 __m128i x = _mm_loadu_si128((__m128i *)src);
856 __m128i z = _mm_srli_si128(x, 2);
857
858 __m128i v0 = _mm_unpacklo_epi16(x, z);
859 v0 = _mm_madd_epi16(v0, hfilter_vec);
860 v0 = xx_roundn_epu32(v0, FILTER_BITS);
861
862 _mm_storel_epi64((__m128i *)b, _mm_packs_epi32(v0, v0));
863 }
864
865 // Vertical filter
866 if (yoffset == 0) {
867 // The data is already in 'dst', so no need to filter
868 } else if (yoffset == 4) {
869 for (i = 0; i < h; ++i) {
870 __m128i x = _mm_loadl_epi64((__m128i *)dst);
871 __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
872 _mm_storel_epi64((__m128i *)dst, _mm_avg_epu16(x, y));
873 dst += 4;
874 }
875 } else {
876 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
877 const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
878 for (i = 0; i < h; i += 2) {
879 const __m128i x = _mm_loadl_epi64((__m128i *)dst);
880 const __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
881 const __m128i z = _mm_loadl_epi64((__m128i *)&dst[8]);
882 const __m128i res =
883 highbd_filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
884 _mm_storeu_si128((__m128i *)dst, res);
885
886 dst += 8;
887 }
888 }
889 }
890
highbd_masked_variance(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,int a_stride,const uint16_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,uint64_t * sse,int * sum_)891 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
892 const uint16_t *a_ptr, int a_stride,
893 const uint16_t *b_ptr, int b_stride,
894 const uint8_t *m_ptr, int m_stride,
895 int width, int height, uint64_t *sse,
896 int *sum_) {
897 int x, y;
898 // Note on bit widths:
899 // The maximum value of 'sum' is (2^12 - 1) * 128 * 128 =~ 2^26,
900 // so this can be kept as four 32-bit values.
901 // But the maximum value of 'sum_sq' is (2^12 - 1)^2 * 128 * 128 =~ 2^38,
902 // so this must be stored as two 64-bit values.
903 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
904 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
905 const __m128i round_const =
906 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
907 const __m128i zero = _mm_setzero_si128();
908
909 for (y = 0; y < height; y++) {
910 for (x = 0; x < width; x += 8) {
911 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
912 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
913 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
914 const __m128i m =
915 _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)&m_ptr[x]), zero);
916 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
917
918 // Calculate 8 predicted pixels.
919 const __m128i data_l = _mm_unpacklo_epi16(a, b);
920 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
921 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
922 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
923 AOM_BLEND_A64_ROUND_BITS);
924
925 const __m128i data_r = _mm_unpackhi_epi16(a, b);
926 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
927 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
928 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
929 AOM_BLEND_A64_ROUND_BITS);
930
931 const __m128i src_l = _mm_unpacklo_epi16(src, zero);
932 const __m128i src_r = _mm_unpackhi_epi16(src, zero);
933 __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
934 __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
935
936 // Update partial sums and partial sums of squares
937 sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
938 // A trick: Now each entry of diff_l and diff_r is stored in a 32-bit
939 // field, but the range of values is only [-(2^12 - 1), 2^12 - 1].
940 // So we can re-pack into 16-bit fields and use _mm_madd_epi16
941 // to calculate the squares and partially sum them.
942 const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
943 const __m128i prod = _mm_madd_epi16(tmp, tmp);
944 // Then we want to sign-extend to 64 bits and accumulate
945 const __m128i sign = _mm_srai_epi32(prod, 31);
946 const __m128i tmp_0 = _mm_unpacklo_epi32(prod, sign);
947 const __m128i tmp_1 = _mm_unpackhi_epi32(prod, sign);
948 sum_sq = _mm_add_epi64(sum_sq, _mm_add_epi64(tmp_0, tmp_1));
949 }
950
951 src_ptr += src_stride;
952 a_ptr += a_stride;
953 b_ptr += b_stride;
954 m_ptr += m_stride;
955 }
956 // Reduce down to a single sum and sum of squares
957 sum = _mm_hadd_epi32(sum, zero);
958 sum = _mm_hadd_epi32(sum, zero);
959 *sum_ = _mm_cvtsi128_si32(sum);
960 sum_sq = _mm_add_epi64(sum_sq, _mm_srli_si128(sum_sq, 8));
961 _mm_storel_epi64((__m128i *)sse, sum_sq);
962 }
963
highbd_masked_variance4xh(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,const uint16_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,int * sse,int * sum_)964 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
965 const uint16_t *a_ptr,
966 const uint16_t *b_ptr,
967 const uint8_t *m_ptr, int m_stride,
968 int height, int *sse, int *sum_) {
969 int y;
970 // Note: For this function, h <= 8 (or maybe 16 if we add 4:1 partitions).
971 // So the maximum value of sum is (2^12 - 1) * 4 * 16 =~ 2^18
972 // and the maximum value of sum_sq is (2^12 - 1)^2 * 4 * 16 =~ 2^30.
973 // So we can safely pack sum_sq into 32-bit fields, which is slightly more
974 // convenient.
975 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
976 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
977 const __m128i round_const =
978 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
979 const __m128i zero = _mm_setzero_si128();
980
981 for (y = 0; y < height; y += 2) {
982 __m128i src = _mm_unpacklo_epi64(
983 _mm_loadl_epi64((const __m128i *)src_ptr),
984 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
985 const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
986 const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
987 const __m128i m = _mm_unpacklo_epi8(
988 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(const int *)m_ptr),
989 _mm_cvtsi32_si128(*(const int *)&m_ptr[m_stride])),
990 zero);
991 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
992
993 const __m128i data_l = _mm_unpacklo_epi16(a, b);
994 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
995 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
996 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
997 AOM_BLEND_A64_ROUND_BITS);
998
999 const __m128i data_r = _mm_unpackhi_epi16(a, b);
1000 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
1001 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
1002 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
1003 AOM_BLEND_A64_ROUND_BITS);
1004
1005 const __m128i src_l = _mm_unpacklo_epi16(src, zero);
1006 const __m128i src_r = _mm_unpackhi_epi16(src, zero);
1007 __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
1008 __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
1009
1010 // Update partial sums and partial sums of squares
1011 sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
1012 const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
1013 const __m128i prod = _mm_madd_epi16(tmp, tmp);
1014 sum_sq = _mm_add_epi32(sum_sq, prod);
1015
1016 src_ptr += src_stride * 2;
1017 a_ptr += 8;
1018 b_ptr += 8;
1019 m_ptr += m_stride * 2;
1020 }
1021 // Reduce down to a single sum and sum of squares
1022 sum = _mm_hadd_epi32(sum, sum_sq);
1023 sum = _mm_hadd_epi32(sum, zero);
1024 *sum_ = _mm_cvtsi128_si32(sum);
1025 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
1026 }
1027 #endif // CONFIG_AV1_HIGHBITDEPTH
1028
aom_comp_mask_pred_ssse3(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)1029 void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
1030 int width, int height, const uint8_t *ref,
1031 int ref_stride, const uint8_t *mask,
1032 int mask_stride, int invert_mask) {
1033 const uint8_t *src0 = invert_mask ? pred : ref;
1034 const uint8_t *src1 = invert_mask ? ref : pred;
1035 const int stride0 = invert_mask ? width : ref_stride;
1036 const int stride1 = invert_mask ? ref_stride : width;
1037 assert(height % 2 == 0);
1038 int i = 0;
1039 if (width == 8) {
1040 comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
1041 mask, mask_stride);
1042 } else if (width == 16) {
1043 do {
1044 comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
1045 comp_mask_pred_16_ssse3(src0 + stride0, src1 + stride1,
1046 mask + mask_stride, comp_pred + width);
1047 comp_pred += (width << 1);
1048 src0 += (stride0 << 1);
1049 src1 += (stride1 << 1);
1050 mask += (mask_stride << 1);
1051 i += 2;
1052 } while (i < height);
1053 } else {
1054 do {
1055 for (int x = 0; x < width; x += 32) {
1056 comp_mask_pred_16_ssse3(src0 + x, src1 + x, mask + x, comp_pred);
1057 comp_mask_pred_16_ssse3(src0 + x + 16, src1 + x + 16, mask + x + 16,
1058 comp_pred + 16);
1059 comp_pred += 32;
1060 }
1061 src0 += (stride0);
1062 src1 += (stride1);
1063 mask += (mask_stride);
1064 i += 1;
1065 } while (i < height);
1066 }
1067 }
1068