• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17 #include "config/aom_dsp_rtcd.h"
18 
aom_sum_squares_2d_i16_4x4_neon(const int16_t * src,int stride)19 static INLINE uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src,
20                                                        int stride) {
21   int16x4_t s0 = vld1_s16(src + 0 * stride);
22   int16x4_t s1 = vld1_s16(src + 1 * stride);
23   int16x4_t s2 = vld1_s16(src + 2 * stride);
24   int16x4_t s3 = vld1_s16(src + 3 * stride);
25 
26   int32x4_t sum_squares = vmull_s16(s0, s0);
27   sum_squares = vmlal_s16(sum_squares, s1, s1);
28   sum_squares = vmlal_s16(sum_squares, s2, s2);
29   sum_squares = vmlal_s16(sum_squares, s3, s3);
30 
31   return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sum_squares));
32 }
33 
aom_sum_squares_2d_i16_4xn_neon(const int16_t * src,int stride,int height)34 static INLINE uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src,
35                                                        int stride, int height) {
36   int32x4_t sum_squares[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
37 
38   int h = height;
39   do {
40     int16x4_t s0 = vld1_s16(src + 0 * stride);
41     int16x4_t s1 = vld1_s16(src + 1 * stride);
42     int16x4_t s2 = vld1_s16(src + 2 * stride);
43     int16x4_t s3 = vld1_s16(src + 3 * stride);
44 
45     sum_squares[0] = vmlal_s16(sum_squares[0], s0, s0);
46     sum_squares[0] = vmlal_s16(sum_squares[0], s1, s1);
47     sum_squares[1] = vmlal_s16(sum_squares[1], s2, s2);
48     sum_squares[1] = vmlal_s16(sum_squares[1], s3, s3);
49 
50     src += 4 * stride;
51     h -= 4;
52   } while (h != 0);
53 
54   return horizontal_long_add_u32x4(
55       vreinterpretq_u32_s32(vaddq_s32(sum_squares[0], sum_squares[1])));
56 }
57 
aom_sum_squares_2d_i16_nxn_neon(const int16_t * src,int stride,int width,int height)58 static INLINE uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src,
59                                                        int stride, int width,
60                                                        int height) {
61   uint64x2_t sum_squares = vdupq_n_u64(0);
62 
63   int h = height;
64   do {
65     int32x4_t ss_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
66     int w = 0;
67     do {
68       const int16_t *s = src + w;
69       int16x8_t s0 = vld1q_s16(s + 0 * stride);
70       int16x8_t s1 = vld1q_s16(s + 1 * stride);
71       int16x8_t s2 = vld1q_s16(s + 2 * stride);
72       int16x8_t s3 = vld1q_s16(s + 3 * stride);
73 
74       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s0), vget_low_s16(s0));
75       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s1), vget_low_s16(s1));
76       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s2), vget_low_s16(s2));
77       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s3), vget_low_s16(s3));
78       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s0), vget_high_s16(s0));
79       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s1), vget_high_s16(s1));
80       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s2), vget_high_s16(s2));
81       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s3), vget_high_s16(s3));
82       w += 8;
83     } while (w < width);
84 
85     sum_squares = vpadalq_u32(
86         sum_squares, vreinterpretq_u32_s32(vaddq_s32(ss_row[0], ss_row[1])));
87 
88     src += 4 * stride;
89     h -= 4;
90   } while (h != 0);
91 
92   return horizontal_add_u64x2(sum_squares);
93 }
94 
aom_sum_squares_2d_i16_neon(const int16_t * src,int stride,int width,int height)95 uint64_t aom_sum_squares_2d_i16_neon(const int16_t *src, int stride, int width,
96                                      int height) {
97   // 4 elements per row only requires half an SIMD register, so this
98   // must be a special case, but also note that over 75% of all calls
99   // are with size == 4, so it is also the common case.
100   if (LIKELY(width == 4 && height == 4)) {
101     return aom_sum_squares_2d_i16_4x4_neon(src, stride);
102   } else if (LIKELY(width == 4 && (height & 3) == 0)) {
103     return aom_sum_squares_2d_i16_4xn_neon(src, stride, height);
104   } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) {
105     // Generic case
106     return aom_sum_squares_2d_i16_nxn_neon(src, stride, width, height);
107   } else {
108     return aom_sum_squares_2d_i16_c(src, stride, width, height);
109   }
110 }
111 
aom_sum_sse_2d_i16_4x4_neon(const int16_t * src,int stride,int * sum)112 static INLINE uint64_t aom_sum_sse_2d_i16_4x4_neon(const int16_t *src,
113                                                    int stride, int *sum) {
114   int16x4_t s0 = vld1_s16(src + 0 * stride);
115   int16x4_t s1 = vld1_s16(src + 1 * stride);
116   int16x4_t s2 = vld1_s16(src + 2 * stride);
117   int16x4_t s3 = vld1_s16(src + 3 * stride);
118 
119   int32x4_t sse = vmull_s16(s0, s0);
120   sse = vmlal_s16(sse, s1, s1);
121   sse = vmlal_s16(sse, s2, s2);
122   sse = vmlal_s16(sse, s3, s3);
123 
124   int32x4_t sum_01 = vaddl_s16(s0, s1);
125   int32x4_t sum_23 = vaddl_s16(s2, s3);
126   *sum += horizontal_add_s32x4(vaddq_s32(sum_01, sum_23));
127 
128   return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sse));
129 }
130 
aom_sum_sse_2d_i16_4xn_neon(const int16_t * src,int stride,int height,int * sum)131 static INLINE uint64_t aom_sum_sse_2d_i16_4xn_neon(const int16_t *src,
132                                                    int stride, int height,
133                                                    int *sum) {
134   int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
135   int32x2_t sum_acc[2] = { vdup_n_s32(0), vdup_n_s32(0) };
136 
137   int h = height;
138   do {
139     int16x4_t s0 = vld1_s16(src + 0 * stride);
140     int16x4_t s1 = vld1_s16(src + 1 * stride);
141     int16x4_t s2 = vld1_s16(src + 2 * stride);
142     int16x4_t s3 = vld1_s16(src + 3 * stride);
143 
144     sse[0] = vmlal_s16(sse[0], s0, s0);
145     sse[0] = vmlal_s16(sse[0], s1, s1);
146     sse[1] = vmlal_s16(sse[1], s2, s2);
147     sse[1] = vmlal_s16(sse[1], s3, s3);
148 
149     sum_acc[0] = vpadal_s16(sum_acc[0], s0);
150     sum_acc[0] = vpadal_s16(sum_acc[0], s1);
151     sum_acc[1] = vpadal_s16(sum_acc[1], s2);
152     sum_acc[1] = vpadal_s16(sum_acc[1], s3);
153 
154     src += 4 * stride;
155     h -= 4;
156   } while (h != 0);
157 
158   *sum += horizontal_add_s32x4(vcombine_s32(sum_acc[0], sum_acc[1]));
159   return horizontal_long_add_u32x4(
160       vreinterpretq_u32_s32(vaddq_s32(sse[0], sse[1])));
161 }
162 
aom_sum_sse_2d_i16_nxn_neon(const int16_t * src,int stride,int width,int height,int * sum)163 static INLINE uint64_t aom_sum_sse_2d_i16_nxn_neon(const int16_t *src,
164                                                    int stride, int width,
165                                                    int height, int *sum) {
166   uint64x2_t sse = vdupq_n_u64(0);
167   int32x4_t sum_acc = vdupq_n_s32(0);
168 
169   int h = height;
170   do {
171     int32x4_t sse_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
172     int w = 0;
173     do {
174       const int16_t *s = src + w;
175       int16x8_t s0 = vld1q_s16(s + 0 * stride);
176       int16x8_t s1 = vld1q_s16(s + 1 * stride);
177       int16x8_t s2 = vld1q_s16(s + 2 * stride);
178       int16x8_t s3 = vld1q_s16(s + 3 * stride);
179 
180       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s0), vget_low_s16(s0));
181       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s1), vget_low_s16(s1));
182       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s2), vget_low_s16(s2));
183       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s3), vget_low_s16(s3));
184       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s0), vget_high_s16(s0));
185       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s1), vget_high_s16(s1));
186       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s2), vget_high_s16(s2));
187       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s3), vget_high_s16(s3));
188 
189       sum_acc = vpadalq_s16(sum_acc, s0);
190       sum_acc = vpadalq_s16(sum_acc, s1);
191       sum_acc = vpadalq_s16(sum_acc, s2);
192       sum_acc = vpadalq_s16(sum_acc, s3);
193 
194       w += 8;
195     } while (w < width);
196 
197     sse = vpadalq_u32(sse,
198                       vreinterpretq_u32_s32(vaddq_s32(sse_row[0], sse_row[1])));
199 
200     src += 4 * stride;
201     h -= 4;
202   } while (h != 0);
203 
204   *sum += horizontal_add_s32x4(sum_acc);
205   return horizontal_add_u64x2(sse);
206 }
207 
aom_sum_sse_2d_i16_neon(const int16_t * src,int stride,int width,int height,int * sum)208 uint64_t aom_sum_sse_2d_i16_neon(const int16_t *src, int stride, int width,
209                                  int height, int *sum) {
210   uint64_t sse;
211 
212   if (LIKELY(width == 4 && height == 4)) {
213     sse = aom_sum_sse_2d_i16_4x4_neon(src, stride, sum);
214   } else if (LIKELY(width == 4 && (height & 3) == 0)) {
215     // width = 4, height is a multiple of 4.
216     sse = aom_sum_sse_2d_i16_4xn_neon(src, stride, height, sum);
217   } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) {
218     // Generic case - width is multiple of 8, height is multiple of 4.
219     sse = aom_sum_sse_2d_i16_nxn_neon(src, stride, width, height, sum);
220   } else {
221     sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum);
222   }
223 
224   return sse;
225 }
226 
aom_sum_squares_i16_4xn_neon(const int16_t * src,uint32_t n)227 static INLINE uint64_t aom_sum_squares_i16_4xn_neon(const int16_t *src,
228                                                     uint32_t n) {
229   uint64x2_t sum_u64 = vdupq_n_u64(0);
230 
231   int i = n;
232   do {
233     uint32x4_t sum;
234     int16x4_t s0 = vld1_s16(src);
235 
236     sum = vreinterpretq_u32_s32(vmull_s16(s0, s0));
237 
238     sum_u64 = vpadalq_u32(sum_u64, sum);
239 
240     src += 4;
241     i -= 4;
242   } while (i >= 4);
243 
244   if (i > 0) {
245     return horizontal_add_u64x2(sum_u64) + aom_sum_squares_i16_c(src, i);
246   }
247   return horizontal_add_u64x2(sum_u64);
248 }
249 
aom_sum_squares_i16_8xn_neon(const int16_t * src,uint32_t n)250 static INLINE uint64_t aom_sum_squares_i16_8xn_neon(const int16_t *src,
251                                                     uint32_t n) {
252   uint64x2_t sum_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
253 
254   int i = n;
255   do {
256     uint32x4_t sum[2];
257     int16x8_t s0 = vld1q_s16(src);
258 
259     sum[0] =
260         vreinterpretq_u32_s32(vmull_s16(vget_low_s16(s0), vget_low_s16(s0)));
261     sum[1] =
262         vreinterpretq_u32_s32(vmull_s16(vget_high_s16(s0), vget_high_s16(s0)));
263 
264     sum_u64[0] = vpadalq_u32(sum_u64[0], sum[0]);
265     sum_u64[1] = vpadalq_u32(sum_u64[1], sum[1]);
266 
267     src += 8;
268     i -= 8;
269   } while (i >= 8);
270 
271   if (i > 0) {
272     return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1])) +
273            aom_sum_squares_i16_c(src, i);
274   }
275   return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1]));
276 }
277 
aom_sum_squares_i16_neon(const int16_t * src,uint32_t n)278 uint64_t aom_sum_squares_i16_neon(const int16_t *src, uint32_t n) {
279   // This function seems to be called only for values of N >= 64. See
280   // av1/encoder/compound_type.c.
281   if (LIKELY(n >= 8)) {
282     return aom_sum_squares_i16_8xn_neon(src, n);
283   }
284   if (n >= 4) {
285     return aom_sum_squares_i16_4xn_neon(src, n);
286   }
287   return aom_sum_squares_i16_c(src, n);
288 }
289 
aom_var_2d_u8_4xh_neon(uint8_t * src,int src_stride,int width,int height)290 static INLINE uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride,
291                                               int width, int height) {
292   uint64_t sum = 0;
293   uint64_t sse = 0;
294   uint32x2_t sum_u32 = vdup_n_u32(0);
295   uint32x4_t sse_u32 = vdupq_n_u32(0);
296 
297   // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
298   // element before we need to accumulate to 32-bit elements. Since we're
299   // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
300   // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
301   // * 256) / width.
302   int h_limit = (4 * 256) / width;
303   int h_tmp = height > h_limit ? h_limit : height;
304 
305   int h = 0;
306   do {
307     uint16x4_t sum_u16 = vdup_n_u16(0);
308     do {
309       uint8_t *src_ptr = src;
310       int w = width;
311       do {
312         uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride);
313 
314         sum_u16 = vpadal_u8(sum_u16, s0);
315 
316         uint16x8_t sse_u16 = vmull_u8(s0, s0);
317 
318         sse_u32 = vpadalq_u16(sse_u32, sse_u16);
319 
320         src_ptr += 8;
321         w -= 8;
322       } while (w >= 8);
323 
324       // Process remaining columns in the row using C.
325       while (w > 0) {
326         int idx = width - w;
327         const uint8_t v = src[idx];
328         sum += v;
329         sse += v * v;
330         w--;
331       }
332 
333       src += 2 * src_stride;
334       h += 2;
335     } while (h < h_tmp && h < height);
336 
337     sum_u32 = vpadal_u16(sum_u32, sum_u16);
338     h_tmp += h_limit;
339   } while (h < height);
340 
341   sum += horizontal_long_add_u32x2(sum_u32);
342   sse += horizontal_long_add_u32x4(sse_u32);
343 
344   return sse - sum * sum / (width * height);
345 }
346 
aom_var_2d_u8_8xh_neon(uint8_t * src,int src_stride,int width,int height)347 static INLINE uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride,
348                                               int width, int height) {
349   uint64_t sum = 0;
350   uint64_t sse = 0;
351   uint32x2_t sum_u32 = vdup_n_u32(0);
352   uint32x4_t sse_u32 = vdupq_n_u32(0);
353 
354   // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
355   // element before we need to accumulate to 32-bit elements. Since we're
356   // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
357   // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
358   // * 256) / width.
359   int h_limit = (4 * 256) / width;
360   int h_tmp = height > h_limit ? h_limit : height;
361 
362   int h = 0;
363   do {
364     uint16x4_t sum_u16 = vdup_n_u16(0);
365     do {
366       uint8_t *src_ptr = src;
367       int w = width;
368       do {
369         uint8x8_t s0 = vld1_u8(src_ptr);
370 
371         sum_u16 = vpadal_u8(sum_u16, s0);
372 
373         uint16x8_t sse_u16 = vmull_u8(s0, s0);
374 
375         sse_u32 = vpadalq_u16(sse_u32, sse_u16);
376 
377         src_ptr += 8;
378         w -= 8;
379       } while (w >= 8);
380 
381       // Process remaining columns in the row using C.
382       while (w > 0) {
383         int idx = width - w;
384         const uint8_t v = src[idx];
385         sum += v;
386         sse += v * v;
387         w--;
388       }
389 
390       src += src_stride;
391       ++h;
392     } while (h < h_tmp && h < height);
393 
394     sum_u32 = vpadal_u16(sum_u32, sum_u16);
395     h_tmp += h_limit;
396   } while (h < height);
397 
398   sum += horizontal_long_add_u32x2(sum_u32);
399   sse += horizontal_long_add_u32x4(sse_u32);
400 
401   return sse - sum * sum / (width * height);
402 }
403 
aom_var_2d_u8_16xh_neon(uint8_t * src,int src_stride,int width,int height)404 static INLINE uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride,
405                                                int width, int height) {
406   uint64_t sum = 0;
407   uint64_t sse = 0;
408   uint32x4_t sum_u32 = vdupq_n_u32(0);
409   uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
410 
411   // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
412   // element before we need to accumulate to 32-bit elements. Since we're
413   // accumulating in uint16x8_t vectors, this means we can accumulate up to 8
414   // rows of 256 elements. Therefore the limit can be computed as: h_limit = (8
415   // * 256) / width.
416   int h_limit = (8 * 256) / width;
417   int h_tmp = height > h_limit ? h_limit : height;
418 
419   int h = 0;
420   do {
421     uint16x8_t sum_u16 = vdupq_n_u16(0);
422     do {
423       int w = width;
424       uint8_t *src_ptr = src;
425       do {
426         uint8x16_t s0 = vld1q_u8(src_ptr);
427 
428         sum_u16 = vpadalq_u8(sum_u16, s0);
429 
430         uint16x8_t sse_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(s0));
431         uint16x8_t sse_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(s0));
432 
433         sse_u32[0] = vpadalq_u16(sse_u32[0], sse_u16_lo);
434         sse_u32[1] = vpadalq_u16(sse_u32[1], sse_u16_hi);
435 
436         src_ptr += 16;
437         w -= 16;
438       } while (w >= 16);
439 
440       // Process remaining columns in the row using C.
441       while (w > 0) {
442         int idx = width - w;
443         const uint8_t v = src[idx];
444         sum += v;
445         sse += v * v;
446         w--;
447       }
448 
449       src += src_stride;
450       ++h;
451     } while (h < h_tmp && h < height);
452 
453     sum_u32 = vpadalq_u16(sum_u32, sum_u16);
454     h_tmp += h_limit;
455   } while (h < height);
456 
457   sum += horizontal_long_add_u32x4(sum_u32);
458   sse += horizontal_long_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
459 
460   return sse - sum * sum / (width * height);
461 }
462 
aom_var_2d_u8_neon(uint8_t * src,int src_stride,int width,int height)463 uint64_t aom_var_2d_u8_neon(uint8_t *src, int src_stride, int width,
464                             int height) {
465   if (width >= 16) {
466     return aom_var_2d_u8_16xh_neon(src, src_stride, width, height);
467   }
468   if (width >= 8) {
469     return aom_var_2d_u8_8xh_neon(src, src_stride, width, height);
470   }
471   if (width >= 4 && height % 2 == 0) {
472     return aom_var_2d_u8_4xh_neon(src, src_stride, width, height);
473   }
474   return aom_var_2d_u8_c(src, src_stride, width, height);
475 }
476 
aom_var_2d_u16_4xh_neon(uint8_t * src,int src_stride,int width,int height)477 static INLINE uint64_t aom_var_2d_u16_4xh_neon(uint8_t *src, int src_stride,
478                                                int width, int height) {
479   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
480   uint64_t sum = 0;
481   uint64_t sse = 0;
482   uint32x2_t sum_u32 = vdup_n_u32(0);
483   uint64x2_t sse_u64 = vdupq_n_u64(0);
484 
485   int h = height;
486   do {
487     int w = width;
488     uint16_t *src_ptr = src_u16;
489     do {
490       uint16x4_t s0 = vld1_u16(src_ptr);
491 
492       sum_u32 = vpadal_u16(sum_u32, s0);
493 
494       uint32x4_t sse_u32 = vmull_u16(s0, s0);
495 
496       sse_u64 = vpadalq_u32(sse_u64, sse_u32);
497 
498       src_ptr += 4;
499       w -= 4;
500     } while (w >= 4);
501 
502     // Process remaining columns in the row using C.
503     while (w > 0) {
504       int idx = width - w;
505       const uint16_t v = src_u16[idx];
506       sum += v;
507       sse += v * v;
508       w--;
509     }
510 
511     src_u16 += src_stride;
512   } while (--h != 0);
513 
514   sum += horizontal_long_add_u32x2(sum_u32);
515   sse += horizontal_add_u64x2(sse_u64);
516 
517   return sse - sum * sum / (width * height);
518 }
519 
aom_var_2d_u16_8xh_neon(uint8_t * src,int src_stride,int width,int height)520 static INLINE uint64_t aom_var_2d_u16_8xh_neon(uint8_t *src, int src_stride,
521                                                int width, int height) {
522   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
523   uint64_t sum = 0;
524   uint64_t sse = 0;
525   uint32x4_t sum_u32 = vdupq_n_u32(0);
526   uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
527 
528   int h = height;
529   do {
530     int w = width;
531     uint16_t *src_ptr = src_u16;
532     do {
533       uint16x8_t s0 = vld1q_u16(src_ptr);
534 
535       sum_u32 = vpadalq_u16(sum_u32, s0);
536 
537       uint32x4_t sse_u32_lo = vmull_u16(vget_low_u16(s0), vget_low_u16(s0));
538       uint32x4_t sse_u32_hi = vmull_u16(vget_high_u16(s0), vget_high_u16(s0));
539 
540       sse_u64[0] = vpadalq_u32(sse_u64[0], sse_u32_lo);
541       sse_u64[1] = vpadalq_u32(sse_u64[1], sse_u32_hi);
542 
543       src_ptr += 8;
544       w -= 8;
545     } while (w >= 8);
546 
547     // Process remaining columns in the row using C.
548     while (w > 0) {
549       int idx = width - w;
550       const uint16_t v = src_u16[idx];
551       sum += v;
552       sse += v * v;
553       w--;
554     }
555 
556     src_u16 += src_stride;
557   } while (--h != 0);
558 
559   sum += horizontal_long_add_u32x4(sum_u32);
560   sse += horizontal_add_u64x2(vaddq_u64(sse_u64[0], sse_u64[1]));
561 
562   return sse - sum * sum / (width * height);
563 }
564 
aom_var_2d_u16_neon(uint8_t * src,int src_stride,int width,int height)565 uint64_t aom_var_2d_u16_neon(uint8_t *src, int src_stride, int width,
566                              int height) {
567   if (width >= 8) {
568     return aom_var_2d_u16_8xh_neon(src, src_stride, width, height);
569   }
570   if (width >= 4) {
571     return aom_var_2d_u16_4xh_neon(src, src_stride, width, height);
572   }
573   return aom_var_2d_u16_c(src, src_stride, width, height);
574 }
575