1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12
13 #include "config/aom_dsp_rtcd.h"
14 #include "aom_dsp/arm/sum_neon.h"
15
highbd_sse_8x1_init_neon(const uint16_t * src,const uint16_t * ref,uint32x4_t * sse_acc0,uint32x4_t * sse_acc1)16 static INLINE void highbd_sse_8x1_init_neon(const uint16_t *src,
17 const uint16_t *ref,
18 uint32x4_t *sse_acc0,
19 uint32x4_t *sse_acc1) {
20 uint16x8_t s = vld1q_u16(src);
21 uint16x8_t r = vld1q_u16(ref);
22
23 uint16x8_t abs_diff = vabdq_u16(s, r);
24 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
25 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
26
27 *sse_acc0 = vmull_u16(abs_diff_lo, abs_diff_lo);
28 *sse_acc1 = vmull_u16(abs_diff_hi, abs_diff_hi);
29 }
30
highbd_sse_8x1_neon(const uint16_t * src,const uint16_t * ref,uint32x4_t * sse_acc0,uint32x4_t * sse_acc1)31 static INLINE void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
32 uint32x4_t *sse_acc0,
33 uint32x4_t *sse_acc1) {
34 uint16x8_t s = vld1q_u16(src);
35 uint16x8_t r = vld1q_u16(ref);
36
37 uint16x8_t abs_diff = vabdq_u16(s, r);
38 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
39 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
40
41 *sse_acc0 = vmlal_u16(*sse_acc0, abs_diff_lo, abs_diff_lo);
42 *sse_acc1 = vmlal_u16(*sse_acc1, abs_diff_hi, abs_diff_hi);
43 }
44
highbd_sse_128xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)45 static INLINE int64_t highbd_sse_128xh_neon(const uint16_t *src, int src_stride,
46 const uint16_t *ref, int ref_stride,
47 int height) {
48 uint32x4_t sse[16];
49 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
50 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
51 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
52 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
53 highbd_sse_8x1_init_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
54 highbd_sse_8x1_init_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
55 highbd_sse_8x1_init_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
56 highbd_sse_8x1_init_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
57 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
58 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
59 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
60 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
61 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
62 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
63 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
64 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
65
66 src += src_stride;
67 ref += ref_stride;
68
69 while (--height != 0) {
70 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
71 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
72 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
73 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
74 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
75 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
76 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
77 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
78 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
79 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
80 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
81 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
82 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
83 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
84 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
85 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
86
87 src += src_stride;
88 ref += ref_stride;
89 }
90
91 return horizontal_long_add_u32x4_x16(sse);
92 }
93
highbd_sse_64xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)94 static INLINE int64_t highbd_sse_64xh_neon(const uint16_t *src, int src_stride,
95 const uint16_t *ref, int ref_stride,
96 int height) {
97 uint32x4_t sse[8];
98 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
99 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
100 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
101 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
102 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
103 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
104 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
105 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
106
107 src += src_stride;
108 ref += ref_stride;
109
110 while (--height != 0) {
111 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
112 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
113 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
114 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
115 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
116 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
117 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
118 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
119
120 src += src_stride;
121 ref += ref_stride;
122 }
123
124 return horizontal_long_add_u32x4_x8(sse);
125 }
126
highbd_sse_32xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)127 static INLINE int64_t highbd_sse_32xh_neon(const uint16_t *src, int src_stride,
128 const uint16_t *ref, int ref_stride,
129 int height) {
130 uint32x4_t sse[8];
131 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
132 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
133 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
134 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
135
136 src += src_stride;
137 ref += ref_stride;
138
139 while (--height != 0) {
140 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
141 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
142 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
143 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
144
145 src += src_stride;
146 ref += ref_stride;
147 }
148
149 return horizontal_long_add_u32x4_x8(sse);
150 }
151
highbd_sse_16xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)152 static INLINE int64_t highbd_sse_16xh_neon(const uint16_t *src, int src_stride,
153 const uint16_t *ref, int ref_stride,
154 int height) {
155 uint32x4_t sse[4];
156 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
157 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
158
159 src += src_stride;
160 ref += ref_stride;
161
162 while (--height != 0) {
163 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
164 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
165
166 src += src_stride;
167 ref += ref_stride;
168 }
169
170 return horizontal_long_add_u32x4_x4(sse);
171 }
172
highbd_sse_8xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)173 static INLINE int64_t highbd_sse_8xh_neon(const uint16_t *src, int src_stride,
174 const uint16_t *ref, int ref_stride,
175 int height) {
176 uint32x4_t sse[2];
177 highbd_sse_8x1_init_neon(src, ref, &sse[0], &sse[1]);
178
179 src += src_stride;
180 ref += ref_stride;
181
182 while (--height != 0) {
183 highbd_sse_8x1_neon(src, ref, &sse[0], &sse[1]);
184
185 src += src_stride;
186 ref += ref_stride;
187 }
188
189 return horizontal_long_add_u32x4_x2(sse);
190 }
191
highbd_sse_4xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)192 static INLINE int64_t highbd_sse_4xh_neon(const uint16_t *src, int src_stride,
193 const uint16_t *ref, int ref_stride,
194 int height) {
195 // Peel the first loop iteration.
196 uint16x4_t s = vld1_u16(src);
197 uint16x4_t r = vld1_u16(ref);
198
199 uint16x4_t abs_diff = vabd_u16(s, r);
200 uint32x4_t sse = vmull_u16(abs_diff, abs_diff);
201
202 src += src_stride;
203 ref += ref_stride;
204
205 while (--height != 0) {
206 s = vld1_u16(src);
207 r = vld1_u16(ref);
208
209 abs_diff = vabd_u16(s, r);
210 sse = vmlal_u16(sse, abs_diff, abs_diff);
211
212 src += src_stride;
213 ref += ref_stride;
214 }
215
216 return horizontal_long_add_u32x4(sse);
217 }
218
highbd_sse_wxh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int width,int height)219 static INLINE int64_t highbd_sse_wxh_neon(const uint16_t *src, int src_stride,
220 const uint16_t *ref, int ref_stride,
221 int width, int height) {
222 // { 0, 1, 2, 3, 4, 5, 6, 7 }
223 uint16x8_t k01234567 = vmovl_u8(vcreate_u8(0x0706050403020100));
224 uint16x8_t remainder_mask = vcltq_u16(k01234567, vdupq_n_u16(width & 7));
225 uint64_t sse = 0;
226
227 do {
228 int w = width;
229 int offset = 0;
230
231 do {
232 uint16x8_t s = vld1q_u16(src + offset);
233 uint16x8_t r = vld1q_u16(ref + offset);
234
235 if (w < 8) {
236 // Mask out-of-range elements.
237 s = vandq_u16(s, remainder_mask);
238 r = vandq_u16(r, remainder_mask);
239 }
240
241 uint16x8_t abs_diff = vabdq_u16(s, r);
242 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
243 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
244
245 uint32x4_t sse_u32 = vmull_u16(abs_diff_lo, abs_diff_lo);
246 sse_u32 = vmlal_u16(sse_u32, abs_diff_hi, abs_diff_hi);
247
248 sse += horizontal_long_add_u32x4(sse_u32);
249
250 offset += 8;
251 w -= 8;
252 } while (w > 0);
253
254 src += src_stride;
255 ref += ref_stride;
256 } while (--height != 0);
257
258 return sse;
259 }
260
aom_highbd_sse_neon(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,int width,int height)261 int64_t aom_highbd_sse_neon(const uint8_t *src8, int src_stride,
262 const uint8_t *ref8, int ref_stride, int width,
263 int height) {
264 uint16_t *src = CONVERT_TO_SHORTPTR(src8);
265 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
266
267 switch (width) {
268 case 4:
269 return highbd_sse_4xh_neon(src, src_stride, ref, ref_stride, height);
270 case 8:
271 return highbd_sse_8xh_neon(src, src_stride, ref, ref_stride, height);
272 case 16:
273 return highbd_sse_16xh_neon(src, src_stride, ref, ref_stride, height);
274 case 32:
275 return highbd_sse_32xh_neon(src, src_stride, ref, ref_stride, height);
276 case 64:
277 return highbd_sse_64xh_neon(src, src_stride, ref, ref_stride, height);
278 case 128:
279 return highbd_sse_128xh_neon(src, src_stride, ref, ref_stride, height);
280 default:
281 return highbd_sse_wxh_neon(src, src_stride, ref, ref_stride, width,
282 height);
283 }
284 }
285