1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12
13 #include "config/aom_dsp_rtcd.h"
14 #include "aom_dsp/arm/mem_neon.h"
15 #include "aom_dsp/arm/sum_neon.h"
16
sse_16x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x4_t * sse)17 static INLINE void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
18 uint32x4_t *sse) {
19 uint8x16_t s = vld1q_u8(src);
20 uint8x16_t r = vld1q_u8(ref);
21
22 uint8x16_t abs_diff = vabdq_u8(s, r);
23
24 *sse = vdotq_u32(*sse, abs_diff, abs_diff);
25 }
26
sse_8x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x2_t * sse)27 static INLINE void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
28 uint32x2_t *sse) {
29 uint8x8_t s = vld1_u8(src);
30 uint8x8_t r = vld1_u8(ref);
31
32 uint8x8_t abs_diff = vabd_u8(s, r);
33
34 *sse = vdot_u32(*sse, abs_diff, abs_diff);
35 }
36
sse_4x2_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,uint32x2_t * sse)37 static INLINE void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride,
38 const uint8_t *ref, int ref_stride,
39 uint32x2_t *sse) {
40 uint8x8_t s = load_unaligned_u8(src, src_stride);
41 uint8x8_t r = load_unaligned_u8(ref, ref_stride);
42
43 uint8x8_t abs_diff = vabd_u8(s, r);
44
45 *sse = vdot_u32(*sse, abs_diff, abs_diff);
46 }
47
sse_wxh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)48 static INLINE uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride,
49 const uint8_t *ref, int ref_stride,
50 int width, int height) {
51 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
52
53 if ((width & 0x07) && ((width & 0x07) < 5)) {
54 int i = height;
55 do {
56 int j = 0;
57 do {
58 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
59 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
60 &sse[1]);
61 j += 8;
62 } while (j + 4 < width);
63
64 sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]);
65 src += 2 * src_stride;
66 ref += 2 * ref_stride;
67 i -= 2;
68 } while (i != 0);
69 } else {
70 int i = height;
71 do {
72 int j = 0;
73 do {
74 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
75 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
76 &sse[1]);
77 j += 8;
78 } while (j < width);
79
80 src += 2 * src_stride;
81 ref += 2 * ref_stride;
82 i -= 2;
83 } while (i != 0);
84 }
85 return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
86 }
87
sse_128xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)88 static INLINE uint32_t sse_128xh_neon_dotprod(const uint8_t *src,
89 int src_stride,
90 const uint8_t *ref,
91 int ref_stride, int height) {
92 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
93
94 int i = height;
95 do {
96 sse_16x1_neon_dotprod(src, ref, &sse[0]);
97 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
98 sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
99 sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
100 sse_16x1_neon_dotprod(src + 64, ref + 64, &sse[0]);
101 sse_16x1_neon_dotprod(src + 80, ref + 80, &sse[1]);
102 sse_16x1_neon_dotprod(src + 96, ref + 96, &sse[0]);
103 sse_16x1_neon_dotprod(src + 112, ref + 112, &sse[1]);
104
105 src += src_stride;
106 ref += ref_stride;
107 } while (--i != 0);
108
109 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
110 }
111
sse_64xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)112 static INLINE uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride,
113 const uint8_t *ref, int ref_stride,
114 int height) {
115 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
116
117 int i = height;
118 do {
119 sse_16x1_neon_dotprod(src, ref, &sse[0]);
120 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
121 sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
122 sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
123
124 src += src_stride;
125 ref += ref_stride;
126 } while (--i != 0);
127
128 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
129 }
130
sse_32xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)131 static INLINE uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride,
132 const uint8_t *ref, int ref_stride,
133 int height) {
134 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
135
136 int i = height;
137 do {
138 sse_16x1_neon_dotprod(src, ref, &sse[0]);
139 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
140
141 src += src_stride;
142 ref += ref_stride;
143 } while (--i != 0);
144
145 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
146 }
147
sse_16xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)148 static INLINE uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride,
149 const uint8_t *ref, int ref_stride,
150 int height) {
151 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
152
153 int i = height;
154 do {
155 sse_16x1_neon_dotprod(src, ref, &sse[0]);
156 src += src_stride;
157 ref += ref_stride;
158 sse_16x1_neon_dotprod(src, ref, &sse[1]);
159 src += src_stride;
160 ref += ref_stride;
161 i -= 2;
162 } while (i != 0);
163
164 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
165 }
166
sse_8xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)167 static INLINE uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride,
168 const uint8_t *ref, int ref_stride,
169 int height) {
170 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
171
172 int i = height;
173 do {
174 sse_8x1_neon_dotprod(src, ref, &sse[0]);
175 src += src_stride;
176 ref += ref_stride;
177 sse_8x1_neon_dotprod(src, ref, &sse[1]);
178 src += src_stride;
179 ref += ref_stride;
180 i -= 2;
181 } while (i != 0);
182
183 return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
184 }
185
sse_4xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)186 static INLINE uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride,
187 const uint8_t *ref, int ref_stride,
188 int height) {
189 uint32x2_t sse = vdup_n_u32(0);
190
191 int i = height;
192 do {
193 sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse);
194
195 src += 2 * src_stride;
196 ref += 2 * ref_stride;
197 i -= 2;
198 } while (i != 0);
199
200 return horizontal_add_u32x2(sse);
201 }
202
aom_sse_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)203 int64_t aom_sse_neon_dotprod(const uint8_t *src, int src_stride,
204 const uint8_t *ref, int ref_stride, int width,
205 int height) {
206 switch (width) {
207 case 4:
208 return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
209 case 8:
210 return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
211 case 16:
212 return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
213 case 32:
214 return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
215 case 64:
216 return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
217 case 128:
218 return sse_128xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
219 default:
220 return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width,
221 height);
222 }
223 }
224