1 /*
2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17
18 #include "aom_dsp/txfm_common.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/transpose_neon.h"
21 #include "aom_ports/mem.h"
22 #include "av1/common/common.h"
23 #include "av1/common/arm/convolve_neon.h"
24
25 #define HORZ_FILTERING_CORE(t0, t1, t2, t3, t4, t5, t6, res) \
26 res0 = vreinterpretq_s16_u16(vaddl_u8(t0, t1)); \
27 res1 = vreinterpretq_s16_u16(vaddl_u8(t2, t3)); \
28 res2 = vreinterpretq_s16_u16(vaddl_u8(t4, t5)); \
29 res3 = vreinterpretq_s16_u16(vmovl_u8(t6)); \
30 res = wiener_convolve8_horiz_8x8(res0, res1, res2, res3, filter_x_tmp, bd, \
31 conv_params->round_0);
32
33 #define PROCESS_ROW_FOR_VERTICAL_FILTER \
34 __builtin_prefetch(dst_tmp_ptr + 0 * dst_stride); \
35 \
36 do { \
37 s7 = vld1q_s16(s); \
38 s += src_stride; \
39 \
40 t0 = wiener_convolve8_vert_4x8(s0, s1, s2, s3, s4, s5, s6, filter_y_tmp, \
41 bd, conv_params->round_1); \
42 vst1_u8(d, t0); \
43 d += dst_stride; \
44 \
45 s0 = s1; \
46 s1 = s2; \
47 s2 = s3; \
48 s3 = s4; \
49 s4 = s5; \
50 s5 = s6; \
51 s6 = s7; \
52 height--; \
53 } while (height > 0);
54
process_row_for_horz_filtering(uint16_t * dst_ptr,int16_t * filter_x,const uint8_t * src_ptr,ptrdiff_t src_stride,ptrdiff_t dst_stride,int round0_bits,int w,int height,int bd)55 static INLINE void process_row_for_horz_filtering(
56 uint16_t *dst_ptr, int16_t *filter_x, const uint8_t *src_ptr,
57 ptrdiff_t src_stride, ptrdiff_t dst_stride, int round0_bits, int w,
58 int height, int bd) {
59 do {
60 __builtin_prefetch(src_ptr);
61
62 uint8x8_t tt0 = vld1_u8(src_ptr); // a0 a1 a2 a3 a4 a5 a6 a7
63
64 __builtin_prefetch(dst_ptr);
65
66 const uint8_t *ss = src_ptr + 8;
67 uint16_t *d_tmp = dst_ptr;
68 int width = w;
69
70 do {
71 uint8x8_t tt7 = vld1_u8(ss); // a8 a9 a10 a11 a12 a13 a14 a15
72 uint8x8_t ttemp_0 = tt0;
73 tt0 = tt7;
74
75 uint8x8_t tt1 = vext_u8(ttemp_0, tt7, 1); // a1 a2 a3 a4 a5 a6 a7 a8
76 uint8x8_t tt2 = vext_u8(ttemp_0, tt7, 2); // a2 a3 a4 a5 a6 a7 a8 a9
77 uint8x8_t tt3 = vext_u8(ttemp_0, tt7, 3); // a3 a4 a5 a6 a7 a8 a9 a10
78 uint8x8_t tt4 = vext_u8(ttemp_0, tt7, 4); // a4 a5 a6 a7 a8 a9 a10 a11
79 uint8x8_t tt5 = vext_u8(ttemp_0, tt7, 5); // a5 a6 a7 a8 a9 a10 a11 a12
80 uint8x8_t tt6 = vext_u8(ttemp_0, tt7, 6); // a6 a7 a8 a9 a10 a11 a12 a13
81 tt7 = vext_u8(ttemp_0, tt7, 7); // a7 a8 a9 a10 a11 a12 a13 a14
82
83 int16x8_t ttt0 = vreinterpretq_s16_u16(vaddl_u8(ttemp_0, tt6));
84 int16x8_t ttt1 = vreinterpretq_s16_u16(vaddl_u8(tt1, tt5));
85 int16x8_t ttt2 = vreinterpretq_s16_u16(vaddl_u8(tt2, tt4));
86 int16x8_t ttt3 = vreinterpretq_s16_u16(vmovl_u8(tt3));
87 uint16x8_t dd0 = wiener_convolve8_horiz_8x8(ttt0, ttt1, ttt2, ttt3,
88 filter_x, bd, round0_bits);
89
90 vst1q_u16(d_tmp, dd0);
91
92 ss += 8;
93 d_tmp += 8;
94 width -= 8;
95 } while (width > 0);
96
97 src_ptr += src_stride;
98 dst_ptr += dst_stride;
99 height--;
100 } while (height > 0);
101 }
102
103 /* Wiener filter 2D
104 Apply horizontal filter and store in a temporary buffer. When applying
105 vertical filter, overwrite the original pixel values.
106 */
av1_wiener_convolve_add_src_neon(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_x,int x_step_q4,const int16_t * filter_y,int y_step_q4,int w,int h,const ConvolveParams * conv_params)107 void av1_wiener_convolve_add_src_neon(const uint8_t *src, ptrdiff_t src_stride,
108 uint8_t *dst, ptrdiff_t dst_stride,
109 const int16_t *filter_x, int x_step_q4,
110 const int16_t *filter_y, int y_step_q4,
111 int w, int h,
112 const ConvolveParams *conv_params) {
113 uint8_t *d;
114 const uint8_t *src_ptr, *s_tmp;
115 uint16_t *dst_ptr;
116 (void)x_step_q4;
117 (void)y_step_q4;
118
119 int height;
120 const int bd = 8;
121 // Indicates the height needs to be processed during horizontal filtering.
122 const int intermediate_height = h + SUBPEL_TAPS - 1;
123 const int center_tap = ((SUBPEL_TAPS - 1) / 2);
124 int16_t filter_x_tmp[7], filter_y_tmp[7];
125
126 DECLARE_ALIGNED(16, uint16_t,
127 temp[(MAX_SB_SIZE + HORIZ_EXTRA_ROWS) * MAX_SB_SIZE]);
128
129 assert(x_step_q4 == 16 && y_step_q4 == 16);
130 assert(!(w % 8));
131
132 assert(w <= MAX_SB_SIZE);
133 assert(h <= MAX_SB_SIZE);
134
135 assert(filter_x[7] == 0);
136 assert(filter_y[7] == 0);
137
138 /* assumption of horizontal filtering output will not exceed 15 bit.
139 ((bd) + 1 + FILTER_BITS - conv_params->round_0) <= 15
140 16 - conv_params->round_0 <= 15 -- (conv_params->round_0) >= 1
141 */
142 assert((conv_params->round_0) >= 1);
143
144 memcpy(&filter_x_tmp[0], filter_x, sizeof(*filter_x) * FILTER_BITS);
145 memcpy(&filter_y_tmp[0], filter_y, sizeof(*filter_y) * FILTER_BITS);
146
147 filter_x_tmp[3] += (1 << FILTER_BITS);
148 filter_y_tmp[3] += (1 << FILTER_BITS);
149
150 s_tmp = src - center_tap * src_stride - center_tap;
151 dst_ptr = temp;
152 src_ptr = s_tmp;
153 height = intermediate_height;
154
155 // For aarch_64.
156 #if defined(__aarch64__)
157 int processed_height = 0;
158 uint16_t *d_tmp;
159 int width, remaining_height;
160 // Start of horizontal filtering.
161 if (intermediate_height > 7) {
162 uint16x8_t res4, res5, res6, res7, res8, res9, res10, res11;
163 uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
164 do {
165 const uint8_t *s;
166
167 __builtin_prefetch(src_ptr + 0 * src_stride);
168 __builtin_prefetch(src_ptr + 1 * src_stride);
169 __builtin_prefetch(src_ptr + 2 * src_stride);
170 __builtin_prefetch(src_ptr + 3 * src_stride);
171 __builtin_prefetch(src_ptr + 4 * src_stride);
172 __builtin_prefetch(src_ptr + 5 * src_stride);
173 __builtin_prefetch(src_ptr + 6 * src_stride);
174 __builtin_prefetch(src_ptr + 7 * src_stride);
175
176 load_u8_8x8(src_ptr, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
177 transpose_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
178
179 s = src_ptr + 7;
180 d_tmp = dst_ptr;
181 width = w;
182
183 __builtin_prefetch(dst_ptr + 0 * dst_stride);
184 __builtin_prefetch(dst_ptr + 1 * dst_stride);
185 __builtin_prefetch(dst_ptr + 2 * dst_stride);
186 __builtin_prefetch(dst_ptr + 3 * dst_stride);
187 __builtin_prefetch(dst_ptr + 4 * dst_stride);
188 __builtin_prefetch(dst_ptr + 5 * dst_stride);
189 __builtin_prefetch(dst_ptr + 6 * dst_stride);
190 __builtin_prefetch(dst_ptr + 7 * dst_stride);
191
192 do {
193 int16x8_t res0, res1, res2, res3;
194 uint8x8_t t8, t9, t10, t11, t12, t13, t14;
195 load_u8_8x8(s, src_stride, &t7, &t8, &t9, &t10, &t11, &t12, &t13, &t14);
196 transpose_u8_8x8(&t7, &t8, &t9, &t10, &t11, &t12, &t13, &t14);
197
198 HORZ_FILTERING_CORE(t0, t6, t1, t5, t2, t4, t3, res4)
199 HORZ_FILTERING_CORE(t1, t7, t2, t6, t3, t5, t4, res5)
200 HORZ_FILTERING_CORE(t2, t8, t3, t7, t4, t6, t5, res6)
201 HORZ_FILTERING_CORE(t3, t9, t4, t8, t5, t7, t6, res7)
202 HORZ_FILTERING_CORE(t4, t10, t5, t9, t6, t8, t7, res8)
203 HORZ_FILTERING_CORE(t5, t11, t6, t10, t7, t9, t8, res9)
204 HORZ_FILTERING_CORE(t6, t12, t7, t11, t8, t10, t9, res10)
205 HORZ_FILTERING_CORE(t7, t13, t8, t12, t9, t11, t10, res11)
206
207 transpose_u16_8x8(&res4, &res5, &res6, &res7, &res8, &res9, &res10,
208 &res11);
209 store_u16_8x8(d_tmp, MAX_SB_SIZE, res4, res5, res6, res7, res8, res9,
210 res10, res11);
211
212 t0 = t8;
213 t1 = t9;
214 t2 = t10;
215 t3 = t11;
216 t4 = t12;
217 t5 = t13;
218 t6 = t14;
219 s += 8;
220 d_tmp += 8;
221 width -= 8;
222 } while (width > 0);
223 src_ptr += 8 * src_stride;
224 dst_ptr += 8 * MAX_SB_SIZE;
225 height -= 8;
226 processed_height += 8;
227 } while (height > 7);
228 }
229
230 // Process the remaining rows for horizontal filtering.
231 remaining_height = intermediate_height - processed_height;
232 if (remaining_height)
233 process_row_for_horz_filtering(dst_ptr, filter_x_tmp, src_ptr, src_stride,
234 MAX_SB_SIZE, conv_params->round_0, w, height,
235 bd);
236
237 // Start of vertical filtering.
238 {
239 int16_t *src_tmp_ptr, *s;
240 uint8_t *dst_tmp_ptr;
241 height = h;
242 width = w;
243 src_tmp_ptr = (int16_t *)temp;
244 dst_tmp_ptr = dst;
245 src_stride = MAX_SB_SIZE;
246
247 do {
248 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
249 uint8x8_t t0;
250 s = src_tmp_ptr;
251 s0 = vld1q_s16(s);
252 s += src_stride;
253 s1 = vld1q_s16(s);
254 s += src_stride;
255 s2 = vld1q_s16(s);
256 s += src_stride;
257 s3 = vld1q_s16(s);
258 s += src_stride;
259 s4 = vld1q_s16(s);
260 s += src_stride;
261 s5 = vld1q_s16(s);
262 s += src_stride;
263 s6 = vld1q_s16(s);
264 s += src_stride;
265 d = dst_tmp_ptr;
266 height = h;
267
268 do {
269 int16x8_t s8, s9, s10;
270 uint8x8_t t1, t2, t3;
271 __builtin_prefetch(dst_tmp_ptr + 0 * dst_stride);
272 __builtin_prefetch(dst_tmp_ptr + 1 * dst_stride);
273 __builtin_prefetch(dst_tmp_ptr + 2 * dst_stride);
274 __builtin_prefetch(dst_tmp_ptr + 3 * dst_stride);
275
276 s7 = vld1q_s16(s);
277 s += src_stride;
278 s8 = vld1q_s16(s);
279 s += src_stride;
280 s9 = vld1q_s16(s);
281 s += src_stride;
282 s10 = vld1q_s16(s);
283 s += src_stride;
284
285 t0 = wiener_convolve8_vert_4x8(s0, s1, s2, s3, s4, s5, s6, filter_y_tmp,
286 bd, conv_params->round_1);
287 t1 = wiener_convolve8_vert_4x8(s1, s2, s3, s4, s5, s6, s7, filter_y_tmp,
288 bd, conv_params->round_1);
289 t2 = wiener_convolve8_vert_4x8(s2, s3, s4, s5, s6, s7, s8, filter_y_tmp,
290 bd, conv_params->round_1);
291 t3 = wiener_convolve8_vert_4x8(s3, s4, s5, s6, s7, s8, s9, filter_y_tmp,
292 bd, conv_params->round_1);
293
294 vst1_u8(d, t0);
295 d += dst_stride;
296 vst1_u8(d, t1);
297 d += dst_stride;
298 vst1_u8(d, t2);
299 d += dst_stride;
300 vst1_u8(d, t3);
301 d += dst_stride;
302
303 s0 = s4;
304 s1 = s5;
305 s2 = s6;
306 s3 = s7;
307 s4 = s8;
308 s5 = s9;
309 s6 = s10;
310 height -= 4;
311 } while (height > 3);
312
313 if (height) {
314 PROCESS_ROW_FOR_VERTICAL_FILTER
315 }
316 src_tmp_ptr += 8;
317 dst_tmp_ptr += 8;
318 w -= 8;
319 } while (w > 0);
320 }
321 #else
322 // Start of horizontal filtering.
323 process_row_for_horz_filtering(dst_ptr, filter_x_tmp, src_ptr, src_stride,
324 MAX_SB_SIZE, conv_params->round_0, w, height,
325 bd);
326
327 // Start of vertical filtering.
328 {
329 int16_t *src_tmp_ptr, *s;
330 uint8_t *dst_tmp_ptr;
331 src_tmp_ptr = (int16_t *)temp;
332 dst_tmp_ptr = dst;
333 src_stride = MAX_SB_SIZE;
334
335 do {
336 uint8x8_t t0;
337 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
338 s = src_tmp_ptr;
339 s0 = vld1q_s16(s);
340 s += src_stride;
341 s1 = vld1q_s16(s);
342 s += src_stride;
343 s2 = vld1q_s16(s);
344 s += src_stride;
345 s3 = vld1q_s16(s);
346 s += src_stride;
347 s4 = vld1q_s16(s);
348 s += src_stride;
349 s5 = vld1q_s16(s);
350 s += src_stride;
351 s6 = vld1q_s16(s);
352 s += src_stride;
353 d = dst_tmp_ptr;
354 height = h;
355 PROCESS_ROW_FOR_VERTICAL_FILTER
356
357 src_tmp_ptr += 8;
358 dst_tmp_ptr += 8;
359
360 w -= 8;
361 } while (w > 0);
362 }
363 #endif
364 }
365