1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_config.h"
15 #include "config/av1_rtcd.h"
16
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_ports/mem.h"
20 #include "av1/common/arm/convolve_neon.h"
21 #include "av1/common/convolve.h"
22 #include "av1/common/filter.h"
23
24 DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
25 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
26 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
27 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
28 };
29
convolve12_4_x(uint8x16_t samples,const int8x16_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)30 static INLINE int16x4_t convolve12_4_x(uint8x16_t samples,
31 const int8x16_t filter,
32 const int32x4_t correction,
33 const uint8x16_t range_limit,
34 const uint8x16x3_t permute_tbl) {
35 int8x16_t clamped_samples, permuted_samples[3];
36 int32x4_t sum;
37
38 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
39 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
40
41 // Permute samples ready for dot product.
42 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
43 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
44 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
45 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
46 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
47 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
48
49 // Accumulate dot product into 'correction' to account for range clamp.
50 // First 4 output values.
51 sum = vdotq_laneq_s32(correction, permuted_samples[0], filter, 0);
52 sum = vdotq_laneq_s32(sum, permuted_samples[1], filter, 1);
53 sum = vdotq_laneq_s32(sum, permuted_samples[2], filter, 2);
54
55 return vqrshrn_n_s32(sum, FILTER_BITS);
56 }
57
convolve12_8_x(uint8x16_t samples[2],const int8x16_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)58 static INLINE uint8x8_t convolve12_8_x(uint8x16_t samples[2],
59 const int8x16_t filter,
60 const int32x4_t correction,
61 const uint8x16_t range_limit,
62 const uint8x16x3_t permute_tbl) {
63 int8x16_t clamped_samples[2], permuted_samples[4];
64 int32x4_t sum[2];
65
66 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
67 clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples[0], range_limit));
68 clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples[1], range_limit));
69
70 // Permute samples ready for dot product.
71 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
72 permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
73 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
74 permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
75 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
76 permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
77 // {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 }
78 permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
79
80 // Accumulate dot product into 'correction' to account for range clamp.
81 // First 4 output values.
82 sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filter, 0);
83 sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filter, 1);
84 sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filter, 2);
85 // Second 4 output values.
86 sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filter, 0);
87 sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filter, 1);
88 sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filter, 2);
89
90 // Narrow and re-pack.
91 int16x8_t sum_s16 = vcombine_s16(vqrshrn_n_s32(sum[0], FILTER_BITS),
92 vqrshrn_n_s32(sum[1], FILTER_BITS));
93 return vqmovun_s16(sum_s16);
94 }
95
convolve_x_sr_12tap_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const int16_t * x_filter_ptr)96 static INLINE void convolve_x_sr_12tap_neon_dotprod(
97 const uint8_t *src, int src_stride, uint8_t *dst, int dst_stride, int w,
98 int h, const int16_t *x_filter_ptr) {
99 const int16x8_t filter_0_7 = vld1q_s16(x_filter_ptr);
100 const int16x4_t filter_8_11 = vld1_s16(x_filter_ptr + 8);
101 const int16x8_t filter_8_15 = vcombine_s16(filter_8_11, vdup_n_s16(0));
102 const int8x16_t filter =
103 vcombine_s8(vmovn_s16(filter_0_7), vmovn_s16(filter_8_15));
104
105 const int32_t correction_s32 =
106 vaddvq_s32(vaddq_s32(vpaddlq_s16(vshlq_n_s16(filter_0_7, FILTER_BITS)),
107 vpaddlq_s16(vshlq_n_s16(filter_8_15, FILTER_BITS))));
108 // A shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding right
109 // shift by FILTER_BITS - instead of a first rounding right shift by
110 // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
111 // ROUND0_BITS.
112 int32x4_t correction = vdupq_n_s32(correction_s32 + (1 << (ROUND0_BITS - 1)));
113 const uint8x16_t range_limit = vdupq_n_u8(128);
114 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
115
116 // Special case the following no-op filter as 128 won't fit into the
117 // 8-bit signed dot-product instruction:
118 // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
119 if (vgetq_lane_s16(filter_0_7, 5) == 128) {
120 // Undo the horizontal offset in the calling function.
121 src += 5;
122
123 do {
124 const uint8_t *s = src;
125 uint8_t *d = dst;
126 int width = w;
127
128 do {
129 uint8x8_t d0 = vld1_u8(s);
130 if (w == 4) {
131 store_u8_4x1(d, d0);
132 } else {
133 vst1_u8(d, d0);
134 }
135
136 s += 8;
137 d += 8;
138 width -= 8;
139 } while (width > 0);
140 src += src_stride;
141 dst += dst_stride;
142 } while (--h != 0);
143 } else {
144 if (w <= 4) {
145 do {
146 uint8x16_t s0, s1, s2, s3;
147 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
148
149 int16x4_t d0 =
150 convolve12_4_x(s0, filter, correction, range_limit, permute_tbl);
151 int16x4_t d1 =
152 convolve12_4_x(s1, filter, correction, range_limit, permute_tbl);
153 int16x4_t d2 =
154 convolve12_4_x(s2, filter, correction, range_limit, permute_tbl);
155 int16x4_t d3 =
156 convolve12_4_x(s3, filter, correction, range_limit, permute_tbl);
157
158 uint8x8_t d01 = vqmovun_s16(vcombine_s16(d0, d1));
159 uint8x8_t d23 = vqmovun_s16(vcombine_s16(d2, d3));
160
161 store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
162 store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
163
164 dst += 4 * dst_stride;
165 src += 4 * src_stride;
166 h -= 4;
167 } while (h != 0);
168 } else {
169 do {
170 const uint8_t *s = src;
171 uint8_t *d = dst;
172 int width = w;
173
174 do {
175 uint8x16_t s0[2], s1[2], s2[2], s3[2];
176 load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
177 load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
178
179 uint8x8_t d0 =
180 convolve12_8_x(s0, filter, correction, range_limit, permute_tbl);
181 uint8x8_t d1 =
182 convolve12_8_x(s1, filter, correction, range_limit, permute_tbl);
183 uint8x8_t d2 =
184 convolve12_8_x(s2, filter, correction, range_limit, permute_tbl);
185 uint8x8_t d3 =
186 convolve12_8_x(s3, filter, correction, range_limit, permute_tbl);
187
188 store_u8_8x4(d + 0 * dst_stride, dst_stride, d0, d1, d2, d3);
189
190 s += 8;
191 d += 8;
192 width -= 8;
193 } while (width != 0);
194 src += 4 * src_stride;
195 dst += 4 * dst_stride;
196 h -= 4;
197 } while (h != 0);
198 }
199 }
200 }
201
convolve4_4_x(uint8x16_t samples,const int8x8_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)202 static INLINE int16x4_t convolve4_4_x(uint8x16_t samples, const int8x8_t filter,
203 const int32x4_t correction,
204 const uint8x16_t range_limit,
205 const uint8x16_t permute_tbl) {
206 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
207 int8x16_t clamped_samples =
208 vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
209
210 // Permute samples ready for dot product.
211 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
212 int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
213
214 // Accumulate dot product into 'correction' to account for range clamp.
215 int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, filter, 0);
216
217 // Packing is performed by the caller.
218 return vmovn_s32(sum);
219 }
220
convolve8_8_x(uint8x16_t samples,const int8x8_t filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)221 static INLINE uint8x8_t convolve8_8_x(uint8x16_t samples, const int8x8_t filter,
222 const int32x4_t correction,
223 const uint8x16_t range_limit,
224 const uint8x16x3_t permute_tbl) {
225 int8x16_t clamped_samples, permuted_samples[3];
226 int32x4_t sum[2];
227
228 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
229 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
230
231 // Permute samples ready for dot product. */
232 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
233 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
234 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
235 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
236 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
237 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
238
239 // Accumulate dot product into 'correction' to account for range clamp.
240 // First 4 output values.
241 sum[0] = vdotq_lane_s32(correction, permuted_samples[0], filter, 0);
242 sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], filter, 1);
243 // Second 4 output values.
244 sum[1] = vdotq_lane_s32(correction, permuted_samples[1], filter, 0);
245 sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], filter, 1);
246
247 // Narrow and re-pack.
248 int16x8_t sum_s16 = vcombine_s16(vmovn_s32(sum[0]), vmovn_s32(sum[1]));
249 // We halved the convolution filter values so - 1 from the right shift.
250 return vqrshrun_n_s16(sum_s16, FILTER_BITS - 1);
251 }
252
av1_convolve_x_sr_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)253 void av1_convolve_x_sr_neon_dotprod(const uint8_t *src, int src_stride,
254 uint8_t *dst, int dst_stride, int w, int h,
255 const InterpFilterParams *filter_params_x,
256 const int subpel_x_qn,
257 ConvolveParams *conv_params) {
258 if (w == 2 || h == 2) {
259 av1_convolve_x_sr_c(src, src_stride, dst, dst_stride, w, h, filter_params_x,
260 subpel_x_qn, conv_params);
261 return;
262 }
263
264 const uint8_t horiz_offset = filter_params_x->taps / 2 - 1;
265 src -= horiz_offset;
266
267 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
268 filter_params_x, subpel_x_qn & SUBPEL_MASK);
269
270 if (filter_params_x->taps > 8) {
271 convolve_x_sr_12tap_neon_dotprod(src, src_stride, dst, dst_stride, w, h,
272 x_filter_ptr);
273 return;
274 }
275
276 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
277 // Dot product constants.
278 const int32_t correction_s32 =
279 vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
280 // This shim of (1 << ((ROUND0_BITS - 1) - 1) enables us to use a single
281 // rounding right shift by FILTER_BITS - instead of a first rounding right
282 // shift by ROUND0_BITS, followed by second rounding right shift by
283 // FILTER_BITS - ROUND0_BITS.
284 // The outermost -1 is needed because we will halve the filter values.
285 const int32x4_t correction =
286 vdupq_n_s32(correction_s32 + (1 << ((ROUND0_BITS - 1) - 1)));
287 const uint8x16_t range_limit = vdupq_n_u8(128);
288
289 if (w <= 4) {
290 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
291 // 4-tap filters are used for blocks having width <= 4.
292 // Filter values are even, so halve to reduce intermediate precision reqs.
293 const int8x8_t x_filter =
294 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
295
296 src += 2;
297
298 do {
299 uint8x16_t s0, s1, s2, s3;
300 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
301
302 int16x4_t d0 =
303 convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
304 int16x4_t d1 =
305 convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
306 int16x4_t d2 =
307 convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
308 int16x4_t d3 =
309 convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
310
311 // We halved the convolution filter values so - 1 from the right shift.
312 uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS - 1);
313 uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS - 1);
314
315 store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
316 store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
317
318 src += 4 * src_stride;
319 dst += 4 * dst_stride;
320 h -= 4;
321 } while (h != 0);
322 } else {
323 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
324 // Filter values are even, so halve to reduce intermediate precision reqs.
325 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
326
327 do {
328 int width = w;
329 const uint8_t *s = src;
330 uint8_t *d = dst;
331
332 do {
333 uint8x16_t s0, s1, s2, s3;
334 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
335
336 uint8x8_t d0 =
337 convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
338 uint8x8_t d1 =
339 convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
340 uint8x8_t d2 =
341 convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
342 uint8x8_t d3 =
343 convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
344
345 store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
346
347 s += 8;
348 d += 8;
349 width -= 8;
350 } while (width != 0);
351 src += 4 * src_stride;
352 dst += 4 * dst_stride;
353 h -= 4;
354 } while (h != 0);
355 }
356 }
357
convolve12_4_2d_h(uint8x16_t samples,const int8x16_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)358 static INLINE int16x4_t convolve12_4_2d_h(uint8x16_t samples,
359 const int8x16_t filters,
360 const int32x4_t correction,
361 const uint8x16_t range_limit,
362 const uint8x16x3_t permute_tbl) {
363 int8x16_t clamped_samples, permuted_samples[3];
364 int32x4_t sum;
365
366 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
367 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
368
369 // Permute samples ready for dot product.
370 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
371 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
372 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
373 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
374 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
375 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
376
377 // Accumulate dot product into 'correction' to account for range clamp.
378 // First 4 output values.
379 sum = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
380 sum = vdotq_laneq_s32(sum, permuted_samples[1], filters, 1);
381 sum = vdotq_laneq_s32(sum, permuted_samples[2], filters, 2);
382
383 // Narrow and re-pack.
384 return vshrn_n_s32(sum, ROUND0_BITS);
385 }
386
convolve12_8_2d_h(uint8x16_t samples[2],const int8x16_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)387 static INLINE int16x8_t convolve12_8_2d_h(uint8x16_t samples[2],
388 const int8x16_t filters,
389 const int32x4_t correction,
390 const uint8x16_t range_limit,
391 const uint8x16x3_t permute_tbl) {
392 int8x16_t clamped_samples[2], permuted_samples[4];
393 int32x4_t sum[2];
394
395 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
396 clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples[0], range_limit));
397 clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples[1], range_limit));
398
399 // Permute samples ready for dot product.
400 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
401 permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
402 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
403 permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
404 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
405 permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
406 // {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 }
407 permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
408
409 // Accumulate dot product into 'correction' to account for range clamp.
410 // First 4 output values.
411 sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
412 sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filters, 1);
413 sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filters, 2);
414 // Second 4 output values.
415 sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filters, 0);
416 sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filters, 1);
417 sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filters, 2);
418
419 // Narrow and re-pack.
420 return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS),
421 vshrn_n_s32(sum[1], ROUND0_BITS));
422 }
423
convolve_2d_sr_horiz_12tap_neon_dotprod(const uint8_t * src_ptr,int src_stride,int16_t * dst_ptr,const int dst_stride,int w,int h,const int16x8_t x_filter_0_7,const int16x4_t x_filter_8_11)424 static INLINE void convolve_2d_sr_horiz_12tap_neon_dotprod(
425 const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
426 const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
427 const int16x4_t x_filter_8_11) {
428 const int bd = 8;
429
430 // Special case the following no-op filter as 128 won't fit into the 8-bit
431 // signed dot-product instruction:
432 // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
433 if (vgetq_lane_s16(x_filter_0_7, 5) == 128) {
434 const uint16x8_t horiz_const = vdupq_n_u16((1 << (bd - 1)));
435 // Undo the horizontal offset in the calling function.
436 src_ptr += 5;
437
438 do {
439 const uint8_t *s = src_ptr;
440 int16_t *d = dst_ptr;
441 int width = w;
442
443 do {
444 uint8x8_t s0 = vld1_u8(s);
445 uint16x8_t d0 = vaddw_u8(horiz_const, s0);
446 d0 = vshlq_n_u16(d0, FILTER_BITS - ROUND0_BITS);
447 // Store 8 elements to avoid additional branches. This is safe if the
448 // actual block width is < 8 because the intermediate buffer is large
449 // enough to accommodate 128x128 blocks.
450 vst1q_s16(d, vreinterpretq_s16_u16(d0));
451
452 d += 8;
453 s += 8;
454 width -= 8;
455 } while (width > 0);
456 src_ptr += src_stride;
457 dst_ptr += dst_stride;
458 } while (--h != 0);
459
460 } else {
461 // Narrow filter values to 8-bit.
462 const int16x8x2_t x_filter_s16 = {
463 { x_filter_0_7, vcombine_s16(x_filter_8_11, vdup_n_s16(0)) }
464 };
465 const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
466 vmovn_s16(x_filter_s16.val[1]));
467
468 // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
469 // - which are generally faster than rounding shifts on modern CPUs.
470 const int32_t horiz_const =
471 ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
472 // Dot product constants.
473 const int32x4_t correct_tmp =
474 vaddq_s32(vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[0], 7)),
475 vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[1], 7)));
476 const int32x4_t correction =
477 vdupq_n_s32(vaddvq_s32(correct_tmp) + horiz_const);
478 const uint8x16_t range_limit = vdupq_n_u8(128);
479 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
480
481 if (w <= 4) {
482 do {
483 uint8x16_t s0, s1, s2, s3;
484 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
485
486 int16x4_t d0 = convolve12_4_2d_h(s0, x_filter, correction, range_limit,
487 permute_tbl);
488 int16x4_t d1 = convolve12_4_2d_h(s1, x_filter, correction, range_limit,
489 permute_tbl);
490 int16x4_t d2 = convolve12_4_2d_h(s2, x_filter, correction, range_limit,
491 permute_tbl);
492 int16x4_t d3 = convolve12_4_2d_h(s3, x_filter, correction, range_limit,
493 permute_tbl);
494
495 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
496
497 src_ptr += 4 * src_stride;
498 dst_ptr += 4 * dst_stride;
499 h -= 4;
500 } while (h > 4);
501
502 do {
503 uint8x16_t s0 = vld1q_u8(src_ptr);
504 int16x4_t d0 = convolve12_4_2d_h(s0, x_filter, correction, range_limit,
505 permute_tbl);
506 vst1_s16(dst_ptr, d0);
507
508 src_ptr += src_stride;
509 dst_ptr += dst_stride;
510 } while (--h != 0);
511
512 } else {
513 do {
514 const uint8_t *s = src_ptr;
515 int16_t *d = dst_ptr;
516 int width = w;
517
518 do {
519 uint8x16_t s0[2], s1[2], s2[2], s3[2];
520 load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
521 load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
522
523 int16x8_t d0 = convolve12_8_2d_h(s0, x_filter, correction,
524 range_limit, permute_tbl);
525 int16x8_t d1 = convolve12_8_2d_h(s1, x_filter, correction,
526 range_limit, permute_tbl);
527 int16x8_t d2 = convolve12_8_2d_h(s2, x_filter, correction,
528 range_limit, permute_tbl);
529 int16x8_t d3 = convolve12_8_2d_h(s3, x_filter, correction,
530 range_limit, permute_tbl);
531
532 store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
533
534 s += 8;
535 d += 8;
536 width -= 8;
537 } while (width != 0);
538 src_ptr += 4 * src_stride;
539 dst_ptr += 4 * dst_stride;
540 h -= 4;
541 } while (h > 4);
542
543 do {
544 const uint8_t *s = src_ptr;
545 int16_t *d = dst_ptr;
546 int width = w;
547
548 do {
549 uint8x16_t s0[2];
550 s0[0] = vld1q_u8(s);
551 s0[1] = vld1q_u8(s + 4);
552 int16x8_t d0 = convolve12_8_2d_h(s0, x_filter, correction,
553 range_limit, permute_tbl);
554 vst1q_s16(d, d0);
555
556 s += 8;
557 d += 8;
558 width -= 8;
559 } while (width != 0);
560 src_ptr += src_stride;
561 dst_ptr += dst_stride;
562 } while (--h != 0);
563 }
564 }
565 }
566
convolve4_4_2d_h(uint8x16_t samples,const int8x8_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)567 static INLINE int16x4_t convolve4_4_2d_h(uint8x16_t samples,
568 const int8x8_t filters,
569 const int32x4_t correction,
570 const uint8x16_t range_limit,
571 const uint8x16_t permute_tbl) {
572 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
573 int8x16_t clamped_samples =
574 vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
575
576 // Permute samples ready for dot product.
577 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
578 int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
579
580 // Accumulate dot product into 'correction' to account for range clamp.
581 int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, filters, 0);
582
583 // We halved the convolution filter values so -1 from the right shift.
584 return vshrn_n_s32(sum, ROUND0_BITS - 1);
585 }
586
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t filters,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)587 static INLINE int16x8_t convolve8_8_2d_h(uint8x16_t samples,
588 const int8x8_t filters,
589 const int32x4_t correction,
590 const uint8x16_t range_limit,
591 const uint8x16x3_t permute_tbl) {
592 int8x16_t clamped_samples, permuted_samples[3];
593 int32x4_t sum[2];
594
595 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
596 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
597
598 // Permute samples ready for dot product.
599 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
600 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
601 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
602 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
603 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
604 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
605
606 // Accumulate dot product into 'correction' to account for range clamp.
607 // First 4 output values.
608 sum[0] = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
609 sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
610 // Second 4 output values.
611 sum[1] = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
612 sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
613
614 // Narrow and re-pack.
615 // We halved the convolution filter values so -1 from the right shift.
616 return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
617 vshrn_n_s32(sum[1], ROUND0_BITS - 1));
618 }
619
convolve_2d_sr_horiz_neon_dotprod(const uint8_t * src,int src_stride,int16_t * im_block,int im_stride,int w,int im_h,const int16_t * x_filter_ptr)620 static INLINE void convolve_2d_sr_horiz_neon_dotprod(
621 const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
622 int im_h, const int16_t *x_filter_ptr) {
623 const int bd = 8;
624 // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
625 // shifts - which are generally faster than rounding shifts on modern CPUs.
626 // The outermost -1 is needed because we halved the filter values.
627 const int32_t horiz_const =
628 ((1 << (bd + FILTER_BITS - 2)) + (1 << ((ROUND0_BITS - 1) - 1)));
629 // Dot product constants.
630 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
631 const int32_t correction_s32 =
632 vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
633 const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const);
634 const uint8x16_t range_limit = vdupq_n_u8(128);
635
636 const uint8_t *src_ptr = src;
637 int16_t *dst_ptr = im_block;
638 int dst_stride = im_stride;
639 int height = im_h;
640
641 if (w <= 4) {
642 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
643 // 4-tap filters are used for blocks having width <= 4.
644 // Filter values are even, so halve to reduce intermediate precision reqs.
645 const int8x8_t x_filter =
646 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
647
648 src_ptr += 2;
649
650 do {
651 uint8x16_t s0, s1, s2, s3;
652 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
653
654 int16x4_t d0 =
655 convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
656 int16x4_t d1 =
657 convolve4_4_2d_h(s1, x_filter, correction, range_limit, permute_tbl);
658 int16x4_t d2 =
659 convolve4_4_2d_h(s2, x_filter, correction, range_limit, permute_tbl);
660 int16x4_t d3 =
661 convolve4_4_2d_h(s3, x_filter, correction, range_limit, permute_tbl);
662
663 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
664
665 src_ptr += 4 * src_stride;
666 dst_ptr += 4 * dst_stride;
667 height -= 4;
668 } while (height > 4);
669
670 do {
671 uint8x16_t s0 = vld1q_u8(src_ptr);
672 int16x4_t d0 =
673 convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
674 vst1_s16(dst_ptr, d0);
675
676 src_ptr += src_stride;
677 dst_ptr += dst_stride;
678 } while (--height != 0);
679 } else {
680 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
681 // Filter values are even, so halve to reduce intermediate precision reqs.
682 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
683
684 do {
685 const uint8_t *s = src_ptr;
686 int16_t *d = dst_ptr;
687 int width = w;
688
689 do {
690 uint8x16_t s0, s1, s2, s3;
691 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
692
693 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
694 permute_tbl);
695 int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, correction, range_limit,
696 permute_tbl);
697 int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, correction, range_limit,
698 permute_tbl);
699 int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, correction, range_limit,
700 permute_tbl);
701
702 store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
703
704 s += 8;
705 d += 8;
706 width -= 8;
707 } while (width != 0);
708 src_ptr += 4 * src_stride;
709 dst_ptr += 4 * dst_stride;
710 height -= 4;
711 } while (height > 4);
712
713 do {
714 const uint8_t *s = src_ptr;
715 int16_t *d = dst_ptr;
716 int width = w;
717
718 do {
719 uint8x16_t s0 = vld1q_u8(s);
720 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
721 permute_tbl);
722 vst1q_s16(d, d0);
723
724 s += 8;
725 d += 8;
726 width -= 8;
727 } while (width != 0);
728 src_ptr += src_stride;
729 dst_ptr += dst_stride;
730 } while (--height != 0);
731 }
732 }
733
av1_convolve_2d_sr_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)734 void av1_convolve_2d_sr_neon_dotprod(const uint8_t *src, int src_stride,
735 uint8_t *dst, int dst_stride, int w, int h,
736 const InterpFilterParams *filter_params_x,
737 const InterpFilterParams *filter_params_y,
738 const int subpel_x_qn,
739 const int subpel_y_qn,
740 ConvolveParams *conv_params) {
741 if (w == 2 || h == 2) {
742 av1_convolve_2d_sr_c(src, src_stride, dst, dst_stride, w, h,
743 filter_params_x, filter_params_y, subpel_x_qn,
744 subpel_y_qn, conv_params);
745 return;
746 }
747
748 const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
749 const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
750 const int im_h = h + clamped_y_taps - 1;
751 const int im_stride = MAX_SB_SIZE;
752 const int vert_offset = clamped_y_taps / 2 - 1;
753 const int horiz_offset = filter_params_x->taps / 2 - 1;
754 const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
755
756 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
757 filter_params_x, subpel_x_qn & SUBPEL_MASK);
758 const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
759 filter_params_y, subpel_y_qn & SUBPEL_MASK);
760
761 if (filter_params_x->taps > 8) {
762 DECLARE_ALIGNED(16, int16_t,
763 im_block[(MAX_SB_SIZE + MAX_FILTER_TAP - 1) * MAX_SB_SIZE]);
764
765 const int16x8_t x_filter_0_7 = vld1q_s16(x_filter_ptr);
766 const int16x4_t x_filter_8_11 = vld1_s16(x_filter_ptr + 8);
767 const int16x8_t y_filter_0_7 = vld1q_s16(y_filter_ptr);
768 const int16x4_t y_filter_8_11 = vld1_s16(y_filter_ptr + 8);
769
770 convolve_2d_sr_horiz_12tap_neon_dotprod(src_ptr, src_stride, im_block,
771 im_stride, w, im_h, x_filter_0_7,
772 x_filter_8_11);
773
774 convolve_2d_sr_vert_12tap_neon(im_block, im_stride, dst, dst_stride, w, h,
775 y_filter_0_7, y_filter_8_11);
776 } else {
777 DECLARE_ALIGNED(16, int16_t,
778 im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
779
780 convolve_2d_sr_horiz_neon_dotprod(src_ptr, src_stride, im_block, im_stride,
781 w, im_h, x_filter_ptr);
782
783 const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
784
785 if (clamped_y_taps <= 6) {
786 convolve_2d_sr_vert_6tap_neon(im_block, im_stride, dst, dst_stride, w, h,
787 y_filter);
788 } else {
789 convolve_2d_sr_vert_8tap_neon(im_block, im_stride, dst, dst_stride, w, h,
790 y_filter);
791 }
792 }
793 }
794