1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12
13 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
14 #include "aom_dsp/arm/mem_neon.h"
15 #include "config/aom_dsp_rtcd.h"
16
highbd_sse_8x1_neon(const uint16_t * src,const uint16_t * ref,uint64x2_t * sse)17 static INLINE void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
18 uint64x2_t *sse) {
19 uint16x8_t s = vld1q_u16(src);
20 uint16x8_t r = vld1q_u16(ref);
21
22 uint16x8_t abs_diff = vabdq_u16(s, r);
23
24 *sse = aom_udotq_u16(*sse, abs_diff, abs_diff);
25 }
26
highbd_sse_128xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)27 static INLINE int64_t highbd_sse_128xh_sve(const uint16_t *src, int src_stride,
28 const uint16_t *ref, int ref_stride,
29 int height) {
30 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
31 vdupq_n_u64(0) };
32
33 do {
34 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
35 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
36 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
37 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
38 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]);
39 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]);
40 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]);
41 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]);
42 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0]);
43 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[1]);
44 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[2]);
45 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[3]);
46 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[0]);
47 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[1]);
48 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[2]);
49 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[3]);
50
51 src += src_stride;
52 ref += ref_stride;
53 } while (--height != 0);
54
55 sse[0] = vaddq_u64(sse[0], sse[1]);
56 sse[2] = vaddq_u64(sse[2], sse[3]);
57 sse[0] = vaddq_u64(sse[0], sse[2]);
58 return vaddvq_u64(sse[0]);
59 }
60
highbd_sse_64xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)61 static INLINE int64_t highbd_sse_64xh_sve(const uint16_t *src, int src_stride,
62 const uint16_t *ref, int ref_stride,
63 int height) {
64 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
65 vdupq_n_u64(0) };
66
67 do {
68 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
69 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
70 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
71 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
72 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]);
73 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]);
74 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]);
75 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]);
76
77 src += src_stride;
78 ref += ref_stride;
79 } while (--height != 0);
80
81 sse[0] = vaddq_u64(sse[0], sse[1]);
82 sse[2] = vaddq_u64(sse[2], sse[3]);
83 sse[0] = vaddq_u64(sse[0], sse[2]);
84 return vaddvq_u64(sse[0]);
85 }
86
highbd_sse_32xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)87 static INLINE int64_t highbd_sse_32xh_sve(const uint16_t *src, int src_stride,
88 const uint16_t *ref, int ref_stride,
89 int height) {
90 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
91 vdupq_n_u64(0) };
92
93 do {
94 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
95 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
96 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
97 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
98
99 src += src_stride;
100 ref += ref_stride;
101 } while (--height != 0);
102
103 sse[0] = vaddq_u64(sse[0], sse[1]);
104 sse[2] = vaddq_u64(sse[2], sse[3]);
105 sse[0] = vaddq_u64(sse[0], sse[2]);
106 return vaddvq_u64(sse[0]);
107 }
108
highbd_sse_16xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)109 static INLINE int64_t highbd_sse_16xh_sve(const uint16_t *src, int src_stride,
110 const uint16_t *ref, int ref_stride,
111 int height) {
112 uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
113
114 do {
115 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
116 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
117
118 src += src_stride;
119 ref += ref_stride;
120 } while (--height != 0);
121
122 return vaddvq_u64(vaddq_u64(sse[0], sse[1]));
123 }
124
highbd_sse_8xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)125 static INLINE int64_t highbd_sse_8xh_sve(const uint16_t *src, int src_stride,
126 const uint16_t *ref, int ref_stride,
127 int height) {
128 uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
129
130 do {
131 highbd_sse_8x1_neon(src + 0 * src_stride, ref + 0 * ref_stride, &sse[0]);
132 highbd_sse_8x1_neon(src + 1 * src_stride, ref + 1 * ref_stride, &sse[1]);
133
134 src += 2 * src_stride;
135 ref += 2 * ref_stride;
136 height -= 2;
137 } while (height != 0);
138
139 return vaddvq_u64(vaddq_u64(sse[0], sse[1]));
140 }
141
highbd_sse_4xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)142 static INLINE int64_t highbd_sse_4xh_sve(const uint16_t *src, int src_stride,
143 const uint16_t *ref, int ref_stride,
144 int height) {
145 uint64x2_t sse = vdupq_n_u64(0);
146
147 do {
148 uint16x8_t s = load_unaligned_u16_4x2(src, src_stride);
149 uint16x8_t r = load_unaligned_u16_4x2(ref, ref_stride);
150
151 uint16x8_t abs_diff = vabdq_u16(s, r);
152 sse = aom_udotq_u16(sse, abs_diff, abs_diff);
153
154 src += 2 * src_stride;
155 ref += 2 * ref_stride;
156 height -= 2;
157 } while (height != 0);
158
159 return vaddvq_u64(sse);
160 }
161
highbd_sse_wxh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int width,int height)162 static INLINE int64_t highbd_sse_wxh_sve(const uint16_t *src, int src_stride,
163 const uint16_t *ref, int ref_stride,
164 int width, int height) {
165 svuint64_t sse = svdup_n_u64(0);
166 uint64_t step = svcnth();
167
168 do {
169 int w = 0;
170 const uint16_t *src_ptr = src;
171 const uint16_t *ref_ptr = ref;
172
173 do {
174 svbool_t pred = svwhilelt_b16_u32(w, width);
175 svuint16_t s = svld1_u16(pred, src_ptr);
176 svuint16_t r = svld1_u16(pred, ref_ptr);
177
178 svuint16_t abs_diff = svabd_u16_z(pred, s, r);
179
180 sse = svdot_u64(sse, abs_diff, abs_diff);
181
182 src_ptr += step;
183 ref_ptr += step;
184 w += step;
185 } while (w < width);
186
187 src += src_stride;
188 ref += ref_stride;
189 } while (--height != 0);
190
191 return svaddv_u64(svptrue_b64(), sse);
192 }
193
aom_highbd_sse_sve(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,int width,int height)194 int64_t aom_highbd_sse_sve(const uint8_t *src8, int src_stride,
195 const uint8_t *ref8, int ref_stride, int width,
196 int height) {
197 uint16_t *src = CONVERT_TO_SHORTPTR(src8);
198 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
199
200 switch (width) {
201 case 4: return highbd_sse_4xh_sve(src, src_stride, ref, ref_stride, height);
202 case 8: return highbd_sse_8xh_sve(src, src_stride, ref, ref_stride, height);
203 case 16:
204 return highbd_sse_16xh_sve(src, src_stride, ref, ref_stride, height);
205 case 32:
206 return highbd_sse_32xh_sve(src, src_stride, ref, ref_stride, height);
207 case 64:
208 return highbd_sse_64xh_sve(src, src_stride, ref, ref_stride, height);
209 case 128:
210 return highbd_sse_128xh_sve(src, src_stride, ref, ref_stride, height);
211 default:
212 return highbd_sse_wxh_sve(src, src_stride, ref, ref_stride, width,
213 height);
214 }
215 }
216