• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2020, Alliance for Open Media. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 
13 #include "config/aom_dsp_rtcd.h"
14 #include "aom/aom_integer.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17 #include "aom_dsp/arm/transpose_neon.h"
18 
sse_w16_neon(uint32x4_t * sum,const uint8_t * a,const uint8_t * b)19 static INLINE void sse_w16_neon(uint32x4_t *sum, const uint8_t *a,
20                                 const uint8_t *b) {
21   const uint8x16_t v_a0 = vld1q_u8(a);
22   const uint8x16_t v_b0 = vld1q_u8(b);
23   const uint8x16_t diff = vabdq_u8(v_a0, v_b0);
24   const uint8x8_t diff_lo = vget_low_u8(diff);
25   const uint8x8_t diff_hi = vget_high_u8(diff);
26   *sum = vpadalq_u16(*sum, vmull_u8(diff_lo, diff_lo));
27   *sum = vpadalq_u16(*sum, vmull_u8(diff_hi, diff_hi));
28 }
aom_sse4x2_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,uint32x4_t * sum)29 static INLINE void aom_sse4x2_neon(const uint8_t *a, int a_stride,
30                                    const uint8_t *b, int b_stride,
31                                    uint32x4_t *sum) {
32   uint8x8_t v_a0, v_b0;
33   v_a0 = v_b0 = vcreate_u8(0);
34   // above line is only to shadow [-Werror=uninitialized]
35   v_a0 = vreinterpret_u8_u32(
36       vld1_lane_u32((uint32_t *)a, vreinterpret_u32_u8(v_a0), 0));
37   v_a0 = vreinterpret_u8_u32(
38       vld1_lane_u32((uint32_t *)(a + a_stride), vreinterpret_u32_u8(v_a0), 1));
39   v_b0 = vreinterpret_u8_u32(
40       vld1_lane_u32((uint32_t *)b, vreinterpret_u32_u8(v_b0), 0));
41   v_b0 = vreinterpret_u8_u32(
42       vld1_lane_u32((uint32_t *)(b + b_stride), vreinterpret_u32_u8(v_b0), 1));
43   const uint8x8_t v_a_w = vabd_u8(v_a0, v_b0);
44   *sum = vpadalq_u16(*sum, vmull_u8(v_a_w, v_a_w));
45 }
aom_sse8_neon(const uint8_t * a,const uint8_t * b,uint32x4_t * sum)46 static INLINE void aom_sse8_neon(const uint8_t *a, const uint8_t *b,
47                                  uint32x4_t *sum) {
48   const uint8x8_t v_a_w = vld1_u8(a);
49   const uint8x8_t v_b_w = vld1_u8(b);
50   const uint8x8_t v_d_w = vabd_u8(v_a_w, v_b_w);
51   *sum = vpadalq_u16(*sum, vmull_u8(v_d_w, v_d_w));
52 }
aom_sse_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int width,int height)53 int64_t aom_sse_neon(const uint8_t *a, int a_stride, const uint8_t *b,
54                      int b_stride, int width, int height) {
55   int y = 0;
56   int64_t sse = 0;
57   uint32x4_t sum = vdupq_n_u32(0);
58   switch (width) {
59     case 4:
60       do {
61         aom_sse4x2_neon(a, a_stride, b, b_stride, &sum);
62         a += a_stride << 1;
63         b += b_stride << 1;
64         y += 2;
65       } while (y < height);
66 #if defined(__aarch64__)
67       sse = vaddvq_u32(sum);
68 #else
69       sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
70 #endif  // __aarch64__
71       break;
72     case 8:
73       do {
74         aom_sse8_neon(a, b, &sum);
75         a += a_stride;
76         b += b_stride;
77         y += 1;
78       } while (y < height);
79 #if defined(__aarch64__)
80       sse = vaddvq_u32(sum);
81 #else
82       sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
83 #endif  // __aarch64__
84       break;
85     case 16:
86       do {
87         sse_w16_neon(&sum, a, b);
88         a += a_stride;
89         b += b_stride;
90         y += 1;
91       } while (y < height);
92 #if defined(__aarch64__)
93       sse = vaddvq_u32(sum);
94 #else
95       sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
96 #endif  // __aarch64__
97       break;
98     case 32:
99       do {
100         sse_w16_neon(&sum, a, b);
101         sse_w16_neon(&sum, a + 16, b + 16);
102         a += a_stride;
103         b += b_stride;
104         y += 1;
105       } while (y < height);
106 #if defined(__aarch64__)
107       sse = vaddvq_u32(sum);
108 #else
109       sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
110 #endif  // __aarch64__
111       break;
112     case 64:
113       do {
114         sse_w16_neon(&sum, a, b);
115         sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
116         sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
117         sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
118         a += a_stride;
119         b += b_stride;
120         y += 1;
121       } while (y < height);
122 #if defined(__aarch64__)
123       sse = vaddvq_u32(sum);
124 #else
125       sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
126 #endif  // __aarch64__
127       break;
128     case 128:
129       do {
130         sse_w16_neon(&sum, a, b);
131         sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
132         sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
133         sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
134         sse_w16_neon(&sum, a + 16 * 4, b + 16 * 4);
135         sse_w16_neon(&sum, a + 16 * 5, b + 16 * 5);
136         sse_w16_neon(&sum, a + 16 * 6, b + 16 * 6);
137         sse_w16_neon(&sum, a + 16 * 7, b + 16 * 7);
138         a += a_stride;
139         b += b_stride;
140         y += 1;
141       } while (y < height);
142 #if defined(__aarch64__)
143       sse = vaddvq_u32(sum);
144 #else
145       sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
146 #endif  // __aarch64__
147       break;
148     default:
149       if (width & 0x07) {
150         do {
151           int i = 0;
152           do {
153             aom_sse8_neon(a + i, b + i, &sum);
154             aom_sse8_neon(a + i + a_stride, b + i + b_stride, &sum);
155             i += 8;
156           } while (i + 4 < width);
157           aom_sse4x2_neon(a + i, a_stride, b + i, b_stride, &sum);
158           a += (a_stride << 1);
159           b += (b_stride << 1);
160           y += 2;
161         } while (y < height);
162       } else {
163         do {
164           int i = 0;
165           do {
166             aom_sse8_neon(a + i, b + i, &sum);
167             i += 8;
168           } while (i < width);
169           a += a_stride;
170           b += b_stride;
171           y += 1;
172         } while (y < height);
173       }
174 #if defined(__aarch64__)
175       sse = vaddvq_u32(sum);
176 #else
177       sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
178 #endif  // __aarch64__
179       break;
180   }
181   return sse;
182 }
183 
184 #if CONFIG_AV1_HIGHBITDEPTH
highbd_sse_W8x1_neon(uint16x8_t q2,uint16x8_t q3)185 static INLINE uint32_t highbd_sse_W8x1_neon(uint16x8_t q2, uint16x8_t q3) {
186   uint32_t sse;
187   const uint32_t sse1 = 0;
188   const uint32x4_t q1 = vld1q_dup_u32(&sse1);
189 
190   uint16x8_t q4 = vabdq_u16(q2, q3);  // diff = abs(a[x] - b[x])
191   uint16x4_t d0 = vget_low_u16(q4);
192   uint16x4_t d1 = vget_high_u16(q4);
193 
194   uint32x4_t q6 = vmlal_u16(q1, d0, d0);
195   uint32x4_t q7 = vmlal_u16(q1, d1, d1);
196 
197   uint32x2_t d4 = vadd_u32(vget_low_u32(q6), vget_high_u32(q6));
198   uint32x2_t d5 = vadd_u32(vget_low_u32(q7), vget_high_u32(q7));
199 
200   uint32x2_t d6 = vadd_u32(d4, d5);
201 
202   sse = vget_lane_u32(d6, 0);
203   sse += vget_lane_u32(d6, 1);
204 
205   return sse;
206 }
207 
aom_highbd_sse_neon(const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,int width,int height)208 int64_t aom_highbd_sse_neon(const uint8_t *a8, int a_stride, const uint8_t *b8,
209                             int b_stride, int width, int height) {
210   const uint16x8_t q0 = { 0, 1, 2, 3, 4, 5, 6, 7 };
211   int64_t sse = 0;
212   uint16_t *a = CONVERT_TO_SHORTPTR(a8);
213   uint16_t *b = CONVERT_TO_SHORTPTR(b8);
214   int x, y;
215   int addinc;
216   uint16x4_t d0, d1, d2, d3;
217   uint16_t dx;
218   uint16x8_t q2, q3, q4, q5;
219 
220   switch (width) {
221     case 4:
222       for (y = 0; y < height; y += 2) {
223         d0 = vld1_u16(a);  // load 4 data
224         a += a_stride;
225         d1 = vld1_u16(a);
226         a += a_stride;
227 
228         d2 = vld1_u16(b);
229         b += b_stride;
230         d3 = vld1_u16(b);
231         b += b_stride;
232         q2 = vcombine_u16(d0, d1);  // make a 8 data vector
233         q3 = vcombine_u16(d2, d3);
234 
235         sse += highbd_sse_W8x1_neon(q2, q3);
236       }
237       break;
238     case 8:
239       for (y = 0; y < height; y++) {
240         q2 = vld1q_u16(a);
241         q3 = vld1q_u16(b);
242 
243         sse += highbd_sse_W8x1_neon(q2, q3);
244 
245         a += a_stride;
246         b += b_stride;
247       }
248       break;
249     case 16:
250       for (y = 0; y < height; y++) {
251         q2 = vld1q_u16(a);
252         q3 = vld1q_u16(b);
253 
254         sse += highbd_sse_W8x1_neon(q2, q3);
255 
256         q2 = vld1q_u16(a + 8);
257         q3 = vld1q_u16(b + 8);
258 
259         sse += highbd_sse_W8x1_neon(q2, q3);
260 
261         a += a_stride;
262         b += b_stride;
263       }
264       break;
265     case 32:
266       for (y = 0; y < height; y++) {
267         q2 = vld1q_u16(a);
268         q3 = vld1q_u16(b);
269 
270         sse += highbd_sse_W8x1_neon(q2, q3);
271 
272         q2 = vld1q_u16(a + 8);
273         q3 = vld1q_u16(b + 8);
274 
275         sse += highbd_sse_W8x1_neon(q2, q3);
276 
277         q2 = vld1q_u16(a + 16);
278         q3 = vld1q_u16(b + 16);
279 
280         sse += highbd_sse_W8x1_neon(q2, q3);
281 
282         q2 = vld1q_u16(a + 24);
283         q3 = vld1q_u16(b + 24);
284 
285         sse += highbd_sse_W8x1_neon(q2, q3);
286 
287         a += a_stride;
288         b += b_stride;
289       }
290       break;
291     case 64:
292       for (y = 0; y < height; y++) {
293         q2 = vld1q_u16(a);
294         q3 = vld1q_u16(b);
295 
296         sse += highbd_sse_W8x1_neon(q2, q3);
297 
298         q2 = vld1q_u16(a + 8);
299         q3 = vld1q_u16(b + 8);
300 
301         sse += highbd_sse_W8x1_neon(q2, q3);
302 
303         q2 = vld1q_u16(a + 16);
304         q3 = vld1q_u16(b + 16);
305 
306         sse += highbd_sse_W8x1_neon(q2, q3);
307 
308         q2 = vld1q_u16(a + 24);
309         q3 = vld1q_u16(b + 24);
310 
311         sse += highbd_sse_W8x1_neon(q2, q3);
312 
313         q2 = vld1q_u16(a + 32);
314         q3 = vld1q_u16(b + 32);
315 
316         sse += highbd_sse_W8x1_neon(q2, q3);
317 
318         q2 = vld1q_u16(a + 40);
319         q3 = vld1q_u16(b + 40);
320 
321         sse += highbd_sse_W8x1_neon(q2, q3);
322 
323         q2 = vld1q_u16(a + 48);
324         q3 = vld1q_u16(b + 48);
325 
326         sse += highbd_sse_W8x1_neon(q2, q3);
327 
328         q2 = vld1q_u16(a + 56);
329         q3 = vld1q_u16(b + 56);
330 
331         sse += highbd_sse_W8x1_neon(q2, q3);
332 
333         a += a_stride;
334         b += b_stride;
335       }
336       break;
337     case 128:
338       for (y = 0; y < height; y++) {
339         q2 = vld1q_u16(a);
340         q3 = vld1q_u16(b);
341 
342         sse += highbd_sse_W8x1_neon(q2, q3);
343 
344         q2 = vld1q_u16(a + 8);
345         q3 = vld1q_u16(b + 8);
346 
347         sse += highbd_sse_W8x1_neon(q2, q3);
348 
349         q2 = vld1q_u16(a + 16);
350         q3 = vld1q_u16(b + 16);
351 
352         sse += highbd_sse_W8x1_neon(q2, q3);
353 
354         q2 = vld1q_u16(a + 24);
355         q3 = vld1q_u16(b + 24);
356 
357         sse += highbd_sse_W8x1_neon(q2, q3);
358 
359         q2 = vld1q_u16(a + 32);
360         q3 = vld1q_u16(b + 32);
361 
362         sse += highbd_sse_W8x1_neon(q2, q3);
363 
364         q2 = vld1q_u16(a + 40);
365         q3 = vld1q_u16(b + 40);
366 
367         sse += highbd_sse_W8x1_neon(q2, q3);
368 
369         q2 = vld1q_u16(a + 48);
370         q3 = vld1q_u16(b + 48);
371 
372         sse += highbd_sse_W8x1_neon(q2, q3);
373 
374         q2 = vld1q_u16(a + 56);
375         q3 = vld1q_u16(b + 56);
376 
377         sse += highbd_sse_W8x1_neon(q2, q3);
378 
379         q2 = vld1q_u16(a + 64);
380         q3 = vld1q_u16(b + 64);
381 
382         sse += highbd_sse_W8x1_neon(q2, q3);
383 
384         q2 = vld1q_u16(a + 72);
385         q3 = vld1q_u16(b + 72);
386 
387         sse += highbd_sse_W8x1_neon(q2, q3);
388 
389         q2 = vld1q_u16(a + 80);
390         q3 = vld1q_u16(b + 80);
391 
392         sse += highbd_sse_W8x1_neon(q2, q3);
393 
394         q2 = vld1q_u16(a + 88);
395         q3 = vld1q_u16(b + 88);
396 
397         sse += highbd_sse_W8x1_neon(q2, q3);
398 
399         q2 = vld1q_u16(a + 96);
400         q3 = vld1q_u16(b + 96);
401 
402         sse += highbd_sse_W8x1_neon(q2, q3);
403 
404         q2 = vld1q_u16(a + 104);
405         q3 = vld1q_u16(b + 104);
406 
407         sse += highbd_sse_W8x1_neon(q2, q3);
408 
409         q2 = vld1q_u16(a + 112);
410         q3 = vld1q_u16(b + 112);
411 
412         sse += highbd_sse_W8x1_neon(q2, q3);
413 
414         q2 = vld1q_u16(a + 120);
415         q3 = vld1q_u16(b + 120);
416 
417         sse += highbd_sse_W8x1_neon(q2, q3);
418         a += a_stride;
419         b += b_stride;
420       }
421       break;
422     default:
423 
424       for (y = 0; y < height; y++) {
425         x = width;
426         while (x > 0) {
427           addinc = width - x;
428           q2 = vld1q_u16(a + addinc);
429           q3 = vld1q_u16(b + addinc);
430           if (x < 8) {
431             dx = x;
432             q4 = vld1q_dup_u16(&dx);
433             q5 = vcltq_u16(q0, q4);
434             q2 = vandq_u16(q2, q5);
435             q3 = vandq_u16(q3, q5);
436           }
437           sse += highbd_sse_W8x1_neon(q2, q3);
438           x -= 8;
439         }
440         a += a_stride;
441         b += b_stride;
442       }
443   }
444   return (int64_t)sse;
445 }
446 #endif
447