1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11 #ifndef AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
12 #define AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
13
14 #include <assert.h>
15 #include <arm_neon.h>
16 #include <memory.h>
17 #include <math.h>
18
19 #include "aom_dsp/aom_dsp_common.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 #include "aom_dsp/arm/transpose_neon.h"
22 #include "aom_ports/mem.h"
23 #include "config/av1_rtcd.h"
24 #include "av1/common/warped_motion.h"
25 #include "av1/common/scale.h"
26
27 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f4(const uint8x16_t in,
28 int sx, int alpha);
29
30 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f8(const uint8x16_t in,
31 int sx, int alpha);
32
33 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f1(const uint8x16_t in,
34 int sx);
35
36 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f1(const uint8x16_t in,
37 int sx);
38
39 static AOM_FORCE_INLINE void vertical_filter_4x1_f1(const int16x8_t *src,
40 int32x4_t *res, int sy);
41
42 static AOM_FORCE_INLINE void vertical_filter_4x1_f4(const int16x8_t *src,
43 int32x4_t *res, int sy,
44 int gamma);
45
46 static AOM_FORCE_INLINE void vertical_filter_8x1_f1(const int16x8_t *src,
47 int32x4_t *res_low,
48 int32x4_t *res_high,
49 int sy);
50
51 static AOM_FORCE_INLINE void vertical_filter_8x1_f8(const int16x8_t *src,
52 int32x4_t *res_low,
53 int32x4_t *res_high, int sy,
54 int gamma);
55
load_filters_4(int16x8_t out[],int offset,int stride)56 static AOM_FORCE_INLINE void load_filters_4(int16x8_t out[], int offset,
57 int stride) {
58 out[0] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 0 * stride) >>
59 WARPEDDIFF_PREC_BITS)));
60 out[1] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 1 * stride) >>
61 WARPEDDIFF_PREC_BITS)));
62 out[2] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 2 * stride) >>
63 WARPEDDIFF_PREC_BITS)));
64 out[3] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 3 * stride) >>
65 WARPEDDIFF_PREC_BITS)));
66 }
67
load_filters_8(int16x8_t out[],int offset,int stride)68 static AOM_FORCE_INLINE void load_filters_8(int16x8_t out[], int offset,
69 int stride) {
70 out[0] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 0 * stride) >>
71 WARPEDDIFF_PREC_BITS)));
72 out[1] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 1 * stride) >>
73 WARPEDDIFF_PREC_BITS)));
74 out[2] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 2 * stride) >>
75 WARPEDDIFF_PREC_BITS)));
76 out[3] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 3 * stride) >>
77 WARPEDDIFF_PREC_BITS)));
78 out[4] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 4 * stride) >>
79 WARPEDDIFF_PREC_BITS)));
80 out[5] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 5 * stride) >>
81 WARPEDDIFF_PREC_BITS)));
82 out[6] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 6 * stride) >>
83 WARPEDDIFF_PREC_BITS)));
84 out[7] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 7 * stride) >>
85 WARPEDDIFF_PREC_BITS)));
86 }
87
clamp_iy(int iy,int height)88 static AOM_FORCE_INLINE int clamp_iy(int iy, int height) {
89 return clamp(iy, 0, height - 1);
90 }
91
warp_affine_horizontal(const uint8_t * ref,int width,int height,int stride,int p_width,int p_height,int16_t alpha,int16_t beta,const int64_t x4,const int64_t y4,const int i,int16x8_t tmp[])92 static AOM_FORCE_INLINE void warp_affine_horizontal(
93 const uint8_t *ref, int width, int height, int stride, int p_width,
94 int p_height, int16_t alpha, int16_t beta, const int64_t x4,
95 const int64_t y4, const int i, int16x8_t tmp[]) {
96 const int bd = 8;
97 const int reduce_bits_horiz = ROUND0_BITS;
98 const int height_limit = AOMMIN(8, p_height - i) + 7;
99
100 int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS);
101 int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS);
102
103 int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
104 sx4 += alpha * (-4) + beta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
105 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
106 sx4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
107
108 if (ix4 <= -7) {
109 for (int k = 0; k < height_limit; ++k) {
110 int iy = clamp_iy(iy4 + k - 7, height);
111 int16_t dup_val =
112 (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
113 ref[iy * stride] * (1 << (FILTER_BITS - reduce_bits_horiz));
114 tmp[k] = vdupq_n_s16(dup_val);
115 }
116 return;
117 } else if (ix4 >= width + 6) {
118 for (int k = 0; k < height_limit; ++k) {
119 int iy = clamp_iy(iy4 + k - 7, height);
120 int16_t dup_val = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
121 ref[iy * stride + (width - 1)] *
122 (1 << (FILTER_BITS - reduce_bits_horiz));
123 tmp[k] = vdupq_n_s16(dup_val);
124 }
125 return;
126 }
127
128 static const uint8_t kIotaArr[] = { 0, 1, 2, 3, 4, 5, 6, 7,
129 8, 9, 10, 11, 12, 13, 14, 15 };
130 const uint8x16_t indx = vld1q_u8(kIotaArr);
131
132 const int out_of_boundary_left = -(ix4 - 6);
133 const int out_of_boundary_right = (ix4 + 8) - width;
134
135 #define APPLY_HORIZONTAL_SHIFT(fn, ...) \
136 do { \
137 if (out_of_boundary_left >= 0 || out_of_boundary_right >= 0) { \
138 for (int k = 0; k < height_limit; ++k) { \
139 const int iy = clamp_iy(iy4 + k - 7, height); \
140 const uint8_t *src = ref + iy * stride + ix4 - 7; \
141 uint8x16_t src_1 = vld1q_u8(src); \
142 \
143 if (out_of_boundary_left >= 0) { \
144 int limit = out_of_boundary_left + 1; \
145 uint8x16_t cmp_vec = vdupq_n_u8(out_of_boundary_left); \
146 uint8x16_t vec_dup = vdupq_n_u8(*(src + limit)); \
147 uint8x16_t mask_val = vcleq_u8(indx, cmp_vec); \
148 src_1 = vbslq_u8(mask_val, vec_dup, src_1); \
149 } \
150 if (out_of_boundary_right >= 0) { \
151 int limit = 15 - (out_of_boundary_right + 1); \
152 uint8x16_t cmp_vec = vdupq_n_u8(15 - out_of_boundary_right); \
153 uint8x16_t vec_dup = vdupq_n_u8(*(src + limit)); \
154 uint8x16_t mask_val = vcgeq_u8(indx, cmp_vec); \
155 src_1 = vbslq_u8(mask_val, vec_dup, src_1); \
156 } \
157 tmp[k] = (fn)(src_1, __VA_ARGS__); \
158 } \
159 } else { \
160 for (int k = 0; k < height_limit; ++k) { \
161 const int iy = clamp_iy(iy4 + k - 7, height); \
162 const uint8_t *src = ref + iy * stride + ix4 - 7; \
163 uint8x16_t src_1 = vld1q_u8(src); \
164 tmp[k] = (fn)(src_1, __VA_ARGS__); \
165 } \
166 } \
167 } while (0)
168
169 if (p_width == 4) {
170 if (beta == 0) {
171 if (alpha == 0) {
172 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1, sx4);
173 } else {
174 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, sx4, alpha);
175 }
176 } else {
177 if (alpha == 0) {
178 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1,
179 (sx4 + beta * (k - 3)));
180 } else {
181 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, (sx4 + beta * (k - 3)),
182 alpha);
183 }
184 }
185 } else {
186 if (beta == 0) {
187 if (alpha == 0) {
188 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1, sx4);
189 } else {
190 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, sx4, alpha);
191 }
192 } else {
193 if (alpha == 0) {
194 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1,
195 (sx4 + beta * (k - 3)));
196 } else {
197 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, (sx4 + beta * (k - 3)),
198 alpha);
199 }
200 }
201 }
202 }
203
warp_affine_vertical(uint8_t * pred,int p_width,int p_height,int p_stride,int is_compound,uint16_t * dst,int dst_stride,int do_average,int use_dist_wtd_comp_avg,int16_t gamma,int16_t delta,const int64_t y4,const int i,const int j,int16x8_t tmp[],const int fwd,const int bwd)204 static AOM_FORCE_INLINE void warp_affine_vertical(
205 uint8_t *pred, int p_width, int p_height, int p_stride, int is_compound,
206 uint16_t *dst, int dst_stride, int do_average, int use_dist_wtd_comp_avg,
207 int16_t gamma, int16_t delta, const int64_t y4, const int i, const int j,
208 int16x8_t tmp[], const int fwd, const int bwd) {
209 const int bd = 8;
210 const int reduce_bits_horiz = ROUND0_BITS;
211 const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz;
212 int add_const_vert;
213 if (is_compound) {
214 add_const_vert =
215 (1 << offset_bits_vert) + (1 << (COMPOUND_ROUND1_BITS - 1));
216 } else {
217 add_const_vert =
218 (1 << offset_bits_vert) + (1 << (2 * FILTER_BITS - ROUND0_BITS - 1));
219 }
220 const int sub_constant = (1 << (bd - 1)) + (1 << bd);
221
222 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
223 const int res_sub_const =
224 (1 << (2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS - 1)) -
225 (1 << (offset_bits - COMPOUND_ROUND1_BITS)) -
226 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
227
228 int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
229 sy4 += gamma * (-4) + delta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
230 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
231 sy4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
232
233 if (p_width > 4) {
234 for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) {
235 int sy = sy4 + delta * (k + 4);
236 const int16x8_t *v_src = tmp + (k + 4);
237
238 int32x4_t res_lo, res_hi;
239 if (gamma == 0) {
240 vertical_filter_8x1_f1(v_src, &res_lo, &res_hi, sy);
241 } else {
242 vertical_filter_8x1_f8(v_src, &res_lo, &res_hi, sy, gamma);
243 }
244
245 res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert));
246 res_hi = vaddq_s32(res_hi, vdupq_n_s32(add_const_vert));
247
248 if (is_compound) {
249 uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j];
250 int16x8_t res_s16 =
251 vcombine_s16(vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS),
252 vshrn_n_s32(res_hi, COMPOUND_ROUND1_BITS));
253 if (do_average) {
254 int16x8_t tmp16 = vreinterpretq_s16_u16(vld1q_u16(p));
255 if (use_dist_wtd_comp_avg) {
256 int32x4_t tmp32_lo = vmull_n_s16(vget_low_s16(tmp16), fwd);
257 int32x4_t tmp32_hi = vmull_n_s16(vget_high_s16(tmp16), fwd);
258 tmp32_lo = vmlal_n_s16(tmp32_lo, vget_low_s16(res_s16), bwd);
259 tmp32_hi = vmlal_n_s16(tmp32_hi, vget_high_s16(res_s16), bwd);
260 tmp16 = vcombine_s16(vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS),
261 vshrn_n_s32(tmp32_hi, DIST_PRECISION_BITS));
262 } else {
263 tmp16 = vhaddq_s16(tmp16, res_s16);
264 }
265 int16x8_t res = vaddq_s16(tmp16, vdupq_n_s16(res_sub_const));
266 uint8x8_t res8 = vqshrun_n_s16(
267 res, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS);
268 vst1_u8(&pred[(i + k + 4) * p_stride + j], res8);
269 } else {
270 vst1q_u16(p, vreinterpretq_u16_s16(res_s16));
271 }
272 } else {
273 int16x8_t res16 =
274 vcombine_s16(vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS),
275 vshrn_n_s32(res_hi, 2 * FILTER_BITS - ROUND0_BITS));
276 res16 = vsubq_s16(res16, vdupq_n_s16(sub_constant));
277
278 uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j];
279 vst1_u8(p, vqmovun_s16(res16));
280 }
281 }
282 } else {
283 // p_width == 4
284 for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) {
285 int sy = sy4 + delta * (k + 4);
286 const int16x8_t *v_src = tmp + (k + 4);
287
288 int32x4_t res_lo;
289 if (gamma == 0) {
290 vertical_filter_4x1_f1(v_src, &res_lo, sy);
291 } else {
292 vertical_filter_4x1_f4(v_src, &res_lo, sy, gamma);
293 }
294
295 res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert));
296
297 if (is_compound) {
298 uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j];
299
300 int16x4_t res_lo_s16 = vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS);
301 if (do_average) {
302 uint8_t *const dst8 = &pred[(i + k + 4) * p_stride + j];
303 int16x4_t tmp16_lo = vreinterpret_s16_u16(vld1_u16(p));
304 if (use_dist_wtd_comp_avg) {
305 int32x4_t tmp32_lo = vmull_n_s16(tmp16_lo, fwd);
306 tmp32_lo = vmlal_n_s16(tmp32_lo, res_lo_s16, bwd);
307 tmp16_lo = vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS);
308 } else {
309 tmp16_lo = vhadd_s16(tmp16_lo, res_lo_s16);
310 }
311 int16x4_t res = vadd_s16(tmp16_lo, vdup_n_s16(res_sub_const));
312 uint8x8_t res8 = vqshrun_n_s16(
313 vcombine_s16(res, vdup_n_s16(0)),
314 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS);
315 vst1_lane_u32((uint32_t *)dst8, vreinterpret_u32_u8(res8), 0);
316 } else {
317 uint16x4_t res_u16_low = vreinterpret_u16_s16(res_lo_s16);
318 vst1_u16(p, res_u16_low);
319 }
320 } else {
321 int16x4_t res16 = vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS);
322 res16 = vsub_s16(res16, vdup_n_s16(sub_constant));
323
324 uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j];
325 uint8x8_t val = vqmovun_s16(vcombine_s16(res16, vdup_n_s16(0)));
326 vst1_lane_u32((uint32_t *)p, vreinterpret_u32_u8(val), 0);
327 }
328 }
329 }
330 }
331
av1_warp_affine_common(const int32_t * mat,const uint8_t * ref,int width,int height,int stride,uint8_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)332 static AOM_FORCE_INLINE void av1_warp_affine_common(
333 const int32_t *mat, const uint8_t *ref, int width, int height, int stride,
334 uint8_t *pred, int p_col, int p_row, int p_width, int p_height,
335 int p_stride, int subsampling_x, int subsampling_y,
336 ConvolveParams *conv_params, int16_t alpha, int16_t beta, int16_t gamma,
337 int16_t delta) {
338 const int w0 = conv_params->fwd_offset;
339 const int w1 = conv_params->bck_offset;
340 const int is_compound = conv_params->is_compound;
341 uint16_t *const dst = conv_params->dst;
342 const int dst_stride = conv_params->dst_stride;
343 const int do_average = conv_params->do_average;
344 const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
345
346 assert(IMPLIES(is_compound, dst != NULL));
347 assert(IMPLIES(do_average, is_compound));
348
349 for (int i = 0; i < p_height; i += 8) {
350 for (int j = 0; j < p_width; j += 8) {
351 const int32_t src_x = (p_col + j + 4) << subsampling_x;
352 const int32_t src_y = (p_row + i + 4) << subsampling_y;
353 const int64_t dst_x =
354 (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0];
355 const int64_t dst_y =
356 (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1];
357
358 const int64_t x4 = dst_x >> subsampling_x;
359 const int64_t y4 = dst_y >> subsampling_y;
360
361 int16x8_t tmp[15];
362 warp_affine_horizontal(ref, width, height, stride, p_width, p_height,
363 alpha, beta, x4, y4, i, tmp);
364 warp_affine_vertical(pred, p_width, p_height, p_stride, is_compound, dst,
365 dst_stride, do_average, use_dist_wtd_comp_avg, gamma,
366 delta, y4, i, j, tmp, w0, w1);
367 }
368 }
369 }
370
371 #endif // AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
372