• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/av1_rtcd.h"
16 
17 #include "aom_dsp/arm/sum_neon.h"
18 #include "av1/common/restoration.h"
19 #include "av1/encoder/arm/neon/pickrst_neon.h"
20 #include "av1/encoder/pickrst.h"
21 
av1_lowbd_pixel_proj_error_neon(const uint8_t * src,int width,int height,int src_stride,const uint8_t * dat,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)22 int64_t av1_lowbd_pixel_proj_error_neon(
23     const uint8_t *src, int width, int height, int src_stride,
24     const uint8_t *dat, int dat_stride, int32_t *flt0, int flt0_stride,
25     int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
26   int64_t sse = 0;
27   int64x2_t sse_s64 = vdupq_n_s64(0);
28 
29   if (params->r[0] > 0 && params->r[1] > 0) {
30     int32x2_t xq_v = vld1_s32(xq);
31     int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), SGRPROJ_RST_BITS);
32 
33     do {
34       int j = 0;
35       int32x4_t sse_s32 = vdupq_n_s32(0);
36 
37       do {
38         const uint8x8_t d = vld1_u8(&dat[j]);
39         const uint8x8_t s = vld1_u8(&src[j]);
40         int32x4_t flt0_0 = vld1q_s32(&flt0[j]);
41         int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]);
42         int32x4_t flt1_0 = vld1q_s32(&flt1[j]);
43         int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]);
44 
45         int32x4_t offset =
46             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1));
47         int32x4_t v0 = vmlaq_lane_s32(offset, flt0_0, xq_v, 0);
48         int32x4_t v1 = vmlaq_lane_s32(offset, flt0_1, xq_v, 0);
49 
50         v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1);
51         v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1);
52 
53         int16x8_t d_s16 = vreinterpretq_s16_u16(vmovl_u8(d));
54         v0 = vmlsl_lane_s16(v0, vget_low_s16(d_s16),
55                             vreinterpret_s16_s32(xq_sum_v), 0);
56         v1 = vmlsl_lane_s16(v1, vget_high_s16(d_s16),
57                             vreinterpret_s16_s32(xq_sum_v), 0);
58 
59         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
60         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
61 
62         int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s));
63         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff);
64         int16x4_t e_lo = vget_low_s16(e);
65         int16x4_t e_hi = vget_high_s16(e);
66 
67         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
68         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
69 
70         j += 8;
71       } while (j <= width - 8);
72 
73       for (int k = j; k < width; ++k) {
74         int32_t u = (dat[k] << SGRPROJ_RST_BITS);
75         int32_t v = (1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)) +
76                     xq[0] * flt0[k] + xq[1] * flt1[k] - u * (xq[0] + xq[1]);
77         int32_t e =
78             (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
79         sse += e * e;
80       }
81 
82       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
83 
84       dat += dat_stride;
85       src += src_stride;
86       flt0 += flt0_stride;
87       flt1 += flt1_stride;
88     } while (--height != 0);
89   } else if (params->r[0] > 0 || params->r[1] > 0) {
90     int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
91     int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
92     int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
93     int32x2_t xq_v = vdup_n_s32(xq_active);
94 
95     do {
96       int32x4_t sse_s32 = vdupq_n_s32(0);
97       int j = 0;
98 
99       do {
100         const uint8x8_t d = vld1_u8(&dat[j]);
101         const uint8x8_t s = vld1_u8(&src[j]);
102         int32x4_t flt_0 = vld1q_s32(&flt[j]);
103         int32x4_t flt_1 = vld1q_s32(&flt[j + 4]);
104         int16x8_t d_s16 =
105             vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
106 
107         int32x4_t sub_0 = vsubw_s16(flt_0, vget_low_s16(d_s16));
108         int32x4_t sub_1 = vsubw_s16(flt_1, vget_high_s16(d_s16));
109 
110         int32x4_t offset =
111             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1));
112         int32x4_t v0 = vmlaq_lane_s32(offset, sub_0, xq_v, 0);
113         int32x4_t v1 = vmlaq_lane_s32(offset, sub_1, xq_v, 0);
114 
115         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
116         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
117 
118         int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s));
119         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff);
120         int16x4_t e_lo = vget_low_s16(e);
121         int16x4_t e_hi = vget_high_s16(e);
122 
123         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
124         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
125 
126         j += 8;
127       } while (j <= width - 8);
128 
129       for (int k = j; k < width; ++k) {
130         int32_t u = dat[k] << SGRPROJ_RST_BITS;
131         int32_t v = xq_active * (flt[k] - u);
132         int32_t e = ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) +
133                     dat[k] - src[k];
134         sse += e * e;
135       }
136 
137       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
138 
139       dat += dat_stride;
140       src += src_stride;
141       flt += flt_stride;
142     } while (--height != 0);
143   } else {
144     uint32x4_t sse_s32 = vdupq_n_u32(0);
145 
146     do {
147       int j = 0;
148 
149       do {
150         const uint8x16_t d = vld1q_u8(&dat[j]);
151         const uint8x16_t s = vld1q_u8(&src[j]);
152 
153         uint8x16_t diff = vabdq_u8(d, s);
154         uint8x8_t diff_lo = vget_low_u8(diff);
155         uint8x8_t diff_hi = vget_high_u8(diff);
156 
157         sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_lo, diff_lo));
158         sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_hi, diff_hi));
159 
160         j += 16;
161       } while (j <= width - 16);
162 
163       for (int k = j; k < width; ++k) {
164         int32_t e = dat[k] - src[k];
165         sse += e * e;
166       }
167 
168       dat += dat_stride;
169       src += src_stride;
170     } while (--height != 0);
171 
172     sse_s64 = vreinterpretq_s64_u64(vpaddlq_u32(sse_s32));
173   }
174 
175   sse += horizontal_add_s64x2(sse_s64);
176   return sse;
177 }
178 
179 // We can accumulate up to 65536 8-bit multiplication results in 32-bit. We are
180 // processing 2 pixels at a time, so the accumulator max can be as high as 32768
181 // for the compute stats.
182 #define STAT_ACCUMULATOR_MAX 32768
183 
tbl2(uint8x16_t a,uint8x16_t b,uint8x8_t idx)184 static INLINE uint8x8_t tbl2(uint8x16_t a, uint8x16_t b, uint8x8_t idx) {
185 #if AOM_ARCH_AARCH64
186   uint8x16x2_t table = { { a, b } };
187   return vqtbl2_u8(table, idx);
188 #else
189   uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b),
190                           vget_high_u8(b) } };
191   return vtbl4_u8(table, idx);
192 #endif
193 }
194 
tbl2q(uint8x16_t a,uint8x16_t b,uint8x16_t idx)195 static INLINE uint8x16_t tbl2q(uint8x16_t a, uint8x16_t b, uint8x16_t idx) {
196 #if AOM_ARCH_AARCH64
197   uint8x16x2_t table = { { a, b } };
198   return vqtbl2q_u8(table, idx);
199 #else
200   uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b),
201                           vget_high_u8(b) } };
202   return vcombine_u8(vtbl4_u8(table, vget_low_u8(idx)),
203                      vtbl4_u8(table, vget_high_u8(idx)));
204 #endif
205 }
206 
207 // The M matrix is accumulated in STAT_ACCUMULATOR_MAX steps to speed-up the
208 // computation. This function computes the final M from the accumulated
209 // (src_s64) and the residual parts (src_s32). It also transposes the result as
210 // the output needs to be column-major.
acc_transpose_M(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int scale)211 static INLINE void acc_transpose_M(int64_t *dst, const int64_t *src_s64,
212                                    const int32_t *src_s32, const int wiener_win,
213                                    int scale) {
214   for (int i = 0; i < wiener_win; ++i) {
215     for (int j = 0; j < wiener_win; ++j) {
216       int tr_idx = j * wiener_win + i;
217       *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale;
218     }
219   }
220 }
221 
222 // The resulting H is a column-major matrix accumulated from the transposed
223 // (column-major) samples of the filter kernel (5x5 or 7x7) viewed as a single
224 // vector. For the 7x7 filter case: H(49x49) = [49 x 1] x [1 x 49]. This
225 // function transforms back to the originally expected format (double
226 // transpose). The H matrix is accumulated in STAT_ACCUMULATOR_MAX steps to
227 // speed-up the computation. This function computes the final H from the
228 // accumulated (src_s64) and the residual parts (src_s32). The computed H is
229 // only an upper triangle matrix, this function also fills the lower triangle of
230 // the resulting matrix.
update_H(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int stride,int scale)231 static void update_H(int64_t *dst, const int64_t *src_s64,
232                      const int32_t *src_s32, const int wiener_win, int stride,
233                      int scale) {
234   // For a simplified theoretical 3x3 case where `wiener_win` is 3 and
235   // `wiener_win2` is 9, the M matrix is 3x3:
236   // 0, 3, 6
237   // 1, 4, 7
238   // 2, 5, 8
239   //
240   // This is viewed as a vector to compute H (9x9) by vector outer product:
241   // 0, 3, 6, 1, 4, 7, 2, 5, 8
242   //
243   // Double transpose and upper triangle remapping for 3x3 -> 9x9 case:
244   // 0,    3,    6,    1,    4,    7,    2,    5,    8,
245   // 3,   30,   33,   12,   31,   34,   21,   32,   35,
246   // 6,   33,   60,   15,   42,   61,   24,   51,   62,
247   // 1,   12,   15,   10,   13,   16,   11,   14,   17,
248   // 4,   31,   42,   13,   40,   43,   22,   41,   44,
249   // 7,   34,   61,   16,   43,   70,   25,   52,   71,
250   // 2,   21,   24,   11,   22,   25,   20,   23,   26,
251   // 5,   32,   51,   14,   41,   52,   23,   50,   53,
252   // 8,   35,   62,   17,   44,   71,   26,   53,   80,
253   const int wiener_win2 = wiener_win * wiener_win;
254 
255   // Loop through the indices according to the remapping above, along the
256   // columns:
257   // 0, wiener_win, 2 * wiener_win, ..., 1, 1 + 2 * wiener_win, ...,
258   // wiener_win - 1, wiener_win - 1 + wiener_win, ...
259   // For the 3x3 case `j` will be: 0, 3, 6, 1, 4, 7, 2, 5, 8.
260   for (int i = 0; i < wiener_win; ++i) {
261     for (int j = i; j < wiener_win2; j += wiener_win) {
262       // These two inner loops are the same as the two outer loops, but running
263       // along rows instead of columns. For the 3x3 case `l` will be:
264       // 0, 3, 6, 1, 4, 7, 2, 5, 8.
265       for (int k = 0; k < wiener_win; ++k) {
266         for (int l = k; l < wiener_win2; l += wiener_win) {
267           // The nominal double transpose indexing would be:
268           // int idx = stride * j + l;
269           // However we need the upper-triangle indices, it is easy with some
270           // min/max operations.
271           int tr_idx = stride * AOMMIN(j, l) + AOMMAX(j, l);
272 
273           // Resulting matrix is filled by combining the 64-bit and the residual
274           // 32-bit matrices together with scaling.
275           *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale;
276         }
277       }
278     }
279   }
280 }
281 
282 // Load 7x7 matrix into 3 and a half 128-bit vectors from consecutive rows, the
283 // last load address is offset to prevent out-of-bounds access.
load_and_pack_u8_8x7(uint8x16_t dst[4],const uint8_t * src,ptrdiff_t stride)284 static INLINE void load_and_pack_u8_8x7(uint8x16_t dst[4], const uint8_t *src,
285                                         ptrdiff_t stride) {
286   dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
287   src += 2 * stride;
288   dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
289   src += 2 * stride;
290   dst[2] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
291   src += 2 * stride;
292   dst[3] = vcombine_u8(vld1_u8(src - 1), vdup_n_u8(0));
293 }
294 
compute_stats_win7_neon(const uint8_t * dgd,const uint8_t * src,int width,int height,int dgd_stride,int src_stride,int avg,int64_t * M,int64_t * H,int downsample_factor)295 static INLINE void compute_stats_win7_neon(const uint8_t *dgd,
296                                            const uint8_t *src, int width,
297                                            int height, int dgd_stride,
298                                            int src_stride, int avg, int64_t *M,
299                                            int64_t *H, int downsample_factor) {
300   // Matrix names are capitalized to help readability.
301   DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_ALIGN3]);
302   DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_ALIGN3]);
303   DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_ALIGN3]);
304   DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_ALIGN3]);
305   DECLARE_ALIGNED(64, int32_t, H_s32[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
306   DECLARE_ALIGNED(64, int64_t, H_s64[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
307 
308   memset(M_s32, 0, sizeof(M_s32));
309   memset(M_s64, 0, sizeof(M_s64));
310   memset(H_s32, 0, sizeof(H_s32));
311   memset(H_s64, 0, sizeof(H_s64));
312 
313   // Look-up tables to create 8x6 matrix with consecutive elements from two 7x7
314   // matrices.
315   // clang-format off
316   DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats7[96]) = {
317     0,  1,  2,  3,  4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 16, 17,
318     2,  3,  4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 16, 17, 18, 19,
319     4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 17, 18, 19, 20, 21, 22,
320     1,  2,  3,  4,  5,  6,  7,  9, 10, 11, 12, 13, 14, 15, 17, 18,
321     3,  4,  5,  6,  7,  9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20,
322     5,  6,  7,  9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23,
323   };
324   // clang-format on
325 
326   const uint8x16_t lut0 = vld1q_u8(shuffle_stats7 + 0);
327   const uint8x16_t lut1 = vld1q_u8(shuffle_stats7 + 16);
328   const uint8x16_t lut2 = vld1q_u8(shuffle_stats7 + 32);
329   const uint8x16_t lut3 = vld1q_u8(shuffle_stats7 + 48);
330   const uint8x16_t lut4 = vld1q_u8(shuffle_stats7 + 64);
331   const uint8x16_t lut5 = vld1q_u8(shuffle_stats7 + 80);
332 
333   int acc_cnt = STAT_ACCUMULATOR_MAX;
334   const int src_next = downsample_factor * src_stride - width;
335   const int dgd_next = downsample_factor * dgd_stride - width;
336   const uint8x8_t avg_u8 = vdup_n_u8(avg);
337 
338   do {
339     int j = width;
340     while (j >= 2) {
341       // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
342       // middle 6x7 elements being shared.
343       uint8x16_t dgd_rows[4];
344       load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride);
345 
346       const uint8_t *dgd_ptr = dgd + dgd_stride * 6;
347       dgd += 2;
348 
349       // Re-arrange (and widen) the combined 8x7 matrix to have the 2 whole 7x7
350       // matrices (1 for each of the 2 pixels) separated into distinct
351       // int16x8_t[6] arrays. These arrays contain 48 elements of the 49 (7x7).
352       // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 49
353       // consecutive elements.
354       int16x8_t dgd_avg0[6];
355       int16x8_t dgd_avg1[6];
356       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
357       uint8x16_t dgd_shuf3 = tbl2q(dgd_rows[0], dgd_rows[1], lut3);
358 
359       dgd_avg0[0] =
360           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
361       dgd_avg0[1] =
362           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
363       dgd_avg1[0] =
364           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf3), avg_u8));
365       dgd_avg1[1] =
366           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf3), avg_u8));
367 
368       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
369       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
370       vst1q_s16(DGD_AVG1, dgd_avg1[0]);
371       vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
372 
373       uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
374       uint8x16_t dgd_shuf4 = tbl2q(dgd_rows[1], dgd_rows[2], lut4);
375 
376       dgd_avg0[2] =
377           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
378       dgd_avg0[3] =
379           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
380       dgd_avg1[2] =
381           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf4), avg_u8));
382       dgd_avg1[3] =
383           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf4), avg_u8));
384 
385       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
386       vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
387       vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
388       vst1q_s16(DGD_AVG1 + 24, dgd_avg1[3]);
389 
390       uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
391       uint8x16_t dgd_shuf5 = tbl2q(dgd_rows[2], dgd_rows[3], lut5);
392 
393       dgd_avg0[4] =
394           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
395       dgd_avg0[5] =
396           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
397       dgd_avg1[4] =
398           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf5), avg_u8));
399       dgd_avg1[5] =
400           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf5), avg_u8));
401 
402       vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
403       vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
404       vst1q_s16(DGD_AVG1 + 32, dgd_avg1[4]);
405       vst1q_s16(DGD_AVG1 + 40, dgd_avg1[5]);
406 
407       // The remaining last (49th) elements of `dgd - avg`.
408       DGD_AVG0[48] = dgd_ptr[6] - avg;
409       DGD_AVG1[48] = dgd_ptr[7] - avg;
410 
411       // Accumulate into row-major variant of matrix M (cross-correlation) for 2
412       // output pixels at a time. M is of size 7 * 7. It needs to be filled such
413       // that multiplying one element from src with each element of a row of the
414       // wiener window will fill one column of M. However this is not very
415       // convenient in terms of memory access, as it means we do contiguous
416       // loads of dgd but strided stores to M. As a result, we use an
417       // intermediate matrix M_s32 which is instead filled such that one row of
418       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
419       // then transposed to return M.
420       int src_avg0 = *src++ - avg;
421       int src_avg1 = *src++ - avg;
422       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
423       int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
424       update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
425                        dgd_avg1[0]);
426       update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
427                        dgd_avg1[1]);
428       update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
429                        dgd_avg1[2]);
430       update_M_2pixels(M_s32 + 24, src_avg0_s16, src_avg1_s16, dgd_avg0[3],
431                        dgd_avg1[3]);
432       update_M_2pixels(M_s32 + 32, src_avg0_s16, src_avg1_s16, dgd_avg0[4],
433                        dgd_avg1[4]);
434       update_M_2pixels(M_s32 + 40, src_avg0_s16, src_avg1_s16, dgd_avg0[5],
435                        dgd_avg1[5]);
436 
437       // Last (49th) element of M_s32 can be computed as scalar more efficiently
438       // for 2 output pixels.
439       M_s32[48] += DGD_AVG0[48] * src_avg0 + DGD_AVG1[48] * src_avg1;
440 
441       // Start accumulating into row-major version of matrix H
442       // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
443       // row-major. H is of size 49 * 49. It is filled by multiplying every pair
444       // of elements of the wiener window together (vector outer product). Since
445       // it is a symmetric matrix, we only compute the upper-right triangle, and
446       // then copy it down to the lower-left later. The upper triangle is
447       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
448       // column-major and the resulting H matrix is also expected to be
449       // column-major. It is not efficient to work with column-major matrices,
450       // so we accumulate into a row-major matrix H_s32. At the end of the
451       // algorithm a double transpose transformation will convert H_s32 back to
452       // the expected output layout.
453       update_H_7x7_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
454 
455       // The last element of the triangle of H_s32 matrix can be computed as a
456       // scalar more efficiently.
457       H_s32[48 * WIENER_WIN2_ALIGN2 + 48] +=
458           DGD_AVG0[48] * DGD_AVG0[48] + DGD_AVG1[48] * DGD_AVG1[48];
459 
460       // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent
461       // overflow.
462       if (--acc_cnt == 0) {
463         acc_cnt = STAT_ACCUMULATOR_MAX;
464 
465         accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_ALIGN2);
466 
467         // The widening accumulation is only needed for the upper triangle part
468         // of the matrix.
469         int64_t *lh = H_s64;
470         int32_t *lh32 = H_s32;
471         for (int k = 0; k < WIENER_WIN2; ++k) {
472           // The widening accumulation is only run for the relevant parts
473           // (upper-right triangle) in a row 4-element aligned.
474           int k4 = k / 4 * 4;
475           accumulate_and_clear(lh + k4, lh32 + k4, 48 - k4);
476 
477           // Last element of the row is computed separately.
478           lh[48] += lh32[48];
479           lh32[48] = 0;
480 
481           lh += WIENER_WIN2_ALIGN2;
482           lh32 += WIENER_WIN2_ALIGN2;
483         }
484       }
485 
486       j -= 2;
487     }
488 
489     // Computations for odd pixel in the row.
490     if (width & 1) {
491       // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
492       // middle 6x7 elements being shared.
493       uint8x16_t dgd_rows[4];
494       load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride);
495 
496       const uint8_t *dgd_ptr = dgd + dgd_stride * 6;
497       ++dgd;
498 
499       // Re-arrange (and widen) the combined 8x7 matrix to have a whole 7x7
500       // matrix tightly packed into a int16x8_t[6] array. This array contains
501       // 48 elements of the 49 (7x7). Compute `dgd - avg` for the whole buffer.
502       // The DGD_AVG buffer contains 49 consecutive elements.
503       int16x8_t dgd_avg0[6];
504       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
505       dgd_avg0[0] =
506           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
507       dgd_avg0[1] =
508           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
509       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
510       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
511 
512       uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
513       dgd_avg0[2] =
514           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
515       dgd_avg0[3] =
516           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
517       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
518       vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
519 
520       uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
521       dgd_avg0[4] =
522           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
523       dgd_avg0[5] =
524           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
525       vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
526       vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
527 
528       // The remaining last (49th) element of `dgd - avg`.
529       DGD_AVG0[48] = dgd_ptr[6] - avg;
530 
531       // Accumulate into row-major order variant of matrix M (cross-correlation)
532       // for 1 output pixel at a time. M is of size 7 * 7. It needs to be filled
533       // such that multiplying one element from src with each element of a row
534       // of the wiener window will fill one column of M. However this is not
535       // very convenient in terms of memory access, as it means we do
536       // contiguous loads of dgd but strided stores to M. As a result, we use an
537       // intermediate matrix M_s32 which is instead filled such that one row of
538       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
539       // then transposed to return M.
540       int src_avg0 = *src++ - avg;
541       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
542       update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
543       update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
544       update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
545       update_M_1pixel(M_s32 + 24, src_avg0_s16, dgd_avg0[3]);
546       update_M_1pixel(M_s32 + 32, src_avg0_s16, dgd_avg0[4]);
547       update_M_1pixel(M_s32 + 40, src_avg0_s16, dgd_avg0[5]);
548 
549       // Last (49th) element of M_s32 can be computed as scalar more efficiently
550       // for 1 output pixel.
551       M_s32[48] += DGD_AVG0[48] * src_avg0;
552 
553       // Start accumulating into row-major order version of matrix H
554       // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
555       // H is of size 49 * 49. It is filled by multiplying every pair of
556       // elements of the wiener window together (vector outer product). Since it
557       // is a symmetric matrix, we only compute the upper-right triangle, and
558       // then copy it down to the lower-left later. The upper triangle is
559       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
560       // column-major and the resulting H matrix is also expected to be
561       // column-major. It is not efficient to work column-major matrices, so we
562       // accumulate into a row-major matrix H_s32. At the end of the algorithm a
563       // double transpose transformation will convert H_s32 back to the expected
564       // output layout.
565       update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_ALIGN2, 48);
566 
567       // The last element of the triangle of H_s32 matrix can be computed as
568       // scalar more efficiently.
569       H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += DGD_AVG0[48] * DGD_AVG0[48];
570     }
571 
572     src += src_next;
573     dgd += dgd_next;
574   } while (--height != 0);
575 
576   acc_transpose_M(M, M_s64, M_s32, WIENER_WIN, downsample_factor);
577 
578   update_H(H, H_s64, H_s32, WIENER_WIN, WIENER_WIN2_ALIGN2, downsample_factor);
579 }
580 
581 // Load 5x5 matrix into 2 and a half 128-bit vectors from consecutive rows, the
582 // last load address is offset to prevent out-of-bounds access.
load_and_pack_u8_6x5(uint8x16_t dst[3],const uint8_t * src,ptrdiff_t stride)583 static INLINE void load_and_pack_u8_6x5(uint8x16_t dst[3], const uint8_t *src,
584                                         ptrdiff_t stride) {
585   dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
586   src += 2 * stride;
587   dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
588   src += 2 * stride;
589   dst[2] = vcombine_u8(vld1_u8(src - 3), vdup_n_u8(0));
590 }
591 
compute_stats_win5_neon(const uint8_t * dgd,const uint8_t * src,int width,int height,int dgd_stride,int src_stride,int avg,int64_t * M,int64_t * H,int downsample_factor)592 static INLINE void compute_stats_win5_neon(const uint8_t *dgd,
593                                            const uint8_t *src, int width,
594                                            int height, int dgd_stride,
595                                            int src_stride, int avg, int64_t *M,
596                                            int64_t *H, int downsample_factor) {
597   // Matrix names are capitalized to help readability.
598   DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_REDUCED_ALIGN3]);
599   DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_REDUCED_ALIGN3]);
600   DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_REDUCED_ALIGN3]);
601   DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_REDUCED_ALIGN3]);
602   DECLARE_ALIGNED(64, int32_t,
603                   H_s32[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
604   DECLARE_ALIGNED(64, int64_t,
605                   H_s64[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
606 
607   memset(M_s32, 0, sizeof(M_s32));
608   memset(M_s64, 0, sizeof(M_s64));
609   memset(H_s32, 0, sizeof(H_s32));
610   memset(H_s64, 0, sizeof(H_s64));
611 
612   // Look-up tables to create 8x3 matrix with consecutive elements from two 5x5
613   // matrices.
614   // clang-format off
615   DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats5[48]) = {
616     0,  1,  2,  3,  4,  8,  9, 10, 11, 12, 16, 17, 18, 19, 20, 24,
617     1,  2,  3,  4,  5,  9, 10, 11, 12, 13, 17, 18, 19, 20, 21, 25,
618     9, 10, 11, 12, 19, 20, 21, 22, 10, 11, 12, 13, 20, 21, 22, 23,
619   };
620   // clang-format on
621 
622   const uint8x16_t lut0 = vld1q_u8(shuffle_stats5 + 0);
623   const uint8x16_t lut1 = vld1q_u8(shuffle_stats5 + 16);
624   const uint8x16_t lut2 = vld1q_u8(shuffle_stats5 + 32);
625 
626   int acc_cnt = STAT_ACCUMULATOR_MAX;
627   const int src_next = downsample_factor * src_stride - width;
628   const int dgd_next = downsample_factor * dgd_stride - width;
629   const uint8x8_t avg_u8 = vdup_n_u8(avg);
630 
631   do {
632     int j = width;
633     while (j >= 2) {
634       // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
635       // middle 4x5 elements being shared.
636       uint8x16_t dgd_rows[3];
637       load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride);
638 
639       const uint8_t *dgd_ptr = dgd + dgd_stride * 4;
640       dgd += 2;
641 
642       // Re-arrange (and widen) the combined 6x5 matrix to have the 2 whole 5x5
643       // matrices (1 for each of the 2 pixels) separated into distinct
644       // int16x8_t[3] arrays. These arrays contain 24 elements of the 25 (5x5).
645       // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 25
646       // consecutive elements.
647       int16x8_t dgd_avg0[3];
648       int16x8_t dgd_avg1[3];
649       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
650       uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[0], dgd_rows[1], lut1);
651       uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[1], dgd_rows[2], lut2);
652 
653       dgd_avg0[0] =
654           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
655       dgd_avg0[1] =
656           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
657       dgd_avg0[2] =
658           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
659       dgd_avg1[0] =
660           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
661       dgd_avg1[1] =
662           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
663       dgd_avg1[2] =
664           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
665 
666       vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]);
667       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
668       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
669       vst1q_s16(DGD_AVG1 + 0, dgd_avg1[0]);
670       vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
671       vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
672 
673       // The remaining last (25th) elements of `dgd - avg`.
674       DGD_AVG0[24] = dgd_ptr[4] - avg;
675       DGD_AVG1[24] = dgd_ptr[5] - avg;
676 
677       // Accumulate into row-major variant of matrix M (cross-correlation) for 2
678       // output pixels at a time. M is of size 5 * 5. It needs to be filled such
679       // that multiplying one element from src with each element of a row of the
680       // wiener window will fill one column of M. However this is not very
681       // convenient in terms of memory access, as it means we do contiguous
682       // loads of dgd but strided stores to M. As a result, we use an
683       // intermediate matrix M_s32 which is instead filled such that one row of
684       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
685       // then transposed to return M.
686       int src_avg0 = *src++ - avg;
687       int src_avg1 = *src++ - avg;
688       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
689       int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
690       update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
691                        dgd_avg1[0]);
692       update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
693                        dgd_avg1[1]);
694       update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
695                        dgd_avg1[2]);
696 
697       // Last (25th) element of M_s32 can be computed as scalar more efficiently
698       // for 2 output pixels.
699       M_s32[24] += DGD_AVG0[24] * src_avg0 + DGD_AVG1[24] * src_avg1;
700 
701       // Start accumulating into row-major version of matrix H
702       // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
703       // row-major. H is of size 25 * 25. It is filled by multiplying every pair
704       // of elements of the wiener window together (vector outer product). Since
705       // it is a symmetric matrix, we only compute the upper-right triangle, and
706       // then copy it down to the lower-left later. The upper triangle is
707       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
708       // column-major and the resulting H matrix is also expected to be
709       // column-major. It is not efficient to work with column-major matrices,
710       // so we accumulate into a row-major matrix H_s32. At the end of the
711       // algorithm a double transpose transformation will convert H_s32 back to
712       // the expected output layout.
713       update_H_5x5_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
714 
715       // The last element of the triangle of H_s32 matrix can be computed as a
716       // scalar more efficiently.
717       H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
718           DGD_AVG0[24] * DGD_AVG0[24] + DGD_AVG1[24] * DGD_AVG1[24];
719 
720       // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent
721       // overflow.
722       if (--acc_cnt == 0) {
723         acc_cnt = STAT_ACCUMULATOR_MAX;
724 
725         accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_REDUCED_ALIGN2);
726 
727         // The widening accumulation is only needed for the upper triangle part
728         // of the matrix.
729         int64_t *lh = H_s64;
730         int32_t *lh32 = H_s32;
731         for (int k = 0; k < WIENER_WIN2_REDUCED; ++k) {
732           // The widening accumulation is only run for the relevant parts
733           // (upper-right triangle) in a row 4-element aligned.
734           int k4 = k / 4 * 4;
735           accumulate_and_clear(lh + k4, lh32 + k4, 24 - k4);
736 
737           // Last element of the row is computed separately.
738           lh[24] += lh32[24];
739           lh32[24] = 0;
740 
741           lh += WIENER_WIN2_REDUCED_ALIGN2;
742           lh32 += WIENER_WIN2_REDUCED_ALIGN2;
743         }
744       }
745 
746       j -= 2;
747     }
748 
749     // Computations for odd pixel in the row.
750     if (width & 1) {
751       // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
752       // middle 4x5 elements being shared.
753       uint8x16_t dgd_rows[3];
754       load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride);
755 
756       const uint8_t *dgd_ptr = dgd + dgd_stride * 4;
757       ++dgd;
758 
759       // Re-arrange (and widen) the combined 6x5 matrix to have a whole 5x5
760       // matrix tightly packed into a int16x8_t[3] array. This array contains
761       // 24 elements of the 25 (5x5). Compute `dgd - avg` for the whole buffer.
762       // The DGD_AVG buffer contains 25 consecutive elements.
763       int16x8_t dgd_avg0[3];
764       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
765       uint8x8_t dgd_shuf1 = tbl2(dgd_rows[1], dgd_rows[2], vget_low_u8(lut2));
766 
767       dgd_avg0[0] =
768           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
769       dgd_avg0[1] =
770           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
771       dgd_avg0[2] = vreinterpretq_s16_u16(vsubl_u8(dgd_shuf1, avg_u8));
772 
773       vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]);
774       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
775       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
776 
777       // The remaining last (25th) element of `dgd - avg`.
778       DGD_AVG0[24] = dgd_ptr[4] - avg;
779 
780       // Accumulate into row-major order variant of matrix M (cross-correlation)
781       // for 1 output pixel at a time. M is of size 5 * 5. It needs to be filled
782       // such that multiplying one element from src with each element of a row
783       // of the wiener window will fill one column of M. However this is not
784       // very convenient in terms of memory access, as it means we do
785       // contiguous loads of dgd but strided stores to M. As a result, we use an
786       // intermediate matrix M_s32 which is instead filled such that one row of
787       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
788       // then transposed to return M.
789       int src_avg0 = *src++ - avg;
790       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
791       update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
792       update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
793       update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
794 
795       // Last (25th) element of M_s32 can be computed as scalar more efficiently
796       // for 1 output pixel.
797       M_s32[24] += DGD_AVG0[24] * src_avg0;
798 
799       // Start accumulating into row-major order version of matrix H
800       // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
801       // H is of size 25 * 25. It is filled by multiplying every pair of
802       // elements of the wiener window together (vector outer product). Since it
803       // is a symmetric matrix, we only compute the upper-right triangle, and
804       // then copy it down to the lower-left later. The upper triangle is
805       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
806       // column-major and the resulting H matrix is also expected to be
807       // column-major. It is not efficient to work column-major matrices, so we
808       // accumulate into a row-major matrix H_s32. At the end of the algorithm a
809       // double transpose transformation will convert H_s32 back to the expected
810       // output layout.
811       update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_REDUCED_ALIGN2, 24);
812 
813       // The last element of the triangle of H_s32 matrix can be computed as a
814       // scalar more efficiently.
815       H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
816           DGD_AVG0[24] * DGD_AVG0[24];
817     }
818 
819     src += src_next;
820     dgd += dgd_next;
821   } while (--height != 0);
822 
823   acc_transpose_M(M, M_s64, M_s32, WIENER_WIN_REDUCED, downsample_factor);
824 
825   update_H(H, H_s64, H_s32, WIENER_WIN_REDUCED, WIENER_WIN2_REDUCED_ALIGN2,
826            downsample_factor);
827 }
828 
find_average_neon(const uint8_t * src,int src_stride,int width,int height)829 static INLINE uint8_t find_average_neon(const uint8_t *src, int src_stride,
830                                         int width, int height) {
831   uint64_t sum = 0;
832 
833   if (width >= 16) {
834     int h = 0;
835     // We can accumulate up to 257 8-bit values in a 16-bit value, given
836     // that each 16-bit vector has 8 elements, that means we can process up to
837     // int(257*8/width) rows before we need to widen to 32-bit vector
838     // elements.
839     int h_overflow = 257 * 8 / width;
840     int h_limit = height > h_overflow ? h_overflow : height;
841     uint32x4_t avg_u32 = vdupq_n_u32(0);
842     do {
843       uint16x8_t avg_u16 = vdupq_n_u16(0);
844       do {
845         int j = width;
846         const uint8_t *src_ptr = src;
847         do {
848           uint8x16_t s = vld1q_u8(src_ptr);
849           avg_u16 = vpadalq_u8(avg_u16, s);
850           j -= 16;
851           src_ptr += 16;
852         } while (j >= 16);
853         if (j >= 8) {
854           uint8x8_t s = vld1_u8(src_ptr);
855           avg_u16 = vaddw_u8(avg_u16, s);
856           j -= 8;
857           src_ptr += 8;
858         }
859         // Scalar tail case.
860         while (j > 0) {
861           sum += src[width - j];
862           j--;
863         }
864         src += src_stride;
865       } while (++h < h_limit);
866       avg_u32 = vpadalq_u16(avg_u32, avg_u16);
867 
868       h_limit += h_overflow;
869       h_limit = height > h_overflow ? h_overflow : height;
870     } while (h < height);
871     return (uint8_t)((horizontal_long_add_u32x4(avg_u32) + sum) /
872                      (width * height));
873   }
874   if (width >= 8) {
875     int h = 0;
876     // We can accumulate up to 257 8-bit values in a 16-bit value, given
877     // that each 16-bit vector has 4 elements, that means we can process up to
878     // int(257*4/width) rows before we need to widen to 32-bit vector
879     // elements.
880     int h_overflow = 257 * 4 / width;
881     int h_limit = height > h_overflow ? h_overflow : height;
882     uint32x2_t avg_u32 = vdup_n_u32(0);
883     do {
884       uint16x4_t avg_u16 = vdup_n_u16(0);
885       do {
886         int j = width;
887         const uint8_t *src_ptr = src;
888         uint8x8_t s = vld1_u8(src_ptr);
889         avg_u16 = vpadal_u8(avg_u16, s);
890         j -= 8;
891         src_ptr += 8;
892         // Scalar tail case.
893         while (j > 0) {
894           sum += src[width - j];
895           j--;
896         }
897         src += src_stride;
898       } while (++h < h_limit);
899       avg_u32 = vpadal_u16(avg_u32, avg_u16);
900 
901       h_limit += h_overflow;
902       h_limit = height > h_overflow ? h_overflow : height;
903     } while (h < height);
904     return (uint8_t)((horizontal_long_add_u32x2(avg_u32) + sum) /
905                      (width * height));
906   }
907   int i = height;
908   do {
909     int j = 0;
910     do {
911       sum += src[j];
912     } while (++j < width);
913     src += src_stride;
914   } while (--i != 0);
915   return (uint8_t)(sum / (width * height));
916 }
917 
av1_compute_stats_neon(int wiener_win,const uint8_t * dgd,const uint8_t * src,int16_t * dgd_avg,int16_t * src_avg,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,int use_downsampled_wiener_stats)918 void av1_compute_stats_neon(int wiener_win, const uint8_t *dgd,
919                             const uint8_t *src, int16_t *dgd_avg,
920                             int16_t *src_avg, int h_start, int h_end,
921                             int v_start, int v_end, int dgd_stride,
922                             int src_stride, int64_t *M, int64_t *H,
923                             int use_downsampled_wiener_stats) {
924   assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA);
925   assert(WIENER_STATS_DOWNSAMPLE_FACTOR == 4);
926   (void)dgd_avg;
927   (void)src_avg;
928 
929   const int wiener_win2 = wiener_win * wiener_win;
930   const int wiener_halfwin = wiener_win >> 1;
931   const int width = h_end - h_start;
932   const int height = v_end - v_start;
933 
934   const uint8_t *dgd_start = dgd + h_start + v_start * dgd_stride;
935   const uint8_t *src_start = src + h_start + v_start * src_stride;
936 
937   // The wiener window will slide along the dgd frame, centered on each pixel.
938   // For the top left pixel and all the pixels on the side of the frame this
939   // means half of the window will be outside of the frame. As such the actual
940   // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
941   // wider and 2 * wiener_halfwin higher than the original dgd buffer.
942   const int vert_offset = v_start - wiener_halfwin;
943   const int horiz_offset = h_start - wiener_halfwin;
944   const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
945 
946   uint8_t avg = find_average_neon(dgd_start, dgd_stride, width, height);
947 
948   // Since the height is not necessarily a multiple of the downsample factor,
949   // the last line of src will be scaled according to how many rows remain.
950   int downsample_factor =
951       use_downsampled_wiener_stats ? WIENER_STATS_DOWNSAMPLE_FACTOR : 1;
952 
953   int downsampled_height = height / downsample_factor;
954   int downsample_remainder = height % downsample_factor;
955 
956   memset(M, 0, wiener_win2 * sizeof(*M));
957   memset(H, 0, wiener_win2 * wiener_win2 * sizeof(*H));
958 
959   // Calculate the M and H matrices for the normal and downsampled cases.
960   if (downsampled_height > 0) {
961     if (wiener_win == WIENER_WIN) {
962       compute_stats_win7_neon(dgd_win, src_start, width, downsampled_height,
963                               dgd_stride, src_stride, avg, M, H,
964                               downsample_factor);
965     } else {
966       compute_stats_win5_neon(dgd_win, src_start, width, downsampled_height,
967                               dgd_stride, src_stride, avg, M, H,
968                               downsample_factor);
969     }
970   }
971 
972   // Accumulate the remaining last rows in the downsampled case.
973   if (downsample_remainder > 0) {
974     int remainder_offset = height - downsample_remainder;
975     if (wiener_win == WIENER_WIN) {
976       compute_stats_win7_neon(dgd_win + remainder_offset * dgd_stride,
977                               src_start + remainder_offset * src_stride, width,
978                               1, dgd_stride, src_stride, avg, M, H,
979                               downsample_remainder);
980     } else {
981       compute_stats_win5_neon(dgd_win + remainder_offset * dgd_stride,
982                               src_start + remainder_offset * src_stride, width,
983                               1, dgd_stride, src_stride, avg, M, H,
984                               downsample_remainder);
985     }
986   }
987 }
988 
calc_proj_params_r0_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])989 static INLINE void calc_proj_params_r0_r1_neon(
990     const uint8_t *src8, int width, int height, int src_stride,
991     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
992     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
993   assert(width % 8 == 0);
994   const int size = width * height;
995 
996   int64x2_t h00_lo = vdupq_n_s64(0);
997   int64x2_t h00_hi = vdupq_n_s64(0);
998   int64x2_t h11_lo = vdupq_n_s64(0);
999   int64x2_t h11_hi = vdupq_n_s64(0);
1000   int64x2_t h01_lo = vdupq_n_s64(0);
1001   int64x2_t h01_hi = vdupq_n_s64(0);
1002   int64x2_t c0_lo = vdupq_n_s64(0);
1003   int64x2_t c0_hi = vdupq_n_s64(0);
1004   int64x2_t c1_lo = vdupq_n_s64(0);
1005   int64x2_t c1_hi = vdupq_n_s64(0);
1006 
1007   do {
1008     const uint8_t *src_ptr = src8;
1009     const uint8_t *dat_ptr = dat8;
1010     int32_t *flt0_ptr = flt0;
1011     int32_t *flt1_ptr = flt1;
1012     int w = width;
1013 
1014     do {
1015       uint8x8_t s = vld1_u8(src_ptr);
1016       uint8x8_t d = vld1_u8(dat_ptr);
1017       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
1018       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
1019       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
1020       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
1021 
1022       int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
1023       int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
1024 
1025       int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
1026       int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
1027       f0_lo = vsubw_s16(f0_lo, vget_low_s16(u));
1028       f0_hi = vsubw_s16(f0_hi, vget_high_s16(u));
1029       f1_lo = vsubw_s16(f1_lo, vget_low_s16(u));
1030       f1_hi = vsubw_s16(f1_hi, vget_high_s16(u));
1031 
1032       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
1033       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
1034       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
1035       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
1036 
1037       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
1038       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
1039       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
1040       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
1041 
1042       h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo));
1043       h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo));
1044       h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi));
1045       h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi));
1046 
1047       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
1048       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
1049       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
1050       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
1051 
1052       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
1053       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
1054       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
1055       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
1056 
1057       src_ptr += 8;
1058       dat_ptr += 8;
1059       flt0_ptr += 8;
1060       flt1_ptr += 8;
1061       w -= 8;
1062     } while (w != 0);
1063 
1064     src8 += src_stride;
1065     dat8 += dat_stride;
1066     flt0 += flt0_stride;
1067     flt1 += flt1_stride;
1068   } while (--height != 0);
1069 
1070   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
1071   H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size;
1072   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
1073   H[1][0] = H[0][1];
1074   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
1075   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
1076 }
1077 
calc_proj_params_r0_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])1078 static INLINE void calc_proj_params_r0_neon(const uint8_t *src8, int width,
1079                                             int height, int src_stride,
1080                                             const uint8_t *dat8, int dat_stride,
1081                                             int32_t *flt0, int flt0_stride,
1082                                             int64_t H[2][2], int64_t C[2]) {
1083   assert(width % 8 == 0);
1084   const int size = width * height;
1085 
1086   int64x2_t h00_lo = vdupq_n_s64(0);
1087   int64x2_t h00_hi = vdupq_n_s64(0);
1088   int64x2_t c0_lo = vdupq_n_s64(0);
1089   int64x2_t c0_hi = vdupq_n_s64(0);
1090 
1091   do {
1092     const uint8_t *src_ptr = src8;
1093     const uint8_t *dat_ptr = dat8;
1094     int32_t *flt0_ptr = flt0;
1095     int w = width;
1096 
1097     do {
1098       uint8x8_t s = vld1_u8(src_ptr);
1099       uint8x8_t d = vld1_u8(dat_ptr);
1100       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
1101       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
1102 
1103       int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
1104       int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
1105 
1106       int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
1107       int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
1108       f0_lo = vsubw_s16(f0_lo, vget_low_s16(u));
1109       f0_hi = vsubw_s16(f0_hi, vget_high_s16(u));
1110 
1111       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
1112       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
1113       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
1114       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
1115 
1116       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
1117       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
1118       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
1119       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
1120 
1121       src_ptr += 8;
1122       dat_ptr += 8;
1123       flt0_ptr += 8;
1124       w -= 8;
1125     } while (w != 0);
1126 
1127     src8 += src_stride;
1128     dat8 += dat_stride;
1129     flt0 += flt0_stride;
1130   } while (--height != 0);
1131 
1132   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
1133   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
1134 }
1135 
calc_proj_params_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])1136 static INLINE void calc_proj_params_r1_neon(const uint8_t *src8, int width,
1137                                             int height, int src_stride,
1138                                             const uint8_t *dat8, int dat_stride,
1139                                             int32_t *flt1, int flt1_stride,
1140                                             int64_t H[2][2], int64_t C[2]) {
1141   assert(width % 8 == 0);
1142   const int size = width * height;
1143 
1144   int64x2_t h11_lo = vdupq_n_s64(0);
1145   int64x2_t h11_hi = vdupq_n_s64(0);
1146   int64x2_t c1_lo = vdupq_n_s64(0);
1147   int64x2_t c1_hi = vdupq_n_s64(0);
1148 
1149   do {
1150     const uint8_t *src_ptr = src8;
1151     const uint8_t *dat_ptr = dat8;
1152     int32_t *flt1_ptr = flt1;
1153     int w = width;
1154 
1155     do {
1156       uint8x8_t s = vld1_u8(src_ptr);
1157       uint8x8_t d = vld1_u8(dat_ptr);
1158       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
1159       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
1160 
1161       int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
1162       int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
1163 
1164       int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
1165       int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
1166       f1_lo = vsubw_s16(f1_lo, vget_low_s16(u));
1167       f1_hi = vsubw_s16(f1_hi, vget_high_s16(u));
1168 
1169       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
1170       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
1171       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
1172       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
1173 
1174       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
1175       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
1176       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
1177       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
1178 
1179       src_ptr += 8;
1180       dat_ptr += 8;
1181       flt1_ptr += 8;
1182       w -= 8;
1183     } while (w != 0);
1184 
1185     src8 += src_stride;
1186     dat8 += dat_stride;
1187     flt1 += flt1_stride;
1188   } while (--height != 0);
1189 
1190   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
1191   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
1192 }
1193 
1194 // The function calls 3 subfunctions for the following cases :
1195 // 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements
1196 //    of C and H need to be computed.
1197 // 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are
1198 //    non-zero and need to be computed.
1199 // 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are
1200 //    non-zero and need to be computed.
av1_calc_proj_params_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)1201 void av1_calc_proj_params_neon(const uint8_t *src8, int width, int height,
1202                                int src_stride, const uint8_t *dat8,
1203                                int dat_stride, int32_t *flt0, int flt0_stride,
1204                                int32_t *flt1, int flt1_stride, int64_t H[2][2],
1205                                int64_t C[2], const sgr_params_type *params) {
1206   if ((params->r[0] > 0) && (params->r[1] > 0)) {
1207     calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8,
1208                                 dat_stride, flt0, flt0_stride, flt1,
1209                                 flt1_stride, H, C);
1210   } else if (params->r[0] > 0) {
1211     calc_proj_params_r0_neon(src8, width, height, src_stride, dat8, dat_stride,
1212                              flt0, flt0_stride, H, C);
1213   } else if (params->r[1] > 0) {
1214     calc_proj_params_r1_neon(src8, width, height, src_stride, dat8, dat_stride,
1215                              flt1, flt1_stride, H, C);
1216   }
1217 }
1218