1 /*
2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <immintrin.h>
13
14 #include "config/av1_rtcd.h"
15
16 #include "aom/aom_integer.h"
17 #include "aom_dsp/blend.h"
18 #include "aom_dsp/x86/synonyms.h"
19 #include "aom_dsp/x86/synonyms_avx2.h"
20 #include "av1/common/blockd.h"
21
calc_mask_avx2(const __m256i mask_base,const __m256i s0,const __m256i s1)22 static INLINE __m256i calc_mask_avx2(const __m256i mask_base, const __m256i s0,
23 const __m256i s1) {
24 const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(s0, s1));
25 return _mm256_abs_epi16(
26 _mm256_add_epi16(mask_base, _mm256_srli_epi16(diff, 4)));
27 // clamp(diff, 0, 64) can be skiped for diff is always in the range ( 38, 54)
28 }
av1_build_compound_diffwtd_mask_avx2(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int stride0,const uint8_t * src1,int stride1,int h,int w)29 void av1_build_compound_diffwtd_mask_avx2(uint8_t *mask,
30 DIFFWTD_MASK_TYPE mask_type,
31 const uint8_t *src0, int stride0,
32 const uint8_t *src1, int stride1,
33 int h, int w) {
34 const int mb = (mask_type == DIFFWTD_38_INV) ? AOM_BLEND_A64_MAX_ALPHA : 0;
35 const __m256i y_mask_base = _mm256_set1_epi16(38 - mb);
36 int i = 0;
37 if (4 == w) {
38 do {
39 const __m128i s0A = xx_loadl_32(src0);
40 const __m128i s0B = xx_loadl_32(src0 + stride0);
41 const __m128i s0C = xx_loadl_32(src0 + stride0 * 2);
42 const __m128i s0D = xx_loadl_32(src0 + stride0 * 3);
43 const __m128i s0AB = _mm_unpacklo_epi32(s0A, s0B);
44 const __m128i s0CD = _mm_unpacklo_epi32(s0C, s0D);
45 const __m128i s0ABCD = _mm_unpacklo_epi64(s0AB, s0CD);
46 const __m256i s0ABCD_w = _mm256_cvtepu8_epi16(s0ABCD);
47
48 const __m128i s1A = xx_loadl_32(src1);
49 const __m128i s1B = xx_loadl_32(src1 + stride1);
50 const __m128i s1C = xx_loadl_32(src1 + stride1 * 2);
51 const __m128i s1D = xx_loadl_32(src1 + stride1 * 3);
52 const __m128i s1AB = _mm_unpacklo_epi32(s1A, s1B);
53 const __m128i s1CD = _mm_unpacklo_epi32(s1C, s1D);
54 const __m128i s1ABCD = _mm_unpacklo_epi64(s1AB, s1CD);
55 const __m256i s1ABCD_w = _mm256_cvtepu8_epi16(s1ABCD);
56 const __m256i m16 = calc_mask_avx2(y_mask_base, s0ABCD_w, s1ABCD_w);
57 const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
58 const __m128i x_m8 =
59 _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8));
60 xx_storeu_128(mask, x_m8);
61 src0 += (stride0 << 2);
62 src1 += (stride1 << 2);
63 mask += 16;
64 i += 4;
65 } while (i < h);
66 } else if (8 == w) {
67 do {
68 const __m128i s0A = xx_loadl_64(src0);
69 const __m128i s0B = xx_loadl_64(src0 + stride0);
70 const __m128i s0C = xx_loadl_64(src0 + stride0 * 2);
71 const __m128i s0D = xx_loadl_64(src0 + stride0 * 3);
72 const __m256i s0AC_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s0A, s0C));
73 const __m256i s0BD_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s0B, s0D));
74 const __m128i s1A = xx_loadl_64(src1);
75 const __m128i s1B = xx_loadl_64(src1 + stride1);
76 const __m128i s1C = xx_loadl_64(src1 + stride1 * 2);
77 const __m128i s1D = xx_loadl_64(src1 + stride1 * 3);
78 const __m256i s1AB_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s1A, s1C));
79 const __m256i s1CD_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s1B, s1D));
80 const __m256i m16AC = calc_mask_avx2(y_mask_base, s0AC_w, s1AB_w);
81 const __m256i m16BD = calc_mask_avx2(y_mask_base, s0BD_w, s1CD_w);
82 const __m256i m8 = _mm256_packus_epi16(m16AC, m16BD);
83 yy_storeu_256(mask, m8);
84 src0 += stride0 << 2;
85 src1 += stride1 << 2;
86 mask += 32;
87 i += 4;
88 } while (i < h);
89 } else if (16 == w) {
90 do {
91 const __m128i s0A = xx_load_128(src0);
92 const __m128i s0B = xx_load_128(src0 + stride0);
93 const __m128i s1A = xx_load_128(src1);
94 const __m128i s1B = xx_load_128(src1 + stride1);
95 const __m256i s0AL = _mm256_cvtepu8_epi16(s0A);
96 const __m256i s0BL = _mm256_cvtepu8_epi16(s0B);
97 const __m256i s1AL = _mm256_cvtepu8_epi16(s1A);
98 const __m256i s1BL = _mm256_cvtepu8_epi16(s1B);
99
100 const __m256i m16AL = calc_mask_avx2(y_mask_base, s0AL, s1AL);
101 const __m256i m16BL = calc_mask_avx2(y_mask_base, s0BL, s1BL);
102
103 const __m256i m8 =
104 _mm256_permute4x64_epi64(_mm256_packus_epi16(m16AL, m16BL), 0xd8);
105 yy_storeu_256(mask, m8);
106 src0 += stride0 << 1;
107 src1 += stride1 << 1;
108 mask += 32;
109 i += 2;
110 } while (i < h);
111 } else {
112 do {
113 int j = 0;
114 do {
115 const __m256i s0 = yy_loadu_256(src0 + j);
116 const __m256i s1 = yy_loadu_256(src1 + j);
117 const __m256i s0L = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(s0));
118 const __m256i s1L = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(s1));
119 const __m256i s0H =
120 _mm256_cvtepu8_epi16(_mm256_extracti128_si256(s0, 1));
121 const __m256i s1H =
122 _mm256_cvtepu8_epi16(_mm256_extracti128_si256(s1, 1));
123 const __m256i m16L = calc_mask_avx2(y_mask_base, s0L, s1L);
124 const __m256i m16H = calc_mask_avx2(y_mask_base, s0H, s1H);
125 const __m256i m8 =
126 _mm256_permute4x64_epi64(_mm256_packus_epi16(m16L, m16H), 0xd8);
127 yy_storeu_256(mask + j, m8);
128 j += 32;
129 } while (j < w);
130 src0 += stride0;
131 src1 += stride1;
132 mask += w;
133 i += 1;
134 } while (i < h);
135 }
136 }
137
calc_mask_d16_avx2(const __m256i * data_src0,const __m256i * data_src1,const __m256i * round_const,const __m256i * mask_base_16,const __m256i * clip_diff,int round)138 static INLINE __m256i calc_mask_d16_avx2(const __m256i *data_src0,
139 const __m256i *data_src1,
140 const __m256i *round_const,
141 const __m256i *mask_base_16,
142 const __m256i *clip_diff, int round) {
143 const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1);
144 const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0);
145 const __m256i diff = _mm256_max_epu16(diffa, diffb);
146 const __m256i diff_round =
147 _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round);
148 const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2);
149 const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16);
150 const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff);
151 return diff_clamp;
152 }
153
calc_mask_d16_inv_avx2(const __m256i * data_src0,const __m256i * data_src1,const __m256i * round_const,const __m256i * mask_base_16,const __m256i * clip_diff,int round)154 static INLINE __m256i calc_mask_d16_inv_avx2(const __m256i *data_src0,
155 const __m256i *data_src1,
156 const __m256i *round_const,
157 const __m256i *mask_base_16,
158 const __m256i *clip_diff,
159 int round) {
160 const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1);
161 const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0);
162 const __m256i diff = _mm256_max_epu16(diffa, diffb);
163 const __m256i diff_round =
164 _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round);
165 const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2);
166 const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16);
167 const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff);
168 const __m256i diff_const_16 = _mm256_sub_epi16(*clip_diff, diff_clamp);
169 return diff_const_16;
170 }
171
build_compound_diffwtd_mask_d16_avx2(uint8_t * mask,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,int shift)172 static INLINE void build_compound_diffwtd_mask_d16_avx2(
173 uint8_t *mask, const CONV_BUF_TYPE *src0, int src0_stride,
174 const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, int shift) {
175 const int mask_base = 38;
176 const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1);
177 const __m256i y38 = _mm256_set1_epi16(mask_base);
178 const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
179 int i = 0;
180 if (w == 4) {
181 do {
182 const __m128i s0A = xx_loadl_64(src0);
183 const __m128i s0B = xx_loadl_64(src0 + src0_stride);
184 const __m128i s0C = xx_loadl_64(src0 + src0_stride * 2);
185 const __m128i s0D = xx_loadl_64(src0 + src0_stride * 3);
186 const __m128i s1A = xx_loadl_64(src1);
187 const __m128i s1B = xx_loadl_64(src1 + src1_stride);
188 const __m128i s1C = xx_loadl_64(src1 + src1_stride * 2);
189 const __m128i s1D = xx_loadl_64(src1 + src1_stride * 3);
190 const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0C, s0D),
191 _mm_unpacklo_epi64(s0A, s0B));
192 const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1C, s1D),
193 _mm_unpacklo_epi64(s1A, s1B));
194 const __m256i m16 = calc_mask_d16_avx2(&s0, &s1, &_r, &y38, &y64, shift);
195 const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
196 xx_storeu_128(mask,
197 _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8)));
198 src0 += src0_stride << 2;
199 src1 += src1_stride << 2;
200 mask += 16;
201 i += 4;
202 } while (i < h);
203 } else if (w == 8) {
204 do {
205 const __m256i s0AB = yy_loadu2_128(src0 + src0_stride, src0);
206 const __m256i s0CD =
207 yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2);
208 const __m256i s1AB = yy_loadu2_128(src1 + src1_stride, src1);
209 const __m256i s1CD =
210 yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2);
211 const __m256i m16AB =
212 calc_mask_d16_avx2(&s0AB, &s1AB, &_r, &y38, &y64, shift);
213 const __m256i m16CD =
214 calc_mask_d16_avx2(&s0CD, &s1CD, &_r, &y38, &y64, shift);
215 const __m256i m8 = _mm256_packus_epi16(m16AB, m16CD);
216 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
217 src0 += src0_stride << 2;
218 src1 += src1_stride << 2;
219 mask += 32;
220 i += 4;
221 } while (i < h);
222 } else if (w == 16) {
223 do {
224 const __m256i s0A = yy_loadu_256(src0);
225 const __m256i s0B = yy_loadu_256(src0 + src0_stride);
226 const __m256i s1A = yy_loadu_256(src1);
227 const __m256i s1B = yy_loadu_256(src1 + src1_stride);
228 const __m256i m16A =
229 calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
230 const __m256i m16B =
231 calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
232 const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
233 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
234 src0 += src0_stride << 1;
235 src1 += src1_stride << 1;
236 mask += 32;
237 i += 2;
238 } while (i < h);
239 } else if (w == 32) {
240 do {
241 const __m256i s0A = yy_loadu_256(src0);
242 const __m256i s0B = yy_loadu_256(src0 + 16);
243 const __m256i s1A = yy_loadu_256(src1);
244 const __m256i s1B = yy_loadu_256(src1 + 16);
245 const __m256i m16A =
246 calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
247 const __m256i m16B =
248 calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
249 const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
250 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
251 src0 += src0_stride;
252 src1 += src1_stride;
253 mask += 32;
254 i += 1;
255 } while (i < h);
256 } else if (w == 64) {
257 do {
258 const __m256i s0A = yy_loadu_256(src0);
259 const __m256i s0B = yy_loadu_256(src0 + 16);
260 const __m256i s0C = yy_loadu_256(src0 + 32);
261 const __m256i s0D = yy_loadu_256(src0 + 48);
262 const __m256i s1A = yy_loadu_256(src1);
263 const __m256i s1B = yy_loadu_256(src1 + 16);
264 const __m256i s1C = yy_loadu_256(src1 + 32);
265 const __m256i s1D = yy_loadu_256(src1 + 48);
266 const __m256i m16A =
267 calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
268 const __m256i m16B =
269 calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
270 const __m256i m16C =
271 calc_mask_d16_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
272 const __m256i m16D =
273 calc_mask_d16_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
274 const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
275 const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
276 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
277 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
278 src0 += src0_stride;
279 src1 += src1_stride;
280 mask += 64;
281 i += 1;
282 } while (i < h);
283 } else {
284 do {
285 const __m256i s0A = yy_loadu_256(src0);
286 const __m256i s0B = yy_loadu_256(src0 + 16);
287 const __m256i s0C = yy_loadu_256(src0 + 32);
288 const __m256i s0D = yy_loadu_256(src0 + 48);
289 const __m256i s0E = yy_loadu_256(src0 + 64);
290 const __m256i s0F = yy_loadu_256(src0 + 80);
291 const __m256i s0G = yy_loadu_256(src0 + 96);
292 const __m256i s0H = yy_loadu_256(src0 + 112);
293 const __m256i s1A = yy_loadu_256(src1);
294 const __m256i s1B = yy_loadu_256(src1 + 16);
295 const __m256i s1C = yy_loadu_256(src1 + 32);
296 const __m256i s1D = yy_loadu_256(src1 + 48);
297 const __m256i s1E = yy_loadu_256(src1 + 64);
298 const __m256i s1F = yy_loadu_256(src1 + 80);
299 const __m256i s1G = yy_loadu_256(src1 + 96);
300 const __m256i s1H = yy_loadu_256(src1 + 112);
301 const __m256i m16A =
302 calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
303 const __m256i m16B =
304 calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
305 const __m256i m16C =
306 calc_mask_d16_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
307 const __m256i m16D =
308 calc_mask_d16_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
309 const __m256i m16E =
310 calc_mask_d16_avx2(&s0E, &s1E, &_r, &y38, &y64, shift);
311 const __m256i m16F =
312 calc_mask_d16_avx2(&s0F, &s1F, &_r, &y38, &y64, shift);
313 const __m256i m16G =
314 calc_mask_d16_avx2(&s0G, &s1G, &_r, &y38, &y64, shift);
315 const __m256i m16H =
316 calc_mask_d16_avx2(&s0H, &s1H, &_r, &y38, &y64, shift);
317 const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
318 const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
319 const __m256i m8EF = _mm256_packus_epi16(m16E, m16F);
320 const __m256i m8GH = _mm256_packus_epi16(m16G, m16H);
321 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
322 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
323 yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8EF, 0xd8));
324 yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8GH, 0xd8));
325 src0 += src0_stride;
326 src1 += src1_stride;
327 mask += 128;
328 i += 1;
329 } while (i < h);
330 }
331 }
332
build_compound_diffwtd_mask_d16_inv_avx2(uint8_t * mask,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,int shift)333 static INLINE void build_compound_diffwtd_mask_d16_inv_avx2(
334 uint8_t *mask, const CONV_BUF_TYPE *src0, int src0_stride,
335 const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, int shift) {
336 const int mask_base = 38;
337 const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1);
338 const __m256i y38 = _mm256_set1_epi16(mask_base);
339 const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
340 int i = 0;
341 if (w == 4) {
342 do {
343 const __m128i s0A = xx_loadl_64(src0);
344 const __m128i s0B = xx_loadl_64(src0 + src0_stride);
345 const __m128i s0C = xx_loadl_64(src0 + src0_stride * 2);
346 const __m128i s0D = xx_loadl_64(src0 + src0_stride * 3);
347 const __m128i s1A = xx_loadl_64(src1);
348 const __m128i s1B = xx_loadl_64(src1 + src1_stride);
349 const __m128i s1C = xx_loadl_64(src1 + src1_stride * 2);
350 const __m128i s1D = xx_loadl_64(src1 + src1_stride * 3);
351 const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0C, s0D),
352 _mm_unpacklo_epi64(s0A, s0B));
353 const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1C, s1D),
354 _mm_unpacklo_epi64(s1A, s1B));
355 const __m256i m16 =
356 calc_mask_d16_inv_avx2(&s0, &s1, &_r, &y38, &y64, shift);
357 const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
358 xx_storeu_128(mask,
359 _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8)));
360 src0 += src0_stride << 2;
361 src1 += src1_stride << 2;
362 mask += 16;
363 i += 4;
364 } while (i < h);
365 } else if (w == 8) {
366 do {
367 const __m256i s0AB = yy_loadu2_128(src0 + src0_stride, src0);
368 const __m256i s0CD =
369 yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2);
370 const __m256i s1AB = yy_loadu2_128(src1 + src1_stride, src1);
371 const __m256i s1CD =
372 yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2);
373 const __m256i m16AB =
374 calc_mask_d16_inv_avx2(&s0AB, &s1AB, &_r, &y38, &y64, shift);
375 const __m256i m16CD =
376 calc_mask_d16_inv_avx2(&s0CD, &s1CD, &_r, &y38, &y64, shift);
377 const __m256i m8 = _mm256_packus_epi16(m16AB, m16CD);
378 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
379 src0 += src0_stride << 2;
380 src1 += src1_stride << 2;
381 mask += 32;
382 i += 4;
383 } while (i < h);
384 } else if (w == 16) {
385 do {
386 const __m256i s0A = yy_loadu_256(src0);
387 const __m256i s0B = yy_loadu_256(src0 + src0_stride);
388 const __m256i s1A = yy_loadu_256(src1);
389 const __m256i s1B = yy_loadu_256(src1 + src1_stride);
390 const __m256i m16A =
391 calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
392 const __m256i m16B =
393 calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
394 const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
395 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
396 src0 += src0_stride << 1;
397 src1 += src1_stride << 1;
398 mask += 32;
399 i += 2;
400 } while (i < h);
401 } else if (w == 32) {
402 do {
403 const __m256i s0A = yy_loadu_256(src0);
404 const __m256i s0B = yy_loadu_256(src0 + 16);
405 const __m256i s1A = yy_loadu_256(src1);
406 const __m256i s1B = yy_loadu_256(src1 + 16);
407 const __m256i m16A =
408 calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
409 const __m256i m16B =
410 calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
411 const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
412 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
413 src0 += src0_stride;
414 src1 += src1_stride;
415 mask += 32;
416 i += 1;
417 } while (i < h);
418 } else if (w == 64) {
419 do {
420 const __m256i s0A = yy_loadu_256(src0);
421 const __m256i s0B = yy_loadu_256(src0 + 16);
422 const __m256i s0C = yy_loadu_256(src0 + 32);
423 const __m256i s0D = yy_loadu_256(src0 + 48);
424 const __m256i s1A = yy_loadu_256(src1);
425 const __m256i s1B = yy_loadu_256(src1 + 16);
426 const __m256i s1C = yy_loadu_256(src1 + 32);
427 const __m256i s1D = yy_loadu_256(src1 + 48);
428 const __m256i m16A =
429 calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
430 const __m256i m16B =
431 calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
432 const __m256i m16C =
433 calc_mask_d16_inv_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
434 const __m256i m16D =
435 calc_mask_d16_inv_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
436 const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
437 const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
438 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
439 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
440 src0 += src0_stride;
441 src1 += src1_stride;
442 mask += 64;
443 i += 1;
444 } while (i < h);
445 } else {
446 do {
447 const __m256i s0A = yy_loadu_256(src0);
448 const __m256i s0B = yy_loadu_256(src0 + 16);
449 const __m256i s0C = yy_loadu_256(src0 + 32);
450 const __m256i s0D = yy_loadu_256(src0 + 48);
451 const __m256i s0E = yy_loadu_256(src0 + 64);
452 const __m256i s0F = yy_loadu_256(src0 + 80);
453 const __m256i s0G = yy_loadu_256(src0 + 96);
454 const __m256i s0H = yy_loadu_256(src0 + 112);
455 const __m256i s1A = yy_loadu_256(src1);
456 const __m256i s1B = yy_loadu_256(src1 + 16);
457 const __m256i s1C = yy_loadu_256(src1 + 32);
458 const __m256i s1D = yy_loadu_256(src1 + 48);
459 const __m256i s1E = yy_loadu_256(src1 + 64);
460 const __m256i s1F = yy_loadu_256(src1 + 80);
461 const __m256i s1G = yy_loadu_256(src1 + 96);
462 const __m256i s1H = yy_loadu_256(src1 + 112);
463 const __m256i m16A =
464 calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
465 const __m256i m16B =
466 calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
467 const __m256i m16C =
468 calc_mask_d16_inv_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
469 const __m256i m16D =
470 calc_mask_d16_inv_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
471 const __m256i m16E =
472 calc_mask_d16_inv_avx2(&s0E, &s1E, &_r, &y38, &y64, shift);
473 const __m256i m16F =
474 calc_mask_d16_inv_avx2(&s0F, &s1F, &_r, &y38, &y64, shift);
475 const __m256i m16G =
476 calc_mask_d16_inv_avx2(&s0G, &s1G, &_r, &y38, &y64, shift);
477 const __m256i m16H =
478 calc_mask_d16_inv_avx2(&s0H, &s1H, &_r, &y38, &y64, shift);
479 const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
480 const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
481 const __m256i m8EF = _mm256_packus_epi16(m16E, m16F);
482 const __m256i m8GH = _mm256_packus_epi16(m16G, m16H);
483 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
484 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
485 yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8EF, 0xd8));
486 yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8GH, 0xd8));
487 src0 += src0_stride;
488 src1 += src1_stride;
489 mask += 128;
490 i += 1;
491 } while (i < h);
492 }
493 }
494
av1_build_compound_diffwtd_mask_d16_avx2(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)495 void av1_build_compound_diffwtd_mask_d16_avx2(
496 uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
497 int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
498 ConvolveParams *conv_params, int bd) {
499 const int shift =
500 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
501 // When rounding constant is added, there is a possibility of overflow.
502 // However that much precision is not required. Code should very well work for
503 // other values of DIFF_FACTOR_LOG2 and AOM_BLEND_A64_MAX_ALPHA as well. But
504 // there is a possibility of corner case bugs.
505 assert(DIFF_FACTOR_LOG2 == 4);
506 assert(AOM_BLEND_A64_MAX_ALPHA == 64);
507
508 if (mask_type == DIFFWTD_38) {
509 build_compound_diffwtd_mask_d16_avx2(mask, src0, src0_stride, src1,
510 src1_stride, h, w, shift);
511 } else {
512 build_compound_diffwtd_mask_d16_inv_avx2(mask, src0, src0_stride, src1,
513 src1_stride, h, w, shift);
514 }
515 }
516
av1_build_compound_diffwtd_mask_highbd_avx2(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w,int bd)517 void av1_build_compound_diffwtd_mask_highbd_avx2(
518 uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0,
519 int src0_stride, const uint8_t *src1, int src1_stride, int h, int w,
520 int bd) {
521 if (w < 16) {
522 av1_build_compound_diffwtd_mask_highbd_ssse3(
523 mask, mask_type, src0, src0_stride, src1, src1_stride, h, w, bd);
524 } else {
525 assert(mask_type == DIFFWTD_38 || mask_type == DIFFWTD_38_INV);
526 assert(bd >= 8);
527 assert((w % 16) == 0);
528 const __m256i y0 = _mm256_setzero_si256();
529 const __m256i yAOM_BLEND_A64_MAX_ALPHA =
530 _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
531 const int mask_base = 38;
532 const __m256i ymask_base = _mm256_set1_epi16(mask_base);
533 const uint16_t *ssrc0 = CONVERT_TO_SHORTPTR(src0);
534 const uint16_t *ssrc1 = CONVERT_TO_SHORTPTR(src1);
535 if (bd == 8) {
536 if (mask_type == DIFFWTD_38_INV) {
537 for (int i = 0; i < h; ++i) {
538 for (int j = 0; j < w; j += 16) {
539 __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
540 __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
541 __m256i diff = _mm256_srai_epi16(
542 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), DIFF_FACTOR_LOG2);
543 __m256i m = _mm256_min_epi16(
544 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
545 yAOM_BLEND_A64_MAX_ALPHA);
546 m = _mm256_sub_epi16(yAOM_BLEND_A64_MAX_ALPHA, m);
547 m = _mm256_packus_epi16(m, m);
548 m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
549 __m128i m0 = _mm256_castsi256_si128(m);
550 _mm_storeu_si128((__m128i *)&mask[j], m0);
551 }
552 ssrc0 += src0_stride;
553 ssrc1 += src1_stride;
554 mask += w;
555 }
556 } else {
557 for (int i = 0; i < h; ++i) {
558 for (int j = 0; j < w; j += 16) {
559 __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
560 __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
561 __m256i diff = _mm256_srai_epi16(
562 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), DIFF_FACTOR_LOG2);
563 __m256i m = _mm256_min_epi16(
564 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
565 yAOM_BLEND_A64_MAX_ALPHA);
566 m = _mm256_packus_epi16(m, m);
567 m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
568 __m128i m0 = _mm256_castsi256_si128(m);
569 _mm_storeu_si128((__m128i *)&mask[j], m0);
570 }
571 ssrc0 += src0_stride;
572 ssrc1 += src1_stride;
573 mask += w;
574 }
575 }
576 } else {
577 const __m128i xshift = xx_set1_64_from_32i(bd - 8 + DIFF_FACTOR_LOG2);
578 if (mask_type == DIFFWTD_38_INV) {
579 for (int i = 0; i < h; ++i) {
580 for (int j = 0; j < w; j += 16) {
581 __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
582 __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
583 __m256i diff = _mm256_sra_epi16(
584 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), xshift);
585 __m256i m = _mm256_min_epi16(
586 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
587 yAOM_BLEND_A64_MAX_ALPHA);
588 m = _mm256_sub_epi16(yAOM_BLEND_A64_MAX_ALPHA, m);
589 m = _mm256_packus_epi16(m, m);
590 m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
591 __m128i m0 = _mm256_castsi256_si128(m);
592 _mm_storeu_si128((__m128i *)&mask[j], m0);
593 }
594 ssrc0 += src0_stride;
595 ssrc1 += src1_stride;
596 mask += w;
597 }
598 } else {
599 for (int i = 0; i < h; ++i) {
600 for (int j = 0; j < w; j += 16) {
601 __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
602 __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
603 __m256i diff = _mm256_sra_epi16(
604 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), xshift);
605 __m256i m = _mm256_min_epi16(
606 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
607 yAOM_BLEND_A64_MAX_ALPHA);
608 m = _mm256_packus_epi16(m, m);
609 m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
610 __m128i m0 = _mm256_castsi256_si128(m);
611 _mm_storeu_si128((__m128i *)&mask[j], m0);
612 }
613 ssrc0 += src0_stride;
614 ssrc1 += src1_stride;
615 mask += w;
616 }
617 }
618 }
619 }
620 }
621