1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "av1/common/arm/compound_convolve_neon.h"
17 #include "config/aom_config.h"
18 #include "config/av1_rtcd.h"
19
20 DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
21 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
22 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
23 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
24 };
25
convolve4_4_2d_h(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)26 static INLINE int16x4_t convolve4_4_2d_h(uint8x16_t samples,
27 const int8x8_t x_filter,
28 const int32x4_t correction,
29 const uint8x16_t range_limit,
30 const uint8x16_t permute_tbl) {
31 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
32 int8x16_t clamped_samples =
33 vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
34
35 // Permute samples ready for dot product.
36 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
37 int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
38
39 // Accumulate dot product into 'correction' to account for range clamp.
40 int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
41
42 // We halved the convolution filter values so -1 from the right shift.
43 return vshrn_n_s32(sum, ROUND0_BITS - 1);
44 }
45
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)46 static INLINE int16x8_t convolve8_8_2d_h(uint8x16_t samples,
47 const int8x8_t x_filter,
48 const int32x4_t correction,
49 const uint8x16_t range_limit,
50 const uint8x16x3_t permute_tbl) {
51 int8x16_t clamped_samples, permuted_samples[3];
52 int32x4_t sum[2];
53
54 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
55 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
56
57 // Permute samples ready for dot product. */
58 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
59 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
60 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
61 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
62 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
63 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
64
65 // Accumulate dot product into 'correction' to account for range clamp.
66 // First 4 output values.
67 sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
68 sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
69 // Second 4 output values.
70 sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
71 sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
72
73 // Narrow and re-pack.
74 // We halved the convolution filter values so -1 from the right shift.
75 return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
76 vshrn_n_s32(sum[1], ROUND0_BITS - 1));
77 }
78
dist_wtd_convolve_2d_horiz_neon_dotprod(const uint8_t * src,int src_stride,int16_t * im_block,const int im_stride,const int16_t * x_filter_ptr,const int im_h,int w)79 static INLINE void dist_wtd_convolve_2d_horiz_neon_dotprod(
80 const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
81 const int16_t *x_filter_ptr, const int im_h, int w) {
82 const int bd = 8;
83 const int32_t horiz_const = (1 << (bd + FILTER_BITS - 2));
84 // Dot product constants and other shims.
85 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
86 const int32_t correction_s32 =
87 vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
88 // Fold horiz_const into the dot-product filter correction constant. The
89 // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
90 // rounding shifts - which are generally faster than rounding shifts on
91 // modern CPUs. (The extra -1 is needed because we halved the filter values.)
92 const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const +
93 (1 << ((ROUND0_BITS - 1) - 1)));
94 const uint8x16_t range_limit = vdupq_n_u8(128);
95
96 const uint8_t *src_ptr = src;
97 int16_t *dst_ptr = im_block;
98 int dst_stride = im_stride;
99 int height = im_h;
100
101 if (w == 4) {
102 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
103 // 4-tap filters are used for blocks having width <= 4.
104 // Filter values are even, so halve to reduce intermediate precision reqs.
105 const int8x8_t x_filter =
106 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
107
108 src_ptr += 2;
109
110 do {
111 uint8x16_t s0, s1, s2, s3;
112 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
113
114 int16x4_t d0 =
115 convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
116 int16x4_t d1 =
117 convolve4_4_2d_h(s1, x_filter, correction, range_limit, permute_tbl);
118 int16x4_t d2 =
119 convolve4_4_2d_h(s2, x_filter, correction, range_limit, permute_tbl);
120 int16x4_t d3 =
121 convolve4_4_2d_h(s3, x_filter, correction, range_limit, permute_tbl);
122
123 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
124
125 src_ptr += 4 * src_stride;
126 dst_ptr += 4 * dst_stride;
127 height -= 4;
128 } while (height > 4);
129
130 do {
131 uint8x16_t s0 = vld1q_u8(src_ptr);
132
133 int16x4_t d0 =
134 convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
135
136 vst1_s16(dst_ptr, d0);
137
138 src_ptr += src_stride;
139 dst_ptr += dst_stride;
140 } while (--height != 0);
141 } else {
142 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
143 // Filter values are even, so halve to reduce intermediate precision reqs.
144 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
145
146 do {
147 const uint8_t *s = src_ptr;
148 int16_t *d = dst_ptr;
149 int width = w;
150
151 do {
152 uint8x16_t s0, s1, s2, s3;
153 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
154
155 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
156 permute_tbl);
157 int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, correction, range_limit,
158 permute_tbl);
159 int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, correction, range_limit,
160 permute_tbl);
161 int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, correction, range_limit,
162 permute_tbl);
163
164 store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
165
166 s += 8;
167 d += 8;
168 width -= 8;
169 } while (width > 0);
170 src_ptr += 4 * src_stride;
171 dst_ptr += 4 * dst_stride;
172 height -= 4;
173 } while (height > 4);
174
175 do {
176 const uint8_t *s = src_ptr;
177 int16_t *d = dst_ptr;
178 int width = w;
179
180 do {
181 uint8x16_t s0 = vld1q_u8(s);
182
183 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
184 permute_tbl);
185
186 vst1q_s16(d, d0);
187
188 s += 8;
189 d += 8;
190 width -= 8;
191 } while (width > 0);
192 src_ptr += src_stride;
193 dst_ptr += dst_stride;
194 } while (--height != 0);
195 }
196 }
197
av1_dist_wtd_convolve_2d_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)198 void av1_dist_wtd_convolve_2d_neon_dotprod(
199 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
200 int h, const InterpFilterParams *filter_params_x,
201 const InterpFilterParams *filter_params_y, const int subpel_x_qn,
202 const int subpel_y_qn, ConvolveParams *conv_params) {
203 assert(w % 4 == 0);
204 assert(h % 4 == 0);
205
206 DECLARE_ALIGNED(16, int16_t,
207 im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
208
209 const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
210 const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
211
212 const int im_h = h + clamped_y_taps - 1;
213 const int im_stride = MAX_SB_SIZE;
214 const int vert_offset = clamped_y_taps / 2 - 1;
215 const int horiz_offset = filter_params_x->taps / 2 - 1;
216 const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
217 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
218 filter_params_x, subpel_x_qn & SUBPEL_MASK);
219 const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
220 filter_params_y, subpel_y_qn & SUBPEL_MASK);
221
222 const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
223
224 dist_wtd_convolve_2d_horiz_neon_dotprod(src_ptr, src_stride, im_block,
225 im_stride, x_filter_ptr, im_h, w);
226
227 if (clamped_y_taps == 6) {
228 if (conv_params->do_average) {
229 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
230 dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
231 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
232 w);
233 } else {
234 dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8,
235 dst8_stride, conv_params,
236 y_filter, h, w);
237 }
238 } else {
239 dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params,
240 y_filter, h, w);
241 }
242 } else {
243 if (conv_params->do_average) {
244 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
245 dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
246 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
247 w);
248 } else {
249 dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8,
250 dst8_stride, conv_params,
251 y_filter, h, w);
252 }
253 } else {
254 dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params,
255 y_filter, h, w);
256 }
257 }
258 }
259
convolve4_4_x(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)260 static INLINE uint16x4_t convolve4_4_x(uint8x16_t samples,
261 const int8x8_t x_filter,
262 const int32x4_t correction,
263 const uint8x16_t range_limit,
264 const uint8x16_t permute_tbl) {
265 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
266 int8x16_t clamped_samples =
267 vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
268
269 // Permute samples ready for dot product.
270 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
271 int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
272
273 // Accumulate dot product into 'correction' to account for range clamp.
274 int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
275
276 // We halved the convolution filter values so -1 from the right shift.
277 return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
278 }
279
convolve8_8_x(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)280 static INLINE uint16x8_t convolve8_8_x(uint8x16_t samples,
281 const int8x8_t x_filter,
282 const int32x4_t correction,
283 const uint8x16_t range_limit,
284 const uint8x16x3_t permute_tbl) {
285 int8x16_t clamped_samples, permuted_samples[3];
286 int32x4_t sum[2];
287
288 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
289 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
290
291 // Permute samples ready for dot product. */
292 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
293 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
294 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
295 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
296 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
297 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
298
299 // Accumulate dot product into 'correction' to account for range clamp.
300 // First 4 output values.
301 sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
302 sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
303 // Second 4 output values.
304 sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
305 sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
306
307 // Narrow and re-pack.
308 // We halved the convolution filter values so -1 from the right shift.
309 int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
310 vshrn_n_s32(sum[1], ROUND0_BITS - 1));
311 return vreinterpretq_u16_s16(res);
312 }
313
dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)314 static INLINE void dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
315 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
316 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
317 ConvolveParams *conv_params) {
318 assert(w % 4 == 0);
319 assert(h % 4 == 0);
320
321 const int bd = 8;
322 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
323 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
324 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
325 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
326
327 const uint16_t fwd_offset = conv_params->fwd_offset;
328 const uint16_t bck_offset = conv_params->bck_offset;
329
330 // Horizontal filter.
331 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
332 filter_params_x, subpel_x_qn & SUBPEL_MASK);
333 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
334
335 // Dot-product constants and other shims.
336 const uint8x16_t range_limit = vdupq_n_u8(128);
337 const int32_t correction_s32 =
338 vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
339 // Fold round_offset into the dot-product filter correction constant. The
340 // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
341 // rounding shifts - which are generally faster than rounding shifts on
342 // modern CPUs. (The extra -1 is needed because we halved the filter values.)
343 int32x4_t correction =
344 vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
345 (1 << ((ROUND0_BITS - 1) - 1)));
346
347 const int horiz_offset = filter_params_x->taps / 2 - 1;
348 const uint8_t *src_ptr = src - horiz_offset;
349 CONV_BUF_TYPE *dst_ptr = conv_params->dst;
350 uint8_t *dst8_ptr = dst8;
351 int dst_stride = conv_params->dst_stride;
352 int height = h;
353
354 if (w == 4) {
355 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
356 // 4-tap filters are used for blocks having width <= 4.
357 // Filter values are even, so halve to reduce intermediate precision reqs.
358 const int8x8_t x_filter =
359 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
360
361 src_ptr += 2;
362
363 do {
364 uint8x16_t s0, s1, s2, s3;
365 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
366
367 uint16x4_t d0 =
368 convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
369 uint16x4_t d1 =
370 convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
371 uint16x4_t d2 =
372 convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
373 uint16x4_t d3 =
374 convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
375
376 uint16x4_t dd0, dd1, dd2, dd3;
377 load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
378
379 uint8x8_t d01_u8, d23_u8;
380 compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
381 bck_offset, round_offset_vec, &d01_u8, &d23_u8);
382
383 store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
384 store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
385
386 src_ptr += 4 * src_stride;
387 dst_ptr += 4 * dst_stride;
388 dst8_ptr += 4 * dst8_stride;
389 height -= 4;
390 } while (height != 0);
391 } else {
392 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
393 // Filter values are even, so halve to reduce intermediate precision reqs.
394 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
395
396 do {
397 const uint8_t *s = src_ptr;
398 CONV_BUF_TYPE *d = dst_ptr;
399 uint8_t *d_u8 = dst8_ptr;
400 int width = w;
401
402 do {
403 uint8x16_t s0, s1, s2, s3;
404 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
405
406 uint16x8_t d0 =
407 convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
408 uint16x8_t d1 =
409 convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
410 uint16x8_t d2 =
411 convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
412 uint16x8_t d3 =
413 convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
414
415 uint16x8_t dd0, dd1, dd2, dd3;
416 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
417
418 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
419 compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
420 bck_offset, round_offset_vec, &d0_u8, &d1_u8,
421 &d2_u8, &d3_u8);
422
423 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
424
425 s += 8;
426 d += 8;
427 d_u8 += 8;
428 width -= 8;
429 } while (width != 0);
430 src_ptr += 4 * src_stride;
431 dst_ptr += 4 * dst_stride;
432 dst8_ptr += 4 * dst8_stride;
433 height -= 4;
434 } while (height != 0);
435 }
436 }
437
dist_wtd_convolve_x_avg_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)438 static INLINE void dist_wtd_convolve_x_avg_neon_dotprod(
439 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
440 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
441 ConvolveParams *conv_params) {
442 assert(w % 4 == 0);
443 assert(h % 4 == 0);
444
445 const int bd = 8;
446 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
447 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
448 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
449 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
450
451 // Horizontal filter.
452 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
453 filter_params_x, subpel_x_qn & SUBPEL_MASK);
454 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
455
456 // Dot-product constants and other shims.
457 const uint8x16_t range_limit = vdupq_n_u8(128);
458 const int32_t correction_s32 =
459 vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
460 // Fold round_offset into the dot-product filter correction constant. The
461 // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
462 // rounding shifts - which are generally faster than rounding shifts on
463 // modern CPUs. (The extra -1 is needed because we halved the filter values.)
464 int32x4_t correction =
465 vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
466 (1 << ((ROUND0_BITS - 1) - 1)));
467
468 const int horiz_offset = filter_params_x->taps / 2 - 1;
469 const uint8_t *src_ptr = src - horiz_offset;
470 CONV_BUF_TYPE *dst_ptr = conv_params->dst;
471 uint8_t *dst8_ptr = dst8;
472 int dst_stride = conv_params->dst_stride;
473 int height = h;
474
475 if (w == 4) {
476 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
477 // 4-tap filters are used for blocks having width <= 4.
478 // Filter values are even, so halve to reduce intermediate precision reqs.
479 const int8x8_t x_filter =
480 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
481
482 src_ptr += 2;
483
484 do {
485 uint8x16_t s0, s1, s2, s3;
486 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
487
488 uint16x4_t d0 =
489 convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
490 uint16x4_t d1 =
491 convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
492 uint16x4_t d2 =
493 convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
494 uint16x4_t d3 =
495 convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
496
497 uint16x4_t dd0, dd1, dd2, dd3;
498 load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
499
500 uint8x8_t d01_u8, d23_u8;
501 compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
502 round_offset_vec, &d01_u8, &d23_u8);
503
504 store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
505 store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
506
507 src_ptr += 4 * src_stride;
508 dst_ptr += 4 * dst_stride;
509 dst8_ptr += 4 * dst8_stride;
510 height -= 4;
511 } while (height != 0);
512 } else {
513 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
514 // Filter values are even, so halve to reduce intermediate precision reqs.
515 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
516
517 do {
518 const uint8_t *s = src_ptr;
519 CONV_BUF_TYPE *d = dst_ptr;
520 uint8_t *d_u8 = dst8_ptr;
521 int width = w;
522
523 do {
524 uint8x16_t s0, s1, s2, s3;
525 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
526
527 uint16x8_t d0 =
528 convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
529 uint16x8_t d1 =
530 convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
531 uint16x8_t d2 =
532 convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
533 uint16x8_t d3 =
534 convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
535
536 uint16x8_t dd0, dd1, dd2, dd3;
537 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
538
539 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
540 compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
541 round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
542
543 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
544
545 s += 8;
546 d += 8;
547 d_u8 += 8;
548 width -= 8;
549 } while (width != 0);
550 src_ptr += 4 * src_stride;
551 dst_ptr += 4 * dst_stride;
552 dst8_ptr += 4 * dst8_stride;
553 height -= 4;
554 } while (height != 0);
555 }
556 }
557
dist_wtd_convolve_x_neon_dotprod(const uint8_t * src,int src_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)558 static INLINE void dist_wtd_convolve_x_neon_dotprod(
559 const uint8_t *src, int src_stride, int w, int h,
560 const InterpFilterParams *filter_params_x, const int subpel_x_qn,
561 ConvolveParams *conv_params) {
562 assert(w % 4 == 0);
563 assert(h % 4 == 0);
564
565 const int bd = 8;
566 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
567 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
568 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
569
570 // Horizontal filter.
571 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
572 filter_params_x, subpel_x_qn & SUBPEL_MASK);
573 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
574
575 // Dot-product constants and other shims.
576 const uint8x16_t range_limit = vdupq_n_u8(128);
577 const int32_t correction_s32 =
578 vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
579 // Fold round_offset into the dot-product filter correction constant. The
580 // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
581 // rounding shifts - which are generally faster than rounding shifts on
582 // modern CPUs. (The extra -1 is needed because we halved the filter values.)
583 int32x4_t correction =
584 vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
585 (1 << ((ROUND0_BITS - 1) - 1)));
586
587 const int horiz_offset = filter_params_x->taps / 2 - 1;
588 const uint8_t *src_ptr = src - horiz_offset;
589 CONV_BUF_TYPE *dst_ptr = conv_params->dst;
590 int dst_stride = conv_params->dst_stride;
591 int height = h;
592
593 if (w == 4) {
594 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
595 // 4-tap filters are used for blocks having width <= 4.
596 // Filter values are even, so halve to reduce intermediate precision reqs.
597 const int8x8_t x_filter =
598 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
599
600 src_ptr += 2;
601
602 do {
603 uint8x16_t s0, s1, s2, s3;
604 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
605
606 uint16x4_t d0 =
607 convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
608 uint16x4_t d1 =
609 convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
610 uint16x4_t d2 =
611 convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
612 uint16x4_t d3 =
613 convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
614
615 store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
616
617 src_ptr += 4 * src_stride;
618 dst_ptr += 4 * dst_stride;
619 height -= 4;
620 } while (height != 0);
621 } else {
622 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
623 // Filter values are even, so halve to reduce intermediate precision reqs.
624 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
625
626 do {
627 const uint8_t *s = src_ptr;
628 CONV_BUF_TYPE *d = dst_ptr;
629 int width = w;
630
631 do {
632 uint8x16_t s0, s1, s2, s3;
633 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
634
635 uint16x8_t d0 =
636 convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
637 uint16x8_t d1 =
638 convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
639 uint16x8_t d2 =
640 convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
641 uint16x8_t d3 =
642 convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
643
644 store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
645
646 s += 8;
647 d += 8;
648 width -= 8;
649 } while (width != 0);
650 src_ptr += 4 * src_stride;
651 dst_ptr += 4 * dst_stride;
652 height -= 4;
653 } while (height != 0);
654 }
655 }
656
av1_dist_wtd_convolve_x_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)657 void av1_dist_wtd_convolve_x_neon_dotprod(
658 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
659 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
660 ConvolveParams *conv_params) {
661 if (conv_params->do_average) {
662 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
663 dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
664 src, src_stride, dst8, dst8_stride, w, h, filter_params_x,
665 subpel_x_qn, conv_params);
666 } else {
667 dist_wtd_convolve_x_avg_neon_dotprod(src, src_stride, dst8, dst8_stride,
668 w, h, filter_params_x, subpel_x_qn,
669 conv_params);
670 }
671 } else {
672 dist_wtd_convolve_x_neon_dotprod(src, src_stride, w, h, filter_params_x,
673 subpel_x_qn, conv_params);
674 }
675 }
676