1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "config/aom_dsp_rtcd.h"
17
aom_sum_squares_2d_i16_4xh_sve(const int16_t * src,int stride,int height)18 static INLINE uint64_t aom_sum_squares_2d_i16_4xh_sve(const int16_t *src,
19 int stride, int height) {
20 int64x2_t sum_squares = vdupq_n_s64(0);
21
22 do {
23 int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride));
24
25 sum_squares = aom_sdotq_s16(sum_squares, s, s);
26
27 src += 2 * stride;
28 height -= 2;
29 } while (height != 0);
30
31 return (uint64_t)vaddvq_s64(sum_squares);
32 }
33
aom_sum_squares_2d_i16_8xh_sve(const int16_t * src,int stride,int height)34 static INLINE uint64_t aom_sum_squares_2d_i16_8xh_sve(const int16_t *src,
35 int stride, int height) {
36 int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
37
38 do {
39 int16x8_t s0 = vld1q_s16(src + 0 * stride);
40 int16x8_t s1 = vld1q_s16(src + 1 * stride);
41
42 sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0);
43 sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1);
44
45 src += 2 * stride;
46 height -= 2;
47 } while (height != 0);
48
49 sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]);
50 return (uint64_t)vaddvq_s64(sum_squares[0]);
51 }
52
aom_sum_squares_2d_i16_large_sve(const int16_t * src,int stride,int width,int height)53 static INLINE uint64_t aom_sum_squares_2d_i16_large_sve(const int16_t *src,
54 int stride, int width,
55 int height) {
56 int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
57
58 do {
59 const int16_t *src_ptr = src;
60 int w = width;
61 do {
62 int16x8_t s0 = vld1q_s16(src_ptr);
63 int16x8_t s1 = vld1q_s16(src_ptr + 8);
64
65 sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0);
66 sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1);
67
68 src_ptr += 16;
69 w -= 16;
70 } while (w != 0);
71
72 src += stride;
73 } while (--height != 0);
74
75 sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]);
76 return (uint64_t)vaddvq_s64(sum_squares[0]);
77 }
78
aom_sum_squares_2d_i16_wxh_sve(const int16_t * src,int stride,int width,int height)79 static INLINE uint64_t aom_sum_squares_2d_i16_wxh_sve(const int16_t *src,
80 int stride, int width,
81 int height) {
82 svint64_t sum_squares = svdup_n_s64(0);
83 uint64_t step = svcnth();
84
85 do {
86 const int16_t *src_ptr = src;
87 int w = 0;
88 do {
89 svbool_t pred = svwhilelt_b16_u32(w, width);
90 svint16_t s0 = svld1_s16(pred, src_ptr);
91
92 sum_squares = svdot_s64(sum_squares, s0, s0);
93
94 src_ptr += step;
95 w += step;
96 } while (w < width);
97
98 src += stride;
99 } while (--height != 0);
100
101 return (uint64_t)svaddv_s64(svptrue_b64(), sum_squares);
102 }
103
aom_sum_squares_2d_i16_sve(const int16_t * src,int stride,int width,int height)104 uint64_t aom_sum_squares_2d_i16_sve(const int16_t *src, int stride, int width,
105 int height) {
106 if (width == 4) {
107 return aom_sum_squares_2d_i16_4xh_sve(src, stride, height);
108 }
109 if (width == 8) {
110 return aom_sum_squares_2d_i16_8xh_sve(src, stride, height);
111 }
112 if (width % 16 == 0) {
113 return aom_sum_squares_2d_i16_large_sve(src, stride, width, height);
114 }
115 return aom_sum_squares_2d_i16_wxh_sve(src, stride, width, height);
116 }
117
aom_sum_squares_i16_sve(const int16_t * src,uint32_t n)118 uint64_t aom_sum_squares_i16_sve(const int16_t *src, uint32_t n) {
119 // This function seems to be called only for values of N >= 64. See
120 // av1/encoder/compound_type.c. Additionally, because N = width x height for
121 // width and height between the standard block sizes, N will also be a
122 // multiple of 64.
123 if (LIKELY(n % 64 == 0)) {
124 int64x2_t sum[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0),
125 vdupq_n_s64(0) };
126
127 do {
128 int16x8_t s0 = vld1q_s16(src);
129 int16x8_t s1 = vld1q_s16(src + 8);
130 int16x8_t s2 = vld1q_s16(src + 16);
131 int16x8_t s3 = vld1q_s16(src + 24);
132
133 sum[0] = aom_sdotq_s16(sum[0], s0, s0);
134 sum[1] = aom_sdotq_s16(sum[1], s1, s1);
135 sum[2] = aom_sdotq_s16(sum[2], s2, s2);
136 sum[3] = aom_sdotq_s16(sum[3], s3, s3);
137
138 src += 32;
139 n -= 32;
140 } while (n != 0);
141
142 sum[0] = vaddq_s64(sum[0], sum[1]);
143 sum[2] = vaddq_s64(sum[2], sum[3]);
144 sum[0] = vaddq_s64(sum[0], sum[2]);
145 return vaddvq_s64(sum[0]);
146 }
147 return aom_sum_squares_i16_c(src, n);
148 }
149
aom_sum_sse_2d_i16_4xh_sve(const int16_t * src,int stride,int height,int * sum)150 static INLINE uint64_t aom_sum_sse_2d_i16_4xh_sve(const int16_t *src,
151 int stride, int height,
152 int *sum) {
153 int64x2_t sse = vdupq_n_s64(0);
154 int32x4_t sum_s32 = vdupq_n_s32(0);
155
156 do {
157 int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride));
158
159 sse = aom_sdotq_s16(sse, s, s);
160
161 sum_s32 = vpadalq_s16(sum_s32, s);
162
163 src += 2 * stride;
164 height -= 2;
165 } while (height != 0);
166
167 *sum += vaddvq_s32(sum_s32);
168 return vaddvq_s64(sse);
169 }
170
aom_sum_sse_2d_i16_8xh_sve(const int16_t * src,int stride,int height,int * sum)171 static INLINE uint64_t aom_sum_sse_2d_i16_8xh_sve(const int16_t *src,
172 int stride, int height,
173 int *sum) {
174 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
175 int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
176
177 do {
178 int16x8_t s0 = vld1q_s16(src);
179 int16x8_t s1 = vld1q_s16(src + stride);
180
181 sse[0] = aom_sdotq_s16(sse[0], s0, s0);
182 sse[1] = aom_sdotq_s16(sse[1], s1, s1);
183
184 sum_acc[0] = vpadalq_s16(sum_acc[0], s0);
185 sum_acc[1] = vpadalq_s16(sum_acc[1], s1);
186
187 src += 2 * stride;
188 height -= 2;
189 } while (height != 0);
190
191 *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1]));
192 return vaddvq_s64(vaddq_s64(sse[0], sse[1]));
193 }
194
aom_sum_sse_2d_i16_16xh_sve(const int16_t * src,int stride,int width,int height,int * sum)195 static INLINE uint64_t aom_sum_sse_2d_i16_16xh_sve(const int16_t *src,
196 int stride, int width,
197 int height, int *sum) {
198 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
199 int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
200
201 do {
202 int w = 0;
203 do {
204 int16x8_t s0 = vld1q_s16(src + w);
205 int16x8_t s1 = vld1q_s16(src + w + 8);
206
207 sse[0] = aom_sdotq_s16(sse[0], s0, s0);
208 sse[1] = aom_sdotq_s16(sse[1], s1, s1);
209
210 sum_acc[0] = vpadalq_s16(sum_acc[0], s0);
211 sum_acc[1] = vpadalq_s16(sum_acc[1], s1);
212
213 w += 16;
214 } while (w < width);
215
216 src += stride;
217 } while (--height != 0);
218
219 *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1]));
220 return vaddvq_s64(vaddq_s64(sse[0], sse[1]));
221 }
222
aom_sum_sse_2d_i16_sve(const int16_t * src,int stride,int width,int height,int * sum)223 uint64_t aom_sum_sse_2d_i16_sve(const int16_t *src, int stride, int width,
224 int height, int *sum) {
225 uint64_t sse;
226
227 if (width == 4) {
228 sse = aom_sum_sse_2d_i16_4xh_sve(src, stride, height, sum);
229 } else if (width == 8) {
230 sse = aom_sum_sse_2d_i16_8xh_sve(src, stride, height, sum);
231 } else if (width % 16 == 0) {
232 sse = aom_sum_sse_2d_i16_16xh_sve(src, stride, width, height, sum);
233 } else {
234 sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum);
235 }
236
237 return sse;
238 }
239
aom_var_2d_u16_4xh_sve(uint8_t * src,int src_stride,int width,int height)240 static INLINE uint64_t aom_var_2d_u16_4xh_sve(uint8_t *src, int src_stride,
241 int width, int height) {
242 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
243 uint64_t sum = 0;
244 uint64_t sse = 0;
245 uint32x4_t sum_u32 = vdupq_n_u32(0);
246 uint64x2_t sse_u64 = vdupq_n_u64(0);
247
248 int h = height;
249 do {
250 uint16x8_t s0 =
251 vcombine_u16(vld1_u16(src_u16), vld1_u16(src_u16 + src_stride));
252
253 sum_u32 = vpadalq_u16(sum_u32, s0);
254
255 sse_u64 = aom_udotq_u16(sse_u64, s0, s0);
256
257 src_u16 += 2 * src_stride;
258 h -= 2;
259 } while (h != 0);
260
261 sum += vaddlvq_u32(sum_u32);
262 sse += vaddvq_u64(sse_u64);
263
264 return sse - sum * sum / (width * height);
265 }
266
aom_var_2d_u16_8xh_sve(uint8_t * src,int src_stride,int width,int height)267 static INLINE uint64_t aom_var_2d_u16_8xh_sve(uint8_t *src, int src_stride,
268 int width, int height) {
269 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
270 uint64_t sum = 0;
271 uint64_t sse = 0;
272 uint32x4_t sum_u32 = vdupq_n_u32(0);
273 uint64x2_t sse_u64 = vdupq_n_u64(0);
274
275 int h = height;
276 do {
277 int w = width;
278 uint16_t *src_ptr = src_u16;
279 do {
280 uint16x8_t s0 = vld1q_u16(src_ptr);
281
282 sum_u32 = vpadalq_u16(sum_u32, s0);
283
284 sse_u64 = aom_udotq_u16(sse_u64, s0, s0);
285
286 src_ptr += 8;
287 w -= 8;
288 } while (w != 0);
289
290 src_u16 += src_stride;
291 } while (--h != 0);
292
293 sum += vaddlvq_u32(sum_u32);
294 sse += vaddvq_u64(sse_u64);
295
296 return sse - sum * sum / (width * height);
297 }
298
aom_var_2d_u16_16xh_sve(uint8_t * src,int src_stride,int width,int height)299 static INLINE uint64_t aom_var_2d_u16_16xh_sve(uint8_t *src, int src_stride,
300 int width, int height) {
301 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
302 uint64_t sum = 0;
303 uint64_t sse = 0;
304 uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
305 uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
306
307 int h = height;
308 do {
309 int w = width;
310 uint16_t *src_ptr = src_u16;
311 do {
312 uint16x8_t s0 = vld1q_u16(src_ptr);
313 uint16x8_t s1 = vld1q_u16(src_ptr + 8);
314
315 sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
316 sum_u32[1] = vpadalq_u16(sum_u32[1], s1);
317
318 sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0);
319 sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1);
320
321 src_ptr += 16;
322 w -= 16;
323 } while (w != 0);
324
325 src_u16 += src_stride;
326 } while (--h != 0);
327
328 sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]);
329 sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]);
330
331 sum += vaddlvq_u32(sum_u32[0]);
332 sse += vaddvq_u64(sse_u64[0]);
333
334 return sse - sum * sum / (width * height);
335 }
336
aom_var_2d_u16_large_sve(uint8_t * src,int src_stride,int width,int height)337 static INLINE uint64_t aom_var_2d_u16_large_sve(uint8_t *src, int src_stride,
338 int width, int height) {
339 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
340 uint64_t sum = 0;
341 uint64_t sse = 0;
342 uint32x4_t sum_u32[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
343 vdupq_n_u32(0) };
344 uint64x2_t sse_u64[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
345 vdupq_n_u64(0) };
346
347 int h = height;
348 do {
349 int w = width;
350 uint16_t *src_ptr = src_u16;
351 do {
352 uint16x8_t s0 = vld1q_u16(src_ptr);
353 uint16x8_t s1 = vld1q_u16(src_ptr + 8);
354 uint16x8_t s2 = vld1q_u16(src_ptr + 16);
355 uint16x8_t s3 = vld1q_u16(src_ptr + 24);
356
357 sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
358 sum_u32[1] = vpadalq_u16(sum_u32[1], s1);
359 sum_u32[2] = vpadalq_u16(sum_u32[2], s2);
360 sum_u32[3] = vpadalq_u16(sum_u32[3], s3);
361
362 sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0);
363 sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1);
364 sse_u64[2] = aom_udotq_u16(sse_u64[2], s2, s2);
365 sse_u64[3] = aom_udotq_u16(sse_u64[3], s3, s3);
366
367 src_ptr += 32;
368 w -= 32;
369 } while (w != 0);
370
371 src_u16 += src_stride;
372 } while (--h != 0);
373
374 sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]);
375 sum_u32[2] = vaddq_u32(sum_u32[2], sum_u32[3]);
376 sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[2]);
377 sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]);
378 sse_u64[2] = vaddq_u64(sse_u64[2], sse_u64[3]);
379 sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[2]);
380
381 sum += vaddlvq_u32(sum_u32[0]);
382 sse += vaddvq_u64(sse_u64[0]);
383
384 return sse - sum * sum / (width * height);
385 }
386
aom_var_2d_u16_sve(uint8_t * src,int src_stride,int width,int height)387 uint64_t aom_var_2d_u16_sve(uint8_t *src, int src_stride, int width,
388 int height) {
389 if (width == 4) {
390 return aom_var_2d_u16_4xh_sve(src, src_stride, width, height);
391 }
392 if (width == 8) {
393 return aom_var_2d_u16_8xh_sve(src, src_stride, width, height);
394 }
395 if (width == 16) {
396 return aom_var_2d_u16_16xh_sve(src, src_stride, width, height);
397 }
398 if (width % 32 == 0) {
399 return aom_var_2d_u16_large_sve(src, src_stride, width, height);
400 }
401 return aom_var_2d_u16_neon(src, src_stride, width, height);
402 }
403