• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdlib.h>
13 #include <string.h>
14 #include <tmmintrin.h>
15 
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18 
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/aom_filter.h"
21 #include "aom_dsp/blend.h"
22 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
23 #include "aom_dsp/x86/synonyms.h"
24 #include "aom_ports/mem.h"
25 
26 // For width a multiple of 16
27 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
28                             int yoffset, uint8_t *dst, int w, int h);
29 
30 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
31                                int yoffset, uint8_t *dst, int h);
32 
33 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
34                                int yoffset, uint8_t *dst, int h);
35 
36 // For width a multiple of 16
37 static void masked_variance(const uint8_t *src_ptr, int src_stride,
38                             const uint8_t *a_ptr, int a_stride,
39                             const uint8_t *b_ptr, int b_stride,
40                             const uint8_t *m_ptr, int m_stride, int width,
41                             int height, unsigned int *sse, int *sum_);
42 
43 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
44                                const uint8_t *a_ptr, const uint8_t *b_ptr,
45                                const uint8_t *m_ptr, int m_stride, int height,
46                                unsigned int *sse, int *sum_);
47 
48 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
49                                const uint8_t *a_ptr, const uint8_t *b_ptr,
50                                const uint8_t *m_ptr, int m_stride, int height,
51                                unsigned int *sse, int *sum_);
52 
53 #define MASK_SUBPIX_VAR_SSSE3(W, H)                                   \
54   unsigned int aom_masked_sub_pixel_variance##W##x##H##_ssse3(        \
55       const uint8_t *src, int src_stride, int xoffset, int yoffset,   \
56       const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
57       const uint8_t *msk, int msk_stride, int invert_mask,            \
58       unsigned int *sse) {                                            \
59     int sum;                                                          \
60     uint8_t temp[(H + 1) * W];                                        \
61                                                                       \
62     bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);   \
63                                                                       \
64     if (!invert_mask)                                                 \
65       masked_variance(ref, ref_stride, temp, W, second_pred, W, msk,  \
66                       msk_stride, W, H, sse, &sum);                   \
67     else                                                              \
68       masked_variance(ref, ref_stride, second_pred, W, temp, W, msk,  \
69                       msk_stride, W, H, sse, &sum);                   \
70     return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H));         \
71   }
72 
73 #define MASK_SUBPIX_VAR8XH_SSSE3(H)                                           \
74   unsigned int aom_masked_sub_pixel_variance8x##H##_ssse3(                    \
75       const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
76       const uint8_t *ref, int ref_stride, const uint8_t *second_pred,         \
77       const uint8_t *msk, int msk_stride, int invert_mask,                    \
78       unsigned int *sse) {                                                    \
79     int sum;                                                                  \
80     uint8_t temp[(H + 1) * 8];                                                \
81                                                                               \
82     bilinear_filter8xh(src, src_stride, xoffset, yoffset, temp, H);           \
83                                                                               \
84     if (!invert_mask)                                                         \
85       masked_variance8xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
86                          H, sse, &sum);                                       \
87     else                                                                      \
88       masked_variance8xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
89                          H, sse, &sum);                                       \
90     return *sse - (uint32_t)(((int64_t)sum * sum) / (8 * H));                 \
91   }
92 
93 #define MASK_SUBPIX_VAR4XH_SSSE3(H)                                           \
94   unsigned int aom_masked_sub_pixel_variance4x##H##_ssse3(                    \
95       const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
96       const uint8_t *ref, int ref_stride, const uint8_t *second_pred,         \
97       const uint8_t *msk, int msk_stride, int invert_mask,                    \
98       unsigned int *sse) {                                                    \
99     int sum;                                                                  \
100     uint8_t temp[(H + 1) * 4];                                                \
101                                                                               \
102     bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);           \
103                                                                               \
104     if (!invert_mask)                                                         \
105       masked_variance4xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
106                          H, sse, &sum);                                       \
107     else                                                                      \
108       masked_variance4xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
109                          H, sse, &sum);                                       \
110     return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H));                 \
111   }
112 
113 MASK_SUBPIX_VAR_SSSE3(128, 128)
114 MASK_SUBPIX_VAR_SSSE3(128, 64)
115 MASK_SUBPIX_VAR_SSSE3(64, 128)
116 MASK_SUBPIX_VAR_SSSE3(64, 64)
117 MASK_SUBPIX_VAR_SSSE3(64, 32)
118 MASK_SUBPIX_VAR_SSSE3(32, 64)
119 MASK_SUBPIX_VAR_SSSE3(32, 32)
120 MASK_SUBPIX_VAR_SSSE3(32, 16)
121 MASK_SUBPIX_VAR_SSSE3(16, 32)
122 MASK_SUBPIX_VAR_SSSE3(16, 16)
123 MASK_SUBPIX_VAR_SSSE3(16, 8)
124 MASK_SUBPIX_VAR8XH_SSSE3(16)
125 MASK_SUBPIX_VAR8XH_SSSE3(8)
126 MASK_SUBPIX_VAR8XH_SSSE3(4)
127 MASK_SUBPIX_VAR4XH_SSSE3(8)
128 MASK_SUBPIX_VAR4XH_SSSE3(4)
129 MASK_SUBPIX_VAR4XH_SSSE3(16)
130 MASK_SUBPIX_VAR_SSSE3(16, 4)
131 MASK_SUBPIX_VAR8XH_SSSE3(32)
132 MASK_SUBPIX_VAR_SSSE3(32, 8)
133 MASK_SUBPIX_VAR_SSSE3(64, 16)
134 MASK_SUBPIX_VAR_SSSE3(16, 64)
135 
filter_block(const __m128i a,const __m128i b,const __m128i filter)136 static INLINE __m128i filter_block(const __m128i a, const __m128i b,
137                                    const __m128i filter) {
138   __m128i v0 = _mm_unpacklo_epi8(a, b);
139   v0 = _mm_maddubs_epi16(v0, filter);
140   v0 = xx_roundn_epu16(v0, FILTER_BITS);
141 
142   __m128i v1 = _mm_unpackhi_epi8(a, b);
143   v1 = _mm_maddubs_epi16(v1, filter);
144   v1 = xx_roundn_epu16(v1, FILTER_BITS);
145 
146   return _mm_packus_epi16(v0, v1);
147 }
148 
bilinear_filter(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int w,int h)149 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
150                             int yoffset, uint8_t *dst, int w, int h) {
151   int i, j;
152   // Horizontal filter
153   if (xoffset == 0) {
154     uint8_t *b = dst;
155     for (i = 0; i < h + 1; ++i) {
156       for (j = 0; j < w; j += 16) {
157         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
158         _mm_storeu_si128((__m128i *)&b[j], x);
159       }
160       src += src_stride;
161       b += w;
162     }
163   } else if (xoffset == 4) {
164     uint8_t *b = dst;
165     for (i = 0; i < h + 1; ++i) {
166       for (j = 0; j < w; j += 16) {
167         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
168         __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
169         __m128i z = _mm_alignr_epi8(y, x, 1);
170         _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu8(x, z));
171       }
172       src += src_stride;
173       b += w;
174     }
175   } else {
176     uint8_t *b = dst;
177     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
178     const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
179     for (i = 0; i < h + 1; ++i) {
180       for (j = 0; j < w; j += 16) {
181         const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
182         const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
183         const __m128i z = _mm_alignr_epi8(y, x, 1);
184         const __m128i res = filter_block(x, z, hfilter_vec);
185         _mm_storeu_si128((__m128i *)&b[j], res);
186       }
187 
188       src += src_stride;
189       b += w;
190     }
191   }
192 
193   // Vertical filter
194   if (yoffset == 0) {
195     // The data is already in 'dst', so no need to filter
196   } else if (yoffset == 4) {
197     for (i = 0; i < h; ++i) {
198       for (j = 0; j < w; j += 16) {
199         __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
200         __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
201         _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu8(x, y));
202       }
203       dst += w;
204     }
205   } else {
206     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
207     const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
208     for (i = 0; i < h; ++i) {
209       for (j = 0; j < w; j += 16) {
210         const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
211         const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
212         const __m128i res = filter_block(x, y, vfilter_vec);
213         _mm_storeu_si128((__m128i *)&dst[j], res);
214       }
215 
216       dst += w;
217     }
218   }
219 }
220 
filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)221 static INLINE __m128i filter_block_2rows(const __m128i *a0, const __m128i *b0,
222                                          const __m128i *a1, const __m128i *b1,
223                                          const __m128i *filter) {
224   __m128i v0 = _mm_unpacklo_epi8(*a0, *b0);
225   v0 = _mm_maddubs_epi16(v0, *filter);
226   v0 = xx_roundn_epu16(v0, FILTER_BITS);
227 
228   __m128i v1 = _mm_unpacklo_epi8(*a1, *b1);
229   v1 = _mm_maddubs_epi16(v1, *filter);
230   v1 = xx_roundn_epu16(v1, FILTER_BITS);
231 
232   return _mm_packus_epi16(v0, v1);
233 }
234 
bilinear_filter8xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)235 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
236                                int yoffset, uint8_t *dst, int h) {
237   int i;
238   // Horizontal filter
239   if (xoffset == 0) {
240     uint8_t *b = dst;
241     for (i = 0; i < h + 1; ++i) {
242       __m128i x = _mm_loadl_epi64((__m128i *)src);
243       _mm_storel_epi64((__m128i *)b, x);
244       src += src_stride;
245       b += 8;
246     }
247   } else if (xoffset == 4) {
248     uint8_t *b = dst;
249     for (i = 0; i < h + 1; ++i) {
250       __m128i x = _mm_loadu_si128((__m128i *)src);
251       __m128i z = _mm_srli_si128(x, 1);
252       _mm_storel_epi64((__m128i *)b, _mm_avg_epu8(x, z));
253       src += src_stride;
254       b += 8;
255     }
256   } else {
257     uint8_t *b = dst;
258     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
259     const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
260     for (i = 0; i < h; i += 2) {
261       const __m128i x0 = _mm_loadu_si128((__m128i *)src);
262       const __m128i z0 = _mm_srli_si128(x0, 1);
263       const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
264       const __m128i z1 = _mm_srli_si128(x1, 1);
265       const __m128i res = filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
266       _mm_storeu_si128((__m128i *)b, res);
267 
268       src += src_stride * 2;
269       b += 16;
270     }
271     // Handle i = h separately
272     const __m128i x0 = _mm_loadu_si128((__m128i *)src);
273     const __m128i z0 = _mm_srli_si128(x0, 1);
274 
275     __m128i v0 = _mm_unpacklo_epi8(x0, z0);
276     v0 = _mm_maddubs_epi16(v0, hfilter_vec);
277     v0 = xx_roundn_epu16(v0, FILTER_BITS);
278 
279     _mm_storel_epi64((__m128i *)b, _mm_packus_epi16(v0, v0));
280   }
281 
282   // Vertical filter
283   if (yoffset == 0) {
284     // The data is already in 'dst', so no need to filter
285   } else if (yoffset == 4) {
286     for (i = 0; i < h; ++i) {
287       __m128i x = _mm_loadl_epi64((__m128i *)dst);
288       __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
289       _mm_storel_epi64((__m128i *)dst, _mm_avg_epu8(x, y));
290       dst += 8;
291     }
292   } else {
293     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
294     const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
295     for (i = 0; i < h; i += 2) {
296       const __m128i x = _mm_loadl_epi64((__m128i *)dst);
297       const __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
298       const __m128i z = _mm_loadl_epi64((__m128i *)&dst[16]);
299       const __m128i res = filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
300       _mm_storeu_si128((__m128i *)dst, res);
301 
302       dst += 16;
303     }
304   }
305 }
306 
bilinear_filter4xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)307 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
308                                int yoffset, uint8_t *dst, int h) {
309   int i;
310   // Horizontal filter
311   if (xoffset == 0) {
312     uint8_t *b = dst;
313     for (i = 0; i < h + 1; ++i) {
314       __m128i x = xx_loadl_32((__m128i *)src);
315       xx_storel_32(b, x);
316       src += src_stride;
317       b += 4;
318     }
319   } else if (xoffset == 4) {
320     uint8_t *b = dst;
321     for (i = 0; i < h + 1; ++i) {
322       __m128i x = _mm_loadl_epi64((__m128i *)src);
323       __m128i z = _mm_srli_si128(x, 1);
324       xx_storel_32(b, _mm_avg_epu8(x, z));
325       src += src_stride;
326       b += 4;
327     }
328   } else {
329     uint8_t *b = dst;
330     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
331     const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
332     for (i = 0; i < h; i += 4) {
333       const __m128i x0 = _mm_loadl_epi64((__m128i *)src);
334       const __m128i z0 = _mm_srli_si128(x0, 1);
335       const __m128i x1 = _mm_loadl_epi64((__m128i *)&src[src_stride]);
336       const __m128i z1 = _mm_srli_si128(x1, 1);
337       const __m128i x2 = _mm_loadl_epi64((__m128i *)&src[src_stride * 2]);
338       const __m128i z2 = _mm_srli_si128(x2, 1);
339       const __m128i x3 = _mm_loadl_epi64((__m128i *)&src[src_stride * 3]);
340       const __m128i z3 = _mm_srli_si128(x3, 1);
341 
342       const __m128i a0 = _mm_unpacklo_epi32(x0, x1);
343       const __m128i b0 = _mm_unpacklo_epi32(z0, z1);
344       const __m128i a1 = _mm_unpacklo_epi32(x2, x3);
345       const __m128i b1 = _mm_unpacklo_epi32(z2, z3);
346       const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &hfilter_vec);
347       _mm_storeu_si128((__m128i *)b, res);
348 
349       src += src_stride * 4;
350       b += 16;
351     }
352     // Handle i = h separately
353     const __m128i x = _mm_loadl_epi64((__m128i *)src);
354     const __m128i z = _mm_srli_si128(x, 1);
355 
356     __m128i v0 = _mm_unpacklo_epi8(x, z);
357     v0 = _mm_maddubs_epi16(v0, hfilter_vec);
358     v0 = xx_roundn_epu16(v0, FILTER_BITS);
359 
360     xx_storel_32(b, _mm_packus_epi16(v0, v0));
361   }
362 
363   // Vertical filter
364   if (yoffset == 0) {
365     // The data is already in 'dst', so no need to filter
366   } else if (yoffset == 4) {
367     for (i = 0; i < h; ++i) {
368       __m128i x = xx_loadl_32((__m128i *)dst);
369       __m128i y = xx_loadl_32((__m128i *)&dst[4]);
370       xx_storel_32(dst, _mm_avg_epu8(x, y));
371       dst += 4;
372     }
373   } else {
374     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
375     const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
376     for (i = 0; i < h; i += 4) {
377       const __m128i a = xx_loadl_32((__m128i *)dst);
378       const __m128i b = xx_loadl_32((__m128i *)&dst[4]);
379       const __m128i c = xx_loadl_32((__m128i *)&dst[8]);
380       const __m128i d = xx_loadl_32((__m128i *)&dst[12]);
381       const __m128i e = xx_loadl_32((__m128i *)&dst[16]);
382 
383       const __m128i a0 = _mm_unpacklo_epi32(a, b);
384       const __m128i b0 = _mm_unpacklo_epi32(b, c);
385       const __m128i a1 = _mm_unpacklo_epi32(c, d);
386       const __m128i b1 = _mm_unpacklo_epi32(d, e);
387       const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &vfilter_vec);
388       _mm_storeu_si128((__m128i *)dst, res);
389 
390       dst += 16;
391     }
392   }
393 }
394 
accumulate_block(const __m128i * src,const __m128i * a,const __m128i * b,const __m128i * m,__m128i * sum,__m128i * sum_sq)395 static INLINE void accumulate_block(const __m128i *src, const __m128i *a,
396                                     const __m128i *b, const __m128i *m,
397                                     __m128i *sum, __m128i *sum_sq) {
398   const __m128i zero = _mm_setzero_si128();
399   const __m128i one = _mm_set1_epi16(1);
400   const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
401   const __m128i m_inv = _mm_sub_epi8(mask_max, *m);
402 
403   // Calculate 16 predicted pixels.
404   // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
405   // is 64 * 255, so we have plenty of space to add rounding constants.
406   const __m128i data_l = _mm_unpacklo_epi8(*a, *b);
407   const __m128i mask_l = _mm_unpacklo_epi8(*m, m_inv);
408   __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
409   pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
410 
411   const __m128i data_r = _mm_unpackhi_epi8(*a, *b);
412   const __m128i mask_r = _mm_unpackhi_epi8(*m, m_inv);
413   __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
414   pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
415 
416   const __m128i src_l = _mm_unpacklo_epi8(*src, zero);
417   const __m128i src_r = _mm_unpackhi_epi8(*src, zero);
418   const __m128i diff_l = _mm_sub_epi16(pred_l, src_l);
419   const __m128i diff_r = _mm_sub_epi16(pred_r, src_r);
420 
421   // Update partial sums and partial sums of squares
422   *sum =
423       _mm_add_epi32(*sum, _mm_madd_epi16(_mm_add_epi16(diff_l, diff_r), one));
424   *sum_sq =
425       _mm_add_epi32(*sum_sq, _mm_add_epi32(_mm_madd_epi16(diff_l, diff_l),
426                                            _mm_madd_epi16(diff_r, diff_r)));
427 }
428 
masked_variance(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,unsigned int * sse,int * sum_)429 static void masked_variance(const uint8_t *src_ptr, int src_stride,
430                             const uint8_t *a_ptr, int a_stride,
431                             const uint8_t *b_ptr, int b_stride,
432                             const uint8_t *m_ptr, int m_stride, int width,
433                             int height, unsigned int *sse, int *sum_) {
434   int x, y;
435   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
436 
437   for (y = 0; y < height; y++) {
438     for (x = 0; x < width; x += 16) {
439       const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
440       const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
441       const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
442       const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
443       accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
444     }
445 
446     src_ptr += src_stride;
447     a_ptr += a_stride;
448     b_ptr += b_stride;
449     m_ptr += m_stride;
450   }
451   // Reduce down to a single sum and sum of squares
452   sum = _mm_hadd_epi32(sum, sum_sq);
453   sum = _mm_hadd_epi32(sum, sum);
454   *sum_ = _mm_cvtsi128_si32(sum);
455   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
456 }
457 
masked_variance8xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)458 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
459                                const uint8_t *a_ptr, const uint8_t *b_ptr,
460                                const uint8_t *m_ptr, int m_stride, int height,
461                                unsigned int *sse, int *sum_) {
462   int y;
463   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
464 
465   for (y = 0; y < height; y += 2) {
466     __m128i src = _mm_unpacklo_epi64(
467         _mm_loadl_epi64((const __m128i *)src_ptr),
468         _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
469     const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
470     const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
471     const __m128i m =
472         _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
473                            _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
474     accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
475 
476     src_ptr += src_stride * 2;
477     a_ptr += 16;
478     b_ptr += 16;
479     m_ptr += m_stride * 2;
480   }
481   // Reduce down to a single sum and sum of squares
482   sum = _mm_hadd_epi32(sum, sum_sq);
483   sum = _mm_hadd_epi32(sum, sum);
484   *sum_ = _mm_cvtsi128_si32(sum);
485   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
486 }
487 
masked_variance4xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)488 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
489                                const uint8_t *a_ptr, const uint8_t *b_ptr,
490                                const uint8_t *m_ptr, int m_stride, int height,
491                                unsigned int *sse, int *sum_) {
492   int y;
493   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
494 
495   for (y = 0; y < height; y += 4) {
496     // Load four rows at a time
497     __m128i src = _mm_setr_epi32(*(int *)src_ptr, *(int *)&src_ptr[src_stride],
498                                  *(int *)&src_ptr[src_stride * 2],
499                                  *(int *)&src_ptr[src_stride * 3]);
500     const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
501     const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
502     const __m128i m = _mm_setr_epi32(*(int *)m_ptr, *(int *)&m_ptr[m_stride],
503                                      *(int *)&m_ptr[m_stride * 2],
504                                      *(int *)&m_ptr[m_stride * 3]);
505     accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
506 
507     src_ptr += src_stride * 4;
508     a_ptr += 16;
509     b_ptr += 16;
510     m_ptr += m_stride * 4;
511   }
512   // Reduce down to a single sum and sum of squares
513   sum = _mm_hadd_epi32(sum, sum_sq);
514   sum = _mm_hadd_epi32(sum, sum);
515   *sum_ = _mm_cvtsi128_si32(sum);
516   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
517 }
518 
519 #if CONFIG_AV1_HIGHBITDEPTH
520 // For width a multiple of 8
521 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
522                                    int xoffset, int yoffset, uint16_t *dst,
523                                    int w, int h);
524 
525 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
526                                       int xoffset, int yoffset, uint16_t *dst,
527                                       int h);
528 
529 // For width a multiple of 8
530 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
531                                    const uint16_t *a_ptr, int a_stride,
532                                    const uint16_t *b_ptr, int b_stride,
533                                    const uint8_t *m_ptr, int m_stride,
534                                    int width, int height, uint64_t *sse,
535                                    int *sum_);
536 
537 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
538                                       const uint16_t *a_ptr,
539                                       const uint16_t *b_ptr,
540                                       const uint8_t *m_ptr, int m_stride,
541                                       int height, int *sse, int *sum_);
542 
543 #define HIGHBD_MASK_SUBPIX_VAR_SSSE3(W, H)                                  \
544   unsigned int aom_highbd_8_masked_sub_pixel_variance##W##x##H##_ssse3(     \
545       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
546       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
547       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
548     uint64_t sse64;                                                         \
549     int sum;                                                                \
550     uint16_t temp[(H + 1) * W];                                             \
551     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
552     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
553     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
554                                                                             \
555     highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
556                                                                             \
557     if (!invert_mask)                                                       \
558       highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
559                              msk_stride, W, H, &sse64, &sum);               \
560     else                                                                    \
561       highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
562                              msk_stride, W, H, &sse64, &sum);               \
563     *sse = (uint32_t)sse64;                                                 \
564     return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H));               \
565   }                                                                         \
566   unsigned int aom_highbd_10_masked_sub_pixel_variance##W##x##H##_ssse3(    \
567       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
568       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
569       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
570     uint64_t sse64;                                                         \
571     int sum;                                                                \
572     int64_t var;                                                            \
573     uint16_t temp[(H + 1) * W];                                             \
574     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
575     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
576     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
577                                                                             \
578     highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
579                                                                             \
580     if (!invert_mask)                                                       \
581       highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
582                              msk_stride, W, H, &sse64, &sum);               \
583     else                                                                    \
584       highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
585                              msk_stride, W, H, &sse64, &sum);               \
586     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 4);                          \
587     sum = ROUND_POWER_OF_TWO(sum, 2);                                       \
588     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));               \
589     return (var >= 0) ? (uint32_t)var : 0;                                  \
590   }                                                                         \
591   unsigned int aom_highbd_12_masked_sub_pixel_variance##W##x##H##_ssse3(    \
592       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
593       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
594       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
595     uint64_t sse64;                                                         \
596     int sum;                                                                \
597     int64_t var;                                                            \
598     uint16_t temp[(H + 1) * W];                                             \
599     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
600     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
601     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
602                                                                             \
603     highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
604                                                                             \
605     if (!invert_mask)                                                       \
606       highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
607                              msk_stride, W, H, &sse64, &sum);               \
608     else                                                                    \
609       highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
610                              msk_stride, W, H, &sse64, &sum);               \
611     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 8);                          \
612     sum = ROUND_POWER_OF_TWO(sum, 4);                                       \
613     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));               \
614     return (var >= 0) ? (uint32_t)var : 0;                                  \
615   }
616 
617 #define HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(H)                                  \
618   unsigned int aom_highbd_8_masked_sub_pixel_variance4x##H##_ssse3(         \
619       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
620       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
621       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
622     int sse_;                                                               \
623     int sum;                                                                \
624     uint16_t temp[(H + 1) * 4];                                             \
625     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
626     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
627     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
628                                                                             \
629     highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
630                                                                             \
631     if (!invert_mask)                                                       \
632       highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
633                                 msk_stride, H, &sse_, &sum);                \
634     else                                                                    \
635       highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
636                                 msk_stride, H, &sse_, &sum);                \
637     *sse = (uint32_t)sse_;                                                  \
638     return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H));               \
639   }                                                                         \
640   unsigned int aom_highbd_10_masked_sub_pixel_variance4x##H##_ssse3(        \
641       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
642       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
643       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
644     int sse_;                                                               \
645     int sum;                                                                \
646     int64_t var;                                                            \
647     uint16_t temp[(H + 1) * 4];                                             \
648     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
649     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
650     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
651                                                                             \
652     highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
653                                                                             \
654     if (!invert_mask)                                                       \
655       highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
656                                 msk_stride, H, &sse_, &sum);                \
657     else                                                                    \
658       highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
659                                 msk_stride, H, &sse_, &sum);                \
660     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 4);                           \
661     sum = ROUND_POWER_OF_TWO(sum, 2);                                       \
662     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H));               \
663     return (var >= 0) ? (uint32_t)var : 0;                                  \
664   }                                                                         \
665   unsigned int aom_highbd_12_masked_sub_pixel_variance4x##H##_ssse3(        \
666       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
667       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
668       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
669     int sse_;                                                               \
670     int sum;                                                                \
671     int64_t var;                                                            \
672     uint16_t temp[(H + 1) * 4];                                             \
673     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
674     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
675     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
676                                                                             \
677     highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
678                                                                             \
679     if (!invert_mask)                                                       \
680       highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
681                                 msk_stride, H, &sse_, &sum);                \
682     else                                                                    \
683       highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
684                                 msk_stride, H, &sse_, &sum);                \
685     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 8);                           \
686     sum = ROUND_POWER_OF_TWO(sum, 4);                                       \
687     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H));               \
688     return (var >= 0) ? (uint32_t)var : 0;                                  \
689   }
690 
691 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 128)
692 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 64)
693 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 128)
694 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 64)
695 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 32)
696 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 64)
697 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 32)
698 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 16)
699 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 32)
700 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 16)
701 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 8)
702 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 16)
703 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 8)
704 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 4)
705 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(8)
706 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(4)
707 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(16)
708 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 4)
709 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 32)
710 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 8)
711 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 64)
712 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 16)
713 
highbd_filter_block(const __m128i a,const __m128i b,const __m128i filter)714 static INLINE __m128i highbd_filter_block(const __m128i a, const __m128i b,
715                                           const __m128i filter) {
716   __m128i v0 = _mm_unpacklo_epi16(a, b);
717   v0 = _mm_madd_epi16(v0, filter);
718   v0 = xx_roundn_epu32(v0, FILTER_BITS);
719 
720   __m128i v1 = _mm_unpackhi_epi16(a, b);
721   v1 = _mm_madd_epi16(v1, filter);
722   v1 = xx_roundn_epu32(v1, FILTER_BITS);
723 
724   return _mm_packs_epi32(v0, v1);
725 }
726 
highbd_bilinear_filter(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int w,int h)727 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
728                                    int xoffset, int yoffset, uint16_t *dst,
729                                    int w, int h) {
730   int i, j;
731   // Horizontal filter
732   if (xoffset == 0) {
733     uint16_t *b = dst;
734     for (i = 0; i < h + 1; ++i) {
735       for (j = 0; j < w; j += 8) {
736         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
737         _mm_storeu_si128((__m128i *)&b[j], x);
738       }
739       src += src_stride;
740       b += w;
741     }
742   } else if (xoffset == 4) {
743     uint16_t *b = dst;
744     for (i = 0; i < h + 1; ++i) {
745       for (j = 0; j < w; j += 8) {
746         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
747         __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
748         __m128i z = _mm_alignr_epi8(y, x, 2);
749         _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu16(x, z));
750       }
751       src += src_stride;
752       b += w;
753     }
754   } else {
755     uint16_t *b = dst;
756     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
757     const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
758     for (i = 0; i < h + 1; ++i) {
759       for (j = 0; j < w; j += 8) {
760         const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
761         const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
762         const __m128i z = _mm_alignr_epi8(y, x, 2);
763         const __m128i res = highbd_filter_block(x, z, hfilter_vec);
764         _mm_storeu_si128((__m128i *)&b[j], res);
765       }
766 
767       src += src_stride;
768       b += w;
769     }
770   }
771 
772   // Vertical filter
773   if (yoffset == 0) {
774     // The data is already in 'dst', so no need to filter
775   } else if (yoffset == 4) {
776     for (i = 0; i < h; ++i) {
777       for (j = 0; j < w; j += 8) {
778         __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
779         __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
780         _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu16(x, y));
781       }
782       dst += w;
783     }
784   } else {
785     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
786     const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
787     for (i = 0; i < h; ++i) {
788       for (j = 0; j < w; j += 8) {
789         const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
790         const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
791         const __m128i res = highbd_filter_block(x, y, vfilter_vec);
792         _mm_storeu_si128((__m128i *)&dst[j], res);
793       }
794 
795       dst += w;
796     }
797   }
798 }
799 
highbd_filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)800 static INLINE __m128i highbd_filter_block_2rows(const __m128i *a0,
801                                                 const __m128i *b0,
802                                                 const __m128i *a1,
803                                                 const __m128i *b1,
804                                                 const __m128i *filter) {
805   __m128i v0 = _mm_unpacklo_epi16(*a0, *b0);
806   v0 = _mm_madd_epi16(v0, *filter);
807   v0 = xx_roundn_epu32(v0, FILTER_BITS);
808 
809   __m128i v1 = _mm_unpacklo_epi16(*a1, *b1);
810   v1 = _mm_madd_epi16(v1, *filter);
811   v1 = xx_roundn_epu32(v1, FILTER_BITS);
812 
813   return _mm_packs_epi32(v0, v1);
814 }
815 
highbd_bilinear_filter4xh(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int h)816 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
817                                       int xoffset, int yoffset, uint16_t *dst,
818                                       int h) {
819   int i;
820   // Horizontal filter
821   if (xoffset == 0) {
822     uint16_t *b = dst;
823     for (i = 0; i < h + 1; ++i) {
824       __m128i x = _mm_loadl_epi64((__m128i *)src);
825       _mm_storel_epi64((__m128i *)b, x);
826       src += src_stride;
827       b += 4;
828     }
829   } else if (xoffset == 4) {
830     uint16_t *b = dst;
831     for (i = 0; i < h + 1; ++i) {
832       __m128i x = _mm_loadu_si128((__m128i *)src);
833       __m128i z = _mm_srli_si128(x, 2);
834       _mm_storel_epi64((__m128i *)b, _mm_avg_epu16(x, z));
835       src += src_stride;
836       b += 4;
837     }
838   } else {
839     uint16_t *b = dst;
840     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
841     const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
842     for (i = 0; i < h; i += 2) {
843       const __m128i x0 = _mm_loadu_si128((__m128i *)src);
844       const __m128i z0 = _mm_srli_si128(x0, 2);
845       const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
846       const __m128i z1 = _mm_srli_si128(x1, 2);
847       const __m128i res =
848           highbd_filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
849       _mm_storeu_si128((__m128i *)b, res);
850 
851       src += src_stride * 2;
852       b += 8;
853     }
854     // Process i = h separately
855     __m128i x = _mm_loadu_si128((__m128i *)src);
856     __m128i z = _mm_srli_si128(x, 2);
857 
858     __m128i v0 = _mm_unpacklo_epi16(x, z);
859     v0 = _mm_madd_epi16(v0, hfilter_vec);
860     v0 = xx_roundn_epu32(v0, FILTER_BITS);
861 
862     _mm_storel_epi64((__m128i *)b, _mm_packs_epi32(v0, v0));
863   }
864 
865   // Vertical filter
866   if (yoffset == 0) {
867     // The data is already in 'dst', so no need to filter
868   } else if (yoffset == 4) {
869     for (i = 0; i < h; ++i) {
870       __m128i x = _mm_loadl_epi64((__m128i *)dst);
871       __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
872       _mm_storel_epi64((__m128i *)dst, _mm_avg_epu16(x, y));
873       dst += 4;
874     }
875   } else {
876     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
877     const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
878     for (i = 0; i < h; i += 2) {
879       const __m128i x = _mm_loadl_epi64((__m128i *)dst);
880       const __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
881       const __m128i z = _mm_loadl_epi64((__m128i *)&dst[8]);
882       const __m128i res =
883           highbd_filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
884       _mm_storeu_si128((__m128i *)dst, res);
885 
886       dst += 8;
887     }
888   }
889 }
890 
highbd_masked_variance(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,int a_stride,const uint16_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,uint64_t * sse,int * sum_)891 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
892                                    const uint16_t *a_ptr, int a_stride,
893                                    const uint16_t *b_ptr, int b_stride,
894                                    const uint8_t *m_ptr, int m_stride,
895                                    int width, int height, uint64_t *sse,
896                                    int *sum_) {
897   int x, y;
898   // Note on bit widths:
899   // The maximum value of 'sum' is (2^12 - 1) * 128 * 128 =~ 2^26,
900   // so this can be kept as four 32-bit values.
901   // But the maximum value of 'sum_sq' is (2^12 - 1)^2 * 128 * 128 =~ 2^38,
902   // so this must be stored as two 64-bit values.
903   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
904   const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
905   const __m128i round_const =
906       _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
907   const __m128i zero = _mm_setzero_si128();
908 
909   for (y = 0; y < height; y++) {
910     for (x = 0; x < width; x += 8) {
911       const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
912       const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
913       const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
914       const __m128i m =
915           _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)&m_ptr[x]), zero);
916       const __m128i m_inv = _mm_sub_epi16(mask_max, m);
917 
918       // Calculate 8 predicted pixels.
919       const __m128i data_l = _mm_unpacklo_epi16(a, b);
920       const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
921       __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
922       pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
923                               AOM_BLEND_A64_ROUND_BITS);
924 
925       const __m128i data_r = _mm_unpackhi_epi16(a, b);
926       const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
927       __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
928       pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
929                               AOM_BLEND_A64_ROUND_BITS);
930 
931       const __m128i src_l = _mm_unpacklo_epi16(src, zero);
932       const __m128i src_r = _mm_unpackhi_epi16(src, zero);
933       __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
934       __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
935 
936       // Update partial sums and partial sums of squares
937       sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
938       // A trick: Now each entry of diff_l and diff_r is stored in a 32-bit
939       // field, but the range of values is only [-(2^12 - 1), 2^12 - 1].
940       // So we can re-pack into 16-bit fields and use _mm_madd_epi16
941       // to calculate the squares and partially sum them.
942       const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
943       const __m128i prod = _mm_madd_epi16(tmp, tmp);
944       // Then we want to sign-extend to 64 bits and accumulate
945       const __m128i sign = _mm_srai_epi32(prod, 31);
946       const __m128i tmp_0 = _mm_unpacklo_epi32(prod, sign);
947       const __m128i tmp_1 = _mm_unpackhi_epi32(prod, sign);
948       sum_sq = _mm_add_epi64(sum_sq, _mm_add_epi64(tmp_0, tmp_1));
949     }
950 
951     src_ptr += src_stride;
952     a_ptr += a_stride;
953     b_ptr += b_stride;
954     m_ptr += m_stride;
955   }
956   // Reduce down to a single sum and sum of squares
957   sum = _mm_hadd_epi32(sum, zero);
958   sum = _mm_hadd_epi32(sum, zero);
959   *sum_ = _mm_cvtsi128_si32(sum);
960   sum_sq = _mm_add_epi64(sum_sq, _mm_srli_si128(sum_sq, 8));
961   _mm_storel_epi64((__m128i *)sse, sum_sq);
962 }
963 
highbd_masked_variance4xh(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,const uint16_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,int * sse,int * sum_)964 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
965                                       const uint16_t *a_ptr,
966                                       const uint16_t *b_ptr,
967                                       const uint8_t *m_ptr, int m_stride,
968                                       int height, int *sse, int *sum_) {
969   int y;
970   // Note: For this function, h <= 8 (or maybe 16 if we add 4:1 partitions).
971   // So the maximum value of sum is (2^12 - 1) * 4 * 16 =~ 2^18
972   // and the maximum value of sum_sq is (2^12 - 1)^2 * 4 * 16 =~ 2^30.
973   // So we can safely pack sum_sq into 32-bit fields, which is slightly more
974   // convenient.
975   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
976   const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
977   const __m128i round_const =
978       _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
979   const __m128i zero = _mm_setzero_si128();
980 
981   for (y = 0; y < height; y += 2) {
982     __m128i src = _mm_unpacklo_epi64(
983         _mm_loadl_epi64((const __m128i *)src_ptr),
984         _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
985     const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
986     const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
987     const __m128i m = _mm_unpacklo_epi8(
988         _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(const int *)m_ptr),
989                            _mm_cvtsi32_si128(*(const int *)&m_ptr[m_stride])),
990         zero);
991     const __m128i m_inv = _mm_sub_epi16(mask_max, m);
992 
993     const __m128i data_l = _mm_unpacklo_epi16(a, b);
994     const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
995     __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
996     pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
997                             AOM_BLEND_A64_ROUND_BITS);
998 
999     const __m128i data_r = _mm_unpackhi_epi16(a, b);
1000     const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
1001     __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
1002     pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
1003                             AOM_BLEND_A64_ROUND_BITS);
1004 
1005     const __m128i src_l = _mm_unpacklo_epi16(src, zero);
1006     const __m128i src_r = _mm_unpackhi_epi16(src, zero);
1007     __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
1008     __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
1009 
1010     // Update partial sums and partial sums of squares
1011     sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
1012     const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
1013     const __m128i prod = _mm_madd_epi16(tmp, tmp);
1014     sum_sq = _mm_add_epi32(sum_sq, prod);
1015 
1016     src_ptr += src_stride * 2;
1017     a_ptr += 8;
1018     b_ptr += 8;
1019     m_ptr += m_stride * 2;
1020   }
1021   // Reduce down to a single sum and sum of squares
1022   sum = _mm_hadd_epi32(sum, sum_sq);
1023   sum = _mm_hadd_epi32(sum, zero);
1024   *sum_ = _mm_cvtsi128_si32(sum);
1025   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
1026 }
1027 #endif  // CONFIG_AV1_HIGHBITDEPTH
1028 
aom_comp_mask_pred_ssse3(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)1029 void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
1030                               int width, int height, const uint8_t *ref,
1031                               int ref_stride, const uint8_t *mask,
1032                               int mask_stride, int invert_mask) {
1033   const uint8_t *src0 = invert_mask ? pred : ref;
1034   const uint8_t *src1 = invert_mask ? ref : pred;
1035   const int stride0 = invert_mask ? width : ref_stride;
1036   const int stride1 = invert_mask ? ref_stride : width;
1037   assert(height % 2 == 0);
1038   int i = 0;
1039   if (width == 8) {
1040     comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
1041                            mask, mask_stride);
1042   } else if (width == 16) {
1043     do {
1044       comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
1045       comp_mask_pred_16_ssse3(src0 + stride0, src1 + stride1,
1046                               mask + mask_stride, comp_pred + width);
1047       comp_pred += (width << 1);
1048       src0 += (stride0 << 1);
1049       src1 += (stride1 << 1);
1050       mask += (mask_stride << 1);
1051       i += 2;
1052     } while (i < height);
1053   } else {
1054     do {
1055       for (int x = 0; x < width; x += 32) {
1056         comp_mask_pred_16_ssse3(src0 + x, src1 + x, mask + x, comp_pred);
1057         comp_mask_pred_16_ssse3(src0 + x + 16, src1 + x + 16, mask + x + 16,
1058                                 comp_pred + 16);
1059         comp_pred += 32;
1060       }
1061       src0 += (stride0);
1062       src1 += (stride1);
1063       mask += (mask_stride);
1064       i += 1;
1065     } while (i < height);
1066   }
1067 }
1068