1 /*
2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdbool.h>
15 #include <arm_neon_sve_bridge.h>
16
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/transpose_neon.h"
21 #include "aom_ports/mem.h"
22 #include "av1/common/scale.h"
23 #include "av1/common/warped_motion.h"
24 #include "config/av1_rtcd.h"
25 #include "highbd_warp_plane_neon.h"
26
27 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f4(uint16x8x2_t in,int bd,int sx,int alpha)28 highbd_horizontal_filter_4x1_f4(uint16x8x2_t in, int bd, int sx, int alpha) {
29 int16x8_t f[4];
30 load_filters_4(f, sx, alpha);
31
32 int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
33 vreinterpretq_s16_u16(in.val[1]), 0);
34 int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
35 vreinterpretq_s16_u16(in.val[1]), 1);
36 int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
37 vreinterpretq_s16_u16(in.val[1]), 2);
38 int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
39 vreinterpretq_s16_u16(in.val[1]), 3);
40
41 int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f[0]);
42 int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f[1]);
43 int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f[2]);
44 int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f[3]);
45
46 int64x2_t m01 = vpaddq_s64(m0, m1);
47 int64x2_t m23 = vpaddq_s64(m2, m3);
48
49 const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
50 const int offset_bits_horiz = bd + FILTER_BITS - 1;
51
52 int32x4_t res = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
53 res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
54 res = vrshlq_s32(res, vdupq_n_s32(-round0));
55 return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
56 }
57
58 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_8x1_f8(uint16x8x2_t in,int bd,int sx,int alpha)59 highbd_horizontal_filter_8x1_f8(uint16x8x2_t in, int bd, int sx, int alpha) {
60 int16x8_t f[8];
61 load_filters_8(f, sx, alpha);
62
63 int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
64 vreinterpretq_s16_u16(in.val[1]), 0);
65 int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
66 vreinterpretq_s16_u16(in.val[1]), 1);
67 int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
68 vreinterpretq_s16_u16(in.val[1]), 2);
69 int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
70 vreinterpretq_s16_u16(in.val[1]), 3);
71 int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
72 vreinterpretq_s16_u16(in.val[1]), 4);
73 int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
74 vreinterpretq_s16_u16(in.val[1]), 5);
75 int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
76 vreinterpretq_s16_u16(in.val[1]), 6);
77 int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
78 vreinterpretq_s16_u16(in.val[1]), 7);
79
80 int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f[0]);
81 int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f[1]);
82 int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f[2]);
83 int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f[3]);
84 int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), rv4, f[4]);
85 int64x2_t m5 = aom_sdotq_s16(vdupq_n_s64(0), rv5, f[5]);
86 int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), rv6, f[6]);
87 int64x2_t m7 = aom_sdotq_s16(vdupq_n_s64(0), rv7, f[7]);
88
89 int64x2_t m01 = vpaddq_s64(m0, m1);
90 int64x2_t m23 = vpaddq_s64(m2, m3);
91 int64x2_t m45 = vpaddq_s64(m4, m5);
92 int64x2_t m67 = vpaddq_s64(m6, m7);
93
94 const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
95 const int offset_bits_horiz = bd + FILTER_BITS - 1;
96
97 int32x4_t res0 = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
98 int32x4_t res1 = vcombine_s32(vmovn_s64(m45), vmovn_s64(m67));
99 res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
100 res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
101 res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
102 res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
103 return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
104 }
105
106 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f1(uint16x8x2_t in,int bd,int sx)107 highbd_horizontal_filter_4x1_f1(uint16x8x2_t in, int bd, int sx) {
108 int16x8_t f = load_filters_1(sx);
109
110 int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
111 vreinterpretq_s16_u16(in.val[1]), 0);
112 int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
113 vreinterpretq_s16_u16(in.val[1]), 1);
114 int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
115 vreinterpretq_s16_u16(in.val[1]), 2);
116 int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
117 vreinterpretq_s16_u16(in.val[1]), 3);
118
119 int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f);
120 int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f);
121 int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f);
122 int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f);
123
124 int64x2_t m01 = vpaddq_s64(m0, m1);
125 int64x2_t m23 = vpaddq_s64(m2, m3);
126
127 const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
128 const int offset_bits_horiz = bd + FILTER_BITS - 1;
129
130 int32x4_t res = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
131 res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
132 res = vrshlq_s32(res, vdupq_n_s32(-round0));
133 return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
134 }
135
136 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_8x1_f1(uint16x8x2_t in,int bd,int sx)137 highbd_horizontal_filter_8x1_f1(uint16x8x2_t in, int bd, int sx) {
138 int16x8_t f = load_filters_1(sx);
139
140 int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
141 vreinterpretq_s16_u16(in.val[1]), 0);
142 int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
143 vreinterpretq_s16_u16(in.val[1]), 1);
144 int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
145 vreinterpretq_s16_u16(in.val[1]), 2);
146 int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
147 vreinterpretq_s16_u16(in.val[1]), 3);
148 int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
149 vreinterpretq_s16_u16(in.val[1]), 4);
150 int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
151 vreinterpretq_s16_u16(in.val[1]), 5);
152 int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
153 vreinterpretq_s16_u16(in.val[1]), 6);
154 int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
155 vreinterpretq_s16_u16(in.val[1]), 7);
156
157 int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f);
158 int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f);
159 int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f);
160 int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f);
161 int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), rv4, f);
162 int64x2_t m5 = aom_sdotq_s16(vdupq_n_s64(0), rv5, f);
163 int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), rv6, f);
164 int64x2_t m7 = aom_sdotq_s16(vdupq_n_s64(0), rv7, f);
165
166 int64x2_t m01 = vpaddq_s64(m0, m1);
167 int64x2_t m23 = vpaddq_s64(m2, m3);
168 int64x2_t m45 = vpaddq_s64(m4, m5);
169 int64x2_t m67 = vpaddq_s64(m6, m7);
170
171 const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
172 const int offset_bits_horiz = bd + FILTER_BITS - 1;
173
174 int32x4_t res0 = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
175 int32x4_t res1 = vcombine_s32(vmovn_s64(m45), vmovn_s64(m67));
176 res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
177 res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
178 res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
179 res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
180 return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
181 }
182
vertical_filter_4x1_f1(const int16x8_t * tmp,int sy)183 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f1(const int16x8_t *tmp,
184 int sy) {
185 const int16x8_t f = load_filters_1(sy);
186 const int16x4_t f0123 = vget_low_s16(f);
187 const int16x4_t f4567 = vget_high_s16(f);
188
189 // No benefit to using SDOT here, the cost of rearrangement is too high.
190 int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
191 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
192 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
193 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
194 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
195 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
196 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
197 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
198 return m0123;
199 }
200
vertical_filter_8x1_f1(const int16x8_t * tmp,int sy)201 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f1(const int16x8_t *tmp,
202 int sy) {
203 const int16x8_t f = load_filters_1(sy);
204 const int16x4_t f0123 = vget_low_s16(f);
205 const int16x4_t f4567 = vget_high_s16(f);
206
207 // No benefit to using SDOT here, the cost of rearrangement is too high.
208 int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
209 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
210 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
211 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
212 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
213 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
214 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
215 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
216
217 int32x4_t m4567 = vmull_lane_s16(vget_high_s16(tmp[0]), f0123, 0);
218 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[1]), f0123, 1);
219 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[2]), f0123, 2);
220 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[3]), f0123, 3);
221 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[4]), f4567, 0);
222 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[5]), f4567, 1);
223 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[6]), f4567, 2);
224 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[7]), f4567, 3);
225 return (int32x4x2_t){ { m0123, m4567 } };
226 }
227
vertical_filter_4x1_f4(const int16x8_t * tmp,int sy,int gamma)228 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f4(const int16x8_t *tmp,
229 int sy, int gamma) {
230 int16x8_t s0, s1, s2, s3;
231 transpose_elems_s16_4x8(
232 vget_low_s16(tmp[0]), vget_low_s16(tmp[1]), vget_low_s16(tmp[2]),
233 vget_low_s16(tmp[3]), vget_low_s16(tmp[4]), vget_low_s16(tmp[5]),
234 vget_low_s16(tmp[6]), vget_low_s16(tmp[7]), &s0, &s1, &s2, &s3);
235
236 int16x8_t f[4];
237 load_filters_4(f, sy, gamma);
238
239 int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), s0, f[0]);
240 int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), s1, f[1]);
241 int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), s2, f[2]);
242 int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), s3, f[3]);
243
244 int64x2_t m01 = vpaddq_s64(m0, m1);
245 int64x2_t m23 = vpaddq_s64(m2, m3);
246 return vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
247 }
248
vertical_filter_8x1_f8(const int16x8_t * tmp,int sy,int gamma)249 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f8(const int16x8_t *tmp,
250 int sy, int gamma) {
251 int16x8_t s0 = tmp[0];
252 int16x8_t s1 = tmp[1];
253 int16x8_t s2 = tmp[2];
254 int16x8_t s3 = tmp[3];
255 int16x8_t s4 = tmp[4];
256 int16x8_t s5 = tmp[5];
257 int16x8_t s6 = tmp[6];
258 int16x8_t s7 = tmp[7];
259 transpose_elems_inplace_s16_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
260
261 int16x8_t f[8];
262 load_filters_8(f, sy, gamma);
263
264 int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), s0, f[0]);
265 int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), s1, f[1]);
266 int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), s2, f[2]);
267 int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), s3, f[3]);
268 int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), s4, f[4]);
269 int64x2_t m5 = aom_sdotq_s16(vdupq_n_s64(0), s5, f[5]);
270 int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), s6, f[6]);
271 int64x2_t m7 = aom_sdotq_s16(vdupq_n_s64(0), s7, f[7]);
272
273 int64x2_t m01 = vpaddq_s64(m0, m1);
274 int64x2_t m23 = vpaddq_s64(m2, m3);
275 int64x2_t m45 = vpaddq_s64(m4, m5);
276 int64x2_t m67 = vpaddq_s64(m6, m7);
277
278 int32x4x2_t ret;
279 ret.val[0] = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
280 ret.val[1] = vcombine_s32(vmovn_s64(m45), vmovn_s64(m67));
281 return ret;
282 }
283
av1_highbd_warp_affine_sve(const int32_t * mat,const uint16_t * ref,int width,int height,int stride,uint16_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,int bd,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)284 void av1_highbd_warp_affine_sve(const int32_t *mat, const uint16_t *ref,
285 int width, int height, int stride,
286 uint16_t *pred, int p_col, int p_row,
287 int p_width, int p_height, int p_stride,
288 int subsampling_x, int subsampling_y, int bd,
289 ConvolveParams *conv_params, int16_t alpha,
290 int16_t beta, int16_t gamma, int16_t delta) {
291 highbd_warp_affine_common(mat, ref, width, height, stride, pred, p_col, p_row,
292 p_width, p_height, p_stride, subsampling_x,
293 subsampling_y, bd, conv_params, alpha, beta, gamma,
294 delta);
295 }
296