• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/av1_rtcd.h"
16 
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_ports/mem.h"
20 #include "av1/common/arm/convolve_neon.h"
21 #include "av1/common/convolve.h"
22 #include "av1/common/filter.h"
23 
24 DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
25   0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
26   4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
27   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
28 };
29 
convolve12_4_x(uint8x16_t samples,const int8x16_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)30 static INLINE int16x4_t convolve12_4_x(uint8x16_t samples,
31                                        const int8x16_t filter,
32                                        const int32x4_t correction,
33                                        const uint8x16_t range_limit,
34                                        const uint8x16x3_t permute_tbl) {
35   int8x16_t clamped_samples, permuted_samples[3];
36   int32x4_t sum;
37 
38   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
39   clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
40 
41   // Permute samples ready for dot product.
42   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
43   permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
44   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
45   permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
46   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
47   permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
48 
49   // Accumulate dot product into 'correction' to account for range clamp.
50   // First 4 output values.
51   sum = vdotq_laneq_s32(correction, permuted_samples[0], filter, 0);
52   sum = vdotq_laneq_s32(sum, permuted_samples[1], filter, 1);
53   sum = vdotq_laneq_s32(sum, permuted_samples[2], filter, 2);
54 
55   return vqrshrn_n_s32(sum, FILTER_BITS);
56 }
57 
convolve12_8_x(uint8x16_t samples[2],const int8x16_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)58 static INLINE uint8x8_t convolve12_8_x(uint8x16_t samples[2],
59                                        const int8x16_t filter,
60                                        const int32x4_t correction,
61                                        const uint8x16_t range_limit,
62                                        const uint8x16x3_t permute_tbl) {
63   int8x16_t clamped_samples[2], permuted_samples[4];
64   int32x4_t sum[2];
65 
66   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
67   clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples[0], range_limit));
68   clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples[1], range_limit));
69 
70   // Permute samples ready for dot product.
71   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
72   permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
73   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
74   permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
75   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
76   permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
77   // {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 }
78   permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
79 
80   // Accumulate dot product into 'correction' to account for range clamp.
81   // First 4 output values.
82   sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filter, 0);
83   sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filter, 1);
84   sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filter, 2);
85   // Second 4 output values.
86   sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filter, 0);
87   sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filter, 1);
88   sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filter, 2);
89 
90   // Narrow and re-pack.
91   int16x8_t sum_s16 = vcombine_s16(vqrshrn_n_s32(sum[0], FILTER_BITS),
92                                    vqrshrn_n_s32(sum[1], FILTER_BITS));
93   return vqmovun_s16(sum_s16);
94 }
95 
convolve_x_sr_12tap_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const int16_t * x_filter_ptr)96 static INLINE void convolve_x_sr_12tap_neon_dotprod(
97     const uint8_t *src, int src_stride, uint8_t *dst, int dst_stride, int w,
98     int h, const int16_t *x_filter_ptr) {
99   const int16x8_t filter_0_7 = vld1q_s16(x_filter_ptr);
100   const int16x4_t filter_8_11 = vld1_s16(x_filter_ptr + 8);
101   const int16x8_t filter_8_15 = vcombine_s16(filter_8_11, vdup_n_s16(0));
102   const int8x16_t filter =
103       vcombine_s8(vmovn_s16(filter_0_7), vmovn_s16(filter_8_15));
104 
105   const int32_t correction_s32 =
106       vaddvq_s32(vaddq_s32(vpaddlq_s16(vshlq_n_s16(filter_0_7, FILTER_BITS)),
107                            vpaddlq_s16(vshlq_n_s16(filter_8_15, FILTER_BITS))));
108   // A shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding right
109   // shift by FILTER_BITS - instead of a first rounding right shift by
110   // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
111   // ROUND0_BITS.
112   int32x4_t correction = vdupq_n_s32(correction_s32 + (1 << (ROUND0_BITS - 1)));
113   const uint8x16_t range_limit = vdupq_n_u8(128);
114   const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
115 
116   // Special case the following no-op filter as 128 won't fit into the
117   // 8-bit signed dot-product instruction:
118   // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
119   if (vgetq_lane_s16(filter_0_7, 5) == 128) {
120     // Undo the horizontal offset in the calling function.
121     src += 5;
122 
123     do {
124       const uint8_t *s = src;
125       uint8_t *d = dst;
126       int width = w;
127 
128       do {
129         uint8x8_t d0 = vld1_u8(s);
130         if (w == 4) {
131           store_u8_4x1(d, d0);
132         } else {
133           vst1_u8(d, d0);
134         }
135 
136         s += 8;
137         d += 8;
138         width -= 8;
139       } while (width > 0);
140       src += src_stride;
141       dst += dst_stride;
142     } while (--h != 0);
143   } else {
144     if (w <= 4) {
145       do {
146         uint8x16_t s0, s1, s2, s3;
147         load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
148 
149         int16x4_t d0 =
150             convolve12_4_x(s0, filter, correction, range_limit, permute_tbl);
151         int16x4_t d1 =
152             convolve12_4_x(s1, filter, correction, range_limit, permute_tbl);
153         int16x4_t d2 =
154             convolve12_4_x(s2, filter, correction, range_limit, permute_tbl);
155         int16x4_t d3 =
156             convolve12_4_x(s3, filter, correction, range_limit, permute_tbl);
157 
158         uint8x8_t d01 = vqmovun_s16(vcombine_s16(d0, d1));
159         uint8x8_t d23 = vqmovun_s16(vcombine_s16(d2, d3));
160 
161         store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
162         store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
163 
164         dst += 4 * dst_stride;
165         src += 4 * src_stride;
166         h -= 4;
167       } while (h != 0);
168     } else {
169       do {
170         const uint8_t *s = src;
171         uint8_t *d = dst;
172         int width = w;
173 
174         do {
175           uint8x16_t s0[2], s1[2], s2[2], s3[2];
176           load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
177           load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
178 
179           uint8x8_t d0 =
180               convolve12_8_x(s0, filter, correction, range_limit, permute_tbl);
181           uint8x8_t d1 =
182               convolve12_8_x(s1, filter, correction, range_limit, permute_tbl);
183           uint8x8_t d2 =
184               convolve12_8_x(s2, filter, correction, range_limit, permute_tbl);
185           uint8x8_t d3 =
186               convolve12_8_x(s3, filter, correction, range_limit, permute_tbl);
187 
188           store_u8_8x4(d + 0 * dst_stride, dst_stride, d0, d1, d2, d3);
189 
190           s += 8;
191           d += 8;
192           width -= 8;
193         } while (width != 0);
194         src += 4 * src_stride;
195         dst += 4 * dst_stride;
196         h -= 4;
197       } while (h != 0);
198     }
199   }
200 }
201 
convolve4_4_x(uint8x16_t samples,const int8x8_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)202 static INLINE int16x4_t convolve4_4_x(uint8x16_t samples, const int8x8_t filter,
203                                       const int32x4_t correction,
204                                       const uint8x16_t range_limit,
205                                       const uint8x16_t permute_tbl) {
206   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
207   int8x16_t clamped_samples =
208       vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
209 
210   // Permute samples ready for dot product.
211   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
212   int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
213 
214   // Accumulate dot product into 'correction' to account for range clamp.
215   int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, filter, 0);
216 
217   // Packing is performed by the caller.
218   return vmovn_s32(sum);
219 }
220 
convolve8_8_x(uint8x16_t samples,const int8x8_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)221 static INLINE uint8x8_t convolve8_8_x(uint8x16_t samples, const int8x8_t filter,
222                                       const int32x4_t correction,
223                                       const uint8x16_t range_limit,
224                                       const uint8x16x3_t permute_tbl) {
225   int8x16_t clamped_samples, permuted_samples[3];
226   int32x4_t sum[2];
227 
228   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
229   clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
230 
231   // Permute samples ready for dot product. */
232   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
233   permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
234   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
235   permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
236   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
237   permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
238 
239   // Accumulate dot product into 'correction' to account for range clamp.
240   // First 4 output values.
241   sum[0] = vdotq_lane_s32(correction, permuted_samples[0], filter, 0);
242   sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], filter, 1);
243   // Second 4 output values.
244   sum[1] = vdotq_lane_s32(correction, permuted_samples[1], filter, 0);
245   sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], filter, 1);
246 
247   // Narrow and re-pack.
248   int16x8_t sum_s16 = vcombine_s16(vmovn_s32(sum[0]), vmovn_s32(sum[1]));
249   // We halved the convolution filter values so - 1 from the right shift.
250   return vqrshrun_n_s16(sum_s16, FILTER_BITS - 1);
251 }
252 
av1_convolve_x_sr_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)253 void av1_convolve_x_sr_neon_dotprod(const uint8_t *src, int src_stride,
254                                     uint8_t *dst, int dst_stride, int w, int h,
255                                     const InterpFilterParams *filter_params_x,
256                                     const int subpel_x_qn,
257                                     ConvolveParams *conv_params) {
258   if (w == 2 || h == 2) {
259     av1_convolve_x_sr_c(src, src_stride, dst, dst_stride, w, h, filter_params_x,
260                         subpel_x_qn, conv_params);
261     return;
262   }
263 
264   const uint8_t horiz_offset = filter_params_x->taps / 2 - 1;
265   src -= horiz_offset;
266 
267   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
268       filter_params_x, subpel_x_qn & SUBPEL_MASK);
269 
270   if (filter_params_x->taps > 8) {
271     convolve_x_sr_12tap_neon_dotprod(src, src_stride, dst, dst_stride, w, h,
272                                      x_filter_ptr);
273     return;
274   }
275 
276   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
277   // Dot product constants.
278   const int32_t correction_s32 =
279       vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
280   // This shim of (1 << ((ROUND0_BITS - 1) - 1) enables us to use a single
281   // rounding right shift by FILTER_BITS - instead of a first rounding right
282   // shift by ROUND0_BITS, followed by second rounding right shift by
283   // FILTER_BITS - ROUND0_BITS.
284   // The outermost -1 is needed because we will halve the filter values.
285   const int32x4_t correction =
286       vdupq_n_s32(correction_s32 + (1 << ((ROUND0_BITS - 1) - 1)));
287   const uint8x16_t range_limit = vdupq_n_u8(128);
288 
289   if (w <= 4) {
290     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
291     // 4-tap filters are used for blocks having width <= 4.
292     // Filter values are even, so halve to reduce intermediate precision reqs.
293     const int8x8_t x_filter =
294         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
295 
296     src += 2;
297 
298     do {
299       uint8x16_t s0, s1, s2, s3;
300       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
301 
302       int16x4_t d0 =
303           convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
304       int16x4_t d1 =
305           convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
306       int16x4_t d2 =
307           convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
308       int16x4_t d3 =
309           convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
310 
311       // We halved the convolution filter values so - 1 from the right shift.
312       uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS - 1);
313       uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS - 1);
314 
315       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
316       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
317 
318       src += 4 * src_stride;
319       dst += 4 * dst_stride;
320       h -= 4;
321     } while (h != 0);
322   } else {
323     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
324     // Filter values are even, so halve to reduce intermediate precision reqs.
325     const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
326 
327     do {
328       int width = w;
329       const uint8_t *s = src;
330       uint8_t *d = dst;
331 
332       do {
333         uint8x16_t s0, s1, s2, s3;
334         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
335 
336         uint8x8_t d0 =
337             convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
338         uint8x8_t d1 =
339             convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
340         uint8x8_t d2 =
341             convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
342         uint8x8_t d3 =
343             convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
344 
345         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
346 
347         s += 8;
348         d += 8;
349         width -= 8;
350       } while (width != 0);
351       src += 4 * src_stride;
352       dst += 4 * dst_stride;
353       h -= 4;
354     } while (h != 0);
355   }
356 }
357 
convolve12_4_2d_h(uint8x16_t samples,const int8x16_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)358 static INLINE int16x4_t convolve12_4_2d_h(uint8x16_t samples,
359                                           const int8x16_t filters,
360                                           const int32x4_t correction,
361                                           const uint8x16_t range_limit,
362                                           const uint8x16x3_t permute_tbl) {
363   int8x16_t clamped_samples, permuted_samples[3];
364   int32x4_t sum;
365 
366   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
367   clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
368 
369   // Permute samples ready for dot product.
370   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
371   permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
372   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
373   permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
374   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
375   permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
376 
377   // Accumulate dot product into 'correction' to account for range clamp.
378   // First 4 output values.
379   sum = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
380   sum = vdotq_laneq_s32(sum, permuted_samples[1], filters, 1);
381   sum = vdotq_laneq_s32(sum, permuted_samples[2], filters, 2);
382 
383   // Narrow and re-pack.
384   return vshrn_n_s32(sum, ROUND0_BITS);
385 }
386 
convolve12_8_2d_h(uint8x16_t samples[2],const int8x16_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)387 static INLINE int16x8_t convolve12_8_2d_h(uint8x16_t samples[2],
388                                           const int8x16_t filters,
389                                           const int32x4_t correction,
390                                           const uint8x16_t range_limit,
391                                           const uint8x16x3_t permute_tbl) {
392   int8x16_t clamped_samples[2], permuted_samples[4];
393   int32x4_t sum[2];
394 
395   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
396   clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples[0], range_limit));
397   clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples[1], range_limit));
398 
399   // Permute samples ready for dot product.
400   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
401   permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
402   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
403   permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
404   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
405   permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
406   // {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 }
407   permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
408 
409   // Accumulate dot product into 'correction' to account for range clamp.
410   // First 4 output values.
411   sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
412   sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filters, 1);
413   sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filters, 2);
414   // Second 4 output values.
415   sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filters, 0);
416   sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filters, 1);
417   sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filters, 2);
418 
419   // Narrow and re-pack.
420   return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS),
421                       vshrn_n_s32(sum[1], ROUND0_BITS));
422 }
423 
convolve_2d_sr_horiz_12tap_neon_dotprod(const uint8_t * src_ptr,int src_stride,int16_t * dst_ptr,const int dst_stride,int w,int h,const int16x8_t x_filter_0_7,const int16x4_t x_filter_8_11)424 static INLINE void convolve_2d_sr_horiz_12tap_neon_dotprod(
425     const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
426     const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
427     const int16x4_t x_filter_8_11) {
428   const int bd = 8;
429 
430   // Special case the following no-op filter as 128 won't fit into the 8-bit
431   // signed dot-product instruction:
432   // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
433   if (vgetq_lane_s16(x_filter_0_7, 5) == 128) {
434     const uint16x8_t horiz_const = vdupq_n_u16((1 << (bd - 1)));
435     // Undo the horizontal offset in the calling function.
436     src_ptr += 5;
437 
438     do {
439       const uint8_t *s = src_ptr;
440       int16_t *d = dst_ptr;
441       int width = w;
442 
443       do {
444         uint8x8_t s0 = vld1_u8(s);
445         uint16x8_t d0 = vaddw_u8(horiz_const, s0);
446         d0 = vshlq_n_u16(d0, FILTER_BITS - ROUND0_BITS);
447         // Store 8 elements to avoid additional branches. This is safe if the
448         // actual block width is < 8 because the intermediate buffer is large
449         // enough to accommodate 128x128 blocks.
450         vst1q_s16(d, vreinterpretq_s16_u16(d0));
451 
452         d += 8;
453         s += 8;
454         width -= 8;
455       } while (width > 0);
456       src_ptr += src_stride;
457       dst_ptr += dst_stride;
458     } while (--h != 0);
459 
460   } else {
461     // Narrow filter values to 8-bit.
462     const int16x8x2_t x_filter_s16 = {
463       { x_filter_0_7, vcombine_s16(x_filter_8_11, vdup_n_s16(0)) }
464     };
465     const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
466                                            vmovn_s16(x_filter_s16.val[1]));
467 
468     // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
469     // - which are generally faster than rounding shifts on modern CPUs.
470     const int32_t horiz_const =
471         ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
472     // Dot product constants.
473     const int32x4_t correct_tmp =
474         vaddq_s32(vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[0], 7)),
475                   vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[1], 7)));
476     const int32x4_t correction =
477         vdupq_n_s32(vaddvq_s32(correct_tmp) + horiz_const);
478     const uint8x16_t range_limit = vdupq_n_u8(128);
479     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
480 
481     if (w <= 4) {
482       do {
483         uint8x16_t s0, s1, s2, s3;
484         load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
485 
486         int16x4_t d0 = convolve12_4_2d_h(s0, x_filter, correction, range_limit,
487                                          permute_tbl);
488         int16x4_t d1 = convolve12_4_2d_h(s1, x_filter, correction, range_limit,
489                                          permute_tbl);
490         int16x4_t d2 = convolve12_4_2d_h(s2, x_filter, correction, range_limit,
491                                          permute_tbl);
492         int16x4_t d3 = convolve12_4_2d_h(s3, x_filter, correction, range_limit,
493                                          permute_tbl);
494 
495         store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
496 
497         src_ptr += 4 * src_stride;
498         dst_ptr += 4 * dst_stride;
499         h -= 4;
500       } while (h > 4);
501 
502       do {
503         uint8x16_t s0 = vld1q_u8(src_ptr);
504         int16x4_t d0 = convolve12_4_2d_h(s0, x_filter, correction, range_limit,
505                                          permute_tbl);
506         vst1_s16(dst_ptr, d0);
507 
508         src_ptr += src_stride;
509         dst_ptr += dst_stride;
510       } while (--h != 0);
511 
512     } else {
513       do {
514         const uint8_t *s = src_ptr;
515         int16_t *d = dst_ptr;
516         int width = w;
517 
518         do {
519           uint8x16_t s0[2], s1[2], s2[2], s3[2];
520           load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
521           load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
522 
523           int16x8_t d0 = convolve12_8_2d_h(s0, x_filter, correction,
524                                            range_limit, permute_tbl);
525           int16x8_t d1 = convolve12_8_2d_h(s1, x_filter, correction,
526                                            range_limit, permute_tbl);
527           int16x8_t d2 = convolve12_8_2d_h(s2, x_filter, correction,
528                                            range_limit, permute_tbl);
529           int16x8_t d3 = convolve12_8_2d_h(s3, x_filter, correction,
530                                            range_limit, permute_tbl);
531 
532           store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
533 
534           s += 8;
535           d += 8;
536           width -= 8;
537         } while (width != 0);
538         src_ptr += 4 * src_stride;
539         dst_ptr += 4 * dst_stride;
540         h -= 4;
541       } while (h > 4);
542 
543       do {
544         const uint8_t *s = src_ptr;
545         int16_t *d = dst_ptr;
546         int width = w;
547 
548         do {
549           uint8x16_t s0[2];
550           s0[0] = vld1q_u8(s);
551           s0[1] = vld1q_u8(s + 4);
552           int16x8_t d0 = convolve12_8_2d_h(s0, x_filter, correction,
553                                            range_limit, permute_tbl);
554           vst1q_s16(d, d0);
555 
556           s += 8;
557           d += 8;
558           width -= 8;
559         } while (width != 0);
560         src_ptr += src_stride;
561         dst_ptr += dst_stride;
562       } while (--h != 0);
563     }
564   }
565 }
566 
convolve4_4_2d_h(uint8x16_t samples,const int8x8_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)567 static INLINE int16x4_t convolve4_4_2d_h(uint8x16_t samples,
568                                          const int8x8_t filters,
569                                          const int32x4_t correction,
570                                          const uint8x16_t range_limit,
571                                          const uint8x16_t permute_tbl) {
572   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
573   int8x16_t clamped_samples =
574       vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
575 
576   // Permute samples ready for dot product.
577   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
578   int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
579 
580   // Accumulate dot product into 'correction' to account for range clamp.
581   int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, filters, 0);
582 
583   // We halved the convolution filter values so -1 from the right shift.
584   return vshrn_n_s32(sum, ROUND0_BITS - 1);
585 }
586 
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)587 static INLINE int16x8_t convolve8_8_2d_h(uint8x16_t samples,
588                                          const int8x8_t filters,
589                                          const int32x4_t correction,
590                                          const uint8x16_t range_limit,
591                                          const uint8x16x3_t permute_tbl) {
592   int8x16_t clamped_samples, permuted_samples[3];
593   int32x4_t sum[2];
594 
595   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
596   clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
597 
598   // Permute samples ready for dot product.
599   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
600   permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
601   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
602   permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
603   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
604   permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
605 
606   // Accumulate dot product into 'correction' to account for range clamp.
607   // First 4 output values.
608   sum[0] = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
609   sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
610   // Second 4 output values.
611   sum[1] = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
612   sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
613 
614   // Narrow and re-pack.
615   // We halved the convolution filter values so -1 from the right shift.
616   return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
617                       vshrn_n_s32(sum[1], ROUND0_BITS - 1));
618 }
619 
convolve_2d_sr_horiz_neon_dotprod(const uint8_t * src,int src_stride,int16_t * im_block,int im_stride,int w,int im_h,const int16_t * x_filter_ptr)620 static INLINE void convolve_2d_sr_horiz_neon_dotprod(
621     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
622     int im_h, const int16_t *x_filter_ptr) {
623   const int bd = 8;
624   // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
625   // shifts - which are generally faster than rounding shifts on modern CPUs.
626   // The outermost -1 is needed because we halved the filter values.
627   const int32_t horiz_const =
628       ((1 << (bd + FILTER_BITS - 2)) + (1 << ((ROUND0_BITS - 1) - 1)));
629   // Dot product constants.
630   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
631   const int32_t correction_s32 =
632       vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
633   const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const);
634   const uint8x16_t range_limit = vdupq_n_u8(128);
635 
636   const uint8_t *src_ptr = src;
637   int16_t *dst_ptr = im_block;
638   int dst_stride = im_stride;
639   int height = im_h;
640 
641   if (w <= 4) {
642     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
643     // 4-tap filters are used for blocks having width <= 4.
644     // Filter values are even, so halve to reduce intermediate precision reqs.
645     const int8x8_t x_filter =
646         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
647 
648     src_ptr += 2;
649 
650     do {
651       uint8x16_t s0, s1, s2, s3;
652       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
653 
654       int16x4_t d0 =
655           convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
656       int16x4_t d1 =
657           convolve4_4_2d_h(s1, x_filter, correction, range_limit, permute_tbl);
658       int16x4_t d2 =
659           convolve4_4_2d_h(s2, x_filter, correction, range_limit, permute_tbl);
660       int16x4_t d3 =
661           convolve4_4_2d_h(s3, x_filter, correction, range_limit, permute_tbl);
662 
663       store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
664 
665       src_ptr += 4 * src_stride;
666       dst_ptr += 4 * dst_stride;
667       height -= 4;
668     } while (height > 4);
669 
670     do {
671       uint8x16_t s0 = vld1q_u8(src_ptr);
672       int16x4_t d0 =
673           convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
674       vst1_s16(dst_ptr, d0);
675 
676       src_ptr += src_stride;
677       dst_ptr += dst_stride;
678     } while (--height != 0);
679   } else {
680     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
681     // Filter values are even, so halve to reduce intermediate precision reqs.
682     const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
683 
684     do {
685       const uint8_t *s = src_ptr;
686       int16_t *d = dst_ptr;
687       int width = w;
688 
689       do {
690         uint8x16_t s0, s1, s2, s3;
691         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
692 
693         int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
694                                         permute_tbl);
695         int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, correction, range_limit,
696                                         permute_tbl);
697         int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, correction, range_limit,
698                                         permute_tbl);
699         int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, correction, range_limit,
700                                         permute_tbl);
701 
702         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
703 
704         s += 8;
705         d += 8;
706         width -= 8;
707       } while (width != 0);
708       src_ptr += 4 * src_stride;
709       dst_ptr += 4 * dst_stride;
710       height -= 4;
711     } while (height > 4);
712 
713     do {
714       const uint8_t *s = src_ptr;
715       int16_t *d = dst_ptr;
716       int width = w;
717 
718       do {
719         uint8x16_t s0 = vld1q_u8(s);
720         int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
721                                         permute_tbl);
722         vst1q_s16(d, d0);
723 
724         s += 8;
725         d += 8;
726         width -= 8;
727       } while (width != 0);
728       src_ptr += src_stride;
729       dst_ptr += dst_stride;
730     } while (--height != 0);
731   }
732 }
733 
av1_convolve_2d_sr_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)734 void av1_convolve_2d_sr_neon_dotprod(const uint8_t *src, int src_stride,
735                                      uint8_t *dst, int dst_stride, int w, int h,
736                                      const InterpFilterParams *filter_params_x,
737                                      const InterpFilterParams *filter_params_y,
738                                      const int subpel_x_qn,
739                                      const int subpel_y_qn,
740                                      ConvolveParams *conv_params) {
741   if (w == 2 || h == 2) {
742     av1_convolve_2d_sr_c(src, src_stride, dst, dst_stride, w, h,
743                          filter_params_x, filter_params_y, subpel_x_qn,
744                          subpel_y_qn, conv_params);
745     return;
746   }
747 
748   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
749   const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
750   const int im_h = h + clamped_y_taps - 1;
751   const int im_stride = MAX_SB_SIZE;
752   const int vert_offset = clamped_y_taps / 2 - 1;
753   const int horiz_offset = filter_params_x->taps / 2 - 1;
754   const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
755 
756   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
757       filter_params_x, subpel_x_qn & SUBPEL_MASK);
758   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
759       filter_params_y, subpel_y_qn & SUBPEL_MASK);
760 
761   if (filter_params_x->taps > 8) {
762     DECLARE_ALIGNED(16, int16_t,
763                     im_block[(MAX_SB_SIZE + MAX_FILTER_TAP - 1) * MAX_SB_SIZE]);
764 
765     const int16x8_t x_filter_0_7 = vld1q_s16(x_filter_ptr);
766     const int16x4_t x_filter_8_11 = vld1_s16(x_filter_ptr + 8);
767     const int16x8_t y_filter_0_7 = vld1q_s16(y_filter_ptr);
768     const int16x4_t y_filter_8_11 = vld1_s16(y_filter_ptr + 8);
769 
770     convolve_2d_sr_horiz_12tap_neon_dotprod(src_ptr, src_stride, im_block,
771                                             im_stride, w, im_h, x_filter_0_7,
772                                             x_filter_8_11);
773 
774     convolve_2d_sr_vert_12tap_neon(im_block, im_stride, dst, dst_stride, w, h,
775                                    y_filter_0_7, y_filter_8_11);
776   } else {
777     DECLARE_ALIGNED(16, int16_t,
778                     im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
779 
780     convolve_2d_sr_horiz_neon_dotprod(src_ptr, src_stride, im_block, im_stride,
781                                       w, im_h, x_filter_ptr);
782 
783     const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
784 
785     if (clamped_y_taps <= 6) {
786       convolve_2d_sr_vert_6tap_neon(im_block, im_stride, dst, dst_stride, w, h,
787                                     y_filter);
788     } else {
789       convolve_2d_sr_vert_8tap_neon(im_block, im_stride, dst, dst_stride, w, h,
790                                     y_filter);
791     }
792   }
793 }
794