1 /*
2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <assert.h>
13 #include <arm_neon.h>
14
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
20 #include "aom_dsp/arm/aom_neon_sve2_bridge.h"
21 #include "aom_dsp/arm/mem_neon.h"
22 #include "aom_ports/mem.h"
23 #include "av1/common/convolve.h"
24 #include "av1/common/filter.h"
25 #include "av1/common/filter.h"
26 #include "av1/common/arm/highbd_compound_convolve_neon.h"
27 #include "av1/common/arm/highbd_convolve_neon.h"
28
29 DECLARE_ALIGNED(16, static const uint16_t, kDotProdTbl[32]) = {
30 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
31 4, 5, 6, 7, 5, 6, 7, 0, 6, 7, 0, 1, 7, 0, 1, 2,
32 };
33
convolve8_8_x(int16x8_t s0[8],int16x8_t filter,int64x2_t offset,int32x4_t shift)34 static INLINE uint16x8_t convolve8_8_x(int16x8_t s0[8], int16x8_t filter,
35 int64x2_t offset, int32x4_t shift) {
36 int64x2_t sum[8];
37 sum[0] = aom_sdotq_s16(offset, s0[0], filter);
38 sum[1] = aom_sdotq_s16(offset, s0[1], filter);
39 sum[2] = aom_sdotq_s16(offset, s0[2], filter);
40 sum[3] = aom_sdotq_s16(offset, s0[3], filter);
41 sum[4] = aom_sdotq_s16(offset, s0[4], filter);
42 sum[5] = aom_sdotq_s16(offset, s0[5], filter);
43 sum[6] = aom_sdotq_s16(offset, s0[6], filter);
44 sum[7] = aom_sdotq_s16(offset, s0[7], filter);
45
46 sum[0] = vpaddq_s64(sum[0], sum[1]);
47 sum[2] = vpaddq_s64(sum[2], sum[3]);
48 sum[4] = vpaddq_s64(sum[4], sum[5]);
49 sum[6] = vpaddq_s64(sum[6], sum[7]);
50
51 int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
52 int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
53
54 sum0123 = vshlq_s32(sum0123, shift);
55 sum4567 = vshlq_s32(sum4567, shift);
56
57 return vcombine_u16(vqmovun_s32(sum0123), vqmovun_s32(sum4567));
58 }
59
highbd_dist_wtd_convolve_x_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,ConvolveParams * conv_params,const int offset)60 static INLINE void highbd_dist_wtd_convolve_x_sve2(
61 const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
62 int width, int height, const int16_t *x_filter_ptr,
63 ConvolveParams *conv_params, const int offset) {
64 const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
65 const int64x2_t offset_vec = vdupq_n_s64(offset);
66
67 const int64x2_t offset_lo =
68 vcombine_s64(vget_low_s64(offset_vec), vdup_n_s64(0));
69 const int16x8_t filter = vld1q_s16(x_filter_ptr);
70 do {
71 const int16_t *s = (const int16_t *)src;
72 uint16_t *d = dst;
73 int w = width;
74
75 do {
76 int16x8_t s0[8], s1[8], s2[8], s3[8];
77 load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
78 &s0[4], &s0[5], &s0[6], &s0[7]);
79 load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
80 &s1[4], &s1[5], &s1[6], &s1[7]);
81 load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
82 &s2[4], &s2[5], &s2[6], &s2[7]);
83 load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
84 &s3[4], &s3[5], &s3[6], &s3[7]);
85
86 uint16x8_t d0 = convolve8_8_x(s0, filter, offset_lo, shift);
87 uint16x8_t d1 = convolve8_8_x(s1, filter, offset_lo, shift);
88 uint16x8_t d2 = convolve8_8_x(s2, filter, offset_lo, shift);
89 uint16x8_t d3 = convolve8_8_x(s3, filter, offset_lo, shift);
90
91 store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
92
93 s += 8;
94 d += 8;
95 w -= 8;
96 } while (w != 0);
97 src += 4 * src_stride;
98 dst += 4 * dst_stride;
99 height -= 4;
100 } while (height != 0);
101 }
102
convolve4_4_x(int16x8_t s0,int16x8_t filter,int64x2_t offset,int32x4_t shift,uint16x8x2_t permute_tbl)103 static INLINE uint16x4_t convolve4_4_x(int16x8_t s0, int16x8_t filter,
104 int64x2_t offset, int32x4_t shift,
105 uint16x8x2_t permute_tbl) {
106 int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
107 int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
108
109 int64x2_t sum01 = aom_svdot_lane_s16(offset, permuted_samples0, filter, 0);
110 int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
111
112 int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
113 sum0123 = vshlq_s32(sum0123, shift);
114
115 return vqmovun_s32(sum0123);
116 }
117
convolve4_8_x(int16x8_t s0[4],int16x8_t filter,int64x2_t offset,int32x4_t shift,uint16x8_t tbl)118 static INLINE uint16x8_t convolve4_8_x(int16x8_t s0[4], int16x8_t filter,
119 int64x2_t offset, int32x4_t shift,
120 uint16x8_t tbl) {
121 int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
122 int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
123 int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
124 int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
125
126 int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
127 sum0415 = vshlq_s32(sum0415, shift);
128
129 int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
130 sum2637 = vshlq_s32(sum2637, shift);
131
132 uint16x8_t res = vcombine_u16(vqmovun_s32(sum0415), vqmovun_s32(sum2637));
133 return aom_tbl_u16(res, tbl);
134 }
135
136 // clang-format off
137 DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
138 0, 2, 4, 6, 1, 3, 5, 7,
139 };
140 // clang-format on
141
highbd_dist_wtd_convolve_x_4tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,ConvolveParams * conv_params,const int offset)142 static INLINE void highbd_dist_wtd_convolve_x_4tap_sve2(
143 const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
144 int width, int height, const int16_t *x_filter_ptr,
145 ConvolveParams *conv_params, const int offset) {
146 // This shim allows to do only one rounding shift instead of two.
147 const int64x2_t offset_s64 = vdupq_n_s64(offset);
148 const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
149
150 const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
151 const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
152
153 if (width == 4) {
154 uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
155
156 const int16_t *s = (const int16_t *)(src);
157
158 do {
159 int16x8_t s0, s1, s2, s3;
160 load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
161
162 uint16x4_t d0 = convolve4_4_x(s0, filter, offset_s64, shift, permute_tbl);
163 uint16x4_t d1 = convolve4_4_x(s1, filter, offset_s64, shift, permute_tbl);
164 uint16x4_t d2 = convolve4_4_x(s2, filter, offset_s64, shift, permute_tbl);
165 uint16x4_t d3 = convolve4_4_x(s3, filter, offset_s64, shift, permute_tbl);
166
167 store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
168
169 s += 4 * src_stride;
170 dst += 4 * dst_stride;
171 height -= 4;
172 } while (height != 0);
173 } else {
174 uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
175
176 do {
177 const int16_t *s = (const int16_t *)(src);
178 uint16_t *d = dst;
179 int w = width;
180
181 do {
182 int16x8_t s0[4], s1[4], s2[4], s3[4];
183 load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
184 load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
185 load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
186 load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
187
188 uint16x8_t d0 = convolve4_8_x(s0, filter, offset_s64, shift, idx);
189 uint16x8_t d1 = convolve4_8_x(s1, filter, offset_s64, shift, idx);
190 uint16x8_t d2 = convolve4_8_x(s2, filter, offset_s64, shift, idx);
191 uint16x8_t d3 = convolve4_8_x(s3, filter, offset_s64, shift, idx);
192
193 store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
194
195 s += 8;
196 d += 8;
197 w -= 8;
198 } while (w != 0);
199 src += 4 * src_stride;
200 dst += 4 * dst_stride;
201 height -= 4;
202 } while (height != 0);
203 }
204 }
205
av1_highbd_dist_wtd_convolve_x_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params,int bd)206 void av1_highbd_dist_wtd_convolve_x_sve2(
207 const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
208 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
209 ConvolveParams *conv_params, int bd) {
210 DECLARE_ALIGNED(16, uint16_t,
211 im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
212 CONV_BUF_TYPE *dst16 = conv_params->dst;
213 const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
214
215 if (x_filter_taps == 6) {
216 av1_highbd_dist_wtd_convolve_x_neon(src, src_stride, dst, dst_stride, w, h,
217 filter_params_x, subpel_x_qn,
218 conv_params, bd);
219 return;
220 }
221
222 int dst16_stride = conv_params->dst_stride;
223 const int im_stride = MAX_SB_SIZE;
224 const int horiz_offset = filter_params_x->taps / 2 - 1;
225 assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
226 const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
227 const int offset_avg = (1 << (offset_bits - conv_params->round_1)) +
228 (1 << (offset_bits - conv_params->round_1 - 1));
229 const int offset_convolve = (1 << (conv_params->round_0 - 1)) +
230 (1 << (bd + FILTER_BITS)) +
231 (1 << (bd + FILTER_BITS - 1));
232
233 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
234 filter_params_x, subpel_x_qn & SUBPEL_MASK);
235
236 src -= horiz_offset;
237
238 if (conv_params->do_average) {
239 if (x_filter_taps <= 4) {
240 highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
241 im_stride, w, h, x_filter_ptr,
242 conv_params, offset_convolve);
243 } else {
244 highbd_dist_wtd_convolve_x_sve2(src, src_stride, im_block, im_stride, w,
245 h, x_filter_ptr, conv_params,
246 offset_convolve);
247 }
248
249 if (conv_params->use_dist_wtd_comp_avg) {
250 if (bd == 12) {
251 highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
252 w, h, conv_params, offset_avg, bd);
253
254 } else {
255 highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
256 h, conv_params, offset_avg, bd);
257 }
258
259 } else {
260 if (bd == 12) {
261 highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
262 conv_params, offset_avg, bd);
263
264 } else {
265 highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
266 conv_params, offset_avg, bd);
267 }
268 }
269 } else {
270 if (x_filter_taps <= 4) {
271 highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, dst16,
272 dst16_stride, w, h, x_filter_ptr,
273 conv_params, offset_convolve);
274 } else {
275 highbd_dist_wtd_convolve_x_sve2(src, src_stride, dst16, dst16_stride, w,
276 h, x_filter_ptr, conv_params,
277 offset_convolve);
278 }
279 }
280 }
281