• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/av1_rtcd.h"
16 
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_ports/mem.h"
20 #include "av1/common/arm/convolve_neon.h"
21 #include "av1/common/convolve.h"
22 #include "av1/common/filter.h"
23 
24 DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
25   0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
26   4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
27   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
28 };
29 
convolve12_4_x(uint8x16_t samples,const int8x16_t filter,const uint8x16x3_t permute_tbl,const int32x4_t horiz_const)30 static INLINE int16x4_t convolve12_4_x(uint8x16_t samples,
31                                        const int8x16_t filter,
32                                        const uint8x16x3_t permute_tbl,
33                                        const int32x4_t horiz_const) {
34   uint8x16_t permuted_samples[3];
35   int32x4_t sum;
36 
37   // Permute samples ready for dot product.
38   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
39   permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
40   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
41   permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
42   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
43   permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
44 
45   // First 4 output values.
46   sum = vusdotq_laneq_s32(horiz_const, permuted_samples[0], filter, 0);
47   sum = vusdotq_laneq_s32(sum, permuted_samples[1], filter, 1);
48   sum = vusdotq_laneq_s32(sum, permuted_samples[2], filter, 2);
49 
50   return vqrshrn_n_s32(sum, FILTER_BITS);
51 }
52 
convolve12_8_x(uint8x16_t samples[2],const int8x16_t filter,const uint8x16x3_t permute_tbl,const int32x4_t horiz_const)53 static INLINE uint8x8_t convolve12_8_x(uint8x16_t samples[2],
54                                        const int8x16_t filter,
55                                        const uint8x16x3_t permute_tbl,
56                                        const int32x4_t horiz_const) {
57   uint8x16_t permuted_samples[4];
58   int32x4_t sum[2];
59 
60   // Permute samples ready for dot product.
61   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
62   permuted_samples[0] = vqtbl1q_u8(samples[0], permute_tbl.val[0]);
63   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
64   permuted_samples[1] = vqtbl1q_u8(samples[0], permute_tbl.val[1]);
65   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
66   permuted_samples[2] = vqtbl1q_u8(samples[0], permute_tbl.val[2]);
67   // {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 }
68   permuted_samples[3] = vqtbl1q_u8(samples[1], permute_tbl.val[2]);
69 
70   // First 4 output values.
71   sum[0] = vusdotq_laneq_s32(horiz_const, permuted_samples[0], filter, 0);
72   sum[0] = vusdotq_laneq_s32(sum[0], permuted_samples[1], filter, 1);
73   sum[0] = vusdotq_laneq_s32(sum[0], permuted_samples[2], filter, 2);
74   // Second 4 output values.
75   sum[1] = vusdotq_laneq_s32(horiz_const, permuted_samples[1], filter, 0);
76   sum[1] = vusdotq_laneq_s32(sum[1], permuted_samples[2], filter, 1);
77   sum[1] = vusdotq_laneq_s32(sum[1], permuted_samples[3], filter, 2);
78 
79   // Narrow and re-pack.
80   int16x8_t sum_s16 = vcombine_s16(vqrshrn_n_s32(sum[0], FILTER_BITS),
81                                    vqrshrn_n_s32(sum[1], FILTER_BITS));
82   return vqmovun_s16(sum_s16);
83 }
84 
convolve_x_sr_12tap_neon_i8mm(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const int16_t * x_filter_ptr)85 static INLINE void convolve_x_sr_12tap_neon_i8mm(const uint8_t *src,
86                                                  int src_stride, uint8_t *dst,
87                                                  int dst_stride, int w, int h,
88                                                  const int16_t *x_filter_ptr) {
89   const int16x8_t filter_0_7 = vld1q_s16(x_filter_ptr);
90   const int16x4_t filter_8_11 = vld1_s16(x_filter_ptr + 8);
91   const int16x8_t filter_8_15 = vcombine_s16(filter_8_11, vdup_n_s16(0));
92   const int8x16_t filter =
93       vcombine_s8(vmovn_s16(filter_0_7), vmovn_s16(filter_8_15));
94 
95   // Special case the following no-op filter as 128 won't fit into the
96   // 8-bit signed dot-product instruction:
97   // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
98   if (vgetq_lane_s16(filter_0_7, 5) == 128) {
99     // Undo the horizontal offset in the calling function.
100     src += 5;
101 
102     do {
103       const uint8_t *s = src;
104       uint8_t *d = dst;
105       int width = w;
106 
107       do {
108         uint8x8_t d0 = vld1_u8(s);
109         if (w == 4) {
110           store_u8_4x1(d, d0);
111         } else {
112           vst1_u8(d, d0);
113         }
114 
115         s += 8;
116         d += 8;
117         width -= 8;
118       } while (width > 0);
119       src += src_stride;
120       dst += dst_stride;
121     } while (--h != 0);
122   } else {
123     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
124     // This shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding
125     // right shift by FILTER_BITS - instead of a first rounding right shift by
126     // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
127     // ROUND0_BITS.
128     const int32x4_t horiz_const = vdupq_n_s32(1 << (ROUND0_BITS - 1));
129 
130     if (w <= 4) {
131       do {
132         uint8x16_t s0, s1, s2, s3;
133         load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
134 
135         int16x4_t d0 = convolve12_4_x(s0, filter, permute_tbl, horiz_const);
136         int16x4_t d1 = convolve12_4_x(s1, filter, permute_tbl, horiz_const);
137         int16x4_t d2 = convolve12_4_x(s2, filter, permute_tbl, horiz_const);
138         int16x4_t d3 = convolve12_4_x(s3, filter, permute_tbl, horiz_const);
139 
140         uint8x8_t d01 = vqmovun_s16(vcombine_s16(d0, d1));
141         uint8x8_t d23 = vqmovun_s16(vcombine_s16(d2, d3));
142 
143         store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
144         store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
145 
146         dst += 4 * dst_stride;
147         src += 4 * src_stride;
148         h -= 4;
149       } while (h != 0);
150     } else {
151       do {
152         const uint8_t *s = src;
153         uint8_t *d = dst;
154         int width = w;
155 
156         do {
157           uint8x16_t s0[2], s1[2], s2[2], s3[2];
158           load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
159           load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
160 
161           uint8x8_t d0 = convolve12_8_x(s0, filter, permute_tbl, horiz_const);
162           uint8x8_t d1 = convolve12_8_x(s1, filter, permute_tbl, horiz_const);
163           uint8x8_t d2 = convolve12_8_x(s2, filter, permute_tbl, horiz_const);
164           uint8x8_t d3 = convolve12_8_x(s3, filter, permute_tbl, horiz_const);
165 
166           store_u8_8x4(d + 0 * dst_stride, dst_stride, d0, d1, d2, d3);
167 
168           s += 8;
169           d += 8;
170           width -= 8;
171         } while (width != 0);
172         src += 4 * src_stride;
173         dst += 4 * dst_stride;
174         h -= 4;
175       } while (h != 0);
176     }
177   }
178 }
179 
convolve4_4_x(uint8x16_t samples,const int8x8_t filter,const uint8x16_t permute_tbl,const int32x4_t horiz_const)180 static INLINE int16x4_t convolve4_4_x(uint8x16_t samples, const int8x8_t filter,
181                                       const uint8x16_t permute_tbl,
182                                       const int32x4_t horiz_const) {
183   // Permute samples ready for dot product.
184   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
185   uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
186 
187   // First 4 output values.
188   int32x4_t sum = vusdotq_lane_s32(horiz_const, permuted_samples, filter, 0);
189 
190   // Packing is performed by the caller.
191   return vmovn_s32(sum);
192 }
193 
convolve8_8_x(uint8x16_t samples,const int8x8_t filter,const uint8x16x3_t permute_tbl,const int32x4_t horiz_const)194 static INLINE uint8x8_t convolve8_8_x(uint8x16_t samples, const int8x8_t filter,
195                                       const uint8x16x3_t permute_tbl,
196                                       const int32x4_t horiz_const) {
197   uint8x16_t permuted_samples[3];
198   int32x4_t sum[2];
199 
200   // Permute samples ready for dot product.
201   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
202   permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
203   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
204   permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
205   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
206   permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
207 
208   // First 4 output values.
209   sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], filter, 0);
210   sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], filter, 1);
211   // Second 4 output values.
212   sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], filter, 0);
213   sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], filter, 1);
214 
215   int16x8_t sum_s16 = vcombine_s16(vmovn_s32(sum[0]), vmovn_s32(sum[1]));
216   // We halved the convolution filter values so - 1 from the right shift.
217   return vqrshrun_n_s16(sum_s16, FILTER_BITS - 1);
218 }
219 
av1_convolve_x_sr_neon_i8mm(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)220 void av1_convolve_x_sr_neon_i8mm(const uint8_t *src, int src_stride,
221                                  uint8_t *dst, int dst_stride, int w, int h,
222                                  const InterpFilterParams *filter_params_x,
223                                  const int subpel_x_qn,
224                                  ConvolveParams *conv_params) {
225   if (w == 2 || h == 2) {
226     av1_convolve_x_sr_c(src, src_stride, dst, dst_stride, w, h, filter_params_x,
227                         subpel_x_qn, conv_params);
228     return;
229   }
230 
231   const uint8_t horiz_offset = filter_params_x->taps / 2 - 1;
232   src -= horiz_offset;
233 
234   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
235       filter_params_x, subpel_x_qn & SUBPEL_MASK);
236 
237   if (filter_params_x->taps > 8) {
238     convolve_x_sr_12tap_neon_i8mm(src, src_stride, dst, dst_stride, w, h,
239                                   x_filter_ptr);
240     return;
241   }
242 
243   // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use a single
244   // rounding right shift by FILTER_BITS - instead of a first rounding right
245   // shift by ROUND0_BITS, followed by second rounding right shift by
246   // FILTER_BITS - ROUND0_BITS.
247   // The outermost -1 is needed because we will halve the filter values.
248   const int32x4_t horiz_const = vdupq_n_s32(1 << ((ROUND0_BITS - 1) - 1));
249 
250   if (w <= 4) {
251     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
252     // 4-tap filters are used for blocks having width <= 4.
253     // Filter values are even, so halve to reduce intermediate precision reqs.
254     const int8x8_t x_filter =
255         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
256 
257     src += 2;
258 
259     do {
260       uint8x16_t s0, s1, s2, s3;
261       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
262 
263       int16x4_t d0 = convolve4_4_x(s0, x_filter, permute_tbl, horiz_const);
264       int16x4_t d1 = convolve4_4_x(s1, x_filter, permute_tbl, horiz_const);
265       int16x4_t d2 = convolve4_4_x(s2, x_filter, permute_tbl, horiz_const);
266       int16x4_t d3 = convolve4_4_x(s3, x_filter, permute_tbl, horiz_const);
267 
268       // We halved the convolution filter values so - 1 from the right shift.
269       uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS - 1);
270       uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS - 1);
271 
272       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
273       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
274 
275       src += 4 * src_stride;
276       dst += 4 * dst_stride;
277       h -= 4;
278     } while (h != 0);
279 
280   } else {
281     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
282     // Filter values are even, so halve to reduce intermediate precision reqs.
283     const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
284 
285     do {
286       const uint8_t *s = src;
287       uint8_t *d = dst;
288       int width = w;
289 
290       do {
291         uint8x16_t s0, s1, s2, s3;
292         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
293 
294         uint8x8_t d0 = convolve8_8_x(s0, x_filter, permute_tbl, horiz_const);
295         uint8x8_t d1 = convolve8_8_x(s1, x_filter, permute_tbl, horiz_const);
296         uint8x8_t d2 = convolve8_8_x(s2, x_filter, permute_tbl, horiz_const);
297         uint8x8_t d3 = convolve8_8_x(s3, x_filter, permute_tbl, horiz_const);
298 
299         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
300 
301         s += 8;
302         d += 8;
303         width -= 8;
304       } while (width != 0);
305       src += 4 * src_stride;
306       dst += 4 * dst_stride;
307       h -= 4;
308     } while (h != 0);
309   }
310 }
311 
convolve12_4_2d_h(uint8x16_t samples,const int8x16_t filters,const uint8x16x3_t permute_tbl,int32x4_t horiz_const)312 static INLINE int16x4_t convolve12_4_2d_h(uint8x16_t samples,
313                                           const int8x16_t filters,
314                                           const uint8x16x3_t permute_tbl,
315                                           int32x4_t horiz_const) {
316   uint8x16_t permuted_samples[3];
317   int32x4_t sum;
318 
319   // Permute samples ready for dot product.
320   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
321   permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
322   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
323   permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
324   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
325   permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
326 
327   // First 4 output values.
328   sum = vusdotq_laneq_s32(horiz_const, permuted_samples[0], filters, 0);
329   sum = vusdotq_laneq_s32(sum, permuted_samples[1], filters, 1);
330   sum = vusdotq_laneq_s32(sum, permuted_samples[2], filters, 2);
331 
332   // Narrow and re-pack.
333   return vshrn_n_s32(sum, ROUND0_BITS);
334 }
335 
convolve12_8_2d_h(uint8x16_t samples[2],const int8x16_t filters,const uint8x16x3_t permute_tbl,const int32x4_t horiz_const)336 static INLINE int16x8_t convolve12_8_2d_h(uint8x16_t samples[2],
337                                           const int8x16_t filters,
338                                           const uint8x16x3_t permute_tbl,
339                                           const int32x4_t horiz_const) {
340   uint8x16_t permuted_samples[4];
341   int32x4_t sum[2];
342 
343   // Permute samples ready for dot product.
344   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
345   permuted_samples[0] = vqtbl1q_u8(samples[0], permute_tbl.val[0]);
346   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
347   permuted_samples[1] = vqtbl1q_u8(samples[0], permute_tbl.val[1]);
348   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
349   permuted_samples[2] = vqtbl1q_u8(samples[0], permute_tbl.val[2]);
350   // {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 }
351   permuted_samples[3] = vqtbl1q_u8(samples[1], permute_tbl.val[2]);
352 
353   // First 4 output values.
354   sum[0] = vusdotq_laneq_s32(horiz_const, permuted_samples[0], filters, 0);
355   sum[0] = vusdotq_laneq_s32(sum[0], permuted_samples[1], filters, 1);
356   sum[0] = vusdotq_laneq_s32(sum[0], permuted_samples[2], filters, 2);
357   // Second 4 output values.
358   sum[1] = vusdotq_laneq_s32(horiz_const, permuted_samples[1], filters, 0);
359   sum[1] = vusdotq_laneq_s32(sum[1], permuted_samples[2], filters, 1);
360   sum[1] = vusdotq_laneq_s32(sum[1], permuted_samples[3], filters, 2);
361 
362   // Narrow and re-pack.
363   return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS),
364                       vshrn_n_s32(sum[1], ROUND0_BITS));
365 }
366 
convolve_2d_sr_horiz_12tap_neon_i8mm(const uint8_t * src_ptr,int src_stride,int16_t * dst_ptr,const int dst_stride,int w,int h,const int16x8_t x_filter_0_7,const int16x4_t x_filter_8_11)367 static INLINE void convolve_2d_sr_horiz_12tap_neon_i8mm(
368     const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
369     const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
370     const int16x4_t x_filter_8_11) {
371   const int bd = 8;
372 
373   // Special case the following no-op filter as 128 won't fit into the
374   // 8-bit signed dot-product instruction:
375   // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
376   if (vgetq_lane_s16(x_filter_0_7, 5) == 128) {
377     const uint16x8_t horiz_const = vdupq_n_u16((1 << (bd - 1)));
378     // Undo the horizontal offset in the calling function.
379     src_ptr += 5;
380 
381     do {
382       const uint8_t *s = src_ptr;
383       int16_t *d = dst_ptr;
384       int width = w;
385 
386       do {
387         uint8x8_t s0 = vld1_u8(s);
388         uint16x8_t d0 = vaddw_u8(horiz_const, s0);
389         d0 = vshlq_n_u16(d0, FILTER_BITS - ROUND0_BITS);
390         // Store 8 elements to avoid additional branches. This is safe if the
391         // actual block width is < 8 because the intermediate buffer is large
392         // enough to accommodate 128x128 blocks.
393         vst1q_s16(d, vreinterpretq_s16_u16(d0));
394 
395         d += 8;
396         s += 8;
397         width -= 8;
398       } while (width > 0);
399       src_ptr += src_stride;
400       dst_ptr += dst_stride;
401     } while (--h != 0);
402 
403   } else {
404     // Narrow filter values to 8-bit.
405     const int16x8x2_t x_filter_s16 = {
406       { x_filter_0_7, vcombine_s16(x_filter_8_11, vdup_n_s16(0)) }
407     };
408     const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
409                                            vmovn_s16(x_filter_s16.val[1]));
410     // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
411     // - which are generally faster than rounding shifts on modern CPUs.
412     const int32x4_t horiz_const =
413         vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
414     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
415 
416     if (w <= 4) {
417       do {
418         uint8x16_t s0, s1, s2, s3;
419         load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
420 
421         int16x4_t d0 =
422             convolve12_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
423         int16x4_t d1 =
424             convolve12_4_2d_h(s1, x_filter, permute_tbl, horiz_const);
425         int16x4_t d2 =
426             convolve12_4_2d_h(s2, x_filter, permute_tbl, horiz_const);
427         int16x4_t d3 =
428             convolve12_4_2d_h(s3, x_filter, permute_tbl, horiz_const);
429 
430         store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
431 
432         src_ptr += 4 * src_stride;
433         dst_ptr += 4 * dst_stride;
434         h -= 4;
435       } while (h > 4);
436 
437       do {
438         uint8x16_t s0 = vld1q_u8(src_ptr);
439         int16x4_t d0 =
440             convolve12_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
441         vst1_s16(dst_ptr, d0);
442 
443         src_ptr += src_stride;
444         dst_ptr += dst_stride;
445       } while (--h != 0);
446 
447     } else {
448       do {
449         const uint8_t *s = src_ptr;
450         int16_t *d = dst_ptr;
451         int width = w;
452 
453         do {
454           uint8x16_t s0[2], s1[2], s2[2], s3[2];
455           load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
456           load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
457 
458           int16x8_t d0 =
459               convolve12_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
460           int16x8_t d1 =
461               convolve12_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
462           int16x8_t d2 =
463               convolve12_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
464           int16x8_t d3 =
465               convolve12_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
466 
467           store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
468 
469           s += 8;
470           d += 8;
471           width -= 8;
472         } while (width != 0);
473 
474         src_ptr += 4 * src_stride;
475         dst_ptr += 4 * dst_stride;
476         h -= 4;
477       } while (h > 4);
478 
479       do {
480         const uint8_t *s = src_ptr;
481         int16_t *d = dst_ptr;
482         int width = w;
483 
484         do {
485           uint8x16_t s0[2];
486           s0[0] = vld1q_u8(s);
487           s0[1] = vld1q_u8(s + 4);
488           int16x8_t d0 =
489               convolve12_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
490           vst1q_s16(d, d0);
491 
492           s += 8;
493           d += 8;
494           width -= 8;
495         } while (width != 0);
496         src_ptr += src_stride;
497         dst_ptr += dst_stride;
498       } while (--h != 0);
499     }
500   }
501 }
502 
convolve4_4_2d_h(uint8x16_t samples,const int8x8_t filters,const uint8x16_t permute_tbl,const int32x4_t horiz_const)503 static INLINE int16x4_t convolve4_4_2d_h(uint8x16_t samples,
504                                          const int8x8_t filters,
505                                          const uint8x16_t permute_tbl,
506                                          const int32x4_t horiz_const) {
507   // Permute samples ready for dot product.
508   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
509   uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
510 
511   // First 4 output values.
512   int32x4_t sum = vusdotq_lane_s32(horiz_const, permuted_samples, filters, 0);
513 
514   // We halved the convolution filter values so -1 from the right shift.
515   return vshrn_n_s32(sum, ROUND0_BITS - 1);
516 }
517 
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t filters,const uint8x16x3_t permute_tbl,const int32x4_t horiz_const)518 static INLINE int16x8_t convolve8_8_2d_h(uint8x16_t samples,
519                                          const int8x8_t filters,
520                                          const uint8x16x3_t permute_tbl,
521                                          const int32x4_t horiz_const) {
522   uint8x16_t permuted_samples[3];
523   int32x4_t sum[2];
524 
525   // Permute samples ready for dot product.
526   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
527   permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
528   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
529   permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
530   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
531   permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
532 
533   // First 4 output values.
534   sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], filters, 0);
535   sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
536   // Second 4 output values.
537   sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], filters, 0);
538   sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
539 
540   // Narrow and re-pack.
541   // We halved the convolution filter values so -1 from the right shift.
542   return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
543                       vshrn_n_s32(sum[1], ROUND0_BITS - 1));
544 }
545 
convolve_2d_sr_horiz_neon_i8mm(const uint8_t * src,int src_stride,int16_t * im_block,int im_stride,int w,int im_h,const int16_t * x_filter_ptr)546 static INLINE void convolve_2d_sr_horiz_neon_i8mm(
547     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
548     int im_h, const int16_t *x_filter_ptr) {
549   const int bd = 8;
550   // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
551   // shifts - which are generally faster than rounding shifts on modern CPUs.
552   // The outermost -1 is needed because we halved the filter values.
553   const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
554                                             (1 << ((ROUND0_BITS - 1) - 1)));
555 
556   const uint8_t *src_ptr = src;
557   int16_t *dst_ptr = im_block;
558   int dst_stride = im_stride;
559   int height = im_h;
560 
561   if (w <= 4) {
562     const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
563     // 4-tap filters are used for blocks having width <= 4.
564     // Filter values are even, so halve to reduce intermediate precision reqs.
565     const int8x8_t x_filter =
566         vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
567 
568     src_ptr += 2;
569 
570     do {
571       uint8x16_t s0, s1, s2, s3;
572       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
573 
574       int16x4_t d0 = convolve4_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
575       int16x4_t d1 = convolve4_4_2d_h(s1, x_filter, permute_tbl, horiz_const);
576       int16x4_t d2 = convolve4_4_2d_h(s2, x_filter, permute_tbl, horiz_const);
577       int16x4_t d3 = convolve4_4_2d_h(s3, x_filter, permute_tbl, horiz_const);
578 
579       store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
580 
581       src_ptr += 4 * src_stride;
582       dst_ptr += 4 * dst_stride;
583       height -= 4;
584     } while (height > 4);
585 
586     do {
587       uint8x16_t s0 = vld1q_u8(src_ptr);
588       int16x4_t d0 = convolve4_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
589       vst1_s16(dst_ptr, d0);
590 
591       src_ptr += src_stride;
592       dst_ptr += dst_stride;
593     } while (--height != 0);
594   } else {
595     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
596     // Filter values are even, so halve to reduce intermediate precision reqs.
597     const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
598 
599     do {
600       const uint8_t *s = src_ptr;
601       int16_t *d = dst_ptr;
602       int width = w;
603 
604       do {
605         uint8x16_t s0, s1, s2, s3;
606         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
607 
608         int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
609         int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
610         int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
611         int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
612 
613         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
614 
615         s += 8;
616         d += 8;
617         width -= 8;
618       } while (width != 0);
619       src_ptr += 4 * src_stride;
620       dst_ptr += 4 * dst_stride;
621       height -= 4;
622     } while (height > 4);
623 
624     do {
625       const uint8_t *s = src_ptr;
626       int16_t *d = dst_ptr;
627       int width = w;
628 
629       do {
630         uint8x16_t s0 = vld1q_u8(s);
631         int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
632         vst1q_s16(d, d0);
633 
634         s += 8;
635         d += 8;
636         width -= 8;
637       } while (width != 0);
638       src_ptr += src_stride;
639       dst_ptr += dst_stride;
640     } while (--height != 0);
641   }
642 }
643 
av1_convolve_2d_sr_neon_i8mm(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)644 void av1_convolve_2d_sr_neon_i8mm(const uint8_t *src, int src_stride,
645                                   uint8_t *dst, int dst_stride, int w, int h,
646                                   const InterpFilterParams *filter_params_x,
647                                   const InterpFilterParams *filter_params_y,
648                                   const int subpel_x_qn, const int subpel_y_qn,
649                                   ConvolveParams *conv_params) {
650   if (w == 2 || h == 2) {
651     av1_convolve_2d_sr_c(src, src_stride, dst, dst_stride, w, h,
652                          filter_params_x, filter_params_y, subpel_x_qn,
653                          subpel_y_qn, conv_params);
654     return;
655   }
656 
657   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
658   const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
659   const int im_h = h + clamped_y_taps - 1;
660   const int im_stride = MAX_SB_SIZE;
661   const int vert_offset = clamped_y_taps / 2 - 1;
662   const int horiz_offset = filter_params_x->taps / 2 - 1;
663   const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
664 
665   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
666       filter_params_x, subpel_x_qn & SUBPEL_MASK);
667   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
668       filter_params_y, subpel_y_qn & SUBPEL_MASK);
669 
670   if (filter_params_x->taps > 8) {
671     DECLARE_ALIGNED(16, int16_t,
672                     im_block[(MAX_SB_SIZE + MAX_FILTER_TAP - 1) * MAX_SB_SIZE]);
673 
674     const int16x8_t x_filter_0_7 = vld1q_s16(x_filter_ptr);
675     const int16x4_t x_filter_8_11 = vld1_s16(x_filter_ptr + 8);
676     const int16x8_t y_filter_0_7 = vld1q_s16(y_filter_ptr);
677     const int16x4_t y_filter_8_11 = vld1_s16(y_filter_ptr + 8);
678 
679     convolve_2d_sr_horiz_12tap_neon_i8mm(src_ptr, src_stride, im_block,
680                                          im_stride, w, im_h, x_filter_0_7,
681                                          x_filter_8_11);
682 
683     convolve_2d_sr_vert_12tap_neon(im_block, im_stride, dst, dst_stride, w, h,
684                                    y_filter_0_7, y_filter_8_11);
685   } else {
686     DECLARE_ALIGNED(16, int16_t,
687                     im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
688 
689     convolve_2d_sr_horiz_neon_i8mm(src_ptr, src_stride, im_block, im_stride, w,
690                                    im_h, x_filter_ptr);
691 
692     const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
693 
694     if (clamped_y_taps <= 6) {
695       convolve_2d_sr_vert_6tap_neon(im_block, im_stride, dst, dst_stride, w, h,
696                                     y_filter);
697     } else {
698       convolve_2d_sr_vert_8tap_neon(im_block, im_stride, dst, dst_stride, w, h,
699                                     y_filter);
700     }
701   }
702 }
703