• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdint.h>
15 
16 #include "aom_dsp/arm/mem_neon.h"
17 #include "aom_dsp/arm/sum_neon.h"
18 #include "av1/encoder/arm/neon/pickrst_neon.h"
19 #include "av1/encoder/pickrst.h"
20 
highbd_calc_proj_params_r0_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])21 static INLINE void highbd_calc_proj_params_r0_r1_neon(
22     const uint8_t *src8, int width, int height, int src_stride,
23     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
24     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
25   assert(width % 8 == 0);
26   const int size = width * height;
27   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
28   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
29 
30   int64x2_t h00_lo = vdupq_n_s64(0);
31   int64x2_t h00_hi = vdupq_n_s64(0);
32   int64x2_t h11_lo = vdupq_n_s64(0);
33   int64x2_t h11_hi = vdupq_n_s64(0);
34   int64x2_t h01_lo = vdupq_n_s64(0);
35   int64x2_t h01_hi = vdupq_n_s64(0);
36   int64x2_t c0_lo = vdupq_n_s64(0);
37   int64x2_t c0_hi = vdupq_n_s64(0);
38   int64x2_t c1_lo = vdupq_n_s64(0);
39   int64x2_t c1_hi = vdupq_n_s64(0);
40 
41   do {
42     const uint16_t *src_ptr = src;
43     const uint16_t *dat_ptr = dat;
44     int32_t *flt0_ptr = flt0;
45     int32_t *flt1_ptr = flt1;
46     int w = width;
47 
48     do {
49       uint16x8_t s = vld1q_u16(src_ptr);
50       uint16x8_t d = vld1q_u16(dat_ptr);
51       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
52       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
53       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
54       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
55 
56       int32x4_t u_lo =
57           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
58       int32x4_t u_hi = vreinterpretq_s32_u32(
59           vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
60       int32x4_t s_lo =
61           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
62       int32x4_t s_hi = vreinterpretq_s32_u32(
63           vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
64       s_lo = vsubq_s32(s_lo, u_lo);
65       s_hi = vsubq_s32(s_hi, u_hi);
66 
67       f0_lo = vsubq_s32(f0_lo, u_lo);
68       f0_hi = vsubq_s32(f0_hi, u_hi);
69       f1_lo = vsubq_s32(f1_lo, u_lo);
70       f1_hi = vsubq_s32(f1_hi, u_hi);
71 
72       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
73       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
74       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
75       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
76 
77       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
78       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
79       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
80       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
81 
82       h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo));
83       h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo));
84       h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi));
85       h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi));
86 
87       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
88       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
89       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
90       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
91 
92       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
93       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
94       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
95       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
96 
97       src_ptr += 8;
98       dat_ptr += 8;
99       flt0_ptr += 8;
100       flt1_ptr += 8;
101       w -= 8;
102     } while (w != 0);
103 
104     src += src_stride;
105     dat += dat_stride;
106     flt0 += flt0_stride;
107     flt1 += flt1_stride;
108   } while (--height != 0);
109 
110   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
111   H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size;
112   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
113   H[1][0] = H[0][1];
114   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
115   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
116 }
117 
highbd_calc_proj_params_r0_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])118 static INLINE void highbd_calc_proj_params_r0_neon(
119     const uint8_t *src8, int width, int height, int src_stride,
120     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
121     int64_t H[2][2], int64_t C[2]) {
122   assert(width % 8 == 0);
123   const int size = width * height;
124   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
125   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
126 
127   int64x2_t h00_lo = vdupq_n_s64(0);
128   int64x2_t h00_hi = vdupq_n_s64(0);
129   int64x2_t c0_lo = vdupq_n_s64(0);
130   int64x2_t c0_hi = vdupq_n_s64(0);
131 
132   do {
133     const uint16_t *src_ptr = src;
134     const uint16_t *dat_ptr = dat;
135     int32_t *flt0_ptr = flt0;
136     int w = width;
137 
138     do {
139       uint16x8_t s = vld1q_u16(src_ptr);
140       uint16x8_t d = vld1q_u16(dat_ptr);
141       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
142       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
143 
144       int32x4_t u_lo =
145           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
146       int32x4_t u_hi = vreinterpretq_s32_u32(
147           vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
148       int32x4_t s_lo =
149           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
150       int32x4_t s_hi = vreinterpretq_s32_u32(
151           vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
152       s_lo = vsubq_s32(s_lo, u_lo);
153       s_hi = vsubq_s32(s_hi, u_hi);
154 
155       f0_lo = vsubq_s32(f0_lo, u_lo);
156       f0_hi = vsubq_s32(f0_hi, u_hi);
157 
158       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
159       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
160       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
161       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
162 
163       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
164       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
165       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
166       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
167 
168       src_ptr += 8;
169       dat_ptr += 8;
170       flt0_ptr += 8;
171       w -= 8;
172     } while (w != 0);
173 
174     src += src_stride;
175     dat += dat_stride;
176     flt0 += flt0_stride;
177   } while (--height != 0);
178 
179   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
180   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
181 }
182 
highbd_calc_proj_params_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])183 static INLINE void highbd_calc_proj_params_r1_neon(
184     const uint8_t *src8, int width, int height, int src_stride,
185     const uint8_t *dat8, int dat_stride, int32_t *flt1, int flt1_stride,
186     int64_t H[2][2], int64_t C[2]) {
187   assert(width % 8 == 0);
188   const int size = width * height;
189   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
190   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
191 
192   int64x2_t h11_lo = vdupq_n_s64(0);
193   int64x2_t h11_hi = vdupq_n_s64(0);
194   int64x2_t c1_lo = vdupq_n_s64(0);
195   int64x2_t c1_hi = vdupq_n_s64(0);
196 
197   do {
198     const uint16_t *src_ptr = src;
199     const uint16_t *dat_ptr = dat;
200     int32_t *flt1_ptr = flt1;
201     int w = width;
202 
203     do {
204       uint16x8_t s = vld1q_u16(src_ptr);
205       uint16x8_t d = vld1q_u16(dat_ptr);
206       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
207       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
208 
209       int32x4_t u_lo =
210           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
211       int32x4_t u_hi = vreinterpretq_s32_u32(
212           vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
213       int32x4_t s_lo =
214           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
215       int32x4_t s_hi = vreinterpretq_s32_u32(
216           vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
217       s_lo = vsubq_s32(s_lo, u_lo);
218       s_hi = vsubq_s32(s_hi, u_hi);
219 
220       f1_lo = vsubq_s32(f1_lo, u_lo);
221       f1_hi = vsubq_s32(f1_hi, u_hi);
222 
223       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
224       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
225       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
226       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
227 
228       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
229       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
230       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
231       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
232 
233       src_ptr += 8;
234       dat_ptr += 8;
235       flt1_ptr += 8;
236       w -= 8;
237     } while (w != 0);
238 
239     src += src_stride;
240     dat += dat_stride;
241     flt1 += flt1_stride;
242   } while (--height != 0);
243 
244   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
245   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
246 }
247 
248 // The function calls 3 subfunctions for the following cases :
249 // 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements
250 //    of C and H need to be computed.
251 // 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are
252 //    non-zero and need to be computed.
253 // 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are
254 //    non-zero and need to be computed.
av1_calc_proj_params_high_bd_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)255 void av1_calc_proj_params_high_bd_neon(const uint8_t *src8, int width,
256                                        int height, int src_stride,
257                                        const uint8_t *dat8, int dat_stride,
258                                        int32_t *flt0, int flt0_stride,
259                                        int32_t *flt1, int flt1_stride,
260                                        int64_t H[2][2], int64_t C[2],
261                                        const sgr_params_type *params) {
262   if ((params->r[0] > 0) && (params->r[1] > 0)) {
263     highbd_calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8,
264                                        dat_stride, flt0, flt0_stride, flt1,
265                                        flt1_stride, H, C);
266   } else if (params->r[0] > 0) {
267     highbd_calc_proj_params_r0_neon(src8, width, height, src_stride, dat8,
268                                     dat_stride, flt0, flt0_stride, H, C);
269   } else if (params->r[1] > 0) {
270     highbd_calc_proj_params_r1_neon(src8, width, height, src_stride, dat8,
271                                     dat_stride, flt1, flt1_stride, H, C);
272   }
273 }
274 
tbl2q(int16x8_t a,int16x8_t b,uint8x16_t idx)275 static INLINE int16x8_t tbl2q(int16x8_t a, int16x8_t b, uint8x16_t idx) {
276 #if AOM_ARCH_AARCH64
277   uint8x16x2_t table = { { vreinterpretq_u8_s16(a), vreinterpretq_u8_s16(b) } };
278   return vreinterpretq_s16_u8(vqtbl2q_u8(table, idx));
279 #else
280   uint8x8x4_t table = { { vreinterpret_u8_s16(vget_low_s16(a)),
281                           vreinterpret_u8_s16(vget_high_s16(a)),
282                           vreinterpret_u8_s16(vget_low_s16(b)),
283                           vreinterpret_u8_s16(vget_high_s16(b)) } };
284   return vreinterpretq_s16_u8(vcombine_u8(vtbl4_u8(table, vget_low_u8(idx)),
285                                           vtbl4_u8(table, vget_high_u8(idx))));
286 #endif
287 }
288 
tbl3q(int16x8_t a,int16x8_t b,int16x8_t c,uint8x16_t idx)289 static INLINE int16x8_t tbl3q(int16x8_t a, int16x8_t b, int16x8_t c,
290                               uint8x16_t idx) {
291 #if AOM_ARCH_AARCH64
292   uint8x16x3_t table = { { vreinterpretq_u8_s16(a), vreinterpretq_u8_s16(b),
293                            vreinterpretq_u8_s16(c) } };
294   return vreinterpretq_s16_u8(vqtbl3q_u8(table, idx));
295 #else
296   // This is a specific implementation working only for compute stats with
297   // wiener_win == 5.
298   uint8x8x3_t table_lo = { { vreinterpret_u8_s16(vget_low_s16(a)),
299                              vreinterpret_u8_s16(vget_high_s16(a)),
300                              vreinterpret_u8_s16(vget_low_s16(b)) } };
301   uint8x8x3_t table_hi = { { vreinterpret_u8_s16(vget_low_s16(b)),
302                              vreinterpret_u8_s16(vget_high_s16(b)),
303                              vreinterpret_u8_s16(vget_low_s16(c)) } };
304   return vreinterpretq_s16_u8(vcombine_u8(
305       vtbl3_u8(table_lo, vget_low_u8(idx)),
306       vtbl3_u8(table_hi, vsub_u8(vget_high_u8(idx), vdup_n_u8(16)))));
307 #endif
308 }
309 
div_shift_s64(int64_t x,int power)310 static INLINE int64_t div_shift_s64(int64_t x, int power) {
311   return (x < 0 ? x + (1ll << power) - 1 : x) >> power;
312 }
313 
314 // The M matrix is accumulated in a bitdepth-dependent number of steps to
315 // speed up the computation. This function computes the final M from the
316 // accumulated (src_s64) and the residual parts (src_s32). It also transposes
317 // the result as the output needs to be column-major.
acc_transpose_M(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int shift)318 static INLINE void acc_transpose_M(int64_t *dst, const int64_t *src_s64,
319                                    const int32_t *src_s32, const int wiener_win,
320                                    int shift) {
321   for (int i = 0; i < wiener_win; ++i) {
322     for (int j = 0; j < wiener_win; ++j) {
323       int tr_idx = j * wiener_win + i;
324       *dst++ = div_shift_s64(src_s64[tr_idx] + src_s32[tr_idx], shift);
325     }
326   }
327 }
328 
329 // The resulting H is a column-major matrix accumulated from the transposed
330 // (column-major) samples of the filter kernel (5x5 or 7x7) viewed as a single
331 // vector. For the 7x7 filter case: H(49x49) = [49 x 1] x [1 x 49]. This
332 // function transforms back to the originally expected format (double
333 // transpose). The H matrix is accumulated in a bitdepth-dependent number of
334 // steps to speed up the computation. This function computes the final H from
335 // the accumulated (src_s64) and the residual parts (src_s32). The computed H is
336 // only an upper triangle matrix, this function also fills the lower triangle of
337 // the resulting matrix.
update_H(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int stride,int shift)338 static INLINE void update_H(int64_t *dst, const int64_t *src_s64,
339                             const int32_t *src_s32, const int wiener_win,
340                             int stride, int shift) {
341   // For a simplified theoretical 3x3 case where `wiener_win` is 3 and
342   // `wiener_win2` is 9, the M matrix is 3x3:
343   // 0, 3, 6
344   // 1, 4, 7
345   // 2, 5, 8
346   //
347   // This is viewed as a vector to compute H (9x9) by vector outer product:
348   // 0, 3, 6, 1, 4, 7, 2, 5, 8
349   //
350   // Double transpose and upper triangle remapping for 3x3 -> 9x9 case:
351   // 0,    3,    6,    1,    4,    7,    2,    5,    8,
352   // 3,   30,   33,   12,   31,   34,   21,   32,   35,
353   // 6,   33,   60,   15,   42,   61,   24,   51,   62,
354   // 1,   12,   15,   10,   13,   16,   11,   14,   17,
355   // 4,   31,   42,   13,   40,   43,   22,   41,   44,
356   // 7,   34,   61,   16,   43,   70,   25,   52,   71,
357   // 2,   21,   24,   11,   22,   25,   20,   23,   26,
358   // 5,   32,   51,   14,   41,   52,   23,   50,   53,
359   // 8,   35,   62,   17,   44,   71,   26,   53,   80,
360   const int wiener_win2 = wiener_win * wiener_win;
361 
362   // Loop through the indices according to the remapping above, along the
363   // columns:
364   // 0, wiener_win, 2 * wiener_win, ..., 1, 1 + 2 * wiener_win, ...,
365   // wiener_win - 1, wiener_win - 1 + wiener_win, ...
366   // For the 3x3 case `j` will be: 0, 3, 6, 1, 4, 7, 2, 5, 8.
367   for (int i = 0; i < wiener_win; ++i) {
368     for (int j = i; j < wiener_win2; j += wiener_win) {
369       // These two inner loops are the same as the two outer loops, but running
370       // along rows instead of columns. For the 3x3 case `l` will be:
371       // 0, 3, 6, 1, 4, 7, 2, 5, 8.
372       for (int k = 0; k < wiener_win; ++k) {
373         for (int l = k; l < wiener_win2; l += wiener_win) {
374           // The nominal double transpose indexing would be:
375           // int idx = stride * j + l;
376           // However we need the upper-right triangle, it is easy with some
377           // min/max operations.
378           int tr_idx = stride * AOMMIN(j, l) + AOMMAX(j, l);
379 
380           // Resulting matrix is filled by combining the 64-bit and the residual
381           // 32-bit matrices together with scaling.
382           *dst++ = div_shift_s64(src_s64[tr_idx] + src_s32[tr_idx], shift);
383         }
384       }
385     }
386   }
387 }
388 
389 // Load 7x7 matrix into 7 128-bit vectors from consecutive rows, the last load
390 // address is offset to prevent out-of-bounds access.
load_and_pack_s16_8x7(int16x8_t dst[7],const int16_t * src,ptrdiff_t stride)391 static INLINE void load_and_pack_s16_8x7(int16x8_t dst[7], const int16_t *src,
392                                          ptrdiff_t stride) {
393   dst[0] = vld1q_s16(src);
394   src += stride;
395   dst[1] = vld1q_s16(src);
396   src += stride;
397   dst[2] = vld1q_s16(src);
398   src += stride;
399   dst[3] = vld1q_s16(src);
400   src += stride;
401   dst[4] = vld1q_s16(src);
402   src += stride;
403   dst[5] = vld1q_s16(src);
404   src += stride;
405   dst[6] = vld1q_s16(src - 1);
406 }
407 
highbd_compute_stats_win7_neon(const uint16_t * dgd,const uint16_t * src,int avg,int width,int height,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)408 static INLINE void highbd_compute_stats_win7_neon(
409     const uint16_t *dgd, const uint16_t *src, int avg, int width, int height,
410     int dgd_stride, int src_stride, int64_t *M, int64_t *H,
411     aom_bit_depth_t bit_depth) {
412   // Matrix names are capitalized to help readability.
413   DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_ALIGN3]);
414   DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_ALIGN3]);
415   DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_ALIGN3]);
416   DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_ALIGN3]);
417   DECLARE_ALIGNED(64, int32_t, H_s32[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
418   DECLARE_ALIGNED(64, int64_t, H_s64[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
419 
420   memset(M_s32, 0, sizeof(M_s32));
421   memset(M_s64, 0, sizeof(M_s64));
422   memset(H_s32, 0, sizeof(H_s32));
423   memset(H_s64, 0, sizeof(H_s64));
424 
425   // Look-up tables to create 8x6 matrix with consecutive elements from two 7x7
426   // matrices.
427   // clang-format off
428   DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats7_highbd[192]) = {
429     0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 16, 17,
430     2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 16, 17, 18, 19,
431     4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 16, 17, 18, 19, 20, 21,
432     6,  7,  8,  9,  10, 11, 12, 13, 16, 17, 18, 19, 20, 21, 22, 23,
433     8,  9,  10, 11, 12, 13, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
434     10, 11, 12, 13, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
435     2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 18, 19,
436     4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 18, 19, 20, 21,
437     6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23,
438     8,  9,  10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25,
439     10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
440     12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
441   };
442   // clang-format on
443 
444   const uint8x16_t lut0 = vld1q_u8(shuffle_stats7_highbd + 0);
445   const uint8x16_t lut1 = vld1q_u8(shuffle_stats7_highbd + 16);
446   const uint8x16_t lut2 = vld1q_u8(shuffle_stats7_highbd + 32);
447   const uint8x16_t lut3 = vld1q_u8(shuffle_stats7_highbd + 48);
448   const uint8x16_t lut4 = vld1q_u8(shuffle_stats7_highbd + 64);
449   const uint8x16_t lut5 = vld1q_u8(shuffle_stats7_highbd + 80);
450   const uint8x16_t lut6 = vld1q_u8(shuffle_stats7_highbd + 96);
451   const uint8x16_t lut7 = vld1q_u8(shuffle_stats7_highbd + 112);
452   const uint8x16_t lut8 = vld1q_u8(shuffle_stats7_highbd + 128);
453   const uint8x16_t lut9 = vld1q_u8(shuffle_stats7_highbd + 144);
454   const uint8x16_t lut10 = vld1q_u8(shuffle_stats7_highbd + 160);
455   const uint8x16_t lut11 = vld1q_u8(shuffle_stats7_highbd + 176);
456 
457   // We can accumulate up to 65536/4096/256 8/10/12-bit multiplication results
458   // in 32-bit. We are processing 2 pixels at a time, so the accumulator max can
459   // be as high as 32768/2048/128 for the compute stats.
460   const int acc_cnt_max = (1 << (32 - 2 * bit_depth)) >> 1;
461   int acc_cnt = acc_cnt_max;
462   const int src_next = src_stride - width;
463   const int dgd_next = dgd_stride - width;
464   const int16x8_t avg_s16 = vdupq_n_s16(avg);
465 
466   do {
467     int j = width;
468     while (j >= 2) {
469       // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
470       // middle 6x7 elements being shared.
471       int16x8_t dgd_rows[7];
472       load_and_pack_s16_8x7(dgd_rows, (const int16_t *)dgd, dgd_stride);
473 
474       const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 6;
475       dgd += 2;
476 
477       dgd_rows[0] = vsubq_s16(dgd_rows[0], avg_s16);
478       dgd_rows[1] = vsubq_s16(dgd_rows[1], avg_s16);
479       dgd_rows[2] = vsubq_s16(dgd_rows[2], avg_s16);
480       dgd_rows[3] = vsubq_s16(dgd_rows[3], avg_s16);
481       dgd_rows[4] = vsubq_s16(dgd_rows[4], avg_s16);
482       dgd_rows[5] = vsubq_s16(dgd_rows[5], avg_s16);
483       dgd_rows[6] = vsubq_s16(dgd_rows[6], avg_s16);
484 
485       // Re-arrange the combined 8x7 matrix to have the 2 whole 7x7 matrices (1
486       // for each of the 2 pixels) separated into distinct int16x8_t[6] arrays.
487       // These arrays contain 48 elements of the 49 (7x7). Compute `dgd - avg`
488       // for both buffers. Each DGD_AVG buffer contains 49 consecutive elements.
489       int16x8_t dgd_avg0[6];
490       int16x8_t dgd_avg1[6];
491 
492       dgd_avg0[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
493       dgd_avg1[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut6);
494       dgd_avg0[1] = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
495       dgd_avg1[1] = tbl2q(dgd_rows[1], dgd_rows[2], lut7);
496       dgd_avg0[2] = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
497       dgd_avg1[2] = tbl2q(dgd_rows[2], dgd_rows[3], lut8);
498       dgd_avg0[3] = tbl2q(dgd_rows[3], dgd_rows[4], lut3);
499       dgd_avg1[3] = tbl2q(dgd_rows[3], dgd_rows[4], lut9);
500       dgd_avg0[4] = tbl2q(dgd_rows[4], dgd_rows[5], lut4);
501       dgd_avg1[4] = tbl2q(dgd_rows[4], dgd_rows[5], lut10);
502       dgd_avg0[5] = tbl2q(dgd_rows[5], dgd_rows[6], lut5);
503       dgd_avg1[5] = tbl2q(dgd_rows[5], dgd_rows[6], lut11);
504 
505       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
506       vst1q_s16(DGD_AVG1, dgd_avg1[0]);
507       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
508       vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
509       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
510       vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
511       vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
512       vst1q_s16(DGD_AVG1 + 24, dgd_avg1[3]);
513       vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
514       vst1q_s16(DGD_AVG1 + 32, dgd_avg1[4]);
515       vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
516       vst1q_s16(DGD_AVG1 + 40, dgd_avg1[5]);
517 
518       // The remaining last (49th) elements of `dgd - avg`.
519       DGD_AVG0[48] = dgd_ptr[6] - avg;
520       DGD_AVG1[48] = dgd_ptr[7] - avg;
521 
522       // Accumulate into row-major variant of matrix M (cross-correlation) for 2
523       // output pixels at a time. M is of size 7 * 7. It needs to be filled such
524       // that multiplying one element from src with each element of a row of the
525       // wiener window will fill one column of M. However this is not very
526       // convenient in terms of memory access, as it means we do contiguous
527       // loads of dgd but strided stores to M. As a result, we use an
528       // intermediate matrix M_s32 which is instead filled such that one row of
529       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
530       // then transposed to return M.
531       int src_avg0 = *src++ - avg;
532       int src_avg1 = *src++ - avg;
533       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
534       int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
535       update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
536                        dgd_avg1[0]);
537       update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
538                        dgd_avg1[1]);
539       update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
540                        dgd_avg1[2]);
541       update_M_2pixels(M_s32 + 24, src_avg0_s16, src_avg1_s16, dgd_avg0[3],
542                        dgd_avg1[3]);
543       update_M_2pixels(M_s32 + 32, src_avg0_s16, src_avg1_s16, dgd_avg0[4],
544                        dgd_avg1[4]);
545       update_M_2pixels(M_s32 + 40, src_avg0_s16, src_avg1_s16, dgd_avg0[5],
546                        dgd_avg1[5]);
547 
548       // Last (49th) element of M_s32 can be computed as scalar more efficiently
549       // for 2 output pixels.
550       M_s32[48] += DGD_AVG0[48] * src_avg0 + DGD_AVG1[48] * src_avg1;
551 
552       // Start accumulating into row-major version of matrix H
553       // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
554       // row-major. H is of size 49 * 49. It is filled by multiplying every pair
555       // of elements of the wiener window together (vector outer product). Since
556       // it is a symmetric matrix, we only compute the upper-right triangle, and
557       // then copy it down to the lower-left later. The upper triangle is
558       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
559       // column-major and the resulting H matrix is also expected to be
560       // column-major. It is not efficient to work with column-major matrices,
561       // so we accumulate into a row-major matrix H_s32. At the end of the
562       // algorithm a double transpose transformation will convert H_s32 back to
563       // the expected output layout.
564       update_H_7x7_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
565 
566       // The last element of the triangle of H_s32 matrix can be computed as a
567       // scalar more efficiently.
568       H_s32[48 * WIENER_WIN2_ALIGN2 + 48] +=
569           DGD_AVG0[48] * DGD_AVG0[48] + DGD_AVG1[48] * DGD_AVG1[48];
570 
571       // Accumulate into 64-bit after a bit depth dependent number of iterations
572       // to prevent overflow.
573       if (--acc_cnt == 0) {
574         acc_cnt = acc_cnt_max;
575 
576         accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_ALIGN2);
577 
578         // The widening accumulation is only needed for the upper triangle part
579         // of the matrix.
580         int64_t *lh = H_s64;
581         int32_t *lh32 = H_s32;
582         for (int k = 0; k < WIENER_WIN2; ++k) {
583           // The widening accumulation is only run for the relevant parts
584           // (upper-right triangle) in a row 4-element aligned.
585           int k4 = k / 4 * 4;
586           accumulate_and_clear(lh + k4, lh32 + k4, 48 - k4);
587 
588           // Last element of the row is computed separately.
589           lh[48] += lh32[48];
590           lh32[48] = 0;
591 
592           lh += WIENER_WIN2_ALIGN2;
593           lh32 += WIENER_WIN2_ALIGN2;
594         }
595       }
596 
597       j -= 2;
598     }
599 
600     // Computations for odd pixel in the row.
601     if (width & 1) {
602       // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
603       // middle 6x7 elements being shared.
604       int16x8_t dgd_rows[7];
605       load_and_pack_s16_8x7(dgd_rows, (const int16_t *)dgd, dgd_stride);
606 
607       const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 6;
608       ++dgd;
609 
610       // Re-arrange the combined 8x7 matrix to have a whole 7x7 matrix tightly
611       // packed into a int16x8_t[6] array. This array contains 48 elements of
612       // the 49 (7x7). Compute `dgd - avg` for the whole buffer. The DGD_AVG
613       // buffer contains 49 consecutive elements.
614       int16x8_t dgd_avg0[6];
615 
616       dgd_avg0[0] = vsubq_s16(tbl2q(dgd_rows[0], dgd_rows[1], lut0), avg_s16);
617       dgd_avg0[1] = vsubq_s16(tbl2q(dgd_rows[1], dgd_rows[2], lut1), avg_s16);
618       dgd_avg0[2] = vsubq_s16(tbl2q(dgd_rows[2], dgd_rows[3], lut2), avg_s16);
619       dgd_avg0[3] = vsubq_s16(tbl2q(dgd_rows[3], dgd_rows[4], lut3), avg_s16);
620       dgd_avg0[4] = vsubq_s16(tbl2q(dgd_rows[4], dgd_rows[5], lut4), avg_s16);
621       dgd_avg0[5] = vsubq_s16(tbl2q(dgd_rows[5], dgd_rows[6], lut5), avg_s16);
622 
623       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
624       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
625       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
626       vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
627       vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
628       vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
629 
630       // The remaining last (49th) element of `dgd - avg`.
631       DGD_AVG0[48] = dgd_ptr[6] - avg;
632 
633       // Accumulate into row-major order variant of matrix M (cross-correlation)
634       // for 1 output pixel at a time. M is of size 7 * 7. It needs to be filled
635       // such that multiplying one element from src with each element of a row
636       // of the wiener window will fill one column of M. However this is not
637       // very convenient in terms of memory access, as it means we do
638       // contiguous loads of dgd but strided stores to M. As a result, we use an
639       // intermediate matrix M_s32 which is instead filled such that one row of
640       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
641       // then transposed to return M.
642       int src_avg0 = *src++ - avg;
643       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
644       update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
645       update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
646       update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
647       update_M_1pixel(M_s32 + 24, src_avg0_s16, dgd_avg0[3]);
648       update_M_1pixel(M_s32 + 32, src_avg0_s16, dgd_avg0[4]);
649       update_M_1pixel(M_s32 + 40, src_avg0_s16, dgd_avg0[5]);
650 
651       // Last (49th) element of M_s32 can be computed as scalar more efficiently
652       // for 1 output pixel.
653       M_s32[48] += DGD_AVG0[48] * src_avg0;
654 
655       // Start accumulating into row-major order version of matrix H
656       // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
657       // H is of size 49 * 49. It is filled by multiplying every pair of
658       // elements of the wiener window together (vector outer product). Since it
659       // is a symmetric matrix, we only compute the upper-right triangle, and
660       // then copy it down to the lower-left later. The upper triangle is
661       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
662       // column-major and the resulting H matrix is also expected to be
663       // column-major. It is not efficient to work column-major matrices, so we
664       // accumulate into a row-major matrix H_s32. At the end of the algorithm a
665       // double transpose transformation will convert H_s32 back to the expected
666       // output layout.
667       update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_ALIGN2, 48);
668 
669       // The last element of the triangle of H_s32 matrix can be computed as
670       // scalar more efficiently.
671       H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += DGD_AVG0[48] * DGD_AVG0[48];
672     }
673 
674     src += src_next;
675     dgd += dgd_next;
676   } while (--height != 0);
677 
678   int bit_depth_shift = bit_depth - AOM_BITS_8;
679 
680   acc_transpose_M(M, M_s64, M_s32, WIENER_WIN, bit_depth_shift);
681 
682   update_H(H, H_s64, H_s32, WIENER_WIN, WIENER_WIN2_ALIGN2, bit_depth_shift);
683 }
684 
685 // Load 5x5 matrix into 5 128-bit vectors from consecutive rows, the last load
686 // address is offset to prevent out-of-bounds access.
load_and_pack_s16_6x5(int16x8_t dst[5],const int16_t * src,ptrdiff_t stride)687 static INLINE void load_and_pack_s16_6x5(int16x8_t dst[5], const int16_t *src,
688                                          ptrdiff_t stride) {
689   dst[0] = vld1q_s16(src);
690   src += stride;
691   dst[1] = vld1q_s16(src);
692   src += stride;
693   dst[2] = vld1q_s16(src);
694   src += stride;
695   dst[3] = vld1q_s16(src);
696   src += stride;
697   dst[4] = vld1q_s16(src - 3);
698 }
699 
highbd_compute_stats_win5_neon(const uint16_t * dgd,const uint16_t * src,int avg,int width,int height,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)700 static void highbd_compute_stats_win5_neon(const uint16_t *dgd,
701                                            const uint16_t *src, int avg,
702                                            int width, int height,
703                                            int dgd_stride, int src_stride,
704                                            int64_t *M, int64_t *H,
705                                            aom_bit_depth_t bit_depth) {
706   // Matrix names are capitalized to help readability.
707   DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_REDUCED_ALIGN3]);
708   DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_REDUCED_ALIGN3]);
709   DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_REDUCED_ALIGN3]);
710   DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_REDUCED_ALIGN3]);
711   DECLARE_ALIGNED(64, int32_t,
712                   H_s32[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
713   DECLARE_ALIGNED(64, int64_t,
714                   H_s64[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
715 
716   memset(M_s32, 0, sizeof(M_s32));
717   memset(M_s64, 0, sizeof(M_s64));
718   memset(H_s32, 0, sizeof(H_s32));
719   memset(H_s64, 0, sizeof(H_s64));
720 
721   // Look-up tables to create 8x3 matrix with consecutive elements from 5x5
722   // matrix.
723   DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats5_highbd[96]) = {
724     0, 1, 2,  3,  4,  5,  6,  7,  8,  9,  16, 17, 18, 19, 20, 21,
725     6, 7, 8,  9,  16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 32, 33,
726     2, 3, 4,  5,  6,  7,  8,  9,  22, 23, 24, 25, 26, 27, 28, 29,
727     2, 3, 4,  5,  6,  7,  8,  9,  10, 11, 18, 19, 20, 21, 22, 23,
728     8, 9, 10, 11, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 34, 35,
729     4, 5, 6,  7,  8,  9,  10, 11, 24, 25, 26, 27, 28, 29, 30, 31,
730   };
731 
732   const uint8x16_t lut0 = vld1q_u8(shuffle_stats5_highbd + 0);
733   const uint8x16_t lut1 = vld1q_u8(shuffle_stats5_highbd + 16);
734   const uint8x16_t lut2 = vld1q_u8(shuffle_stats5_highbd + 32);
735   const uint8x16_t lut3 = vld1q_u8(shuffle_stats5_highbd + 48);
736   const uint8x16_t lut4 = vld1q_u8(shuffle_stats5_highbd + 64);
737   const uint8x16_t lut5 = vld1q_u8(shuffle_stats5_highbd + 80);
738 
739   // We can accumulate up to 65536/4096/256 8/10/12-bit multiplication results
740   // in 32-bit. We are processing 2 pixels at a time, so the accumulator max can
741   // be as high as 32768/2048/128 for the compute stats.
742   const int acc_cnt_max = (1 << (32 - 2 * bit_depth)) >> 1;
743   int acc_cnt = acc_cnt_max;
744   const int src_next = src_stride - width;
745   const int dgd_next = dgd_stride - width;
746   const int16x8_t avg_s16 = vdupq_n_s16(avg);
747 
748   do {
749     int j = width;
750     while (j >= 2) {
751       // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
752       // middle 4x5 elements being shared.
753       int16x8_t dgd_rows[5];
754       load_and_pack_s16_6x5(dgd_rows, (const int16_t *)dgd, dgd_stride);
755 
756       const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 4;
757       dgd += 2;
758 
759       dgd_rows[0] = vsubq_s16(dgd_rows[0], avg_s16);
760       dgd_rows[1] = vsubq_s16(dgd_rows[1], avg_s16);
761       dgd_rows[2] = vsubq_s16(dgd_rows[2], avg_s16);
762       dgd_rows[3] = vsubq_s16(dgd_rows[3], avg_s16);
763       dgd_rows[4] = vsubq_s16(dgd_rows[4], avg_s16);
764 
765       // Re-arrange the combined 6x5 matrix to have the 2 whole 5x5 matrices (1
766       // for each of the 2 pixels) separated into distinct int16x8_t[3] arrays.
767       // These arrays contain 24 elements of the 25 (5x5). Compute `dgd - avg`
768       // for both buffers. Each DGD_AVG buffer contains 25 consecutive elements.
769       int16x8_t dgd_avg0[3];
770       int16x8_t dgd_avg1[3];
771 
772       dgd_avg0[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
773       dgd_avg1[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut3);
774       dgd_avg0[1] = tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut1);
775       dgd_avg1[1] = tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut4);
776       dgd_avg0[2] = tbl2q(dgd_rows[3], dgd_rows[4], lut2);
777       dgd_avg1[2] = tbl2q(dgd_rows[3], dgd_rows[4], lut5);
778 
779       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
780       vst1q_s16(DGD_AVG1, dgd_avg1[0]);
781       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
782       vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
783       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
784       vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
785 
786       // The remaining last (25th) elements of `dgd - avg`.
787       DGD_AVG0[24] = dgd_ptr[4] - avg;
788       DGD_AVG1[24] = dgd_ptr[5] - avg;
789 
790       // Accumulate into row-major variant of matrix M (cross-correlation) for 2
791       // output pixels at a time. M is of size 5 * 5. It needs to be filled such
792       // that multiplying one element from src with each element of a row of the
793       // wiener window will fill one column of M. However this is not very
794       // convenient in terms of memory access, as it means we do contiguous
795       // loads of dgd but strided stores to M. As a result, we use an
796       // intermediate matrix M_s32 which is instead filled such that one row of
797       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
798       // then transposed to return M.
799       int src_avg0 = *src++ - avg;
800       int src_avg1 = *src++ - avg;
801       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
802       int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
803       update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
804                        dgd_avg1[0]);
805       update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
806                        dgd_avg1[1]);
807       update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
808                        dgd_avg1[2]);
809 
810       // Last (25th) element of M_s32 can be computed as scalar more efficiently
811       // for 2 output pixels.
812       M_s32[24] += DGD_AVG0[24] * src_avg0 + DGD_AVG1[24] * src_avg1;
813 
814       // Start accumulating into row-major version of matrix H
815       // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
816       // row-major. H is of size 25 * 25. It is filled by multiplying every pair
817       // of elements of the wiener window together (vector outer product). Since
818       // it is a symmetric matrix, we only compute the upper-right triangle, and
819       // then copy it down to the lower-left later. The upper triangle is
820       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
821       // column-major and the resulting H matrix is also expected to be
822       // column-major. It is not efficient to work with column-major matrices,
823       // so we accumulate into a row-major matrix H_s32. At the end of the
824       // algorithm a double transpose transformation will convert H_s32 back to
825       // the expected output layout.
826       update_H_5x5_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
827 
828       // The last element of the triangle of H_s32 matrix can be computed as a
829       // scalar more efficiently.
830       H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
831           DGD_AVG0[24] * DGD_AVG0[24] + DGD_AVG1[24] * DGD_AVG1[24];
832 
833       // Accumulate into 64-bit after a bit depth dependent number of iterations
834       // to prevent overflow.
835       if (--acc_cnt == 0) {
836         acc_cnt = acc_cnt_max;
837 
838         accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_REDUCED_ALIGN2);
839 
840         // The widening accumulation is only needed for the upper triangle part
841         // of the matrix.
842         int64_t *lh = H_s64;
843         int32_t *lh32 = H_s32;
844         for (int k = 0; k < WIENER_WIN2_REDUCED; ++k) {
845           // The widening accumulation is only run for the relevant parts
846           // (upper-right triangle) in a row 4-element aligned.
847           int k4 = k / 4 * 4;
848           accumulate_and_clear(lh + k4, lh32 + k4, 24 - k4);
849 
850           // Last element of the row is computed separately.
851           lh[24] += lh32[24];
852           lh32[24] = 0;
853 
854           lh += WIENER_WIN2_REDUCED_ALIGN2;
855           lh32 += WIENER_WIN2_REDUCED_ALIGN2;
856         }
857       }
858 
859       j -= 2;
860     }
861 
862     // Computations for odd pixel in the row.
863     if (width & 1) {
864       // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
865       // middle 4x5 elements being shared.
866       int16x8_t dgd_rows[5];
867       load_and_pack_s16_6x5(dgd_rows, (const int16_t *)dgd, dgd_stride);
868 
869       const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 4;
870       ++dgd;
871 
872       // Re-arrange (and widen) the combined 6x5 matrix to have a whole 5x5
873       // matrix tightly packed into a int16x8_t[3] array. This array contains
874       // 24 elements of the 25 (5x5). Compute `dgd - avg` for the whole buffer.
875       // The DGD_AVG buffer contains 25 consecutive elements.
876       int16x8_t dgd_avg0[3];
877 
878       dgd_avg0[0] = vsubq_s16(tbl2q(dgd_rows[0], dgd_rows[1], lut0), avg_s16);
879       dgd_avg0[1] = vsubq_s16(
880           tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut1), avg_s16);
881       dgd_avg0[2] = vsubq_s16(tbl2q(dgd_rows[3], dgd_rows[4], lut2), avg_s16);
882 
883       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
884       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
885       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
886 
887       // The remaining last (25th) element of `dgd - avg`.
888       DGD_AVG0[24] = dgd_ptr[4] - avg;
889       DGD_AVG1[24] = dgd_ptr[5] - avg;
890 
891       // Accumulate into row-major order variant of matrix M (cross-correlation)
892       // for 1 output pixel at a time. M is of size 5 * 5. It needs to be filled
893       // such that multiplying one element from src with each element of a row
894       // of the wiener window will fill one column of M. However this is not
895       // very convenient in terms of memory access, as it means we do
896       // contiguous loads of dgd but strided stores to M. As a result, we use an
897       // intermediate matrix M_s32 which is instead filled such that one row of
898       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
899       // then transposed to return M.
900       int src_avg0 = *src++ - avg;
901       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
902       update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
903       update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
904       update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
905 
906       // Last (25th) element of M_s32 can be computed as scalar more efficiently
907       // for 1 output pixel.
908       M_s32[24] += DGD_AVG0[24] * src_avg0;
909 
910       // Start accumulating into row-major order version of matrix H
911       // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
912       // H is of size 25 * 25. It is filled by multiplying every pair of
913       // elements of the wiener window together (vector outer product). Since it
914       // is a symmetric matrix, we only compute the upper-right triangle, and
915       // then copy it down to the lower-left later. The upper triangle is
916       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
917       // column-major and the resulting H matrix is also expected to be
918       // column-major. It is not efficient to work with column-major matrices,
919       // so we accumulate into a row-major matrix H_s32. At the end of the
920       // algorithm a double transpose transformation will convert H_s32 back to
921       // the expected output layout.
922       update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_REDUCED_ALIGN2, 24);
923 
924       // The last element of the triangle of H_s32 matrix can be computed as a
925       // scalar more efficiently.
926       H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
927           DGD_AVG0[24] * DGD_AVG0[24];
928     }
929 
930     src += src_next;
931     dgd += dgd_next;
932   } while (--height != 0);
933 
934   int bit_depth_shift = bit_depth - AOM_BITS_8;
935 
936   acc_transpose_M(M, M_s64, M_s32, WIENER_WIN_REDUCED, bit_depth_shift);
937 
938   update_H(H, H_s64, H_s32, WIENER_WIN_REDUCED, WIENER_WIN2_REDUCED_ALIGN2,
939            bit_depth_shift);
940 }
941 
highbd_find_average_neon(const uint16_t * src,int src_stride,int width,int height)942 static uint16_t highbd_find_average_neon(const uint16_t *src, int src_stride,
943                                          int width, int height) {
944   assert(width > 0);
945   assert(height > 0);
946 
947   uint64x2_t sum_u64 = vdupq_n_u64(0);
948   uint64_t sum = 0;
949 
950   int h = height;
951   do {
952     uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
953 
954     int w = width;
955     const uint16_t *row = src;
956     while (w >= 32) {
957       uint16x8_t s0 = vld1q_u16(row + 0);
958       uint16x8_t s1 = vld1q_u16(row + 8);
959       uint16x8_t s2 = vld1q_u16(row + 16);
960       uint16x8_t s3 = vld1q_u16(row + 24);
961 
962       s0 = vaddq_u16(s0, s1);
963       s2 = vaddq_u16(s2, s3);
964       sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
965       sum_u32[1] = vpadalq_u16(sum_u32[1], s2);
966 
967       row += 32;
968       w -= 32;
969     }
970 
971     if (w >= 16) {
972       uint16x8_t s0 = vld1q_u16(row + 0);
973       uint16x8_t s1 = vld1q_u16(row + 8);
974 
975       s0 = vaddq_u16(s0, s1);
976       sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
977 
978       row += 16;
979       w -= 16;
980     }
981 
982     if (w >= 8) {
983       uint16x8_t s0 = vld1q_u16(row);
984       sum_u32[1] = vpadalq_u16(sum_u32[1], s0);
985 
986       row += 8;
987       w -= 8;
988     }
989 
990     if (w >= 4) {
991       uint16x8_t s0 = vcombine_u16(vld1_u16(row), vdup_n_u16(0));
992       sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
993 
994       row += 4;
995       w -= 4;
996     }
997 
998     while (w-- > 0) {
999       sum += *row++;
1000     }
1001 
1002     sum_u64 = vpadalq_u32(sum_u64, vaddq_u32(sum_u32[0], sum_u32[1]));
1003 
1004     src += src_stride;
1005   } while (--h != 0);
1006 
1007   return (uint16_t)((horizontal_add_u64x2(sum_u64) + sum) / (height * width));
1008 }
1009 
av1_compute_stats_highbd_neon(int wiener_win,const uint8_t * dgd8,const uint8_t * src8,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)1010 void av1_compute_stats_highbd_neon(int wiener_win, const uint8_t *dgd8,
1011                                    const uint8_t *src8, int h_start, int h_end,
1012                                    int v_start, int v_end, int dgd_stride,
1013                                    int src_stride, int64_t *M, int64_t *H,
1014                                    aom_bit_depth_t bit_depth) {
1015   assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_REDUCED);
1016 
1017   const int wiener_halfwin = wiener_win >> 1;
1018   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1019   const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
1020   const int height = v_end - v_start;
1021   const int width = h_end - h_start;
1022 
1023   const uint16_t *dgd_start = dgd + h_start + v_start * dgd_stride;
1024   const uint16_t *src_start = src + h_start + v_start * src_stride;
1025 
1026   // The wiener window will slide along the dgd frame, centered on each pixel.
1027   // For the top left pixel and all the pixels on the side of the frame this
1028   // means half of the window will be outside of the frame. As such the actual
1029   // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
1030   // wider and 2 * wiener_halfwin higher than the original dgd buffer.
1031   const int vert_offset = v_start - wiener_halfwin;
1032   const int horiz_offset = h_start - wiener_halfwin;
1033   const uint16_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
1034 
1035   uint16_t avg = highbd_find_average_neon(dgd_start, dgd_stride, width, height);
1036 
1037   if (wiener_win == WIENER_WIN) {
1038     highbd_compute_stats_win7_neon(dgd_win, src_start, avg, width, height,
1039                                    dgd_stride, src_stride, M, H, bit_depth);
1040   } else {
1041     highbd_compute_stats_win5_neon(dgd_win, src_start, avg, width, height,
1042                                    dgd_stride, src_stride, M, H, bit_depth);
1043   }
1044 }
1045 
av1_highbd_pixel_proj_error_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)1046 int64_t av1_highbd_pixel_proj_error_neon(
1047     const uint8_t *src8, int width, int height, int src_stride,
1048     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
1049     int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
1050   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1051   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
1052   int64_t sse = 0;
1053   int64x2_t sse_s64 = vdupq_n_s64(0);
1054 
1055   if (params->r[0] > 0 && params->r[1] > 0) {
1056     int32x2_t xq_v = vld1_s32(xq);
1057     int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), 4);
1058 
1059     do {
1060       int j = 0;
1061       int32x4_t sse_s32 = vdupq_n_s32(0);
1062 
1063       do {
1064         const uint16x8_t d = vld1q_u16(&dat[j]);
1065         const uint16x8_t s = vld1q_u16(&src[j]);
1066         int32x4_t flt0_0 = vld1q_s32(&flt0[j]);
1067         int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]);
1068         int32x4_t flt1_0 = vld1q_s32(&flt1[j]);
1069         int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]);
1070 
1071         int32x4_t d_s32_lo = vreinterpretq_s32_u32(
1072             vmull_lane_u16(vget_low_u16(d), vreinterpret_u16_s32(xq_sum_v), 0));
1073         int32x4_t d_s32_hi = vreinterpretq_s32_u32(vmull_lane_u16(
1074             vget_high_u16(d), vreinterpret_u16_s32(xq_sum_v), 0));
1075 
1076         int32x4_t v0 = vsubq_s32(
1077             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)),
1078             d_s32_lo);
1079         int32x4_t v1 = vsubq_s32(
1080             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)),
1081             d_s32_hi);
1082 
1083         v0 = vmlaq_lane_s32(v0, flt0_0, xq_v, 0);
1084         v1 = vmlaq_lane_s32(v1, flt0_1, xq_v, 0);
1085         v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1);
1086         v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1);
1087 
1088         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1089         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1090 
1091         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1),
1092                                 vreinterpretq_s16_u16(vsubq_u16(d, s)));
1093         int16x4_t e_lo = vget_low_s16(e);
1094         int16x4_t e_hi = vget_high_s16(e);
1095 
1096         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
1097         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
1098 
1099         j += 8;
1100       } while (j <= width - 8);
1101 
1102       for (int k = j; k < width; ++k) {
1103         int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
1104         v += xq[0] * (flt0[k]) + xq[1] * (flt1[k]);
1105         v -= (xq[1] + xq[0]) * (int32_t)(dat[k] << 4);
1106         int32_t e =
1107             (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
1108         sse += ((int64_t)e * e);
1109       }
1110 
1111       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
1112 
1113       dat += dat_stride;
1114       src += src_stride;
1115       flt0 += flt0_stride;
1116       flt1 += flt1_stride;
1117     } while (--height != 0);
1118   } else if (params->r[0] > 0 || params->r[1] > 0) {
1119     int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
1120     int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
1121     int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
1122     int32x4_t xq_v = vdupq_n_s32(xq_active);
1123 
1124     do {
1125       int j = 0;
1126       int32x4_t sse_s32 = vdupq_n_s32(0);
1127       do {
1128         const uint16x8_t d0 = vld1q_u16(&dat[j]);
1129         const uint16x8_t s0 = vld1q_u16(&src[j]);
1130         int32x4_t flt0_0 = vld1q_s32(&flt[j]);
1131         int32x4_t flt0_1 = vld1q_s32(&flt[j + 4]);
1132 
1133         uint16x8_t d_u16 = vshlq_n_u16(d0, 4);
1134         int32x4_t sub0 = vreinterpretq_s32_u32(
1135             vsubw_u16(vreinterpretq_u32_s32(flt0_0), vget_low_u16(d_u16)));
1136         int32x4_t sub1 = vreinterpretq_s32_u32(
1137             vsubw_u16(vreinterpretq_u32_s32(flt0_1), vget_high_u16(d_u16)));
1138 
1139         int32x4_t v0 = vmlaq_s32(
1140             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub0,
1141             xq_v);
1142         int32x4_t v1 = vmlaq_s32(
1143             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub1,
1144             xq_v);
1145 
1146         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1147         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1148 
1149         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1),
1150                                 vreinterpretq_s16_u16(vsubq_u16(d0, s0)));
1151         int16x4_t e_lo = vget_low_s16(e);
1152         int16x4_t e_hi = vget_high_s16(e);
1153 
1154         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
1155         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
1156 
1157         j += 8;
1158       } while (j <= width - 8);
1159 
1160       for (int k = j; k < width; ++k) {
1161         int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
1162         v += xq_active * (int32_t)((uint32_t)flt[j] - (uint16_t)(dat[k] << 4));
1163         const int32_t e =
1164             (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
1165         sse += ((int64_t)e * e);
1166       }
1167 
1168       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
1169 
1170       dat += dat_stride;
1171       flt += flt_stride;
1172       src += src_stride;
1173     } while (--height != 0);
1174   } else {
1175     do {
1176       int j = 0;
1177 
1178       do {
1179         const uint16x8_t d = vld1q_u16(&dat[j]);
1180         const uint16x8_t s = vld1q_u16(&src[j]);
1181 
1182         uint16x8_t diff = vabdq_u16(d, s);
1183         uint16x4_t diff_lo = vget_low_u16(diff);
1184         uint16x4_t diff_hi = vget_high_u16(diff);
1185 
1186         uint32x4_t sqr_lo = vmull_u16(diff_lo, diff_lo);
1187         uint32x4_t sqr_hi = vmull_u16(diff_hi, diff_hi);
1188 
1189         sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_lo));
1190         sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_hi));
1191 
1192         j += 8;
1193       } while (j <= width - 8);
1194 
1195       for (int k = j; k < width; ++k) {
1196         int32_t e = dat[k] - src[k];
1197         sse += e * e;
1198       }
1199 
1200       dat += dat_stride;
1201       src += src_stride;
1202     } while (--height != 0);
1203   }
1204 
1205   sse += horizontal_add_s64x2(sse_s64);
1206   return sse;
1207 }
1208