• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <smmintrin.h>  // SSE4.1
13 #include <immintrin.h>  // AVX2
14 
15 #include <assert.h>
16 
17 #include "aom/aom_integer.h"
18 #include "aom_ports/mem.h"
19 #include "aom_dsp/aom_dsp_common.h"
20 
21 #include "aom_dsp/x86/synonyms.h"
22 #include "aom_dsp/x86/synonyms_avx2.h"
23 #include "aom_dsp/x86/blend_sse4.h"
24 #include "aom_dsp/x86/blend_mask_sse4.h"
25 
26 #include "config/aom_dsp_rtcd.h"
27 
blend_a64_d16_mask_w16_avx2(uint8_t * dst,const CONV_BUF_TYPE * src0,const CONV_BUF_TYPE * src1,const __m256i * m0,const __m256i * v_round_offset,const __m256i * v_maxval,int shift)28 static INLINE void blend_a64_d16_mask_w16_avx2(
29     uint8_t *dst, const CONV_BUF_TYPE *src0, const CONV_BUF_TYPE *src1,
30     const __m256i *m0, const __m256i *v_round_offset, const __m256i *v_maxval,
31     int shift) {
32   const __m256i max_minus_m0 = _mm256_sub_epi16(*v_maxval, *m0);
33   const __m256i s0_0 = yy_loadu_256(src0);
34   const __m256i s1_0 = yy_loadu_256(src1);
35   __m256i res0_lo = _mm256_madd_epi16(_mm256_unpacklo_epi16(s0_0, s1_0),
36                                       _mm256_unpacklo_epi16(*m0, max_minus_m0));
37   __m256i res0_hi = _mm256_madd_epi16(_mm256_unpackhi_epi16(s0_0, s1_0),
38                                       _mm256_unpackhi_epi16(*m0, max_minus_m0));
39   res0_lo =
40       _mm256_srai_epi32(_mm256_sub_epi32(res0_lo, *v_round_offset), shift);
41   res0_hi =
42       _mm256_srai_epi32(_mm256_sub_epi32(res0_hi, *v_round_offset), shift);
43   const __m256i res0 = _mm256_packs_epi32(res0_lo, res0_hi);
44   __m256i res = _mm256_packus_epi16(res0, res0);
45   res = _mm256_permute4x64_epi64(res, 0xd8);
46   _mm_storeu_si128((__m128i *)(dst), _mm256_castsi256_si128(res));
47 }
48 
blend_a64_d16_mask_w32_avx2(uint8_t * dst,const CONV_BUF_TYPE * src0,const CONV_BUF_TYPE * src1,const __m256i * m0,const __m256i * m1,const __m256i * v_round_offset,const __m256i * v_maxval,int shift)49 static INLINE void blend_a64_d16_mask_w32_avx2(
50     uint8_t *dst, const CONV_BUF_TYPE *src0, const CONV_BUF_TYPE *src1,
51     const __m256i *m0, const __m256i *m1, const __m256i *v_round_offset,
52     const __m256i *v_maxval, int shift) {
53   const __m256i max_minus_m0 = _mm256_sub_epi16(*v_maxval, *m0);
54   const __m256i max_minus_m1 = _mm256_sub_epi16(*v_maxval, *m1);
55   const __m256i s0_0 = yy_loadu_256(src0);
56   const __m256i s0_1 = yy_loadu_256(src0 + 16);
57   const __m256i s1_0 = yy_loadu_256(src1);
58   const __m256i s1_1 = yy_loadu_256(src1 + 16);
59   __m256i res0_lo = _mm256_madd_epi16(_mm256_unpacklo_epi16(s0_0, s1_0),
60                                       _mm256_unpacklo_epi16(*m0, max_minus_m0));
61   __m256i res0_hi = _mm256_madd_epi16(_mm256_unpackhi_epi16(s0_0, s1_0),
62                                       _mm256_unpackhi_epi16(*m0, max_minus_m0));
63   __m256i res1_lo = _mm256_madd_epi16(_mm256_unpacklo_epi16(s0_1, s1_1),
64                                       _mm256_unpacklo_epi16(*m1, max_minus_m1));
65   __m256i res1_hi = _mm256_madd_epi16(_mm256_unpackhi_epi16(s0_1, s1_1),
66                                       _mm256_unpackhi_epi16(*m1, max_minus_m1));
67   res0_lo =
68       _mm256_srai_epi32(_mm256_sub_epi32(res0_lo, *v_round_offset), shift);
69   res0_hi =
70       _mm256_srai_epi32(_mm256_sub_epi32(res0_hi, *v_round_offset), shift);
71   res1_lo =
72       _mm256_srai_epi32(_mm256_sub_epi32(res1_lo, *v_round_offset), shift);
73   res1_hi =
74       _mm256_srai_epi32(_mm256_sub_epi32(res1_hi, *v_round_offset), shift);
75   const __m256i res0 = _mm256_packs_epi32(res0_lo, res0_hi);
76   const __m256i res1 = _mm256_packs_epi32(res1_lo, res1_hi);
77   __m256i res = _mm256_packus_epi16(res0, res1);
78   res = _mm256_permute4x64_epi64(res, 0xd8);
79   _mm256_storeu_si256((__m256i *)(dst), res);
80 }
81 
lowbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift)82 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
83     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
84     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
85     const uint8_t *mask, uint32_t mask_stride, int h,
86     const __m256i *round_offset, int shift) {
87   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
88   for (int i = 0; i < h; ++i) {
89     const __m128i m = xx_loadu_128(mask);
90     const __m256i m0 = _mm256_cvtepu8_epi16(m);
91 
92     blend_a64_d16_mask_w16_avx2(dst, src0, src1, &m0, round_offset, &v_maxval,
93                                 shift);
94     mask += mask_stride;
95     dst += dst_stride;
96     src0 += src0_stride;
97     src1 += src1_stride;
98   }
99 }
100 
lowbd_blend_a64_d16_mask_subw0_subh0_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)101 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh0_w32_avx2(
102     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
103     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
104     const uint8_t *mask, uint32_t mask_stride, int h, int w,
105     const __m256i *round_offset, int shift) {
106   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
107   for (int i = 0; i < h; ++i) {
108     for (int j = 0; j < w; j += 32) {
109       const __m256i m = yy_loadu_256(mask + j);
110       const __m256i m0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(m));
111       const __m256i m1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(m, 1));
112 
113       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
114                                   round_offset, &v_maxval, shift);
115     }
116     mask += mask_stride;
117     dst += dst_stride;
118     src0 += src0_stride;
119     src1 += src1_stride;
120   }
121 }
122 
lowbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift)123 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
124     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
125     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
126     const uint8_t *mask, uint32_t mask_stride, int h,
127     const __m256i *round_offset, int shift) {
128   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
129   const __m256i one_b = _mm256_set1_epi8(1);
130   const __m256i two_w = _mm256_set1_epi16(2);
131   for (int i = 0; i < h; ++i) {
132     const __m256i m_i00 = yy_loadu_256(mask);
133     const __m256i m_i10 = yy_loadu_256(mask + mask_stride);
134 
135     const __m256i m0_ac = _mm256_adds_epu8(m_i00, m_i10);
136     const __m256i m0_acbd = _mm256_maddubs_epi16(m0_ac, one_b);
137     const __m256i m0 = _mm256_srli_epi16(_mm256_add_epi16(m0_acbd, two_w), 2);
138 
139     blend_a64_d16_mask_w16_avx2(dst, src0, src1, &m0, round_offset, &v_maxval,
140                                 shift);
141     mask += mask_stride << 1;
142     dst += dst_stride;
143     src0 += src0_stride;
144     src1 += src1_stride;
145   }
146 }
147 
lowbd_blend_a64_d16_mask_subw1_subh1_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)148 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh1_w32_avx2(
149     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
150     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
151     const uint8_t *mask, uint32_t mask_stride, int h, int w,
152     const __m256i *round_offset, int shift) {
153   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
154   const __m256i one_b = _mm256_set1_epi8(1);
155   const __m256i two_w = _mm256_set1_epi16(2);
156   for (int i = 0; i < h; ++i) {
157     for (int j = 0; j < w; j += 32) {
158       const __m256i m_i00 = yy_loadu_256(mask + 2 * j);
159       const __m256i m_i01 = yy_loadu_256(mask + 2 * j + 32);
160       const __m256i m_i10 = yy_loadu_256(mask + mask_stride + 2 * j);
161       const __m256i m_i11 = yy_loadu_256(mask + mask_stride + 2 * j + 32);
162 
163       const __m256i m0_ac = _mm256_adds_epu8(m_i00, m_i10);
164       const __m256i m1_ac = _mm256_adds_epu8(m_i01, m_i11);
165       const __m256i m0_acbd = _mm256_maddubs_epi16(m0_ac, one_b);
166       const __m256i m1_acbd = _mm256_maddubs_epi16(m1_ac, one_b);
167       const __m256i m0 = _mm256_srli_epi16(_mm256_add_epi16(m0_acbd, two_w), 2);
168       const __m256i m1 = _mm256_srli_epi16(_mm256_add_epi16(m1_acbd, two_w), 2);
169 
170       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
171                                   round_offset, &v_maxval, shift);
172     }
173     mask += mask_stride << 1;
174     dst += dst_stride;
175     src0 += src0_stride;
176     src1 += src1_stride;
177   }
178 }
179 
lowbd_blend_a64_d16_mask_subw1_subh0_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)180 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh0_w16_avx2(
181     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
182     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
183     const uint8_t *mask, uint32_t mask_stride, int h, int w,
184     const __m256i *round_offset, int shift) {
185   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
186   const __m256i one_b = _mm256_set1_epi8(1);
187   const __m256i zeros = _mm256_setzero_si256();
188   for (int i = 0; i < h; ++i) {
189     for (int j = 0; j < w; j += 16) {
190       const __m256i m_i00 = yy_loadu_256(mask + 2 * j);
191       const __m256i m0_ac = _mm256_maddubs_epi16(m_i00, one_b);
192       const __m256i m0 = _mm256_avg_epu16(m0_ac, zeros);
193 
194       blend_a64_d16_mask_w16_avx2(dst + j, src0 + j, src1 + j, &m0,
195                                   round_offset, &v_maxval, shift);
196     }
197     mask += mask_stride;
198     dst += dst_stride;
199     src0 += src0_stride;
200     src1 += src1_stride;
201   }
202 }
203 
lowbd_blend_a64_d16_mask_subw1_subh0_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)204 static INLINE void lowbd_blend_a64_d16_mask_subw1_subh0_w32_avx2(
205     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
206     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
207     const uint8_t *mask, uint32_t mask_stride, int h, int w,
208     const __m256i *round_offset, int shift) {
209   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
210   const __m256i one_b = _mm256_set1_epi8(1);
211   const __m256i zeros = _mm256_setzero_si256();
212   for (int i = 0; i < h; ++i) {
213     for (int j = 0; j < w; j += 32) {
214       const __m256i m_i00 = yy_loadu_256(mask + 2 * j);
215       const __m256i m_i01 = yy_loadu_256(mask + 2 * j + 32);
216       const __m256i m0_ac = _mm256_maddubs_epi16(m_i00, one_b);
217       const __m256i m1_ac = _mm256_maddubs_epi16(m_i01, one_b);
218       const __m256i m0 = _mm256_avg_epu16(m0_ac, zeros);
219       const __m256i m1 = _mm256_avg_epu16(m1_ac, zeros);
220 
221       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
222                                   round_offset, &v_maxval, shift);
223     }
224     mask += mask_stride;
225     dst += dst_stride;
226     src0 += src0_stride;
227     src1 += src1_stride;
228   }
229 }
230 
lowbd_blend_a64_d16_mask_subw0_subh1_w16_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)231 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh1_w16_avx2(
232     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
233     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
234     const uint8_t *mask, uint32_t mask_stride, int h, int w,
235     const __m256i *round_offset, int shift) {
236   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
237   const __m128i zeros = _mm_setzero_si128();
238   for (int i = 0; i < h; ++i) {
239     for (int j = 0; j < w; j += 16) {
240       const __m128i m_i00 = xx_loadu_128(mask + j);
241       const __m128i m_i10 = xx_loadu_128(mask + mask_stride + j);
242 
243       const __m128i m_ac = _mm_avg_epu8(_mm_adds_epu8(m_i00, m_i10), zeros);
244       const __m256i m0 = _mm256_cvtepu8_epi16(m_ac);
245 
246       blend_a64_d16_mask_w16_avx2(dst + j, src0 + j, src1 + j, &m0,
247                                   round_offset, &v_maxval, shift);
248     }
249     mask += mask_stride << 1;
250     dst += dst_stride;
251     src0 += src0_stride;
252     src1 += src1_stride;
253   }
254 }
255 
lowbd_blend_a64_d16_mask_subw0_subh1_w32_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,int w,const __m256i * round_offset,int shift)256 static INLINE void lowbd_blend_a64_d16_mask_subw0_subh1_w32_avx2(
257     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
258     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
259     const uint8_t *mask, uint32_t mask_stride, int h, int w,
260     const __m256i *round_offset, int shift) {
261   const __m256i v_maxval = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
262   const __m256i zeros = _mm256_setzero_si256();
263   for (int i = 0; i < h; ++i) {
264     for (int j = 0; j < w; j += 32) {
265       const __m256i m_i00 = yy_loadu_256(mask + j);
266       const __m256i m_i10 = yy_loadu_256(mask + mask_stride + j);
267 
268       const __m256i m_ac =
269           _mm256_avg_epu8(_mm256_adds_epu8(m_i00, m_i10), zeros);
270       const __m256i m0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(m_ac));
271       const __m256i m1 =
272           _mm256_cvtepu8_epi16(_mm256_extracti128_si256(m_ac, 1));
273 
274       blend_a64_d16_mask_w32_avx2(dst + j, src0 + j, src1 + j, &m0, &m1,
275                                   round_offset, &v_maxval, shift);
276     }
277     mask += mask_stride << 1;
278     dst += dst_stride;
279     src0 += src0_stride;
280     src1 += src1_stride;
281   }
282 }
283 
aom_lowbd_blend_a64_d16_mask_avx2(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,ConvolveParams * conv_params)284 void aom_lowbd_blend_a64_d16_mask_avx2(
285     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
286     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
287     const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
288     ConvolveParams *conv_params) {
289   const int bd = 8;
290   const int round_bits =
291       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
292 
293   const int round_offset =
294       ((1 << (round_bits + bd)) + (1 << (round_bits + bd - 1)) -
295        (1 << (round_bits - 1)))
296       << AOM_BLEND_A64_ROUND_BITS;
297 
298   const int shift = round_bits + AOM_BLEND_A64_ROUND_BITS;
299   assert(IMPLIES((void *)src0 == dst, src0_stride == dst_stride));
300   assert(IMPLIES((void *)src1 == dst, src1_stride == dst_stride));
301 
302   assert(h >= 4);
303   assert(w >= 4);
304   assert(IS_POWER_OF_TWO(h));
305   assert(IS_POWER_OF_TWO(w));
306   const __m128i v_round_offset = _mm_set1_epi32(round_offset);
307   const __m256i y_round_offset = _mm256_set1_epi32(round_offset);
308 
309   if (subw == 0 && subh == 0) {
310     switch (w) {
311       case 4:
312         aom_lowbd_blend_a64_d16_mask_subw0_subh0_w4_sse4_1(
313             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
314             mask_stride, h, &v_round_offset, shift);
315         break;
316       case 8:
317         aom_lowbd_blend_a64_d16_mask_subw0_subh0_w8_sse4_1(
318             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
319             mask_stride, h, &v_round_offset, shift);
320         break;
321       case 16:
322         lowbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
323             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
324             mask_stride, h, &y_round_offset, shift);
325         break;
326       default:
327         lowbd_blend_a64_d16_mask_subw0_subh0_w32_avx2(
328             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
329             mask_stride, h, w, &y_round_offset, shift);
330         break;
331     }
332   } else if (subw == 1 && subh == 1) {
333     switch (w) {
334       case 4:
335         aom_lowbd_blend_a64_d16_mask_subw1_subh1_w4_sse4_1(
336             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
337             mask_stride, h, &v_round_offset, shift);
338         break;
339       case 8:
340         aom_lowbd_blend_a64_d16_mask_subw1_subh1_w8_sse4_1(
341             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
342             mask_stride, h, &v_round_offset, shift);
343         break;
344       case 16:
345         lowbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
346             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
347             mask_stride, h, &y_round_offset, shift);
348         break;
349       default:
350         lowbd_blend_a64_d16_mask_subw1_subh1_w32_avx2(
351             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
352             mask_stride, h, w, &y_round_offset, shift);
353         break;
354     }
355   } else if (subw == 1 && subh == 0) {
356     switch (w) {
357       case 4:
358         aom_lowbd_blend_a64_d16_mask_subw1_subh0_w4_sse4_1(
359             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
360             mask_stride, h, &v_round_offset, shift);
361         break;
362       case 8:
363         aom_lowbd_blend_a64_d16_mask_subw1_subh0_w8_sse4_1(
364             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
365             mask_stride, h, &v_round_offset, shift);
366         break;
367       case 16:
368         lowbd_blend_a64_d16_mask_subw1_subh0_w16_avx2(
369             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
370             mask_stride, h, w, &y_round_offset, shift);
371         break;
372       default:
373         lowbd_blend_a64_d16_mask_subw1_subh0_w32_avx2(
374             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
375             mask_stride, h, w, &y_round_offset, shift);
376         break;
377     }
378   } else {
379     switch (w) {
380       case 4:
381         aom_lowbd_blend_a64_d16_mask_subw0_subh1_w4_sse4_1(
382             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
383             mask_stride, h, &v_round_offset, shift);
384         break;
385       case 8:
386         aom_lowbd_blend_a64_d16_mask_subw0_subh1_w8_sse4_1(
387             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
388             mask_stride, h, &v_round_offset, shift);
389         break;
390       case 16:
391         lowbd_blend_a64_d16_mask_subw0_subh1_w16_avx2(
392             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
393             mask_stride, h, w, &y_round_offset, shift);
394         break;
395       default:
396         lowbd_blend_a64_d16_mask_subw0_subh1_w32_avx2(
397             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
398             mask_stride, h, w, &y_round_offset, shift);
399         break;
400     }
401   }
402 }
403 
blend_16_u8_avx2(const uint8_t * src0,const uint8_t * src1,const __m256i * v_m0_b,const __m256i * v_m1_b,const int32_t bits)404 static INLINE __m256i blend_16_u8_avx2(const uint8_t *src0, const uint8_t *src1,
405                                        const __m256i *v_m0_b,
406                                        const __m256i *v_m1_b,
407                                        const int32_t bits) {
408   const __m256i v_s0_b = _mm256_castsi128_si256(xx_loadu_128(src0));
409   const __m256i v_s1_b = _mm256_castsi128_si256(xx_loadu_128(src1));
410   const __m256i v_s0_s_b = _mm256_permute4x64_epi64(v_s0_b, 0xd8);
411   const __m256i v_s1_s_b = _mm256_permute4x64_epi64(v_s1_b, 0xd8);
412 
413   const __m256i v_p0_w =
414       _mm256_maddubs_epi16(_mm256_unpacklo_epi8(v_s0_s_b, v_s1_s_b),
415                            _mm256_unpacklo_epi8(*v_m0_b, *v_m1_b));
416 
417   const __m256i v_res0_w = yy_roundn_epu16(v_p0_w, bits);
418   const __m256i v_res_b = _mm256_packus_epi16(v_res0_w, v_res0_w);
419   const __m256i v_res = _mm256_permute4x64_epi64(v_res_b, 0xd8);
420   return v_res;
421 }
422 
blend_32_u8_avx2(const uint8_t * src0,const uint8_t * src1,const __m256i * v_m0_b,const __m256i * v_m1_b,const int32_t bits)423 static INLINE __m256i blend_32_u8_avx2(const uint8_t *src0, const uint8_t *src1,
424                                        const __m256i *v_m0_b,
425                                        const __m256i *v_m1_b,
426                                        const int32_t bits) {
427   const __m256i v_s0_b = yy_loadu_256(src0);
428   const __m256i v_s1_b = yy_loadu_256(src1);
429 
430   const __m256i v_p0_w =
431       _mm256_maddubs_epi16(_mm256_unpacklo_epi8(v_s0_b, v_s1_b),
432                            _mm256_unpacklo_epi8(*v_m0_b, *v_m1_b));
433   const __m256i v_p1_w =
434       _mm256_maddubs_epi16(_mm256_unpackhi_epi8(v_s0_b, v_s1_b),
435                            _mm256_unpackhi_epi8(*v_m0_b, *v_m1_b));
436 
437   const __m256i v_res0_w = yy_roundn_epu16(v_p0_w, bits);
438   const __m256i v_res1_w = yy_roundn_epu16(v_p1_w, bits);
439   const __m256i v_res = _mm256_packus_epi16(v_res0_w, v_res1_w);
440   return v_res;
441 }
442 
blend_a64_mask_sx_sy_w16_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h)443 static INLINE void blend_a64_mask_sx_sy_w16_avx2(
444     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
445     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
446     const uint8_t *mask, uint32_t mask_stride, int h) {
447   const __m256i v_zmask_b = _mm256_set1_epi16(0xFF);
448   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
449   do {
450     const __m256i v_ral_b = yy_loadu_256(mask);
451     const __m256i v_rbl_b = yy_loadu_256(mask + mask_stride);
452     const __m256i v_rvsl_b = _mm256_add_epi8(v_ral_b, v_rbl_b);
453     const __m256i v_rvsal_w = _mm256_and_si256(v_rvsl_b, v_zmask_b);
454     const __m256i v_rvsbl_w =
455         _mm256_and_si256(_mm256_srli_si256(v_rvsl_b, 1), v_zmask_b);
456     const __m256i v_rsl_w = _mm256_add_epi16(v_rvsal_w, v_rvsbl_w);
457 
458     const __m256i v_m0_w = yy_roundn_epu16(v_rsl_w, 2);
459     const __m256i v_m0_b = _mm256_packus_epi16(v_m0_w, v_m0_w);
460     const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
461 
462     const __m256i y_res_b = blend_16_u8_avx2(src0, src1, &v_m0_b, &v_m1_b,
463                                              AOM_BLEND_A64_ROUND_BITS);
464 
465     xx_storeu_128(dst, _mm256_castsi256_si128(y_res_b));
466     dst += dst_stride;
467     src0 += src0_stride;
468     src1 += src1_stride;
469     mask += 2 * mask_stride;
470   } while (--h);
471 }
472 
blend_a64_mask_sx_sy_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)473 static INLINE void blend_a64_mask_sx_sy_w32n_avx2(
474     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
475     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
476     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
477   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
478   const __m256i v_zmask_b = _mm256_set1_epi16(0xFF);
479   do {
480     int c;
481     for (c = 0; c < w; c += 32) {
482       const __m256i v_ral_b = yy_loadu_256(mask + 2 * c);
483       const __m256i v_rah_b = yy_loadu_256(mask + 2 * c + 32);
484       const __m256i v_rbl_b = yy_loadu_256(mask + mask_stride + 2 * c);
485       const __m256i v_rbh_b = yy_loadu_256(mask + mask_stride + 2 * c + 32);
486       const __m256i v_rvsl_b = _mm256_add_epi8(v_ral_b, v_rbl_b);
487       const __m256i v_rvsh_b = _mm256_add_epi8(v_rah_b, v_rbh_b);
488       const __m256i v_rvsal_w = _mm256_and_si256(v_rvsl_b, v_zmask_b);
489       const __m256i v_rvsah_w = _mm256_and_si256(v_rvsh_b, v_zmask_b);
490       const __m256i v_rvsbl_w =
491           _mm256_and_si256(_mm256_srli_si256(v_rvsl_b, 1), v_zmask_b);
492       const __m256i v_rvsbh_w =
493           _mm256_and_si256(_mm256_srli_si256(v_rvsh_b, 1), v_zmask_b);
494       const __m256i v_rsl_w = _mm256_add_epi16(v_rvsal_w, v_rvsbl_w);
495       const __m256i v_rsh_w = _mm256_add_epi16(v_rvsah_w, v_rvsbh_w);
496 
497       const __m256i v_m0l_w = yy_roundn_epu16(v_rsl_w, 2);
498       const __m256i v_m0h_w = yy_roundn_epu16(v_rsh_w, 2);
499       const __m256i v_m0_b =
500           _mm256_permute4x64_epi64(_mm256_packus_epi16(v_m0l_w, v_m0h_w), 0xd8);
501       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
502 
503       const __m256i v_res_b = blend_32_u8_avx2(
504           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
505 
506       yy_storeu_256(dst + c, v_res_b);
507     }
508     dst += dst_stride;
509     src0 += src0_stride;
510     src1 += src1_stride;
511     mask += 2 * mask_stride;
512   } while (--h);
513 }
514 
blend_a64_mask_sx_sy_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)515 static INLINE void blend_a64_mask_sx_sy_avx2(
516     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
517     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
518     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
519   const __m128i v_shuffle_b = xx_loadu_128(g_blend_a64_mask_shuffle);
520   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
521   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
522   switch (w) {
523     case 4:
524       do {
525         const __m128i v_ra_b = xx_loadl_64(mask);
526         const __m128i v_rb_b = xx_loadl_64(mask + mask_stride);
527         const __m128i v_rvs_b = _mm_add_epi8(v_ra_b, v_rb_b);
528         const __m128i v_r_s_b = _mm_shuffle_epi8(v_rvs_b, v_shuffle_b);
529         const __m128i v_r0_s_w = _mm_cvtepu8_epi16(v_r_s_b);
530         const __m128i v_r1_s_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_r_s_b, 8));
531         const __m128i v_rs_w = _mm_add_epi16(v_r0_s_w, v_r1_s_w);
532         const __m128i v_m0_w = xx_roundn_epu16(v_rs_w, 2);
533         const __m128i v_m0_b = _mm_packus_epi16(v_m0_w, v_m0_w);
534         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
535 
536         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
537 
538         xx_storel_32(dst, v_res_b);
539 
540         dst += dst_stride;
541         src0 += src0_stride;
542         src1 += src1_stride;
543         mask += 2 * mask_stride;
544       } while (--h);
545       break;
546     case 8:
547       do {
548         const __m128i v_ra_b = xx_loadu_128(mask);
549         const __m128i v_rb_b = xx_loadu_128(mask + mask_stride);
550         const __m128i v_rvs_b = _mm_add_epi8(v_ra_b, v_rb_b);
551         const __m128i v_r_s_b = _mm_shuffle_epi8(v_rvs_b, v_shuffle_b);
552         const __m128i v_r0_s_w = _mm_cvtepu8_epi16(v_r_s_b);
553         const __m128i v_r1_s_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_r_s_b, 8));
554         const __m128i v_rs_w = _mm_add_epi16(v_r0_s_w, v_r1_s_w);
555         const __m128i v_m0_w = xx_roundn_epu16(v_rs_w, 2);
556         const __m128i v_m0_b = _mm_packus_epi16(v_m0_w, v_m0_w);
557         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
558 
559         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
560 
561         xx_storel_64(dst, v_res_b);
562 
563         dst += dst_stride;
564         src0 += src0_stride;
565         src1 += src1_stride;
566         mask += 2 * mask_stride;
567       } while (--h);
568       break;
569     case 16:
570       blend_a64_mask_sx_sy_w16_avx2(dst, dst_stride, src0, src0_stride, src1,
571                                     src1_stride, mask, mask_stride, h);
572       break;
573     default:
574       blend_a64_mask_sx_sy_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
575                                      src1_stride, mask, mask_stride, w, h);
576       break;
577   }
578 }
579 
blend_a64_mask_sx_w16_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h)580 static INLINE void blend_a64_mask_sx_w16_avx2(
581     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
582     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
583     const uint8_t *mask, uint32_t mask_stride, int h) {
584   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
585   const __m256i v_zmask_b = _mm256_set1_epi16(0xff);
586   do {
587     const __m256i v_rl_b = yy_loadu_256(mask);
588     const __m256i v_al_b =
589         _mm256_avg_epu8(v_rl_b, _mm256_srli_si256(v_rl_b, 1));
590 
591     const __m256i v_m0_w = _mm256_and_si256(v_al_b, v_zmask_b);
592     const __m256i v_m0_b = _mm256_packus_epi16(v_m0_w, _mm256_setzero_si256());
593     const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
594 
595     const __m256i v_res_b = blend_16_u8_avx2(src0, src1, &v_m0_b, &v_m1_b,
596                                              AOM_BLEND_A64_ROUND_BITS);
597 
598     xx_storeu_128(dst, _mm256_castsi256_si128(v_res_b));
599     dst += dst_stride;
600     src0 += src0_stride;
601     src1 += src1_stride;
602     mask += mask_stride;
603   } while (--h);
604 }
605 
blend_a64_mask_sx_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)606 static INLINE void blend_a64_mask_sx_w32n_avx2(
607     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
608     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
609     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
610   const __m256i v_shuffle_b = yy_loadu_256(g_blend_a64_mask_shuffle);
611   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
612   do {
613     int c;
614     for (c = 0; c < w; c += 32) {
615       const __m256i v_r0_b = yy_loadu_256(mask + 2 * c);
616       const __m256i v_r1_b = yy_loadu_256(mask + 2 * c + 32);
617       const __m256i v_r0_s_b = _mm256_shuffle_epi8(v_r0_b, v_shuffle_b);
618       const __m256i v_r1_s_b = _mm256_shuffle_epi8(v_r1_b, v_shuffle_b);
619       const __m256i v_al_b =
620           _mm256_avg_epu8(v_r0_s_b, _mm256_srli_si256(v_r0_s_b, 8));
621       const __m256i v_ah_b =
622           _mm256_avg_epu8(v_r1_s_b, _mm256_srli_si256(v_r1_s_b, 8));
623 
624       const __m256i v_m0_b =
625           _mm256_permute4x64_epi64(_mm256_unpacklo_epi64(v_al_b, v_ah_b), 0xd8);
626       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
627 
628       const __m256i v_res_b = blend_32_u8_avx2(
629           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
630 
631       yy_storeu_256(dst + c, v_res_b);
632     }
633     dst += dst_stride;
634     src0 += src0_stride;
635     src1 += src1_stride;
636     mask += mask_stride;
637   } while (--h);
638 }
639 
blend_a64_mask_sx_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)640 static INLINE void blend_a64_mask_sx_avx2(
641     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
642     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
643     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
644   const __m128i v_shuffle_b = xx_loadu_128(g_blend_a64_mask_shuffle);
645   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
646   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
647   switch (w) {
648     case 4:
649       do {
650         const __m128i v_r_b = xx_loadl_64(mask);
651         const __m128i v_r0_s_b = _mm_shuffle_epi8(v_r_b, v_shuffle_b);
652         const __m128i v_r_lo_b = _mm_unpacklo_epi64(v_r0_s_b, v_r0_s_b);
653         const __m128i v_r_hi_b = _mm_unpackhi_epi64(v_r0_s_b, v_r0_s_b);
654         const __m128i v_m0_b = _mm_avg_epu8(v_r_lo_b, v_r_hi_b);
655         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
656 
657         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
658 
659         xx_storel_32(dst, v_res_b);
660 
661         dst += dst_stride;
662         src0 += src0_stride;
663         src1 += src1_stride;
664         mask += mask_stride;
665       } while (--h);
666       break;
667     case 8:
668       do {
669         const __m128i v_r_b = xx_loadu_128(mask);
670         const __m128i v_r0_s_b = _mm_shuffle_epi8(v_r_b, v_shuffle_b);
671         const __m128i v_r_lo_b = _mm_unpacklo_epi64(v_r0_s_b, v_r0_s_b);
672         const __m128i v_r_hi_b = _mm_unpackhi_epi64(v_r0_s_b, v_r0_s_b);
673         const __m128i v_m0_b = _mm_avg_epu8(v_r_lo_b, v_r_hi_b);
674         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
675 
676         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
677 
678         xx_storel_64(dst, v_res_b);
679 
680         dst += dst_stride;
681         src0 += src0_stride;
682         src1 += src1_stride;
683         mask += mask_stride;
684       } while (--h);
685       break;
686     case 16:
687       blend_a64_mask_sx_w16_avx2(dst, dst_stride, src0, src0_stride, src1,
688                                  src1_stride, mask, mask_stride, h);
689       break;
690     default:
691       blend_a64_mask_sx_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
692                                   src1_stride, mask, mask_stride, w, h);
693       break;
694   }
695 }
696 
blend_a64_mask_sy_w16_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h)697 static INLINE void blend_a64_mask_sy_w16_avx2(
698     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
699     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
700     const uint8_t *mask, uint32_t mask_stride, int h) {
701   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
702   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
703   do {
704     const __m128i v_ra_b = xx_loadu_128(mask);
705     const __m128i v_rb_b = xx_loadu_128(mask + mask_stride);
706     const __m128i v_m0_b = _mm_avg_epu8(v_ra_b, v_rb_b);
707 
708     const __m128i v_m1_b = _mm_sub_epi16(v_maxval_b, v_m0_b);
709     const __m128i v_res_b = blend_16_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
710 
711     xx_storeu_128(dst, v_res_b);
712     dst += dst_stride;
713     src0 += src0_stride;
714     src1 += src1_stride;
715     mask += 2 * mask_stride;
716   } while (--h);
717 }
718 
blend_a64_mask_sy_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)719 static INLINE void blend_a64_mask_sy_w32n_avx2(
720     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
721     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
722     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
723   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
724   do {
725     int c;
726     for (c = 0; c < w; c += 32) {
727       const __m256i v_ra_b = yy_loadu_256(mask + c);
728       const __m256i v_rb_b = yy_loadu_256(mask + c + mask_stride);
729       const __m256i v_m0_b = _mm256_avg_epu8(v_ra_b, v_rb_b);
730       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
731       const __m256i v_res_b = blend_32_u8_avx2(
732           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
733 
734       yy_storeu_256(dst + c, v_res_b);
735     }
736     dst += dst_stride;
737     src0 += src0_stride;
738     src1 += src1_stride;
739     mask += 2 * mask_stride;
740   } while (--h);
741 }
742 
blend_a64_mask_sy_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)743 static INLINE void blend_a64_mask_sy_avx2(
744     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
745     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
746     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
747   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
748   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
749   switch (w) {
750     case 4:
751       do {
752         const __m128i v_ra_b = xx_loadl_32(mask);
753         const __m128i v_rb_b = xx_loadl_32(mask + mask_stride);
754         const __m128i v_m0_b = _mm_avg_epu8(v_ra_b, v_rb_b);
755         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
756         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
757 
758         xx_storel_32(dst, v_res_b);
759 
760         dst += dst_stride;
761         src0 += src0_stride;
762         src1 += src1_stride;
763         mask += 2 * mask_stride;
764       } while (--h);
765       break;
766     case 8:
767       do {
768         const __m128i v_ra_b = xx_loadl_64(mask);
769         const __m128i v_rb_b = xx_loadl_64(mask + mask_stride);
770         const __m128i v_m0_b = _mm_avg_epu8(v_ra_b, v_rb_b);
771         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
772         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
773 
774         xx_storel_64(dst, v_res_b);
775 
776         dst += dst_stride;
777         src0 += src0_stride;
778         src1 += src1_stride;
779         mask += 2 * mask_stride;
780       } while (--h);
781       break;
782     case 16:
783       blend_a64_mask_sy_w16_avx2(dst, dst_stride, src0, src0_stride, src1,
784                                  src1_stride, mask, mask_stride, h);
785       break;
786     default:
787       blend_a64_mask_sy_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
788                                   src1_stride, mask, mask_stride, w, h);
789   }
790 }
791 
blend_a64_mask_w32n_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)792 static INLINE void blend_a64_mask_w32n_avx2(
793     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
794     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
795     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
796   const __m256i v_maxval_b = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
797   do {
798     int c;
799     for (c = 0; c < w; c += 32) {
800       const __m256i v_m0_b = yy_loadu_256(mask + c);
801       const __m256i v_m1_b = _mm256_sub_epi8(v_maxval_b, v_m0_b);
802 
803       const __m256i v_res_b = blend_32_u8_avx2(
804           src0 + c, src1 + c, &v_m0_b, &v_m1_b, AOM_BLEND_A64_ROUND_BITS);
805 
806       yy_storeu_256(dst + c, v_res_b);
807     }
808     dst += dst_stride;
809     src0 += src0_stride;
810     src1 += src1_stride;
811     mask += mask_stride;
812   } while (--h);
813 }
814 
blend_a64_mask_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h)815 static INLINE void blend_a64_mask_avx2(
816     uint8_t *dst, uint32_t dst_stride, const uint8_t *src0,
817     uint32_t src0_stride, const uint8_t *src1, uint32_t src1_stride,
818     const uint8_t *mask, uint32_t mask_stride, int w, int h) {
819   const __m128i v_maxval_b = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
820   const __m128i _r = _mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
821   switch (w) {
822     case 4:
823       do {
824         const __m128i v_m0_b = xx_loadl_32(mask);
825         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
826         const __m128i v_res_b = blend_4_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
827 
828         xx_storel_32(dst, v_res_b);
829 
830         dst += dst_stride;
831         src0 += src0_stride;
832         src1 += src1_stride;
833         mask += mask_stride;
834       } while (--h);
835       break;
836     case 8:
837       do {
838         const __m128i v_m0_b = xx_loadl_64(mask);
839         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
840         const __m128i v_res_b = blend_8_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
841 
842         xx_storel_64(dst, v_res_b);
843 
844         dst += dst_stride;
845         src0 += src0_stride;
846         src1 += src1_stride;
847         mask += mask_stride;
848       } while (--h);
849       break;
850     case 16:
851       do {
852         const __m128i v_m0_b = xx_loadu_128(mask);
853         const __m128i v_m1_b = _mm_sub_epi8(v_maxval_b, v_m0_b);
854         const __m128i v_res_b = blend_16_u8(src0, src1, &v_m0_b, &v_m1_b, &_r);
855 
856         xx_storeu_128(dst, v_res_b);
857         dst += dst_stride;
858         src0 += src0_stride;
859         src1 += src1_stride;
860         mask += mask_stride;
861       } while (--h);
862       break;
863     default:
864       blend_a64_mask_w32n_avx2(dst, dst_stride, src0, src0_stride, src1,
865                                src1_stride, mask, mask_stride, w, h);
866   }
867 }
868 
aom_blend_a64_mask_avx2(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh)869 void aom_blend_a64_mask_avx2(uint8_t *dst, uint32_t dst_stride,
870                              const uint8_t *src0, uint32_t src0_stride,
871                              const uint8_t *src1, uint32_t src1_stride,
872                              const uint8_t *mask, uint32_t mask_stride, int w,
873                              int h, int subw, int subh) {
874   assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
875   assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
876 
877   assert(h >= 1);
878   assert(w >= 1);
879   assert(IS_POWER_OF_TWO(h));
880   assert(IS_POWER_OF_TWO(w));
881 
882   if (UNLIKELY((h | w) & 3)) {  // if (w <= 2 || h <= 2)
883     aom_blend_a64_mask_c(dst, dst_stride, src0, src0_stride, src1, src1_stride,
884                          mask, mask_stride, w, h, subw, subh);
885   } else {
886     if (subw & subh) {
887       blend_a64_mask_sx_sy_avx2(dst, dst_stride, src0, src0_stride, src1,
888                                 src1_stride, mask, mask_stride, w, h);
889     } else if (subw) {
890       blend_a64_mask_sx_avx2(dst, dst_stride, src0, src0_stride, src1,
891                              src1_stride, mask, mask_stride, w, h);
892     } else if (subh) {
893       blend_a64_mask_sy_avx2(dst, dst_stride, src0, src0_stride, src1,
894                              src1_stride, mask, mask_stride, w, h);
895     } else {
896       blend_a64_mask_avx2(dst, dst_stride, src0, src0_stride, src1, src1_stride,
897                           mask, mask_stride, w, h);
898     }
899   }
900 }
901 
902 #if CONFIG_AV1_HIGHBITDEPTH
903 //////////////////////////////////////////////////////////////////////////////
904 // aom_highbd_blend_a64_d16_mask_avx2()
905 //////////////////////////////////////////////////////////////////////////////
906 
highbd_blend_a64_d16_mask_w4_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const __m256i * mask0,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)907 static INLINE void highbd_blend_a64_d16_mask_w4_avx2(
908     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
909     const CONV_BUF_TYPE *src1, int src1_stride, const __m256i *mask0,
910     const __m256i *round_offset, int shift, const __m256i *clip_low,
911     const __m256i *clip_high, const __m256i *mask_max) {
912   // Load 4x u16 pixels from each of 4 rows from each source
913   const __m256i s0 = _mm256_set_epi64x(*(int64_t *)(src0 + 3 * src0_stride),
914                                        *(int64_t *)(src0 + 2 * src0_stride),
915                                        *(int64_t *)(src0 + 1 * src0_stride),
916                                        *(int64_t *)(src0 + 0 * src0_stride));
917   const __m256i s1 = _mm256_set_epi64x(*(int64_t *)(src1 + 3 * src1_stride),
918                                        *(int64_t *)(src1 + 2 * src1_stride),
919                                        *(int64_t *)(src1 + 1 * src1_stride),
920                                        *(int64_t *)(src1 + 0 * src1_stride));
921   // Generate the inverse mask
922   const __m256i mask1 = _mm256_sub_epi16(*mask_max, *mask0);
923 
924   // Multiply each mask by the respective source
925   const __m256i mul0_highs = _mm256_mulhi_epu16(*mask0, s0);
926   const __m256i mul0_lows = _mm256_mullo_epi16(*mask0, s0);
927   const __m256i mul0h = _mm256_unpackhi_epi16(mul0_lows, mul0_highs);
928   const __m256i mul0l = _mm256_unpacklo_epi16(mul0_lows, mul0_highs);
929   // Note that AVX2 unpack orders 64-bit words as [3 1] [2 0] to keep within
930   // lanes Later, packs does the same again which cancels this out with no need
931   // for a permute.  The intermediate values being reordered makes no difference
932 
933   const __m256i mul1_highs = _mm256_mulhi_epu16(mask1, s1);
934   const __m256i mul1_lows = _mm256_mullo_epi16(mask1, s1);
935   const __m256i mul1h = _mm256_unpackhi_epi16(mul1_lows, mul1_highs);
936   const __m256i mul1l = _mm256_unpacklo_epi16(mul1_lows, mul1_highs);
937 
938   const __m256i sumh = _mm256_add_epi32(mul0h, mul1h);
939   const __m256i suml = _mm256_add_epi32(mul0l, mul1l);
940 
941   const __m256i roundh =
942       _mm256_srai_epi32(_mm256_sub_epi32(sumh, *round_offset), shift);
943   const __m256i roundl =
944       _mm256_srai_epi32(_mm256_sub_epi32(suml, *round_offset), shift);
945 
946   const __m256i pack = _mm256_packs_epi32(roundl, roundh);
947   const __m256i clip =
948       _mm256_min_epi16(_mm256_max_epi16(pack, *clip_low), *clip_high);
949 
950   // _mm256_extract_epi64 doesn't exist on x86, so do it the old-fashioned way:
951   const __m128i cliph = _mm256_extracti128_si256(clip, 1);
952   xx_storel_64(dst + 3 * dst_stride, _mm_srli_si128(cliph, 8));
953   xx_storel_64(dst + 2 * dst_stride, cliph);
954   const __m128i clipl = _mm256_castsi256_si128(clip);
955   xx_storel_64(dst + 1 * dst_stride, _mm_srli_si128(clipl, 8));
956   xx_storel_64(dst + 0 * dst_stride, clipl);
957 }
958 
highbd_blend_a64_d16_mask_subw0_subh0_w4_avx2(uint16_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)959 static INLINE void highbd_blend_a64_d16_mask_subw0_subh0_w4_avx2(
960     uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
961     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
962     const uint8_t *mask, uint32_t mask_stride, int h,
963     const __m256i *round_offset, int shift, const __m256i *clip_low,
964     const __m256i *clip_high, const __m256i *mask_max) {
965   do {
966     // Load 8x u8 pixels from each of 4 rows of the mask, pad each to u16
967     const __m128i mask08 = _mm_set_epi32(*(int32_t *)(mask + 3 * mask_stride),
968                                          *(int32_t *)(mask + 2 * mask_stride),
969                                          *(int32_t *)(mask + 1 * mask_stride),
970                                          *(int32_t *)(mask + 0 * mask_stride));
971     const __m256i mask0 = _mm256_cvtepu8_epi16(mask08);
972 
973     highbd_blend_a64_d16_mask_w4_avx2(dst, dst_stride, src0, src0_stride, src1,
974                                       src1_stride, &mask0, round_offset, shift,
975                                       clip_low, clip_high, mask_max);
976 
977     dst += dst_stride * 4;
978     src0 += src0_stride * 4;
979     src1 += src1_stride * 4;
980     mask += mask_stride * 4;
981   } while (h -= 4);
982 }
983 
highbd_blend_a64_d16_mask_subw1_subh1_w4_avx2(uint16_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)984 static INLINE void highbd_blend_a64_d16_mask_subw1_subh1_w4_avx2(
985     uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
986     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
987     const uint8_t *mask, uint32_t mask_stride, int h,
988     const __m256i *round_offset, int shift, const __m256i *clip_low,
989     const __m256i *clip_high, const __m256i *mask_max) {
990   const __m256i one_b = _mm256_set1_epi8(1);
991   const __m256i two_w = _mm256_set1_epi16(2);
992   do {
993     // Load 8 pixels from each of 8 rows of mask,
994     // (saturating) add together rows then use madd to add adjacent pixels
995     // Finally, divide each value by 4 (with rounding)
996     const __m256i m0246 =
997         _mm256_set_epi64x(*(int64_t *)(mask + 6 * mask_stride),
998                           *(int64_t *)(mask + 4 * mask_stride),
999                           *(int64_t *)(mask + 2 * mask_stride),
1000                           *(int64_t *)(mask + 0 * mask_stride));
1001     const __m256i m1357 =
1002         _mm256_set_epi64x(*(int64_t *)(mask + 7 * mask_stride),
1003                           *(int64_t *)(mask + 5 * mask_stride),
1004                           *(int64_t *)(mask + 3 * mask_stride),
1005                           *(int64_t *)(mask + 1 * mask_stride));
1006     const __m256i addrows = _mm256_adds_epu8(m0246, m1357);
1007     const __m256i adjacent = _mm256_maddubs_epi16(addrows, one_b);
1008     const __m256i mask0 =
1009         _mm256_srli_epi16(_mm256_add_epi16(adjacent, two_w), 2);
1010 
1011     highbd_blend_a64_d16_mask_w4_avx2(dst, dst_stride, src0, src0_stride, src1,
1012                                       src1_stride, &mask0, round_offset, shift,
1013                                       clip_low, clip_high, mask_max);
1014 
1015     dst += dst_stride * 4;
1016     src0 += src0_stride * 4;
1017     src1 += src1_stride * 4;
1018     mask += mask_stride * 8;
1019   } while (h -= 4);
1020 }
1021 
highbd_blend_a64_d16_mask_w8_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const __m256i * mask0a,const __m256i * mask0b,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1022 static INLINE void highbd_blend_a64_d16_mask_w8_avx2(
1023     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1024     const CONV_BUF_TYPE *src1, int src1_stride, const __m256i *mask0a,
1025     const __m256i *mask0b, const __m256i *round_offset, int shift,
1026     const __m256i *clip_low, const __m256i *clip_high,
1027     const __m256i *mask_max) {
1028   // Load 8x u16 pixels from each of 4 rows from each source
1029   const __m256i s0a =
1030       yy_loadu2_128(src0 + 0 * src0_stride, src0 + 1 * src0_stride);
1031   const __m256i s0b =
1032       yy_loadu2_128(src0 + 2 * src0_stride, src0 + 3 * src0_stride);
1033   const __m256i s1a =
1034       yy_loadu2_128(src1 + 0 * src1_stride, src1 + 1 * src1_stride);
1035   const __m256i s1b =
1036       yy_loadu2_128(src1 + 2 * src1_stride, src1 + 3 * src1_stride);
1037 
1038   // Generate inverse masks
1039   const __m256i mask1a = _mm256_sub_epi16(*mask_max, *mask0a);
1040   const __m256i mask1b = _mm256_sub_epi16(*mask_max, *mask0b);
1041 
1042   // Multiply sources by respective masks
1043   const __m256i mul0a_highs = _mm256_mulhi_epu16(*mask0a, s0a);
1044   const __m256i mul0a_lows = _mm256_mullo_epi16(*mask0a, s0a);
1045   const __m256i mul0ah = _mm256_unpackhi_epi16(mul0a_lows, mul0a_highs);
1046   const __m256i mul0al = _mm256_unpacklo_epi16(mul0a_lows, mul0a_highs);
1047   // Note that AVX2 unpack orders 64-bit words as [3 1] [2 0] to keep within
1048   // lanes Later, packs does the same again which cancels this out with no need
1049   // for a permute.  The intermediate values being reordered makes no difference
1050 
1051   const __m256i mul1a_highs = _mm256_mulhi_epu16(mask1a, s1a);
1052   const __m256i mul1a_lows = _mm256_mullo_epi16(mask1a, s1a);
1053   const __m256i mul1ah = _mm256_unpackhi_epi16(mul1a_lows, mul1a_highs);
1054   const __m256i mul1al = _mm256_unpacklo_epi16(mul1a_lows, mul1a_highs);
1055 
1056   const __m256i sumah = _mm256_add_epi32(mul0ah, mul1ah);
1057   const __m256i sumal = _mm256_add_epi32(mul0al, mul1al);
1058 
1059   const __m256i mul0b_highs = _mm256_mulhi_epu16(*mask0b, s0b);
1060   const __m256i mul0b_lows = _mm256_mullo_epi16(*mask0b, s0b);
1061   const __m256i mul0bh = _mm256_unpackhi_epi16(mul0b_lows, mul0b_highs);
1062   const __m256i mul0bl = _mm256_unpacklo_epi16(mul0b_lows, mul0b_highs);
1063 
1064   const __m256i mul1b_highs = _mm256_mulhi_epu16(mask1b, s1b);
1065   const __m256i mul1b_lows = _mm256_mullo_epi16(mask1b, s1b);
1066   const __m256i mul1bh = _mm256_unpackhi_epi16(mul1b_lows, mul1b_highs);
1067   const __m256i mul1bl = _mm256_unpacklo_epi16(mul1b_lows, mul1b_highs);
1068 
1069   const __m256i sumbh = _mm256_add_epi32(mul0bh, mul1bh);
1070   const __m256i sumbl = _mm256_add_epi32(mul0bl, mul1bl);
1071 
1072   // Divide down each result, with rounding
1073   const __m256i roundah =
1074       _mm256_srai_epi32(_mm256_sub_epi32(sumah, *round_offset), shift);
1075   const __m256i roundal =
1076       _mm256_srai_epi32(_mm256_sub_epi32(sumal, *round_offset), shift);
1077   const __m256i roundbh =
1078       _mm256_srai_epi32(_mm256_sub_epi32(sumbh, *round_offset), shift);
1079   const __m256i roundbl =
1080       _mm256_srai_epi32(_mm256_sub_epi32(sumbl, *round_offset), shift);
1081 
1082   // Pack each i32 down to an i16 with saturation, then clip to valid range
1083   const __m256i packa = _mm256_packs_epi32(roundal, roundah);
1084   const __m256i clipa =
1085       _mm256_min_epi16(_mm256_max_epi16(packa, *clip_low), *clip_high);
1086   const __m256i packb = _mm256_packs_epi32(roundbl, roundbh);
1087   const __m256i clipb =
1088       _mm256_min_epi16(_mm256_max_epi16(packb, *clip_low), *clip_high);
1089 
1090   // Store 8x u16 pixels to each of 4 rows in the destination
1091   yy_storeu2_128(dst + 0 * dst_stride, dst + 1 * dst_stride, clipa);
1092   yy_storeu2_128(dst + 2 * dst_stride, dst + 3 * dst_stride, clipb);
1093 }
1094 
highbd_blend_a64_d16_mask_subw0_subh0_w8_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1095 static INLINE void highbd_blend_a64_d16_mask_subw0_subh0_w8_avx2(
1096     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1097     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1098     int mask_stride, int h, const __m256i *round_offset, int shift,
1099     const __m256i *clip_low, const __m256i *clip_high,
1100     const __m256i *mask_max) {
1101   do {
1102     // Load 8x u8 pixels from each of 4 rows in the mask
1103     const __m128i mask0a8 =
1104         _mm_set_epi64x(*(int64_t *)mask, *(uint64_t *)(mask + mask_stride));
1105     const __m128i mask0b8 =
1106         _mm_set_epi64x(*(int64_t *)(mask + 2 * mask_stride),
1107                        *(int64_t *)(mask + 3 * mask_stride));
1108     const __m256i mask0a = _mm256_cvtepu8_epi16(mask0a8);
1109     const __m256i mask0b = _mm256_cvtepu8_epi16(mask0b8);
1110 
1111     highbd_blend_a64_d16_mask_w8_avx2(
1112         dst, dst_stride, src0, src0_stride, src1, src1_stride, &mask0a, &mask0b,
1113         round_offset, shift, clip_low, clip_high, mask_max);
1114 
1115     dst += dst_stride * 4;
1116     src0 += src0_stride * 4;
1117     src1 += src1_stride * 4;
1118     mask += mask_stride * 4;
1119   } while (h -= 4);
1120 }
1121 
highbd_blend_a64_d16_mask_subw1_subh1_w8_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1122 static INLINE void highbd_blend_a64_d16_mask_subw1_subh1_w8_avx2(
1123     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1124     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1125     int mask_stride, int h, const __m256i *round_offset, int shift,
1126     const __m256i *clip_low, const __m256i *clip_high,
1127     const __m256i *mask_max) {
1128   const __m256i one_b = _mm256_set1_epi8(1);
1129   const __m256i two_w = _mm256_set1_epi16(2);
1130   do {
1131     // Load 16x u8 pixels from each of 8 rows in the mask,
1132     // (saturating) add together rows then use madd to add adjacent pixels
1133     // Finally, divide each value by 4 (with rounding)
1134     const __m256i m02 =
1135         yy_loadu2_128(mask + 0 * mask_stride, mask + 2 * mask_stride);
1136     const __m256i m13 =
1137         yy_loadu2_128(mask + 1 * mask_stride, mask + 3 * mask_stride);
1138     const __m256i m0123 =
1139         _mm256_maddubs_epi16(_mm256_adds_epu8(m02, m13), one_b);
1140     const __m256i mask_0a =
1141         _mm256_srli_epi16(_mm256_add_epi16(m0123, two_w), 2);
1142     const __m256i m46 =
1143         yy_loadu2_128(mask + 4 * mask_stride, mask + 6 * mask_stride);
1144     const __m256i m57 =
1145         yy_loadu2_128(mask + 5 * mask_stride, mask + 7 * mask_stride);
1146     const __m256i m4567 =
1147         _mm256_maddubs_epi16(_mm256_adds_epu8(m46, m57), one_b);
1148     const __m256i mask_0b =
1149         _mm256_srli_epi16(_mm256_add_epi16(m4567, two_w), 2);
1150 
1151     highbd_blend_a64_d16_mask_w8_avx2(
1152         dst, dst_stride, src0, src0_stride, src1, src1_stride, &mask_0a,
1153         &mask_0b, round_offset, shift, clip_low, clip_high, mask_max);
1154 
1155     dst += dst_stride * 4;
1156     src0 += src0_stride * 4;
1157     src1 += src1_stride * 4;
1158     mask += mask_stride * 8;
1159   } while (h -= 4);
1160 }
1161 
highbd_blend_a64_d16_mask_w16_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const __m256i * mask0a,const __m256i * mask0b,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1162 static INLINE void highbd_blend_a64_d16_mask_w16_avx2(
1163     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1164     const CONV_BUF_TYPE *src1, int src1_stride, const __m256i *mask0a,
1165     const __m256i *mask0b, const __m256i *round_offset, int shift,
1166     const __m256i *clip_low, const __m256i *clip_high,
1167     const __m256i *mask_max) {
1168   // Load 16x pixels from each of 2 rows from each source
1169   const __m256i s0a = yy_loadu_256(src0);
1170   const __m256i s0b = yy_loadu_256(src0 + src0_stride);
1171   const __m256i s1a = yy_loadu_256(src1);
1172   const __m256i s1b = yy_loadu_256(src1 + src1_stride);
1173 
1174   // Calculate inverse masks
1175   const __m256i mask1a = _mm256_sub_epi16(*mask_max, *mask0a);
1176   const __m256i mask1b = _mm256_sub_epi16(*mask_max, *mask0b);
1177 
1178   // Multiply each source by appropriate mask
1179   const __m256i mul0a_highs = _mm256_mulhi_epu16(*mask0a, s0a);
1180   const __m256i mul0a_lows = _mm256_mullo_epi16(*mask0a, s0a);
1181   const __m256i mul0ah = _mm256_unpackhi_epi16(mul0a_lows, mul0a_highs);
1182   const __m256i mul0al = _mm256_unpacklo_epi16(mul0a_lows, mul0a_highs);
1183   // Note that AVX2 unpack orders 64-bit words as [3 1] [2 0] to keep within
1184   // lanes Later, packs does the same again which cancels this out with no need
1185   // for a permute.  The intermediate values being reordered makes no difference
1186 
1187   const __m256i mul1a_highs = _mm256_mulhi_epu16(mask1a, s1a);
1188   const __m256i mul1a_lows = _mm256_mullo_epi16(mask1a, s1a);
1189   const __m256i mul1ah = _mm256_unpackhi_epi16(mul1a_lows, mul1a_highs);
1190   const __m256i mul1al = _mm256_unpacklo_epi16(mul1a_lows, mul1a_highs);
1191 
1192   const __m256i mulah = _mm256_add_epi32(mul0ah, mul1ah);
1193   const __m256i mulal = _mm256_add_epi32(mul0al, mul1al);
1194 
1195   const __m256i mul0b_highs = _mm256_mulhi_epu16(*mask0b, s0b);
1196   const __m256i mul0b_lows = _mm256_mullo_epi16(*mask0b, s0b);
1197   const __m256i mul0bh = _mm256_unpackhi_epi16(mul0b_lows, mul0b_highs);
1198   const __m256i mul0bl = _mm256_unpacklo_epi16(mul0b_lows, mul0b_highs);
1199 
1200   const __m256i mul1b_highs = _mm256_mulhi_epu16(mask1b, s1b);
1201   const __m256i mul1b_lows = _mm256_mullo_epi16(mask1b, s1b);
1202   const __m256i mul1bh = _mm256_unpackhi_epi16(mul1b_lows, mul1b_highs);
1203   const __m256i mul1bl = _mm256_unpacklo_epi16(mul1b_lows, mul1b_highs);
1204 
1205   const __m256i mulbh = _mm256_add_epi32(mul0bh, mul1bh);
1206   const __m256i mulbl = _mm256_add_epi32(mul0bl, mul1bl);
1207 
1208   const __m256i resah =
1209       _mm256_srai_epi32(_mm256_sub_epi32(mulah, *round_offset), shift);
1210   const __m256i resal =
1211       _mm256_srai_epi32(_mm256_sub_epi32(mulal, *round_offset), shift);
1212   const __m256i resbh =
1213       _mm256_srai_epi32(_mm256_sub_epi32(mulbh, *round_offset), shift);
1214   const __m256i resbl =
1215       _mm256_srai_epi32(_mm256_sub_epi32(mulbl, *round_offset), shift);
1216 
1217   // Signed saturating pack from i32 to i16:
1218   const __m256i packa = _mm256_packs_epi32(resal, resah);
1219   const __m256i packb = _mm256_packs_epi32(resbl, resbh);
1220 
1221   // Clip the values to the valid range
1222   const __m256i clipa =
1223       _mm256_min_epi16(_mm256_max_epi16(packa, *clip_low), *clip_high);
1224   const __m256i clipb =
1225       _mm256_min_epi16(_mm256_max_epi16(packb, *clip_low), *clip_high);
1226 
1227   // Store 16 pixels
1228   yy_storeu_256(dst, clipa);
1229   yy_storeu_256(dst + dst_stride, clipb);
1230 }
1231 
highbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,int w,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1232 static INLINE void highbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
1233     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1234     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1235     int mask_stride, int h, int w, const __m256i *round_offset, int shift,
1236     const __m256i *clip_low, const __m256i *clip_high,
1237     const __m256i *mask_max) {
1238   for (int i = 0; i < h; i += 2) {
1239     for (int j = 0; j < w; j += 16) {
1240       // Load 16x u8 alpha-mask values from each of two rows and pad to u16
1241       const __m128i masks_a8 = xx_loadu_128(mask + j);
1242       const __m128i masks_b8 = xx_loadu_128(mask + mask_stride + j);
1243       const __m256i mask0a = _mm256_cvtepu8_epi16(masks_a8);
1244       const __m256i mask0b = _mm256_cvtepu8_epi16(masks_b8);
1245 
1246       highbd_blend_a64_d16_mask_w16_avx2(
1247           dst + j, dst_stride, src0 + j, src0_stride, src1 + j, src1_stride,
1248           &mask0a, &mask0b, round_offset, shift, clip_low, clip_high, mask_max);
1249     }
1250     dst += dst_stride * 2;
1251     src0 += src0_stride * 2;
1252     src1 += src1_stride * 2;
1253     mask += mask_stride * 2;
1254   }
1255 }
1256 
highbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(uint16_t * dst,int dst_stride,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,const uint8_t * mask,int mask_stride,int h,int w,const __m256i * round_offset,int shift,const __m256i * clip_low,const __m256i * clip_high,const __m256i * mask_max)1257 static INLINE void highbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
1258     uint16_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
1259     const CONV_BUF_TYPE *src1, int src1_stride, const uint8_t *mask,
1260     int mask_stride, int h, int w, const __m256i *round_offset, int shift,
1261     const __m256i *clip_low, const __m256i *clip_high,
1262     const __m256i *mask_max) {
1263   const __m256i one_b = _mm256_set1_epi8(1);
1264   const __m256i two_w = _mm256_set1_epi16(2);
1265   for (int i = 0; i < h; i += 2) {
1266     for (int j = 0; j < w; j += 16) {
1267       // Load 32x u8 alpha-mask values from each of four rows
1268       // (saturating) add pairs of rows, then use madd to add adjacent values
1269       // Finally, divide down each result with rounding
1270       const __m256i m0 = yy_loadu_256(mask + 0 * mask_stride + 2 * j);
1271       const __m256i m1 = yy_loadu_256(mask + 1 * mask_stride + 2 * j);
1272       const __m256i m2 = yy_loadu_256(mask + 2 * mask_stride + 2 * j);
1273       const __m256i m3 = yy_loadu_256(mask + 3 * mask_stride + 2 * j);
1274 
1275       const __m256i m01_8 = _mm256_adds_epu8(m0, m1);
1276       const __m256i m23_8 = _mm256_adds_epu8(m2, m3);
1277 
1278       const __m256i m01 = _mm256_maddubs_epi16(m01_8, one_b);
1279       const __m256i m23 = _mm256_maddubs_epi16(m23_8, one_b);
1280 
1281       const __m256i mask0a = _mm256_srli_epi16(_mm256_add_epi16(m01, two_w), 2);
1282       const __m256i mask0b = _mm256_srli_epi16(_mm256_add_epi16(m23, two_w), 2);
1283 
1284       highbd_blend_a64_d16_mask_w16_avx2(
1285           dst + j, dst_stride, src0 + j, src0_stride, src1 + j, src1_stride,
1286           &mask0a, &mask0b, round_offset, shift, clip_low, clip_high, mask_max);
1287     }
1288     dst += dst_stride * 2;
1289     src0 += src0_stride * 2;
1290     src1 += src1_stride * 2;
1291     mask += mask_stride * 4;
1292   }
1293 }
1294 
aom_highbd_blend_a64_d16_mask_avx2(uint8_t * dst8,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,ConvolveParams * conv_params,const int bd)1295 void aom_highbd_blend_a64_d16_mask_avx2(
1296     uint8_t *dst8, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
1297     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
1298     const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
1299     ConvolveParams *conv_params, const int bd) {
1300   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
1301   const int round_bits =
1302       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
1303   const int32_t round_offset =
1304       ((1 << (round_bits + bd)) + (1 << (round_bits + bd - 1)) -
1305        (1 << (round_bits - 1)))
1306       << AOM_BLEND_A64_ROUND_BITS;
1307   const __m256i v_round_offset = _mm256_set1_epi32(round_offset);
1308   const int shift = round_bits + AOM_BLEND_A64_ROUND_BITS;
1309 
1310   const __m256i clip_low = _mm256_setzero_si256();
1311   const __m256i clip_high = _mm256_set1_epi16((1 << bd) - 1);
1312   const __m256i mask_max = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
1313 
1314   assert(IMPLIES((void *)src0 == dst, src0_stride == dst_stride));
1315   assert(IMPLIES((void *)src1 == dst, src1_stride == dst_stride));
1316 
1317   assert(h >= 4);
1318   assert(w >= 4);
1319   assert(IS_POWER_OF_TWO(h));
1320   assert(IS_POWER_OF_TWO(w));
1321 
1322   if (subw == 0 && subh == 0) {
1323     switch (w) {
1324       case 4:
1325         highbd_blend_a64_d16_mask_subw0_subh0_w4_avx2(
1326             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1327             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1328             &mask_max);
1329         break;
1330       case 8:
1331         highbd_blend_a64_d16_mask_subw0_subh0_w8_avx2(
1332             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1333             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1334             &mask_max);
1335         break;
1336       default:  // >= 16
1337         highbd_blend_a64_d16_mask_subw0_subh0_w16_avx2(
1338             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1339             mask_stride, h, w, &v_round_offset, shift, &clip_low, &clip_high,
1340             &mask_max);
1341         break;
1342     }
1343 
1344   } else if (subw == 1 && subh == 1) {
1345     switch (w) {
1346       case 4:
1347         highbd_blend_a64_d16_mask_subw1_subh1_w4_avx2(
1348             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1349             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1350             &mask_max);
1351         break;
1352       case 8:
1353         highbd_blend_a64_d16_mask_subw1_subh1_w8_avx2(
1354             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1355             mask_stride, h, &v_round_offset, shift, &clip_low, &clip_high,
1356             &mask_max);
1357         break;
1358       default:  // >= 16
1359         highbd_blend_a64_d16_mask_subw1_subh1_w16_avx2(
1360             dst, dst_stride, src0, src0_stride, src1, src1_stride, mask,
1361             mask_stride, h, w, &v_round_offset, shift, &clip_low, &clip_high,
1362             &mask_max);
1363         break;
1364     }
1365   } else {
1366     // Sub-sampling in only one axis doesn't seem to happen very much, so fall
1367     // back to the vanilla C implementation instead of having all the optimised
1368     // code for these.
1369     aom_highbd_blend_a64_d16_mask_c(dst8, dst_stride, src0, src0_stride, src1,
1370                                     src1_stride, mask, mask_stride, w, h, subw,
1371                                     subh, conv_params, bd);
1372   }
1373 }
1374 #endif  // CONFIG_AV1_HIGHBITDEPTH
1375