1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 #include "aom/aom_integer.h"
17 #include "aom_dsp/blend.h"
18 #include "mem_neon.h"
19 #include "sum_neon.h"
20
masked_sad_16x1_neon(uint16x8_t sad,const uint8x16_t s0,const uint8x16_t a0,const uint8x16_t b0,const uint8x16_t m0)21 static INLINE uint16x8_t masked_sad_16x1_neon(uint16x8_t sad,
22 const uint8x16_t s0,
23 const uint8x16_t a0,
24 const uint8x16_t b0,
25 const uint8x16_t m0) {
26 uint8x16_t m0_inv = vsubq_u8(vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA), m0);
27 uint16x8_t blend_u16_lo = vmull_u8(vget_low_u8(m0), vget_low_u8(a0));
28 uint16x8_t blend_u16_hi = vmull_u8(vget_high_u8(m0), vget_high_u8(a0));
29 blend_u16_lo = vmlal_u8(blend_u16_lo, vget_low_u8(m0_inv), vget_low_u8(b0));
30 blend_u16_hi = vmlal_u8(blend_u16_hi, vget_high_u8(m0_inv), vget_high_u8(b0));
31
32 uint8x8_t blend_u8_lo = vrshrn_n_u16(blend_u16_lo, AOM_BLEND_A64_ROUND_BITS);
33 uint8x8_t blend_u8_hi = vrshrn_n_u16(blend_u16_hi, AOM_BLEND_A64_ROUND_BITS);
34 uint8x16_t blend_u8 = vcombine_u8(blend_u8_lo, blend_u8_hi);
35 return vpadalq_u8(sad, vabdq_u8(blend_u8, s0));
36 }
37
masked_inv_sadwxhx4d_large_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int width,int height,int h_overflow)38 static INLINE void masked_inv_sadwxhx4d_large_neon(
39 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
40 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
41 int mask_stride, uint32_t res[4], int width, int height, int h_overflow) {
42 uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
43 vdupq_n_u32(0) };
44 int h_limit = height > h_overflow ? h_overflow : height;
45
46 int ref_offset = 0;
47 int i = 0;
48 do {
49 uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
50 vdupq_n_u16(0) };
51 uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
52 vdupq_n_u16(0) };
53
54 do {
55 int j = 0;
56 do {
57 uint8x16_t s0 = vld1q_u8(src + j);
58 uint8x16_t p0 = vld1q_u8(second_pred + j);
59 uint8x16_t m0 = vld1q_u8(mask + j);
60 sum_lo[0] = masked_sad_16x1_neon(sum_lo[0], s0, p0,
61 vld1q_u8(ref[0] + ref_offset + j), m0);
62 sum_lo[1] = masked_sad_16x1_neon(sum_lo[1], s0, p0,
63 vld1q_u8(ref[1] + ref_offset + j), m0);
64 sum_lo[2] = masked_sad_16x1_neon(sum_lo[2], s0, p0,
65 vld1q_u8(ref[2] + ref_offset + j), m0);
66 sum_lo[3] = masked_sad_16x1_neon(sum_lo[3], s0, p0,
67 vld1q_u8(ref[3] + ref_offset + j), m0);
68
69 uint8x16_t s1 = vld1q_u8(src + j + 16);
70 uint8x16_t p1 = vld1q_u8(second_pred + j + 16);
71 uint8x16_t m1 = vld1q_u8(mask + j + 16);
72 sum_hi[0] = masked_sad_16x1_neon(
73 sum_hi[0], s1, p1, vld1q_u8(ref[0] + ref_offset + j + 16), m1);
74 sum_hi[1] = masked_sad_16x1_neon(
75 sum_hi[1], s1, p1, vld1q_u8(ref[1] + ref_offset + j + 16), m1);
76 sum_hi[2] = masked_sad_16x1_neon(
77 sum_hi[2], s1, p1, vld1q_u8(ref[2] + ref_offset + j + 16), m1);
78 sum_hi[3] = masked_sad_16x1_neon(
79 sum_hi[3], s1, p1, vld1q_u8(ref[3] + ref_offset + j + 16), m1);
80
81 j += 32;
82 } while (j < width);
83
84 src += src_stride;
85 ref_offset += ref_stride;
86 second_pred += width;
87 mask += mask_stride;
88 } while (++i < h_limit);
89
90 sum[0] = vpadalq_u16(sum[0], sum_lo[0]);
91 sum[0] = vpadalq_u16(sum[0], sum_hi[0]);
92 sum[1] = vpadalq_u16(sum[1], sum_lo[1]);
93 sum[1] = vpadalq_u16(sum[1], sum_hi[1]);
94 sum[2] = vpadalq_u16(sum[2], sum_lo[2]);
95 sum[2] = vpadalq_u16(sum[2], sum_hi[2]);
96 sum[3] = vpadalq_u16(sum[3], sum_lo[3]);
97 sum[3] = vpadalq_u16(sum[3], sum_hi[3]);
98
99 h_limit += h_overflow;
100 } while (i < height);
101
102 vst1q_u32(res, horizontal_add_4d_u32x4(sum));
103 }
104
masked_inv_sad128xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)105 static INLINE void masked_inv_sad128xhx4d_neon(
106 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
107 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
108 int mask_stride, uint32_t res[4], int h) {
109 masked_inv_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred,
110 mask, mask_stride, res, 128, h, 32);
111 }
112
masked_inv_sad64xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)113 static INLINE void masked_inv_sad64xhx4d_neon(
114 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
115 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
116 int mask_stride, uint32_t res[4], int h) {
117 masked_inv_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred,
118 mask, mask_stride, res, 64, h, 64);
119 }
120
masked_sadwxhx4d_large_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int width,int height,int h_overflow)121 static INLINE void masked_sadwxhx4d_large_neon(
122 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
123 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
124 int mask_stride, uint32_t res[4], int width, int height, int h_overflow) {
125 uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
126 vdupq_n_u32(0) };
127 int h_limit = height > h_overflow ? h_overflow : height;
128
129 int ref_offset = 0;
130 int i = 0;
131 do {
132 uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
133 vdupq_n_u16(0) };
134 uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
135 vdupq_n_u16(0) };
136
137 do {
138 int j = 0;
139 do {
140 uint8x16_t s0 = vld1q_u8(src + j);
141 uint8x16_t p0 = vld1q_u8(second_pred + j);
142 uint8x16_t m0 = vld1q_u8(mask + j);
143 sum_lo[0] = masked_sad_16x1_neon(
144 sum_lo[0], s0, vld1q_u8(ref[0] + ref_offset + j), p0, m0);
145 sum_lo[1] = masked_sad_16x1_neon(
146 sum_lo[1], s0, vld1q_u8(ref[1] + ref_offset + j), p0, m0);
147 sum_lo[2] = masked_sad_16x1_neon(
148 sum_lo[2], s0, vld1q_u8(ref[2] + ref_offset + j), p0, m0);
149 sum_lo[3] = masked_sad_16x1_neon(
150 sum_lo[3], s0, vld1q_u8(ref[3] + ref_offset + j), p0, m0);
151
152 uint8x16_t s1 = vld1q_u8(src + j + 16);
153 uint8x16_t p1 = vld1q_u8(second_pred + j + 16);
154 uint8x16_t m1 = vld1q_u8(mask + j + 16);
155 sum_hi[0] = masked_sad_16x1_neon(
156 sum_hi[0], s1, vld1q_u8(ref[0] + ref_offset + j + 16), p1, m1);
157 sum_hi[1] = masked_sad_16x1_neon(
158 sum_hi[1], s1, vld1q_u8(ref[1] + ref_offset + j + 16), p1, m1);
159 sum_hi[2] = masked_sad_16x1_neon(
160 sum_hi[2], s1, vld1q_u8(ref[2] + ref_offset + j + 16), p1, m1);
161 sum_hi[3] = masked_sad_16x1_neon(
162 sum_hi[3], s1, vld1q_u8(ref[3] + ref_offset + j + 16), p1, m1);
163
164 j += 32;
165 } while (j < width);
166
167 src += src_stride;
168 ref_offset += ref_stride;
169 second_pred += width;
170 mask += mask_stride;
171 } while (++i < h_limit);
172
173 sum[0] = vpadalq_u16(sum[0], sum_lo[0]);
174 sum[0] = vpadalq_u16(sum[0], sum_hi[0]);
175 sum[1] = vpadalq_u16(sum[1], sum_lo[1]);
176 sum[1] = vpadalq_u16(sum[1], sum_hi[1]);
177 sum[2] = vpadalq_u16(sum[2], sum_lo[2]);
178 sum[2] = vpadalq_u16(sum[2], sum_hi[2]);
179 sum[3] = vpadalq_u16(sum[3], sum_lo[3]);
180 sum[3] = vpadalq_u16(sum[3], sum_hi[3]);
181
182 h_limit += h_overflow;
183 } while (i < height);
184
185 vst1q_u32(res, horizontal_add_4d_u32x4(sum));
186 }
187
masked_sad128xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)188 static INLINE void masked_sad128xhx4d_neon(const uint8_t *src, int src_stride,
189 const uint8_t *const ref[4],
190 int ref_stride,
191 const uint8_t *second_pred,
192 const uint8_t *mask, int mask_stride,
193 uint32_t res[4], int h) {
194 masked_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred,
195 mask, mask_stride, res, 128, h, 32);
196 }
197
masked_sad64xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)198 static INLINE void masked_sad64xhx4d_neon(const uint8_t *src, int src_stride,
199 const uint8_t *const ref[4],
200 int ref_stride,
201 const uint8_t *second_pred,
202 const uint8_t *mask, int mask_stride,
203 uint32_t res[4], int h) {
204 masked_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred,
205 mask, mask_stride, res, 64, h, 64);
206 }
207
masked_inv_sad32xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)208 static INLINE void masked_inv_sad32xhx4d_neon(
209 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
210 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
211 int mask_stride, uint32_t res[4], int h) {
212 uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
213 vdupq_n_u16(0) };
214 uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
215 vdupq_n_u16(0) };
216
217 int ref_offset = 0;
218 int i = h;
219 do {
220 uint8x16_t s0 = vld1q_u8(src);
221 uint8x16_t p0 = vld1q_u8(second_pred);
222 uint8x16_t m0 = vld1q_u8(mask);
223 sum_lo[0] = masked_sad_16x1_neon(sum_lo[0], s0, p0,
224 vld1q_u8(ref[0] + ref_offset), m0);
225 sum_lo[1] = masked_sad_16x1_neon(sum_lo[1], s0, p0,
226 vld1q_u8(ref[1] + ref_offset), m0);
227 sum_lo[2] = masked_sad_16x1_neon(sum_lo[2], s0, p0,
228 vld1q_u8(ref[2] + ref_offset), m0);
229 sum_lo[3] = masked_sad_16x1_neon(sum_lo[3], s0, p0,
230 vld1q_u8(ref[3] + ref_offset), m0);
231
232 uint8x16_t s1 = vld1q_u8(src + 16);
233 uint8x16_t p1 = vld1q_u8(second_pred + 16);
234 uint8x16_t m1 = vld1q_u8(mask + 16);
235 sum_hi[0] = masked_sad_16x1_neon(sum_hi[0], s1, p1,
236 vld1q_u8(ref[0] + ref_offset + 16), m1);
237 sum_hi[1] = masked_sad_16x1_neon(sum_hi[1], s1, p1,
238 vld1q_u8(ref[1] + ref_offset + 16), m1);
239 sum_hi[2] = masked_sad_16x1_neon(sum_hi[2], s1, p1,
240 vld1q_u8(ref[2] + ref_offset + 16), m1);
241 sum_hi[3] = masked_sad_16x1_neon(sum_hi[3], s1, p1,
242 vld1q_u8(ref[3] + ref_offset + 16), m1);
243
244 src += src_stride;
245 ref_offset += ref_stride;
246 second_pred += 32;
247 mask += mask_stride;
248 } while (--i != 0);
249
250 vst1q_u32(res, horizontal_long_add_4d_u16x8(sum_lo, sum_hi));
251 }
252
masked_sad32xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)253 static INLINE void masked_sad32xhx4d_neon(const uint8_t *src, int src_stride,
254 const uint8_t *const ref[4],
255 int ref_stride,
256 const uint8_t *second_pred,
257 const uint8_t *mask, int mask_stride,
258 uint32_t res[4], int h) {
259 uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
260 vdupq_n_u16(0) };
261 uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
262 vdupq_n_u16(0) };
263
264 int ref_offset = 0;
265 int i = h;
266 do {
267 uint8x16_t s0 = vld1q_u8(src);
268 uint8x16_t p0 = vld1q_u8(second_pred);
269 uint8x16_t m0 = vld1q_u8(mask);
270 sum_lo[0] = masked_sad_16x1_neon(sum_lo[0], s0,
271 vld1q_u8(ref[0] + ref_offset), p0, m0);
272 sum_lo[1] = masked_sad_16x1_neon(sum_lo[1], s0,
273 vld1q_u8(ref[1] + ref_offset), p0, m0);
274 sum_lo[2] = masked_sad_16x1_neon(sum_lo[2], s0,
275 vld1q_u8(ref[2] + ref_offset), p0, m0);
276 sum_lo[3] = masked_sad_16x1_neon(sum_lo[3], s0,
277 vld1q_u8(ref[3] + ref_offset), p0, m0);
278
279 uint8x16_t s1 = vld1q_u8(src + 16);
280 uint8x16_t p1 = vld1q_u8(second_pred + 16);
281 uint8x16_t m1 = vld1q_u8(mask + 16);
282 sum_hi[0] = masked_sad_16x1_neon(
283 sum_hi[0], s1, vld1q_u8(ref[0] + ref_offset + 16), p1, m1);
284 sum_hi[1] = masked_sad_16x1_neon(
285 sum_hi[1], s1, vld1q_u8(ref[1] + ref_offset + 16), p1, m1);
286 sum_hi[2] = masked_sad_16x1_neon(
287 sum_hi[2], s1, vld1q_u8(ref[2] + ref_offset + 16), p1, m1);
288 sum_hi[3] = masked_sad_16x1_neon(
289 sum_hi[3], s1, vld1q_u8(ref[3] + ref_offset + 16), p1, m1);
290
291 src += src_stride;
292 ref_offset += ref_stride;
293 second_pred += 32;
294 mask += mask_stride;
295 } while (--i != 0);
296
297 vst1q_u32(res, horizontal_long_add_4d_u16x8(sum_lo, sum_hi));
298 }
299
masked_inv_sad16xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)300 static INLINE void masked_inv_sad16xhx4d_neon(
301 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
302 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
303 int mask_stride, uint32_t res[4], int h) {
304 uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
305 vdupq_n_u16(0) };
306 uint32x4_t sum_u32[4];
307
308 int ref_offset = 0;
309 int i = h;
310 do {
311 uint8x16_t s0 = vld1q_u8(src);
312 uint8x16_t p0 = vld1q_u8(second_pred);
313 uint8x16_t m0 = vld1q_u8(mask);
314 sum_u16[0] = masked_sad_16x1_neon(sum_u16[0], s0, p0,
315 vld1q_u8(ref[0] + ref_offset), m0);
316 sum_u16[1] = masked_sad_16x1_neon(sum_u16[1], s0, p0,
317 vld1q_u8(ref[1] + ref_offset), m0);
318 sum_u16[2] = masked_sad_16x1_neon(sum_u16[2], s0, p0,
319 vld1q_u8(ref[2] + ref_offset), m0);
320 sum_u16[3] = masked_sad_16x1_neon(sum_u16[3], s0, p0,
321 vld1q_u8(ref[3] + ref_offset), m0);
322
323 src += src_stride;
324 ref_offset += ref_stride;
325 second_pred += 16;
326 mask += mask_stride;
327 } while (--i != 0);
328
329 sum_u32[0] = vpaddlq_u16(sum_u16[0]);
330 sum_u32[1] = vpaddlq_u16(sum_u16[1]);
331 sum_u32[2] = vpaddlq_u16(sum_u16[2]);
332 sum_u32[3] = vpaddlq_u16(sum_u16[3]);
333
334 vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32));
335 }
336
masked_sad16xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)337 static INLINE void masked_sad16xhx4d_neon(const uint8_t *src, int src_stride,
338 const uint8_t *const ref[4],
339 int ref_stride,
340 const uint8_t *second_pred,
341 const uint8_t *mask, int mask_stride,
342 uint32_t res[4], int h) {
343 uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
344 vdupq_n_u16(0) };
345 uint32x4_t sum_u32[4];
346
347 int ref_offset = 0;
348 int i = h;
349 do {
350 uint8x16_t s0 = vld1q_u8(src);
351 uint8x16_t p0 = vld1q_u8(second_pred);
352 uint8x16_t m0 = vld1q_u8(mask);
353 sum_u16[0] = masked_sad_16x1_neon(sum_u16[0], s0,
354 vld1q_u8(ref[0] + ref_offset), p0, m0);
355 sum_u16[1] = masked_sad_16x1_neon(sum_u16[1], s0,
356 vld1q_u8(ref[1] + ref_offset), p0, m0);
357 sum_u16[2] = masked_sad_16x1_neon(sum_u16[2], s0,
358 vld1q_u8(ref[2] + ref_offset), p0, m0);
359 sum_u16[3] = masked_sad_16x1_neon(sum_u16[3], s0,
360 vld1q_u8(ref[3] + ref_offset), p0, m0);
361
362 src += src_stride;
363 ref_offset += ref_stride;
364 second_pred += 16;
365 mask += mask_stride;
366 } while (--i != 0);
367
368 sum_u32[0] = vpaddlq_u16(sum_u16[0]);
369 sum_u32[1] = vpaddlq_u16(sum_u16[1]);
370 sum_u32[2] = vpaddlq_u16(sum_u16[2]);
371 sum_u32[3] = vpaddlq_u16(sum_u16[3]);
372
373 vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32));
374 }
375
masked_sad_8x1_neon(uint16x8_t sad,const uint8x8_t s0,const uint8x8_t a0,const uint8x8_t b0,const uint8x8_t m0)376 static INLINE uint16x8_t masked_sad_8x1_neon(uint16x8_t sad, const uint8x8_t s0,
377 const uint8x8_t a0,
378 const uint8x8_t b0,
379 const uint8x8_t m0) {
380 uint8x8_t m0_inv = vsub_u8(vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA), m0);
381 uint16x8_t blend_u16 = vmull_u8(m0, a0);
382 blend_u16 = vmlal_u8(blend_u16, m0_inv, b0);
383
384 uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS);
385 return vabal_u8(sad, blend_u8, s0);
386 }
387
masked_inv_sad8xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)388 static INLINE void masked_inv_sad8xhx4d_neon(
389 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
390 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
391 int mask_stride, uint32_t res[4], int h) {
392 uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
393 vdupq_n_u16(0) };
394
395 int ref_offset = 0;
396 int i = h;
397 do {
398 uint8x8_t s0 = vld1_u8(src);
399 uint8x8_t p0 = vld1_u8(second_pred);
400 uint8x8_t m0 = vld1_u8(mask);
401 sum[0] =
402 masked_sad_8x1_neon(sum[0], s0, p0, vld1_u8(ref[0] + ref_offset), m0);
403 sum[1] =
404 masked_sad_8x1_neon(sum[1], s0, p0, vld1_u8(ref[1] + ref_offset), m0);
405 sum[2] =
406 masked_sad_8x1_neon(sum[2], s0, p0, vld1_u8(ref[2] + ref_offset), m0);
407 sum[3] =
408 masked_sad_8x1_neon(sum[3], s0, p0, vld1_u8(ref[3] + ref_offset), m0);
409
410 src += src_stride;
411 ref_offset += ref_stride;
412 second_pred += 8;
413 mask += mask_stride;
414 } while (--i != 0);
415
416 vst1q_u32(res, horizontal_add_4d_u16x8(sum));
417 }
418
masked_sad8xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)419 static INLINE void masked_sad8xhx4d_neon(const uint8_t *src, int src_stride,
420 const uint8_t *const ref[4],
421 int ref_stride,
422 const uint8_t *second_pred,
423 const uint8_t *mask, int mask_stride,
424 uint32_t res[4], int h) {
425 uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
426 vdupq_n_u16(0) };
427
428 int ref_offset = 0;
429 int i = h;
430 do {
431 uint8x8_t s0 = vld1_u8(src);
432 uint8x8_t p0 = vld1_u8(second_pred);
433 uint8x8_t m0 = vld1_u8(mask);
434
435 sum[0] =
436 masked_sad_8x1_neon(sum[0], s0, vld1_u8(ref[0] + ref_offset), p0, m0);
437 sum[1] =
438 masked_sad_8x1_neon(sum[1], s0, vld1_u8(ref[1] + ref_offset), p0, m0);
439 sum[2] =
440 masked_sad_8x1_neon(sum[2], s0, vld1_u8(ref[2] + ref_offset), p0, m0);
441 sum[3] =
442 masked_sad_8x1_neon(sum[3], s0, vld1_u8(ref[3] + ref_offset), p0, m0);
443
444 src += src_stride;
445 ref_offset += ref_stride;
446 second_pred += 8;
447 mask += mask_stride;
448 } while (--i != 0);
449
450 vst1q_u32(res, horizontal_add_4d_u16x8(sum));
451 }
452
masked_inv_sad4xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)453 static INLINE void masked_inv_sad4xhx4d_neon(
454 const uint8_t *src, int src_stride, const uint8_t *const ref[4],
455 int ref_stride, const uint8_t *second_pred, const uint8_t *mask,
456 int mask_stride, uint32_t res[4], int h) {
457 uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
458 vdupq_n_u16(0) };
459
460 int ref_offset = 0;
461 int i = h / 2;
462 do {
463 uint8x8_t s = load_unaligned_u8(src, src_stride);
464 uint8x8_t r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride);
465 uint8x8_t r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride);
466 uint8x8_t r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride);
467 uint8x8_t r3 = load_unaligned_u8(ref[3] + ref_offset, ref_stride);
468 uint8x8_t p0 = vld1_u8(second_pred);
469 uint8x8_t m0 = load_unaligned_u8(mask, mask_stride);
470
471 sum[0] = masked_sad_8x1_neon(sum[0], s, p0, r0, m0);
472 sum[1] = masked_sad_8x1_neon(sum[1], s, p0, r1, m0);
473 sum[2] = masked_sad_8x1_neon(sum[2], s, p0, r2, m0);
474 sum[3] = masked_sad_8x1_neon(sum[3], s, p0, r3, m0);
475
476 src += 2 * src_stride;
477 ref_offset += 2 * ref_stride;
478 second_pred += 2 * 4;
479 mask += 2 * mask_stride;
480 } while (--i != 0);
481
482 vst1q_u32(res, horizontal_add_4d_u16x8(sum));
483 }
484
masked_sad4xhx4d_neon(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,uint32_t res[4],int h)485 static INLINE void masked_sad4xhx4d_neon(const uint8_t *src, int src_stride,
486 const uint8_t *const ref[4],
487 int ref_stride,
488 const uint8_t *second_pred,
489 const uint8_t *mask, int mask_stride,
490 uint32_t res[4], int h) {
491 uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
492 vdupq_n_u16(0) };
493
494 int ref_offset = 0;
495 int i = h / 2;
496 do {
497 uint8x8_t s = load_unaligned_u8(src, src_stride);
498 uint8x8_t r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride);
499 uint8x8_t r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride);
500 uint8x8_t r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride);
501 uint8x8_t r3 = load_unaligned_u8(ref[3] + ref_offset, ref_stride);
502 uint8x8_t p0 = vld1_u8(second_pred);
503 uint8x8_t m0 = load_unaligned_u8(mask, mask_stride);
504
505 sum[0] = masked_sad_8x1_neon(sum[0], s, r0, p0, m0);
506 sum[1] = masked_sad_8x1_neon(sum[1], s, r1, p0, m0);
507 sum[2] = masked_sad_8x1_neon(sum[2], s, r2, p0, m0);
508 sum[3] = masked_sad_8x1_neon(sum[3], s, r3, p0, m0);
509
510 src += 2 * src_stride;
511 ref_offset += 2 * ref_stride;
512 second_pred += 2 * 4;
513 mask += 2 * mask_stride;
514 } while (--i != 0);
515
516 vst1q_u32(res, horizontal_add_4d_u16x8(sum));
517 }
518
519 #define MASKED_SAD4D_WXH_NEON(w, h) \
520 void aom_masked_sad##w##x##h##x4d_neon( \
521 const uint8_t *src, int src_stride, const uint8_t *ref[4], \
522 int ref_stride, const uint8_t *second_pred, const uint8_t *msk, \
523 int msk_stride, int invert_mask, uint32_t res[4]) { \
524 if (invert_mask) { \
525 masked_inv_sad##w##xhx4d_neon(src, src_stride, ref, ref_stride, \
526 second_pred, msk, msk_stride, res, h); \
527 } else { \
528 masked_sad##w##xhx4d_neon(src, src_stride, ref, ref_stride, second_pred, \
529 msk, msk_stride, res, h); \
530 } \
531 }
532
533 MASKED_SAD4D_WXH_NEON(4, 8)
534 MASKED_SAD4D_WXH_NEON(4, 4)
535
536 MASKED_SAD4D_WXH_NEON(8, 16)
537 MASKED_SAD4D_WXH_NEON(8, 8)
538 MASKED_SAD4D_WXH_NEON(8, 4)
539
540 MASKED_SAD4D_WXH_NEON(16, 32)
541 MASKED_SAD4D_WXH_NEON(16, 16)
542 MASKED_SAD4D_WXH_NEON(16, 8)
543
544 MASKED_SAD4D_WXH_NEON(32, 64)
545 MASKED_SAD4D_WXH_NEON(32, 32)
546 MASKED_SAD4D_WXH_NEON(32, 16)
547
548 MASKED_SAD4D_WXH_NEON(64, 128)
549 MASKED_SAD4D_WXH_NEON(64, 64)
550 MASKED_SAD4D_WXH_NEON(64, 32)
551
552 MASKED_SAD4D_WXH_NEON(128, 128)
553 MASKED_SAD4D_WXH_NEON(128, 64)
554
555 #if !CONFIG_REALTIME_ONLY
556 MASKED_SAD4D_WXH_NEON(4, 16)
557 MASKED_SAD4D_WXH_NEON(16, 4)
558 MASKED_SAD4D_WXH_NEON(8, 32)
559 MASKED_SAD4D_WXH_NEON(32, 8)
560 MASKED_SAD4D_WXH_NEON(16, 64)
561 MASKED_SAD4D_WXH_NEON(64, 16)
562 #endif
563