• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/convolve.h"
16 #include "src/utils/constants.h"
17 #include "src/utils/cpu.h"
18 
19 #if LIBGAV1_ENABLE_SSE4_1
20 #include <smmintrin.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/dsp/x86/common_sse4.h"
30 #include "src/utils/common.h"
31 
32 namespace libgav1 {
33 namespace dsp {
34 namespace low_bitdepth {
35 namespace {
36 
37 // TODO(slavarnway): Move to common neon/sse4 file.
GetNumTapsInFilter(const int filter_index)38 int GetNumTapsInFilter(const int filter_index) {
39   if (filter_index < 2) {
40     // Despite the names these only use 6 taps.
41     // kInterpolationFilterEightTap
42     // kInterpolationFilterEightTapSmooth
43     return 6;
44   }
45 
46   if (filter_index == 2) {
47     // kInterpolationFilterEightTapSharp
48     return 8;
49   }
50 
51   if (filter_index == 3) {
52     // kInterpolationFilterBilinear
53     return 2;
54   }
55 
56   assert(filter_index > 3);
57   // For small sizes (width/height <= 4) the large filters are replaced with 4
58   // tap options.
59   // If the original filters were |kInterpolationFilterEightTap| or
60   // |kInterpolationFilterEightTapSharp| then it becomes
61   // |kInterpolationFilterSwitchable|.
62   // If it was |kInterpolationFilterEightTapSmooth| then it becomes an unnamed 4
63   // tap filter.
64   return 4;
65 }
66 
67 constexpr int kIntermediateStride = kMaxSuperBlockSizeInPixels;
68 constexpr int kHorizontalOffset = 3;
69 constexpr int kFilterIndexShift = 6;
70 
71 // Multiply every entry in |src[]| by the corresponding entry in |taps[]| and
72 // sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final
73 // sum from outranging int16_t.
74 template <int filter_index>
SumOnePassTaps(const __m128i * const src,const __m128i * const taps)75 __m128i SumOnePassTaps(const __m128i* const src, const __m128i* const taps) {
76   __m128i sum;
77   if (filter_index < 2) {
78     // 6 taps.
79     const __m128i v_madd_21 = _mm_maddubs_epi16(src[0], taps[0]);  // k2k1
80     const __m128i v_madd_43 = _mm_maddubs_epi16(src[1], taps[1]);  // k4k3
81     const __m128i v_madd_65 = _mm_maddubs_epi16(src[2], taps[2]);  // k6k5
82     sum = _mm_add_epi16(v_madd_21, v_madd_43);
83     sum = _mm_add_epi16(sum, v_madd_65);
84   } else if (filter_index == 2) {
85     // 8 taps.
86     const __m128i v_madd_10 = _mm_maddubs_epi16(src[0], taps[0]);  // k1k0
87     const __m128i v_madd_32 = _mm_maddubs_epi16(src[1], taps[1]);  // k3k2
88     const __m128i v_madd_54 = _mm_maddubs_epi16(src[2], taps[2]);  // k5k4
89     const __m128i v_madd_76 = _mm_maddubs_epi16(src[3], taps[3]);  // k7k6
90     const __m128i v_sum_3210 = _mm_add_epi16(v_madd_10, v_madd_32);
91     const __m128i v_sum_7654 = _mm_add_epi16(v_madd_54, v_madd_76);
92     sum = _mm_add_epi16(v_sum_7654, v_sum_3210);
93   } else if (filter_index == 3) {
94     // 2 taps.
95     sum = _mm_maddubs_epi16(src[0], taps[0]);  // k4k3
96   } else {
97     // 4 taps.
98     const __m128i v_madd_32 = _mm_maddubs_epi16(src[0], taps[0]);  // k3k2
99     const __m128i v_madd_54 = _mm_maddubs_epi16(src[1], taps[1]);  // k5k4
100     sum = _mm_add_epi16(v_madd_32, v_madd_54);
101   }
102   return sum;
103 }
104 
105 template <int filter_index>
SumHorizontalTaps(const uint8_t * const src,const __m128i * const v_tap)106 __m128i SumHorizontalTaps(const uint8_t* const src,
107                           const __m128i* const v_tap) {
108   __m128i v_src[4];
109   const __m128i src_long = LoadUnaligned16(src);
110   const __m128i src_long_dup_lo = _mm_unpacklo_epi8(src_long, src_long);
111   const __m128i src_long_dup_hi = _mm_unpackhi_epi8(src_long, src_long);
112 
113   if (filter_index < 2) {
114     // 6 taps.
115     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 3);   // _21
116     v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);   // _43
117     v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 11);  // _65
118   } else if (filter_index == 2) {
119     // 8 taps.
120     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 1);   // _10
121     v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);   // _32
122     v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);   // _54
123     v_src[3] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 13);  // _76
124   } else if (filter_index == 3) {
125     // 2 taps.
126     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);  // _43
127   } else if (filter_index > 3) {
128     // 4 taps.
129     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);  // _32
130     v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);  // _54
131   }
132   const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap);
133   return sum;
134 }
135 
136 template <int filter_index>
SimpleHorizontalTaps(const uint8_t * const src,const __m128i * const v_tap)137 __m128i SimpleHorizontalTaps(const uint8_t* const src,
138                              const __m128i* const v_tap) {
139   __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap);
140 
141   // Normally the Horizontal pass does the downshift in two passes:
142   // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
143   // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
144   // requires adding the rounding offset from the skipped shift.
145   constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
146 
147   sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit));
148   sum = RightShiftWithRounding_S16(sum, kFilterBits - 1);
149   return _mm_packus_epi16(sum, sum);
150 }
151 
152 template <int filter_index>
HorizontalTaps8To16(const uint8_t * const src,const __m128i * const v_tap)153 __m128i HorizontalTaps8To16(const uint8_t* const src,
154                             const __m128i* const v_tap) {
155   const __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap);
156 
157   return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
158 }
159 
160 template <int filter_index>
SumHorizontalTaps2x2(const uint8_t * src,const ptrdiff_t src_stride,const __m128i * const v_tap)161 __m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride,
162                              const __m128i* const v_tap) {
163   const __m128i input0 = LoadLo8(&src[2]);
164   const __m128i input1 = LoadLo8(&src[2 + src_stride]);
165 
166   if (filter_index == 3) {
167     // 03 04 04 05 05 06 06 07 ....
168     const __m128i input0_dup =
169         _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 3);
170     // 13 14 14 15 15 16 16 17 ....
171     const __m128i input1_dup =
172         _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 3);
173     const __m128i v_src_43 = _mm_unpacklo_epi64(input0_dup, input1_dup);
174     const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]);  // k4k3
175     return v_sum_43;
176   }
177 
178   // 02 03 03 04 04 05 05 06 06 07 ....
179   const __m128i input0_dup =
180       _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 1);
181   // 12 13 13 14 14 15 15 16 16 17 ....
182   const __m128i input1_dup =
183       _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 1);
184   // 04 05 05 06 06 07 07 08 ...
185   const __m128i input0_dup_54 = _mm_srli_si128(input0_dup, 4);
186   // 14 15 15 16 16 17 17 18 ...
187   const __m128i input1_dup_54 = _mm_srli_si128(input1_dup, 4);
188   const __m128i v_src_32 = _mm_unpacklo_epi64(input0_dup, input1_dup);
189   const __m128i v_src_54 = _mm_unpacklo_epi64(input0_dup_54, input1_dup_54);
190   const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]);  // k3k2
191   const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]);  // k5k4
192   const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32);
193   return v_sum_5432;
194 }
195 
196 template <int filter_index>
SimpleHorizontalTaps2x2(const uint8_t * src,const ptrdiff_t src_stride,const __m128i * const v_tap)197 __m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride,
198                                 const __m128i* const v_tap) {
199   __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
200 
201   // Normally the Horizontal pass does the downshift in two passes:
202   // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
203   // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
204   // requires adding the rounding offset from the skipped shift.
205   constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
206 
207   sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit));
208   sum = RightShiftWithRounding_S16(sum, kFilterBits - 1);
209   return _mm_packus_epi16(sum, sum);
210 }
211 
212 template <int filter_index>
HorizontalTaps8To16_2x2(const uint8_t * src,const ptrdiff_t src_stride,const __m128i * const v_tap)213 __m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride,
214                                 const __m128i* const v_tap) {
215   const __m128i sum =
216       SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
217 
218   return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
219 }
220 
221 template <int num_taps, int step, int filter_index, bool is_2d = false,
222           bool is_compound = false>
FilterHorizontal(const uint8_t * src,const ptrdiff_t src_stride,void * const dest,const ptrdiff_t pred_stride,const int width,const int height,const __m128i * const v_tap)223 void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride,
224                       void* const dest, const ptrdiff_t pred_stride,
225                       const int width, const int height,
226                       const __m128i* const v_tap) {
227   auto* dest8 = static_cast<uint8_t*>(dest);
228   auto* dest16 = static_cast<uint16_t*>(dest);
229 
230   // 4 tap filters are never used when width > 4.
231   if (num_taps != 4 && width > 4) {
232     int y = 0;
233     do {
234       int x = 0;
235       do {
236         if (is_2d || is_compound) {
237           const __m128i v_sum =
238               HorizontalTaps8To16<filter_index>(&src[x], v_tap);
239           if (is_2d) {
240             StoreAligned16(&dest16[x], v_sum);
241           } else {
242             StoreUnaligned16(&dest16[x], v_sum);
243           }
244         } else {
245           const __m128i result =
246               SimpleHorizontalTaps<filter_index>(&src[x], v_tap);
247           StoreLo8(&dest8[x], result);
248         }
249         x += step;
250       } while (x < width);
251       src += src_stride;
252       dest8 += pred_stride;
253       dest16 += pred_stride;
254     } while (++y < height);
255     return;
256   }
257 
258   // Horizontal passes only needs to account for |num_taps| 2 and 4 when
259   // |width| <= 4.
260   assert(width <= 4);
261   assert(num_taps <= 4);
262   if (num_taps <= 4) {
263     if (width == 4) {
264       int y = 0;
265       do {
266         if (is_2d || is_compound) {
267           const __m128i v_sum = HorizontalTaps8To16<filter_index>(src, v_tap);
268           StoreLo8(dest16, v_sum);
269         } else {
270           const __m128i result = SimpleHorizontalTaps<filter_index>(src, v_tap);
271           Store4(&dest8[0], result);
272         }
273         src += src_stride;
274         dest8 += pred_stride;
275         dest16 += pred_stride;
276       } while (++y < height);
277       return;
278     }
279 
280     if (!is_compound) {
281       int y = 0;
282       do {
283         if (is_2d) {
284           const __m128i sum =
285               HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap);
286           Store4(&dest16[0], sum);
287           dest16 += pred_stride;
288           Store4(&dest16[0], _mm_srli_si128(sum, 8));
289           dest16 += pred_stride;
290         } else {
291           const __m128i sum =
292               SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
293           Store2(dest8, sum);
294           dest8 += pred_stride;
295           Store2(dest8, _mm_srli_si128(sum, 4));
296           dest8 += pred_stride;
297         }
298 
299         src += src_stride << 1;
300         y += 2;
301       } while (y < height - 1);
302 
303       // The 2d filters have an odd |height| because the horizontal pass
304       // generates context for the vertical pass.
305       if (is_2d) {
306         assert(height % 2 == 1);
307         __m128i sum;
308         const __m128i input = LoadLo8(&src[2]);
309         if (filter_index == 3) {
310           // 03 04 04 05 05 06 06 07 ....
311           const __m128i v_src_43 =
312               _mm_srli_si128(_mm_unpacklo_epi8(input, input), 3);
313           sum = _mm_maddubs_epi16(v_src_43, v_tap[0]);  // k4k3
314         } else {
315           // 02 03 03 04 04 05 05 06 06 07 ....
316           const __m128i v_src_32 =
317               _mm_srli_si128(_mm_unpacklo_epi8(input, input), 1);
318           // 04 05 05 06 06 07 07 08 ...
319           const __m128i v_src_54 = _mm_srli_si128(v_src_32, 4);
320           const __m128i v_madd_32 =
321               _mm_maddubs_epi16(v_src_32, v_tap[0]);  // k3k2
322           const __m128i v_madd_54 =
323               _mm_maddubs_epi16(v_src_54, v_tap[1]);  // k5k4
324           sum = _mm_add_epi16(v_madd_54, v_madd_32);
325         }
326         sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
327         Store4(dest16, sum);
328       }
329     }
330   }
331 }
332 
333 template <int num_taps, bool is_2d_vertical = false>
SetupTaps(const __m128i * const filter,__m128i * v_tap)334 LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter,
335                                      __m128i* v_tap) {
336   if (num_taps == 8) {
337     v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0);   // k1k0
338     v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55);  // k3k2
339     v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa);  // k5k4
340     v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff);  // k7k6
341     if (is_2d_vertical) {
342       v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
343       v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]);
344       v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]);
345       v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]);
346     } else {
347       v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
348       v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]);
349       v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]);
350       v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]);
351     }
352   } else if (num_taps == 6) {
353     const __m128i adjusted_filter = _mm_srli_si128(*filter, 1);
354     v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0);   // k2k1
355     v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55);  // k4k3
356     v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa);  // k6k5
357     if (is_2d_vertical) {
358       v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
359       v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]);
360       v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]);
361     } else {
362       v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
363       v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]);
364       v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]);
365     }
366   } else if (num_taps == 4) {
367     v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55);  // k3k2
368     v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa);  // k5k4
369     if (is_2d_vertical) {
370       v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
371       v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]);
372     } else {
373       v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
374       v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]);
375     }
376   } else {  // num_taps == 2
377     const __m128i adjusted_filter = _mm_srli_si128(*filter, 1);
378     v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55);  // k4k3
379     if (is_2d_vertical) {
380       v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
381     } else {
382       v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
383     }
384   }
385 }
386 
387 template <int num_taps, bool is_compound>
SimpleSum2DVerticalTaps(const __m128i * const src,const __m128i * const taps)388 __m128i SimpleSum2DVerticalTaps(const __m128i* const src,
389                                 const __m128i* const taps) {
390   __m128i sum_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[0], src[1]), taps[0]);
391   __m128i sum_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[0], src[1]), taps[0]);
392   if (num_taps >= 4) {
393     __m128i madd_lo =
394         _mm_madd_epi16(_mm_unpacklo_epi16(src[2], src[3]), taps[1]);
395     __m128i madd_hi =
396         _mm_madd_epi16(_mm_unpackhi_epi16(src[2], src[3]), taps[1]);
397     sum_lo = _mm_add_epi32(sum_lo, madd_lo);
398     sum_hi = _mm_add_epi32(sum_hi, madd_hi);
399     if (num_taps >= 6) {
400       madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[4], src[5]), taps[2]);
401       madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[4], src[5]), taps[2]);
402       sum_lo = _mm_add_epi32(sum_lo, madd_lo);
403       sum_hi = _mm_add_epi32(sum_hi, madd_hi);
404       if (num_taps == 8) {
405         madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[6], src[7]), taps[3]);
406         madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[6], src[7]), taps[3]);
407         sum_lo = _mm_add_epi32(sum_lo, madd_lo);
408         sum_hi = _mm_add_epi32(sum_hi, madd_hi);
409       }
410     }
411   }
412 
413   if (is_compound) {
414     return _mm_packs_epi32(
415         RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
416         RightShiftWithRounding_S32(sum_hi,
417                                    kInterRoundBitsCompoundVertical - 1));
418   }
419 
420   return _mm_packs_epi32(
421       RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
422       RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
423 }
424 
425 template <int num_taps, bool is_compound = false>
Filter2DVertical(const uint16_t * src,void * const dst,const ptrdiff_t dst_stride,const int width,const int height,const __m128i * const taps)426 void Filter2DVertical(const uint16_t* src, void* const dst,
427                       const ptrdiff_t dst_stride, const int width,
428                       const int height, const __m128i* const taps) {
429   assert(width >= 8);
430   constexpr int next_row = num_taps - 1;
431   // The Horizontal pass uses |width| as |stride| for the intermediate buffer.
432   const ptrdiff_t src_stride = width;
433 
434   auto* dst8 = static_cast<uint8_t*>(dst);
435   auto* dst16 = static_cast<uint16_t*>(dst);
436 
437   int x = 0;
438   do {
439     __m128i srcs[8];
440     const uint16_t* src_x = src + x;
441     srcs[0] = LoadAligned16(src_x);
442     src_x += src_stride;
443     if (num_taps >= 4) {
444       srcs[1] = LoadAligned16(src_x);
445       src_x += src_stride;
446       srcs[2] = LoadAligned16(src_x);
447       src_x += src_stride;
448       if (num_taps >= 6) {
449         srcs[3] = LoadAligned16(src_x);
450         src_x += src_stride;
451         srcs[4] = LoadAligned16(src_x);
452         src_x += src_stride;
453         if (num_taps == 8) {
454           srcs[5] = LoadAligned16(src_x);
455           src_x += src_stride;
456           srcs[6] = LoadAligned16(src_x);
457           src_x += src_stride;
458         }
459       }
460     }
461 
462     int y = 0;
463     do {
464       srcs[next_row] = LoadAligned16(src_x);
465       src_x += src_stride;
466 
467       const __m128i sum =
468           SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
469       if (is_compound) {
470         StoreUnaligned16(dst16 + x + y * dst_stride, sum);
471       } else {
472         StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(sum, sum));
473       }
474 
475       srcs[0] = srcs[1];
476       if (num_taps >= 4) {
477         srcs[1] = srcs[2];
478         srcs[2] = srcs[3];
479         if (num_taps >= 6) {
480           srcs[3] = srcs[4];
481           srcs[4] = srcs[5];
482           if (num_taps == 8) {
483             srcs[5] = srcs[6];
484             srcs[6] = srcs[7];
485           }
486         }
487       }
488     } while (++y < height);
489     x += 8;
490   } while (x < width);
491 }
492 
493 // Take advantage of |src_stride| == |width| to process two rows at a time.
494 template <int num_taps, bool is_compound = false>
Filter2DVertical4xH(const uint16_t * src,void * const dst,const ptrdiff_t dst_stride,const int height,const __m128i * const taps)495 void Filter2DVertical4xH(const uint16_t* src, void* const dst,
496                          const ptrdiff_t dst_stride, const int height,
497                          const __m128i* const taps) {
498   auto* dst8 = static_cast<uint8_t*>(dst);
499   auto* dst16 = static_cast<uint16_t*>(dst);
500 
501   __m128i srcs[9];
502   srcs[0] = LoadAligned16(src);
503   src += 8;
504   if (num_taps >= 4) {
505     srcs[2] = LoadAligned16(src);
506     src += 8;
507     srcs[1] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[0], 8), srcs[2]);
508     if (num_taps >= 6) {
509       srcs[4] = LoadAligned16(src);
510       src += 8;
511       srcs[3] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[2], 8), srcs[4]);
512       if (num_taps == 8) {
513         srcs[6] = LoadAligned16(src);
514         src += 8;
515         srcs[5] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[4], 8), srcs[6]);
516       }
517     }
518   }
519 
520   int y = 0;
521   do {
522     srcs[num_taps] = LoadAligned16(src);
523     src += 8;
524     srcs[num_taps - 1] = _mm_unpacklo_epi64(
525         _mm_srli_si128(srcs[num_taps - 2], 8), srcs[num_taps]);
526 
527     const __m128i sum =
528         SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
529     if (is_compound) {
530       StoreUnaligned16(dst16, sum);
531       dst16 += 4 << 1;
532     } else {
533       const __m128i results = _mm_packus_epi16(sum, sum);
534       Store4(dst8, results);
535       dst8 += dst_stride;
536       Store4(dst8, _mm_srli_si128(results, 4));
537       dst8 += dst_stride;
538     }
539 
540     srcs[0] = srcs[2];
541     if (num_taps >= 4) {
542       srcs[1] = srcs[3];
543       srcs[2] = srcs[4];
544       if (num_taps >= 6) {
545         srcs[3] = srcs[5];
546         srcs[4] = srcs[6];
547         if (num_taps == 8) {
548           srcs[5] = srcs[7];
549           srcs[6] = srcs[8];
550         }
551       }
552     }
553     y += 2;
554   } while (y < height);
555 }
556 
557 // Take advantage of |src_stride| == |width| to process four rows at a time.
558 template <int num_taps>
Filter2DVertical2xH(const uint16_t * src,void * const dst,const ptrdiff_t dst_stride,const int height,const __m128i * const taps)559 void Filter2DVertical2xH(const uint16_t* src, void* const dst,
560                          const ptrdiff_t dst_stride, const int height,
561                          const __m128i* const taps) {
562   constexpr int next_row = (num_taps < 6) ? 4 : 8;
563 
564   auto* dst8 = static_cast<uint8_t*>(dst);
565 
566   __m128i srcs[9];
567   srcs[0] = LoadAligned16(src);
568   src += 8;
569   if (num_taps >= 6) {
570     srcs[4] = LoadAligned16(src);
571     src += 8;
572     srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4);
573     if (num_taps == 8) {
574       srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8);
575       srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12);
576     }
577   }
578 
579   int y = 0;
580   do {
581     srcs[next_row] = LoadAligned16(src);
582     src += 8;
583     if (num_taps == 2) {
584       srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4);
585     } else if (num_taps == 4) {
586       srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4);
587       srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8);
588       srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12);
589     } else if (num_taps == 6) {
590       srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8);
591       srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12);
592       srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4);
593     } else if (num_taps == 8) {
594       srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4);
595       srcs[6] = _mm_alignr_epi8(srcs[8], srcs[4], 8);
596       srcs[7] = _mm_alignr_epi8(srcs[8], srcs[4], 12);
597     }
598 
599     const __m128i sum =
600         SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps);
601     const __m128i results = _mm_packus_epi16(sum, sum);
602 
603     Store2(dst8, results);
604     dst8 += dst_stride;
605     Store2(dst8, _mm_srli_si128(results, 2));
606     // When |height| <= 4 the taps are restricted to 2 and 4 tap variants.
607     // Therefore we don't need to check this condition when |height| > 4.
608     if (num_taps <= 4 && height == 2) return;
609     dst8 += dst_stride;
610     Store2(dst8, _mm_srli_si128(results, 4));
611     dst8 += dst_stride;
612     Store2(dst8, _mm_srli_si128(results, 6));
613     dst8 += dst_stride;
614 
615     srcs[0] = srcs[4];
616     if (num_taps == 6) {
617       srcs[1] = srcs[5];
618       srcs[4] = srcs[8];
619     } else if (num_taps == 8) {
620       srcs[1] = srcs[5];
621       srcs[2] = srcs[6];
622       srcs[3] = srcs[7];
623       srcs[4] = srcs[8];
624     }
625 
626     y += 4;
627   } while (y < height);
628 }
629 
630 template <bool is_2d = false, bool is_compound = false>
DoHorizontalPass(const uint8_t * const src,const ptrdiff_t src_stride,void * const dst,const ptrdiff_t dst_stride,const int width,const int height,const int subpixel,const int filter_index)631 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
632     const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
633     const ptrdiff_t dst_stride, const int width, const int height,
634     const int subpixel, const int filter_index) {
635   const int filter_id = (subpixel >> 6) & kSubPixelMask;
636   assert(filter_id != 0);
637   __m128i v_tap[4];
638   const __m128i v_horizontal_filter =
639       LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
640 
641   if (filter_index == 2) {  // 8 tap.
642     SetupTaps<8>(&v_horizontal_filter, v_tap);
643     FilterHorizontal<8, 8, 2, is_2d, is_compound>(
644         src, src_stride, dst, dst_stride, width, height, v_tap);
645   } else if (filter_index == 1) {  // 6 tap.
646     SetupTaps<6>(&v_horizontal_filter, v_tap);
647     FilterHorizontal<6, 8, 1, is_2d, is_compound>(
648         src, src_stride, dst, dst_stride, width, height, v_tap);
649   } else if (filter_index == 0) {  // 6 tap.
650     SetupTaps<6>(&v_horizontal_filter, v_tap);
651     FilterHorizontal<6, 8, 0, is_2d, is_compound>(
652         src, src_stride, dst, dst_stride, width, height, v_tap);
653   } else if (filter_index == 4) {  // 4 tap.
654     SetupTaps<4>(&v_horizontal_filter, v_tap);
655     FilterHorizontal<4, 8, 4, is_2d, is_compound>(
656         src, src_stride, dst, dst_stride, width, height, v_tap);
657   } else if (filter_index == 5) {  // 4 tap.
658     SetupTaps<4>(&v_horizontal_filter, v_tap);
659     FilterHorizontal<4, 8, 5, is_2d, is_compound>(
660         src, src_stride, dst, dst_stride, width, height, v_tap);
661   } else {  // 2 tap.
662     SetupTaps<2>(&v_horizontal_filter, v_tap);
663     FilterHorizontal<2, 8, 3, is_2d, is_compound>(
664         src, src_stride, dst, dst_stride, width, height, v_tap);
665   }
666 }
667 
Convolve2D_SSE4_1(const void * const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int subpixel_x,const int subpixel_y,const int width,const int height,void * prediction,const ptrdiff_t pred_stride)668 void Convolve2D_SSE4_1(const void* const reference,
669                        const ptrdiff_t reference_stride,
670                        const int horizontal_filter_index,
671                        const int vertical_filter_index, const int subpixel_x,
672                        const int subpixel_y, const int width, const int height,
673                        void* prediction, const ptrdiff_t pred_stride) {
674   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
675   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
676   const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
677 
678   // The output of the horizontal filter is guaranteed to fit in 16 bits.
679   alignas(16) uint16_t
680       intermediate_result[kMaxSuperBlockSizeInPixels *
681                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
682   const int intermediate_height = height + vertical_taps - 1;
683 
684   const ptrdiff_t src_stride = reference_stride;
685   const auto* src = static_cast<const uint8_t*>(reference) -
686                     (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
687 
688   DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width,
689                                    width, intermediate_height, subpixel_x,
690                                    horiz_filter_index);
691 
692   // Vertical filter.
693   auto* dest = static_cast<uint8_t*>(prediction);
694   const ptrdiff_t dest_stride = pred_stride;
695   const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
696   assert(filter_id != 0);
697 
698   __m128i taps[4];
699   const __m128i v_filter =
700       LoadLo8(kHalfSubPixelFilters[vert_filter_index][filter_id]);
701 
702   if (vertical_taps == 8) {
703     SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
704     if (width == 2) {
705       Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height,
706                              taps);
707     } else if (width == 4) {
708       Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height,
709                              taps);
710     } else {
711       Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height,
712                           taps);
713     }
714   } else if (vertical_taps == 6) {
715     SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
716     if (width == 2) {
717       Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height,
718                              taps);
719     } else if (width == 4) {
720       Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height,
721                              taps);
722     } else {
723       Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height,
724                           taps);
725     }
726   } else if (vertical_taps == 4) {
727     SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
728     if (width == 2) {
729       Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height,
730                              taps);
731     } else if (width == 4) {
732       Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height,
733                              taps);
734     } else {
735       Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height,
736                           taps);
737     }
738   } else {  // |vertical_taps| == 2
739     SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
740     if (width == 2) {
741       Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height,
742                              taps);
743     } else if (width == 4) {
744       Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height,
745                              taps);
746     } else {
747       Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height,
748                           taps);
749     }
750   }
751 }
752 
753 // The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D
754 // Vertical calculations.
Compound1DShift(const __m128i sum)755 __m128i Compound1DShift(const __m128i sum) {
756   return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
757 }
758 
759 template <int filter_index>
SumVerticalTaps(const __m128i * const srcs,const __m128i * const v_tap)760 __m128i SumVerticalTaps(const __m128i* const srcs, const __m128i* const v_tap) {
761   __m128i v_src[4];
762 
763   if (filter_index < 2) {
764     // 6 taps.
765     v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
766     v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]);
767     v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]);
768   } else if (filter_index == 2) {
769     // 8 taps.
770     v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
771     v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]);
772     v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]);
773     v_src[3] = _mm_unpacklo_epi8(srcs[6], srcs[7]);
774   } else if (filter_index == 3) {
775     // 2 taps.
776     v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
777   } else if (filter_index > 3) {
778     // 4 taps.
779     v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
780     v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]);
781   }
782   const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap);
783   return sum;
784 }
785 
786 template <int filter_index, bool is_compound = false>
FilterVertical(const uint8_t * src,const ptrdiff_t src_stride,void * const dst,const ptrdiff_t dst_stride,const int width,const int height,const __m128i * const v_tap)787 void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride,
788                     void* const dst, const ptrdiff_t dst_stride,
789                     const int width, const int height,
790                     const __m128i* const v_tap) {
791   const int num_taps = GetNumTapsInFilter(filter_index);
792   const int next_row = num_taps - 1;
793   auto* dst8 = static_cast<uint8_t*>(dst);
794   auto* dst16 = static_cast<uint16_t*>(dst);
795   assert(width >= 8);
796 
797   int x = 0;
798   do {
799     const uint8_t* src_x = src + x;
800     __m128i srcs[8];
801     srcs[0] = LoadLo8(src_x);
802     src_x += src_stride;
803     if (num_taps >= 4) {
804       srcs[1] = LoadLo8(src_x);
805       src_x += src_stride;
806       srcs[2] = LoadLo8(src_x);
807       src_x += src_stride;
808       if (num_taps >= 6) {
809         srcs[3] = LoadLo8(src_x);
810         src_x += src_stride;
811         srcs[4] = LoadLo8(src_x);
812         src_x += src_stride;
813         if (num_taps == 8) {
814           srcs[5] = LoadLo8(src_x);
815           src_x += src_stride;
816           srcs[6] = LoadLo8(src_x);
817           src_x += src_stride;
818         }
819       }
820     }
821 
822     int y = 0;
823     do {
824       srcs[next_row] = LoadLo8(src_x);
825       src_x += src_stride;
826 
827       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
828       if (is_compound) {
829         const __m128i results = Compound1DShift(sums);
830         StoreUnaligned16(dst16 + x + y * dst_stride, results);
831       } else {
832         const __m128i results =
833             RightShiftWithRounding_S16(sums, kFilterBits - 1);
834         StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(results, results));
835       }
836 
837       srcs[0] = srcs[1];
838       if (num_taps >= 4) {
839         srcs[1] = srcs[2];
840         srcs[2] = srcs[3];
841         if (num_taps >= 6) {
842           srcs[3] = srcs[4];
843           srcs[4] = srcs[5];
844           if (num_taps == 8) {
845             srcs[5] = srcs[6];
846             srcs[6] = srcs[7];
847           }
848         }
849       }
850     } while (++y < height);
851     x += 8;
852   } while (x < width);
853 }
854 
855 template <int filter_index, bool is_compound = false>
FilterVertical4xH(const uint8_t * src,const ptrdiff_t src_stride,void * const dst,const ptrdiff_t dst_stride,const int height,const __m128i * const v_tap)856 void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride,
857                        void* const dst, const ptrdiff_t dst_stride,
858                        const int height, const __m128i* const v_tap) {
859   const int num_taps = GetNumTapsInFilter(filter_index);
860   auto* dst8 = static_cast<uint8_t*>(dst);
861   auto* dst16 = static_cast<uint16_t*>(dst);
862 
863   __m128i srcs[9];
864 
865   if (num_taps == 2) {
866     srcs[2] = _mm_setzero_si128();
867     // 00 01 02 03
868     srcs[0] = Load4(src);
869     src += src_stride;
870 
871     int y = 0;
872     do {
873       // 10 11 12 13
874       const __m128i a = Load4(src);
875       // 00 01 02 03 10 11 12 13
876       srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
877       src += src_stride;
878       // 20 21 22 23
879       srcs[2] = Load4(src);
880       src += src_stride;
881       // 10 11 12 13 20 21 22 23
882       srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
883 
884       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
885       if (is_compound) {
886         const __m128i results = Compound1DShift(sums);
887         StoreUnaligned16(dst16, results);
888         dst16 += 4 << 1;
889       } else {
890         const __m128i results_16 =
891             RightShiftWithRounding_S16(sums, kFilterBits - 1);
892         const __m128i results = _mm_packus_epi16(results_16, results_16);
893         Store4(dst8, results);
894         dst8 += dst_stride;
895         Store4(dst8, _mm_srli_si128(results, 4));
896         dst8 += dst_stride;
897       }
898 
899       srcs[0] = srcs[2];
900       y += 2;
901     } while (y < height);
902   } else if (num_taps == 4) {
903     srcs[4] = _mm_setzero_si128();
904     // 00 01 02 03
905     srcs[0] = Load4(src);
906     src += src_stride;
907     // 10 11 12 13
908     const __m128i a = Load4(src);
909     // 00 01 02 03 10 11 12 13
910     srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
911     src += src_stride;
912     // 20 21 22 23
913     srcs[2] = Load4(src);
914     src += src_stride;
915     // 10 11 12 13 20 21 22 23
916     srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
917 
918     int y = 0;
919     do {
920       // 30 31 32 33
921       const __m128i b = Load4(src);
922       // 20 21 22 23 30 31 32 33
923       srcs[2] = _mm_unpacklo_epi32(srcs[2], b);
924       src += src_stride;
925       // 40 41 42 43
926       srcs[4] = Load4(src);
927       src += src_stride;
928       // 30 31 32 33 40 41 42 43
929       srcs[3] = _mm_unpacklo_epi32(b, srcs[4]);
930 
931       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
932       if (is_compound) {
933         const __m128i results = Compound1DShift(sums);
934         StoreUnaligned16(dst16, results);
935         dst16 += 4 << 1;
936       } else {
937         const __m128i results_16 =
938             RightShiftWithRounding_S16(sums, kFilterBits - 1);
939         const __m128i results = _mm_packus_epi16(results_16, results_16);
940         Store4(dst8, results);
941         dst8 += dst_stride;
942         Store4(dst8, _mm_srli_si128(results, 4));
943         dst8 += dst_stride;
944       }
945 
946       srcs[0] = srcs[2];
947       srcs[1] = srcs[3];
948       srcs[2] = srcs[4];
949       y += 2;
950     } while (y < height);
951   } else if (num_taps == 6) {
952     srcs[6] = _mm_setzero_si128();
953     // 00 01 02 03
954     srcs[0] = Load4(src);
955     src += src_stride;
956     // 10 11 12 13
957     const __m128i a = Load4(src);
958     // 00 01 02 03 10 11 12 13
959     srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
960     src += src_stride;
961     // 20 21 22 23
962     srcs[2] = Load4(src);
963     src += src_stride;
964     // 10 11 12 13 20 21 22 23
965     srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
966     // 30 31 32 33
967     const __m128i b = Load4(src);
968     // 20 21 22 23 30 31 32 33
969     srcs[2] = _mm_unpacklo_epi32(srcs[2], b);
970     src += src_stride;
971     // 40 41 42 43
972     srcs[4] = Load4(src);
973     src += src_stride;
974     // 30 31 32 33 40 41 42 43
975     srcs[3] = _mm_unpacklo_epi32(b, srcs[4]);
976 
977     int y = 0;
978     do {
979       // 50 51 52 53
980       const __m128i c = Load4(src);
981       // 40 41 42 43 50 51 52 53
982       srcs[4] = _mm_unpacklo_epi32(srcs[4], c);
983       src += src_stride;
984       // 60 61 62 63
985       srcs[6] = Load4(src);
986       src += src_stride;
987       // 50 51 52 53 60 61 62 63
988       srcs[5] = _mm_unpacklo_epi32(c, srcs[6]);
989 
990       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
991       if (is_compound) {
992         const __m128i results = Compound1DShift(sums);
993         StoreUnaligned16(dst16, results);
994         dst16 += 4 << 1;
995       } else {
996         const __m128i results_16 =
997             RightShiftWithRounding_S16(sums, kFilterBits - 1);
998         const __m128i results = _mm_packus_epi16(results_16, results_16);
999         Store4(dst8, results);
1000         dst8 += dst_stride;
1001         Store4(dst8, _mm_srli_si128(results, 4));
1002         dst8 += dst_stride;
1003       }
1004 
1005       srcs[0] = srcs[2];
1006       srcs[1] = srcs[3];
1007       srcs[2] = srcs[4];
1008       srcs[3] = srcs[5];
1009       srcs[4] = srcs[6];
1010       y += 2;
1011     } while (y < height);
1012   } else if (num_taps == 8) {
1013     srcs[8] = _mm_setzero_si128();
1014     // 00 01 02 03
1015     srcs[0] = Load4(src);
1016     src += src_stride;
1017     // 10 11 12 13
1018     const __m128i a = Load4(src);
1019     // 00 01 02 03 10 11 12 13
1020     srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
1021     src += src_stride;
1022     // 20 21 22 23
1023     srcs[2] = Load4(src);
1024     src += src_stride;
1025     // 10 11 12 13 20 21 22 23
1026     srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
1027     // 30 31 32 33
1028     const __m128i b = Load4(src);
1029     // 20 21 22 23 30 31 32 33
1030     srcs[2] = _mm_unpacklo_epi32(srcs[2], b);
1031     src += src_stride;
1032     // 40 41 42 43
1033     srcs[4] = Load4(src);
1034     src += src_stride;
1035     // 30 31 32 33 40 41 42 43
1036     srcs[3] = _mm_unpacklo_epi32(b, srcs[4]);
1037     // 50 51 52 53
1038     const __m128i c = Load4(src);
1039     // 40 41 42 43 50 51 52 53
1040     srcs[4] = _mm_unpacklo_epi32(srcs[4], c);
1041     src += src_stride;
1042     // 60 61 62 63
1043     srcs[6] = Load4(src);
1044     src += src_stride;
1045     // 50 51 52 53 60 61 62 63
1046     srcs[5] = _mm_unpacklo_epi32(c, srcs[6]);
1047 
1048     int y = 0;
1049     do {
1050       // 70 71 72 73
1051       const __m128i d = Load4(src);
1052       // 60 61 62 63 70 71 72 73
1053       srcs[6] = _mm_unpacklo_epi32(srcs[6], d);
1054       src += src_stride;
1055       // 80 81 82 83
1056       srcs[8] = Load4(src);
1057       src += src_stride;
1058       // 70 71 72 73 80 81 82 83
1059       srcs[7] = _mm_unpacklo_epi32(d, srcs[8]);
1060 
1061       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
1062       if (is_compound) {
1063         const __m128i results = Compound1DShift(sums);
1064         StoreUnaligned16(dst16, results);
1065         dst16 += 4 << 1;
1066       } else {
1067         const __m128i results_16 =
1068             RightShiftWithRounding_S16(sums, kFilterBits - 1);
1069         const __m128i results = _mm_packus_epi16(results_16, results_16);
1070         Store4(dst8, results);
1071         dst8 += dst_stride;
1072         Store4(dst8, _mm_srli_si128(results, 4));
1073         dst8 += dst_stride;
1074       }
1075 
1076       srcs[0] = srcs[2];
1077       srcs[1] = srcs[3];
1078       srcs[2] = srcs[4];
1079       srcs[3] = srcs[5];
1080       srcs[4] = srcs[6];
1081       srcs[5] = srcs[7];
1082       srcs[6] = srcs[8];
1083       y += 2;
1084     } while (y < height);
1085   }
1086 }
1087 
1088 template <int filter_index, bool negative_outside_taps = false>
FilterVertical2xH(const uint8_t * src,const ptrdiff_t src_stride,void * const dst,const ptrdiff_t dst_stride,const int height,const __m128i * const v_tap)1089 void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride,
1090                        void* const dst, const ptrdiff_t dst_stride,
1091                        const int height, const __m128i* const v_tap) {
1092   const int num_taps = GetNumTapsInFilter(filter_index);
1093   auto* dst8 = static_cast<uint8_t*>(dst);
1094 
1095   __m128i srcs[9];
1096 
1097   if (num_taps == 2) {
1098     srcs[2] = _mm_setzero_si128();
1099     // 00 01
1100     srcs[0] = Load2(src);
1101     src += src_stride;
1102 
1103     int y = 0;
1104     do {
1105       // 00 01 10 11
1106       srcs[0] = Load2<1>(src, srcs[0]);
1107       src += src_stride;
1108       // 00 01 10 11 20 21
1109       srcs[0] = Load2<2>(src, srcs[0]);
1110       src += src_stride;
1111       // 00 01 10 11 20 21 30 31
1112       srcs[0] = Load2<3>(src, srcs[0]);
1113       src += src_stride;
1114       // 40 41
1115       srcs[2] = Load2<0>(src, srcs[2]);
1116       src += src_stride;
1117       // 00 01 10 11 20 21 30 31 40 41
1118       const __m128i srcs_0_2 = _mm_unpacklo_epi64(srcs[0], srcs[2]);
1119       // 10 11 20 21 30 31 40 41
1120       srcs[1] = _mm_srli_si128(srcs_0_2, 2);
1121       // This uses srcs[0]..srcs[1].
1122       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
1123       const __m128i results_16 =
1124           RightShiftWithRounding_S16(sums, kFilterBits - 1);
1125       const __m128i results = _mm_packus_epi16(results_16, results_16);
1126 
1127       Store2(dst8, results);
1128       dst8 += dst_stride;
1129       Store2(dst8, _mm_srli_si128(results, 2));
1130       if (height == 2) return;
1131       dst8 += dst_stride;
1132       Store2(dst8, _mm_srli_si128(results, 4));
1133       dst8 += dst_stride;
1134       Store2(dst8, _mm_srli_si128(results, 6));
1135       dst8 += dst_stride;
1136 
1137       srcs[0] = srcs[2];
1138       y += 4;
1139     } while (y < height);
1140   } else if (num_taps == 4) {
1141     srcs[4] = _mm_setzero_si128();
1142 
1143     // 00 01
1144     srcs[0] = Load2(src);
1145     src += src_stride;
1146     // 00 01 10 11
1147     srcs[0] = Load2<1>(src, srcs[0]);
1148     src += src_stride;
1149     // 00 01 10 11 20 21
1150     srcs[0] = Load2<2>(src, srcs[0]);
1151     src += src_stride;
1152 
1153     int y = 0;
1154     do {
1155       // 00 01 10 11 20 21 30 31
1156       srcs[0] = Load2<3>(src, srcs[0]);
1157       src += src_stride;
1158       // 40 41
1159       srcs[4] = Load2<0>(src, srcs[4]);
1160       src += src_stride;
1161       // 40 41 50 51
1162       srcs[4] = Load2<1>(src, srcs[4]);
1163       src += src_stride;
1164       // 40 41 50 51 60 61
1165       srcs[4] = Load2<2>(src, srcs[4]);
1166       src += src_stride;
1167       // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
1168       const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]);
1169       // 10 11 20 21 30 31 40 41
1170       srcs[1] = _mm_srli_si128(srcs_0_4, 2);
1171       // 20 21 30 31 40 41 50 51
1172       srcs[2] = _mm_srli_si128(srcs_0_4, 4);
1173       // 30 31 40 41 50 51 60 61
1174       srcs[3] = _mm_srli_si128(srcs_0_4, 6);
1175 
1176       // This uses srcs[0]..srcs[3].
1177       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
1178       const __m128i results_16 =
1179           RightShiftWithRounding_S16(sums, kFilterBits - 1);
1180       const __m128i results = _mm_packus_epi16(results_16, results_16);
1181 
1182       Store2(dst8, results);
1183       dst8 += dst_stride;
1184       Store2(dst8, _mm_srli_si128(results, 2));
1185       if (height == 2) return;
1186       dst8 += dst_stride;
1187       Store2(dst8, _mm_srli_si128(results, 4));
1188       dst8 += dst_stride;
1189       Store2(dst8, _mm_srli_si128(results, 6));
1190       dst8 += dst_stride;
1191 
1192       srcs[0] = srcs[4];
1193       y += 4;
1194     } while (y < height);
1195   } else if (num_taps == 6) {
1196     // During the vertical pass the number of taps is restricted when
1197     // |height| <= 4.
1198     assert(height > 4);
1199     srcs[8] = _mm_setzero_si128();
1200 
1201     // 00 01
1202     srcs[0] = Load2(src);
1203     src += src_stride;
1204     // 00 01 10 11
1205     srcs[0] = Load2<1>(src, srcs[0]);
1206     src += src_stride;
1207     // 00 01 10 11 20 21
1208     srcs[0] = Load2<2>(src, srcs[0]);
1209     src += src_stride;
1210     // 00 01 10 11 20 21 30 31
1211     srcs[0] = Load2<3>(src, srcs[0]);
1212     src += src_stride;
1213     // 40 41
1214     srcs[4] = Load2(src);
1215     src += src_stride;
1216     // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
1217     const __m128i srcs_0_4x = _mm_unpacklo_epi64(srcs[0], srcs[4]);
1218     // 10 11 20 21 30 31 40 41
1219     srcs[1] = _mm_srli_si128(srcs_0_4x, 2);
1220 
1221     int y = 0;
1222     do {
1223       // 40 41 50 51
1224       srcs[4] = Load2<1>(src, srcs[4]);
1225       src += src_stride;
1226       // 40 41 50 51 60 61
1227       srcs[4] = Load2<2>(src, srcs[4]);
1228       src += src_stride;
1229       // 40 41 50 51 60 61 70 71
1230       srcs[4] = Load2<3>(src, srcs[4]);
1231       src += src_stride;
1232       // 80 81
1233       srcs[8] = Load2<0>(src, srcs[8]);
1234       src += src_stride;
1235       // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
1236       const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]);
1237       // 20 21 30 31 40 41 50 51
1238       srcs[2] = _mm_srli_si128(srcs_0_4, 4);
1239       // 30 31 40 41 50 51 60 61
1240       srcs[3] = _mm_srli_si128(srcs_0_4, 6);
1241       const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]);
1242       // 50 51 60 61 70 71 80 81
1243       srcs[5] = _mm_srli_si128(srcs_4_8, 2);
1244 
1245       // This uses srcs[0]..srcs[5].
1246       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
1247       const __m128i results_16 =
1248           RightShiftWithRounding_S16(sums, kFilterBits - 1);
1249       const __m128i results = _mm_packus_epi16(results_16, results_16);
1250 
1251       Store2(dst8, results);
1252       dst8 += dst_stride;
1253       Store2(dst8, _mm_srli_si128(results, 2));
1254       dst8 += dst_stride;
1255       Store2(dst8, _mm_srli_si128(results, 4));
1256       dst8 += dst_stride;
1257       Store2(dst8, _mm_srli_si128(results, 6));
1258       dst8 += dst_stride;
1259 
1260       srcs[0] = srcs[4];
1261       srcs[1] = srcs[5];
1262       srcs[4] = srcs[8];
1263       y += 4;
1264     } while (y < height);
1265   } else if (num_taps == 8) {
1266     // During the vertical pass the number of taps is restricted when
1267     // |height| <= 4.
1268     assert(height > 4);
1269     srcs[8] = _mm_setzero_si128();
1270     // 00 01
1271     srcs[0] = Load2(src);
1272     src += src_stride;
1273     // 00 01 10 11
1274     srcs[0] = Load2<1>(src, srcs[0]);
1275     src += src_stride;
1276     // 00 01 10 11 20 21
1277     srcs[0] = Load2<2>(src, srcs[0]);
1278     src += src_stride;
1279     // 00 01 10 11 20 21 30 31
1280     srcs[0] = Load2<3>(src, srcs[0]);
1281     src += src_stride;
1282     // 40 41
1283     srcs[4] = Load2(src);
1284     src += src_stride;
1285     // 40 41 50 51
1286     srcs[4] = Load2<1>(src, srcs[4]);
1287     src += src_stride;
1288     // 40 41 50 51 60 61
1289     srcs[4] = Load2<2>(src, srcs[4]);
1290     src += src_stride;
1291 
1292     // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
1293     const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]);
1294     // 10 11 20 21 30 31 40 41
1295     srcs[1] = _mm_srli_si128(srcs_0_4, 2);
1296     // 20 21 30 31 40 41 50 51
1297     srcs[2] = _mm_srli_si128(srcs_0_4, 4);
1298     // 30 31 40 41 50 51 60 61
1299     srcs[3] = _mm_srli_si128(srcs_0_4, 6);
1300 
1301     int y = 0;
1302     do {
1303       // 40 41 50 51 60 61 70 71
1304       srcs[4] = Load2<3>(src, srcs[4]);
1305       src += src_stride;
1306       // 80 81
1307       srcs[8] = Load2<0>(src, srcs[8]);
1308       src += src_stride;
1309       // 80 81 90 91
1310       srcs[8] = Load2<1>(src, srcs[8]);
1311       src += src_stride;
1312       // 80 81 90 91 a0 a1
1313       srcs[8] = Load2<2>(src, srcs[8]);
1314       src += src_stride;
1315 
1316       // 40 41 50 51 60 61 70 71 80 81 90 91 a0 a1
1317       const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]);
1318       // 50 51 60 61 70 71 80 81
1319       srcs[5] = _mm_srli_si128(srcs_4_8, 2);
1320       // 60 61 70 71 80 81 90 91
1321       srcs[6] = _mm_srli_si128(srcs_4_8, 4);
1322       // 70 71 80 81 90 91 a0 a1
1323       srcs[7] = _mm_srli_si128(srcs_4_8, 6);
1324 
1325       // This uses srcs[0]..srcs[7].
1326       const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
1327       const __m128i results_16 =
1328           RightShiftWithRounding_S16(sums, kFilterBits - 1);
1329       const __m128i results = _mm_packus_epi16(results_16, results_16);
1330 
1331       Store2(dst8, results);
1332       dst8 += dst_stride;
1333       Store2(dst8, _mm_srli_si128(results, 2));
1334       dst8 += dst_stride;
1335       Store2(dst8, _mm_srli_si128(results, 4));
1336       dst8 += dst_stride;
1337       Store2(dst8, _mm_srli_si128(results, 6));
1338       dst8 += dst_stride;
1339 
1340       srcs[0] = srcs[4];
1341       srcs[1] = srcs[5];
1342       srcs[2] = srcs[6];
1343       srcs[3] = srcs[7];
1344       srcs[4] = srcs[8];
1345       y += 4;
1346     } while (y < height);
1347   }
1348 }
1349 
ConvolveVertical_SSE4_1(const void * const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int subpixel_y,const int width,const int height,void * prediction,const ptrdiff_t pred_stride)1350 void ConvolveVertical_SSE4_1(const void* const reference,
1351                              const ptrdiff_t reference_stride,
1352                              const int /*horizontal_filter_index*/,
1353                              const int vertical_filter_index,
1354                              const int /*subpixel_x*/, const int subpixel_y,
1355                              const int width, const int height,
1356                              void* prediction, const ptrdiff_t pred_stride) {
1357   const int filter_index = GetFilterIndex(vertical_filter_index, height);
1358   const int vertical_taps = GetNumTapsInFilter(filter_index);
1359   const ptrdiff_t src_stride = reference_stride;
1360   const auto* src = static_cast<const uint8_t*>(reference) -
1361                     (vertical_taps / 2 - 1) * src_stride;
1362   auto* dest = static_cast<uint8_t*>(prediction);
1363   const ptrdiff_t dest_stride = pred_stride;
1364   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
1365   assert(filter_id != 0);
1366 
1367   __m128i taps[4];
1368   const __m128i v_filter =
1369       LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
1370 
1371   if (filter_index < 2) {  // 6 tap.
1372     SetupTaps<6>(&v_filter, taps);
1373     if (width == 2) {
1374       FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height, taps);
1375     } else if (width == 4) {
1376       FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height, taps);
1377     } else {
1378       FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
1379                         taps);
1380     }
1381   } else if (filter_index == 2) {  // 8 tap.
1382     SetupTaps<8>(&v_filter, taps);
1383     if (width == 2) {
1384       FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
1385     } else if (width == 4) {
1386       FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
1387     } else {
1388       FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
1389                         taps);
1390     }
1391   } else if (filter_index == 3) {  // 2 tap.
1392     SetupTaps<2>(&v_filter, taps);
1393     if (width == 2) {
1394       FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height, taps);
1395     } else if (width == 4) {
1396       FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height, taps);
1397     } else {
1398       FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
1399                         taps);
1400     }
1401   } else if (filter_index == 4) {  // 4 tap.
1402     SetupTaps<4>(&v_filter, taps);
1403     if (width == 2) {
1404       FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, taps);
1405     } else if (width == 4) {
1406       FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, taps);
1407     } else {
1408       FilterVertical<4>(src, src_stride, dest, dest_stride, width, height,
1409                         taps);
1410     }
1411   } else {
1412     // TODO(slavarnway): Investigate adding |filter_index| == 1 special cases.
1413     // See convolve_neon.cc
1414     SetupTaps<4>(&v_filter, taps);
1415 
1416     if (width == 2) {
1417       FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height, taps);
1418     } else if (width == 4) {
1419       FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height, taps);
1420     } else {
1421       FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
1422                         taps);
1423     }
1424   }
1425 }
1426 
ConvolveCompoundCopy_SSE4(const void * const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * prediction,const ptrdiff_t pred_stride)1427 void ConvolveCompoundCopy_SSE4(
1428     const void* const reference, const ptrdiff_t reference_stride,
1429     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
1430     const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
1431     const int height, void* prediction, const ptrdiff_t pred_stride) {
1432   const auto* src = static_cast<const uint8_t*>(reference);
1433   const ptrdiff_t src_stride = reference_stride;
1434   auto* dest = static_cast<uint16_t*>(prediction);
1435   constexpr int kRoundBitsVertical =
1436       kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
1437   if (width >= 16) {
1438     int y = height;
1439     do {
1440       int x = 0;
1441       do {
1442         const __m128i v_src = LoadUnaligned16(&src[x]);
1443         const __m128i v_src_ext_lo = _mm_cvtepu8_epi16(v_src);
1444         const __m128i v_src_ext_hi =
1445             _mm_cvtepu8_epi16(_mm_srli_si128(v_src, 8));
1446         const __m128i v_dest_lo =
1447             _mm_slli_epi16(v_src_ext_lo, kRoundBitsVertical);
1448         const __m128i v_dest_hi =
1449             _mm_slli_epi16(v_src_ext_hi, kRoundBitsVertical);
1450         // TODO(slavarnway): Investigate using aligned stores.
1451         StoreUnaligned16(&dest[x], v_dest_lo);
1452         StoreUnaligned16(&dest[x + 8], v_dest_hi);
1453         x += 16;
1454       } while (x < width);
1455       src += src_stride;
1456       dest += pred_stride;
1457     } while (--y != 0);
1458   } else if (width == 8) {
1459     int y = height;
1460     do {
1461       const __m128i v_src = LoadLo8(&src[0]);
1462       const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src);
1463       const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical);
1464       StoreUnaligned16(&dest[0], v_dest);
1465       src += src_stride;
1466       dest += pred_stride;
1467     } while (--y != 0);
1468   } else { /* width == 4 */
1469     int y = height;
1470     do {
1471       const __m128i v_src0 = Load4(&src[0]);
1472       const __m128i v_src1 = Load4(&src[src_stride]);
1473       const __m128i v_src = _mm_unpacklo_epi32(v_src0, v_src1);
1474       const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src);
1475       const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical);
1476       StoreLo8(&dest[0], v_dest);
1477       StoreHi8(&dest[pred_stride], v_dest);
1478       src += src_stride * 2;
1479       dest += pred_stride * 2;
1480       y -= 2;
1481     } while (y != 0);
1482   }
1483 }
1484 
ConvolveCompoundVertical_SSE4_1(const void * const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int subpixel_y,const int width,const int height,void * prediction,const ptrdiff_t)1485 void ConvolveCompoundVertical_SSE4_1(
1486     const void* const reference, const ptrdiff_t reference_stride,
1487     const int /*horizontal_filter_index*/, const int vertical_filter_index,
1488     const int /*subpixel_x*/, const int subpixel_y, const int width,
1489     const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
1490   const int filter_index = GetFilterIndex(vertical_filter_index, height);
1491   const int vertical_taps = GetNumTapsInFilter(filter_index);
1492   const ptrdiff_t src_stride = reference_stride;
1493   const auto* src = static_cast<const uint8_t*>(reference) -
1494                     (vertical_taps / 2 - 1) * src_stride;
1495   auto* dest = static_cast<uint16_t*>(prediction);
1496   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
1497   assert(filter_id != 0);
1498 
1499   __m128i taps[4];
1500   const __m128i v_filter =
1501       LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
1502 
1503   if (filter_index < 2) {  // 6 tap.
1504     SetupTaps<6>(&v_filter, taps);
1505     if (width == 4) {
1506       FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
1507                                                  height, taps);
1508     } else {
1509       FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
1510                                               width, height, taps);
1511     }
1512   } else if (filter_index == 2) {  // 8 tap.
1513     SetupTaps<8>(&v_filter, taps);
1514 
1515     if (width == 4) {
1516       FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
1517                                                  height, taps);
1518     } else {
1519       FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
1520                                               width, height, taps);
1521     }
1522   } else if (filter_index == 3) {  // 2 tap.
1523     SetupTaps<2>(&v_filter, taps);
1524 
1525     if (width == 4) {
1526       FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
1527                                                  height, taps);
1528     } else {
1529       FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
1530                                               width, height, taps);
1531     }
1532   } else if (filter_index == 4) {  // 4 tap.
1533     SetupTaps<4>(&v_filter, taps);
1534 
1535     if (width == 4) {
1536       FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4,
1537                                                  height, taps);
1538     } else {
1539       FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width,
1540                                               width, height, taps);
1541     }
1542   } else {
1543     SetupTaps<4>(&v_filter, taps);
1544 
1545     if (width == 4) {
1546       FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
1547                                                  height, taps);
1548     } else {
1549       FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
1550                                               width, height, taps);
1551     }
1552   }
1553 }
1554 
ConvolveHorizontal_SSE4_1(const void * const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int subpixel_x,const int,const int width,const int height,void * prediction,const ptrdiff_t pred_stride)1555 void ConvolveHorizontal_SSE4_1(const void* const reference,
1556                                const ptrdiff_t reference_stride,
1557                                const int horizontal_filter_index,
1558                                const int /*vertical_filter_index*/,
1559                                const int subpixel_x, const int /*subpixel_y*/,
1560                                const int width, const int height,
1561                                void* prediction, const ptrdiff_t pred_stride) {
1562   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
1563   // Set |src| to the outermost tap.
1564   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
1565   auto* dest = static_cast<uint8_t*>(prediction);
1566 
1567   DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height,
1568                    subpixel_x, filter_index);
1569 }
1570 
ConvolveCompoundHorizontal_SSE4_1(const void * const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int subpixel_x,const int,const int width,const int height,void * prediction,const ptrdiff_t)1571 void ConvolveCompoundHorizontal_SSE4_1(
1572     const void* const reference, const ptrdiff_t reference_stride,
1573     const int horizontal_filter_index, const int /*vertical_filter_index*/,
1574     const int subpixel_x, const int /*subpixel_y*/, const int width,
1575     const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
1576   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
1577   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
1578   auto* dest = static_cast<uint16_t*>(prediction);
1579 
1580   DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>(
1581       src, reference_stride, dest, width, width, height, subpixel_x,
1582       filter_index);
1583 }
1584 
ConvolveCompound2D_SSE4_1(const void * const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int subpixel_x,const int subpixel_y,const int width,const int height,void * prediction,const ptrdiff_t)1585 void ConvolveCompound2D_SSE4_1(
1586     const void* const reference, const ptrdiff_t reference_stride,
1587     const int horizontal_filter_index, const int vertical_filter_index,
1588     const int subpixel_x, const int subpixel_y, const int width,
1589     const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
1590   // The output of the horizontal filter, i.e. the intermediate_result, is
1591   // guaranteed to fit in int16_t.
1592   alignas(16) uint16_t
1593       intermediate_result[kMaxSuperBlockSizeInPixels *
1594                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
1595 
1596   // Horizontal filter.
1597   // Filter types used for width <= 4 are different from those for width > 4.
1598   // When width > 4, the valid filter index range is always [0, 3].
1599   // When width <= 4, the valid filter index range is always [4, 5].
1600   // Similarly for height.
1601   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
1602   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
1603   const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
1604   const int intermediate_height = height + vertical_taps - 1;
1605   const ptrdiff_t src_stride = reference_stride;
1606   const auto* const src = static_cast<const uint8_t*>(reference) -
1607                           (vertical_taps / 2 - 1) * src_stride -
1608                           kHorizontalOffset;
1609 
1610   DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
1611       src, src_stride, intermediate_result, width, width, intermediate_height,
1612       subpixel_x, horiz_filter_index);
1613 
1614   // Vertical filter.
1615   auto* dest = static_cast<uint16_t*>(prediction);
1616   const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
1617   assert(filter_id != 0);
1618 
1619   const ptrdiff_t dest_stride = width;
1620   __m128i taps[4];
1621   const __m128i v_filter =
1622       LoadLo8(kHalfSubPixelFilters[vert_filter_index][filter_id]);
1623 
1624   if (vertical_taps == 8) {
1625     SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
1626     if (width == 4) {
1627       Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest,
1628                                                    dest_stride, height, taps);
1629     } else {
1630       Filter2DVertical<8, /*is_compound=*/true>(
1631           intermediate_result, dest, dest_stride, width, height, taps);
1632     }
1633   } else if (vertical_taps == 6) {
1634     SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
1635     if (width == 4) {
1636       Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest,
1637                                                    dest_stride, height, taps);
1638     } else {
1639       Filter2DVertical<6, /*is_compound=*/true>(
1640           intermediate_result, dest, dest_stride, width, height, taps);
1641     }
1642   } else if (vertical_taps == 4) {
1643     SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
1644     if (width == 4) {
1645       Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest,
1646                                                    dest_stride, height, taps);
1647     } else {
1648       Filter2DVertical<4, /*is_compound=*/true>(
1649           intermediate_result, dest, dest_stride, width, height, taps);
1650     }
1651   } else {  // |vertical_taps| == 2
1652     SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
1653     if (width == 4) {
1654       Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest,
1655                                                    dest_stride, height, taps);
1656     } else {
1657       Filter2DVertical<2, /*is_compound=*/true>(
1658           intermediate_result, dest, dest_stride, width, height, taps);
1659     }
1660   }
1661 }
1662 
1663 // Pre-transposed filters.
1664 template <int filter_index>
GetHalfSubPixelFilter(__m128i * output)1665 inline void GetHalfSubPixelFilter(__m128i* output) {
1666   // Filter 0
1667   alignas(
1668       16) static constexpr int8_t kHalfSubPixel6TapSignedFilterColumns[6][16] =
1669       {{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
1670        {0, -3, -5, -6, -7, -7, -8, -7, -7, -6, -6, -6, -5, -4, -2, -1},
1671        {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
1672        {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
1673        {0, -1, -2, -4, -5, -6, -6, -6, -7, -7, -8, -7, -7, -6, -5, -3},
1674        {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
1675   // Filter 1
1676   alignas(16) static constexpr int8_t
1677       kHalfSubPixel6TapMixedSignedFilterColumns[6][16] = {
1678           {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
1679           {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
1680           {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
1681           {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
1682           {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14},
1683           {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
1684   // Filter 2
1685   alignas(
1686       16) static constexpr int8_t kHalfSubPixel8TapSignedFilterColumns[8][16] =
1687       {{0, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, 0},
1688        {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
1689        {0, -3, -6, -9, -11, -11, -12, -12, -12, -11, -10, -9, -7, -5, -3, -1},
1690        {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
1691        {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
1692        {0, -1, -3, -5, -7, -9, -10, -11, -12, -12, -12, -11, -11, -9, -6, -3},
1693        {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
1694        {0, 0, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1}};
1695   // Filter 3
1696   alignas(16) static constexpr uint8_t kHalfSubPixel2TapFilterColumns[2][16] = {
1697       {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
1698       {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
1699   // Filter 4
1700   alignas(
1701       16) static constexpr int8_t kHalfSubPixel4TapSignedFilterColumns[4][16] =
1702       {{0, -2, -4, -5, -6, -6, -7, -6, -6, -5, -5, -5, -4, -3, -2, -1},
1703        {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
1704        {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
1705        {0, -1, -2, -3, -4, -5, -5, -5, -6, -6, -7, -6, -6, -5, -4, -2}};
1706   // Filter 5
1707   alignas(
1708       16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
1709       {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
1710       {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
1711       {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
1712       {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
1713   switch (filter_index) {
1714     case 0:
1715       output[0] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[0]);
1716       output[1] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[1]);
1717       output[2] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[2]);
1718       output[3] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[3]);
1719       output[4] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[4]);
1720       output[5] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[5]);
1721       break;
1722     case 1:
1723       // The term "mixed" refers to the fact that the outer taps have a mix of
1724       // negative and positive values.
1725       output[0] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[0]);
1726       output[1] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[1]);
1727       output[2] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[2]);
1728       output[3] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[3]);
1729       output[4] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[4]);
1730       output[5] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[5]);
1731       break;
1732     case 2:
1733       output[0] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[0]);
1734       output[1] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[1]);
1735       output[2] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[2]);
1736       output[3] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[3]);
1737       output[4] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[4]);
1738       output[5] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[5]);
1739       output[6] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[6]);
1740       output[7] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[7]);
1741       break;
1742     case 3:
1743       output[0] = LoadAligned16(kHalfSubPixel2TapFilterColumns[0]);
1744       output[1] = LoadAligned16(kHalfSubPixel2TapFilterColumns[1]);
1745       break;
1746     case 4:
1747       output[0] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[0]);
1748       output[1] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[1]);
1749       output[2] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[2]);
1750       output[3] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[3]);
1751       break;
1752     default:
1753       assert(filter_index == 5);
1754       output[0] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[0]);
1755       output[1] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[1]);
1756       output[2] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[2]);
1757       output[3] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[3]);
1758       break;
1759   }
1760 }
1761 
1762 // There are many opportunities for overreading in scaled convolve, because
1763 // the range of starting points for filter windows is anywhere from 0 to 16
1764 // for 8 destination pixels, and the window sizes range from 2 to 8. To
1765 // accommodate this range concisely, we use |grade_x| to mean the most steps
1766 // in src that can be traversed in a single |step_x| increment, i.e. 1 or 2.
1767 // More importantly, |grade_x| answers the question "how many vector loads are
1768 // needed to cover the source values?"
1769 // When |grade_x| == 1, the maximum number of source values needed is 8 separate
1770 // starting positions plus 7 more to cover taps, all fitting into 16 bytes.
1771 // When |grade_x| > 1, we are guaranteed to exceed 8 whole steps in src for
1772 // every 8 |step_x| increments, on top of 8 possible taps. The first load covers
1773 // the starting sources for each kernel, while the final load covers the taps.
1774 // Since the offset value of src_x cannot exceed 8 and |num_taps| does not
1775 // exceed 4 when width <= 4, |grade_x| is set to 1 regardless of the value of
1776 // |step_x|.
1777 template <int num_taps, int grade_x>
PrepareSourceVectors(const uint8_t * src,const __m128i src_indices,__m128i * const source)1778 inline void PrepareSourceVectors(const uint8_t* src, const __m128i src_indices,
1779                                  __m128i* const source /*[num_taps >> 1]*/) {
1780   const __m128i src_vals = LoadUnaligned16(src);
1781   source[0] = _mm_shuffle_epi8(src_vals, src_indices);
1782   if (grade_x == 1) {
1783     if (num_taps > 2) {
1784       source[1] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 2), src_indices);
1785     }
1786     if (num_taps > 4) {
1787       source[2] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 4), src_indices);
1788     }
1789     if (num_taps > 6) {
1790       source[3] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 6), src_indices);
1791     }
1792   } else {
1793     assert(grade_x > 1);
1794     assert(num_taps != 4);
1795     // grade_x > 1 also means width >= 8 && num_taps != 4
1796     const __m128i src_vals_ext = LoadLo8(src + 16);
1797     if (num_taps > 2) {
1798       source[1] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 2),
1799                                    src_indices);
1800       source[2] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 4),
1801                                    src_indices);
1802     }
1803     if (num_taps > 6) {
1804       source[3] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 6),
1805                                    src_indices);
1806     }
1807   }
1808 }
1809 
1810 template <int num_taps>
PrepareHorizontalTaps(const __m128i subpel_indices,const __m128i * filter_taps,__m128i * out_taps)1811 inline void PrepareHorizontalTaps(const __m128i subpel_indices,
1812                                   const __m128i* filter_taps,
1813                                   __m128i* out_taps) {
1814   const __m128i scale_index_offsets =
1815       _mm_srli_epi16(subpel_indices, kFilterIndexShift);
1816   const __m128i filter_index_mask = _mm_set1_epi8(kSubPixelMask);
1817   const __m128i filter_indices =
1818       _mm_and_si128(_mm_packus_epi16(scale_index_offsets, scale_index_offsets),
1819                     filter_index_mask);
1820   // Line up taps for maddubs_epi16.
1821   // The unpack is also assumed to be lighter than shift+alignr.
1822   for (int k = 0; k < (num_taps >> 1); ++k) {
1823     const __m128i taps0 = _mm_shuffle_epi8(filter_taps[2 * k], filter_indices);
1824     const __m128i taps1 =
1825         _mm_shuffle_epi8(filter_taps[2 * k + 1], filter_indices);
1826     out_taps[k] = _mm_unpacklo_epi8(taps0, taps1);
1827   }
1828 }
1829 
HorizontalScaleIndices(const __m128i subpel_indices)1830 inline __m128i HorizontalScaleIndices(const __m128i subpel_indices) {
1831   const __m128i src_indices16 =
1832       _mm_srli_epi16(subpel_indices, kScaleSubPixelBits);
1833   const __m128i src_indices = _mm_packus_epi16(src_indices16, src_indices16);
1834   return _mm_unpacklo_epi8(src_indices,
1835                            _mm_add_epi8(src_indices, _mm_set1_epi8(1)));
1836 }
1837 
1838 template <int grade_x, int filter_index, int num_taps>
ConvolveHorizontalScale(const uint8_t * src,ptrdiff_t src_stride,int width,int subpixel_x,int step_x,int intermediate_height,int16_t * intermediate)1839 inline void ConvolveHorizontalScale(const uint8_t* src, ptrdiff_t src_stride,
1840                                     int width, int subpixel_x, int step_x,
1841                                     int intermediate_height,
1842                                     int16_t* intermediate) {
1843   // Account for the 0-taps that precede the 2 nonzero taps.
1844   const int kernel_offset = (8 - num_taps) >> 1;
1845   const int ref_x = subpixel_x >> kScaleSubPixelBits;
1846   const int step_x8 = step_x << 3;
1847   __m128i filter_taps[num_taps];
1848   GetHalfSubPixelFilter<filter_index>(filter_taps);
1849   const __m128i index_steps =
1850       _mm_mullo_epi16(_mm_set_epi16(7, 6, 5, 4, 3, 2, 1, 0),
1851                       _mm_set1_epi16(static_cast<int16_t>(step_x)));
1852 
1853   __m128i taps[num_taps >> 1];
1854   __m128i source[num_taps >> 1];
1855   int p = subpixel_x;
1856   // Case when width <= 4 is possible.
1857   if (filter_index >= 3) {
1858     if (filter_index > 3 || width <= 4) {
1859       const uint8_t* src_x =
1860           &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
1861       // Only add steps to the 10-bit truncated p to avoid overflow.
1862       const __m128i p_fraction = _mm_set1_epi16(p & 1023);
1863       const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction);
1864       PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps);
1865       const __m128i packed_indices = HorizontalScaleIndices(subpel_indices);
1866 
1867       int y = intermediate_height;
1868       do {
1869         // Load and line up source values with the taps. Width 4 means no need
1870         // to load extended source.
1871         PrepareSourceVectors<num_taps, /*grade_x=*/1>(src_x, packed_indices,
1872                                                       source);
1873 
1874         StoreLo8(intermediate, RightShiftWithRounding_S16(
1875                                    SumOnePassTaps<filter_index>(source, taps),
1876                                    kInterRoundBitsHorizontal - 1));
1877         src_x += src_stride;
1878         intermediate += kIntermediateStride;
1879       } while (--y != 0);
1880       return;
1881     }
1882   }
1883 
1884   // |width| >= 8
1885   int x = 0;
1886   do {
1887     const uint8_t* src_x =
1888         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
1889     int16_t* intermediate_x = intermediate + x;
1890     // Only add steps to the 10-bit truncated p to avoid overflow.
1891     const __m128i p_fraction = _mm_set1_epi16(p & 1023);
1892     const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction);
1893     PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps);
1894     const __m128i packed_indices = HorizontalScaleIndices(subpel_indices);
1895 
1896     int y = intermediate_height;
1897     do {
1898       // For each x, a lane of src_k[k] contains src_x[k].
1899       PrepareSourceVectors<num_taps, grade_x>(src_x, packed_indices, source);
1900 
1901       // Shift by one less because the taps are halved.
1902       StoreAligned16(
1903           intermediate_x,
1904           RightShiftWithRounding_S16(SumOnePassTaps<filter_index>(source, taps),
1905                                      kInterRoundBitsHorizontal - 1));
1906       src_x += src_stride;
1907       intermediate_x += kIntermediateStride;
1908     } while (--y != 0);
1909     x += 8;
1910     p += step_x8;
1911   } while (x < width);
1912 }
1913 
1914 template <int num_taps>
PrepareVerticalTaps(const int8_t * taps,__m128i * output)1915 inline void PrepareVerticalTaps(const int8_t* taps, __m128i* output) {
1916   // Avoid overreading the filter due to starting at kernel_offset.
1917   // The only danger of overread is in the final filter, which has 4 taps.
1918   const __m128i filter =
1919       _mm_cvtepi8_epi16((num_taps > 4) ? LoadLo8(taps) : Load4(taps));
1920   output[0] = _mm_shuffle_epi32(filter, 0);
1921   if (num_taps > 2) {
1922     output[1] = _mm_shuffle_epi32(filter, 0x55);
1923   }
1924   if (num_taps > 4) {
1925     output[2] = _mm_shuffle_epi32(filter, 0xAA);
1926   }
1927   if (num_taps > 6) {
1928     output[3] = _mm_shuffle_epi32(filter, 0xFF);
1929   }
1930 }
1931 
1932 // Process eight 16 bit inputs and output eight 16 bit values.
1933 template <int num_taps, bool is_compound>
Sum2DVerticalTaps(const __m128i * const src,const __m128i * taps)1934 inline __m128i Sum2DVerticalTaps(const __m128i* const src,
1935                                  const __m128i* taps) {
1936   const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]);
1937   __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps[0]);
1938   const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]);
1939   __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps[0]);
1940   if (num_taps > 2) {
1941     const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]);
1942     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps[1]));
1943     const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]);
1944     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps[1]));
1945   }
1946   if (num_taps > 4) {
1947     const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]);
1948     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps[2]));
1949     const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]);
1950     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps[2]));
1951   }
1952   if (num_taps > 6) {
1953     const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]);
1954     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps[3]));
1955     const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]);
1956     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps[3]));
1957   }
1958   if (is_compound) {
1959     return _mm_packs_epi32(
1960         RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
1961         RightShiftWithRounding_S32(sum_hi,
1962                                    kInterRoundBitsCompoundVertical - 1));
1963   }
1964   return _mm_packs_epi32(
1965       RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
1966       RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
1967 }
1968 
1969 // Bottom half of each src[k] is the source for one filter, and the top half
1970 // is the source for the other filter, for the next destination row.
1971 template <int num_taps, bool is_compound>
Sum2DVerticalTaps4x2(const __m128i * const src,const __m128i * taps_lo,const __m128i * taps_hi)1972 __m128i Sum2DVerticalTaps4x2(const __m128i* const src, const __m128i* taps_lo,
1973                              const __m128i* taps_hi) {
1974   const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]);
1975   __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps_lo[0]);
1976   const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]);
1977   __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps_hi[0]);
1978   if (num_taps > 2) {
1979     const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]);
1980     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps_lo[1]));
1981     const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]);
1982     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps_hi[1]));
1983   }
1984   if (num_taps > 4) {
1985     const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]);
1986     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps_lo[2]));
1987     const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]);
1988     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps_hi[2]));
1989   }
1990   if (num_taps > 6) {
1991     const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]);
1992     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps_lo[3]));
1993     const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]);
1994     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps_hi[3]));
1995   }
1996 
1997   if (is_compound) {
1998     return _mm_packs_epi32(
1999         RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
2000         RightShiftWithRounding_S32(sum_hi,
2001                                    kInterRoundBitsCompoundVertical - 1));
2002   }
2003   return _mm_packs_epi32(
2004       RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
2005       RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
2006 }
2007 
2008 // |width_class| is 2, 4, or 8, according to the Store function that should be
2009 // used.
2010 template <int num_taps, int width_class, bool is_compound>
2011 #if LIBGAV1_MSAN
ConvolveVerticalScale(const int16_t * src,const int width,const int subpixel_y,const int filter_index,const int step_y,const int height,void * dest,const ptrdiff_t dest_stride)2012 __attribute__((no_sanitize_memory)) void ConvolveVerticalScale(
2013 #else
2014 inline void ConvolveVerticalScale(
2015 #endif
2016     const int16_t* src, const int width, const int subpixel_y,
2017     const int filter_index, const int step_y, const int height, void* dest,
2018     const ptrdiff_t dest_stride) {
2019   constexpr ptrdiff_t src_stride = kIntermediateStride;
2020   constexpr int kernel_offset = (8 - num_taps) / 2;
2021   const int16_t* src_y = src;
2022   // |dest| is 16-bit in compound mode, Pixel otherwise.
2023   auto* dest16_y = static_cast<uint16_t*>(dest);
2024   auto* dest_y = static_cast<uint8_t*>(dest);
2025   __m128i s[num_taps];
2026 
2027   int p = subpixel_y & 1023;
2028   int y = height;
2029   if (width_class <= 4) {
2030     __m128i filter_taps_lo[num_taps >> 1];
2031     __m128i filter_taps_hi[num_taps >> 1];
2032     do {  // y > 0
2033       for (int i = 0; i < num_taps; ++i) {
2034         s[i] = LoadLo8(src_y + i * src_stride);
2035       }
2036       int filter_id = (p >> 6) & kSubPixelMask;
2037       const int8_t* filter0 =
2038           kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
2039       PrepareVerticalTaps<num_taps>(filter0, filter_taps_lo);
2040       p += step_y;
2041       src_y = src + (p >> kScaleSubPixelBits) * src_stride;
2042 
2043       for (int i = 0; i < num_taps; ++i) {
2044         s[i] = LoadHi8(s[i], src_y + i * src_stride);
2045       }
2046       filter_id = (p >> 6) & kSubPixelMask;
2047       const int8_t* filter1 =
2048           kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
2049       PrepareVerticalTaps<num_taps>(filter1, filter_taps_hi);
2050       p += step_y;
2051       src_y = src + (p >> kScaleSubPixelBits) * src_stride;
2052 
2053       const __m128i sums = Sum2DVerticalTaps4x2<num_taps, is_compound>(
2054           s, filter_taps_lo, filter_taps_hi);
2055       if (is_compound) {
2056         assert(width_class > 2);
2057         StoreLo8(dest16_y, sums);
2058         dest16_y += dest_stride;
2059         StoreHi8(dest16_y, sums);
2060         dest16_y += dest_stride;
2061       } else {
2062         const __m128i result = _mm_packus_epi16(sums, sums);
2063         if (width_class == 2) {
2064           Store2(dest_y, result);
2065           dest_y += dest_stride;
2066           Store2(dest_y, _mm_srli_si128(result, 4));
2067         } else {
2068           Store4(dest_y, result);
2069           dest_y += dest_stride;
2070           Store4(dest_y, _mm_srli_si128(result, 4));
2071         }
2072         dest_y += dest_stride;
2073       }
2074       y -= 2;
2075     } while (y != 0);
2076     return;
2077   }
2078 
2079   // |width_class| >= 8
2080   __m128i filter_taps[num_taps >> 1];
2081   do {  // y > 0
2082     src_y = src + (p >> kScaleSubPixelBits) * src_stride;
2083     const int filter_id = (p >> 6) & kSubPixelMask;
2084     const int8_t* filter =
2085         kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
2086     PrepareVerticalTaps<num_taps>(filter, filter_taps);
2087 
2088     int x = 0;
2089     do {  // x < width
2090       for (int i = 0; i < num_taps; ++i) {
2091         s[i] = LoadUnaligned16(src_y + i * src_stride);
2092       }
2093 
2094       const __m128i sums =
2095           Sum2DVerticalTaps<num_taps, is_compound>(s, filter_taps);
2096       if (is_compound) {
2097         StoreUnaligned16(dest16_y + x, sums);
2098       } else {
2099         StoreLo8(dest_y + x, _mm_packus_epi16(sums, sums));
2100       }
2101       x += 8;
2102       src_y += 8;
2103     } while (x < width);
2104     p += step_y;
2105     dest_y += dest_stride;
2106     dest16_y += dest_stride;
2107   } while (--y != 0);
2108 }
2109 
2110 template <bool is_compound>
ConvolveScale2D_SSE4_1(const void * const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int subpixel_x,const int subpixel_y,const int step_x,const int step_y,const int width,const int height,void * prediction,const ptrdiff_t pred_stride)2111 void ConvolveScale2D_SSE4_1(const void* const reference,
2112                             const ptrdiff_t reference_stride,
2113                             const int horizontal_filter_index,
2114                             const int vertical_filter_index,
2115                             const int subpixel_x, const int subpixel_y,
2116                             const int step_x, const int step_y, const int width,
2117                             const int height, void* prediction,
2118                             const ptrdiff_t pred_stride) {
2119   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
2120   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
2121   assert(step_x <= 2048);
2122   // The output of the horizontal filter, i.e. the intermediate_result, is
2123   // guaranteed to fit in int16_t.
2124   // TODO(petersonab): Reduce intermediate block stride to width to make smaller
2125   // blocks faster.
2126   alignas(16) int16_t
2127       intermediate_result[kMaxSuperBlockSizeInPixels *
2128                           (2 * kMaxSuperBlockSizeInPixels + kSubPixelTaps)];
2129   const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
2130   const int intermediate_height =
2131       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
2132        kScaleSubPixelBits) +
2133       num_vert_taps;
2134 
2135   // Horizontal filter.
2136   // Filter types used for width <= 4 are different from those for width > 4.
2137   // When width > 4, the valid filter index range is always [0, 3].
2138   // When width <= 4, the valid filter index range is always [3, 5].
2139   // Similarly for height.
2140   int16_t* intermediate = intermediate_result;
2141   const ptrdiff_t src_stride = reference_stride;
2142   const auto* src = static_cast<const uint8_t*>(reference);
2143   const int vert_kernel_offset = (8 - num_vert_taps) / 2;
2144   src += vert_kernel_offset * src_stride;
2145 
2146   // Derive the maximum value of |step_x| at which all source values fit in one
2147   // 16-byte load. Final index is src_x + |num_taps| - 1 < 16
2148   // step_x*7 is the final base sub-pixel index for the shuffle mask for filter
2149   // inputs in each iteration on large blocks. When step_x is large, we need a
2150   // second register and alignr in order to gather all filter inputs.
2151   // |num_taps| - 1 is the offset for the shuffle of inputs to the final tap.
2152   const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index);
2153   const int kernel_start_ceiling = 16 - num_horiz_taps;
2154   // This truncated quotient |grade_x_threshold| selects |step_x| such that:
2155   // (step_x * 7) >> kScaleSubPixelBits < single load limit
2156   const int grade_x_threshold =
2157       (kernel_start_ceiling << kScaleSubPixelBits) / 7;
2158   switch (horiz_filter_index) {
2159     case 0:
2160       if (step_x > grade_x_threshold) {
2161         ConvolveHorizontalScale<2, 0, 6>(src, src_stride, width, subpixel_x,
2162                                          step_x, intermediate_height,
2163                                          intermediate);
2164       } else {
2165         ConvolveHorizontalScale<1, 0, 6>(src, src_stride, width, subpixel_x,
2166                                          step_x, intermediate_height,
2167                                          intermediate);
2168       }
2169       break;
2170     case 1:
2171       if (step_x > grade_x_threshold) {
2172         ConvolveHorizontalScale<2, 1, 6>(src, src_stride, width, subpixel_x,
2173                                          step_x, intermediate_height,
2174                                          intermediate);
2175 
2176       } else {
2177         ConvolveHorizontalScale<1, 1, 6>(src, src_stride, width, subpixel_x,
2178                                          step_x, intermediate_height,
2179                                          intermediate);
2180       }
2181       break;
2182     case 2:
2183       if (step_x > grade_x_threshold) {
2184         ConvolveHorizontalScale<2, 2, 8>(src, src_stride, width, subpixel_x,
2185                                          step_x, intermediate_height,
2186                                          intermediate);
2187       } else {
2188         ConvolveHorizontalScale<1, 2, 8>(src, src_stride, width, subpixel_x,
2189                                          step_x, intermediate_height,
2190                                          intermediate);
2191       }
2192       break;
2193     case 3:
2194       if (step_x > grade_x_threshold) {
2195         ConvolveHorizontalScale<2, 3, 2>(src, src_stride, width, subpixel_x,
2196                                          step_x, intermediate_height,
2197                                          intermediate);
2198       } else {
2199         ConvolveHorizontalScale<1, 3, 2>(src, src_stride, width, subpixel_x,
2200                                          step_x, intermediate_height,
2201                                          intermediate);
2202       }
2203       break;
2204     case 4:
2205       assert(width <= 4);
2206       ConvolveHorizontalScale<1, 4, 4>(src, src_stride, width, subpixel_x,
2207                                        step_x, intermediate_height,
2208                                        intermediate);
2209       break;
2210     default:
2211       assert(horiz_filter_index == 5);
2212       assert(width <= 4);
2213       ConvolveHorizontalScale<1, 5, 4>(src, src_stride, width, subpixel_x,
2214                                        step_x, intermediate_height,
2215                                        intermediate);
2216   }
2217 
2218   // Vertical filter.
2219   intermediate = intermediate_result;
2220   switch (vert_filter_index) {
2221     case 0:
2222     case 1:
2223       if (!is_compound && width == 2) {
2224         ConvolveVerticalScale<6, 2, is_compound>(
2225             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2226             prediction, pred_stride);
2227       } else if (width == 4) {
2228         ConvolveVerticalScale<6, 4, is_compound>(
2229             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2230             prediction, pred_stride);
2231       } else {
2232         ConvolveVerticalScale<6, 8, is_compound>(
2233             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2234             prediction, pred_stride);
2235       }
2236       break;
2237     case 2:
2238       if (!is_compound && width == 2) {
2239         ConvolveVerticalScale<8, 2, is_compound>(
2240             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2241             prediction, pred_stride);
2242       } else if (width == 4) {
2243         ConvolveVerticalScale<8, 4, is_compound>(
2244             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2245             prediction, pred_stride);
2246       } else {
2247         ConvolveVerticalScale<8, 8, is_compound>(
2248             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2249             prediction, pred_stride);
2250       }
2251       break;
2252     case 3:
2253       if (!is_compound && width == 2) {
2254         ConvolveVerticalScale<2, 2, is_compound>(
2255             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2256             prediction, pred_stride);
2257       } else if (width == 4) {
2258         ConvolveVerticalScale<2, 4, is_compound>(
2259             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2260             prediction, pred_stride);
2261       } else {
2262         ConvolveVerticalScale<2, 8, is_compound>(
2263             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2264             prediction, pred_stride);
2265       }
2266       break;
2267     default:
2268       assert(vert_filter_index == 4 || vert_filter_index == 5);
2269       if (!is_compound && width == 2) {
2270         ConvolveVerticalScale<4, 2, is_compound>(
2271             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2272             prediction, pred_stride);
2273       } else if (width == 4) {
2274         ConvolveVerticalScale<4, 4, is_compound>(
2275             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2276             prediction, pred_stride);
2277       } else {
2278         ConvolveVerticalScale<4, 8, is_compound>(
2279             intermediate, width, subpixel_y, vert_filter_index, step_y, height,
2280             prediction, pred_stride);
2281       }
2282   }
2283 }
2284 
Init8bpp()2285 void Init8bpp() {
2286   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
2287   assert(dsp != nullptr);
2288   dsp->convolve[0][0][0][1] = ConvolveHorizontal_SSE4_1;
2289   dsp->convolve[0][0][1][0] = ConvolveVertical_SSE4_1;
2290   dsp->convolve[0][0][1][1] = Convolve2D_SSE4_1;
2291 
2292   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_SSE4;
2293   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_SSE4_1;
2294   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_SSE4_1;
2295   dsp->convolve[0][1][1][1] = ConvolveCompound2D_SSE4_1;
2296 
2297   dsp->convolve_scale[0] = ConvolveScale2D_SSE4_1<false>;
2298   dsp->convolve_scale[1] = ConvolveScale2D_SSE4_1<true>;
2299 }
2300 
2301 }  // namespace
2302 }  // namespace low_bitdepth
2303 
ConvolveInit_SSE4_1()2304 void ConvolveInit_SSE4_1() { low_bitdepth::Init8bpp(); }
2305 
2306 }  // namespace dsp
2307 }  // namespace libgav1
2308 
2309 #else  // !LIBGAV1_ENABLE_SSE4_1
2310 namespace libgav1 {
2311 namespace dsp {
2312 
ConvolveInit_SSE4_1()2313 void ConvolveInit_SSE4_1() {}
2314 
2315 }  // namespace dsp
2316 }  // namespace libgav1
2317 #endif  // LIBGAV1_ENABLE_SSE4_1
2318