1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <immintrin.h>
13
14 #include "config/aom_dsp_rtcd.h"
15
16 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
17 #include "aom_dsp/x86/synonyms.h"
18
mm256_add_hi_lo_epi16(const __m256i val)19 static INLINE __m128i mm256_add_hi_lo_epi16(const __m256i val) {
20 return _mm_add_epi16(_mm256_castsi256_si128(val),
21 _mm256_extractf128_si256(val, 1));
22 }
23
mm256_add_hi_lo_epi32(const __m256i val)24 static INLINE __m128i mm256_add_hi_lo_epi32(const __m256i val) {
25 return _mm_add_epi32(_mm256_castsi256_si128(val),
26 _mm256_extractf128_si256(val, 1));
27 }
28
variance_kernel_avx2(const __m256i src,const __m256i ref,__m256i * const sse,__m256i * const sum)29 static INLINE void variance_kernel_avx2(const __m256i src, const __m256i ref,
30 __m256i *const sse,
31 __m256i *const sum) {
32 const __m256i adj_sub = _mm256_set1_epi16((short)0xff01); // (1,-1)
33
34 // unpack into pairs of source and reference values
35 const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
36 const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
37
38 // subtract adjacent elements using src*1 + ref*-1
39 const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
40 const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
41 const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
42 const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
43
44 // add to the running totals
45 *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1));
46 *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
47 }
48
variance_final_from_32bit_sum_avx2(__m256i vsse,__m128i vsum,unsigned int * const sse)49 static INLINE int variance_final_from_32bit_sum_avx2(__m256i vsse, __m128i vsum,
50 unsigned int *const sse) {
51 // extract the low lane and add it to the high lane
52 const __m128i sse_reg_128 = mm256_add_hi_lo_epi32(vsse);
53
54 // unpack sse and sum registers and add
55 const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum);
56 const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum);
57 const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
58
59 // perform the final summation and extract the results
60 const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
61 *((int *)sse) = _mm_cvtsi128_si32(res);
62 return _mm_extract_epi32(res, 1);
63 }
64
65 // handle pixels (<= 512)
variance_final_512_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)66 static INLINE int variance_final_512_avx2(__m256i vsse, __m256i vsum,
67 unsigned int *const sse) {
68 // extract the low lane and add it to the high lane
69 const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
70 const __m128i vsum_64 = _mm_add_epi16(vsum_128, _mm_srli_si128(vsum_128, 8));
71 const __m128i sum_int32 = _mm_cvtepi16_epi32(vsum_64);
72 return variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse);
73 }
74
75 // handle 1024 pixels (32x32, 16x64, 64x16)
variance_final_1024_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)76 static INLINE int variance_final_1024_avx2(__m256i vsse, __m256i vsum,
77 unsigned int *const sse) {
78 // extract the low lane and add it to the high lane
79 const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
80 const __m128i vsum_64 =
81 _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128),
82 _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8)));
83 return variance_final_from_32bit_sum_avx2(vsse, vsum_64, sse);
84 }
85
sum_to_32bit_avx2(const __m256i sum)86 static INLINE __m256i sum_to_32bit_avx2(const __m256i sum) {
87 const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum));
88 const __m256i sum_hi =
89 _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1));
90 return _mm256_add_epi32(sum_lo, sum_hi);
91 }
92
93 // handle 2048 pixels (32x64, 64x32)
variance_final_2048_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)94 static INLINE int variance_final_2048_avx2(__m256i vsse, __m256i vsum,
95 unsigned int *const sse) {
96 vsum = sum_to_32bit_avx2(vsum);
97 const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);
98 return variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse);
99 }
100
variance16_kernel_avx2(const uint8_t * const src,const int src_stride,const uint8_t * const ref,const int ref_stride,__m256i * const sse,__m256i * const sum)101 static INLINE void variance16_kernel_avx2(
102 const uint8_t *const src, const int src_stride, const uint8_t *const ref,
103 const int ref_stride, __m256i *const sse, __m256i *const sum) {
104 const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
105 const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
106 const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
107 const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
108 const __m256i s = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
109 const __m256i r = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
110 variance_kernel_avx2(s, r, sse, sum);
111 }
112
variance32_kernel_avx2(const uint8_t * const src,const uint8_t * const ref,__m256i * const sse,__m256i * const sum)113 static INLINE void variance32_kernel_avx2(const uint8_t *const src,
114 const uint8_t *const ref,
115 __m256i *const sse,
116 __m256i *const sum) {
117 const __m256i s = _mm256_loadu_si256((__m256i const *)(src));
118 const __m256i r = _mm256_loadu_si256((__m256i const *)(ref));
119 variance_kernel_avx2(s, r, sse, sum);
120 }
121
variance16_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)122 static INLINE void variance16_avx2(const uint8_t *src, const int src_stride,
123 const uint8_t *ref, const int ref_stride,
124 const int h, __m256i *const vsse,
125 __m256i *const vsum) {
126 *vsum = _mm256_setzero_si256();
127
128 for (int i = 0; i < h; i += 2) {
129 variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum);
130 src += 2 * src_stride;
131 ref += 2 * ref_stride;
132 }
133 }
134
variance32_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)135 static INLINE void variance32_avx2(const uint8_t *src, const int src_stride,
136 const uint8_t *ref, const int ref_stride,
137 const int h, __m256i *const vsse,
138 __m256i *const vsum) {
139 *vsum = _mm256_setzero_si256();
140
141 for (int i = 0; i < h; i++) {
142 variance32_kernel_avx2(src, ref, vsse, vsum);
143 src += src_stride;
144 ref += ref_stride;
145 }
146 }
147
variance64_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)148 static INLINE void variance64_avx2(const uint8_t *src, const int src_stride,
149 const uint8_t *ref, const int ref_stride,
150 const int h, __m256i *const vsse,
151 __m256i *const vsum) {
152 *vsum = _mm256_setzero_si256();
153
154 for (int i = 0; i < h; i++) {
155 variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
156 variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
157 src += src_stride;
158 ref += ref_stride;
159 }
160 }
161
variance128_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)162 static INLINE void variance128_avx2(const uint8_t *src, const int src_stride,
163 const uint8_t *ref, const int ref_stride,
164 const int h, __m256i *const vsse,
165 __m256i *const vsum) {
166 *vsum = _mm256_setzero_si256();
167
168 for (int i = 0; i < h; i++) {
169 variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
170 variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
171 variance32_kernel_avx2(src + 64, ref + 64, vsse, vsum);
172 variance32_kernel_avx2(src + 96, ref + 96, vsse, vsum);
173 src += src_stride;
174 ref += ref_stride;
175 }
176 }
177
178 #define AOM_VAR_NO_LOOP_AVX2(bw, bh, bits, max_pixel) \
179 unsigned int aom_variance##bw##x##bh##_avx2( \
180 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
181 unsigned int *sse) { \
182 __m256i vsse = _mm256_setzero_si256(); \
183 __m256i vsum; \
184 variance##bw##_avx2(src, src_stride, ref, ref_stride, bh, &vsse, &vsum); \
185 const int sum = variance_final_##max_pixel##_avx2(vsse, vsum, sse); \
186 return *sse - (uint32_t)(((int64_t)sum * sum) >> bits); \
187 }
188
189 AOM_VAR_NO_LOOP_AVX2(16, 8, 7, 512)
190 AOM_VAR_NO_LOOP_AVX2(16, 16, 8, 512)
191 AOM_VAR_NO_LOOP_AVX2(16, 32, 9, 512)
192
193 AOM_VAR_NO_LOOP_AVX2(32, 16, 9, 512)
194 AOM_VAR_NO_LOOP_AVX2(32, 32, 10, 1024)
195 AOM_VAR_NO_LOOP_AVX2(32, 64, 11, 2048)
196
197 AOM_VAR_NO_LOOP_AVX2(64, 32, 11, 2048)
198
199 #if !CONFIG_REALTIME_ONLY
200 AOM_VAR_NO_LOOP_AVX2(64, 16, 10, 1024)
201 AOM_VAR_NO_LOOP_AVX2(32, 8, 8, 512)
202 AOM_VAR_NO_LOOP_AVX2(16, 64, 10, 1024)
203 AOM_VAR_NO_LOOP_AVX2(16, 4, 6, 512)
204 #endif
205
206 #define AOM_VAR_LOOP_AVX2(bw, bh, bits, uh) \
207 unsigned int aom_variance##bw##x##bh##_avx2( \
208 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
209 unsigned int *sse) { \
210 __m256i vsse = _mm256_setzero_si256(); \
211 __m256i vsum = _mm256_setzero_si256(); \
212 for (int i = 0; i < (bh / uh); i++) { \
213 __m256i vsum16; \
214 variance##bw##_avx2(src, src_stride, ref, ref_stride, uh, &vsse, \
215 &vsum16); \
216 vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16)); \
217 src += uh * src_stride; \
218 ref += uh * ref_stride; \
219 } \
220 const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum); \
221 const int sum = variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse); \
222 return *sse - (unsigned int)(((int64_t)sum * sum) >> bits); \
223 }
224
225 AOM_VAR_LOOP_AVX2(64, 64, 12, 32) // 64x32 * ( 64/32)
226 AOM_VAR_LOOP_AVX2(64, 128, 13, 32) // 64x32 * (128/32)
227 AOM_VAR_LOOP_AVX2(128, 64, 13, 16) // 128x16 * ( 64/16)
228 AOM_VAR_LOOP_AVX2(128, 128, 14, 16) // 128x16 * (128/16)
229
aom_mse16x16_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,unsigned int * sse)230 unsigned int aom_mse16x16_avx2(const uint8_t *src, int src_stride,
231 const uint8_t *ref, int ref_stride,
232 unsigned int *sse) {
233 aom_variance16x16_avx2(src, src_stride, ref, ref_stride, sse);
234 return *sse;
235 }
236
mm256_loadu2(const uint8_t * p0,const uint8_t * p1)237 static INLINE __m256i mm256_loadu2(const uint8_t *p0, const uint8_t *p1) {
238 const __m256i d =
239 _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
240 return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
241 }
242
mm256_loadu2_16(const uint16_t * p0,const uint16_t * p1)243 static INLINE __m256i mm256_loadu2_16(const uint16_t *p0, const uint16_t *p1) {
244 const __m256i d =
245 _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
246 return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
247 }
248
comp_mask_pred_line_avx2(const __m256i s0,const __m256i s1,const __m256i a,uint8_t * comp_pred)249 static INLINE void comp_mask_pred_line_avx2(const __m256i s0, const __m256i s1,
250 const __m256i a,
251 uint8_t *comp_pred) {
252 const __m256i alpha_max = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
253 const int16_t round_bits = 15 - AOM_BLEND_A64_ROUND_BITS;
254 const __m256i round_offset = _mm256_set1_epi16(1 << (round_bits));
255
256 const __m256i ma = _mm256_sub_epi8(alpha_max, a);
257
258 const __m256i ssAL = _mm256_unpacklo_epi8(s0, s1);
259 const __m256i aaAL = _mm256_unpacklo_epi8(a, ma);
260 const __m256i ssAH = _mm256_unpackhi_epi8(s0, s1);
261 const __m256i aaAH = _mm256_unpackhi_epi8(a, ma);
262
263 const __m256i blendAL = _mm256_maddubs_epi16(ssAL, aaAL);
264 const __m256i blendAH = _mm256_maddubs_epi16(ssAH, aaAH);
265 const __m256i roundAL = _mm256_mulhrs_epi16(blendAL, round_offset);
266 const __m256i roundAH = _mm256_mulhrs_epi16(blendAH, round_offset);
267
268 const __m256i roundA = _mm256_packus_epi16(roundAL, roundAH);
269 _mm256_storeu_si256((__m256i *)(comp_pred), roundA);
270 }
271
aom_comp_avg_pred_avx2(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride)272 void aom_comp_avg_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
273 int height, const uint8_t *ref, int ref_stride) {
274 int row = 0;
275 if (width == 8) {
276 do {
277 const __m256i pred_0123 = _mm256_loadu_si256((const __m256i *)(pred));
278 const __m128i ref_0 = _mm_loadl_epi64((const __m128i *)(ref));
279 const __m128i ref_1 =
280 _mm_loadl_epi64((const __m128i *)(ref + ref_stride));
281 const __m128i ref_2 =
282 _mm_loadl_epi64((const __m128i *)(ref + 2 * ref_stride));
283 const __m128i ref_3 =
284 _mm_loadl_epi64((const __m128i *)(ref + 3 * ref_stride));
285 const __m128i ref_01 = _mm_unpacklo_epi64(ref_0, ref_1);
286 const __m128i ref_23 = _mm_unpacklo_epi64(ref_2, ref_3);
287
288 const __m256i ref_0123 =
289 _mm256_inserti128_si256(_mm256_castsi128_si256(ref_01), ref_23, 1);
290 const __m256i average = _mm256_avg_epu8(pred_0123, ref_0123);
291 _mm256_storeu_si256((__m256i *)(comp_pred), average);
292
293 row += 4;
294 pred += 32;
295 comp_pred += 32;
296 ref += 4 * ref_stride;
297 } while (row < height);
298 } else if (width == 16) {
299 do {
300 const __m256i pred_0 = _mm256_loadu_si256((const __m256i *)(pred));
301 const __m256i pred_1 = _mm256_loadu_si256((const __m256i *)(pred + 32));
302 const __m256i tmp0 =
303 _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(ref)));
304 const __m256i ref_0 = _mm256_inserti128_si256(
305 tmp0, _mm_loadu_si128((const __m128i *)(ref + ref_stride)), 1);
306 const __m256i tmp1 = _mm256_castsi128_si256(
307 _mm_loadu_si128((const __m128i *)(ref + 2 * ref_stride)));
308 const __m256i ref_1 = _mm256_inserti128_si256(
309 tmp1, _mm_loadu_si128((const __m128i *)(ref + 3 * ref_stride)), 1);
310 const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
311 const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
312 _mm256_storeu_si256((__m256i *)(comp_pred), average_0);
313 _mm256_storeu_si256((__m256i *)(comp_pred + 32), average_1);
314
315 row += 4;
316 pred += 64;
317 comp_pred += 64;
318 ref += 4 * ref_stride;
319 } while (row < height);
320 } else if (width == 32) {
321 do {
322 const __m256i pred_0 = _mm256_loadu_si256((const __m256i *)(pred));
323 const __m256i pred_1 = _mm256_loadu_si256((const __m256i *)(pred + 32));
324 const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)(ref));
325 const __m256i ref_1 =
326 _mm256_loadu_si256((const __m256i *)(ref + ref_stride));
327 const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
328 const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
329 _mm256_storeu_si256((__m256i *)(comp_pred), average_0);
330 _mm256_storeu_si256((__m256i *)(comp_pred + 32), average_1);
331
332 row += 2;
333 pred += 64;
334 comp_pred += 64;
335 ref += 2 * ref_stride;
336 } while (row < height);
337 } else if (width % 64 == 0) {
338 do {
339 for (int x = 0; x < width; x += 64) {
340 const __m256i pred_0 = _mm256_loadu_si256((const __m256i *)(pred + x));
341 const __m256i pred_1 =
342 _mm256_loadu_si256((const __m256i *)(pred + x + 32));
343 const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)(ref + x));
344 const __m256i ref_1 =
345 _mm256_loadu_si256((const __m256i *)(ref + x + 32));
346 const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
347 const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
348 _mm256_storeu_si256((__m256i *)(comp_pred + x), average_0);
349 _mm256_storeu_si256((__m256i *)(comp_pred + x + 32), average_1);
350 }
351 row++;
352 pred += width;
353 comp_pred += width;
354 ref += ref_stride;
355 } while (row < height);
356 } else {
357 aom_comp_avg_pred_c(comp_pred, pred, width, height, ref, ref_stride);
358 }
359 }
360
aom_comp_mask_pred_avx2(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)361 void aom_comp_mask_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
362 int height, const uint8_t *ref, int ref_stride,
363 const uint8_t *mask, int mask_stride,
364 int invert_mask) {
365 int i = 0;
366 const uint8_t *src0 = invert_mask ? pred : ref;
367 const uint8_t *src1 = invert_mask ? ref : pred;
368 const int stride0 = invert_mask ? width : ref_stride;
369 const int stride1 = invert_mask ? ref_stride : width;
370 if (width == 8) {
371 comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
372 mask, mask_stride);
373 } else if (width == 16) {
374 do {
375 const __m256i sA0 = mm256_loadu2(src0 + stride0, src0);
376 const __m256i sA1 = mm256_loadu2(src1 + stride1, src1);
377 const __m256i aA = mm256_loadu2(mask + mask_stride, mask);
378 src0 += (stride0 << 1);
379 src1 += (stride1 << 1);
380 mask += (mask_stride << 1);
381 const __m256i sB0 = mm256_loadu2(src0 + stride0, src0);
382 const __m256i sB1 = mm256_loadu2(src1 + stride1, src1);
383 const __m256i aB = mm256_loadu2(mask + mask_stride, mask);
384 src0 += (stride0 << 1);
385 src1 += (stride1 << 1);
386 mask += (mask_stride << 1);
387 // comp_pred's stride == width == 16
388 comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
389 comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
390 comp_pred += (16 << 2);
391 i += 4;
392 } while (i < height);
393 } else {
394 do {
395 for (int x = 0; x < width; x += 32) {
396 const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0 + x));
397 const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1 + x));
398 const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask + x));
399
400 comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
401 comp_pred += 32;
402 }
403 src0 += stride0;
404 src1 += stride1;
405 mask += mask_stride;
406 i++;
407 } while (i < height);
408 }
409 }
410
highbd_comp_mask_pred_line_avx2(const __m256i s0,const __m256i s1,const __m256i a)411 static INLINE __m256i highbd_comp_mask_pred_line_avx2(const __m256i s0,
412 const __m256i s1,
413 const __m256i a) {
414 const __m256i alpha_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
415 const __m256i round_const =
416 _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
417 const __m256i a_inv = _mm256_sub_epi16(alpha_max, a);
418
419 const __m256i s_lo = _mm256_unpacklo_epi16(s0, s1);
420 const __m256i a_lo = _mm256_unpacklo_epi16(a, a_inv);
421 const __m256i pred_lo = _mm256_madd_epi16(s_lo, a_lo);
422 const __m256i pred_l = _mm256_srai_epi32(
423 _mm256_add_epi32(pred_lo, round_const), AOM_BLEND_A64_ROUND_BITS);
424
425 const __m256i s_hi = _mm256_unpackhi_epi16(s0, s1);
426 const __m256i a_hi = _mm256_unpackhi_epi16(a, a_inv);
427 const __m256i pred_hi = _mm256_madd_epi16(s_hi, a_hi);
428 const __m256i pred_h = _mm256_srai_epi32(
429 _mm256_add_epi32(pred_hi, round_const), AOM_BLEND_A64_ROUND_BITS);
430
431 const __m256i comp = _mm256_packs_epi32(pred_l, pred_h);
432
433 return comp;
434 }
435
aom_highbd_comp_mask_pred_avx2(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)436 void aom_highbd_comp_mask_pred_avx2(uint8_t *comp_pred8, const uint8_t *pred8,
437 int width, int height, const uint8_t *ref8,
438 int ref_stride, const uint8_t *mask,
439 int mask_stride, int invert_mask) {
440 int i = 0;
441 uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
442 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
443 uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
444 const uint16_t *src0 = invert_mask ? pred : ref;
445 const uint16_t *src1 = invert_mask ? ref : pred;
446 const int stride0 = invert_mask ? width : ref_stride;
447 const int stride1 = invert_mask ? ref_stride : width;
448 const __m256i zero = _mm256_setzero_si256();
449
450 if (width == 8) {
451 do {
452 const __m256i s0 = mm256_loadu2_16(src0 + stride0, src0);
453 const __m256i s1 = mm256_loadu2_16(src1 + stride1, src1);
454
455 const __m128i m_l = _mm_loadl_epi64((const __m128i *)mask);
456 const __m128i m_h = _mm_loadl_epi64((const __m128i *)(mask + 8));
457
458 __m256i m = _mm256_castsi128_si256(m_l);
459 m = _mm256_insertf128_si256(m, m_h, 1);
460 const __m256i m_16 = _mm256_unpacklo_epi8(m, zero);
461
462 const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
463
464 _mm_storeu_si128((__m128i *)(comp_pred), _mm256_castsi256_si128(comp));
465
466 _mm_storeu_si128((__m128i *)(comp_pred + width),
467 _mm256_extractf128_si256(comp, 1));
468
469 src0 += (stride0 << 1);
470 src1 += (stride1 << 1);
471 mask += (mask_stride << 1);
472 comp_pred += (width << 1);
473 i += 2;
474 } while (i < height);
475 } else if (width == 16) {
476 do {
477 const __m256i s0 = _mm256_loadu_si256((const __m256i *)(src0));
478 const __m256i s1 = _mm256_loadu_si256((const __m256i *)(src1));
479 const __m256i m_16 =
480 _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)mask));
481
482 const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
483
484 _mm256_storeu_si256((__m256i *)comp_pred, comp);
485
486 src0 += stride0;
487 src1 += stride1;
488 mask += mask_stride;
489 comp_pred += width;
490 i += 1;
491 } while (i < height);
492 } else {
493 do {
494 for (int x = 0; x < width; x += 32) {
495 const __m256i s0 = _mm256_loadu_si256((const __m256i *)(src0 + x));
496 const __m256i s2 = _mm256_loadu_si256((const __m256i *)(src0 + x + 16));
497 const __m256i s1 = _mm256_loadu_si256((const __m256i *)(src1 + x));
498 const __m256i s3 = _mm256_loadu_si256((const __m256i *)(src1 + x + 16));
499
500 const __m256i m01_16 =
501 _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(mask + x)));
502 const __m256i m23_16 = _mm256_cvtepu8_epi16(
503 _mm_loadu_si128((const __m128i *)(mask + x + 16)));
504
505 const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m01_16);
506 const __m256i comp1 = highbd_comp_mask_pred_line_avx2(s2, s3, m23_16);
507
508 _mm256_storeu_si256((__m256i *)comp_pred, comp);
509 _mm256_storeu_si256((__m256i *)(comp_pred + 16), comp1);
510
511 comp_pred += 32;
512 }
513 src0 += stride0;
514 src1 += stride1;
515 mask += mask_stride;
516 i += 1;
517 } while (i < height);
518 }
519 }
520
mse_4xh_16bit_avx2(uint8_t * dst,int dstride,uint16_t * src,int sstride,int h)521 static uint64_t mse_4xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
522 int sstride, int h) {
523 uint64_t sum = 0;
524 __m128i dst0_4x8, dst1_4x8, dst2_4x8, dst3_4x8, dst_16x8;
525 __m128i src0_4x16, src1_4x16, src2_4x16, src3_4x16;
526 __m256i src0_8x16, src1_8x16, dst_16x16, src_16x16;
527 __m256i res0_4x64, res1_4x64;
528 __m256i sub_result;
529 const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
530 __m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
531 for (int i = 0; i < h; i += 4) {
532 dst0_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 0) * dstride]));
533 dst1_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 1) * dstride]));
534 dst2_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 2) * dstride]));
535 dst3_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 3) * dstride]));
536 dst_16x8 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(dst0_4x8, dst1_4x8),
537 _mm_unpacklo_epi32(dst2_4x8, dst3_4x8));
538 dst_16x16 = _mm256_cvtepu8_epi16(dst_16x8);
539
540 src0_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
541 src1_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
542 src2_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 2) * sstride]));
543 src3_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 3) * sstride]));
544 src0_8x16 =
545 _mm256_castsi128_si256(_mm_unpacklo_epi64(src0_4x16, src1_4x16));
546 src1_8x16 =
547 _mm256_castsi128_si256(_mm_unpacklo_epi64(src2_4x16, src3_4x16));
548 src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
549
550 // r15 r14 r13------------r1 r0 - 16 bit
551 sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
552
553 // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
554 src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
555
556 // accumulation of result
557 square_result = _mm256_add_epi32(square_result, src_16x16);
558 }
559
560 // s5 s4 s1 s0 - 64bit
561 res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
562 // s7 s6 s3 s2 - 64bit
563 res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
564 // r3 r2 r1 r0 - 64bit
565 res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
566 // r1+r3 r2+r0 - 64bit
567 const __m128i sum_1x64 =
568 _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
569 _mm256_extracti128_si256(res0_4x64, 1));
570 xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
571 return sum;
572 }
573
574 // Compute mse of four consecutive 4x4 blocks.
575 // In src buffer, each 4x4 block in a 32x32 filter block is stored sequentially.
576 // Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
577 // buffer, thus dstride is a frame level stride.
mse_4xh_quad_16bit_avx2(uint8_t * dst,int dstride,uint16_t * src,int src_blk_stride,int h)578 static uint64_t mse_4xh_quad_16bit_avx2(uint8_t *dst, int dstride,
579 uint16_t *src, int src_blk_stride,
580 int h) {
581 uint64_t sum = 0;
582 __m128i dst0_16x8, dst1_16x8, dst2_16x8, dst3_16x8;
583 __m256i dst0_16x16, dst1_16x16, dst2_16x16, dst3_16x16;
584 __m256i res0_4x64, res1_4x64;
585 __m256i sub_result_0, sub_result_1, sub_result_2, sub_result_3;
586 const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
587 __m256i square_result = zeros;
588 uint16_t *src_temp = src;
589
590 for (int i = 0; i < h; i += 4) {
591 dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
592 dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
593 dst2_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 2) * dstride]));
594 dst3_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 3) * dstride]));
595
596 // row0 of 1st,2nd, 3rd and 4th 4x4 blocks- d00 d10 d20 d30
597 dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
598 // row1 of 1st,2nd, 3rd and 4th 4x4 blocks - d01 d11 d21 d31
599 dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
600 // row2 of 1st,2nd, 3rd and 4th 4x4 blocks - d02 d12 d22 d32
601 dst2_16x16 = _mm256_cvtepu8_epi16(dst2_16x8);
602 // row3 of 1st,2nd, 3rd and 4th 4x4 blocks - d03 d13 d23 d33
603 dst3_16x16 = _mm256_cvtepu8_epi16(dst3_16x8);
604
605 // All rows of 1st 4x4 block - r00 r01 r02 r03
606 __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
607 // All rows of 2nd 4x4 block - r10 r11 r12 r13
608 __m256i src1_16x16 =
609 _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
610 // All rows of 3rd 4x4 block - r20 r21 r22 r23
611 __m256i src2_16x16 =
612 _mm256_loadu_si256((__m256i const *)(&src_temp[2 * src_blk_stride]));
613 // All rows of 4th 4x4 block - r30 r31 r32 r33
614 __m256i src3_16x16 =
615 _mm256_loadu_si256((__m256i const *)(&src_temp[3 * src_blk_stride]));
616
617 // r00 r10 r02 r12
618 __m256i tmp0_16x16 = _mm256_unpacklo_epi64(src0_16x16, src1_16x16);
619 // r01 r11 r03 r13
620 __m256i tmp1_16x16 = _mm256_unpackhi_epi64(src0_16x16, src1_16x16);
621 // r20 r30 r22 r32
622 __m256i tmp2_16x16 = _mm256_unpacklo_epi64(src2_16x16, src3_16x16);
623 // r21 r31 r23 r33
624 __m256i tmp3_16x16 = _mm256_unpackhi_epi64(src2_16x16, src3_16x16);
625
626 // r00 r10 r20 r30
627 src0_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x20);
628 // r01 r11 r21 r31
629 src1_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x20);
630 // r02 r12 r22 r32
631 src2_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x31);
632 // r03 r13 r23 r33
633 src3_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x31);
634
635 // r15 r14 r13------------r1 r0 - 16 bit
636 sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(src0_16x16, dst0_16x16));
637 sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(src1_16x16, dst1_16x16));
638 sub_result_2 = _mm256_abs_epi16(_mm256_sub_epi16(src2_16x16, dst2_16x16));
639 sub_result_3 = _mm256_abs_epi16(_mm256_sub_epi16(src3_16x16, dst3_16x16));
640
641 // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
642 src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
643 src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
644 src2_16x16 = _mm256_madd_epi16(sub_result_2, sub_result_2);
645 src3_16x16 = _mm256_madd_epi16(sub_result_3, sub_result_3);
646
647 // accumulation of result
648 src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
649 src2_16x16 = _mm256_add_epi32(src2_16x16, src3_16x16);
650 const __m256i square_result_0 = _mm256_add_epi32(src0_16x16, src2_16x16);
651 square_result = _mm256_add_epi32(square_result, square_result_0);
652 src_temp += 16;
653 }
654
655 // s5 s4 s1 s0 - 64bit
656 res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
657 // s7 s6 s3 s2 - 64bit
658 res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
659 // r3 r2 r1 r0 - 64bit
660 res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
661 // r1+r3 r2+r0 - 64bit
662 const __m128i sum_1x64 =
663 _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
664 _mm256_extracti128_si256(res0_4x64, 1));
665 xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
666 return sum;
667 }
668
mse_8xh_16bit_avx2(uint8_t * dst,int dstride,uint16_t * src,int sstride,int h)669 static uint64_t mse_8xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
670 int sstride, int h) {
671 uint64_t sum = 0;
672 __m128i dst0_8x8, dst1_8x8, dst3_16x8;
673 __m256i src0_8x16, src1_8x16, src_16x16, dst_16x16;
674 __m256i res0_4x64, res1_4x64;
675 __m256i sub_result;
676 const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
677 __m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
678
679 for (int i = 0; i < h; i += 2) {
680 dst0_8x8 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
681 dst1_8x8 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 1) * dstride]));
682 dst3_16x8 = _mm_unpacklo_epi64(dst0_8x8, dst1_8x8);
683 dst_16x16 = _mm256_cvtepu8_epi16(dst3_16x8);
684
685 src0_8x16 =
686 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)&src[i * sstride]));
687 src1_8x16 = _mm256_castsi128_si256(
688 _mm_loadu_si128((__m128i *)&src[(i + 1) * sstride]));
689 src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
690
691 // r15 r14 r13 - - - r1 r0 - 16 bit
692 sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
693
694 // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
695 src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
696
697 // accumulation of result
698 square_result = _mm256_add_epi32(square_result, src_16x16);
699 }
700
701 // s5 s4 s1 s0 - 64bit
702 res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
703 // s7 s6 s3 s2 - 64bit
704 res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
705 // r3 r2 r1 r0 - 64bit
706 res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
707 // r1+r3 r2+r0 - 64bit
708 const __m128i sum_1x64 =
709 _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
710 _mm256_extracti128_si256(res0_4x64, 1));
711 xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
712 return sum;
713 }
714
715 // Compute mse of two consecutive 8x8 blocks.
716 // In src buffer, each 8x8 block in a 64x64 filter block is stored sequentially.
717 // Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
718 // buffer, thus dstride is a frame level stride.
mse_8xh_dual_16bit_avx2(uint8_t * dst,int dstride,uint16_t * src,int src_blk_stride,int h)719 static uint64_t mse_8xh_dual_16bit_avx2(uint8_t *dst, int dstride,
720 uint16_t *src, int src_blk_stride,
721 int h) {
722 uint64_t sum = 0;
723 __m128i dst0_16x8, dst1_16x8;
724 __m256i dst0_16x16, dst1_16x16;
725 __m256i res0_4x64, res1_4x64;
726 __m256i sub_result_0, sub_result_1;
727 const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
728 __m256i square_result = zeros;
729 uint16_t *src_temp = src;
730
731 for (int i = 0; i < h; i += 2) {
732 dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
733 dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
734
735 // row0 of 1st and 2nd 8x8 block - d00 d10
736 dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
737 // row1 of 1st and 2nd 8x8 block - d01 d11
738 dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
739
740 // 2 rows of 1st 8x8 block - r00 r01
741 __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
742 // 2 rows of 2nd 8x8 block - r10 r11
743 __m256i src1_16x16 =
744 _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
745 // r00 r10 - 128bit
746 __m256i tmp0_16x16 =
747 _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x20);
748 // r01 r11 - 128bit
749 __m256i tmp1_16x16 =
750 _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x31);
751
752 // r15 r14 r13------------r1 r0 - 16 bit
753 sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(tmp0_16x16, dst0_16x16));
754 sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(tmp1_16x16, dst1_16x16));
755
756 // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit each
757 src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
758 src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
759
760 // accumulation of result
761 src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
762 square_result = _mm256_add_epi32(square_result, src0_16x16);
763 src_temp += 16;
764 }
765
766 // s5 s4 s1 s0 - 64bit
767 res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
768 // s7 s6 s3 s2 - 64bit
769 res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
770 // r3 r2 r1 r0 - 64bit
771 res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
772 // r1+r3 r2+r0 - 64bit
773 const __m128i sum_1x64 =
774 _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
775 _mm256_extracti128_si256(res0_4x64, 1));
776 xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
777 return sum;
778 }
779
aom_mse_wxh_16bit_avx2(uint8_t * dst,int dstride,uint16_t * src,int sstride,int w,int h)780 uint64_t aom_mse_wxh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
781 int sstride, int w, int h) {
782 assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
783 "w=8/4 and h=8/4 must be satisfied");
784 switch (w) {
785 case 4: return mse_4xh_16bit_avx2(dst, dstride, src, sstride, h);
786 case 8: return mse_8xh_16bit_avx2(dst, dstride, src, sstride, h);
787 default: assert(0 && "unsupported width"); return -1;
788 }
789 }
790
791 // Computes mse of two 8x8 or four 4x4 consecutive blocks. Luma plane uses 8x8
792 // block and Chroma uses 4x4 block. In src buffer, each block in a filter block
793 // is stored sequentially. Hence src_blk_stride is same as block width. Whereas
794 // dst buffer is a frame buffer, thus dstride is a frame level stride.
aom_mse_16xh_16bit_avx2(uint8_t * dst,int dstride,uint16_t * src,int w,int h)795 uint64_t aom_mse_16xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
796 int w, int h) {
797 assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
798 "w=8/4 and h=8/4 must be satisfied");
799 switch (w) {
800 case 4: return mse_4xh_quad_16bit_avx2(dst, dstride, src, w * h, h);
801 case 8: return mse_8xh_dual_16bit_avx2(dst, dstride, src, w * h, h);
802 default: assert(0 && "unsupported width"); return -1;
803 }
804 }
805
calc_sum_sse_wd32_avx2(const uint8_t * src,const uint8_t * ref,__m256i set_one_minusone,__m256i sse_8x16[2],__m256i sum_8x16[2])806 static INLINE void calc_sum_sse_wd32_avx2(const uint8_t *src,
807 const uint8_t *ref,
808 __m256i set_one_minusone,
809 __m256i sse_8x16[2],
810 __m256i sum_8x16[2]) {
811 const __m256i s00_256 = _mm256_loadu_si256((__m256i const *)(src));
812 const __m256i r00_256 = _mm256_loadu_si256((__m256i const *)(ref));
813
814 const __m256i u_low_256 = _mm256_unpacklo_epi8(s00_256, r00_256);
815 const __m256i u_high_256 = _mm256_unpackhi_epi8(s00_256, r00_256);
816
817 const __m256i diff0 = _mm256_maddubs_epi16(u_low_256, set_one_minusone);
818 const __m256i diff1 = _mm256_maddubs_epi16(u_high_256, set_one_minusone);
819
820 sse_8x16[0] = _mm256_add_epi32(sse_8x16[0], _mm256_madd_epi16(diff0, diff0));
821 sse_8x16[1] = _mm256_add_epi32(sse_8x16[1], _mm256_madd_epi16(diff1, diff1));
822 sum_8x16[0] = _mm256_add_epi16(sum_8x16[0], diff0);
823 sum_8x16[1] = _mm256_add_epi16(sum_8x16[1], diff1);
824 }
825
calc_sum_sse_order(__m256i * sse_hx16,__m256i * sum_hx16,unsigned int * tot_sse,int * tot_sum)826 static INLINE __m256i calc_sum_sse_order(__m256i *sse_hx16, __m256i *sum_hx16,
827 unsigned int *tot_sse, int *tot_sum) {
828 // s00 s01 s10 s11 s20 s21 s30 s31
829 const __m256i sse_results = _mm256_hadd_epi32(sse_hx16[0], sse_hx16[1]);
830 // d00 d01 d02 d03 | d10 d11 d12 d13 | d20 d21 d22 d23 | d30 d31 d32 d33
831 const __m256i sum_result_r0 = _mm256_hadd_epi16(sum_hx16[0], sum_hx16[1]);
832 // d00 d01 d10 d11 | d00 d02 d10 d11 | d20 d21 d30 d31 | d20 d21 d30 d31
833 const __m256i sum_result_1 = _mm256_hadd_epi16(sum_result_r0, sum_result_r0);
834 // d00 d01 d10 d11 d20 d21 d30 d31 | X
835 const __m256i sum_result_3 = _mm256_permute4x64_epi64(sum_result_1, 0x08);
836 // d00 d01 d10 d11 d20 d21 d30 d31
837 const __m256i sum_results =
838 _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum_result_3));
839
840 // Add sum & sse registers appropriately to get total sum & sse separately.
841 // s0 s1 d0 d1 s2 s3 d2 d3
842 const __m256i sum_sse_add = _mm256_hadd_epi32(sse_results, sum_results);
843 // s0 s1 s2 s3 d0 d1 d2 d3
844 const __m256i sum_sse_order_add = _mm256_permute4x64_epi64(sum_sse_add, 0xd8);
845 // s0+s1 s2+s3 s0+s1 s2+s3 d0+d1 d2+d3 d0+d1 d2+d3
846 const __m256i sum_sse_order_add_1 =
847 _mm256_hadd_epi32(sum_sse_order_add, sum_sse_order_add);
848 // s0 x x x | d0 x x x
849 const __m256i sum_sse_order_add_final =
850 _mm256_hadd_epi32(sum_sse_order_add_1, sum_sse_order_add_1);
851 // s0
852 const uint32_t first_value =
853 (uint32_t)_mm256_extract_epi32(sum_sse_order_add_final, 0);
854 *tot_sse += first_value;
855 // d0
856 const int second_value = _mm256_extract_epi32(sum_sse_order_add_final, 4);
857 *tot_sum += second_value;
858 return sum_sse_order_add;
859 }
860
get_var_sse_sum_8x8_quad_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,const int ref_stride,const int h,uint32_t * sse8x8,int * sum8x8,unsigned int * tot_sse,int * tot_sum,uint32_t * var8x8)861 static INLINE void get_var_sse_sum_8x8_quad_avx2(
862 const uint8_t *src, int src_stride, const uint8_t *ref,
863 const int ref_stride, const int h, uint32_t *sse8x8, int *sum8x8,
864 unsigned int *tot_sse, int *tot_sum, uint32_t *var8x8) {
865 assert(h <= 128); // May overflow for larger height.
866 __m256i sse_8x16[2], sum_8x16[2];
867 sum_8x16[0] = _mm256_setzero_si256();
868 sse_8x16[0] = _mm256_setzero_si256();
869 sum_8x16[1] = sum_8x16[0];
870 sse_8x16[1] = sse_8x16[0];
871 const __m256i set_one_minusone = _mm256_set1_epi16((short)0xff01);
872
873 for (int i = 0; i < h; i++) {
874 // Process 8x32 block of one row.
875 calc_sum_sse_wd32_avx2(src, ref, set_one_minusone, sse_8x16, sum_8x16);
876 src += src_stride;
877 ref += ref_stride;
878 }
879
880 const __m256i sum_sse_order_add =
881 calc_sum_sse_order(sse_8x16, sum_8x16, tot_sse, tot_sum);
882
883 // s0 s1 s2 s3
884 _mm_storeu_si128((__m128i *)sse8x8,
885 _mm256_castsi256_si128(sum_sse_order_add));
886 // d0 d1 d2 d3
887 const __m128i sum_temp8x8 = _mm256_extractf128_si256(sum_sse_order_add, 1);
888 _mm_storeu_si128((__m128i *)sum8x8, sum_temp8x8);
889
890 // (d0xd0 >> 6)=f0 (d1xd1 >> 6)=f1 (d2xd2 >> 6)=f2 (d3xd3 >> 6)=f3
891 const __m128i mull_results =
892 _mm_srli_epi32(_mm_mullo_epi32(sum_temp8x8, sum_temp8x8), 6);
893 // s0-f0=v0 s1-f1=v1 s2-f2=v2 s3-f3=v3
894 const __m128i variance_8x8 =
895 _mm_sub_epi32(_mm256_castsi256_si128(sum_sse_order_add), mull_results);
896 // v0 v1 v2 v3
897 _mm_storeu_si128((__m128i *)var8x8, variance_8x8);
898 }
899
get_var_sse_sum_16x16_dual_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,const int ref_stride,const int h,uint32_t * sse16x16,unsigned int * tot_sse,int * tot_sum,uint32_t * var16x16)900 static INLINE void get_var_sse_sum_16x16_dual_avx2(
901 const uint8_t *src, int src_stride, const uint8_t *ref,
902 const int ref_stride, const int h, uint32_t *sse16x16,
903 unsigned int *tot_sse, int *tot_sum, uint32_t *var16x16) {
904 assert(h <= 128); // May overflow for larger height.
905 __m256i sse_16x16[2], sum_16x16[2];
906 sum_16x16[0] = _mm256_setzero_si256();
907 sse_16x16[0] = _mm256_setzero_si256();
908 sum_16x16[1] = sum_16x16[0];
909 sse_16x16[1] = sse_16x16[0];
910 const __m256i set_one_minusone = _mm256_set1_epi16((short)0xff01);
911
912 for (int i = 0; i < h; i++) {
913 // Process 16x32 block of one row.
914 calc_sum_sse_wd32_avx2(src, ref, set_one_minusone, sse_16x16, sum_16x16);
915 src += src_stride;
916 ref += ref_stride;
917 }
918
919 const __m256i sum_sse_order_add =
920 calc_sum_sse_order(sse_16x16, sum_16x16, tot_sse, tot_sum);
921
922 const __m256i sum_sse_order_add_1 =
923 _mm256_hadd_epi32(sum_sse_order_add, sum_sse_order_add);
924
925 // s0+s1 s2+s3 x x
926 _mm_storel_epi64((__m128i *)sse16x16,
927 _mm256_castsi256_si128(sum_sse_order_add_1));
928
929 // d0+d1 d2+d3 x x
930 const __m128i sum_temp16x16 =
931 _mm256_extractf128_si256(sum_sse_order_add_1, 1);
932
933 // (d0xd0 >> 6)=f0 (d1xd1 >> 6)=f1 (d2xd2 >> 6)=f2 (d3xd3 >> 6)=f3
934 const __m128i mull_results =
935 _mm_srli_epi32(_mm_mullo_epi32(sum_temp16x16, sum_temp16x16), 8);
936
937 // s0-f0=v0 s1-f1=v1 s2-f2=v2 s3-f3=v3
938 const __m128i variance_16x16 =
939 _mm_sub_epi32(_mm256_castsi256_si128(sum_sse_order_add_1), mull_results);
940
941 // v0 v1 v2 v3
942 _mm_storel_epi64((__m128i *)var16x16, variance_16x16);
943 }
944
aom_get_var_sse_sum_8x8_quad_avx2(const uint8_t * src_ptr,int source_stride,const uint8_t * ref_ptr,int ref_stride,uint32_t * sse8x8,int * sum8x8,unsigned int * tot_sse,int * tot_sum,uint32_t * var8x8)945 void aom_get_var_sse_sum_8x8_quad_avx2(const uint8_t *src_ptr,
946 int source_stride,
947 const uint8_t *ref_ptr, int ref_stride,
948 uint32_t *sse8x8, int *sum8x8,
949 unsigned int *tot_sse, int *tot_sum,
950 uint32_t *var8x8) {
951 get_var_sse_sum_8x8_quad_avx2(src_ptr, source_stride, ref_ptr, ref_stride, 8,
952 sse8x8, sum8x8, tot_sse, tot_sum, var8x8);
953 }
954
aom_get_var_sse_sum_16x16_dual_avx2(const uint8_t * src_ptr,int source_stride,const uint8_t * ref_ptr,int ref_stride,uint32_t * sse16x16,unsigned int * tot_sse,int * tot_sum,uint32_t * var16x16)955 void aom_get_var_sse_sum_16x16_dual_avx2(const uint8_t *src_ptr,
956 int source_stride,
957 const uint8_t *ref_ptr, int ref_stride,
958 uint32_t *sse16x16,
959 unsigned int *tot_sse, int *tot_sum,
960 uint32_t *var16x16) {
961 get_var_sse_sum_16x16_dual_avx2(src_ptr, source_stride, ref_ptr, ref_stride,
962 16, sse16x16, tot_sse, tot_sum, var16x16);
963 }
964