• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_config.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_dsp/arm/sum_neon.h"
20 
21 #define MAX_UPSAMPLE_SZ 16
22 
23 // TODO(aomedia:349436249): enable for armv7 after SIGBUS is fixed.
24 #if AOM_ARCH_AARCH64
25 
26 // These kernels are a transposed version of those defined in reconintra.c,
27 // with the absolute value of the negatives taken in the top row.
28 DECLARE_ALIGNED(16, const uint8_t,
29                 av1_filter_intra_taps_neon[FILTER_INTRA_MODES][7][8]) = {
30   // clang-format off
31   {
32       {  6,  5,  3,  3,  4,  3,  3,  3 },
33       { 10,  2,  1,  1,  6,  2,  2,  1 },
34       {  0, 10,  1,  1,  0,  6,  2,  2 },
35       {  0,  0, 10,  2,  0,  0,  6,  2 },
36       {  0,  0,  0, 10,  0,  0,  0,  6 },
37       { 12,  9,  7,  5,  2,  2,  2,  3 },
38       {  0,  0,  0,  0, 12,  9,  7,  5 }
39   },
40   {
41       { 10,  6,  4,  2, 10,  6,  4,  2 },
42       { 16,  0,  0,  0, 16,  0,  0,  0 },
43       {  0, 16,  0,  0,  0, 16,  0,  0 },
44       {  0,  0, 16,  0,  0,  0, 16,  0 },
45       {  0,  0,  0, 16,  0,  0,  0, 16 },
46       { 10,  6,  4,  2,  0,  0,  0,  0 },
47       {  0,  0,  0,  0, 10,  6,  4,  2 }
48   },
49   {
50       {  8,  8,  8,  8,  4,  4,  4,  4 },
51       {  8,  0,  0,  0,  4,  0,  0,  0 },
52       {  0,  8,  0,  0,  0,  4,  0,  0 },
53       {  0,  0,  8,  0,  0,  0,  4,  0 },
54       {  0,  0,  0,  8,  0,  0,  0,  4 },
55       { 16, 16, 16, 16,  0,  0,  0,  0 },
56       {  0,  0,  0,  0, 16, 16, 16, 16 }
57   },
58   {
59       {  2,  1,  1,  0,  1,  1,  1,  1 },
60       {  8,  3,  2,  1,  4,  3,  2,  2 },
61       {  0,  8,  3,  2,  0,  4,  3,  2 },
62       {  0,  0,  8,  3,  0,  0,  4,  3 },
63       {  0,  0,  0,  8,  0,  0,  0,  4 },
64       { 10,  6,  4,  2,  3,  4,  4,  3 },
65       {  0,  0,  0,  0, 10,  6,  4,  3 }
66   },
67   {
68       { 12, 10,  9,  8, 10,  9,  8,  7 },
69       { 14,  0,  0,  0, 12,  1,  0,  0 },
70       {  0, 14,  0,  0,  0, 12,  0,  0 },
71       {  0,  0, 14,  0,  0,  0, 12,  1 },
72       {  0,  0,  0, 14,  0,  0,  0, 12 },
73       { 14, 12, 11, 10,  0,  0,  1,  1 },
74       {  0,  0,  0,  0, 14, 12, 11,  9 }
75   }
76   // clang-format on
77 };
78 
79 #define FILTER_INTRA_SCALE_BITS 4
80 
av1_filter_intra_predictor_neon(uint8_t * dst,ptrdiff_t stride,TX_SIZE tx_size,const uint8_t * above,const uint8_t * left,int mode)81 void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
82                                      TX_SIZE tx_size, const uint8_t *above,
83                                      const uint8_t *left, int mode) {
84   const int width = tx_size_wide[tx_size];
85   const int height = tx_size_high[tx_size];
86   assert(width <= 32 && height <= 32);
87 
88   const uint8x8_t f0 = vld1_u8(av1_filter_intra_taps_neon[mode][0]);
89   const uint8x8_t f1 = vld1_u8(av1_filter_intra_taps_neon[mode][1]);
90   const uint8x8_t f2 = vld1_u8(av1_filter_intra_taps_neon[mode][2]);
91   const uint8x8_t f3 = vld1_u8(av1_filter_intra_taps_neon[mode][3]);
92   const uint8x8_t f4 = vld1_u8(av1_filter_intra_taps_neon[mode][4]);
93   const uint8x8_t f5 = vld1_u8(av1_filter_intra_taps_neon[mode][5]);
94   const uint8x8_t f6 = vld1_u8(av1_filter_intra_taps_neon[mode][6]);
95 
96   uint8_t buffer[33][33];
97   // Populate the top row in the scratch buffer with data from above.
98   memcpy(buffer[0], &above[-1], (width + 1) * sizeof(uint8_t));
99   // Populate the first column in the scratch buffer with data from the left.
100   int r = 0;
101   do {
102     buffer[r + 1][0] = left[r];
103   } while (++r < height);
104 
105   // Computing 4 cols per iteration (instead of 8) for 8x<h> blocks is faster.
106   if (width <= 8) {
107     r = 1;
108     do {
109       int c = 1;
110       uint8x8_t s0 = vld1_dup_u8(&buffer[r - 1][c - 1]);
111       uint8x8_t s5 = vld1_dup_u8(&buffer[r + 0][c - 1]);
112       uint8x8_t s6 = vld1_dup_u8(&buffer[r + 1][c - 1]);
113 
114       do {
115         uint8x8_t s1234 = load_u8_4x1(&buffer[r - 1][c - 1] + 1);
116         uint8x8_t s1 = vdup_lane_u8(s1234, 0);
117         uint8x8_t s2 = vdup_lane_u8(s1234, 1);
118         uint8x8_t s3 = vdup_lane_u8(s1234, 2);
119         uint8x8_t s4 = vdup_lane_u8(s1234, 3);
120 
121         uint16x8_t sum = vmull_u8(s1, f1);
122         // First row of each filter has all negative values so subtract.
123         sum = vmlsl_u8(sum, s0, f0);
124         sum = vmlal_u8(sum, s2, f2);
125         sum = vmlal_u8(sum, s3, f3);
126         sum = vmlal_u8(sum, s4, f4);
127         sum = vmlal_u8(sum, s5, f5);
128         sum = vmlal_u8(sum, s6, f6);
129 
130         uint8x8_t res =
131             vqrshrun_n_s16(vreinterpretq_s16_u16(sum), FILTER_INTRA_SCALE_BITS);
132 
133         // Store buffer[r + 0][c] and buffer[r + 1][c].
134         store_u8x4_strided_x2(&buffer[r][c], 33, res);
135 
136         store_u8x4_strided_x2(dst + (r - 1) * stride + c - 1, stride, res);
137 
138         s0 = s4;
139         s5 = vdup_lane_u8(res, 3);
140         s6 = vdup_lane_u8(res, 7);
141         c += 4;
142       } while (c < width + 1);
143 
144       r += 2;
145     } while (r < height + 1);
146   } else {
147     r = 1;
148     do {
149       int c = 1;
150       uint8x8_t s0_lo = vld1_dup_u8(&buffer[r - 1][c - 1]);
151       uint8x8_t s5_lo = vld1_dup_u8(&buffer[r + 0][c - 1]);
152       uint8x8_t s6_lo = vld1_dup_u8(&buffer[r + 1][c - 1]);
153 
154       do {
155         uint8x8_t s1234 = vld1_u8(&buffer[r - 1][c - 1] + 1);
156         uint8x8_t s1_lo = vdup_lane_u8(s1234, 0);
157         uint8x8_t s2_lo = vdup_lane_u8(s1234, 1);
158         uint8x8_t s3_lo = vdup_lane_u8(s1234, 2);
159         uint8x8_t s4_lo = vdup_lane_u8(s1234, 3);
160 
161         uint16x8_t sum_lo = vmull_u8(s1_lo, f1);
162         // First row of each filter has all negative values so subtract.
163         sum_lo = vmlsl_u8(sum_lo, s0_lo, f0);
164         sum_lo = vmlal_u8(sum_lo, s2_lo, f2);
165         sum_lo = vmlal_u8(sum_lo, s3_lo, f3);
166         sum_lo = vmlal_u8(sum_lo, s4_lo, f4);
167         sum_lo = vmlal_u8(sum_lo, s5_lo, f5);
168         sum_lo = vmlal_u8(sum_lo, s6_lo, f6);
169 
170         uint8x8_t res_lo = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_lo),
171                                           FILTER_INTRA_SCALE_BITS);
172 
173         uint8x8_t s0_hi = s4_lo;
174         uint8x8_t s1_hi = vdup_lane_u8(s1234, 4);
175         uint8x8_t s2_hi = vdup_lane_u8(s1234, 5);
176         uint8x8_t s3_hi = vdup_lane_u8(s1234, 6);
177         uint8x8_t s4_hi = vdup_lane_u8(s1234, 7);
178         uint8x8_t s5_hi = vdup_lane_u8(res_lo, 3);
179         uint8x8_t s6_hi = vdup_lane_u8(res_lo, 7);
180 
181         uint16x8_t sum_hi = vmull_u8(s1_hi, f1);
182         // First row of each filter has all negative values so subtract.
183         sum_hi = vmlsl_u8(sum_hi, s0_hi, f0);
184         sum_hi = vmlal_u8(sum_hi, s2_hi, f2);
185         sum_hi = vmlal_u8(sum_hi, s3_hi, f3);
186         sum_hi = vmlal_u8(sum_hi, s4_hi, f4);
187         sum_hi = vmlal_u8(sum_hi, s5_hi, f5);
188         sum_hi = vmlal_u8(sum_hi, s6_hi, f6);
189 
190         uint8x8_t res_hi = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_hi),
191                                           FILTER_INTRA_SCALE_BITS);
192 
193         uint32x2x2_t res =
194             vzip_u32(vreinterpret_u32_u8(res_lo), vreinterpret_u32_u8(res_hi));
195 
196         vst1_u8(&buffer[r + 0][c], vreinterpret_u8_u32(res.val[0]));
197         vst1_u8(&buffer[r + 1][c], vreinterpret_u8_u32(res.val[1]));
198 
199         vst1_u8(dst + (r - 1) * stride + c - 1,
200                 vreinterpret_u8_u32(res.val[0]));
201         vst1_u8(dst + (r + 0) * stride + c - 1,
202                 vreinterpret_u8_u32(res.val[1]));
203 
204         s0_lo = s4_hi;
205         s5_lo = vdup_lane_u8(res_hi, 3);
206         s6_lo = vdup_lane_u8(res_hi, 7);
207         c += 8;
208       } while (c < width + 1);
209 
210       r += 2;
211     } while (r < height + 1);
212   }
213 }
214 #endif  // AOM_ARCH_AARCH64
215 
av1_filter_intra_edge_neon(uint8_t * p,int sz,int strength)216 void av1_filter_intra_edge_neon(uint8_t *p, int sz, int strength) {
217   if (!strength) return;
218   assert(sz >= 0 && sz <= 129);
219 
220   uint8_t edge[160];  // Max value of sz + enough padding for vector accesses.
221   memcpy(edge + 1, p, sz * sizeof(*p));
222 
223   // Populate extra space appropriately.
224   edge[0] = edge[1];
225   edge[sz + 1] = edge[sz];
226   edge[sz + 2] = edge[sz];
227 
228   // Don't overwrite first pixel.
229   uint8_t *dst = p + 1;
230   sz--;
231 
232   if (strength == 1) {  // Filter: {4, 8, 4}.
233     const uint8_t *src = edge + 1;
234 
235     while (sz >= 8) {
236       uint8x8_t s0 = vld1_u8(src);
237       uint8x8_t s1 = vld1_u8(src + 1);
238       uint8x8_t s2 = vld1_u8(src + 2);
239 
240       // Make use of the identity:
241       // (4*a + 8*b + 4*c) >> 4 == (a + (b << 1) + c) >> 2
242       uint16x8_t t0 = vaddl_u8(s0, s2);
243       uint16x8_t t1 = vaddl_u8(s1, s1);
244       uint16x8_t sum = vaddq_u16(t0, t1);
245       uint8x8_t res = vrshrn_n_u16(sum, 2);
246 
247       vst1_u8(dst, res);
248 
249       src += 8;
250       dst += 8;
251       sz -= 8;
252     }
253 
254     if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
255       uint8x8_t s0 = vld1_u8(src);
256       uint8x8_t s1 = vld1_u8(src + 1);
257       uint8x8_t s2 = vld1_u8(src + 2);
258 
259       uint16x8_t t0 = vaddl_u8(s0, s2);
260       uint16x8_t t1 = vaddl_u8(s1, s1);
261       uint16x8_t sum = vaddq_u16(t0, t1);
262       uint8x8_t res = vrshrn_n_u16(sum, 2);
263 
264       // Mask off out-of-bounds indices.
265       uint8x8_t current_dst = vld1_u8(dst);
266       uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
267       res = vbsl_u8(mask, res, current_dst);
268 
269       vst1_u8(dst, res);
270     }
271   } else if (strength == 2) {  // Filter: {5, 6, 5}.
272     const uint8_t *src = edge + 1;
273 
274     const uint8x8x3_t filter = { { vdup_n_u8(5), vdup_n_u8(6), vdup_n_u8(5) } };
275 
276     while (sz >= 8) {
277       uint8x8_t s0 = vld1_u8(src);
278       uint8x8_t s1 = vld1_u8(src + 1);
279       uint8x8_t s2 = vld1_u8(src + 2);
280 
281       uint16x8_t accum = vmull_u8(s0, filter.val[0]);
282       accum = vmlal_u8(accum, s1, filter.val[1]);
283       accum = vmlal_u8(accum, s2, filter.val[2]);
284       uint8x8_t res = vrshrn_n_u16(accum, 4);
285 
286       vst1_u8(dst, res);
287 
288       src += 8;
289       dst += 8;
290       sz -= 8;
291     }
292 
293     if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
294       uint8x8_t s0 = vld1_u8(src);
295       uint8x8_t s1 = vld1_u8(src + 1);
296       uint8x8_t s2 = vld1_u8(src + 2);
297 
298       uint16x8_t accum = vmull_u8(s0, filter.val[0]);
299       accum = vmlal_u8(accum, s1, filter.val[1]);
300       accum = vmlal_u8(accum, s2, filter.val[2]);
301       uint8x8_t res = vrshrn_n_u16(accum, 4);
302 
303       // Mask off out-of-bounds indices.
304       uint8x8_t current_dst = vld1_u8(dst);
305       uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
306       res = vbsl_u8(mask, res, current_dst);
307 
308       vst1_u8(dst, res);
309     }
310   } else {  // Filter {2, 4, 4, 4, 2}.
311     const uint8_t *src = edge;
312 
313     while (sz >= 8) {
314       uint8x8_t s0 = vld1_u8(src);
315       uint8x8_t s1 = vld1_u8(src + 1);
316       uint8x8_t s2 = vld1_u8(src + 2);
317       uint8x8_t s3 = vld1_u8(src + 3);
318       uint8x8_t s4 = vld1_u8(src + 4);
319 
320       // Make use of the identity:
321       // (2*a + 4*b + 4*c + 4*d + 2*e) >> 4 == (a + ((b + c + d) << 1) + e) >> 3
322       uint16x8_t t0 = vaddl_u8(s0, s4);
323       uint16x8_t t1 = vaddl_u8(s1, s2);
324       t1 = vaddw_u8(t1, s3);
325       t1 = vaddq_u16(t1, t1);
326       uint16x8_t sum = vaddq_u16(t0, t1);
327       uint8x8_t res = vrshrn_n_u16(sum, 3);
328 
329       vst1_u8(dst, res);
330 
331       src += 8;
332       dst += 8;
333       sz -= 8;
334     }
335 
336     if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
337       uint8x8_t s0 = vld1_u8(src);
338       uint8x8_t s1 = vld1_u8(src + 1);
339       uint8x8_t s2 = vld1_u8(src + 2);
340       uint8x8_t s3 = vld1_u8(src + 3);
341       uint8x8_t s4 = vld1_u8(src + 4);
342 
343       uint16x8_t t0 = vaddl_u8(s0, s4);
344       uint16x8_t t1 = vaddl_u8(s1, s2);
345       t1 = vaddw_u8(t1, s3);
346       t1 = vaddq_u16(t1, t1);
347       uint16x8_t sum = vaddq_u16(t0, t1);
348       uint8x8_t res = vrshrn_n_u16(sum, 3);
349 
350       // Mask off out-of-bounds indices.
351       uint8x8_t current_dst = vld1_u8(dst);
352       uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
353       res = vbsl_u8(mask, res, current_dst);
354 
355       vst1_u8(dst, res);
356     }
357   }
358 }
359 
av1_upsample_intra_edge_neon(uint8_t * p,int sz)360 void av1_upsample_intra_edge_neon(uint8_t *p, int sz) {
361   if (!sz) return;
362 
363   assert(sz <= MAX_UPSAMPLE_SZ);
364 
365   uint8_t edge[MAX_UPSAMPLE_SZ + 3];
366   const uint8_t *src = edge;
367 
368   // Copy p[-1..(sz-1)] and pad out both ends.
369   edge[0] = p[-1];
370   edge[1] = p[-1];
371   memcpy(edge + 2, p, sz);
372   edge[sz + 2] = p[sz - 1];
373   p[-2] = p[-1];
374 
375   uint8_t *dst = p - 1;
376 
377   do {
378     uint8x8_t s0 = vld1_u8(src);
379     uint8x8_t s1 = vld1_u8(src + 1);
380     uint8x8_t s2 = vld1_u8(src + 2);
381     uint8x8_t s3 = vld1_u8(src + 3);
382 
383     int16x8_t t0 = vreinterpretq_s16_u16(vaddl_u8(s0, s3));
384     int16x8_t t1 = vreinterpretq_s16_u16(vaddl_u8(s1, s2));
385     t1 = vmulq_n_s16(t1, 9);
386     t1 = vsubq_s16(t1, t0);
387 
388     uint8x8x2_t res = { { vqrshrun_n_s16(t1, 4), s2 } };
389 
390     vst2_u8(dst, res);
391 
392     src += 8;
393     dst += 16;
394     sz -= 8;
395   } while (sz > 0);
396 }
397