• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdbool.h>
15 #include <arm_neon_sve_bridge.h>
16 
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/transpose_neon.h"
21 #include "aom_ports/mem.h"
22 #include "av1/common/scale.h"
23 #include "av1/common/warped_motion.h"
24 #include "config/av1_rtcd.h"
25 #include "highbd_warp_plane_neon.h"
26 
27 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f4(uint16x8x2_t in,int bd,int sx,int alpha)28 highbd_horizontal_filter_4x1_f4(uint16x8x2_t in, int bd, int sx, int alpha) {
29   int16x8_t f[4];
30   load_filters_4(f, sx, alpha);
31 
32   int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
33                             vreinterpretq_s16_u16(in.val[1]), 0);
34   int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
35                             vreinterpretq_s16_u16(in.val[1]), 1);
36   int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
37                             vreinterpretq_s16_u16(in.val[1]), 2);
38   int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
39                             vreinterpretq_s16_u16(in.val[1]), 3);
40 
41   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f[0]);
42   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f[1]);
43   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f[2]);
44   int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f[3]);
45 
46   int64x2_t m01 = vpaddq_s64(m0, m1);
47   int64x2_t m23 = vpaddq_s64(m2, m3);
48 
49   const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
50   const int offset_bits_horiz = bd + FILTER_BITS - 1;
51 
52   int32x4_t res = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
53   res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
54   res = vrshlq_s32(res, vdupq_n_s32(-round0));
55   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
56 }
57 
58 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_8x1_f8(uint16x8x2_t in,int bd,int sx,int alpha)59 highbd_horizontal_filter_8x1_f8(uint16x8x2_t in, int bd, int sx, int alpha) {
60   int16x8_t f[8];
61   load_filters_8(f, sx, alpha);
62 
63   int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
64                             vreinterpretq_s16_u16(in.val[1]), 0);
65   int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
66                             vreinterpretq_s16_u16(in.val[1]), 1);
67   int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
68                             vreinterpretq_s16_u16(in.val[1]), 2);
69   int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
70                             vreinterpretq_s16_u16(in.val[1]), 3);
71   int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
72                             vreinterpretq_s16_u16(in.val[1]), 4);
73   int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
74                             vreinterpretq_s16_u16(in.val[1]), 5);
75   int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
76                             vreinterpretq_s16_u16(in.val[1]), 6);
77   int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
78                             vreinterpretq_s16_u16(in.val[1]), 7);
79 
80   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f[0]);
81   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f[1]);
82   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f[2]);
83   int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f[3]);
84   int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), rv4, f[4]);
85   int64x2_t m5 = aom_sdotq_s16(vdupq_n_s64(0), rv5, f[5]);
86   int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), rv6, f[6]);
87   int64x2_t m7 = aom_sdotq_s16(vdupq_n_s64(0), rv7, f[7]);
88 
89   int64x2_t m01 = vpaddq_s64(m0, m1);
90   int64x2_t m23 = vpaddq_s64(m2, m3);
91   int64x2_t m45 = vpaddq_s64(m4, m5);
92   int64x2_t m67 = vpaddq_s64(m6, m7);
93 
94   const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
95   const int offset_bits_horiz = bd + FILTER_BITS - 1;
96 
97   int32x4_t res0 = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
98   int32x4_t res1 = vcombine_s32(vmovn_s64(m45), vmovn_s64(m67));
99   res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
100   res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
101   res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
102   res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
103   return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
104 }
105 
106 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f1(uint16x8x2_t in,int bd,int sx)107 highbd_horizontal_filter_4x1_f1(uint16x8x2_t in, int bd, int sx) {
108   int16x8_t f = load_filters_1(sx);
109 
110   int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
111                             vreinterpretq_s16_u16(in.val[1]), 0);
112   int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
113                             vreinterpretq_s16_u16(in.val[1]), 1);
114   int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
115                             vreinterpretq_s16_u16(in.val[1]), 2);
116   int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
117                             vreinterpretq_s16_u16(in.val[1]), 3);
118 
119   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f);
120   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f);
121   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f);
122   int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f);
123 
124   int64x2_t m01 = vpaddq_s64(m0, m1);
125   int64x2_t m23 = vpaddq_s64(m2, m3);
126 
127   const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
128   const int offset_bits_horiz = bd + FILTER_BITS - 1;
129 
130   int32x4_t res = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
131   res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
132   res = vrshlq_s32(res, vdupq_n_s32(-round0));
133   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
134 }
135 
136 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_8x1_f1(uint16x8x2_t in,int bd,int sx)137 highbd_horizontal_filter_8x1_f1(uint16x8x2_t in, int bd, int sx) {
138   int16x8_t f = load_filters_1(sx);
139 
140   int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
141                             vreinterpretq_s16_u16(in.val[1]), 0);
142   int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
143                             vreinterpretq_s16_u16(in.val[1]), 1);
144   int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
145                             vreinterpretq_s16_u16(in.val[1]), 2);
146   int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
147                             vreinterpretq_s16_u16(in.val[1]), 3);
148   int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
149                             vreinterpretq_s16_u16(in.val[1]), 4);
150   int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
151                             vreinterpretq_s16_u16(in.val[1]), 5);
152   int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
153                             vreinterpretq_s16_u16(in.val[1]), 6);
154   int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
155                             vreinterpretq_s16_u16(in.val[1]), 7);
156 
157   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f);
158   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f);
159   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f);
160   int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), rv3, f);
161   int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), rv4, f);
162   int64x2_t m5 = aom_sdotq_s16(vdupq_n_s64(0), rv5, f);
163   int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), rv6, f);
164   int64x2_t m7 = aom_sdotq_s16(vdupq_n_s64(0), rv7, f);
165 
166   int64x2_t m01 = vpaddq_s64(m0, m1);
167   int64x2_t m23 = vpaddq_s64(m2, m3);
168   int64x2_t m45 = vpaddq_s64(m4, m5);
169   int64x2_t m67 = vpaddq_s64(m6, m7);
170 
171   const int round0 = bd == 12 ? ROUND0_BITS + 2 : ROUND0_BITS;
172   const int offset_bits_horiz = bd + FILTER_BITS - 1;
173 
174   int32x4_t res0 = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
175   int32x4_t res1 = vcombine_s32(vmovn_s64(m45), vmovn_s64(m67));
176   res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
177   res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
178   res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
179   res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
180   return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
181 }
182 
vertical_filter_4x1_f1(const int16x8_t * tmp,int sy)183 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f1(const int16x8_t *tmp,
184                                                          int sy) {
185   const int16x8_t f = load_filters_1(sy);
186   const int16x4_t f0123 = vget_low_s16(f);
187   const int16x4_t f4567 = vget_high_s16(f);
188 
189   // No benefit to using SDOT here, the cost of rearrangement is too high.
190   int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
191   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
192   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
193   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
194   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
195   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
196   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
197   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
198   return m0123;
199 }
200 
vertical_filter_8x1_f1(const int16x8_t * tmp,int sy)201 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f1(const int16x8_t *tmp,
202                                                            int sy) {
203   const int16x8_t f = load_filters_1(sy);
204   const int16x4_t f0123 = vget_low_s16(f);
205   const int16x4_t f4567 = vget_high_s16(f);
206 
207   // No benefit to using SDOT here, the cost of rearrangement is too high.
208   int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
209   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
210   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
211   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
212   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
213   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
214   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
215   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
216 
217   int32x4_t m4567 = vmull_lane_s16(vget_high_s16(tmp[0]), f0123, 0);
218   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[1]), f0123, 1);
219   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[2]), f0123, 2);
220   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[3]), f0123, 3);
221   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[4]), f4567, 0);
222   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[5]), f4567, 1);
223   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[6]), f4567, 2);
224   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[7]), f4567, 3);
225   return (int32x4x2_t){ { m0123, m4567 } };
226 }
227 
vertical_filter_4x1_f4(const int16x8_t * tmp,int sy,int gamma)228 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f4(const int16x8_t *tmp,
229                                                          int sy, int gamma) {
230   int16x8_t s0, s1, s2, s3;
231   transpose_elems_s16_4x8(
232       vget_low_s16(tmp[0]), vget_low_s16(tmp[1]), vget_low_s16(tmp[2]),
233       vget_low_s16(tmp[3]), vget_low_s16(tmp[4]), vget_low_s16(tmp[5]),
234       vget_low_s16(tmp[6]), vget_low_s16(tmp[7]), &s0, &s1, &s2, &s3);
235 
236   int16x8_t f[4];
237   load_filters_4(f, sy, gamma);
238 
239   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), s0, f[0]);
240   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), s1, f[1]);
241   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), s2, f[2]);
242   int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), s3, f[3]);
243 
244   int64x2_t m01 = vpaddq_s64(m0, m1);
245   int64x2_t m23 = vpaddq_s64(m2, m3);
246   return vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
247 }
248 
vertical_filter_8x1_f8(const int16x8_t * tmp,int sy,int gamma)249 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f8(const int16x8_t *tmp,
250                                                            int sy, int gamma) {
251   int16x8_t s0 = tmp[0];
252   int16x8_t s1 = tmp[1];
253   int16x8_t s2 = tmp[2];
254   int16x8_t s3 = tmp[3];
255   int16x8_t s4 = tmp[4];
256   int16x8_t s5 = tmp[5];
257   int16x8_t s6 = tmp[6];
258   int16x8_t s7 = tmp[7];
259   transpose_elems_inplace_s16_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
260 
261   int16x8_t f[8];
262   load_filters_8(f, sy, gamma);
263 
264   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), s0, f[0]);
265   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), s1, f[1]);
266   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), s2, f[2]);
267   int64x2_t m3 = aom_sdotq_s16(vdupq_n_s64(0), s3, f[3]);
268   int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), s4, f[4]);
269   int64x2_t m5 = aom_sdotq_s16(vdupq_n_s64(0), s5, f[5]);
270   int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), s6, f[6]);
271   int64x2_t m7 = aom_sdotq_s16(vdupq_n_s64(0), s7, f[7]);
272 
273   int64x2_t m01 = vpaddq_s64(m0, m1);
274   int64x2_t m23 = vpaddq_s64(m2, m3);
275   int64x2_t m45 = vpaddq_s64(m4, m5);
276   int64x2_t m67 = vpaddq_s64(m6, m7);
277 
278   int32x4x2_t ret;
279   ret.val[0] = vcombine_s32(vmovn_s64(m01), vmovn_s64(m23));
280   ret.val[1] = vcombine_s32(vmovn_s64(m45), vmovn_s64(m67));
281   return ret;
282 }
283 
av1_highbd_warp_affine_sve(const int32_t * mat,const uint16_t * ref,int width,int height,int stride,uint16_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,int bd,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)284 void av1_highbd_warp_affine_sve(const int32_t *mat, const uint16_t *ref,
285                                 int width, int height, int stride,
286                                 uint16_t *pred, int p_col, int p_row,
287                                 int p_width, int p_height, int p_stride,
288                                 int subsampling_x, int subsampling_y, int bd,
289                                 ConvolveParams *conv_params, int16_t alpha,
290                                 int16_t beta, int16_t gamma, int16_t delta) {
291   highbd_warp_affine_common(mat, ref, width, height, stride, pred, p_col, p_row,
292                             p_width, p_height, p_stride, subsampling_x,
293                             subsampling_y, bd, conv_params, alpha, beta, gamma,
294                             delta);
295 }
296