1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdint.h>
15
16 #include "aom_dsp/arm/mem_neon.h"
17 #include "aom_dsp/arm/sum_neon.h"
18 #include "av1/encoder/arm/neon/pickrst_neon.h"
19 #include "av1/encoder/pickrst.h"
20
highbd_calc_proj_params_r0_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])21 static INLINE void highbd_calc_proj_params_r0_r1_neon(
22 const uint8_t *src8, int width, int height, int src_stride,
23 const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
24 int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
25 assert(width % 8 == 0);
26 const int size = width * height;
27 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
28 const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
29
30 int64x2_t h00_lo = vdupq_n_s64(0);
31 int64x2_t h00_hi = vdupq_n_s64(0);
32 int64x2_t h11_lo = vdupq_n_s64(0);
33 int64x2_t h11_hi = vdupq_n_s64(0);
34 int64x2_t h01_lo = vdupq_n_s64(0);
35 int64x2_t h01_hi = vdupq_n_s64(0);
36 int64x2_t c0_lo = vdupq_n_s64(0);
37 int64x2_t c0_hi = vdupq_n_s64(0);
38 int64x2_t c1_lo = vdupq_n_s64(0);
39 int64x2_t c1_hi = vdupq_n_s64(0);
40
41 do {
42 const uint16_t *src_ptr = src;
43 const uint16_t *dat_ptr = dat;
44 int32_t *flt0_ptr = flt0;
45 int32_t *flt1_ptr = flt1;
46 int w = width;
47
48 do {
49 uint16x8_t s = vld1q_u16(src_ptr);
50 uint16x8_t d = vld1q_u16(dat_ptr);
51 int32x4_t f0_lo = vld1q_s32(flt0_ptr);
52 int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
53 int32x4_t f1_lo = vld1q_s32(flt1_ptr);
54 int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
55
56 int32x4_t u_lo =
57 vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
58 int32x4_t u_hi = vreinterpretq_s32_u32(
59 vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
60 int32x4_t s_lo =
61 vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
62 int32x4_t s_hi = vreinterpretq_s32_u32(
63 vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
64 s_lo = vsubq_s32(s_lo, u_lo);
65 s_hi = vsubq_s32(s_hi, u_hi);
66
67 f0_lo = vsubq_s32(f0_lo, u_lo);
68 f0_hi = vsubq_s32(f0_hi, u_hi);
69 f1_lo = vsubq_s32(f1_lo, u_lo);
70 f1_hi = vsubq_s32(f1_hi, u_hi);
71
72 h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
73 h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
74 h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
75 h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
76
77 h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
78 h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
79 h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
80 h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
81
82 h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo));
83 h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo));
84 h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi));
85 h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi));
86
87 c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
88 c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
89 c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
90 c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
91
92 c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
93 c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
94 c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
95 c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
96
97 src_ptr += 8;
98 dat_ptr += 8;
99 flt0_ptr += 8;
100 flt1_ptr += 8;
101 w -= 8;
102 } while (w != 0);
103
104 src += src_stride;
105 dat += dat_stride;
106 flt0 += flt0_stride;
107 flt1 += flt1_stride;
108 } while (--height != 0);
109
110 H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
111 H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size;
112 H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
113 H[1][0] = H[0][1];
114 C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
115 C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
116 }
117
highbd_calc_proj_params_r0_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])118 static INLINE void highbd_calc_proj_params_r0_neon(
119 const uint8_t *src8, int width, int height, int src_stride,
120 const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
121 int64_t H[2][2], int64_t C[2]) {
122 assert(width % 8 == 0);
123 const int size = width * height;
124 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
125 const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
126
127 int64x2_t h00_lo = vdupq_n_s64(0);
128 int64x2_t h00_hi = vdupq_n_s64(0);
129 int64x2_t c0_lo = vdupq_n_s64(0);
130 int64x2_t c0_hi = vdupq_n_s64(0);
131
132 do {
133 const uint16_t *src_ptr = src;
134 const uint16_t *dat_ptr = dat;
135 int32_t *flt0_ptr = flt0;
136 int w = width;
137
138 do {
139 uint16x8_t s = vld1q_u16(src_ptr);
140 uint16x8_t d = vld1q_u16(dat_ptr);
141 int32x4_t f0_lo = vld1q_s32(flt0_ptr);
142 int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
143
144 int32x4_t u_lo =
145 vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
146 int32x4_t u_hi = vreinterpretq_s32_u32(
147 vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
148 int32x4_t s_lo =
149 vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
150 int32x4_t s_hi = vreinterpretq_s32_u32(
151 vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
152 s_lo = vsubq_s32(s_lo, u_lo);
153 s_hi = vsubq_s32(s_hi, u_hi);
154
155 f0_lo = vsubq_s32(f0_lo, u_lo);
156 f0_hi = vsubq_s32(f0_hi, u_hi);
157
158 h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
159 h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
160 h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
161 h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
162
163 c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
164 c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
165 c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
166 c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
167
168 src_ptr += 8;
169 dat_ptr += 8;
170 flt0_ptr += 8;
171 w -= 8;
172 } while (w != 0);
173
174 src += src_stride;
175 dat += dat_stride;
176 flt0 += flt0_stride;
177 } while (--height != 0);
178
179 H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
180 C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
181 }
182
highbd_calc_proj_params_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])183 static INLINE void highbd_calc_proj_params_r1_neon(
184 const uint8_t *src8, int width, int height, int src_stride,
185 const uint8_t *dat8, int dat_stride, int32_t *flt1, int flt1_stride,
186 int64_t H[2][2], int64_t C[2]) {
187 assert(width % 8 == 0);
188 const int size = width * height;
189 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
190 const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
191
192 int64x2_t h11_lo = vdupq_n_s64(0);
193 int64x2_t h11_hi = vdupq_n_s64(0);
194 int64x2_t c1_lo = vdupq_n_s64(0);
195 int64x2_t c1_hi = vdupq_n_s64(0);
196
197 do {
198 const uint16_t *src_ptr = src;
199 const uint16_t *dat_ptr = dat;
200 int32_t *flt1_ptr = flt1;
201 int w = width;
202
203 do {
204 uint16x8_t s = vld1q_u16(src_ptr);
205 uint16x8_t d = vld1q_u16(dat_ptr);
206 int32x4_t f1_lo = vld1q_s32(flt1_ptr);
207 int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
208
209 int32x4_t u_lo =
210 vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
211 int32x4_t u_hi = vreinterpretq_s32_u32(
212 vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
213 int32x4_t s_lo =
214 vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
215 int32x4_t s_hi = vreinterpretq_s32_u32(
216 vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
217 s_lo = vsubq_s32(s_lo, u_lo);
218 s_hi = vsubq_s32(s_hi, u_hi);
219
220 f1_lo = vsubq_s32(f1_lo, u_lo);
221 f1_hi = vsubq_s32(f1_hi, u_hi);
222
223 h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
224 h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
225 h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
226 h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
227
228 c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
229 c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
230 c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
231 c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
232
233 src_ptr += 8;
234 dat_ptr += 8;
235 flt1_ptr += 8;
236 w -= 8;
237 } while (w != 0);
238
239 src += src_stride;
240 dat += dat_stride;
241 flt1 += flt1_stride;
242 } while (--height != 0);
243
244 H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
245 C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
246 }
247
248 // The function calls 3 subfunctions for the following cases :
249 // 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements
250 // of C and H need to be computed.
251 // 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are
252 // non-zero and need to be computed.
253 // 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are
254 // non-zero and need to be computed.
av1_calc_proj_params_high_bd_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)255 void av1_calc_proj_params_high_bd_neon(const uint8_t *src8, int width,
256 int height, int src_stride,
257 const uint8_t *dat8, int dat_stride,
258 int32_t *flt0, int flt0_stride,
259 int32_t *flt1, int flt1_stride,
260 int64_t H[2][2], int64_t C[2],
261 const sgr_params_type *params) {
262 if ((params->r[0] > 0) && (params->r[1] > 0)) {
263 highbd_calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8,
264 dat_stride, flt0, flt0_stride, flt1,
265 flt1_stride, H, C);
266 } else if (params->r[0] > 0) {
267 highbd_calc_proj_params_r0_neon(src8, width, height, src_stride, dat8,
268 dat_stride, flt0, flt0_stride, H, C);
269 } else if (params->r[1] > 0) {
270 highbd_calc_proj_params_r1_neon(src8, width, height, src_stride, dat8,
271 dat_stride, flt1, flt1_stride, H, C);
272 }
273 }
274
tbl2q(int16x8_t a,int16x8_t b,uint8x16_t idx)275 static INLINE int16x8_t tbl2q(int16x8_t a, int16x8_t b, uint8x16_t idx) {
276 #if AOM_ARCH_AARCH64
277 uint8x16x2_t table = { { vreinterpretq_u8_s16(a), vreinterpretq_u8_s16(b) } };
278 return vreinterpretq_s16_u8(vqtbl2q_u8(table, idx));
279 #else
280 uint8x8x4_t table = { { vreinterpret_u8_s16(vget_low_s16(a)),
281 vreinterpret_u8_s16(vget_high_s16(a)),
282 vreinterpret_u8_s16(vget_low_s16(b)),
283 vreinterpret_u8_s16(vget_high_s16(b)) } };
284 return vreinterpretq_s16_u8(vcombine_u8(vtbl4_u8(table, vget_low_u8(idx)),
285 vtbl4_u8(table, vget_high_u8(idx))));
286 #endif
287 }
288
tbl3q(int16x8_t a,int16x8_t b,int16x8_t c,uint8x16_t idx)289 static INLINE int16x8_t tbl3q(int16x8_t a, int16x8_t b, int16x8_t c,
290 uint8x16_t idx) {
291 #if AOM_ARCH_AARCH64
292 uint8x16x3_t table = { { vreinterpretq_u8_s16(a), vreinterpretq_u8_s16(b),
293 vreinterpretq_u8_s16(c) } };
294 return vreinterpretq_s16_u8(vqtbl3q_u8(table, idx));
295 #else
296 // This is a specific implementation working only for compute stats with
297 // wiener_win == 5.
298 uint8x8x3_t table_lo = { { vreinterpret_u8_s16(vget_low_s16(a)),
299 vreinterpret_u8_s16(vget_high_s16(a)),
300 vreinterpret_u8_s16(vget_low_s16(b)) } };
301 uint8x8x3_t table_hi = { { vreinterpret_u8_s16(vget_low_s16(b)),
302 vreinterpret_u8_s16(vget_high_s16(b)),
303 vreinterpret_u8_s16(vget_low_s16(c)) } };
304 return vreinterpretq_s16_u8(vcombine_u8(
305 vtbl3_u8(table_lo, vget_low_u8(idx)),
306 vtbl3_u8(table_hi, vsub_u8(vget_high_u8(idx), vdup_n_u8(16)))));
307 #endif
308 }
309
div_shift_s64(int64_t x,int power)310 static INLINE int64_t div_shift_s64(int64_t x, int power) {
311 return (x < 0 ? x + (1ll << power) - 1 : x) >> power;
312 }
313
314 // The M matrix is accumulated in a bitdepth-dependent number of steps to
315 // speed up the computation. This function computes the final M from the
316 // accumulated (src_s64) and the residual parts (src_s32). It also transposes
317 // the result as the output needs to be column-major.
acc_transpose_M(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int shift)318 static INLINE void acc_transpose_M(int64_t *dst, const int64_t *src_s64,
319 const int32_t *src_s32, const int wiener_win,
320 int shift) {
321 for (int i = 0; i < wiener_win; ++i) {
322 for (int j = 0; j < wiener_win; ++j) {
323 int tr_idx = j * wiener_win + i;
324 *dst++ = div_shift_s64(src_s64[tr_idx] + src_s32[tr_idx], shift);
325 }
326 }
327 }
328
329 // The resulting H is a column-major matrix accumulated from the transposed
330 // (column-major) samples of the filter kernel (5x5 or 7x7) viewed as a single
331 // vector. For the 7x7 filter case: H(49x49) = [49 x 1] x [1 x 49]. This
332 // function transforms back to the originally expected format (double
333 // transpose). The H matrix is accumulated in a bitdepth-dependent number of
334 // steps to speed up the computation. This function computes the final H from
335 // the accumulated (src_s64) and the residual parts (src_s32). The computed H is
336 // only an upper triangle matrix, this function also fills the lower triangle of
337 // the resulting matrix.
update_H(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int stride,int shift)338 static INLINE void update_H(int64_t *dst, const int64_t *src_s64,
339 const int32_t *src_s32, const int wiener_win,
340 int stride, int shift) {
341 // For a simplified theoretical 3x3 case where `wiener_win` is 3 and
342 // `wiener_win2` is 9, the M matrix is 3x3:
343 // 0, 3, 6
344 // 1, 4, 7
345 // 2, 5, 8
346 //
347 // This is viewed as a vector to compute H (9x9) by vector outer product:
348 // 0, 3, 6, 1, 4, 7, 2, 5, 8
349 //
350 // Double transpose and upper triangle remapping for 3x3 -> 9x9 case:
351 // 0, 3, 6, 1, 4, 7, 2, 5, 8,
352 // 3, 30, 33, 12, 31, 34, 21, 32, 35,
353 // 6, 33, 60, 15, 42, 61, 24, 51, 62,
354 // 1, 12, 15, 10, 13, 16, 11, 14, 17,
355 // 4, 31, 42, 13, 40, 43, 22, 41, 44,
356 // 7, 34, 61, 16, 43, 70, 25, 52, 71,
357 // 2, 21, 24, 11, 22, 25, 20, 23, 26,
358 // 5, 32, 51, 14, 41, 52, 23, 50, 53,
359 // 8, 35, 62, 17, 44, 71, 26, 53, 80,
360 const int wiener_win2 = wiener_win * wiener_win;
361
362 // Loop through the indices according to the remapping above, along the
363 // columns:
364 // 0, wiener_win, 2 * wiener_win, ..., 1, 1 + 2 * wiener_win, ...,
365 // wiener_win - 1, wiener_win - 1 + wiener_win, ...
366 // For the 3x3 case `j` will be: 0, 3, 6, 1, 4, 7, 2, 5, 8.
367 for (int i = 0; i < wiener_win; ++i) {
368 for (int j = i; j < wiener_win2; j += wiener_win) {
369 // These two inner loops are the same as the two outer loops, but running
370 // along rows instead of columns. For the 3x3 case `l` will be:
371 // 0, 3, 6, 1, 4, 7, 2, 5, 8.
372 for (int k = 0; k < wiener_win; ++k) {
373 for (int l = k; l < wiener_win2; l += wiener_win) {
374 // The nominal double transpose indexing would be:
375 // int idx = stride * j + l;
376 // However we need the upper-right triangle, it is easy with some
377 // min/max operations.
378 int tr_idx = stride * AOMMIN(j, l) + AOMMAX(j, l);
379
380 // Resulting matrix is filled by combining the 64-bit and the residual
381 // 32-bit matrices together with scaling.
382 *dst++ = div_shift_s64(src_s64[tr_idx] + src_s32[tr_idx], shift);
383 }
384 }
385 }
386 }
387 }
388
389 // Load 7x7 matrix into 7 128-bit vectors from consecutive rows, the last load
390 // address is offset to prevent out-of-bounds access.
load_and_pack_s16_8x7(int16x8_t dst[7],const int16_t * src,ptrdiff_t stride)391 static INLINE void load_and_pack_s16_8x7(int16x8_t dst[7], const int16_t *src,
392 ptrdiff_t stride) {
393 dst[0] = vld1q_s16(src);
394 src += stride;
395 dst[1] = vld1q_s16(src);
396 src += stride;
397 dst[2] = vld1q_s16(src);
398 src += stride;
399 dst[3] = vld1q_s16(src);
400 src += stride;
401 dst[4] = vld1q_s16(src);
402 src += stride;
403 dst[5] = vld1q_s16(src);
404 src += stride;
405 dst[6] = vld1q_s16(src - 1);
406 }
407
highbd_compute_stats_win7_neon(const uint16_t * dgd,const uint16_t * src,int avg,int width,int height,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)408 static INLINE void highbd_compute_stats_win7_neon(
409 const uint16_t *dgd, const uint16_t *src, int avg, int width, int height,
410 int dgd_stride, int src_stride, int64_t *M, int64_t *H,
411 aom_bit_depth_t bit_depth) {
412 // Matrix names are capitalized to help readability.
413 DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_ALIGN3]);
414 DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_ALIGN3]);
415 DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_ALIGN3]);
416 DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_ALIGN3]);
417 DECLARE_ALIGNED(64, int32_t, H_s32[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
418 DECLARE_ALIGNED(64, int64_t, H_s64[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
419
420 memset(M_s32, 0, sizeof(M_s32));
421 memset(M_s64, 0, sizeof(M_s64));
422 memset(H_s32, 0, sizeof(H_s32));
423 memset(H_s64, 0, sizeof(H_s64));
424
425 // Look-up tables to create 8x6 matrix with consecutive elements from two 7x7
426 // matrices.
427 // clang-format off
428 DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats7_highbd[192]) = {
429 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17,
430 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 19,
431 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 21,
432 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 21, 22, 23,
433 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
434 10, 11, 12, 13, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
435 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19,
436 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21,
437 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23,
438 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25,
439 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
440 12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
441 };
442 // clang-format on
443
444 const uint8x16_t lut0 = vld1q_u8(shuffle_stats7_highbd + 0);
445 const uint8x16_t lut1 = vld1q_u8(shuffle_stats7_highbd + 16);
446 const uint8x16_t lut2 = vld1q_u8(shuffle_stats7_highbd + 32);
447 const uint8x16_t lut3 = vld1q_u8(shuffle_stats7_highbd + 48);
448 const uint8x16_t lut4 = vld1q_u8(shuffle_stats7_highbd + 64);
449 const uint8x16_t lut5 = vld1q_u8(shuffle_stats7_highbd + 80);
450 const uint8x16_t lut6 = vld1q_u8(shuffle_stats7_highbd + 96);
451 const uint8x16_t lut7 = vld1q_u8(shuffle_stats7_highbd + 112);
452 const uint8x16_t lut8 = vld1q_u8(shuffle_stats7_highbd + 128);
453 const uint8x16_t lut9 = vld1q_u8(shuffle_stats7_highbd + 144);
454 const uint8x16_t lut10 = vld1q_u8(shuffle_stats7_highbd + 160);
455 const uint8x16_t lut11 = vld1q_u8(shuffle_stats7_highbd + 176);
456
457 // We can accumulate up to 65536/4096/256 8/10/12-bit multiplication results
458 // in 32-bit. We are processing 2 pixels at a time, so the accumulator max can
459 // be as high as 32768/2048/128 for the compute stats.
460 const int acc_cnt_max = (1 << (32 - 2 * bit_depth)) >> 1;
461 int acc_cnt = acc_cnt_max;
462 const int src_next = src_stride - width;
463 const int dgd_next = dgd_stride - width;
464 const int16x8_t avg_s16 = vdupq_n_s16(avg);
465
466 do {
467 int j = width;
468 while (j >= 2) {
469 // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
470 // middle 6x7 elements being shared.
471 int16x8_t dgd_rows[7];
472 load_and_pack_s16_8x7(dgd_rows, (const int16_t *)dgd, dgd_stride);
473
474 const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 6;
475 dgd += 2;
476
477 dgd_rows[0] = vsubq_s16(dgd_rows[0], avg_s16);
478 dgd_rows[1] = vsubq_s16(dgd_rows[1], avg_s16);
479 dgd_rows[2] = vsubq_s16(dgd_rows[2], avg_s16);
480 dgd_rows[3] = vsubq_s16(dgd_rows[3], avg_s16);
481 dgd_rows[4] = vsubq_s16(dgd_rows[4], avg_s16);
482 dgd_rows[5] = vsubq_s16(dgd_rows[5], avg_s16);
483 dgd_rows[6] = vsubq_s16(dgd_rows[6], avg_s16);
484
485 // Re-arrange the combined 8x7 matrix to have the 2 whole 7x7 matrices (1
486 // for each of the 2 pixels) separated into distinct int16x8_t[6] arrays.
487 // These arrays contain 48 elements of the 49 (7x7). Compute `dgd - avg`
488 // for both buffers. Each DGD_AVG buffer contains 49 consecutive elements.
489 int16x8_t dgd_avg0[6];
490 int16x8_t dgd_avg1[6];
491
492 dgd_avg0[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
493 dgd_avg1[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut6);
494 dgd_avg0[1] = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
495 dgd_avg1[1] = tbl2q(dgd_rows[1], dgd_rows[2], lut7);
496 dgd_avg0[2] = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
497 dgd_avg1[2] = tbl2q(dgd_rows[2], dgd_rows[3], lut8);
498 dgd_avg0[3] = tbl2q(dgd_rows[3], dgd_rows[4], lut3);
499 dgd_avg1[3] = tbl2q(dgd_rows[3], dgd_rows[4], lut9);
500 dgd_avg0[4] = tbl2q(dgd_rows[4], dgd_rows[5], lut4);
501 dgd_avg1[4] = tbl2q(dgd_rows[4], dgd_rows[5], lut10);
502 dgd_avg0[5] = tbl2q(dgd_rows[5], dgd_rows[6], lut5);
503 dgd_avg1[5] = tbl2q(dgd_rows[5], dgd_rows[6], lut11);
504
505 vst1q_s16(DGD_AVG0, dgd_avg0[0]);
506 vst1q_s16(DGD_AVG1, dgd_avg1[0]);
507 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
508 vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
509 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
510 vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
511 vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
512 vst1q_s16(DGD_AVG1 + 24, dgd_avg1[3]);
513 vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
514 vst1q_s16(DGD_AVG1 + 32, dgd_avg1[4]);
515 vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
516 vst1q_s16(DGD_AVG1 + 40, dgd_avg1[5]);
517
518 // The remaining last (49th) elements of `dgd - avg`.
519 DGD_AVG0[48] = dgd_ptr[6] - avg;
520 DGD_AVG1[48] = dgd_ptr[7] - avg;
521
522 // Accumulate into row-major variant of matrix M (cross-correlation) for 2
523 // output pixels at a time. M is of size 7 * 7. It needs to be filled such
524 // that multiplying one element from src with each element of a row of the
525 // wiener window will fill one column of M. However this is not very
526 // convenient in terms of memory access, as it means we do contiguous
527 // loads of dgd but strided stores to M. As a result, we use an
528 // intermediate matrix M_s32 which is instead filled such that one row of
529 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
530 // then transposed to return M.
531 int src_avg0 = *src++ - avg;
532 int src_avg1 = *src++ - avg;
533 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
534 int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
535 update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
536 dgd_avg1[0]);
537 update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
538 dgd_avg1[1]);
539 update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
540 dgd_avg1[2]);
541 update_M_2pixels(M_s32 + 24, src_avg0_s16, src_avg1_s16, dgd_avg0[3],
542 dgd_avg1[3]);
543 update_M_2pixels(M_s32 + 32, src_avg0_s16, src_avg1_s16, dgd_avg0[4],
544 dgd_avg1[4]);
545 update_M_2pixels(M_s32 + 40, src_avg0_s16, src_avg1_s16, dgd_avg0[5],
546 dgd_avg1[5]);
547
548 // Last (49th) element of M_s32 can be computed as scalar more efficiently
549 // for 2 output pixels.
550 M_s32[48] += DGD_AVG0[48] * src_avg0 + DGD_AVG1[48] * src_avg1;
551
552 // Start accumulating into row-major version of matrix H
553 // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
554 // row-major. H is of size 49 * 49. It is filled by multiplying every pair
555 // of elements of the wiener window together (vector outer product). Since
556 // it is a symmetric matrix, we only compute the upper-right triangle, and
557 // then copy it down to the lower-left later. The upper triangle is
558 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
559 // column-major and the resulting H matrix is also expected to be
560 // column-major. It is not efficient to work with column-major matrices,
561 // so we accumulate into a row-major matrix H_s32. At the end of the
562 // algorithm a double transpose transformation will convert H_s32 back to
563 // the expected output layout.
564 update_H_7x7_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
565
566 // The last element of the triangle of H_s32 matrix can be computed as a
567 // scalar more efficiently.
568 H_s32[48 * WIENER_WIN2_ALIGN2 + 48] +=
569 DGD_AVG0[48] * DGD_AVG0[48] + DGD_AVG1[48] * DGD_AVG1[48];
570
571 // Accumulate into 64-bit after a bit depth dependent number of iterations
572 // to prevent overflow.
573 if (--acc_cnt == 0) {
574 acc_cnt = acc_cnt_max;
575
576 accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_ALIGN2);
577
578 // The widening accumulation is only needed for the upper triangle part
579 // of the matrix.
580 int64_t *lh = H_s64;
581 int32_t *lh32 = H_s32;
582 for (int k = 0; k < WIENER_WIN2; ++k) {
583 // The widening accumulation is only run for the relevant parts
584 // (upper-right triangle) in a row 4-element aligned.
585 int k4 = k / 4 * 4;
586 accumulate_and_clear(lh + k4, lh32 + k4, 48 - k4);
587
588 // Last element of the row is computed separately.
589 lh[48] += lh32[48];
590 lh32[48] = 0;
591
592 lh += WIENER_WIN2_ALIGN2;
593 lh32 += WIENER_WIN2_ALIGN2;
594 }
595 }
596
597 j -= 2;
598 }
599
600 // Computations for odd pixel in the row.
601 if (width & 1) {
602 // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
603 // middle 6x7 elements being shared.
604 int16x8_t dgd_rows[7];
605 load_and_pack_s16_8x7(dgd_rows, (const int16_t *)dgd, dgd_stride);
606
607 const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 6;
608 ++dgd;
609
610 // Re-arrange the combined 8x7 matrix to have a whole 7x7 matrix tightly
611 // packed into a int16x8_t[6] array. This array contains 48 elements of
612 // the 49 (7x7). Compute `dgd - avg` for the whole buffer. The DGD_AVG
613 // buffer contains 49 consecutive elements.
614 int16x8_t dgd_avg0[6];
615
616 dgd_avg0[0] = vsubq_s16(tbl2q(dgd_rows[0], dgd_rows[1], lut0), avg_s16);
617 dgd_avg0[1] = vsubq_s16(tbl2q(dgd_rows[1], dgd_rows[2], lut1), avg_s16);
618 dgd_avg0[2] = vsubq_s16(tbl2q(dgd_rows[2], dgd_rows[3], lut2), avg_s16);
619 dgd_avg0[3] = vsubq_s16(tbl2q(dgd_rows[3], dgd_rows[4], lut3), avg_s16);
620 dgd_avg0[4] = vsubq_s16(tbl2q(dgd_rows[4], dgd_rows[5], lut4), avg_s16);
621 dgd_avg0[5] = vsubq_s16(tbl2q(dgd_rows[5], dgd_rows[6], lut5), avg_s16);
622
623 vst1q_s16(DGD_AVG0, dgd_avg0[0]);
624 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
625 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
626 vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
627 vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
628 vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
629
630 // The remaining last (49th) element of `dgd - avg`.
631 DGD_AVG0[48] = dgd_ptr[6] - avg;
632
633 // Accumulate into row-major order variant of matrix M (cross-correlation)
634 // for 1 output pixel at a time. M is of size 7 * 7. It needs to be filled
635 // such that multiplying one element from src with each element of a row
636 // of the wiener window will fill one column of M. However this is not
637 // very convenient in terms of memory access, as it means we do
638 // contiguous loads of dgd but strided stores to M. As a result, we use an
639 // intermediate matrix M_s32 which is instead filled such that one row of
640 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
641 // then transposed to return M.
642 int src_avg0 = *src++ - avg;
643 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
644 update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
645 update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
646 update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
647 update_M_1pixel(M_s32 + 24, src_avg0_s16, dgd_avg0[3]);
648 update_M_1pixel(M_s32 + 32, src_avg0_s16, dgd_avg0[4]);
649 update_M_1pixel(M_s32 + 40, src_avg0_s16, dgd_avg0[5]);
650
651 // Last (49th) element of M_s32 can be computed as scalar more efficiently
652 // for 1 output pixel.
653 M_s32[48] += DGD_AVG0[48] * src_avg0;
654
655 // Start accumulating into row-major order version of matrix H
656 // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
657 // H is of size 49 * 49. It is filled by multiplying every pair of
658 // elements of the wiener window together (vector outer product). Since it
659 // is a symmetric matrix, we only compute the upper-right triangle, and
660 // then copy it down to the lower-left later. The upper triangle is
661 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
662 // column-major and the resulting H matrix is also expected to be
663 // column-major. It is not efficient to work column-major matrices, so we
664 // accumulate into a row-major matrix H_s32. At the end of the algorithm a
665 // double transpose transformation will convert H_s32 back to the expected
666 // output layout.
667 update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_ALIGN2, 48);
668
669 // The last element of the triangle of H_s32 matrix can be computed as
670 // scalar more efficiently.
671 H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += DGD_AVG0[48] * DGD_AVG0[48];
672 }
673
674 src += src_next;
675 dgd += dgd_next;
676 } while (--height != 0);
677
678 int bit_depth_shift = bit_depth - AOM_BITS_8;
679
680 acc_transpose_M(M, M_s64, M_s32, WIENER_WIN, bit_depth_shift);
681
682 update_H(H, H_s64, H_s32, WIENER_WIN, WIENER_WIN2_ALIGN2, bit_depth_shift);
683 }
684
685 // Load 5x5 matrix into 5 128-bit vectors from consecutive rows, the last load
686 // address is offset to prevent out-of-bounds access.
load_and_pack_s16_6x5(int16x8_t dst[5],const int16_t * src,ptrdiff_t stride)687 static INLINE void load_and_pack_s16_6x5(int16x8_t dst[5], const int16_t *src,
688 ptrdiff_t stride) {
689 dst[0] = vld1q_s16(src);
690 src += stride;
691 dst[1] = vld1q_s16(src);
692 src += stride;
693 dst[2] = vld1q_s16(src);
694 src += stride;
695 dst[3] = vld1q_s16(src);
696 src += stride;
697 dst[4] = vld1q_s16(src - 3);
698 }
699
highbd_compute_stats_win5_neon(const uint16_t * dgd,const uint16_t * src,int avg,int width,int height,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)700 static void highbd_compute_stats_win5_neon(const uint16_t *dgd,
701 const uint16_t *src, int avg,
702 int width, int height,
703 int dgd_stride, int src_stride,
704 int64_t *M, int64_t *H,
705 aom_bit_depth_t bit_depth) {
706 // Matrix names are capitalized to help readability.
707 DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_REDUCED_ALIGN3]);
708 DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_REDUCED_ALIGN3]);
709 DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_REDUCED_ALIGN3]);
710 DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_REDUCED_ALIGN3]);
711 DECLARE_ALIGNED(64, int32_t,
712 H_s32[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
713 DECLARE_ALIGNED(64, int64_t,
714 H_s64[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
715
716 memset(M_s32, 0, sizeof(M_s32));
717 memset(M_s64, 0, sizeof(M_s64));
718 memset(H_s32, 0, sizeof(H_s32));
719 memset(H_s64, 0, sizeof(H_s64));
720
721 // Look-up tables to create 8x3 matrix with consecutive elements from 5x5
722 // matrix.
723 DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats5_highbd[96]) = {
724 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17, 18, 19, 20, 21,
725 6, 7, 8, 9, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 32, 33,
726 2, 3, 4, 5, 6, 7, 8, 9, 22, 23, 24, 25, 26, 27, 28, 29,
727 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 18, 19, 20, 21, 22, 23,
728 8, 9, 10, 11, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 34, 35,
729 4, 5, 6, 7, 8, 9, 10, 11, 24, 25, 26, 27, 28, 29, 30, 31,
730 };
731
732 const uint8x16_t lut0 = vld1q_u8(shuffle_stats5_highbd + 0);
733 const uint8x16_t lut1 = vld1q_u8(shuffle_stats5_highbd + 16);
734 const uint8x16_t lut2 = vld1q_u8(shuffle_stats5_highbd + 32);
735 const uint8x16_t lut3 = vld1q_u8(shuffle_stats5_highbd + 48);
736 const uint8x16_t lut4 = vld1q_u8(shuffle_stats5_highbd + 64);
737 const uint8x16_t lut5 = vld1q_u8(shuffle_stats5_highbd + 80);
738
739 // We can accumulate up to 65536/4096/256 8/10/12-bit multiplication results
740 // in 32-bit. We are processing 2 pixels at a time, so the accumulator max can
741 // be as high as 32768/2048/128 for the compute stats.
742 const int acc_cnt_max = (1 << (32 - 2 * bit_depth)) >> 1;
743 int acc_cnt = acc_cnt_max;
744 const int src_next = src_stride - width;
745 const int dgd_next = dgd_stride - width;
746 const int16x8_t avg_s16 = vdupq_n_s16(avg);
747
748 do {
749 int j = width;
750 while (j >= 2) {
751 // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
752 // middle 4x5 elements being shared.
753 int16x8_t dgd_rows[5];
754 load_and_pack_s16_6x5(dgd_rows, (const int16_t *)dgd, dgd_stride);
755
756 const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 4;
757 dgd += 2;
758
759 dgd_rows[0] = vsubq_s16(dgd_rows[0], avg_s16);
760 dgd_rows[1] = vsubq_s16(dgd_rows[1], avg_s16);
761 dgd_rows[2] = vsubq_s16(dgd_rows[2], avg_s16);
762 dgd_rows[3] = vsubq_s16(dgd_rows[3], avg_s16);
763 dgd_rows[4] = vsubq_s16(dgd_rows[4], avg_s16);
764
765 // Re-arrange the combined 6x5 matrix to have the 2 whole 5x5 matrices (1
766 // for each of the 2 pixels) separated into distinct int16x8_t[3] arrays.
767 // These arrays contain 24 elements of the 25 (5x5). Compute `dgd - avg`
768 // for both buffers. Each DGD_AVG buffer contains 25 consecutive elements.
769 int16x8_t dgd_avg0[3];
770 int16x8_t dgd_avg1[3];
771
772 dgd_avg0[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
773 dgd_avg1[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut3);
774 dgd_avg0[1] = tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut1);
775 dgd_avg1[1] = tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut4);
776 dgd_avg0[2] = tbl2q(dgd_rows[3], dgd_rows[4], lut2);
777 dgd_avg1[2] = tbl2q(dgd_rows[3], dgd_rows[4], lut5);
778
779 vst1q_s16(DGD_AVG0, dgd_avg0[0]);
780 vst1q_s16(DGD_AVG1, dgd_avg1[0]);
781 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
782 vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
783 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
784 vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
785
786 // The remaining last (25th) elements of `dgd - avg`.
787 DGD_AVG0[24] = dgd_ptr[4] - avg;
788 DGD_AVG1[24] = dgd_ptr[5] - avg;
789
790 // Accumulate into row-major variant of matrix M (cross-correlation) for 2
791 // output pixels at a time. M is of size 5 * 5. It needs to be filled such
792 // that multiplying one element from src with each element of a row of the
793 // wiener window will fill one column of M. However this is not very
794 // convenient in terms of memory access, as it means we do contiguous
795 // loads of dgd but strided stores to M. As a result, we use an
796 // intermediate matrix M_s32 which is instead filled such that one row of
797 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
798 // then transposed to return M.
799 int src_avg0 = *src++ - avg;
800 int src_avg1 = *src++ - avg;
801 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
802 int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
803 update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
804 dgd_avg1[0]);
805 update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
806 dgd_avg1[1]);
807 update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
808 dgd_avg1[2]);
809
810 // Last (25th) element of M_s32 can be computed as scalar more efficiently
811 // for 2 output pixels.
812 M_s32[24] += DGD_AVG0[24] * src_avg0 + DGD_AVG1[24] * src_avg1;
813
814 // Start accumulating into row-major version of matrix H
815 // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
816 // row-major. H is of size 25 * 25. It is filled by multiplying every pair
817 // of elements of the wiener window together (vector outer product). Since
818 // it is a symmetric matrix, we only compute the upper-right triangle, and
819 // then copy it down to the lower-left later. The upper triangle is
820 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
821 // column-major and the resulting H matrix is also expected to be
822 // column-major. It is not efficient to work with column-major matrices,
823 // so we accumulate into a row-major matrix H_s32. At the end of the
824 // algorithm a double transpose transformation will convert H_s32 back to
825 // the expected output layout.
826 update_H_5x5_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
827
828 // The last element of the triangle of H_s32 matrix can be computed as a
829 // scalar more efficiently.
830 H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
831 DGD_AVG0[24] * DGD_AVG0[24] + DGD_AVG1[24] * DGD_AVG1[24];
832
833 // Accumulate into 64-bit after a bit depth dependent number of iterations
834 // to prevent overflow.
835 if (--acc_cnt == 0) {
836 acc_cnt = acc_cnt_max;
837
838 accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_REDUCED_ALIGN2);
839
840 // The widening accumulation is only needed for the upper triangle part
841 // of the matrix.
842 int64_t *lh = H_s64;
843 int32_t *lh32 = H_s32;
844 for (int k = 0; k < WIENER_WIN2_REDUCED; ++k) {
845 // The widening accumulation is only run for the relevant parts
846 // (upper-right triangle) in a row 4-element aligned.
847 int k4 = k / 4 * 4;
848 accumulate_and_clear(lh + k4, lh32 + k4, 24 - k4);
849
850 // Last element of the row is computed separately.
851 lh[24] += lh32[24];
852 lh32[24] = 0;
853
854 lh += WIENER_WIN2_REDUCED_ALIGN2;
855 lh32 += WIENER_WIN2_REDUCED_ALIGN2;
856 }
857 }
858
859 j -= 2;
860 }
861
862 // Computations for odd pixel in the row.
863 if (width & 1) {
864 // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
865 // middle 4x5 elements being shared.
866 int16x8_t dgd_rows[5];
867 load_and_pack_s16_6x5(dgd_rows, (const int16_t *)dgd, dgd_stride);
868
869 const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 4;
870 ++dgd;
871
872 // Re-arrange (and widen) the combined 6x5 matrix to have a whole 5x5
873 // matrix tightly packed into a int16x8_t[3] array. This array contains
874 // 24 elements of the 25 (5x5). Compute `dgd - avg` for the whole buffer.
875 // The DGD_AVG buffer contains 25 consecutive elements.
876 int16x8_t dgd_avg0[3];
877
878 dgd_avg0[0] = vsubq_s16(tbl2q(dgd_rows[0], dgd_rows[1], lut0), avg_s16);
879 dgd_avg0[1] = vsubq_s16(
880 tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut1), avg_s16);
881 dgd_avg0[2] = vsubq_s16(tbl2q(dgd_rows[3], dgd_rows[4], lut2), avg_s16);
882
883 vst1q_s16(DGD_AVG0, dgd_avg0[0]);
884 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
885 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
886
887 // The remaining last (25th) element of `dgd - avg`.
888 DGD_AVG0[24] = dgd_ptr[4] - avg;
889 DGD_AVG1[24] = dgd_ptr[5] - avg;
890
891 // Accumulate into row-major order variant of matrix M (cross-correlation)
892 // for 1 output pixel at a time. M is of size 5 * 5. It needs to be filled
893 // such that multiplying one element from src with each element of a row
894 // of the wiener window will fill one column of M. However this is not
895 // very convenient in terms of memory access, as it means we do
896 // contiguous loads of dgd but strided stores to M. As a result, we use an
897 // intermediate matrix M_s32 which is instead filled such that one row of
898 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
899 // then transposed to return M.
900 int src_avg0 = *src++ - avg;
901 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
902 update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
903 update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
904 update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
905
906 // Last (25th) element of M_s32 can be computed as scalar more efficiently
907 // for 1 output pixel.
908 M_s32[24] += DGD_AVG0[24] * src_avg0;
909
910 // Start accumulating into row-major order version of matrix H
911 // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
912 // H is of size 25 * 25. It is filled by multiplying every pair of
913 // elements of the wiener window together (vector outer product). Since it
914 // is a symmetric matrix, we only compute the upper-right triangle, and
915 // then copy it down to the lower-left later. The upper triangle is
916 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
917 // column-major and the resulting H matrix is also expected to be
918 // column-major. It is not efficient to work with column-major matrices,
919 // so we accumulate into a row-major matrix H_s32. At the end of the
920 // algorithm a double transpose transformation will convert H_s32 back to
921 // the expected output layout.
922 update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_REDUCED_ALIGN2, 24);
923
924 // The last element of the triangle of H_s32 matrix can be computed as a
925 // scalar more efficiently.
926 H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
927 DGD_AVG0[24] * DGD_AVG0[24];
928 }
929
930 src += src_next;
931 dgd += dgd_next;
932 } while (--height != 0);
933
934 int bit_depth_shift = bit_depth - AOM_BITS_8;
935
936 acc_transpose_M(M, M_s64, M_s32, WIENER_WIN_REDUCED, bit_depth_shift);
937
938 update_H(H, H_s64, H_s32, WIENER_WIN_REDUCED, WIENER_WIN2_REDUCED_ALIGN2,
939 bit_depth_shift);
940 }
941
highbd_find_average_neon(const uint16_t * src,int src_stride,int width,int height)942 static uint16_t highbd_find_average_neon(const uint16_t *src, int src_stride,
943 int width, int height) {
944 assert(width > 0);
945 assert(height > 0);
946
947 uint64x2_t sum_u64 = vdupq_n_u64(0);
948 uint64_t sum = 0;
949
950 int h = height;
951 do {
952 uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
953
954 int w = width;
955 const uint16_t *row = src;
956 while (w >= 32) {
957 uint16x8_t s0 = vld1q_u16(row + 0);
958 uint16x8_t s1 = vld1q_u16(row + 8);
959 uint16x8_t s2 = vld1q_u16(row + 16);
960 uint16x8_t s3 = vld1q_u16(row + 24);
961
962 s0 = vaddq_u16(s0, s1);
963 s2 = vaddq_u16(s2, s3);
964 sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
965 sum_u32[1] = vpadalq_u16(sum_u32[1], s2);
966
967 row += 32;
968 w -= 32;
969 }
970
971 if (w >= 16) {
972 uint16x8_t s0 = vld1q_u16(row + 0);
973 uint16x8_t s1 = vld1q_u16(row + 8);
974
975 s0 = vaddq_u16(s0, s1);
976 sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
977
978 row += 16;
979 w -= 16;
980 }
981
982 if (w >= 8) {
983 uint16x8_t s0 = vld1q_u16(row);
984 sum_u32[1] = vpadalq_u16(sum_u32[1], s0);
985
986 row += 8;
987 w -= 8;
988 }
989
990 if (w >= 4) {
991 uint16x8_t s0 = vcombine_u16(vld1_u16(row), vdup_n_u16(0));
992 sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
993
994 row += 4;
995 w -= 4;
996 }
997
998 while (w-- > 0) {
999 sum += *row++;
1000 }
1001
1002 sum_u64 = vpadalq_u32(sum_u64, vaddq_u32(sum_u32[0], sum_u32[1]));
1003
1004 src += src_stride;
1005 } while (--h != 0);
1006
1007 return (uint16_t)((horizontal_add_u64x2(sum_u64) + sum) / (height * width));
1008 }
1009
av1_compute_stats_highbd_neon(int wiener_win,const uint8_t * dgd8,const uint8_t * src8,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)1010 void av1_compute_stats_highbd_neon(int wiener_win, const uint8_t *dgd8,
1011 const uint8_t *src8, int h_start, int h_end,
1012 int v_start, int v_end, int dgd_stride,
1013 int src_stride, int64_t *M, int64_t *H,
1014 aom_bit_depth_t bit_depth) {
1015 assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_REDUCED);
1016
1017 const int wiener_halfwin = wiener_win >> 1;
1018 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1019 const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
1020 const int height = v_end - v_start;
1021 const int width = h_end - h_start;
1022
1023 const uint16_t *dgd_start = dgd + h_start + v_start * dgd_stride;
1024 const uint16_t *src_start = src + h_start + v_start * src_stride;
1025
1026 // The wiener window will slide along the dgd frame, centered on each pixel.
1027 // For the top left pixel and all the pixels on the side of the frame this
1028 // means half of the window will be outside of the frame. As such the actual
1029 // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
1030 // wider and 2 * wiener_halfwin higher than the original dgd buffer.
1031 const int vert_offset = v_start - wiener_halfwin;
1032 const int horiz_offset = h_start - wiener_halfwin;
1033 const uint16_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
1034
1035 uint16_t avg = highbd_find_average_neon(dgd_start, dgd_stride, width, height);
1036
1037 if (wiener_win == WIENER_WIN) {
1038 highbd_compute_stats_win7_neon(dgd_win, src_start, avg, width, height,
1039 dgd_stride, src_stride, M, H, bit_depth);
1040 } else {
1041 highbd_compute_stats_win5_neon(dgd_win, src_start, avg, width, height,
1042 dgd_stride, src_stride, M, H, bit_depth);
1043 }
1044 }
1045
av1_highbd_pixel_proj_error_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)1046 int64_t av1_highbd_pixel_proj_error_neon(
1047 const uint8_t *src8, int width, int height, int src_stride,
1048 const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
1049 int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
1050 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1051 const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
1052 int64_t sse = 0;
1053 int64x2_t sse_s64 = vdupq_n_s64(0);
1054
1055 if (params->r[0] > 0 && params->r[1] > 0) {
1056 int32x2_t xq_v = vld1_s32(xq);
1057 int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), 4);
1058
1059 do {
1060 int j = 0;
1061 int32x4_t sse_s32 = vdupq_n_s32(0);
1062
1063 do {
1064 const uint16x8_t d = vld1q_u16(&dat[j]);
1065 const uint16x8_t s = vld1q_u16(&src[j]);
1066 int32x4_t flt0_0 = vld1q_s32(&flt0[j]);
1067 int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]);
1068 int32x4_t flt1_0 = vld1q_s32(&flt1[j]);
1069 int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]);
1070
1071 int32x4_t d_s32_lo = vreinterpretq_s32_u32(
1072 vmull_lane_u16(vget_low_u16(d), vreinterpret_u16_s32(xq_sum_v), 0));
1073 int32x4_t d_s32_hi = vreinterpretq_s32_u32(vmull_lane_u16(
1074 vget_high_u16(d), vreinterpret_u16_s32(xq_sum_v), 0));
1075
1076 int32x4_t v0 = vsubq_s32(
1077 vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)),
1078 d_s32_lo);
1079 int32x4_t v1 = vsubq_s32(
1080 vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)),
1081 d_s32_hi);
1082
1083 v0 = vmlaq_lane_s32(v0, flt0_0, xq_v, 0);
1084 v1 = vmlaq_lane_s32(v1, flt0_1, xq_v, 0);
1085 v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1);
1086 v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1);
1087
1088 int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1089 int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1090
1091 int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1),
1092 vreinterpretq_s16_u16(vsubq_u16(d, s)));
1093 int16x4_t e_lo = vget_low_s16(e);
1094 int16x4_t e_hi = vget_high_s16(e);
1095
1096 sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
1097 sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
1098
1099 j += 8;
1100 } while (j <= width - 8);
1101
1102 for (int k = j; k < width; ++k) {
1103 int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
1104 v += xq[0] * (flt0[k]) + xq[1] * (flt1[k]);
1105 v -= (xq[1] + xq[0]) * (int32_t)(dat[k] << 4);
1106 int32_t e =
1107 (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
1108 sse += ((int64_t)e * e);
1109 }
1110
1111 sse_s64 = vpadalq_s32(sse_s64, sse_s32);
1112
1113 dat += dat_stride;
1114 src += src_stride;
1115 flt0 += flt0_stride;
1116 flt1 += flt1_stride;
1117 } while (--height != 0);
1118 } else if (params->r[0] > 0 || params->r[1] > 0) {
1119 int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
1120 int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
1121 int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
1122 int32x4_t xq_v = vdupq_n_s32(xq_active);
1123
1124 do {
1125 int j = 0;
1126 int32x4_t sse_s32 = vdupq_n_s32(0);
1127 do {
1128 const uint16x8_t d0 = vld1q_u16(&dat[j]);
1129 const uint16x8_t s0 = vld1q_u16(&src[j]);
1130 int32x4_t flt0_0 = vld1q_s32(&flt[j]);
1131 int32x4_t flt0_1 = vld1q_s32(&flt[j + 4]);
1132
1133 uint16x8_t d_u16 = vshlq_n_u16(d0, 4);
1134 int32x4_t sub0 = vreinterpretq_s32_u32(
1135 vsubw_u16(vreinterpretq_u32_s32(flt0_0), vget_low_u16(d_u16)));
1136 int32x4_t sub1 = vreinterpretq_s32_u32(
1137 vsubw_u16(vreinterpretq_u32_s32(flt0_1), vget_high_u16(d_u16)));
1138
1139 int32x4_t v0 = vmlaq_s32(
1140 vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub0,
1141 xq_v);
1142 int32x4_t v1 = vmlaq_s32(
1143 vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub1,
1144 xq_v);
1145
1146 int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1147 int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1148
1149 int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1),
1150 vreinterpretq_s16_u16(vsubq_u16(d0, s0)));
1151 int16x4_t e_lo = vget_low_s16(e);
1152 int16x4_t e_hi = vget_high_s16(e);
1153
1154 sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
1155 sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
1156
1157 j += 8;
1158 } while (j <= width - 8);
1159
1160 for (int k = j; k < width; ++k) {
1161 int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
1162 v += xq_active * (int32_t)((uint32_t)flt[j] - (uint16_t)(dat[k] << 4));
1163 const int32_t e =
1164 (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
1165 sse += ((int64_t)e * e);
1166 }
1167
1168 sse_s64 = vpadalq_s32(sse_s64, sse_s32);
1169
1170 dat += dat_stride;
1171 flt += flt_stride;
1172 src += src_stride;
1173 } while (--height != 0);
1174 } else {
1175 do {
1176 int j = 0;
1177
1178 do {
1179 const uint16x8_t d = vld1q_u16(&dat[j]);
1180 const uint16x8_t s = vld1q_u16(&src[j]);
1181
1182 uint16x8_t diff = vabdq_u16(d, s);
1183 uint16x4_t diff_lo = vget_low_u16(diff);
1184 uint16x4_t diff_hi = vget_high_u16(diff);
1185
1186 uint32x4_t sqr_lo = vmull_u16(diff_lo, diff_lo);
1187 uint32x4_t sqr_hi = vmull_u16(diff_hi, diff_hi);
1188
1189 sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_lo));
1190 sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_hi));
1191
1192 j += 8;
1193 } while (j <= width - 8);
1194
1195 for (int k = j; k < width; ++k) {
1196 int32_t e = dat[k] - src[k];
1197 sse += e * e;
1198 }
1199
1200 dat += dat_stride;
1201 src += src_stride;
1202 } while (--height != 0);
1203 }
1204
1205 sse += horizontal_add_s64x2(sse_s64);
1206 return sse;
1207 }
1208