• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <emmintrin.h>  // SSE2
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_dsp/x86/synonyms.h"
19 #include "aom_ports/mem.h"
20 
21 #include "av1/common/filter.h"
22 #include "av1/common/reconinter.h"
23 
24 typedef uint32_t (*high_variance_fn_t)(const uint16_t *src, int src_stride,
25                                        const uint16_t *ref, int ref_stride,
26                                        uint32_t *sse, int *sum);
27 
28 uint32_t aom_highbd_calc8x8var_sse2(const uint16_t *src, int src_stride,
29                                     const uint16_t *ref, int ref_stride,
30                                     uint32_t *sse, int *sum);
31 
32 uint32_t aom_highbd_calc16x16var_sse2(const uint16_t *src, int src_stride,
33                                       const uint16_t *ref, int ref_stride,
34                                       uint32_t *sse, int *sum);
35 
highbd_8_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)36 static void highbd_8_variance_sse2(const uint16_t *src, int src_stride,
37                                    const uint16_t *ref, int ref_stride, int w,
38                                    int h, uint32_t *sse, int *sum,
39                                    high_variance_fn_t var_fn, int block_size) {
40   int i, j;
41 
42   *sse = 0;
43   *sum = 0;
44 
45   for (i = 0; i < h; i += block_size) {
46     for (j = 0; j < w; j += block_size) {
47       unsigned int sse0;
48       int sum0;
49       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
50              ref_stride, &sse0, &sum0);
51       *sse += sse0;
52       *sum += sum0;
53     }
54   }
55 }
56 
highbd_10_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)57 static void highbd_10_variance_sse2(const uint16_t *src, int src_stride,
58                                     const uint16_t *ref, int ref_stride, int w,
59                                     int h, uint32_t *sse, int *sum,
60                                     high_variance_fn_t var_fn, int block_size) {
61   int i, j;
62   uint64_t sse_long = 0;
63   int32_t sum_long = 0;
64 
65   for (i = 0; i < h; i += block_size) {
66     for (j = 0; j < w; j += block_size) {
67       unsigned int sse0;
68       int sum0;
69       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
70              ref_stride, &sse0, &sum0);
71       sse_long += sse0;
72       sum_long += sum0;
73     }
74   }
75   *sum = ROUND_POWER_OF_TWO(sum_long, 2);
76   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);
77 }
78 
highbd_12_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)79 static void highbd_12_variance_sse2(const uint16_t *src, int src_stride,
80                                     const uint16_t *ref, int ref_stride, int w,
81                                     int h, uint32_t *sse, int *sum,
82                                     high_variance_fn_t var_fn, int block_size) {
83   int i, j;
84   uint64_t sse_long = 0;
85   int32_t sum_long = 0;
86 
87   for (i = 0; i < h; i += block_size) {
88     for (j = 0; j < w; j += block_size) {
89       unsigned int sse0;
90       int sum0;
91       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
92              ref_stride, &sse0, &sum0);
93       sse_long += sse0;
94       sum_long += sum0;
95     }
96   }
97   *sum = ROUND_POWER_OF_TWO(sum_long, 4);
98   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);
99 }
100 
101 #define VAR_FN(w, h, block_size, shift)                                    \
102   uint32_t aom_highbd_8_variance##w##x##h##_sse2(                          \
103       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
104       int ref_stride, uint32_t *sse) {                                     \
105     int sum;                                                               \
106     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
107     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
108     highbd_8_variance_sse2(                                                \
109         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
110         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
111     return *sse - (uint32_t)(((int64_t)sum * sum) >> shift);               \
112   }                                                                        \
113                                                                            \
114   uint32_t aom_highbd_10_variance##w##x##h##_sse2(                         \
115       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
116       int ref_stride, uint32_t *sse) {                                     \
117     int sum;                                                               \
118     int64_t var;                                                           \
119     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
120     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
121     highbd_10_variance_sse2(                                               \
122         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
123         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
124     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> shift);               \
125     return (var >= 0) ? (uint32_t)var : 0;                                 \
126   }                                                                        \
127                                                                            \
128   uint32_t aom_highbd_12_variance##w##x##h##_sse2(                         \
129       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
130       int ref_stride, uint32_t *sse) {                                     \
131     int sum;                                                               \
132     int64_t var;                                                           \
133     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
134     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
135     highbd_12_variance_sse2(                                               \
136         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
137         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
138     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> shift);               \
139     return (var >= 0) ? (uint32_t)var : 0;                                 \
140   }
141 
142 VAR_FN(128, 128, 16, 14)
143 VAR_FN(128, 64, 16, 13)
144 VAR_FN(64, 128, 16, 13)
145 VAR_FN(64, 64, 16, 12)
146 VAR_FN(64, 32, 16, 11)
147 VAR_FN(32, 64, 16, 11)
148 VAR_FN(32, 32, 16, 10)
149 VAR_FN(32, 16, 16, 9)
150 VAR_FN(16, 32, 16, 9)
151 VAR_FN(16, 16, 16, 8)
152 VAR_FN(16, 8, 8, 7)
153 VAR_FN(8, 16, 8, 7)
154 VAR_FN(8, 8, 8, 6)
155 VAR_FN(8, 32, 8, 8)
156 VAR_FN(32, 8, 8, 8)
157 VAR_FN(16, 64, 16, 10)
158 VAR_FN(64, 16, 16, 10)
159 
160 #undef VAR_FN
161 
aom_highbd_8_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)162 unsigned int aom_highbd_8_mse16x16_sse2(const uint8_t *src8, int src_stride,
163                                         const uint8_t *ref8, int ref_stride,
164                                         unsigned int *sse) {
165   int sum;
166   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
167   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
168   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
169                          aom_highbd_calc16x16var_sse2, 16);
170   return *sse;
171 }
172 
aom_highbd_10_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)173 unsigned int aom_highbd_10_mse16x16_sse2(const uint8_t *src8, int src_stride,
174                                          const uint8_t *ref8, int ref_stride,
175                                          unsigned int *sse) {
176   int sum;
177   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
178   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
179   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
180                           aom_highbd_calc16x16var_sse2, 16);
181   return *sse;
182 }
183 
aom_highbd_12_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)184 unsigned int aom_highbd_12_mse16x16_sse2(const uint8_t *src8, int src_stride,
185                                          const uint8_t *ref8, int ref_stride,
186                                          unsigned int *sse) {
187   int sum;
188   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
189   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
190   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
191                           aom_highbd_calc16x16var_sse2, 16);
192   return *sse;
193 }
194 
aom_highbd_8_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)195 unsigned int aom_highbd_8_mse8x8_sse2(const uint8_t *src8, int src_stride,
196                                       const uint8_t *ref8, int ref_stride,
197                                       unsigned int *sse) {
198   int sum;
199   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
200   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
201   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
202                          aom_highbd_calc8x8var_sse2, 8);
203   return *sse;
204 }
205 
aom_highbd_10_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)206 unsigned int aom_highbd_10_mse8x8_sse2(const uint8_t *src8, int src_stride,
207                                        const uint8_t *ref8, int ref_stride,
208                                        unsigned int *sse) {
209   int sum;
210   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
211   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
212   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
213                           aom_highbd_calc8x8var_sse2, 8);
214   return *sse;
215 }
216 
aom_highbd_12_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)217 unsigned int aom_highbd_12_mse8x8_sse2(const uint8_t *src8, int src_stride,
218                                        const uint8_t *ref8, int ref_stride,
219                                        unsigned int *sse) {
220   int sum;
221   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
222   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
223   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
224                           aom_highbd_calc8x8var_sse2, 8);
225   return *sse;
226 }
227 
228 // The 2 unused parameters are place holders for PIC enabled build.
229 // These definitions are for functions defined in
230 // highbd_subpel_variance_impl_sse2.asm
231 #define DECL(w, opt)                                                         \
232   int aom_highbd_sub_pixel_variance##w##xh_##opt(                            \
233       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
234       const uint16_t *dst, ptrdiff_t dst_stride, int height,                 \
235       unsigned int *sse, void *unused0, void *unused);
236 #define DECLS(opt) \
237   DECL(8, opt)     \
238   DECL(16, opt)
239 
240 DECLS(sse2)
241 
242 #undef DECLS
243 #undef DECL
244 
245 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
246   uint32_t aom_highbd_8_sub_pixel_variance##w##x##h##_##opt(                   \
247       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
248       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
249     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
250     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
251     int se = 0;                                                                \
252     unsigned int sse = 0;                                                      \
253     unsigned int sse2;                                                         \
254     int row_rep = (w > 64) ? 2 : 1;                                            \
255     for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                            \
256       src += wd_64 * 64;                                                       \
257       dst += wd_64 * 64;                                                       \
258       int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
259           src, src_stride, x_offset, y_offset, dst, dst_stride, h, &sse2,      \
260           NULL, NULL);                                                         \
261       se += se2;                                                               \
262       sse += sse2;                                                             \
263       if (w > wf) {                                                            \
264         se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                     \
265             src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride, h, \
266             &sse2, NULL, NULL);                                                \
267         se += se2;                                                             \
268         sse += sse2;                                                           \
269         if (w > wf * 2) {                                                      \
270           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
271               src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,      \
272               dst_stride, h, &sse2, NULL, NULL);                               \
273           se += se2;                                                           \
274           sse += sse2;                                                         \
275           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
276               src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,      \
277               dst_stride, h, &sse2, NULL, NULL);                               \
278           se += se2;                                                           \
279           sse += sse2;                                                         \
280         }                                                                      \
281       }                                                                        \
282     }                                                                          \
283     *sse_ptr = sse;                                                            \
284     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
285   }                                                                            \
286                                                                                \
287   uint32_t aom_highbd_10_sub_pixel_variance##w##x##h##_##opt(                  \
288       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
289       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
290     int64_t var;                                                               \
291     uint32_t sse;                                                              \
292     uint64_t long_sse = 0;                                                     \
293     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
294     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
295     int se = 0;                                                                \
296     int row_rep = (w > 64) ? 2 : 1;                                            \
297     for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                            \
298       src += wd_64 * 64;                                                       \
299       dst += wd_64 * 64;                                                       \
300       int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
301           src, src_stride, x_offset, y_offset, dst, dst_stride, h, &sse, NULL, \
302           NULL);                                                               \
303       se += se2;                                                               \
304       long_sse += sse;                                                         \
305       if (w > wf) {                                                            \
306         uint32_t sse2;                                                         \
307         se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                     \
308             src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride, h, \
309             &sse2, NULL, NULL);                                                \
310         se += se2;                                                             \
311         long_sse += sse2;                                                      \
312         if (w > wf * 2) {                                                      \
313           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
314               src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,      \
315               dst_stride, h, &sse2, NULL, NULL);                               \
316           se += se2;                                                           \
317           long_sse += sse2;                                                    \
318           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
319               src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,      \
320               dst_stride, h, &sse2, NULL, NULL);                               \
321           se += se2;                                                           \
322           long_sse += sse2;                                                    \
323         }                                                                      \
324       }                                                                        \
325     }                                                                          \
326     se = ROUND_POWER_OF_TWO(se, 2);                                            \
327     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 4);                           \
328     *sse_ptr = sse;                                                            \
329     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
330     return (var >= 0) ? (uint32_t)var : 0;                                     \
331   }                                                                            \
332                                                                                \
333   uint32_t aom_highbd_12_sub_pixel_variance##w##x##h##_##opt(                  \
334       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
335       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
336     int start_row;                                                             \
337     uint32_t sse;                                                              \
338     int se = 0;                                                                \
339     int64_t var;                                                               \
340     uint64_t long_sse = 0;                                                     \
341     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
342     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
343     int row_rep = (w > 64) ? 2 : 1;                                            \
344     for (start_row = 0; start_row < h; start_row += 16) {                      \
345       uint32_t sse2;                                                           \
346       int height = h - start_row < 16 ? h - start_row : 16;                    \
347       uint16_t *src_tmp = src + (start_row * src_stride);                      \
348       uint16_t *dst_tmp = dst + (start_row * dst_stride);                      \
349       for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                          \
350         src_tmp += wd_64 * 64;                                                 \
351         dst_tmp += wd_64 * 64;                                                 \
352         int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
353             src_tmp, src_stride, x_offset, y_offset, dst_tmp, dst_stride,      \
354             height, &sse2, NULL, NULL);                                        \
355         se += se2;                                                             \
356         long_sse += sse2;                                                      \
357         if (w > wf) {                                                          \
358           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
359               src_tmp + wf, src_stride, x_offset, y_offset, dst_tmp + wf,      \
360               dst_stride, height, &sse2, NULL, NULL);                          \
361           se += se2;                                                           \
362           long_sse += sse2;                                                    \
363           if (w > wf * 2) {                                                    \
364             se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
365                 src_tmp + 2 * wf, src_stride, x_offset, y_offset,              \
366                 dst_tmp + 2 * wf, dst_stride, height, &sse2, NULL, NULL);      \
367             se += se2;                                                         \
368             long_sse += sse2;                                                  \
369             se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
370                 src_tmp + 3 * wf, src_stride, x_offset, y_offset,              \
371                 dst_tmp + 3 * wf, dst_stride, height, &sse2, NULL, NULL);      \
372             se += se2;                                                         \
373             long_sse += sse2;                                                  \
374           }                                                                    \
375         }                                                                      \
376       }                                                                        \
377     }                                                                          \
378     se = ROUND_POWER_OF_TWO(se, 4);                                            \
379     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
380     *sse_ptr = sse;                                                            \
381     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
382     return (var >= 0) ? (uint32_t)var : 0;                                     \
383   }
384 
385 #define FNS(opt)                         \
386   FN(128, 128, 16, 7, 7, opt, (int64_t)) \
387   FN(128, 64, 16, 7, 6, opt, (int64_t))  \
388   FN(64, 128, 16, 6, 7, opt, (int64_t))  \
389   FN(64, 64, 16, 6, 6, opt, (int64_t))   \
390   FN(64, 32, 16, 6, 5, opt, (int64_t))   \
391   FN(32, 64, 16, 5, 6, opt, (int64_t))   \
392   FN(32, 32, 16, 5, 5, opt, (int64_t))   \
393   FN(32, 16, 16, 5, 4, opt, (int64_t))   \
394   FN(16, 32, 16, 4, 5, opt, (int64_t))   \
395   FN(16, 16, 16, 4, 4, opt, (int64_t))   \
396   FN(16, 8, 16, 4, 3, opt, (int64_t))    \
397   FN(8, 16, 8, 3, 4, opt, (int64_t))     \
398   FN(8, 8, 8, 3, 3, opt, (int64_t))      \
399   FN(8, 4, 8, 3, 2, opt, (int64_t))      \
400   FN(16, 4, 16, 4, 2, opt, (int64_t))    \
401   FN(8, 32, 8, 3, 5, opt, (int64_t))     \
402   FN(32, 8, 16, 5, 3, opt, (int64_t))    \
403   FN(16, 64, 16, 4, 6, opt, (int64_t))   \
404   FN(64, 16, 16, 6, 4, opt, (int64_t))
405 
FNS(sse2)406 FNS(sse2)
407 
408 #undef FNS
409 #undef FN
410 
411 // The 2 unused parameters are place holders for PIC enabled build.
412 #define DECL(w, opt)                                                         \
413   int aom_highbd_sub_pixel_avg_variance##w##xh_##opt(                        \
414       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
415       const uint16_t *dst, ptrdiff_t dst_stride, const uint16_t *sec,        \
416       ptrdiff_t sec_stride, int height, unsigned int *sse, void *unused0,    \
417       void *unused);
418 #define DECLS(opt) \
419   DECL(16, opt)    \
420   DECL(8, opt)
421 
422 DECLS(sse2)
423 #undef DECL
424 #undef DECLS
425 
426 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
427   uint32_t aom_highbd_8_sub_pixel_avg_variance##w##x##h##_##opt(               \
428       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
429       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
430       const uint8_t *sec8) {                                                   \
431     uint32_t sse;                                                              \
432     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
433     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
434     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
435     int se = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
436         src, src_stride, x_offset, y_offset, dst, dst_stride, sec, w, h, &sse, \
437         NULL, NULL);                                                           \
438     if (w > wf) {                                                              \
439       uint32_t sse2;                                                           \
440       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
441           src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride,      \
442           sec + wf, w, h, &sse2, NULL, NULL);                                  \
443       se += se2;                                                               \
444       sse += sse2;                                                             \
445       if (w > wf * 2) {                                                        \
446         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
447             src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,        \
448             dst_stride, sec + 2 * wf, w, h, &sse2, NULL, NULL);                \
449         se += se2;                                                             \
450         sse += sse2;                                                           \
451         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
452             src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,        \
453             dst_stride, sec + 3 * wf, w, h, &sse2, NULL, NULL);                \
454         se += se2;                                                             \
455         sse += sse2;                                                           \
456       }                                                                        \
457     }                                                                          \
458     *sse_ptr = sse;                                                            \
459     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
460   }                                                                            \
461                                                                                \
462   uint32_t aom_highbd_10_sub_pixel_avg_variance##w##x##h##_##opt(              \
463       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
464       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
465       const uint8_t *sec8) {                                                   \
466     int64_t var;                                                               \
467     uint32_t sse;                                                              \
468     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
469     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
470     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
471     int se = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
472         src, src_stride, x_offset, y_offset, dst, dst_stride, sec, w, h, &sse, \
473         NULL, NULL);                                                           \
474     if (w > wf) {                                                              \
475       uint32_t sse2;                                                           \
476       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
477           src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride,      \
478           sec + wf, w, h, &sse2, NULL, NULL);                                  \
479       se += se2;                                                               \
480       sse += sse2;                                                             \
481       if (w > wf * 2) {                                                        \
482         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
483             src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,        \
484             dst_stride, sec + 2 * wf, w, h, &sse2, NULL, NULL);                \
485         se += se2;                                                             \
486         sse += sse2;                                                           \
487         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
488             src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,        \
489             dst_stride, sec + 3 * wf, w, h, &sse2, NULL, NULL);                \
490         se += se2;                                                             \
491         sse += sse2;                                                           \
492       }                                                                        \
493     }                                                                          \
494     se = ROUND_POWER_OF_TWO(se, 2);                                            \
495     sse = ROUND_POWER_OF_TWO(sse, 4);                                          \
496     *sse_ptr = sse;                                                            \
497     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
498     return (var >= 0) ? (uint32_t)var : 0;                                     \
499   }                                                                            \
500                                                                                \
501   uint32_t aom_highbd_12_sub_pixel_avg_variance##w##x##h##_##opt(              \
502       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
503       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
504       const uint8_t *sec8) {                                                   \
505     int start_row;                                                             \
506     int64_t var;                                                               \
507     uint32_t sse;                                                              \
508     int se = 0;                                                                \
509     uint64_t long_sse = 0;                                                     \
510     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
511     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
512     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
513     for (start_row = 0; start_row < h; start_row += 16) {                      \
514       uint32_t sse2;                                                           \
515       int height = h - start_row < 16 ? h - start_row : 16;                    \
516       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
517           src + (start_row * src_stride), src_stride, x_offset, y_offset,      \
518           dst + (start_row * dst_stride), dst_stride, sec + (start_row * w),   \
519           w, height, &sse2, NULL, NULL);                                       \
520       se += se2;                                                               \
521       long_sse += sse2;                                                        \
522       if (w > wf) {                                                            \
523         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
524             src + wf + (start_row * src_stride), src_stride, x_offset,         \
525             y_offset, dst + wf + (start_row * dst_stride), dst_stride,         \
526             sec + wf + (start_row * w), w, height, &sse2, NULL, NULL);         \
527         se += se2;                                                             \
528         long_sse += sse2;                                                      \
529         if (w > wf * 2) {                                                      \
530           se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
531               src + 2 * wf + (start_row * src_stride), src_stride, x_offset,   \
532               y_offset, dst + 2 * wf + (start_row * dst_stride), dst_stride,   \
533               sec + 2 * wf + (start_row * w), w, height, &sse2, NULL, NULL);   \
534           se += se2;                                                           \
535           long_sse += sse2;                                                    \
536           se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
537               src + 3 * wf + (start_row * src_stride), src_stride, x_offset,   \
538               y_offset, dst + 3 * wf + (start_row * dst_stride), dst_stride,   \
539               sec + 3 * wf + (start_row * w), w, height, &sse2, NULL, NULL);   \
540           se += se2;                                                           \
541           long_sse += sse2;                                                    \
542         }                                                                      \
543       }                                                                        \
544     }                                                                          \
545     se = ROUND_POWER_OF_TWO(se, 4);                                            \
546     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
547     *sse_ptr = sse;                                                            \
548     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
549     return (var >= 0) ? (uint32_t)var : 0;                                     \
550   }
551 
552 #define FNS(opt)                       \
553   FN(64, 64, 16, 6, 6, opt, (int64_t)) \
554   FN(64, 32, 16, 6, 5, opt, (int64_t)) \
555   FN(32, 64, 16, 5, 6, opt, (int64_t)) \
556   FN(32, 32, 16, 5, 5, opt, (int64_t)) \
557   FN(32, 16, 16, 5, 4, opt, (int64_t)) \
558   FN(16, 32, 16, 4, 5, opt, (int64_t)) \
559   FN(16, 16, 16, 4, 4, opt, (int64_t)) \
560   FN(16, 8, 16, 4, 3, opt, (int64_t))  \
561   FN(8, 16, 8, 3, 4, opt, (int64_t))   \
562   FN(8, 8, 8, 3, 3, opt, (int64_t))    \
563   FN(8, 4, 8, 3, 2, opt, (int64_t))    \
564   FN(16, 4, 16, 4, 2, opt, (int64_t))  \
565   FN(8, 32, 8, 3, 5, opt, (int64_t))   \
566   FN(32, 8, 16, 5, 3, opt, (int64_t))  \
567   FN(16, 64, 16, 4, 6, opt, (int64_t)) \
568   FN(64, 16, 16, 6, 4, opt, (int64_t))
569 
570 FNS(sse2)
571 
572 #undef FNS
573 #undef FN
574 
575 static INLINE void highbd_compute_dist_wtd_comp_avg(__m128i *p0, __m128i *p1,
576                                                     const __m128i *w0,
577                                                     const __m128i *w1,
578                                                     const __m128i *r,
579                                                     void *const result) {
580   assert(DIST_PRECISION_BITS <= 4);
581   __m128i mult0 = _mm_mullo_epi16(*p0, *w0);
582   __m128i mult1 = _mm_mullo_epi16(*p1, *w1);
583   __m128i sum = _mm_adds_epu16(mult0, mult1);
584   __m128i round = _mm_adds_epu16(sum, *r);
585   __m128i shift = _mm_srli_epi16(round, DIST_PRECISION_BITS);
586 
587   xx_storeu_128(result, shift);
588 }
589 
aom_highbd_dist_wtd_comp_avg_pred_sse2(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const DIST_WTD_COMP_PARAMS * jcp_param)590 void aom_highbd_dist_wtd_comp_avg_pred_sse2(
591     uint8_t *comp_pred8, const uint8_t *pred8, int width, int height,
592     const uint8_t *ref8, int ref_stride,
593     const DIST_WTD_COMP_PARAMS *jcp_param) {
594   int i;
595   const int16_t wt0 = (int16_t)jcp_param->fwd_offset;
596   const int16_t wt1 = (int16_t)jcp_param->bck_offset;
597   const __m128i w0 = _mm_set1_epi16(wt0);
598   const __m128i w1 = _mm_set1_epi16(wt1);
599   const int16_t round = (int16_t)((1 << DIST_PRECISION_BITS) >> 1);
600   const __m128i r = _mm_set1_epi16(round);
601   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
602   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
603   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
604 
605   if (width >= 8) {
606     // Read 8 pixels one row at a time
607     assert(!(width & 7));
608     for (i = 0; i < height; ++i) {
609       int j;
610       for (j = 0; j < width; j += 8) {
611         __m128i p0 = xx_loadu_128(ref);
612         __m128i p1 = xx_loadu_128(pred);
613 
614         highbd_compute_dist_wtd_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
615 
616         comp_pred += 8;
617         pred += 8;
618         ref += 8;
619       }
620       ref += ref_stride - width;
621     }
622   } else {
623     // Read 4 pixels two rows at a time
624     assert(!(width & 3));
625     for (i = 0; i < height; i += 2) {
626       __m128i p0_0 = xx_loadl_64(ref + 0 * ref_stride);
627       __m128i p0_1 = xx_loadl_64(ref + 1 * ref_stride);
628       __m128i p0 = _mm_unpacklo_epi64(p0_0, p0_1);
629       __m128i p1 = xx_loadu_128(pred);
630 
631       highbd_compute_dist_wtd_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
632 
633       comp_pred += 8;
634       pred += 8;
635       ref += 2 * ref_stride;
636     }
637   }
638 }
639 
mse_4xh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int h)640 static uint64_t mse_4xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
641                                           uint16_t *src, int sstride, int h) {
642   uint64_t sum = 0;
643   __m128i reg0_4x16, reg1_4x16;
644   __m128i src_8x16;
645   __m128i dst_8x16;
646   __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
647   __m128i sub_result_8x16;
648   const __m128i zeros = _mm_setzero_si128();
649   __m128i square_result = _mm_setzero_si128();
650   for (int i = 0; i < h; i += 2) {
651     reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
652     reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 1) * dstride]));
653     dst_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
654 
655     reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
656     reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
657     src_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
658 
659     sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
660 
661     res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
662     res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
663 
664     res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
665     res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
666 
667     res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
668     res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
669     res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
670     res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
671 
672     square_result = _mm_add_epi64(
673         square_result,
674         _mm_add_epi64(
675             _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
676             res3_4x64));
677   }
678 
679   const __m128i sum_1x64 =
680       _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
681   xx_storel_64(&sum, sum_1x64);
682   return sum;
683 }
684 
mse_8xh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int h)685 static uint64_t mse_8xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
686                                           uint16_t *src, int sstride, int h) {
687   uint64_t sum = 0;
688   __m128i src_8x16;
689   __m128i dst_8x16;
690   __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
691   __m128i sub_result_8x16;
692   const __m128i zeros = _mm_setzero_si128();
693   __m128i square_result = _mm_setzero_si128();
694 
695   for (int i = 0; i < h; i++) {
696     dst_8x16 = _mm_loadu_si128((__m128i *)&dst[i * dstride]);
697     src_8x16 = _mm_loadu_si128((__m128i *)&src[i * sstride]);
698 
699     sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
700 
701     res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
702     res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
703 
704     res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
705     res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
706 
707     res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
708     res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
709     res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
710     res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
711 
712     square_result = _mm_add_epi64(
713         square_result,
714         _mm_add_epi64(
715             _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
716             res3_4x64));
717   }
718 
719   const __m128i sum_1x64 =
720       _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
721   xx_storel_64(&sum, sum_1x64);
722   return sum;
723 }
724 
aom_mse_wxh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int w,int h)725 uint64_t aom_mse_wxh_16bit_highbd_sse2(uint16_t *dst, int dstride,
726                                        uint16_t *src, int sstride, int w,
727                                        int h) {
728   assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
729          "w=8/4 and h=8/4 must satisfy");
730   switch (w) {
731     case 4: return mse_4xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
732     case 8: return mse_8xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
733     default: assert(0 && "unsupported width"); return -1;
734   }
735 }
736