• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 #include "config/av1_rtcd.h"
14 #include "av1/common/warped_motion.h"
15 #include "aom_dsp/x86/synonyms.h"
16 
17 DECLARE_ALIGNED(32, static const uint8_t, shuffle_alpha0_mask01_avx2[32]) = {
18   0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
19   0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
20 };
21 
22 DECLARE_ALIGNED(32, static const uint8_t, shuffle_alpha0_mask23_avx2[32]) = {
23   2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
24   2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3
25 };
26 
27 DECLARE_ALIGNED(32, static const uint8_t, shuffle_alpha0_mask45_avx2[32]) = {
28   4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
29   4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5
30 };
31 
32 DECLARE_ALIGNED(32, static const uint8_t, shuffle_alpha0_mask67_avx2[32]) = {
33   6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
34   6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7
35 };
36 
37 DECLARE_ALIGNED(32, static const uint8_t, shuffle_gamma0_mask0_avx2[32]) = {
38   0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3,
39   0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
40 };
41 
42 DECLARE_ALIGNED(32, static const uint8_t, shuffle_gamma0_mask1_avx2[32]) = {
43   4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7,
44   4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7
45 };
46 
47 DECLARE_ALIGNED(32, static const uint8_t, shuffle_gamma0_mask2_avx2[32]) = {
48   8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11,
49   8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11
50 };
51 
52 DECLARE_ALIGNED(32, static const uint8_t, shuffle_gamma0_mask3_avx2[32]) = {
53   12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15,
54   12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15
55 };
56 
57 DECLARE_ALIGNED(32, static const uint8_t,
58                 shuffle_src0[32]) = { 0, 2, 2, 4, 4, 6, 6, 8, 1, 3, 3,
59                                       5, 5, 7, 7, 9, 0, 2, 2, 4, 4, 6,
60                                       6, 8, 1, 3, 3, 5, 5, 7, 7, 9 };
61 
62 DECLARE_ALIGNED(32, static const uint8_t,
63                 shuffle_src1[32]) = { 4,  6,  6,  8,  8,  10, 10, 12, 5,  7, 7,
64                                       9,  9,  11, 11, 13, 4,  6,  6,  8,  8, 10,
65                                       10, 12, 5,  7,  7,  9,  9,  11, 11, 13 };
66 
67 DECLARE_ALIGNED(32, static const uint8_t,
68                 shuffle_src2[32]) = { 1, 3, 3, 5, 5,  7, 7, 9, 2, 4, 4,
69                                       6, 6, 8, 8, 10, 1, 3, 3, 5, 5, 7,
70                                       7, 9, 2, 4, 4,  6, 6, 8, 8, 10 };
71 
72 DECLARE_ALIGNED(32, static const uint8_t,
73                 shuffle_src3[32]) = { 5,  7,  7,  9,  9,  11, 11, 13, 6,  8, 8,
74                                       10, 10, 12, 12, 14, 5,  7,  7,  9,  9, 11,
75                                       11, 13, 6,  8,  8,  10, 10, 12, 12, 14 };
76 
filter_src_pixels_avx2(const __m256i src,__m256i * horz_out,__m256i * coeff,const __m256i * shuffle_src,const __m256i * round_const,const __m128i * shift,int row)77 static INLINE void filter_src_pixels_avx2(const __m256i src, __m256i *horz_out,
78                                           __m256i *coeff,
79                                           const __m256i *shuffle_src,
80                                           const __m256i *round_const,
81                                           const __m128i *shift, int row) {
82   const __m256i src_0 = _mm256_shuffle_epi8(src, shuffle_src[0]);
83   const __m256i src_1 = _mm256_shuffle_epi8(src, shuffle_src[1]);
84   const __m256i src_2 = _mm256_shuffle_epi8(src, shuffle_src[2]);
85   const __m256i src_3 = _mm256_shuffle_epi8(src, shuffle_src[3]);
86 
87   const __m256i res_02 = _mm256_maddubs_epi16(src_0, coeff[0]);
88   const __m256i res_46 = _mm256_maddubs_epi16(src_1, coeff[1]);
89   const __m256i res_13 = _mm256_maddubs_epi16(src_2, coeff[2]);
90   const __m256i res_57 = _mm256_maddubs_epi16(src_3, coeff[3]);
91 
92   const __m256i res_even = _mm256_add_epi16(res_02, res_46);
93   const __m256i res_odd = _mm256_add_epi16(res_13, res_57);
94   const __m256i res =
95       _mm256_add_epi16(_mm256_add_epi16(res_even, res_odd), *round_const);
96   horz_out[row] = _mm256_srl_epi16(res, *shift);
97 }
98 
prepare_horizontal_filter_coeff_avx2(int alpha,int beta,int sx,__m256i * coeff)99 static INLINE void prepare_horizontal_filter_coeff_avx2(int alpha, int beta,
100                                                         int sx,
101                                                         __m256i *coeff) {
102   __m128i tmp_0 = _mm_loadl_epi64(
103       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 0 * alpha)) >>
104                                   WARPEDDIFF_PREC_BITS]);
105   __m128i tmp_1 = _mm_loadl_epi64(
106       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 1 * alpha)) >>
107                                   WARPEDDIFF_PREC_BITS]);
108   __m128i tmp_2 = _mm_loadl_epi64(
109       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 2 * alpha)) >>
110                                   WARPEDDIFF_PREC_BITS]);
111   __m128i tmp_3 = _mm_loadl_epi64(
112       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 3 * alpha)) >>
113                                   WARPEDDIFF_PREC_BITS]);
114 
115   __m128i tmp_4 = _mm_loadl_epi64(
116       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 4 * alpha)) >>
117                                   WARPEDDIFF_PREC_BITS]);
118   __m128i tmp_5 = _mm_loadl_epi64(
119       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 5 * alpha)) >>
120                                   WARPEDDIFF_PREC_BITS]);
121   __m128i tmp_6 = _mm_loadl_epi64(
122       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 6 * alpha)) >>
123                                   WARPEDDIFF_PREC_BITS]);
124   __m128i tmp_7 = _mm_loadl_epi64(
125       (__m128i *)&av1_filter_8bit[((unsigned)(sx + 7 * alpha)) >>
126                                   WARPEDDIFF_PREC_BITS]);
127 
128   __m256i tmp0_256 = _mm256_castsi128_si256(tmp_0);
129   __m256i tmp2_256 = _mm256_castsi128_si256(tmp_2);
130   __m256i tmp1_256 = _mm256_castsi128_si256(tmp_1);
131   __m256i tmp3_256 = _mm256_castsi128_si256(tmp_3);
132 
133   __m256i tmp4_256 = _mm256_castsi128_si256(tmp_4);
134   __m256i tmp6_256 = _mm256_castsi128_si256(tmp_6);
135   __m256i tmp5_256 = _mm256_castsi128_si256(tmp_5);
136   __m256i tmp7_256 = _mm256_castsi128_si256(tmp_7);
137 
138   __m128i tmp_8 = _mm_loadl_epi64(
139       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 0 * alpha) >>
140                                   WARPEDDIFF_PREC_BITS]);
141   tmp0_256 = _mm256_inserti128_si256(tmp0_256, tmp_8, 1);
142 
143   __m128i tmp_9 = _mm_loadl_epi64(
144       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 1 * alpha) >>
145                                   WARPEDDIFF_PREC_BITS]);
146   tmp1_256 = _mm256_inserti128_si256(tmp1_256, tmp_9, 1);
147 
148   __m128i tmp_10 = _mm_loadl_epi64(
149       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 2 * alpha) >>
150                                   WARPEDDIFF_PREC_BITS]);
151   tmp2_256 = _mm256_inserti128_si256(tmp2_256, tmp_10, 1);
152 
153   __m128i tmp_11 = _mm_loadl_epi64(
154       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 3 * alpha) >>
155                                   WARPEDDIFF_PREC_BITS]);
156   tmp3_256 = _mm256_inserti128_si256(tmp3_256, tmp_11, 1);
157 
158   tmp_2 = _mm_loadl_epi64(
159       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 4 * alpha) >>
160                                   WARPEDDIFF_PREC_BITS]);
161   tmp4_256 = _mm256_inserti128_si256(tmp4_256, tmp_2, 1);
162 
163   tmp_3 = _mm_loadl_epi64(
164       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 5 * alpha) >>
165                                   WARPEDDIFF_PREC_BITS]);
166   tmp5_256 = _mm256_inserti128_si256(tmp5_256, tmp_3, 1);
167 
168   tmp_6 = _mm_loadl_epi64(
169       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 6 * alpha) >>
170                                   WARPEDDIFF_PREC_BITS]);
171   tmp6_256 = _mm256_inserti128_si256(tmp6_256, tmp_6, 1);
172 
173   tmp_7 = _mm_loadl_epi64(
174       (__m128i *)&av1_filter_8bit[(unsigned)((sx + beta) + 7 * alpha) >>
175                                   WARPEDDIFF_PREC_BITS]);
176   tmp7_256 = _mm256_inserti128_si256(tmp7_256, tmp_7, 1);
177 
178   const __m256i tmp_12 = _mm256_unpacklo_epi16(tmp0_256, tmp2_256);
179   const __m256i tmp_13 = _mm256_unpacklo_epi16(tmp1_256, tmp3_256);
180   const __m256i tmp_14 = _mm256_unpacklo_epi16(tmp4_256, tmp6_256);
181   const __m256i tmp_15 = _mm256_unpacklo_epi16(tmp5_256, tmp7_256);
182 
183   const __m256i res_0 = _mm256_unpacklo_epi32(tmp_12, tmp_14);
184   const __m256i res_1 = _mm256_unpackhi_epi32(tmp_12, tmp_14);
185   const __m256i res_2 = _mm256_unpacklo_epi32(tmp_13, tmp_15);
186   const __m256i res_3 = _mm256_unpackhi_epi32(tmp_13, tmp_15);
187 
188   coeff[0] = _mm256_unpacklo_epi64(res_0, res_2);
189   coeff[1] = _mm256_unpackhi_epi64(res_0, res_2);
190   coeff[2] = _mm256_unpacklo_epi64(res_1, res_3);
191   coeff[3] = _mm256_unpackhi_epi64(res_1, res_3);
192 }
193 
prepare_horizontal_filter_coeff_beta0_avx2(int alpha,int sx,__m256i * coeff)194 static INLINE void prepare_horizontal_filter_coeff_beta0_avx2(int alpha, int sx,
195                                                               __m256i *coeff) {
196   __m128i tmp_0 = _mm_loadl_epi64(
197       (__m128i *)&av1_filter_8bit[(sx + 0 * alpha) >> WARPEDDIFF_PREC_BITS]);
198   __m128i tmp_1 = _mm_loadl_epi64(
199       (__m128i *)&av1_filter_8bit[(sx + 1 * alpha) >> WARPEDDIFF_PREC_BITS]);
200   __m128i tmp_2 = _mm_loadl_epi64(
201       (__m128i *)&av1_filter_8bit[(sx + 2 * alpha) >> WARPEDDIFF_PREC_BITS]);
202   __m128i tmp_3 = _mm_loadl_epi64(
203       (__m128i *)&av1_filter_8bit[(sx + 3 * alpha) >> WARPEDDIFF_PREC_BITS]);
204   __m128i tmp_4 = _mm_loadl_epi64(
205       (__m128i *)&av1_filter_8bit[(sx + 4 * alpha) >> WARPEDDIFF_PREC_BITS]);
206   __m128i tmp_5 = _mm_loadl_epi64(
207       (__m128i *)&av1_filter_8bit[(sx + 5 * alpha) >> WARPEDDIFF_PREC_BITS]);
208   __m128i tmp_6 = _mm_loadl_epi64(
209       (__m128i *)&av1_filter_8bit[(sx + 6 * alpha) >> WARPEDDIFF_PREC_BITS]);
210   __m128i tmp_7 = _mm_loadl_epi64(
211       (__m128i *)&av1_filter_8bit[(sx + 7 * alpha) >> WARPEDDIFF_PREC_BITS]);
212 
213   tmp_0 = _mm_unpacklo_epi16(tmp_0, tmp_2);
214   tmp_1 = _mm_unpacklo_epi16(tmp_1, tmp_3);
215   tmp_4 = _mm_unpacklo_epi16(tmp_4, tmp_6);
216   tmp_5 = _mm_unpacklo_epi16(tmp_5, tmp_7);
217 
218   const __m256i tmp_12 = _mm256_broadcastsi128_si256(tmp_0);
219   const __m256i tmp_13 = _mm256_broadcastsi128_si256(tmp_1);
220   const __m256i tmp_14 = _mm256_broadcastsi128_si256(tmp_4);
221   const __m256i tmp_15 = _mm256_broadcastsi128_si256(tmp_5);
222 
223   const __m256i res_0 = _mm256_unpacklo_epi32(tmp_12, tmp_14);
224   const __m256i res_1 = _mm256_unpackhi_epi32(tmp_12, tmp_14);
225   const __m256i res_2 = _mm256_unpacklo_epi32(tmp_13, tmp_15);
226   const __m256i res_3 = _mm256_unpackhi_epi32(tmp_13, tmp_15);
227 
228   coeff[0] = _mm256_unpacklo_epi64(res_0, res_2);
229   coeff[1] = _mm256_unpackhi_epi64(res_0, res_2);
230   coeff[2] = _mm256_unpacklo_epi64(res_1, res_3);
231   coeff[3] = _mm256_unpackhi_epi64(res_1, res_3);
232 }
233 
prepare_horizontal_filter_coeff_alpha0_avx2(int beta,int sx,__m256i * coeff)234 static INLINE void prepare_horizontal_filter_coeff_alpha0_avx2(int beta, int sx,
235                                                                __m256i *coeff) {
236   const __m128i tmp_0 =
237       _mm_loadl_epi64((__m128i *)&av1_filter_8bit[sx >> WARPEDDIFF_PREC_BITS]);
238   const __m128i tmp_1 = _mm_loadl_epi64(
239       (__m128i *)&av1_filter_8bit[(sx + beta) >> WARPEDDIFF_PREC_BITS]);
240 
241   const __m256i res_0 =
242       _mm256_inserti128_si256(_mm256_castsi128_si256(tmp_0), tmp_1, 0x1);
243 
244   coeff[0] = _mm256_shuffle_epi8(
245       res_0, _mm256_load_si256((__m256i *)shuffle_alpha0_mask01_avx2));
246   coeff[1] = _mm256_shuffle_epi8(
247       res_0, _mm256_load_si256((__m256i *)shuffle_alpha0_mask23_avx2));
248   coeff[2] = _mm256_shuffle_epi8(
249       res_0, _mm256_load_si256((__m256i *)shuffle_alpha0_mask45_avx2));
250   coeff[3] = _mm256_shuffle_epi8(
251       res_0, _mm256_load_si256((__m256i *)shuffle_alpha0_mask67_avx2));
252 }
253 
horizontal_filter_avx2(const __m256i src,__m256i * horz_out,int sx,int alpha,int beta,int row,const __m256i * shuffle_src,const __m256i * round_const,const __m128i * shift)254 static INLINE void horizontal_filter_avx2(const __m256i src, __m256i *horz_out,
255                                           int sx, int alpha, int beta, int row,
256                                           const __m256i *shuffle_src,
257                                           const __m256i *round_const,
258                                           const __m128i *shift) {
259   __m256i coeff[4];
260   prepare_horizontal_filter_coeff_avx2(alpha, beta, sx, coeff);
261   filter_src_pixels_avx2(src, horz_out, coeff, shuffle_src, round_const, shift,
262                          row);
263 }
prepare_horizontal_filter_coeff(int alpha,int sx,__m256i * coeff)264 static INLINE void prepare_horizontal_filter_coeff(int alpha, int sx,
265                                                    __m256i *coeff) {
266   const __m128i tmp_0 = _mm_loadl_epi64(
267       (__m128i *)&av1_filter_8bit[(sx + 0 * alpha) >> WARPEDDIFF_PREC_BITS]);
268   const __m128i tmp_1 = _mm_loadl_epi64(
269       (__m128i *)&av1_filter_8bit[(sx + 1 * alpha) >> WARPEDDIFF_PREC_BITS]);
270   const __m128i tmp_2 = _mm_loadl_epi64(
271       (__m128i *)&av1_filter_8bit[(sx + 2 * alpha) >> WARPEDDIFF_PREC_BITS]);
272   const __m128i tmp_3 = _mm_loadl_epi64(
273       (__m128i *)&av1_filter_8bit[(sx + 3 * alpha) >> WARPEDDIFF_PREC_BITS]);
274   const __m128i tmp_4 = _mm_loadl_epi64(
275       (__m128i *)&av1_filter_8bit[(sx + 4 * alpha) >> WARPEDDIFF_PREC_BITS]);
276   const __m128i tmp_5 = _mm_loadl_epi64(
277       (__m128i *)&av1_filter_8bit[(sx + 5 * alpha) >> WARPEDDIFF_PREC_BITS]);
278   const __m128i tmp_6 = _mm_loadl_epi64(
279       (__m128i *)&av1_filter_8bit[(sx + 6 * alpha) >> WARPEDDIFF_PREC_BITS]);
280   const __m128i tmp_7 = _mm_loadl_epi64(
281       (__m128i *)&av1_filter_8bit[(sx + 7 * alpha) >> WARPEDDIFF_PREC_BITS]);
282 
283   const __m128i tmp_8 = _mm_unpacklo_epi16(tmp_0, tmp_2);
284   const __m128i tmp_9 = _mm_unpacklo_epi16(tmp_1, tmp_3);
285   const __m128i tmp_10 = _mm_unpacklo_epi16(tmp_4, tmp_6);
286   const __m128i tmp_11 = _mm_unpacklo_epi16(tmp_5, tmp_7);
287 
288   const __m128i tmp_12 = _mm_unpacklo_epi32(tmp_8, tmp_10);
289   const __m128i tmp_13 = _mm_unpackhi_epi32(tmp_8, tmp_10);
290   const __m128i tmp_14 = _mm_unpacklo_epi32(tmp_9, tmp_11);
291   const __m128i tmp_15 = _mm_unpackhi_epi32(tmp_9, tmp_11);
292 
293   coeff[0] = _mm256_castsi128_si256(_mm_unpacklo_epi64(tmp_12, tmp_14));
294   coeff[1] = _mm256_castsi128_si256(_mm_unpackhi_epi64(tmp_12, tmp_14));
295   coeff[2] = _mm256_castsi128_si256(_mm_unpacklo_epi64(tmp_13, tmp_15));
296   coeff[3] = _mm256_castsi128_si256(_mm_unpackhi_epi64(tmp_13, tmp_15));
297 }
298 
warp_horizontal_filter_avx2(const uint8_t * ref,__m256i * horz_out,int stride,int32_t ix4,int32_t iy4,int32_t sx4,int alpha,int beta,int p_height,int height,int i,const __m256i * round_const,const __m128i * shift,const __m256i * shuffle_src)299 static INLINE void warp_horizontal_filter_avx2(
300     const uint8_t *ref, __m256i *horz_out, int stride, int32_t ix4, int32_t iy4,
301     int32_t sx4, int alpha, int beta, int p_height, int height, int i,
302     const __m256i *round_const, const __m128i *shift,
303     const __m256i *shuffle_src) {
304   int k, iy, sx, row = 0;
305   __m256i coeff[4];
306   for (k = -7; k <= (AOMMIN(8, p_height - i) - 2); k += 2) {
307     iy = iy4 + k;
308     iy = clamp(iy, 0, height - 1);
309     const __m128i src_0 =
310         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
311     iy = iy4 + k + 1;
312     iy = clamp(iy, 0, height - 1);
313     const __m128i src_1 =
314         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
315     const __m256i src_01 =
316         _mm256_inserti128_si256(_mm256_castsi128_si256(src_0), src_1, 0x1);
317     sx = sx4 + beta * (k + 4);
318     horizontal_filter_avx2(src_01, horz_out, sx, alpha, beta, row, shuffle_src,
319                            round_const, shift);
320     row += 1;
321   }
322   iy = iy4 + k;
323   iy = clamp(iy, 0, height - 1);
324   const __m256i src_01 = _mm256_castsi128_si256(
325       _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)));
326   sx = sx4 + beta * (k + 4);
327   prepare_horizontal_filter_coeff(alpha, sx, coeff);
328   filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src, round_const,
329                          shift, row);
330 }
331 
warp_horizontal_filter_alpha0_avx2(const uint8_t * ref,__m256i * horz_out,int stride,int32_t ix4,int32_t iy4,int32_t sx4,int alpha,int beta,int p_height,int height,int i,const __m256i * round_const,const __m128i * shift,const __m256i * shuffle_src)332 static INLINE void warp_horizontal_filter_alpha0_avx2(
333     const uint8_t *ref, __m256i *horz_out, int stride, int32_t ix4, int32_t iy4,
334     int32_t sx4, int alpha, int beta, int p_height, int height, int i,
335     const __m256i *round_const, const __m128i *shift,
336     const __m256i *shuffle_src) {
337   (void)alpha;
338   int k, iy, sx, row = 0;
339   __m256i coeff[4];
340   for (k = -7; k <= (AOMMIN(8, p_height - i) - 2); k += 2) {
341     iy = iy4 + k;
342     iy = clamp(iy, 0, height - 1);
343     const __m128i src_0 =
344         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
345     iy = iy4 + k + 1;
346     iy = clamp(iy, 0, height - 1);
347     const __m128i src_1 =
348         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
349     const __m256i src_01 =
350         _mm256_inserti128_si256(_mm256_castsi128_si256(src_0), src_1, 0x1);
351     sx = sx4 + beta * (k + 4);
352     prepare_horizontal_filter_coeff_alpha0_avx2(beta, sx, coeff);
353     filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src, round_const,
354                            shift, row);
355     row += 1;
356   }
357   iy = iy4 + k;
358   iy = clamp(iy, 0, height - 1);
359   const __m256i src_01 = _mm256_castsi128_si256(
360       _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)));
361   sx = sx4 + beta * (k + 4);
362   prepare_horizontal_filter_coeff_alpha0_avx2(beta, sx, coeff);
363   filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src, round_const,
364                          shift, row);
365 }
366 
warp_horizontal_filter_beta0_avx2(const uint8_t * ref,__m256i * horz_out,int stride,int32_t ix4,int32_t iy4,int32_t sx4,int alpha,int beta,int p_height,int height,int i,const __m256i * round_const,const __m128i * shift,const __m256i * shuffle_src)367 static INLINE void warp_horizontal_filter_beta0_avx2(
368     const uint8_t *ref, __m256i *horz_out, int stride, int32_t ix4, int32_t iy4,
369     int32_t sx4, int alpha, int beta, int p_height, int height, int i,
370     const __m256i *round_const, const __m128i *shift,
371     const __m256i *shuffle_src) {
372   (void)beta;
373   int k, iy, row = 0;
374   __m256i coeff[4];
375   prepare_horizontal_filter_coeff_beta0_avx2(alpha, sx4, coeff);
376   for (k = -7; k <= (AOMMIN(8, p_height - i) - 2); k += 2) {
377     iy = iy4 + k;
378     iy = clamp(iy, 0, height - 1);
379     const __m128i src_0 =
380         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
381     iy = iy4 + k + 1;
382     iy = clamp(iy, 0, height - 1);
383     const __m128i src_1 =
384         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
385     const __m256i src_01 =
386         _mm256_inserti128_si256(_mm256_castsi128_si256(src_0), src_1, 0x1);
387     filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src, round_const,
388                            shift, row);
389     row += 1;
390   }
391   iy = iy4 + k;
392   iy = clamp(iy, 0, height - 1);
393   const __m256i src_01 = _mm256_castsi128_si256(
394       _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)));
395   filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src, round_const,
396                          shift, row);
397 }
398 
warp_horizontal_filter_alpha0_beta0_avx2(const uint8_t * ref,__m256i * horz_out,int stride,int32_t ix4,int32_t iy4,int32_t sx4,int alpha,int beta,int p_height,int height,int i,const __m256i * round_const,const __m128i * shift,const __m256i * shuffle_src)399 static INLINE void warp_horizontal_filter_alpha0_beta0_avx2(
400     const uint8_t *ref, __m256i *horz_out, int stride, int32_t ix4, int32_t iy4,
401     int32_t sx4, int alpha, int beta, int p_height, int height, int i,
402     const __m256i *round_const, const __m128i *shift,
403     const __m256i *shuffle_src) {
404   (void)alpha;
405   int k, iy, row = 0;
406   __m256i coeff[4];
407   prepare_horizontal_filter_coeff_alpha0_avx2(beta, sx4, coeff);
408   for (k = -7; k <= (AOMMIN(8, p_height - i) - 2); k += 2) {
409     iy = iy4 + k;
410     iy = clamp(iy, 0, height - 1);
411     const __m128i src0 =
412         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
413     iy = iy4 + k + 1;
414     iy = clamp(iy, 0, height - 1);
415     const __m128i src1 =
416         _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
417     const __m256i src_01 =
418         _mm256_inserti128_si256(_mm256_castsi128_si256(src0), src1, 0x1);
419     filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src, round_const,
420                            shift, row);
421     row += 1;
422   }
423   iy = iy4 + k;
424   iy = clamp(iy, 0, height - 1);
425   const __m256i src_01 = _mm256_castsi128_si256(
426       _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)));
427   filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src, round_const,
428                          shift, row);
429 }
430 
unpack_weights_and_set_round_const_avx2(ConvolveParams * conv_params,const int round_bits,const int offset_bits,__m256i * res_sub_const,__m256i * round_bits_const,__m256i * wt)431 static INLINE void unpack_weights_and_set_round_const_avx2(
432     ConvolveParams *conv_params, const int round_bits, const int offset_bits,
433     __m256i *res_sub_const, __m256i *round_bits_const, __m256i *wt) {
434   *res_sub_const =
435       _mm256_set1_epi16(-(1 << (offset_bits - conv_params->round_1)) -
436                         (1 << (offset_bits - conv_params->round_1 - 1)));
437   *round_bits_const = _mm256_set1_epi16(((1 << round_bits) >> 1));
438 
439   const int w0 = conv_params->fwd_offset;
440   const int w1 = conv_params->bck_offset;
441   const __m256i wt0 = _mm256_set1_epi16((short)w0);
442   const __m256i wt1 = _mm256_set1_epi16((short)w1);
443   *wt = _mm256_unpacklo_epi16(wt0, wt1);
444 }
445 
prepare_vertical_filter_coeffs_avx2(int gamma,int delta,int sy,__m256i * coeffs)446 static INLINE void prepare_vertical_filter_coeffs_avx2(int gamma, int delta,
447                                                        int sy,
448                                                        __m256i *coeffs) {
449   __m128i filt_00 =
450       _mm_loadu_si128((__m128i *)(av1_warped_filter +
451                                   ((sy + 0 * gamma) >> WARPEDDIFF_PREC_BITS)));
452   __m128i filt_01 =
453       _mm_loadu_si128((__m128i *)(av1_warped_filter +
454                                   ((sy + 2 * gamma) >> WARPEDDIFF_PREC_BITS)));
455   __m128i filt_02 =
456       _mm_loadu_si128((__m128i *)(av1_warped_filter +
457                                   ((sy + 4 * gamma) >> WARPEDDIFF_PREC_BITS)));
458   __m128i filt_03 =
459       _mm_loadu_si128((__m128i *)(av1_warped_filter +
460                                   ((sy + 6 * gamma) >> WARPEDDIFF_PREC_BITS)));
461 
462   __m128i filt_10 = _mm_loadu_si128(
463       (__m128i *)(av1_warped_filter +
464                   (((sy + delta) + 0 * gamma) >> WARPEDDIFF_PREC_BITS)));
465   __m128i filt_11 = _mm_loadu_si128(
466       (__m128i *)(av1_warped_filter +
467                   (((sy + delta) + 2 * gamma) >> WARPEDDIFF_PREC_BITS)));
468   __m128i filt_12 = _mm_loadu_si128(
469       (__m128i *)(av1_warped_filter +
470                   (((sy + delta) + 4 * gamma) >> WARPEDDIFF_PREC_BITS)));
471   __m128i filt_13 = _mm_loadu_si128(
472       (__m128i *)(av1_warped_filter +
473                   (((sy + delta) + 6 * gamma) >> WARPEDDIFF_PREC_BITS)));
474 
475   __m256i filt_0 =
476       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_00), filt_10, 0x1);
477   __m256i filt_1 =
478       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_01), filt_11, 0x1);
479   __m256i filt_2 =
480       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_02), filt_12, 0x1);
481   __m256i filt_3 =
482       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_03), filt_13, 0x1);
483 
484   __m256i res_0 = _mm256_unpacklo_epi32(filt_0, filt_1);
485   __m256i res_1 = _mm256_unpacklo_epi32(filt_2, filt_3);
486   __m256i res_2 = _mm256_unpackhi_epi32(filt_0, filt_1);
487   __m256i res_3 = _mm256_unpackhi_epi32(filt_2, filt_3);
488 
489   coeffs[0] = _mm256_unpacklo_epi64(res_0, res_1);
490   coeffs[1] = _mm256_unpackhi_epi64(res_0, res_1);
491   coeffs[2] = _mm256_unpacklo_epi64(res_2, res_3);
492   coeffs[3] = _mm256_unpackhi_epi64(res_2, res_3);
493 
494   filt_00 =
495       _mm_loadu_si128((__m128i *)(av1_warped_filter +
496                                   ((sy + 1 * gamma) >> WARPEDDIFF_PREC_BITS)));
497   filt_01 =
498       _mm_loadu_si128((__m128i *)(av1_warped_filter +
499                                   ((sy + 3 * gamma) >> WARPEDDIFF_PREC_BITS)));
500   filt_02 =
501       _mm_loadu_si128((__m128i *)(av1_warped_filter +
502                                   ((sy + 5 * gamma) >> WARPEDDIFF_PREC_BITS)));
503   filt_03 =
504       _mm_loadu_si128((__m128i *)(av1_warped_filter +
505                                   ((sy + 7 * gamma) >> WARPEDDIFF_PREC_BITS)));
506 
507   filt_10 = _mm_loadu_si128(
508       (__m128i *)(av1_warped_filter +
509                   (((sy + delta) + 1 * gamma) >> WARPEDDIFF_PREC_BITS)));
510   filt_11 = _mm_loadu_si128(
511       (__m128i *)(av1_warped_filter +
512                   (((sy + delta) + 3 * gamma) >> WARPEDDIFF_PREC_BITS)));
513   filt_12 = _mm_loadu_si128(
514       (__m128i *)(av1_warped_filter +
515                   (((sy + delta) + 5 * gamma) >> WARPEDDIFF_PREC_BITS)));
516   filt_13 = _mm_loadu_si128(
517       (__m128i *)(av1_warped_filter +
518                   (((sy + delta) + 7 * gamma) >> WARPEDDIFF_PREC_BITS)));
519 
520   filt_0 =
521       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_00), filt_10, 0x1);
522   filt_1 =
523       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_01), filt_11, 0x1);
524   filt_2 =
525       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_02), filt_12, 0x1);
526   filt_3 =
527       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_03), filt_13, 0x1);
528 
529   res_0 = _mm256_unpacklo_epi32(filt_0, filt_1);
530   res_1 = _mm256_unpacklo_epi32(filt_2, filt_3);
531   res_2 = _mm256_unpackhi_epi32(filt_0, filt_1);
532   res_3 = _mm256_unpackhi_epi32(filt_2, filt_3);
533 
534   coeffs[4] = _mm256_unpacklo_epi64(res_0, res_1);
535   coeffs[5] = _mm256_unpackhi_epi64(res_0, res_1);
536   coeffs[6] = _mm256_unpacklo_epi64(res_2, res_3);
537   coeffs[7] = _mm256_unpackhi_epi64(res_2, res_3);
538 }
539 
prepare_vertical_filter_coeffs_delta0_avx2(int gamma,int sy,__m256i * coeffs)540 static INLINE void prepare_vertical_filter_coeffs_delta0_avx2(int gamma, int sy,
541                                                               __m256i *coeffs) {
542   __m128i filt_00 =
543       _mm_loadu_si128((__m128i *)(av1_warped_filter +
544                                   ((sy + 0 * gamma) >> WARPEDDIFF_PREC_BITS)));
545   __m128i filt_01 =
546       _mm_loadu_si128((__m128i *)(av1_warped_filter +
547                                   ((sy + 2 * gamma) >> WARPEDDIFF_PREC_BITS)));
548   __m128i filt_02 =
549       _mm_loadu_si128((__m128i *)(av1_warped_filter +
550                                   ((sy + 4 * gamma) >> WARPEDDIFF_PREC_BITS)));
551   __m128i filt_03 =
552       _mm_loadu_si128((__m128i *)(av1_warped_filter +
553                                   ((sy + 6 * gamma) >> WARPEDDIFF_PREC_BITS)));
554 
555   __m256i filt_0 = _mm256_broadcastsi128_si256(filt_00);
556   __m256i filt_1 = _mm256_broadcastsi128_si256(filt_01);
557   __m256i filt_2 = _mm256_broadcastsi128_si256(filt_02);
558   __m256i filt_3 = _mm256_broadcastsi128_si256(filt_03);
559 
560   __m256i res_0 = _mm256_unpacklo_epi32(filt_0, filt_1);
561   __m256i res_1 = _mm256_unpacklo_epi32(filt_2, filt_3);
562   __m256i res_2 = _mm256_unpackhi_epi32(filt_0, filt_1);
563   __m256i res_3 = _mm256_unpackhi_epi32(filt_2, filt_3);
564 
565   coeffs[0] = _mm256_unpacklo_epi64(res_0, res_1);
566   coeffs[1] = _mm256_unpackhi_epi64(res_0, res_1);
567   coeffs[2] = _mm256_unpacklo_epi64(res_2, res_3);
568   coeffs[3] = _mm256_unpackhi_epi64(res_2, res_3);
569 
570   filt_00 =
571       _mm_loadu_si128((__m128i *)(av1_warped_filter +
572                                   ((sy + 1 * gamma) >> WARPEDDIFF_PREC_BITS)));
573   filt_01 =
574       _mm_loadu_si128((__m128i *)(av1_warped_filter +
575                                   ((sy + 3 * gamma) >> WARPEDDIFF_PREC_BITS)));
576   filt_02 =
577       _mm_loadu_si128((__m128i *)(av1_warped_filter +
578                                   ((sy + 5 * gamma) >> WARPEDDIFF_PREC_BITS)));
579   filt_03 =
580       _mm_loadu_si128((__m128i *)(av1_warped_filter +
581                                   ((sy + 7 * gamma) >> WARPEDDIFF_PREC_BITS)));
582 
583   filt_0 = _mm256_broadcastsi128_si256(filt_00);
584   filt_1 = _mm256_broadcastsi128_si256(filt_01);
585   filt_2 = _mm256_broadcastsi128_si256(filt_02);
586   filt_3 = _mm256_broadcastsi128_si256(filt_03);
587 
588   res_0 = _mm256_unpacklo_epi32(filt_0, filt_1);
589   res_1 = _mm256_unpacklo_epi32(filt_2, filt_3);
590   res_2 = _mm256_unpackhi_epi32(filt_0, filt_1);
591   res_3 = _mm256_unpackhi_epi32(filt_2, filt_3);
592 
593   coeffs[4] = _mm256_unpacklo_epi64(res_0, res_1);
594   coeffs[5] = _mm256_unpackhi_epi64(res_0, res_1);
595   coeffs[6] = _mm256_unpacklo_epi64(res_2, res_3);
596   coeffs[7] = _mm256_unpackhi_epi64(res_2, res_3);
597 }
598 
prepare_vertical_filter_coeffs_gamma0_avx2(int delta,int sy,__m256i * coeffs)599 static INLINE void prepare_vertical_filter_coeffs_gamma0_avx2(int delta, int sy,
600                                                               __m256i *coeffs) {
601   const __m128i filt_0 = _mm_loadu_si128(
602       (__m128i *)(av1_warped_filter + (sy >> WARPEDDIFF_PREC_BITS)));
603   const __m128i filt_1 = _mm_loadu_si128(
604       (__m128i *)(av1_warped_filter + ((sy + delta) >> WARPEDDIFF_PREC_BITS)));
605 
606   __m256i res_0 =
607       _mm256_inserti128_si256(_mm256_castsi128_si256(filt_0), filt_1, 0x1);
608 
609   coeffs[0] = _mm256_shuffle_epi8(
610       res_0, _mm256_load_si256((__m256i *)shuffle_gamma0_mask0_avx2));
611   coeffs[1] = _mm256_shuffle_epi8(
612       res_0, _mm256_load_si256((__m256i *)shuffle_gamma0_mask1_avx2));
613   coeffs[2] = _mm256_shuffle_epi8(
614       res_0, _mm256_load_si256((__m256i *)shuffle_gamma0_mask2_avx2));
615   coeffs[3] = _mm256_shuffle_epi8(
616       res_0, _mm256_load_si256((__m256i *)shuffle_gamma0_mask3_avx2));
617 
618   coeffs[4] = coeffs[0];
619   coeffs[5] = coeffs[1];
620   coeffs[6] = coeffs[2];
621   coeffs[7] = coeffs[3];
622 }
623 
filter_src_pixels_vertical_avx2(__m256i * horz_out,__m256i * src,__m256i * coeffs,__m256i * res_lo,__m256i * res_hi,int row)624 static INLINE void filter_src_pixels_vertical_avx2(__m256i *horz_out,
625                                                    __m256i *src,
626                                                    __m256i *coeffs,
627                                                    __m256i *res_lo,
628                                                    __m256i *res_hi, int row) {
629   const __m256i src_6 = horz_out[row + 3];
630   const __m256i src_7 =
631       _mm256_permute2x128_si256(horz_out[row + 3], horz_out[row + 4], 0x21);
632 
633   src[6] = _mm256_unpacklo_epi16(src_6, src_7);
634 
635   const __m256i res_0 = _mm256_madd_epi16(src[0], coeffs[0]);
636   const __m256i res_2 = _mm256_madd_epi16(src[2], coeffs[1]);
637   const __m256i res_4 = _mm256_madd_epi16(src[4], coeffs[2]);
638   const __m256i res_6 = _mm256_madd_epi16(src[6], coeffs[3]);
639 
640   const __m256i res_even = _mm256_add_epi32(_mm256_add_epi32(res_0, res_2),
641                                             _mm256_add_epi32(res_4, res_6));
642 
643   src[7] = _mm256_unpackhi_epi16(src_6, src_7);
644 
645   const __m256i res_1 = _mm256_madd_epi16(src[1], coeffs[4]);
646   const __m256i res_3 = _mm256_madd_epi16(src[3], coeffs[5]);
647   const __m256i res_5 = _mm256_madd_epi16(src[5], coeffs[6]);
648   const __m256i res_7 = _mm256_madd_epi16(src[7], coeffs[7]);
649 
650   const __m256i res_odd = _mm256_add_epi32(_mm256_add_epi32(res_1, res_3),
651                                            _mm256_add_epi32(res_5, res_7));
652 
653   // Rearrange pixels back into the order 0 ... 7
654   *res_lo = _mm256_unpacklo_epi32(res_even, res_odd);
655   *res_hi = _mm256_unpackhi_epi32(res_even, res_odd);
656 }
657 
store_vertical_filter_output_avx2(const __m256i * res_lo,const __m256i * res_hi,const __m256i * res_add_const,const __m256i * wt,const __m256i * res_sub_const,const __m256i * round_bits_const,uint8_t * pred,ConvolveParams * conv_params,int i,int j,int k,const int reduce_bits_vert,int p_stride,int p_width,const int round_bits)658 static INLINE void store_vertical_filter_output_avx2(
659     const __m256i *res_lo, const __m256i *res_hi, const __m256i *res_add_const,
660     const __m256i *wt, const __m256i *res_sub_const,
661     const __m256i *round_bits_const, uint8_t *pred, ConvolveParams *conv_params,
662     int i, int j, int k, const int reduce_bits_vert, int p_stride, int p_width,
663     const int round_bits) {
664   __m256i res_lo_1 = *res_lo;
665   __m256i res_hi_1 = *res_hi;
666 
667   if (conv_params->is_compound) {
668     __m128i *const p_0 =
669         (__m128i *)&conv_params->dst[(i + k + 4) * conv_params->dst_stride + j];
670     __m128i *const p_1 =
671         (__m128i *)&conv_params
672             ->dst[(i + (k + 1) + 4) * conv_params->dst_stride + j];
673 
674     res_lo_1 = _mm256_srai_epi32(_mm256_add_epi32(res_lo_1, *res_add_const),
675                                  reduce_bits_vert);
676 
677     const __m256i temp_lo_16 = _mm256_packus_epi32(res_lo_1, res_lo_1);
678     __m256i res_lo_16;
679     if (conv_params->do_average) {
680       __m128i *const dst8_0 = (__m128i *)&pred[(i + k + 4) * p_stride + j];
681       __m128i *const dst8_1 =
682           (__m128i *)&pred[(i + (k + 1) + 4) * p_stride + j];
683       const __m128i p_16_0 = _mm_loadl_epi64(p_0);
684       const __m128i p_16_1 = _mm_loadl_epi64(p_1);
685       const __m256i p_16 =
686           _mm256_inserti128_si256(_mm256_castsi128_si256(p_16_0), p_16_1, 1);
687       if (conv_params->use_dist_wtd_comp_avg) {
688         const __m256i p_16_lo = _mm256_unpacklo_epi16(p_16, temp_lo_16);
689         const __m256i wt_res_lo = _mm256_madd_epi16(p_16_lo, *wt);
690         const __m256i shifted_32 =
691             _mm256_srai_epi32(wt_res_lo, DIST_PRECISION_BITS);
692         res_lo_16 = _mm256_packus_epi32(shifted_32, shifted_32);
693       } else {
694         res_lo_16 = _mm256_srai_epi16(_mm256_add_epi16(p_16, temp_lo_16), 1);
695       }
696       res_lo_16 = _mm256_add_epi16(res_lo_16, *res_sub_const);
697       res_lo_16 = _mm256_srai_epi16(
698           _mm256_add_epi16(res_lo_16, *round_bits_const), round_bits);
699       const __m256i res_8_lo = _mm256_packus_epi16(res_lo_16, res_lo_16);
700       const __m128i res_8_lo_0 = _mm256_castsi256_si128(res_8_lo);
701       const __m128i res_8_lo_1 = _mm256_extracti128_si256(res_8_lo, 1);
702       *(int *)dst8_0 = _mm_cvtsi128_si32(res_8_lo_0);
703       *(int *)dst8_1 = _mm_cvtsi128_si32(res_8_lo_1);
704     } else {
705       const __m128i temp_lo_16_0 = _mm256_castsi256_si128(temp_lo_16);
706       const __m128i temp_lo_16_1 = _mm256_extracti128_si256(temp_lo_16, 1);
707       _mm_storel_epi64(p_0, temp_lo_16_0);
708       _mm_storel_epi64(p_1, temp_lo_16_1);
709     }
710     if (p_width > 4) {
711       __m128i *const p4_0 =
712           (__m128i *)&conv_params
713               ->dst[(i + k + 4) * conv_params->dst_stride + j + 4];
714       __m128i *const p4_1 =
715           (__m128i *)&conv_params
716               ->dst[(i + (k + 1) + 4) * conv_params->dst_stride + j + 4];
717       res_hi_1 = _mm256_srai_epi32(_mm256_add_epi32(res_hi_1, *res_add_const),
718                                    reduce_bits_vert);
719       const __m256i temp_hi_16 = _mm256_packus_epi32(res_hi_1, res_hi_1);
720       __m256i res_hi_16;
721       if (conv_params->do_average) {
722         __m128i *const dst8_4_0 =
723             (__m128i *)&pred[(i + k + 4) * p_stride + j + 4];
724         __m128i *const dst8_4_1 =
725             (__m128i *)&pred[(i + (k + 1) + 4) * p_stride + j + 4];
726         const __m128i p4_16_0 = _mm_loadl_epi64(p4_0);
727         const __m128i p4_16_1 = _mm_loadl_epi64(p4_1);
728         const __m256i p4_16 = _mm256_inserti128_si256(
729             _mm256_castsi128_si256(p4_16_0), p4_16_1, 1);
730         if (conv_params->use_dist_wtd_comp_avg) {
731           const __m256i p_16_hi = _mm256_unpacklo_epi16(p4_16, temp_hi_16);
732           const __m256i wt_res_hi = _mm256_madd_epi16(p_16_hi, *wt);
733           const __m256i shifted_32 =
734               _mm256_srai_epi32(wt_res_hi, DIST_PRECISION_BITS);
735           res_hi_16 = _mm256_packus_epi32(shifted_32, shifted_32);
736         } else {
737           res_hi_16 = _mm256_srai_epi16(_mm256_add_epi16(p4_16, temp_hi_16), 1);
738         }
739         res_hi_16 = _mm256_add_epi16(res_hi_16, *res_sub_const);
740         res_hi_16 = _mm256_srai_epi16(
741             _mm256_add_epi16(res_hi_16, *round_bits_const), round_bits);
742         __m256i res_8_hi = _mm256_packus_epi16(res_hi_16, res_hi_16);
743         const __m128i res_8_hi_0 = _mm256_castsi256_si128(res_8_hi);
744         const __m128i res_8_hi_1 = _mm256_extracti128_si256(res_8_hi, 1);
745         *(int *)dst8_4_0 = _mm_cvtsi128_si32(res_8_hi_0);
746         *(int *)dst8_4_1 = _mm_cvtsi128_si32(res_8_hi_1);
747       } else {
748         const __m128i temp_hi_16_0 = _mm256_castsi256_si128(temp_hi_16);
749         const __m128i temp_hi_16_1 = _mm256_extracti128_si256(temp_hi_16, 1);
750         _mm_storel_epi64(p4_0, temp_hi_16_0);
751         _mm_storel_epi64(p4_1, temp_hi_16_1);
752       }
753     }
754   } else {
755     const __m256i res_lo_round = _mm256_srai_epi32(
756         _mm256_add_epi32(res_lo_1, *res_add_const), reduce_bits_vert);
757     const __m256i res_hi_round = _mm256_srai_epi32(
758         _mm256_add_epi32(res_hi_1, *res_add_const), reduce_bits_vert);
759 
760     const __m256i res_16bit = _mm256_packs_epi32(res_lo_round, res_hi_round);
761     const __m256i res_8bit = _mm256_packus_epi16(res_16bit, res_16bit);
762     const __m128i res_8bit0 = _mm256_castsi256_si128(res_8bit);
763     const __m128i res_8bit1 = _mm256_extracti128_si256(res_8bit, 1);
764 
765     // Store, blending with 'pred' if needed
766     __m128i *const p = (__m128i *)&pred[(i + k + 4) * p_stride + j];
767     __m128i *const p1 = (__m128i *)&pred[(i + (k + 1) + 4) * p_stride + j];
768 
769     if (p_width == 4) {
770       *(int *)p = _mm_cvtsi128_si32(res_8bit0);
771       *(int *)p1 = _mm_cvtsi128_si32(res_8bit1);
772     } else {
773       _mm_storel_epi64(p, res_8bit0);
774       _mm_storel_epi64(p1, res_8bit1);
775     }
776   }
777 }
778 
warp_vertical_filter_avx2(uint8_t * pred,__m256i * horz_out,ConvolveParams * conv_params,int16_t gamma,int16_t delta,int p_height,int p_stride,int p_width,int i,int j,int sy4,const int reduce_bits_vert,const __m256i * res_add_const,const int round_bits,const __m256i * res_sub_const,const __m256i * round_bits_const,const __m256i * wt)779 static INLINE void warp_vertical_filter_avx2(
780     uint8_t *pred, __m256i *horz_out, ConvolveParams *conv_params,
781     int16_t gamma, int16_t delta, int p_height, int p_stride, int p_width,
782     int i, int j, int sy4, const int reduce_bits_vert,
783     const __m256i *res_add_const, const int round_bits,
784     const __m256i *res_sub_const, const __m256i *round_bits_const,
785     const __m256i *wt) {
786   int k, row = 0;
787   __m256i src[8];
788   const __m256i src_0 = horz_out[0];
789   const __m256i src_1 =
790       _mm256_permute2x128_si256(horz_out[0], horz_out[1], 0x21);
791   const __m256i src_2 = horz_out[1];
792   const __m256i src_3 =
793       _mm256_permute2x128_si256(horz_out[1], horz_out[2], 0x21);
794   const __m256i src_4 = horz_out[2];
795   const __m256i src_5 =
796       _mm256_permute2x128_si256(horz_out[2], horz_out[3], 0x21);
797 
798   src[0] = _mm256_unpacklo_epi16(src_0, src_1);
799   src[2] = _mm256_unpacklo_epi16(src_2, src_3);
800   src[4] = _mm256_unpacklo_epi16(src_4, src_5);
801 
802   src[1] = _mm256_unpackhi_epi16(src_0, src_1);
803   src[3] = _mm256_unpackhi_epi16(src_2, src_3);
804   src[5] = _mm256_unpackhi_epi16(src_4, src_5);
805 
806   for (k = -4; k < AOMMIN(4, p_height - i - 4); k += 2) {
807     int sy = sy4 + delta * (k + 4);
808     __m256i coeffs[8];
809     prepare_vertical_filter_coeffs_avx2(gamma, delta, sy, coeffs);
810     __m256i res_lo, res_hi;
811     filter_src_pixels_vertical_avx2(horz_out, src, coeffs, &res_lo, &res_hi,
812                                     row);
813     store_vertical_filter_output_avx2(&res_lo, &res_hi, res_add_const, wt,
814                                       res_sub_const, round_bits_const, pred,
815                                       conv_params, i, j, k, reduce_bits_vert,
816                                       p_stride, p_width, round_bits);
817     src[0] = src[2];
818     src[2] = src[4];
819     src[4] = src[6];
820     src[1] = src[3];
821     src[3] = src[5];
822     src[5] = src[7];
823 
824     row += 1;
825   }
826 }
827 
warp_vertical_filter_gamma0_avx2(uint8_t * pred,__m256i * horz_out,ConvolveParams * conv_params,int16_t gamma,int16_t delta,int p_height,int p_stride,int p_width,int i,int j,int sy4,const int reduce_bits_vert,const __m256i * res_add_const,const int round_bits,const __m256i * res_sub_const,const __m256i * round_bits_const,const __m256i * wt)828 static INLINE void warp_vertical_filter_gamma0_avx2(
829     uint8_t *pred, __m256i *horz_out, ConvolveParams *conv_params,
830     int16_t gamma, int16_t delta, int p_height, int p_stride, int p_width,
831     int i, int j, int sy4, const int reduce_bits_vert,
832     const __m256i *res_add_const, const int round_bits,
833     const __m256i *res_sub_const, const __m256i *round_bits_const,
834     const __m256i *wt) {
835   (void)gamma;
836   int k, row = 0;
837   __m256i src[8];
838   const __m256i src_0 = horz_out[0];
839   const __m256i src_1 =
840       _mm256_permute2x128_si256(horz_out[0], horz_out[1], 0x21);
841   const __m256i src_2 = horz_out[1];
842   const __m256i src_3 =
843       _mm256_permute2x128_si256(horz_out[1], horz_out[2], 0x21);
844   const __m256i src_4 = horz_out[2];
845   const __m256i src_5 =
846       _mm256_permute2x128_si256(horz_out[2], horz_out[3], 0x21);
847 
848   src[0] = _mm256_unpacklo_epi16(src_0, src_1);
849   src[2] = _mm256_unpacklo_epi16(src_2, src_3);
850   src[4] = _mm256_unpacklo_epi16(src_4, src_5);
851 
852   src[1] = _mm256_unpackhi_epi16(src_0, src_1);
853   src[3] = _mm256_unpackhi_epi16(src_2, src_3);
854   src[5] = _mm256_unpackhi_epi16(src_4, src_5);
855 
856   for (k = -4; k < AOMMIN(4, p_height - i - 4); k += 2) {
857     int sy = sy4 + delta * (k + 4);
858     __m256i coeffs[8];
859     prepare_vertical_filter_coeffs_gamma0_avx2(delta, sy, coeffs);
860     __m256i res_lo, res_hi;
861     filter_src_pixels_vertical_avx2(horz_out, src, coeffs, &res_lo, &res_hi,
862                                     row);
863     store_vertical_filter_output_avx2(&res_lo, &res_hi, res_add_const, wt,
864                                       res_sub_const, round_bits_const, pred,
865                                       conv_params, i, j, k, reduce_bits_vert,
866                                       p_stride, p_width, round_bits);
867     src[0] = src[2];
868     src[2] = src[4];
869     src[4] = src[6];
870     src[1] = src[3];
871     src[3] = src[5];
872     src[5] = src[7];
873     row += 1;
874   }
875 }
876 
warp_vertical_filter_delta0_avx2(uint8_t * pred,__m256i * horz_out,ConvolveParams * conv_params,int16_t gamma,int16_t delta,int p_height,int p_stride,int p_width,int i,int j,int sy4,const int reduce_bits_vert,const __m256i * res_add_const,const int round_bits,const __m256i * res_sub_const,const __m256i * round_bits_const,const __m256i * wt)877 static INLINE void warp_vertical_filter_delta0_avx2(
878     uint8_t *pred, __m256i *horz_out, ConvolveParams *conv_params,
879     int16_t gamma, int16_t delta, int p_height, int p_stride, int p_width,
880     int i, int j, int sy4, const int reduce_bits_vert,
881     const __m256i *res_add_const, const int round_bits,
882     const __m256i *res_sub_const, const __m256i *round_bits_const,
883     const __m256i *wt) {
884   (void)delta;
885   int k, row = 0;
886   __m256i src[8], coeffs[8];
887   const __m256i src_0 = horz_out[0];
888   const __m256i src_1 =
889       _mm256_permute2x128_si256(horz_out[0], horz_out[1], 0x21);
890   const __m256i src_2 = horz_out[1];
891   const __m256i src_3 =
892       _mm256_permute2x128_si256(horz_out[1], horz_out[2], 0x21);
893   const __m256i src_4 = horz_out[2];
894   const __m256i src_5 =
895       _mm256_permute2x128_si256(horz_out[2], horz_out[3], 0x21);
896 
897   src[0] = _mm256_unpacklo_epi16(src_0, src_1);
898   src[2] = _mm256_unpacklo_epi16(src_2, src_3);
899   src[4] = _mm256_unpacklo_epi16(src_4, src_5);
900 
901   src[1] = _mm256_unpackhi_epi16(src_0, src_1);
902   src[3] = _mm256_unpackhi_epi16(src_2, src_3);
903   src[5] = _mm256_unpackhi_epi16(src_4, src_5);
904 
905   prepare_vertical_filter_coeffs_delta0_avx2(gamma, sy4, coeffs);
906 
907   for (k = -4; k < AOMMIN(4, p_height - i - 4); k += 2) {
908     __m256i res_lo, res_hi;
909     filter_src_pixels_vertical_avx2(horz_out, src, coeffs, &res_lo, &res_hi,
910                                     row);
911     store_vertical_filter_output_avx2(&res_lo, &res_hi, res_add_const, wt,
912                                       res_sub_const, round_bits_const, pred,
913                                       conv_params, i, j, k, reduce_bits_vert,
914                                       p_stride, p_width, round_bits);
915     src[0] = src[2];
916     src[2] = src[4];
917     src[4] = src[6];
918     src[1] = src[3];
919     src[3] = src[5];
920     src[5] = src[7];
921     row += 1;
922   }
923 }
924 
warp_vertical_filter_gamma0_delta0_avx2(uint8_t * pred,__m256i * horz_out,ConvolveParams * conv_params,int16_t gamma,int16_t delta,int p_height,int p_stride,int p_width,int i,int j,int sy4,const int reduce_bits_vert,const __m256i * res_add_const,const int round_bits,const __m256i * res_sub_const,const __m256i * round_bits_const,const __m256i * wt)925 static INLINE void warp_vertical_filter_gamma0_delta0_avx2(
926     uint8_t *pred, __m256i *horz_out, ConvolveParams *conv_params,
927     int16_t gamma, int16_t delta, int p_height, int p_stride, int p_width,
928     int i, int j, int sy4, const int reduce_bits_vert,
929     const __m256i *res_add_const, const int round_bits,
930     const __m256i *res_sub_const, const __m256i *round_bits_const,
931     const __m256i *wt) {
932   (void)gamma;
933   int k, row = 0;
934   __m256i src[8], coeffs[8];
935   const __m256i src_0 = horz_out[0];
936   const __m256i src_1 =
937       _mm256_permute2x128_si256(horz_out[0], horz_out[1], 0x21);
938   const __m256i src_2 = horz_out[1];
939   const __m256i src_3 =
940       _mm256_permute2x128_si256(horz_out[1], horz_out[2], 0x21);
941   const __m256i src_4 = horz_out[2];
942   const __m256i src_5 =
943       _mm256_permute2x128_si256(horz_out[2], horz_out[3], 0x21);
944 
945   src[0] = _mm256_unpacklo_epi16(src_0, src_1);
946   src[2] = _mm256_unpacklo_epi16(src_2, src_3);
947   src[4] = _mm256_unpacklo_epi16(src_4, src_5);
948 
949   src[1] = _mm256_unpackhi_epi16(src_0, src_1);
950   src[3] = _mm256_unpackhi_epi16(src_2, src_3);
951   src[5] = _mm256_unpackhi_epi16(src_4, src_5);
952 
953   prepare_vertical_filter_coeffs_gamma0_avx2(delta, sy4, coeffs);
954 
955   for (k = -4; k < AOMMIN(4, p_height - i - 4); k += 2) {
956     __m256i res_lo, res_hi;
957     filter_src_pixels_vertical_avx2(horz_out, src, coeffs, &res_lo, &res_hi,
958                                     row);
959     store_vertical_filter_output_avx2(&res_lo, &res_hi, res_add_const, wt,
960                                       res_sub_const, round_bits_const, pred,
961                                       conv_params, i, j, k, reduce_bits_vert,
962                                       p_stride, p_width, round_bits);
963     src[0] = src[2];
964     src[2] = src[4];
965     src[4] = src[6];
966     src[1] = src[3];
967     src[3] = src[5];
968     src[5] = src[7];
969     row += 1;
970   }
971 }
972 
prepare_warp_vertical_filter_avx2(uint8_t * pred,__m256i * horz_out,ConvolveParams * conv_params,int16_t gamma,int16_t delta,int p_height,int p_stride,int p_width,int i,int j,int sy4,const int reduce_bits_vert,const __m256i * res_add_const,const int round_bits,const __m256i * res_sub_const,const __m256i * round_bits_const,const __m256i * wt)973 static INLINE void prepare_warp_vertical_filter_avx2(
974     uint8_t *pred, __m256i *horz_out, ConvolveParams *conv_params,
975     int16_t gamma, int16_t delta, int p_height, int p_stride, int p_width,
976     int i, int j, int sy4, const int reduce_bits_vert,
977     const __m256i *res_add_const, const int round_bits,
978     const __m256i *res_sub_const, const __m256i *round_bits_const,
979     const __m256i *wt) {
980   if (gamma == 0 && delta == 0)
981     warp_vertical_filter_gamma0_delta0_avx2(
982         pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
983         i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
984         round_bits_const, wt);
985   else if (gamma == 0 && delta != 0)
986     warp_vertical_filter_gamma0_avx2(
987         pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
988         i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
989         round_bits_const, wt);
990   else if (gamma != 0 && delta == 0)
991     warp_vertical_filter_delta0_avx2(
992         pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
993         i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
994         round_bits_const, wt);
995   else
996     warp_vertical_filter_avx2(pred, horz_out, conv_params, gamma, delta,
997                               p_height, p_stride, p_width, i, j, sy4,
998                               reduce_bits_vert, res_add_const, round_bits,
999                               res_sub_const, round_bits_const, wt);
1000 }
1001 
prepare_warp_horizontal_filter_avx2(const uint8_t * ref,__m256i * horz_out,int stride,int32_t ix4,int32_t iy4,int32_t sx4,int alpha,int beta,int p_height,int height,int i,const __m256i * round_const,const __m128i * shift,const __m256i * shuffle_src)1002 static INLINE void prepare_warp_horizontal_filter_avx2(
1003     const uint8_t *ref, __m256i *horz_out, int stride, int32_t ix4, int32_t iy4,
1004     int32_t sx4, int alpha, int beta, int p_height, int height, int i,
1005     const __m256i *round_const, const __m128i *shift,
1006     const __m256i *shuffle_src) {
1007   if (alpha == 0 && beta == 0)
1008     warp_horizontal_filter_alpha0_beta0_avx2(
1009         ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
1010         round_const, shift, shuffle_src);
1011   else if (alpha == 0 && beta != 0)
1012     warp_horizontal_filter_alpha0_avx2(ref, horz_out, stride, ix4, iy4, sx4,
1013                                        alpha, beta, p_height, height, i,
1014                                        round_const, shift, shuffle_src);
1015   else if (alpha != 0 && beta == 0)
1016     warp_horizontal_filter_beta0_avx2(ref, horz_out, stride, ix4, iy4, sx4,
1017                                       alpha, beta, p_height, height, i,
1018                                       round_const, shift, shuffle_src);
1019   else
1020     warp_horizontal_filter_avx2(ref, horz_out, stride, ix4, iy4, sx4, alpha,
1021                                 beta, p_height, height, i, round_const, shift,
1022                                 shuffle_src);
1023 }
1024 
av1_calc_frame_error_avx2(const uint8_t * const ref,int ref_stride,const uint8_t * const dst,int p_width,int p_height,int dst_stride)1025 int64_t av1_calc_frame_error_avx2(const uint8_t *const ref, int ref_stride,
1026                                   const uint8_t *const dst, int p_width,
1027                                   int p_height, int dst_stride) {
1028   int64_t sum_error = 0;
1029   int i, j;
1030   __m256i row_error, col_error;
1031   __m256i zero = _mm256_setzero_si256();
1032   __m256i dup_255 = _mm256_set1_epi16(255);
1033   col_error = zero;
1034 
1035   for (i = 0; i < (p_height / 4); i++) {
1036     row_error = _mm256_setzero_si256();
1037     for (j = 0; j < (p_width / 16); j++) {
1038       __m256i ref_1_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1039           (__m128i *)(ref + (j * 16) + (((i * 4) + 0) * ref_stride))));
1040       __m256i dst_1_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1041           (__m128i *)(dst + (j * 16) + (((i * 4) + 0) * dst_stride))));
1042       __m256i ref_2_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1043           (__m128i *)(ref + (j * 16) + (((i * 4) + 1) * ref_stride))));
1044       __m256i dst_2_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1045           (__m128i *)(dst + (j * 16) + (((i * 4) + 1) * dst_stride))));
1046       __m256i ref_3_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1047           (__m128i *)(ref + (j * 16) + (((i * 4) + 2) * ref_stride))));
1048       __m256i dst_3_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1049           (__m128i *)(dst + (j * 16) + (((i * 4) + 2) * dst_stride))));
1050       __m256i ref_4_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1051           (__m128i *)(ref + (j * 16) + (((i * 4) + 3) * ref_stride))));
1052       __m256i dst_4_16 = _mm256_cvtepu8_epi16(_mm_load_si128(
1053           (__m128i *)(dst + (j * 16) + (((i * 4) + 3) * dst_stride))));
1054 
1055       __m256i diff_1 =
1056           _mm256_add_epi16(_mm256_sub_epi16(dst_1_16, ref_1_16), dup_255);
1057       __m256i diff_2 =
1058           _mm256_add_epi16(_mm256_sub_epi16(dst_2_16, ref_2_16), dup_255);
1059       __m256i diff_3 =
1060           _mm256_add_epi16(_mm256_sub_epi16(dst_3_16, ref_3_16), dup_255);
1061       __m256i diff_4 =
1062           _mm256_add_epi16(_mm256_sub_epi16(dst_4_16, ref_4_16), dup_255);
1063 
1064       __m256i diff_1_lo = _mm256_unpacklo_epi16(diff_1, zero);
1065       __m256i diff_1_hi = _mm256_unpackhi_epi16(diff_1, zero);
1066       __m256i diff_2_lo = _mm256_unpacklo_epi16(diff_2, zero);
1067       __m256i diff_2_hi = _mm256_unpackhi_epi16(diff_2, zero);
1068       __m256i diff_3_lo = _mm256_unpacklo_epi16(diff_3, zero);
1069       __m256i diff_3_hi = _mm256_unpackhi_epi16(diff_3, zero);
1070       __m256i diff_4_lo = _mm256_unpacklo_epi16(diff_4, zero);
1071       __m256i diff_4_hi = _mm256_unpackhi_epi16(diff_4, zero);
1072 
1073       __m256i error_1_lo =
1074           _mm256_i32gather_epi32(error_measure_lut, diff_1_lo, 4);
1075       __m256i error_1_hi =
1076           _mm256_i32gather_epi32(error_measure_lut, diff_1_hi, 4);
1077       __m256i error_2_lo =
1078           _mm256_i32gather_epi32(error_measure_lut, diff_2_lo, 4);
1079       __m256i error_2_hi =
1080           _mm256_i32gather_epi32(error_measure_lut, diff_2_hi, 4);
1081       __m256i error_3_lo =
1082           _mm256_i32gather_epi32(error_measure_lut, diff_3_lo, 4);
1083       __m256i error_3_hi =
1084           _mm256_i32gather_epi32(error_measure_lut, diff_3_hi, 4);
1085       __m256i error_4_lo =
1086           _mm256_i32gather_epi32(error_measure_lut, diff_4_lo, 4);
1087       __m256i error_4_hi =
1088           _mm256_i32gather_epi32(error_measure_lut, diff_4_hi, 4);
1089 
1090       __m256i error_1 = _mm256_add_epi32(error_1_lo, error_1_hi);
1091       __m256i error_2 = _mm256_add_epi32(error_2_lo, error_2_hi);
1092       __m256i error_3 = _mm256_add_epi32(error_3_lo, error_3_hi);
1093       __m256i error_4 = _mm256_add_epi32(error_4_lo, error_4_hi);
1094 
1095       __m256i error_1_2 = _mm256_add_epi32(error_1, error_2);
1096       __m256i error_3_4 = _mm256_add_epi32(error_3, error_4);
1097 
1098       __m256i error_1_2_3_4 = _mm256_add_epi32(error_1_2, error_3_4);
1099       row_error = _mm256_add_epi32(row_error, error_1_2_3_4);
1100     }
1101     __m256i col_error_lo = _mm256_unpacklo_epi32(row_error, zero);
1102     __m256i col_error_hi = _mm256_unpackhi_epi32(row_error, zero);
1103     __m256i col_error_temp = _mm256_add_epi64(col_error_lo, col_error_hi);
1104     col_error = _mm256_add_epi64(col_error, col_error_temp);
1105     // Error summation for remaining width, which is not multiple of 16
1106     if (p_width & 0xf) {
1107       for (int k = 0; k < 4; ++k) {
1108         for (int l = j * 16; l < p_width; ++l) {
1109           sum_error +=
1110               (int64_t)error_measure(dst[l + ((i * 4) + k) * dst_stride] -
1111                                      ref[l + ((i * 4) + k) * ref_stride]);
1112         }
1113       }
1114     }
1115   }
1116   __m128i sum_error_q_0 = _mm256_castsi256_si128(col_error);
1117   __m128i sum_error_q_1 = _mm256_extracti128_si256(col_error, 1);
1118   sum_error_q_0 = _mm_add_epi64(sum_error_q_0, sum_error_q_1);
1119   int64_t sum_error_d_0, sum_error_d_1;
1120   xx_storel_64(&sum_error_d_0, sum_error_q_0);
1121   xx_storel_64(&sum_error_d_1, _mm_srli_si128(sum_error_q_0, 8));
1122   sum_error = (sum_error + sum_error_d_0 + sum_error_d_1);
1123   // Error summation for remaining height, which is not multiple of 4
1124   if (p_height & 0x3) {
1125     for (int k = i * 4; k < p_height; ++k) {
1126       for (int l = 0; l < p_width; ++l) {
1127         sum_error += (int64_t)error_measure(dst[l + k * dst_stride] -
1128                                             ref[l + k * ref_stride]);
1129       }
1130     }
1131   }
1132   return sum_error;
1133 }
1134 
av1_warp_affine_avx2(const int32_t * mat,const uint8_t * ref,int width,int height,int stride,uint8_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)1135 void av1_warp_affine_avx2(const int32_t *mat, const uint8_t *ref, int width,
1136                           int height, int stride, uint8_t *pred, int p_col,
1137                           int p_row, int p_width, int p_height, int p_stride,
1138                           int subsampling_x, int subsampling_y,
1139                           ConvolveParams *conv_params, int16_t alpha,
1140                           int16_t beta, int16_t gamma, int16_t delta) {
1141   __m256i horz_out[8];
1142   int i, j, k;
1143   const int bd = 8;
1144   const int reduce_bits_horiz = conv_params->round_0;
1145   const int reduce_bits_vert = conv_params->is_compound
1146                                    ? conv_params->round_1
1147                                    : 2 * FILTER_BITS - reduce_bits_horiz;
1148   const int offset_bits_horiz = bd + FILTER_BITS - 1;
1149   assert(IMPLIES(conv_params->is_compound, conv_params->dst != NULL));
1150 
1151   const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz;
1152   const __m256i reduce_bits_vert_const =
1153       _mm256_set1_epi32(((1 << reduce_bits_vert) >> 1));
1154   const __m256i res_add_const = _mm256_set1_epi32(1 << offset_bits_vert);
1155   const int round_bits =
1156       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
1157   const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
1158   assert(IMPLIES(conv_params->do_average, conv_params->is_compound));
1159 
1160   const __m256i round_const = _mm256_set1_epi16(
1161       (1 << offset_bits_horiz) + ((1 << reduce_bits_horiz) >> 1));
1162   const __m128i shift = _mm_cvtsi32_si128(reduce_bits_horiz);
1163 
1164   __m256i res_sub_const, round_bits_const, wt;
1165   unpack_weights_and_set_round_const_avx2(conv_params, round_bits, offset_bits,
1166                                           &res_sub_const, &round_bits_const,
1167                                           &wt);
1168 
1169   __m256i res_add_const_1;
1170   if (conv_params->is_compound == 1) {
1171     res_add_const_1 = _mm256_add_epi32(reduce_bits_vert_const, res_add_const);
1172   } else {
1173     res_add_const_1 = _mm256_set1_epi32(-(1 << (bd + reduce_bits_vert - 1)) +
1174                                         ((1 << reduce_bits_vert) >> 1));
1175   }
1176   const int32_t const1 = alpha * (-4) + beta * (-4) +
1177                          (1 << (WARPEDDIFF_PREC_BITS - 1)) +
1178                          (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
1179   const int32_t const2 = gamma * (-4) + delta * (-4) +
1180                          (1 << (WARPEDDIFF_PREC_BITS - 1)) +
1181                          (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
1182   const int32_t const3 = ((1 << WARP_PARAM_REDUCE_BITS) - 1);
1183   const int16_t const4 = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1));
1184   const int16_t const5 = (1 << (FILTER_BITS - reduce_bits_horiz));
1185 
1186   __m256i shuffle_src[4];
1187   shuffle_src[0] = _mm256_load_si256((__m256i *)shuffle_src0);
1188   shuffle_src[1] = _mm256_load_si256((__m256i *)shuffle_src1);
1189   shuffle_src[2] = _mm256_load_si256((__m256i *)shuffle_src2);
1190   shuffle_src[3] = _mm256_load_si256((__m256i *)shuffle_src3);
1191 
1192   for (i = 0; i < p_height; i += 8) {
1193     for (j = 0; j < p_width; j += 8) {
1194       const int32_t src_x = (p_col + j + 4) << subsampling_x;
1195       const int32_t src_y = (p_row + i + 4) << subsampling_y;
1196       const int64_t dst_x =
1197           (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0];
1198       const int64_t dst_y =
1199           (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1];
1200       const int64_t x4 = dst_x >> subsampling_x;
1201       const int64_t y4 = dst_y >> subsampling_y;
1202 
1203       int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS);
1204       int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
1205       int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS);
1206       int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
1207 
1208       // Add in all the constant terms, including rounding and offset
1209       sx4 += const1;
1210       sy4 += const2;
1211 
1212       sx4 &= ~const3;
1213       sy4 &= ~const3;
1214 
1215       // Horizontal filter
1216       // If the block is aligned such that, after clamping, every sample
1217       // would be taken from the leftmost/rightmost column, then we can
1218       // skip the expensive horizontal filter.
1219 
1220       if (ix4 <= -7) {
1221         int iy, row = 0;
1222         for (k = -7; k <= (AOMMIN(8, p_height - i) - 2); k += 2) {
1223           iy = iy4 + k;
1224           iy = clamp(iy, 0, height - 1);
1225           const __m256i temp_0 =
1226               _mm256_set1_epi16(const4 + ref[iy * stride] * const5);
1227           iy = iy4 + k + 1;
1228           iy = clamp(iy, 0, height - 1);
1229           const __m256i temp_1 =
1230               _mm256_set1_epi16(const4 + ref[iy * stride] * const5);
1231           horz_out[row] = _mm256_blend_epi32(temp_0, temp_1, 0xf0);
1232           row += 1;
1233         }
1234         iy = iy4 + k;
1235         iy = clamp(iy, 0, height - 1);
1236         horz_out[row] = _mm256_set1_epi16(const4 + ref[iy * stride] * const5);
1237       } else if (ix4 >= width + 6) {
1238         int iy, row = 0;
1239         for (k = -7; k <= (AOMMIN(8, p_height - i) - 2); k += 2) {
1240           iy = iy4 + k;
1241           iy = clamp(iy, 0, height - 1);
1242           const __m256i temp_0 = _mm256_set1_epi16(
1243               const4 + ref[iy * stride + (width - 1)] * const5);
1244           iy = iy4 + k + 1;
1245           iy = clamp(iy, 0, height - 1);
1246           const __m256i temp_1 = _mm256_set1_epi16(
1247               const4 + ref[iy * stride + (width - 1)] * const5);
1248           horz_out[row] = _mm256_blend_epi32(temp_0, temp_1, 0xf0);
1249           row += 1;
1250         }
1251         iy = iy4 + k;
1252         iy = clamp(iy, 0, height - 1);
1253         horz_out[row] =
1254             _mm256_set1_epi16(const4 + ref[iy * stride + (width - 1)] * const5);
1255       } else if (((ix4 - 7) < 0) || ((ix4 + 9) > width)) {
1256         const int out_of_boundary_left = -(ix4 - 6);
1257         const int out_of_boundary_right = (ix4 + 8) - width;
1258         int iy, sx, row = 0;
1259         for (k = -7; k <= (AOMMIN(8, p_height - i) - 2); k += 2) {
1260           iy = iy4 + k;
1261           iy = clamp(iy, 0, height - 1);
1262           __m128i src0 =
1263               _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
1264           iy = iy4 + k + 1;
1265           iy = clamp(iy, 0, height - 1);
1266           __m128i src1 =
1267               _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
1268 
1269           if (out_of_boundary_left >= 0) {
1270             const __m128i shuffle_reg_left =
1271                 _mm_loadu_si128((__m128i *)warp_pad_left[out_of_boundary_left]);
1272             src0 = _mm_shuffle_epi8(src0, shuffle_reg_left);
1273             src1 = _mm_shuffle_epi8(src1, shuffle_reg_left);
1274           }
1275           if (out_of_boundary_right >= 0) {
1276             const __m128i shuffle_reg_right = _mm_loadu_si128(
1277                 (__m128i *)warp_pad_right[out_of_boundary_right]);
1278             src0 = _mm_shuffle_epi8(src0, shuffle_reg_right);
1279             src1 = _mm_shuffle_epi8(src1, shuffle_reg_right);
1280           }
1281           sx = sx4 + beta * (k + 4);
1282           const __m256i src_01 =
1283               _mm256_inserti128_si256(_mm256_castsi128_si256(src0), src1, 0x1);
1284           horizontal_filter_avx2(src_01, horz_out, sx, alpha, beta, row,
1285                                  shuffle_src, &round_const, &shift);
1286           row += 1;
1287         }
1288         iy = iy4 + k;
1289         iy = clamp(iy, 0, height - 1);
1290         __m128i src = _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7));
1291         if (out_of_boundary_left >= 0) {
1292           const __m128i shuffle_reg_left =
1293               _mm_loadu_si128((__m128i *)warp_pad_left[out_of_boundary_left]);
1294           src = _mm_shuffle_epi8(src, shuffle_reg_left);
1295         }
1296         if (out_of_boundary_right >= 0) {
1297           const __m128i shuffle_reg_right =
1298               _mm_loadu_si128((__m128i *)warp_pad_right[out_of_boundary_right]);
1299           src = _mm_shuffle_epi8(src, shuffle_reg_right);
1300         }
1301         sx = sx4 + beta * (k + 4);
1302         const __m256i src_01 = _mm256_castsi128_si256(src);
1303         __m256i coeff[4];
1304         prepare_horizontal_filter_coeff(alpha, sx, coeff);
1305         filter_src_pixels_avx2(src_01, horz_out, coeff, shuffle_src,
1306                                &round_const, &shift, row);
1307       } else {
1308         prepare_warp_horizontal_filter_avx2(
1309             ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height,
1310             i, &round_const, &shift, shuffle_src);
1311       }
1312 
1313       // Vertical filter
1314       prepare_warp_vertical_filter_avx2(
1315           pred, horz_out, conv_params, gamma, delta, p_height, p_stride,
1316           p_width, i, j, sy4, reduce_bits_vert, &res_add_const_1, round_bits,
1317           &res_sub_const, &round_bits_const, &wt);
1318     }
1319   }
1320 }
1321