• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/txfm_common.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/transpose_neon.h"
21 #include "aom_ports/mem.h"
22 #include "av1/common/common.h"
23 #include "av1/common/arm/convolve_neon.h"
24 
25 #define HORZ_FILTERING_CORE(t0, t1, t2, t3, t4, t5, t6, res)                 \
26   res0 = vreinterpretq_s16_u16(vaddl_u8(t0, t1));                            \
27   res1 = vreinterpretq_s16_u16(vaddl_u8(t2, t3));                            \
28   res2 = vreinterpretq_s16_u16(vaddl_u8(t4, t5));                            \
29   res3 = vreinterpretq_s16_u16(vmovl_u8(t6));                                \
30   res = wiener_convolve8_horiz_8x8(res0, res1, res2, res3, filter_x_tmp, bd, \
31                                    conv_params->round_0);
32 
33 #define PROCESS_ROW_FOR_VERTICAL_FILTER                                      \
34   __builtin_prefetch(dst_tmp_ptr + 0 * dst_stride);                          \
35                                                                              \
36   do {                                                                       \
37     s7 = vld1q_s16(s);                                                       \
38     s += src_stride;                                                         \
39                                                                              \
40     t0 = wiener_convolve8_vert_4x8(s0, s1, s2, s3, s4, s5, s6, filter_y_tmp, \
41                                    bd, conv_params->round_1);                \
42     vst1_u8(d, t0);                                                          \
43     d += dst_stride;                                                         \
44                                                                              \
45     s0 = s1;                                                                 \
46     s1 = s2;                                                                 \
47     s2 = s3;                                                                 \
48     s3 = s4;                                                                 \
49     s4 = s5;                                                                 \
50     s5 = s6;                                                                 \
51     s6 = s7;                                                                 \
52     height--;                                                                \
53   } while (height > 0);
54 
process_row_for_horz_filtering(uint16_t * dst_ptr,int16_t * filter_x,const uint8_t * src_ptr,ptrdiff_t src_stride,ptrdiff_t dst_stride,int round0_bits,int w,int height,int bd)55 static INLINE void process_row_for_horz_filtering(
56     uint16_t *dst_ptr, int16_t *filter_x, const uint8_t *src_ptr,
57     ptrdiff_t src_stride, ptrdiff_t dst_stride, int round0_bits, int w,
58     int height, int bd) {
59   do {
60     __builtin_prefetch(src_ptr);
61 
62     uint8x8_t tt0 = vld1_u8(src_ptr);  // a0 a1 a2 a3 a4 a5 a6 a7
63 
64     __builtin_prefetch(dst_ptr);
65 
66     const uint8_t *ss = src_ptr + 8;
67     uint16_t *d_tmp = dst_ptr;
68     int width = w;
69 
70     do {
71       uint8x8_t tt7 = vld1_u8(ss);  // a8 a9 a10 a11 a12 a13 a14 a15
72       uint8x8_t ttemp_0 = tt0;
73       tt0 = tt7;
74 
75       uint8x8_t tt1 = vext_u8(ttemp_0, tt7, 1);  // a1 a2 a3 a4 a5 a6 a7 a8
76       uint8x8_t tt2 = vext_u8(ttemp_0, tt7, 2);  // a2 a3 a4 a5 a6 a7 a8 a9
77       uint8x8_t tt3 = vext_u8(ttemp_0, tt7, 3);  // a3 a4 a5 a6 a7 a8 a9 a10
78       uint8x8_t tt4 = vext_u8(ttemp_0, tt7, 4);  // a4 a5 a6 a7 a8 a9 a10 a11
79       uint8x8_t tt5 = vext_u8(ttemp_0, tt7, 5);  // a5 a6 a7 a8 a9 a10 a11 a12
80       uint8x8_t tt6 = vext_u8(ttemp_0, tt7, 6);  // a6 a7 a8 a9 a10 a11 a12 a13
81       tt7 = vext_u8(ttemp_0, tt7, 7);            // a7 a8 a9 a10 a11 a12 a13 a14
82 
83       int16x8_t ttt0 = vreinterpretq_s16_u16(vaddl_u8(ttemp_0, tt6));
84       int16x8_t ttt1 = vreinterpretq_s16_u16(vaddl_u8(tt1, tt5));
85       int16x8_t ttt2 = vreinterpretq_s16_u16(vaddl_u8(tt2, tt4));
86       int16x8_t ttt3 = vreinterpretq_s16_u16(vmovl_u8(tt3));
87       uint16x8_t dd0 = wiener_convolve8_horiz_8x8(ttt0, ttt1, ttt2, ttt3,
88                                                   filter_x, bd, round0_bits);
89 
90       vst1q_u16(d_tmp, dd0);
91 
92       ss += 8;
93       d_tmp += 8;
94       width -= 8;
95     } while (width > 0);
96 
97     src_ptr += src_stride;
98     dst_ptr += dst_stride;
99     height--;
100   } while (height > 0);
101 }
102 
103 /* Wiener filter 2D
104    Apply horizontal filter and store in a temporary buffer. When applying
105    vertical filter, overwrite the original pixel values.
106 */
av1_wiener_convolve_add_src_neon(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_x,int x_step_q4,const int16_t * filter_y,int y_step_q4,int w,int h,const ConvolveParams * conv_params)107 void av1_wiener_convolve_add_src_neon(const uint8_t *src, ptrdiff_t src_stride,
108                                       uint8_t *dst, ptrdiff_t dst_stride,
109                                       const int16_t *filter_x, int x_step_q4,
110                                       const int16_t *filter_y, int y_step_q4,
111                                       int w, int h,
112                                       const ConvolveParams *conv_params) {
113   uint8_t *d;
114   const uint8_t *src_ptr, *s_tmp;
115   uint16_t *dst_ptr;
116   (void)x_step_q4;
117   (void)y_step_q4;
118 
119   int height;
120   const int bd = 8;
121   // Indicates the height needs to be processed during horizontal filtering.
122   const int intermediate_height = h + SUBPEL_TAPS - 1;
123   const int center_tap = ((SUBPEL_TAPS - 1) / 2);
124   int16_t filter_x_tmp[7], filter_y_tmp[7];
125 
126   DECLARE_ALIGNED(16, uint16_t,
127                   temp[(MAX_SB_SIZE + HORIZ_EXTRA_ROWS) * MAX_SB_SIZE]);
128 
129   assert(x_step_q4 == 16 && y_step_q4 == 16);
130   assert(!(w % 8));
131 
132   assert(w <= MAX_SB_SIZE);
133   assert(h <= MAX_SB_SIZE);
134 
135   assert(filter_x[7] == 0);
136   assert(filter_y[7] == 0);
137 
138   /* assumption of horizontal filtering output will not exceed 15 bit.
139      ((bd) + 1 + FILTER_BITS - conv_params->round_0) <= 15
140      16 - conv_params->round_0 <= 15 -- (conv_params->round_0) >= 1
141    */
142   assert((conv_params->round_0) >= 1);
143 
144   memcpy(&filter_x_tmp[0], filter_x, sizeof(*filter_x) * FILTER_BITS);
145   memcpy(&filter_y_tmp[0], filter_y, sizeof(*filter_y) * FILTER_BITS);
146 
147   filter_x_tmp[3] += (1 << FILTER_BITS);
148   filter_y_tmp[3] += (1 << FILTER_BITS);
149 
150   s_tmp = src - center_tap * src_stride - center_tap;
151   dst_ptr = temp;
152   src_ptr = s_tmp;
153   height = intermediate_height;
154 
155   // For aarch_64.
156 #if defined(__aarch64__)
157   int processed_height = 0;
158   uint16_t *d_tmp;
159   int width, remaining_height;
160   // Start of horizontal filtering.
161   if (intermediate_height > 7) {
162     uint16x8_t res4, res5, res6, res7, res8, res9, res10, res11;
163     uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
164     do {
165       const uint8_t *s;
166 
167       __builtin_prefetch(src_ptr + 0 * src_stride);
168       __builtin_prefetch(src_ptr + 1 * src_stride);
169       __builtin_prefetch(src_ptr + 2 * src_stride);
170       __builtin_prefetch(src_ptr + 3 * src_stride);
171       __builtin_prefetch(src_ptr + 4 * src_stride);
172       __builtin_prefetch(src_ptr + 5 * src_stride);
173       __builtin_prefetch(src_ptr + 6 * src_stride);
174       __builtin_prefetch(src_ptr + 7 * src_stride);
175 
176       load_u8_8x8(src_ptr, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
177       transpose_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
178 
179       s = src_ptr + 7;
180       d_tmp = dst_ptr;
181       width = w;
182 
183       __builtin_prefetch(dst_ptr + 0 * dst_stride);
184       __builtin_prefetch(dst_ptr + 1 * dst_stride);
185       __builtin_prefetch(dst_ptr + 2 * dst_stride);
186       __builtin_prefetch(dst_ptr + 3 * dst_stride);
187       __builtin_prefetch(dst_ptr + 4 * dst_stride);
188       __builtin_prefetch(dst_ptr + 5 * dst_stride);
189       __builtin_prefetch(dst_ptr + 6 * dst_stride);
190       __builtin_prefetch(dst_ptr + 7 * dst_stride);
191 
192       do {
193         int16x8_t res0, res1, res2, res3;
194         uint8x8_t t8, t9, t10, t11, t12, t13, t14;
195         load_u8_8x8(s, src_stride, &t7, &t8, &t9, &t10, &t11, &t12, &t13, &t14);
196         transpose_u8_8x8(&t7, &t8, &t9, &t10, &t11, &t12, &t13, &t14);
197 
198         HORZ_FILTERING_CORE(t0, t6, t1, t5, t2, t4, t3, res4)
199         HORZ_FILTERING_CORE(t1, t7, t2, t6, t3, t5, t4, res5)
200         HORZ_FILTERING_CORE(t2, t8, t3, t7, t4, t6, t5, res6)
201         HORZ_FILTERING_CORE(t3, t9, t4, t8, t5, t7, t6, res7)
202         HORZ_FILTERING_CORE(t4, t10, t5, t9, t6, t8, t7, res8)
203         HORZ_FILTERING_CORE(t5, t11, t6, t10, t7, t9, t8, res9)
204         HORZ_FILTERING_CORE(t6, t12, t7, t11, t8, t10, t9, res10)
205         HORZ_FILTERING_CORE(t7, t13, t8, t12, t9, t11, t10, res11)
206 
207         transpose_u16_8x8(&res4, &res5, &res6, &res7, &res8, &res9, &res10,
208                           &res11);
209         store_u16_8x8(d_tmp, MAX_SB_SIZE, res4, res5, res6, res7, res8, res9,
210                       res10, res11);
211 
212         t0 = t8;
213         t1 = t9;
214         t2 = t10;
215         t3 = t11;
216         t4 = t12;
217         t5 = t13;
218         t6 = t14;
219         s += 8;
220         d_tmp += 8;
221         width -= 8;
222       } while (width > 0);
223       src_ptr += 8 * src_stride;
224       dst_ptr += 8 * MAX_SB_SIZE;
225       height -= 8;
226       processed_height += 8;
227     } while (height > 7);
228   }
229 
230   // Process the remaining rows for horizontal filtering.
231   remaining_height = intermediate_height - processed_height;
232   if (remaining_height)
233     process_row_for_horz_filtering(dst_ptr, filter_x_tmp, src_ptr, src_stride,
234                                    MAX_SB_SIZE, conv_params->round_0, w, height,
235                                    bd);
236 
237   // Start of vertical filtering.
238   {
239     int16_t *src_tmp_ptr, *s;
240     uint8_t *dst_tmp_ptr;
241     height = h;
242     width = w;
243     src_tmp_ptr = (int16_t *)temp;
244     dst_tmp_ptr = dst;
245     src_stride = MAX_SB_SIZE;
246 
247     do {
248       int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
249       uint8x8_t t0;
250       s = src_tmp_ptr;
251       s0 = vld1q_s16(s);
252       s += src_stride;
253       s1 = vld1q_s16(s);
254       s += src_stride;
255       s2 = vld1q_s16(s);
256       s += src_stride;
257       s3 = vld1q_s16(s);
258       s += src_stride;
259       s4 = vld1q_s16(s);
260       s += src_stride;
261       s5 = vld1q_s16(s);
262       s += src_stride;
263       s6 = vld1q_s16(s);
264       s += src_stride;
265       d = dst_tmp_ptr;
266       height = h;
267 
268       do {
269         int16x8_t s8, s9, s10;
270         uint8x8_t t1, t2, t3;
271         __builtin_prefetch(dst_tmp_ptr + 0 * dst_stride);
272         __builtin_prefetch(dst_tmp_ptr + 1 * dst_stride);
273         __builtin_prefetch(dst_tmp_ptr + 2 * dst_stride);
274         __builtin_prefetch(dst_tmp_ptr + 3 * dst_stride);
275 
276         s7 = vld1q_s16(s);
277         s += src_stride;
278         s8 = vld1q_s16(s);
279         s += src_stride;
280         s9 = vld1q_s16(s);
281         s += src_stride;
282         s10 = vld1q_s16(s);
283         s += src_stride;
284 
285         t0 = wiener_convolve8_vert_4x8(s0, s1, s2, s3, s4, s5, s6, filter_y_tmp,
286                                        bd, conv_params->round_1);
287         t1 = wiener_convolve8_vert_4x8(s1, s2, s3, s4, s5, s6, s7, filter_y_tmp,
288                                        bd, conv_params->round_1);
289         t2 = wiener_convolve8_vert_4x8(s2, s3, s4, s5, s6, s7, s8, filter_y_tmp,
290                                        bd, conv_params->round_1);
291         t3 = wiener_convolve8_vert_4x8(s3, s4, s5, s6, s7, s8, s9, filter_y_tmp,
292                                        bd, conv_params->round_1);
293 
294         vst1_u8(d, t0);
295         d += dst_stride;
296         vst1_u8(d, t1);
297         d += dst_stride;
298         vst1_u8(d, t2);
299         d += dst_stride;
300         vst1_u8(d, t3);
301         d += dst_stride;
302 
303         s0 = s4;
304         s1 = s5;
305         s2 = s6;
306         s3 = s7;
307         s4 = s8;
308         s5 = s9;
309         s6 = s10;
310         height -= 4;
311       } while (height > 3);
312 
313       if (height) {
314         PROCESS_ROW_FOR_VERTICAL_FILTER
315       }
316       src_tmp_ptr += 8;
317       dst_tmp_ptr += 8;
318       w -= 8;
319     } while (w > 0);
320   }
321 #else
322   // Start of horizontal filtering.
323   process_row_for_horz_filtering(dst_ptr, filter_x_tmp, src_ptr, src_stride,
324                                  MAX_SB_SIZE, conv_params->round_0, w, height,
325                                  bd);
326 
327   // Start of vertical filtering.
328   {
329     int16_t *src_tmp_ptr, *s;
330     uint8_t *dst_tmp_ptr;
331     src_tmp_ptr = (int16_t *)temp;
332     dst_tmp_ptr = dst;
333     src_stride = MAX_SB_SIZE;
334 
335     do {
336       uint8x8_t t0;
337       int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
338       s = src_tmp_ptr;
339       s0 = vld1q_s16(s);
340       s += src_stride;
341       s1 = vld1q_s16(s);
342       s += src_stride;
343       s2 = vld1q_s16(s);
344       s += src_stride;
345       s3 = vld1q_s16(s);
346       s += src_stride;
347       s4 = vld1q_s16(s);
348       s += src_stride;
349       s5 = vld1q_s16(s);
350       s += src_stride;
351       s6 = vld1q_s16(s);
352       s += src_stride;
353       d = dst_tmp_ptr;
354       height = h;
355       PROCESS_ROW_FOR_VERTICAL_FILTER
356 
357       src_tmp_ptr += 8;
358       dst_tmp_ptr += 8;
359 
360       w -= 8;
361     } while (w > 0);
362   }
363 #endif
364 }
365