• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <arm_neon.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
20 #include "aom_dsp/arm/aom_neon_sve2_bridge.h"
21 #include "aom_dsp/arm/mem_neon.h"
22 #include "aom_ports/mem.h"
23 #include "av1/common/convolve.h"
24 #include "av1/common/filter.h"
25 #include "av1/common/filter.h"
26 #include "av1/common/arm/highbd_compound_convolve_neon.h"
27 #include "av1/common/arm/highbd_convolve_neon.h"
28 
29 DECLARE_ALIGNED(16, static const uint16_t, kDotProdTbl[32]) = {
30   0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
31   4, 5, 6, 7, 5, 6, 7, 0, 6, 7, 0, 1, 7, 0, 1, 2,
32 };
33 
convolve8_8_x(int16x8_t s0[8],int16x8_t filter,int64x2_t offset,int32x4_t shift)34 static INLINE uint16x8_t convolve8_8_x(int16x8_t s0[8], int16x8_t filter,
35                                        int64x2_t offset, int32x4_t shift) {
36   int64x2_t sum[8];
37   sum[0] = aom_sdotq_s16(offset, s0[0], filter);
38   sum[1] = aom_sdotq_s16(offset, s0[1], filter);
39   sum[2] = aom_sdotq_s16(offset, s0[2], filter);
40   sum[3] = aom_sdotq_s16(offset, s0[3], filter);
41   sum[4] = aom_sdotq_s16(offset, s0[4], filter);
42   sum[5] = aom_sdotq_s16(offset, s0[5], filter);
43   sum[6] = aom_sdotq_s16(offset, s0[6], filter);
44   sum[7] = aom_sdotq_s16(offset, s0[7], filter);
45 
46   sum[0] = vpaddq_s64(sum[0], sum[1]);
47   sum[2] = vpaddq_s64(sum[2], sum[3]);
48   sum[4] = vpaddq_s64(sum[4], sum[5]);
49   sum[6] = vpaddq_s64(sum[6], sum[7]);
50 
51   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
52   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
53 
54   sum0123 = vshlq_s32(sum0123, shift);
55   sum4567 = vshlq_s32(sum4567, shift);
56 
57   return vcombine_u16(vqmovun_s32(sum0123), vqmovun_s32(sum4567));
58 }
59 
highbd_dist_wtd_convolve_x_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,ConvolveParams * conv_params,const int offset)60 static INLINE void highbd_dist_wtd_convolve_x_sve2(
61     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
62     int width, int height, const int16_t *x_filter_ptr,
63     ConvolveParams *conv_params, const int offset) {
64   const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
65   const int64x2_t offset_vec = vdupq_n_s64(offset);
66 
67   const int64x2_t offset_lo =
68       vcombine_s64(vget_low_s64(offset_vec), vdup_n_s64(0));
69   const int16x8_t filter = vld1q_s16(x_filter_ptr);
70   do {
71     const int16_t *s = (const int16_t *)src;
72     uint16_t *d = dst;
73     int w = width;
74 
75     do {
76       int16x8_t s0[8], s1[8], s2[8], s3[8];
77       load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
78                    &s0[4], &s0[5], &s0[6], &s0[7]);
79       load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
80                    &s1[4], &s1[5], &s1[6], &s1[7]);
81       load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
82                    &s2[4], &s2[5], &s2[6], &s2[7]);
83       load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
84                    &s3[4], &s3[5], &s3[6], &s3[7]);
85 
86       uint16x8_t d0 = convolve8_8_x(s0, filter, offset_lo, shift);
87       uint16x8_t d1 = convolve8_8_x(s1, filter, offset_lo, shift);
88       uint16x8_t d2 = convolve8_8_x(s2, filter, offset_lo, shift);
89       uint16x8_t d3 = convolve8_8_x(s3, filter, offset_lo, shift);
90 
91       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
92 
93       s += 8;
94       d += 8;
95       w -= 8;
96     } while (w != 0);
97     src += 4 * src_stride;
98     dst += 4 * dst_stride;
99     height -= 4;
100   } while (height != 0);
101 }
102 
convolve4_4_x(int16x8_t s0,int16x8_t filter,int64x2_t offset,int32x4_t shift,uint16x8x2_t permute_tbl)103 static INLINE uint16x4_t convolve4_4_x(int16x8_t s0, int16x8_t filter,
104                                        int64x2_t offset, int32x4_t shift,
105                                        uint16x8x2_t permute_tbl) {
106   int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
107   int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
108 
109   int64x2_t sum01 = aom_svdot_lane_s16(offset, permuted_samples0, filter, 0);
110   int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
111 
112   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
113   sum0123 = vshlq_s32(sum0123, shift);
114 
115   return vqmovun_s32(sum0123);
116 }
117 
convolve4_8_x(int16x8_t s0[4],int16x8_t filter,int64x2_t offset,int32x4_t shift,uint16x8_t tbl)118 static INLINE uint16x8_t convolve4_8_x(int16x8_t s0[4], int16x8_t filter,
119                                        int64x2_t offset, int32x4_t shift,
120                                        uint16x8_t tbl) {
121   int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
122   int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
123   int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
124   int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
125 
126   int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
127   sum0415 = vshlq_s32(sum0415, shift);
128 
129   int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
130   sum2637 = vshlq_s32(sum2637, shift);
131 
132   uint16x8_t res = vcombine_u16(vqmovun_s32(sum0415), vqmovun_s32(sum2637));
133   return aom_tbl_u16(res, tbl);
134 }
135 
136 // clang-format off
137 DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
138   0, 2, 4, 6, 1, 3, 5, 7,
139 };
140 // clang-format on
141 
highbd_dist_wtd_convolve_x_4tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,ConvolveParams * conv_params,const int offset)142 static INLINE void highbd_dist_wtd_convolve_x_4tap_sve2(
143     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
144     int width, int height, const int16_t *x_filter_ptr,
145     ConvolveParams *conv_params, const int offset) {
146   // This shim allows to do only one rounding shift instead of two.
147   const int64x2_t offset_s64 = vdupq_n_s64(offset);
148   const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
149 
150   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
151   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
152 
153   if (width == 4) {
154     uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
155 
156     const int16_t *s = (const int16_t *)(src);
157 
158     do {
159       int16x8_t s0, s1, s2, s3;
160       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
161 
162       uint16x4_t d0 = convolve4_4_x(s0, filter, offset_s64, shift, permute_tbl);
163       uint16x4_t d1 = convolve4_4_x(s1, filter, offset_s64, shift, permute_tbl);
164       uint16x4_t d2 = convolve4_4_x(s2, filter, offset_s64, shift, permute_tbl);
165       uint16x4_t d3 = convolve4_4_x(s3, filter, offset_s64, shift, permute_tbl);
166 
167       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
168 
169       s += 4 * src_stride;
170       dst += 4 * dst_stride;
171       height -= 4;
172     } while (height != 0);
173   } else {
174     uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
175 
176     do {
177       const int16_t *s = (const int16_t *)(src);
178       uint16_t *d = dst;
179       int w = width;
180 
181       do {
182         int16x8_t s0[4], s1[4], s2[4], s3[4];
183         load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
184         load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
185         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
186         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
187 
188         uint16x8_t d0 = convolve4_8_x(s0, filter, offset_s64, shift, idx);
189         uint16x8_t d1 = convolve4_8_x(s1, filter, offset_s64, shift, idx);
190         uint16x8_t d2 = convolve4_8_x(s2, filter, offset_s64, shift, idx);
191         uint16x8_t d3 = convolve4_8_x(s3, filter, offset_s64, shift, idx);
192 
193         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
194 
195         s += 8;
196         d += 8;
197         w -= 8;
198       } while (w != 0);
199       src += 4 * src_stride;
200       dst += 4 * dst_stride;
201       height -= 4;
202     } while (height != 0);
203   }
204 }
205 
av1_highbd_dist_wtd_convolve_x_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params,int bd)206 void av1_highbd_dist_wtd_convolve_x_sve2(
207     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
208     int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
209     ConvolveParams *conv_params, int bd) {
210   DECLARE_ALIGNED(16, uint16_t,
211                   im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
212   CONV_BUF_TYPE *dst16 = conv_params->dst;
213   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
214 
215   if (x_filter_taps == 6) {
216     av1_highbd_dist_wtd_convolve_x_neon(src, src_stride, dst, dst_stride, w, h,
217                                         filter_params_x, subpel_x_qn,
218                                         conv_params, bd);
219     return;
220   }
221 
222   int dst16_stride = conv_params->dst_stride;
223   const int im_stride = MAX_SB_SIZE;
224   const int horiz_offset = filter_params_x->taps / 2 - 1;
225   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
226   const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
227   const int offset_avg = (1 << (offset_bits - conv_params->round_1)) +
228                          (1 << (offset_bits - conv_params->round_1 - 1));
229   const int offset_convolve = (1 << (conv_params->round_0 - 1)) +
230                               (1 << (bd + FILTER_BITS)) +
231                               (1 << (bd + FILTER_BITS - 1));
232 
233   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
234       filter_params_x, subpel_x_qn & SUBPEL_MASK);
235 
236   src -= horiz_offset;
237 
238   if (conv_params->do_average) {
239     if (x_filter_taps <= 4) {
240       highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
241                                            im_stride, w, h, x_filter_ptr,
242                                            conv_params, offset_convolve);
243     } else {
244       highbd_dist_wtd_convolve_x_sve2(src, src_stride, im_block, im_stride, w,
245                                       h, x_filter_ptr, conv_params,
246                                       offset_convolve);
247     }
248 
249     if (conv_params->use_dist_wtd_comp_avg) {
250       if (bd == 12) {
251         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
252                                          w, h, conv_params, offset_avg, bd);
253 
254       } else {
255         highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
256                                       h, conv_params, offset_avg, bd);
257       }
258 
259     } else {
260       if (bd == 12) {
261         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
262                                 conv_params, offset_avg, bd);
263 
264       } else {
265         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
266                              conv_params, offset_avg, bd);
267       }
268     }
269   } else {
270     if (x_filter_taps <= 4) {
271       highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, dst16,
272                                            dst16_stride, w, h, x_filter_ptr,
273                                            conv_params, offset_convolve);
274     } else {
275       highbd_dist_wtd_convolve_x_sve2(src, src_stride, dst16, dst16_stride, w,
276                                       h, x_filter_ptr, conv_params,
277                                       offset_convolve);
278     }
279   }
280 }
281