1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12
13 #include "config/aom_dsp_rtcd.h"
14 #include "aom/aom_integer.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17 #include "aom_dsp/arm/transpose_neon.h"
18
sse_w16_neon(uint32x4_t * sum,const uint8_t * a,const uint8_t * b)19 static INLINE void sse_w16_neon(uint32x4_t *sum, const uint8_t *a,
20 const uint8_t *b) {
21 const uint8x16_t v_a0 = vld1q_u8(a);
22 const uint8x16_t v_b0 = vld1q_u8(b);
23 const uint8x16_t diff = vabdq_u8(v_a0, v_b0);
24 const uint8x8_t diff_lo = vget_low_u8(diff);
25 const uint8x8_t diff_hi = vget_high_u8(diff);
26 *sum = vpadalq_u16(*sum, vmull_u8(diff_lo, diff_lo));
27 *sum = vpadalq_u16(*sum, vmull_u8(diff_hi, diff_hi));
28 }
aom_sse4x2_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,uint32x4_t * sum)29 static INLINE void aom_sse4x2_neon(const uint8_t *a, int a_stride,
30 const uint8_t *b, int b_stride,
31 uint32x4_t *sum) {
32 uint8x8_t v_a0, v_b0;
33 v_a0 = v_b0 = vcreate_u8(0);
34 // above line is only to shadow [-Werror=uninitialized]
35 v_a0 = vreinterpret_u8_u32(
36 vld1_lane_u32((uint32_t *)a, vreinterpret_u32_u8(v_a0), 0));
37 v_a0 = vreinterpret_u8_u32(
38 vld1_lane_u32((uint32_t *)(a + a_stride), vreinterpret_u32_u8(v_a0), 1));
39 v_b0 = vreinterpret_u8_u32(
40 vld1_lane_u32((uint32_t *)b, vreinterpret_u32_u8(v_b0), 0));
41 v_b0 = vreinterpret_u8_u32(
42 vld1_lane_u32((uint32_t *)(b + b_stride), vreinterpret_u32_u8(v_b0), 1));
43 const uint8x8_t v_a_w = vabd_u8(v_a0, v_b0);
44 *sum = vpadalq_u16(*sum, vmull_u8(v_a_w, v_a_w));
45 }
aom_sse8_neon(const uint8_t * a,const uint8_t * b,uint32x4_t * sum)46 static INLINE void aom_sse8_neon(const uint8_t *a, const uint8_t *b,
47 uint32x4_t *sum) {
48 const uint8x8_t v_a_w = vld1_u8(a);
49 const uint8x8_t v_b_w = vld1_u8(b);
50 const uint8x8_t v_d_w = vabd_u8(v_a_w, v_b_w);
51 *sum = vpadalq_u16(*sum, vmull_u8(v_d_w, v_d_w));
52 }
aom_sse_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int width,int height)53 int64_t aom_sse_neon(const uint8_t *a, int a_stride, const uint8_t *b,
54 int b_stride, int width, int height) {
55 int y = 0;
56 int64_t sse = 0;
57 uint32x4_t sum = vdupq_n_u32(0);
58 switch (width) {
59 case 4:
60 do {
61 aom_sse4x2_neon(a, a_stride, b, b_stride, &sum);
62 a += a_stride << 1;
63 b += b_stride << 1;
64 y += 2;
65 } while (y < height);
66 #if defined(__aarch64__)
67 sse = vaddvq_u32(sum);
68 #else
69 sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
70 #endif // __aarch64__
71 break;
72 case 8:
73 do {
74 aom_sse8_neon(a, b, &sum);
75 a += a_stride;
76 b += b_stride;
77 y += 1;
78 } while (y < height);
79 #if defined(__aarch64__)
80 sse = vaddvq_u32(sum);
81 #else
82 sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
83 #endif // __aarch64__
84 break;
85 case 16:
86 do {
87 sse_w16_neon(&sum, a, b);
88 a += a_stride;
89 b += b_stride;
90 y += 1;
91 } while (y < height);
92 #if defined(__aarch64__)
93 sse = vaddvq_u32(sum);
94 #else
95 sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
96 #endif // __aarch64__
97 break;
98 case 32:
99 do {
100 sse_w16_neon(&sum, a, b);
101 sse_w16_neon(&sum, a + 16, b + 16);
102 a += a_stride;
103 b += b_stride;
104 y += 1;
105 } while (y < height);
106 #if defined(__aarch64__)
107 sse = vaddvq_u32(sum);
108 #else
109 sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
110 #endif // __aarch64__
111 break;
112 case 64:
113 do {
114 sse_w16_neon(&sum, a, b);
115 sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
116 sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
117 sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
118 a += a_stride;
119 b += b_stride;
120 y += 1;
121 } while (y < height);
122 #if defined(__aarch64__)
123 sse = vaddvq_u32(sum);
124 #else
125 sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
126 #endif // __aarch64__
127 break;
128 case 128:
129 do {
130 sse_w16_neon(&sum, a, b);
131 sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
132 sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
133 sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
134 sse_w16_neon(&sum, a + 16 * 4, b + 16 * 4);
135 sse_w16_neon(&sum, a + 16 * 5, b + 16 * 5);
136 sse_w16_neon(&sum, a + 16 * 6, b + 16 * 6);
137 sse_w16_neon(&sum, a + 16 * 7, b + 16 * 7);
138 a += a_stride;
139 b += b_stride;
140 y += 1;
141 } while (y < height);
142 #if defined(__aarch64__)
143 sse = vaddvq_u32(sum);
144 #else
145 sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
146 #endif // __aarch64__
147 break;
148 default:
149 if (width & 0x07) {
150 do {
151 int i = 0;
152 do {
153 aom_sse8_neon(a + i, b + i, &sum);
154 aom_sse8_neon(a + i + a_stride, b + i + b_stride, &sum);
155 i += 8;
156 } while (i + 4 < width);
157 aom_sse4x2_neon(a + i, a_stride, b + i, b_stride, &sum);
158 a += (a_stride << 1);
159 b += (b_stride << 1);
160 y += 2;
161 } while (y < height);
162 } else {
163 do {
164 int i = 0;
165 do {
166 aom_sse8_neon(a + i, b + i, &sum);
167 i += 8;
168 } while (i < width);
169 a += a_stride;
170 b += b_stride;
171 y += 1;
172 } while (y < height);
173 }
174 #if defined(__aarch64__)
175 sse = vaddvq_u32(sum);
176 #else
177 sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
178 #endif // __aarch64__
179 break;
180 }
181 return sse;
182 }
183
184 #if CONFIG_AV1_HIGHBITDEPTH
highbd_sse_W8x1_neon(uint16x8_t q2,uint16x8_t q3)185 static INLINE uint32_t highbd_sse_W8x1_neon(uint16x8_t q2, uint16x8_t q3) {
186 uint32_t sse;
187 const uint32_t sse1 = 0;
188 const uint32x4_t q1 = vld1q_dup_u32(&sse1);
189
190 uint16x8_t q4 = vabdq_u16(q2, q3); // diff = abs(a[x] - b[x])
191 uint16x4_t d0 = vget_low_u16(q4);
192 uint16x4_t d1 = vget_high_u16(q4);
193
194 uint32x4_t q6 = vmlal_u16(q1, d0, d0);
195 uint32x4_t q7 = vmlal_u16(q1, d1, d1);
196
197 uint32x2_t d4 = vadd_u32(vget_low_u32(q6), vget_high_u32(q6));
198 uint32x2_t d5 = vadd_u32(vget_low_u32(q7), vget_high_u32(q7));
199
200 uint32x2_t d6 = vadd_u32(d4, d5);
201
202 sse = vget_lane_u32(d6, 0);
203 sse += vget_lane_u32(d6, 1);
204
205 return sse;
206 }
207
aom_highbd_sse_neon(const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,int width,int height)208 int64_t aom_highbd_sse_neon(const uint8_t *a8, int a_stride, const uint8_t *b8,
209 int b_stride, int width, int height) {
210 const uint16x8_t q0 = { 0, 1, 2, 3, 4, 5, 6, 7 };
211 int64_t sse = 0;
212 uint16_t *a = CONVERT_TO_SHORTPTR(a8);
213 uint16_t *b = CONVERT_TO_SHORTPTR(b8);
214 int x, y;
215 int addinc;
216 uint16x4_t d0, d1, d2, d3;
217 uint16_t dx;
218 uint16x8_t q2, q3, q4, q5;
219
220 switch (width) {
221 case 4:
222 for (y = 0; y < height; y += 2) {
223 d0 = vld1_u16(a); // load 4 data
224 a += a_stride;
225 d1 = vld1_u16(a);
226 a += a_stride;
227
228 d2 = vld1_u16(b);
229 b += b_stride;
230 d3 = vld1_u16(b);
231 b += b_stride;
232 q2 = vcombine_u16(d0, d1); // make a 8 data vector
233 q3 = vcombine_u16(d2, d3);
234
235 sse += highbd_sse_W8x1_neon(q2, q3);
236 }
237 break;
238 case 8:
239 for (y = 0; y < height; y++) {
240 q2 = vld1q_u16(a);
241 q3 = vld1q_u16(b);
242
243 sse += highbd_sse_W8x1_neon(q2, q3);
244
245 a += a_stride;
246 b += b_stride;
247 }
248 break;
249 case 16:
250 for (y = 0; y < height; y++) {
251 q2 = vld1q_u16(a);
252 q3 = vld1q_u16(b);
253
254 sse += highbd_sse_W8x1_neon(q2, q3);
255
256 q2 = vld1q_u16(a + 8);
257 q3 = vld1q_u16(b + 8);
258
259 sse += highbd_sse_W8x1_neon(q2, q3);
260
261 a += a_stride;
262 b += b_stride;
263 }
264 break;
265 case 32:
266 for (y = 0; y < height; y++) {
267 q2 = vld1q_u16(a);
268 q3 = vld1q_u16(b);
269
270 sse += highbd_sse_W8x1_neon(q2, q3);
271
272 q2 = vld1q_u16(a + 8);
273 q3 = vld1q_u16(b + 8);
274
275 sse += highbd_sse_W8x1_neon(q2, q3);
276
277 q2 = vld1q_u16(a + 16);
278 q3 = vld1q_u16(b + 16);
279
280 sse += highbd_sse_W8x1_neon(q2, q3);
281
282 q2 = vld1q_u16(a + 24);
283 q3 = vld1q_u16(b + 24);
284
285 sse += highbd_sse_W8x1_neon(q2, q3);
286
287 a += a_stride;
288 b += b_stride;
289 }
290 break;
291 case 64:
292 for (y = 0; y < height; y++) {
293 q2 = vld1q_u16(a);
294 q3 = vld1q_u16(b);
295
296 sse += highbd_sse_W8x1_neon(q2, q3);
297
298 q2 = vld1q_u16(a + 8);
299 q3 = vld1q_u16(b + 8);
300
301 sse += highbd_sse_W8x1_neon(q2, q3);
302
303 q2 = vld1q_u16(a + 16);
304 q3 = vld1q_u16(b + 16);
305
306 sse += highbd_sse_W8x1_neon(q2, q3);
307
308 q2 = vld1q_u16(a + 24);
309 q3 = vld1q_u16(b + 24);
310
311 sse += highbd_sse_W8x1_neon(q2, q3);
312
313 q2 = vld1q_u16(a + 32);
314 q3 = vld1q_u16(b + 32);
315
316 sse += highbd_sse_W8x1_neon(q2, q3);
317
318 q2 = vld1q_u16(a + 40);
319 q3 = vld1q_u16(b + 40);
320
321 sse += highbd_sse_W8x1_neon(q2, q3);
322
323 q2 = vld1q_u16(a + 48);
324 q3 = vld1q_u16(b + 48);
325
326 sse += highbd_sse_W8x1_neon(q2, q3);
327
328 q2 = vld1q_u16(a + 56);
329 q3 = vld1q_u16(b + 56);
330
331 sse += highbd_sse_W8x1_neon(q2, q3);
332
333 a += a_stride;
334 b += b_stride;
335 }
336 break;
337 case 128:
338 for (y = 0; y < height; y++) {
339 q2 = vld1q_u16(a);
340 q3 = vld1q_u16(b);
341
342 sse += highbd_sse_W8x1_neon(q2, q3);
343
344 q2 = vld1q_u16(a + 8);
345 q3 = vld1q_u16(b + 8);
346
347 sse += highbd_sse_W8x1_neon(q2, q3);
348
349 q2 = vld1q_u16(a + 16);
350 q3 = vld1q_u16(b + 16);
351
352 sse += highbd_sse_W8x1_neon(q2, q3);
353
354 q2 = vld1q_u16(a + 24);
355 q3 = vld1q_u16(b + 24);
356
357 sse += highbd_sse_W8x1_neon(q2, q3);
358
359 q2 = vld1q_u16(a + 32);
360 q3 = vld1q_u16(b + 32);
361
362 sse += highbd_sse_W8x1_neon(q2, q3);
363
364 q2 = vld1q_u16(a + 40);
365 q3 = vld1q_u16(b + 40);
366
367 sse += highbd_sse_W8x1_neon(q2, q3);
368
369 q2 = vld1q_u16(a + 48);
370 q3 = vld1q_u16(b + 48);
371
372 sse += highbd_sse_W8x1_neon(q2, q3);
373
374 q2 = vld1q_u16(a + 56);
375 q3 = vld1q_u16(b + 56);
376
377 sse += highbd_sse_W8x1_neon(q2, q3);
378
379 q2 = vld1q_u16(a + 64);
380 q3 = vld1q_u16(b + 64);
381
382 sse += highbd_sse_W8x1_neon(q2, q3);
383
384 q2 = vld1q_u16(a + 72);
385 q3 = vld1q_u16(b + 72);
386
387 sse += highbd_sse_W8x1_neon(q2, q3);
388
389 q2 = vld1q_u16(a + 80);
390 q3 = vld1q_u16(b + 80);
391
392 sse += highbd_sse_W8x1_neon(q2, q3);
393
394 q2 = vld1q_u16(a + 88);
395 q3 = vld1q_u16(b + 88);
396
397 sse += highbd_sse_W8x1_neon(q2, q3);
398
399 q2 = vld1q_u16(a + 96);
400 q3 = vld1q_u16(b + 96);
401
402 sse += highbd_sse_W8x1_neon(q2, q3);
403
404 q2 = vld1q_u16(a + 104);
405 q3 = vld1q_u16(b + 104);
406
407 sse += highbd_sse_W8x1_neon(q2, q3);
408
409 q2 = vld1q_u16(a + 112);
410 q3 = vld1q_u16(b + 112);
411
412 sse += highbd_sse_W8x1_neon(q2, q3);
413
414 q2 = vld1q_u16(a + 120);
415 q3 = vld1q_u16(b + 120);
416
417 sse += highbd_sse_W8x1_neon(q2, q3);
418 a += a_stride;
419 b += b_stride;
420 }
421 break;
422 default:
423
424 for (y = 0; y < height; y++) {
425 x = width;
426 while (x > 0) {
427 addinc = width - x;
428 q2 = vld1q_u16(a + addinc);
429 q3 = vld1q_u16(b + addinc);
430 if (x < 8) {
431 dx = x;
432 q4 = vld1q_dup_u16(&dx);
433 q5 = vcltq_u16(q0, q4);
434 q2 = vandq_u16(q2, q5);
435 q3 = vandq_u16(q3, q5);
436 }
437 sse += highbd_sse_W8x1_neon(q2, q3);
438 x -= 8;
439 }
440 a += a_stride;
441 b += b_stride;
442 }
443 }
444 return (int64_t)sse;
445 }
446 #endif
447