• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "av1/common/arm/compound_convolve_neon.h"
17 #include "config/aom_config.h"
18 #include "config/av1_rtcd.h"
19 
20 DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
21   0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
22   4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
23   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
24 };
25 
convolve4_4_2d_h(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)26 static INLINE int16x4_t convolve4_4_2d_h(uint8x16_t samples,
27                                          const int8x8_t x_filter,
28                                          const int32x4_t correction,
29                                          const uint8x16_t range_limit,
30                                          const uint8x16_t permute_tbl) {
31   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
32   int8x16_t clamped_samples =
33       vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
34 
35   // Permute samples ready for dot product.
36   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
37   int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
38 
39   // Accumulate dot product into 'correction' to account for range clamp.
40   int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
41 
42   // We halved the convolution filter values so -1 from the right shift.
43   return vshrn_n_s32(sum, ROUND0_BITS - 1);
44 }
45 
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)46 static INLINE int16x8_t convolve8_8_2d_h(uint8x16_t samples,
47                                          const int8x8_t x_filter,
48                                          const int32x4_t correction,
49                                          const uint8x16_t range_limit,
50                                          const uint8x16x3_t permute_tbl) {
51   int8x16_t clamped_samples, permuted_samples[3];
52   int32x4_t sum[2];
53 
54   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
55   clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
56 
57   // Permute samples ready for dot product. */
58   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
59   permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
60   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
61   permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
62   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
63   permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
64 
65   // Accumulate dot product into 'correction' to account for range clamp.
66   // First 4 output values.
67   sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
68   sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
69   // Second 4 output values.
70   sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
71   sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
72 
73   // Narrow and re-pack.
74   // We halved the convolution filter values so -1 from the right shift.
75   return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
76                       vshrn_n_s32(sum[1], ROUND0_BITS - 1));
77 }
78 
dist_wtd_convolve_2d_horiz_neon_dotprod(const uint8_t * src,int src_stride,int16_t * im_block,const int im_stride,const int16_t * x_filter_ptr,const int im_h,int w)79 static INLINE void dist_wtd_convolve_2d_horiz_neon_dotprod(
80     const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
81     const int16_t *x_filter_ptr, const int im_h, int w) {
82   const int bd = 8;
83   const int32_t horiz_const = (1 << (bd + FILTER_BITS - 2));
84   // Dot product constants and other shims.
85   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
86   const int32_t correction_s32 =
87       vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
88   // Fold horiz_const into the dot-product filter correction constant. The
89   // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
90   // rounding shifts - which are generally faster than rounding shifts on
91   // modern CPUs. (The extra -1 is needed because we halved the filter values.)
92   const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const +
93                                            (1 << ((ROUND0_BITS - 1) - 1)));
94   const uint8x16_t range_limit = vdupq_n_u8(128);
95 
96   const uint8_t *src_ptr = src;
97   int16_t *dst_ptr = im_block;
98   int dst_stride = im_stride;
99   int height = im_h;
100 
101   if (w == 4) {
102     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
103     // 4-tap filters are used for blocks having width <= 4.
104     // Filter values are even, so halve to reduce intermediate precision reqs.
105     const int8x8_t x_filter =
106         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
107 
108     src_ptr += 2;
109 
110     do {
111       uint8x16_t s0, s1, s2, s3;
112       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
113 
114       int16x4_t d0 =
115           convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
116       int16x4_t d1 =
117           convolve4_4_2d_h(s1, x_filter, correction, range_limit, permute_tbl);
118       int16x4_t d2 =
119           convolve4_4_2d_h(s2, x_filter, correction, range_limit, permute_tbl);
120       int16x4_t d3 =
121           convolve4_4_2d_h(s3, x_filter, correction, range_limit, permute_tbl);
122 
123       store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
124 
125       src_ptr += 4 * src_stride;
126       dst_ptr += 4 * dst_stride;
127       height -= 4;
128     } while (height > 4);
129 
130     do {
131       uint8x16_t s0 = vld1q_u8(src_ptr);
132 
133       int16x4_t d0 =
134           convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
135 
136       vst1_s16(dst_ptr, d0);
137 
138       src_ptr += src_stride;
139       dst_ptr += dst_stride;
140     } while (--height != 0);
141   } else {
142     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
143     // Filter values are even, so halve to reduce intermediate precision reqs.
144     const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
145 
146     do {
147       const uint8_t *s = src_ptr;
148       int16_t *d = dst_ptr;
149       int width = w;
150 
151       do {
152         uint8x16_t s0, s1, s2, s3;
153         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
154 
155         int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
156                                         permute_tbl);
157         int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, correction, range_limit,
158                                         permute_tbl);
159         int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, correction, range_limit,
160                                         permute_tbl);
161         int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, correction, range_limit,
162                                         permute_tbl);
163 
164         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
165 
166         s += 8;
167         d += 8;
168         width -= 8;
169       } while (width > 0);
170       src_ptr += 4 * src_stride;
171       dst_ptr += 4 * dst_stride;
172       height -= 4;
173     } while (height > 4);
174 
175     do {
176       const uint8_t *s = src_ptr;
177       int16_t *d = dst_ptr;
178       int width = w;
179 
180       do {
181         uint8x16_t s0 = vld1q_u8(s);
182 
183         int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
184                                         permute_tbl);
185 
186         vst1q_s16(d, d0);
187 
188         s += 8;
189         d += 8;
190         width -= 8;
191       } while (width > 0);
192       src_ptr += src_stride;
193       dst_ptr += dst_stride;
194     } while (--height != 0);
195   }
196 }
197 
av1_dist_wtd_convolve_2d_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)198 void av1_dist_wtd_convolve_2d_neon_dotprod(
199     const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
200     int h, const InterpFilterParams *filter_params_x,
201     const InterpFilterParams *filter_params_y, const int subpel_x_qn,
202     const int subpel_y_qn, ConvolveParams *conv_params) {
203   assert(w % 4 == 0);
204   assert(h % 4 == 0);
205 
206   DECLARE_ALIGNED(16, int16_t,
207                   im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
208 
209   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
210   const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
211 
212   const int im_h = h + clamped_y_taps - 1;
213   const int im_stride = MAX_SB_SIZE;
214   const int vert_offset = clamped_y_taps / 2 - 1;
215   const int horiz_offset = filter_params_x->taps / 2 - 1;
216   const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
217   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
218       filter_params_x, subpel_x_qn & SUBPEL_MASK);
219   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
220       filter_params_y, subpel_y_qn & SUBPEL_MASK);
221 
222   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
223 
224   dist_wtd_convolve_2d_horiz_neon_dotprod(src_ptr, src_stride, im_block,
225                                           im_stride, x_filter_ptr, im_h, w);
226 
227   if (clamped_y_taps == 6) {
228     if (conv_params->do_average) {
229       if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
230         dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
231             im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
232             w);
233       } else {
234         dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8,
235                                                 dst8_stride, conv_params,
236                                                 y_filter, h, w);
237       }
238     } else {
239       dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params,
240                                           y_filter, h, w);
241     }
242   } else {
243     if (conv_params->do_average) {
244       if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
245         dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
246             im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
247             w);
248       } else {
249         dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8,
250                                                 dst8_stride, conv_params,
251                                                 y_filter, h, w);
252       }
253     } else {
254       dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params,
255                                           y_filter, h, w);
256     }
257   }
258 }
259 
convolve4_4_x(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)260 static INLINE uint16x4_t convolve4_4_x(uint8x16_t samples,
261                                        const int8x8_t x_filter,
262                                        const int32x4_t correction,
263                                        const uint8x16_t range_limit,
264                                        const uint8x16_t permute_tbl) {
265   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
266   int8x16_t clamped_samples =
267       vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
268 
269   // Permute samples ready for dot product.
270   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
271   int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
272 
273   // Accumulate dot product into 'correction' to account for range clamp.
274   int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
275 
276   // We halved the convolution filter values so -1 from the right shift.
277   return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
278 }
279 
convolve8_8_x(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)280 static INLINE uint16x8_t convolve8_8_x(uint8x16_t samples,
281                                        const int8x8_t x_filter,
282                                        const int32x4_t correction,
283                                        const uint8x16_t range_limit,
284                                        const uint8x16x3_t permute_tbl) {
285   int8x16_t clamped_samples, permuted_samples[3];
286   int32x4_t sum[2];
287 
288   // Clamp sample range to [-128, 127] for 8-bit signed dot product.
289   clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
290 
291   // Permute samples ready for dot product. */
292   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
293   permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
294   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
295   permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
296   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
297   permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
298 
299   // Accumulate dot product into 'correction' to account for range clamp.
300   // First 4 output values.
301   sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
302   sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
303   // Second 4 output values.
304   sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
305   sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
306 
307   // Narrow and re-pack.
308   // We halved the convolution filter values so -1 from the right shift.
309   int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
310                                vshrn_n_s32(sum[1], ROUND0_BITS - 1));
311   return vreinterpretq_u16_s16(res);
312 }
313 
dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)314 static INLINE void dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
315     const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
316     int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
317     ConvolveParams *conv_params) {
318   assert(w % 4 == 0);
319   assert(h % 4 == 0);
320 
321   const int bd = 8;
322   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
323   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
324                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
325   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
326 
327   const uint16_t fwd_offset = conv_params->fwd_offset;
328   const uint16_t bck_offset = conv_params->bck_offset;
329 
330   // Horizontal filter.
331   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
332       filter_params_x, subpel_x_qn & SUBPEL_MASK);
333   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
334 
335   // Dot-product constants and other shims.
336   const uint8x16_t range_limit = vdupq_n_u8(128);
337   const int32_t correction_s32 =
338       vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
339   // Fold round_offset into the dot-product filter correction constant. The
340   // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
341   // rounding shifts - which are generally faster than rounding shifts on
342   // modern CPUs. (The extra -1 is needed because we halved the filter values.)
343   int32x4_t correction =
344       vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
345                   (1 << ((ROUND0_BITS - 1) - 1)));
346 
347   const int horiz_offset = filter_params_x->taps / 2 - 1;
348   const uint8_t *src_ptr = src - horiz_offset;
349   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
350   uint8_t *dst8_ptr = dst8;
351   int dst_stride = conv_params->dst_stride;
352   int height = h;
353 
354   if (w == 4) {
355     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
356     // 4-tap filters are used for blocks having width <= 4.
357     // Filter values are even, so halve to reduce intermediate precision reqs.
358     const int8x8_t x_filter =
359         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
360 
361     src_ptr += 2;
362 
363     do {
364       uint8x16_t s0, s1, s2, s3;
365       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
366 
367       uint16x4_t d0 =
368           convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
369       uint16x4_t d1 =
370           convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
371       uint16x4_t d2 =
372           convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
373       uint16x4_t d3 =
374           convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
375 
376       uint16x4_t dd0, dd1, dd2, dd3;
377       load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
378 
379       uint8x8_t d01_u8, d23_u8;
380       compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
381                                bck_offset, round_offset_vec, &d01_u8, &d23_u8);
382 
383       store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
384       store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
385 
386       src_ptr += 4 * src_stride;
387       dst_ptr += 4 * dst_stride;
388       dst8_ptr += 4 * dst8_stride;
389       height -= 4;
390     } while (height != 0);
391   } else {
392     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
393     // Filter values are even, so halve to reduce intermediate precision reqs.
394     const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
395 
396     do {
397       const uint8_t *s = src_ptr;
398       CONV_BUF_TYPE *d = dst_ptr;
399       uint8_t *d_u8 = dst8_ptr;
400       int width = w;
401 
402       do {
403         uint8x16_t s0, s1, s2, s3;
404         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
405 
406         uint16x8_t d0 =
407             convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
408         uint16x8_t d1 =
409             convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
410         uint16x8_t d2 =
411             convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
412         uint16x8_t d3 =
413             convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
414 
415         uint16x8_t dd0, dd1, dd2, dd3;
416         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
417 
418         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
419         compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
420                                  bck_offset, round_offset_vec, &d0_u8, &d1_u8,
421                                  &d2_u8, &d3_u8);
422 
423         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
424 
425         s += 8;
426         d += 8;
427         d_u8 += 8;
428         width -= 8;
429       } while (width != 0);
430       src_ptr += 4 * src_stride;
431       dst_ptr += 4 * dst_stride;
432       dst8_ptr += 4 * dst8_stride;
433       height -= 4;
434     } while (height != 0);
435   }
436 }
437 
dist_wtd_convolve_x_avg_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)438 static INLINE void dist_wtd_convolve_x_avg_neon_dotprod(
439     const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
440     int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
441     ConvolveParams *conv_params) {
442   assert(w % 4 == 0);
443   assert(h % 4 == 0);
444 
445   const int bd = 8;
446   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
447   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
448                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
449   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
450 
451   // Horizontal filter.
452   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
453       filter_params_x, subpel_x_qn & SUBPEL_MASK);
454   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
455 
456   // Dot-product constants and other shims.
457   const uint8x16_t range_limit = vdupq_n_u8(128);
458   const int32_t correction_s32 =
459       vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
460   // Fold round_offset into the dot-product filter correction constant. The
461   // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
462   // rounding shifts - which are generally faster than rounding shifts on
463   // modern CPUs. (The extra -1 is needed because we halved the filter values.)
464   int32x4_t correction =
465       vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
466                   (1 << ((ROUND0_BITS - 1) - 1)));
467 
468   const int horiz_offset = filter_params_x->taps / 2 - 1;
469   const uint8_t *src_ptr = src - horiz_offset;
470   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
471   uint8_t *dst8_ptr = dst8;
472   int dst_stride = conv_params->dst_stride;
473   int height = h;
474 
475   if (w == 4) {
476     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
477     // 4-tap filters are used for blocks having width <= 4.
478     // Filter values are even, so halve to reduce intermediate precision reqs.
479     const int8x8_t x_filter =
480         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
481 
482     src_ptr += 2;
483 
484     do {
485       uint8x16_t s0, s1, s2, s3;
486       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
487 
488       uint16x4_t d0 =
489           convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
490       uint16x4_t d1 =
491           convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
492       uint16x4_t d2 =
493           convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
494       uint16x4_t d3 =
495           convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
496 
497       uint16x4_t dd0, dd1, dd2, dd3;
498       load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
499 
500       uint8x8_t d01_u8, d23_u8;
501       compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
502                             round_offset_vec, &d01_u8, &d23_u8);
503 
504       store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
505       store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
506 
507       src_ptr += 4 * src_stride;
508       dst_ptr += 4 * dst_stride;
509       dst8_ptr += 4 * dst8_stride;
510       height -= 4;
511     } while (height != 0);
512   } else {
513     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
514     // Filter values are even, so halve to reduce intermediate precision reqs.
515     const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
516 
517     do {
518       const uint8_t *s = src_ptr;
519       CONV_BUF_TYPE *d = dst_ptr;
520       uint8_t *d_u8 = dst8_ptr;
521       int width = w;
522 
523       do {
524         uint8x16_t s0, s1, s2, s3;
525         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
526 
527         uint16x8_t d0 =
528             convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
529         uint16x8_t d1 =
530             convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
531         uint16x8_t d2 =
532             convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
533         uint16x8_t d3 =
534             convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
535 
536         uint16x8_t dd0, dd1, dd2, dd3;
537         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
538 
539         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
540         compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
541                               round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
542 
543         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
544 
545         s += 8;
546         d += 8;
547         d_u8 += 8;
548         width -= 8;
549       } while (width != 0);
550       src_ptr += 4 * src_stride;
551       dst_ptr += 4 * dst_stride;
552       dst8_ptr += 4 * dst8_stride;
553       height -= 4;
554     } while (height != 0);
555   }
556 }
557 
dist_wtd_convolve_x_neon_dotprod(const uint8_t * src,int src_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)558 static INLINE void dist_wtd_convolve_x_neon_dotprod(
559     const uint8_t *src, int src_stride, int w, int h,
560     const InterpFilterParams *filter_params_x, const int subpel_x_qn,
561     ConvolveParams *conv_params) {
562   assert(w % 4 == 0);
563   assert(h % 4 == 0);
564 
565   const int bd = 8;
566   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
567   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
568                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
569 
570   // Horizontal filter.
571   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
572       filter_params_x, subpel_x_qn & SUBPEL_MASK);
573   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
574 
575   // Dot-product constants and other shims.
576   const uint8x16_t range_limit = vdupq_n_u8(128);
577   const int32_t correction_s32 =
578       vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
579   // Fold round_offset into the dot-product filter correction constant. The
580   // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
581   // rounding shifts - which are generally faster than rounding shifts on
582   // modern CPUs. (The extra -1 is needed because we halved the filter values.)
583   int32x4_t correction =
584       vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
585                   (1 << ((ROUND0_BITS - 1) - 1)));
586 
587   const int horiz_offset = filter_params_x->taps / 2 - 1;
588   const uint8_t *src_ptr = src - horiz_offset;
589   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
590   int dst_stride = conv_params->dst_stride;
591   int height = h;
592 
593   if (w == 4) {
594     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
595     // 4-tap filters are used for blocks having width <= 4.
596     // Filter values are even, so halve to reduce intermediate precision reqs.
597     const int8x8_t x_filter =
598         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
599 
600     src_ptr += 2;
601 
602     do {
603       uint8x16_t s0, s1, s2, s3;
604       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
605 
606       uint16x4_t d0 =
607           convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
608       uint16x4_t d1 =
609           convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
610       uint16x4_t d2 =
611           convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
612       uint16x4_t d3 =
613           convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
614 
615       store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
616 
617       src_ptr += 4 * src_stride;
618       dst_ptr += 4 * dst_stride;
619       height -= 4;
620     } while (height != 0);
621   } else {
622     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
623     // Filter values are even, so halve to reduce intermediate precision reqs.
624     const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
625 
626     do {
627       const uint8_t *s = src_ptr;
628       CONV_BUF_TYPE *d = dst_ptr;
629       int width = w;
630 
631       do {
632         uint8x16_t s0, s1, s2, s3;
633         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
634 
635         uint16x8_t d0 =
636             convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
637         uint16x8_t d1 =
638             convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
639         uint16x8_t d2 =
640             convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
641         uint16x8_t d3 =
642             convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
643 
644         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
645 
646         s += 8;
647         d += 8;
648         width -= 8;
649       } while (width != 0);
650       src_ptr += 4 * src_stride;
651       dst_ptr += 4 * dst_stride;
652       height -= 4;
653     } while (height != 0);
654   }
655 }
656 
av1_dist_wtd_convolve_x_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)657 void av1_dist_wtd_convolve_x_neon_dotprod(
658     const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
659     int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
660     ConvolveParams *conv_params) {
661   if (conv_params->do_average) {
662     if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
663       dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
664           src, src_stride, dst8, dst8_stride, w, h, filter_params_x,
665           subpel_x_qn, conv_params);
666     } else {
667       dist_wtd_convolve_x_avg_neon_dotprod(src, src_stride, dst8, dst8_stride,
668                                            w, h, filter_params_x, subpel_x_qn,
669                                            conv_params);
670     }
671   } else {
672     dist_wtd_convolve_x_neon_dotprod(src, src_stride, w, h, filter_params_x,
673                                      subpel_x_qn, conv_params);
674   }
675 }
676