• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "config/aom_dsp_rtcd.h"
17 
aom_sum_squares_2d_i16_4xh_sve(const int16_t * src,int stride,int height)18 static INLINE uint64_t aom_sum_squares_2d_i16_4xh_sve(const int16_t *src,
19                                                       int stride, int height) {
20   int64x2_t sum_squares = vdupq_n_s64(0);
21 
22   do {
23     int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride));
24 
25     sum_squares = aom_sdotq_s16(sum_squares, s, s);
26 
27     src += 2 * stride;
28     height -= 2;
29   } while (height != 0);
30 
31   return (uint64_t)vaddvq_s64(sum_squares);
32 }
33 
aom_sum_squares_2d_i16_8xh_sve(const int16_t * src,int stride,int height)34 static INLINE uint64_t aom_sum_squares_2d_i16_8xh_sve(const int16_t *src,
35                                                       int stride, int height) {
36   int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
37 
38   do {
39     int16x8_t s0 = vld1q_s16(src + 0 * stride);
40     int16x8_t s1 = vld1q_s16(src + 1 * stride);
41 
42     sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0);
43     sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1);
44 
45     src += 2 * stride;
46     height -= 2;
47   } while (height != 0);
48 
49   sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]);
50   return (uint64_t)vaddvq_s64(sum_squares[0]);
51 }
52 
aom_sum_squares_2d_i16_large_sve(const int16_t * src,int stride,int width,int height)53 static INLINE uint64_t aom_sum_squares_2d_i16_large_sve(const int16_t *src,
54                                                         int stride, int width,
55                                                         int height) {
56   int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
57 
58   do {
59     const int16_t *src_ptr = src;
60     int w = width;
61     do {
62       int16x8_t s0 = vld1q_s16(src_ptr);
63       int16x8_t s1 = vld1q_s16(src_ptr + 8);
64 
65       sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0);
66       sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1);
67 
68       src_ptr += 16;
69       w -= 16;
70     } while (w != 0);
71 
72     src += stride;
73   } while (--height != 0);
74 
75   sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]);
76   return (uint64_t)vaddvq_s64(sum_squares[0]);
77 }
78 
aom_sum_squares_2d_i16_wxh_sve(const int16_t * src,int stride,int width,int height)79 static INLINE uint64_t aom_sum_squares_2d_i16_wxh_sve(const int16_t *src,
80                                                       int stride, int width,
81                                                       int height) {
82   svint64_t sum_squares = svdup_n_s64(0);
83   uint64_t step = svcnth();
84 
85   do {
86     const int16_t *src_ptr = src;
87     int w = 0;
88     do {
89       svbool_t pred = svwhilelt_b16_u32(w, width);
90       svint16_t s0 = svld1_s16(pred, src_ptr);
91 
92       sum_squares = svdot_s64(sum_squares, s0, s0);
93 
94       src_ptr += step;
95       w += step;
96     } while (w < width);
97 
98     src += stride;
99   } while (--height != 0);
100 
101   return (uint64_t)svaddv_s64(svptrue_b64(), sum_squares);
102 }
103 
aom_sum_squares_2d_i16_sve(const int16_t * src,int stride,int width,int height)104 uint64_t aom_sum_squares_2d_i16_sve(const int16_t *src, int stride, int width,
105                                     int height) {
106   if (width == 4) {
107     return aom_sum_squares_2d_i16_4xh_sve(src, stride, height);
108   }
109   if (width == 8) {
110     return aom_sum_squares_2d_i16_8xh_sve(src, stride, height);
111   }
112   if (width % 16 == 0) {
113     return aom_sum_squares_2d_i16_large_sve(src, stride, width, height);
114   }
115   return aom_sum_squares_2d_i16_wxh_sve(src, stride, width, height);
116 }
117 
aom_sum_squares_i16_sve(const int16_t * src,uint32_t n)118 uint64_t aom_sum_squares_i16_sve(const int16_t *src, uint32_t n) {
119   // This function seems to be called only for values of N >= 64. See
120   // av1/encoder/compound_type.c. Additionally, because N = width x height for
121   // width and height between the standard block sizes, N will also be a
122   // multiple of 64.
123   if (LIKELY(n % 64 == 0)) {
124     int64x2_t sum[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0),
125                          vdupq_n_s64(0) };
126 
127     do {
128       int16x8_t s0 = vld1q_s16(src);
129       int16x8_t s1 = vld1q_s16(src + 8);
130       int16x8_t s2 = vld1q_s16(src + 16);
131       int16x8_t s3 = vld1q_s16(src + 24);
132 
133       sum[0] = aom_sdotq_s16(sum[0], s0, s0);
134       sum[1] = aom_sdotq_s16(sum[1], s1, s1);
135       sum[2] = aom_sdotq_s16(sum[2], s2, s2);
136       sum[3] = aom_sdotq_s16(sum[3], s3, s3);
137 
138       src += 32;
139       n -= 32;
140     } while (n != 0);
141 
142     sum[0] = vaddq_s64(sum[0], sum[1]);
143     sum[2] = vaddq_s64(sum[2], sum[3]);
144     sum[0] = vaddq_s64(sum[0], sum[2]);
145     return vaddvq_s64(sum[0]);
146   }
147   return aom_sum_squares_i16_c(src, n);
148 }
149 
aom_sum_sse_2d_i16_4xh_sve(const int16_t * src,int stride,int height,int * sum)150 static INLINE uint64_t aom_sum_sse_2d_i16_4xh_sve(const int16_t *src,
151                                                   int stride, int height,
152                                                   int *sum) {
153   int64x2_t sse = vdupq_n_s64(0);
154   int32x4_t sum_s32 = vdupq_n_s32(0);
155 
156   do {
157     int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride));
158 
159     sse = aom_sdotq_s16(sse, s, s);
160 
161     sum_s32 = vpadalq_s16(sum_s32, s);
162 
163     src += 2 * stride;
164     height -= 2;
165   } while (height != 0);
166 
167   *sum += vaddvq_s32(sum_s32);
168   return vaddvq_s64(sse);
169 }
170 
aom_sum_sse_2d_i16_8xh_sve(const int16_t * src,int stride,int height,int * sum)171 static INLINE uint64_t aom_sum_sse_2d_i16_8xh_sve(const int16_t *src,
172                                                   int stride, int height,
173                                                   int *sum) {
174   int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
175   int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
176 
177   do {
178     int16x8_t s0 = vld1q_s16(src);
179     int16x8_t s1 = vld1q_s16(src + stride);
180 
181     sse[0] = aom_sdotq_s16(sse[0], s0, s0);
182     sse[1] = aom_sdotq_s16(sse[1], s1, s1);
183 
184     sum_acc[0] = vpadalq_s16(sum_acc[0], s0);
185     sum_acc[1] = vpadalq_s16(sum_acc[1], s1);
186 
187     src += 2 * stride;
188     height -= 2;
189   } while (height != 0);
190 
191   *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1]));
192   return vaddvq_s64(vaddq_s64(sse[0], sse[1]));
193 }
194 
aom_sum_sse_2d_i16_16xh_sve(const int16_t * src,int stride,int width,int height,int * sum)195 static INLINE uint64_t aom_sum_sse_2d_i16_16xh_sve(const int16_t *src,
196                                                    int stride, int width,
197                                                    int height, int *sum) {
198   int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
199   int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
200 
201   do {
202     int w = 0;
203     do {
204       int16x8_t s0 = vld1q_s16(src + w);
205       int16x8_t s1 = vld1q_s16(src + w + 8);
206 
207       sse[0] = aom_sdotq_s16(sse[0], s0, s0);
208       sse[1] = aom_sdotq_s16(sse[1], s1, s1);
209 
210       sum_acc[0] = vpadalq_s16(sum_acc[0], s0);
211       sum_acc[1] = vpadalq_s16(sum_acc[1], s1);
212 
213       w += 16;
214     } while (w < width);
215 
216     src += stride;
217   } while (--height != 0);
218 
219   *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1]));
220   return vaddvq_s64(vaddq_s64(sse[0], sse[1]));
221 }
222 
aom_sum_sse_2d_i16_sve(const int16_t * src,int stride,int width,int height,int * sum)223 uint64_t aom_sum_sse_2d_i16_sve(const int16_t *src, int stride, int width,
224                                 int height, int *sum) {
225   uint64_t sse;
226 
227   if (width == 4) {
228     sse = aom_sum_sse_2d_i16_4xh_sve(src, stride, height, sum);
229   } else if (width == 8) {
230     sse = aom_sum_sse_2d_i16_8xh_sve(src, stride, height, sum);
231   } else if (width % 16 == 0) {
232     sse = aom_sum_sse_2d_i16_16xh_sve(src, stride, width, height, sum);
233   } else {
234     sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum);
235   }
236 
237   return sse;
238 }
239 
aom_var_2d_u16_4xh_sve(uint8_t * src,int src_stride,int width,int height)240 static INLINE uint64_t aom_var_2d_u16_4xh_sve(uint8_t *src, int src_stride,
241                                               int width, int height) {
242   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
243   uint64_t sum = 0;
244   uint64_t sse = 0;
245   uint32x4_t sum_u32 = vdupq_n_u32(0);
246   uint64x2_t sse_u64 = vdupq_n_u64(0);
247 
248   int h = height;
249   do {
250     uint16x8_t s0 =
251         vcombine_u16(vld1_u16(src_u16), vld1_u16(src_u16 + src_stride));
252 
253     sum_u32 = vpadalq_u16(sum_u32, s0);
254 
255     sse_u64 = aom_udotq_u16(sse_u64, s0, s0);
256 
257     src_u16 += 2 * src_stride;
258     h -= 2;
259   } while (h != 0);
260 
261   sum += vaddlvq_u32(sum_u32);
262   sse += vaddvq_u64(sse_u64);
263 
264   return sse - sum * sum / (width * height);
265 }
266 
aom_var_2d_u16_8xh_sve(uint8_t * src,int src_stride,int width,int height)267 static INLINE uint64_t aom_var_2d_u16_8xh_sve(uint8_t *src, int src_stride,
268                                               int width, int height) {
269   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
270   uint64_t sum = 0;
271   uint64_t sse = 0;
272   uint32x4_t sum_u32 = vdupq_n_u32(0);
273   uint64x2_t sse_u64 = vdupq_n_u64(0);
274 
275   int h = height;
276   do {
277     int w = width;
278     uint16_t *src_ptr = src_u16;
279     do {
280       uint16x8_t s0 = vld1q_u16(src_ptr);
281 
282       sum_u32 = vpadalq_u16(sum_u32, s0);
283 
284       sse_u64 = aom_udotq_u16(sse_u64, s0, s0);
285 
286       src_ptr += 8;
287       w -= 8;
288     } while (w != 0);
289 
290     src_u16 += src_stride;
291   } while (--h != 0);
292 
293   sum += vaddlvq_u32(sum_u32);
294   sse += vaddvq_u64(sse_u64);
295 
296   return sse - sum * sum / (width * height);
297 }
298 
aom_var_2d_u16_16xh_sve(uint8_t * src,int src_stride,int width,int height)299 static INLINE uint64_t aom_var_2d_u16_16xh_sve(uint8_t *src, int src_stride,
300                                                int width, int height) {
301   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
302   uint64_t sum = 0;
303   uint64_t sse = 0;
304   uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
305   uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
306 
307   int h = height;
308   do {
309     int w = width;
310     uint16_t *src_ptr = src_u16;
311     do {
312       uint16x8_t s0 = vld1q_u16(src_ptr);
313       uint16x8_t s1 = vld1q_u16(src_ptr + 8);
314 
315       sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
316       sum_u32[1] = vpadalq_u16(sum_u32[1], s1);
317 
318       sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0);
319       sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1);
320 
321       src_ptr += 16;
322       w -= 16;
323     } while (w != 0);
324 
325     src_u16 += src_stride;
326   } while (--h != 0);
327 
328   sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]);
329   sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]);
330 
331   sum += vaddlvq_u32(sum_u32[0]);
332   sse += vaddvq_u64(sse_u64[0]);
333 
334   return sse - sum * sum / (width * height);
335 }
336 
aom_var_2d_u16_large_sve(uint8_t * src,int src_stride,int width,int height)337 static INLINE uint64_t aom_var_2d_u16_large_sve(uint8_t *src, int src_stride,
338                                                 int width, int height) {
339   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
340   uint64_t sum = 0;
341   uint64_t sse = 0;
342   uint32x4_t sum_u32[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
343                             vdupq_n_u32(0) };
344   uint64x2_t sse_u64[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
345                             vdupq_n_u64(0) };
346 
347   int h = height;
348   do {
349     int w = width;
350     uint16_t *src_ptr = src_u16;
351     do {
352       uint16x8_t s0 = vld1q_u16(src_ptr);
353       uint16x8_t s1 = vld1q_u16(src_ptr + 8);
354       uint16x8_t s2 = vld1q_u16(src_ptr + 16);
355       uint16x8_t s3 = vld1q_u16(src_ptr + 24);
356 
357       sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
358       sum_u32[1] = vpadalq_u16(sum_u32[1], s1);
359       sum_u32[2] = vpadalq_u16(sum_u32[2], s2);
360       sum_u32[3] = vpadalq_u16(sum_u32[3], s3);
361 
362       sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0);
363       sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1);
364       sse_u64[2] = aom_udotq_u16(sse_u64[2], s2, s2);
365       sse_u64[3] = aom_udotq_u16(sse_u64[3], s3, s3);
366 
367       src_ptr += 32;
368       w -= 32;
369     } while (w != 0);
370 
371     src_u16 += src_stride;
372   } while (--h != 0);
373 
374   sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]);
375   sum_u32[2] = vaddq_u32(sum_u32[2], sum_u32[3]);
376   sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[2]);
377   sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]);
378   sse_u64[2] = vaddq_u64(sse_u64[2], sse_u64[3]);
379   sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[2]);
380 
381   sum += vaddlvq_u32(sum_u32[0]);
382   sse += vaddvq_u64(sse_u64[0]);
383 
384   return sse - sum * sum / (width * height);
385 }
386 
aom_var_2d_u16_sve(uint8_t * src,int src_stride,int width,int height)387 uint64_t aom_var_2d_u16_sve(uint8_t *src, int src_stride, int width,
388                             int height) {
389   if (width == 4) {
390     return aom_var_2d_u16_4xh_sve(src, src_stride, width, height);
391   }
392   if (width == 8) {
393     return aom_var_2d_u16_8xh_sve(src, src_stride, width, height);
394   }
395   if (width == 16) {
396     return aom_var_2d_u16_16xh_sve(src, src_stride, width, height);
397   }
398   if (width % 32 == 0) {
399     return aom_var_2d_u16_large_sve(src, src_stride, width, height);
400   }
401   return aom_var_2d_u16_neon(src, src_stride, width, height);
402 }
403