1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_config.h"
15 #include "config/av1_rtcd.h"
16
17 #include "aom_dsp/arm/sum_neon.h"
18 #include "av1/common/restoration.h"
19 #include "av1/encoder/arm/neon/pickrst_neon.h"
20 #include "av1/encoder/pickrst.h"
21
av1_lowbd_pixel_proj_error_neon(const uint8_t * src,int width,int height,int src_stride,const uint8_t * dat,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)22 int64_t av1_lowbd_pixel_proj_error_neon(
23 const uint8_t *src, int width, int height, int src_stride,
24 const uint8_t *dat, int dat_stride, int32_t *flt0, int flt0_stride,
25 int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
26 int64_t sse = 0;
27 int64x2_t sse_s64 = vdupq_n_s64(0);
28
29 if (params->r[0] > 0 && params->r[1] > 0) {
30 int32x2_t xq_v = vld1_s32(xq);
31 int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), SGRPROJ_RST_BITS);
32
33 do {
34 int j = 0;
35 int32x4_t sse_s32 = vdupq_n_s32(0);
36
37 do {
38 const uint8x8_t d = vld1_u8(&dat[j]);
39 const uint8x8_t s = vld1_u8(&src[j]);
40 int32x4_t flt0_0 = vld1q_s32(&flt0[j]);
41 int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]);
42 int32x4_t flt1_0 = vld1q_s32(&flt1[j]);
43 int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]);
44
45 int32x4_t offset =
46 vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1));
47 int32x4_t v0 = vmlaq_lane_s32(offset, flt0_0, xq_v, 0);
48 int32x4_t v1 = vmlaq_lane_s32(offset, flt0_1, xq_v, 0);
49
50 v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1);
51 v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1);
52
53 int16x8_t d_s16 = vreinterpretq_s16_u16(vmovl_u8(d));
54 v0 = vmlsl_lane_s16(v0, vget_low_s16(d_s16),
55 vreinterpret_s16_s32(xq_sum_v), 0);
56 v1 = vmlsl_lane_s16(v1, vget_high_s16(d_s16),
57 vreinterpret_s16_s32(xq_sum_v), 0);
58
59 int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
60 int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
61
62 int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s));
63 int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff);
64 int16x4_t e_lo = vget_low_s16(e);
65 int16x4_t e_hi = vget_high_s16(e);
66
67 sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
68 sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
69
70 j += 8;
71 } while (j <= width - 8);
72
73 for (int k = j; k < width; ++k) {
74 int32_t u = (dat[k] << SGRPROJ_RST_BITS);
75 int32_t v = (1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)) +
76 xq[0] * flt0[k] + xq[1] * flt1[k] - u * (xq[0] + xq[1]);
77 int32_t e =
78 (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
79 sse += e * e;
80 }
81
82 sse_s64 = vpadalq_s32(sse_s64, sse_s32);
83
84 dat += dat_stride;
85 src += src_stride;
86 flt0 += flt0_stride;
87 flt1 += flt1_stride;
88 } while (--height != 0);
89 } else if (params->r[0] > 0 || params->r[1] > 0) {
90 int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
91 int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
92 int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
93 int32x2_t xq_v = vdup_n_s32(xq_active);
94
95 do {
96 int32x4_t sse_s32 = vdupq_n_s32(0);
97 int j = 0;
98
99 do {
100 const uint8x8_t d = vld1_u8(&dat[j]);
101 const uint8x8_t s = vld1_u8(&src[j]);
102 int32x4_t flt_0 = vld1q_s32(&flt[j]);
103 int32x4_t flt_1 = vld1q_s32(&flt[j + 4]);
104 int16x8_t d_s16 =
105 vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
106
107 int32x4_t sub_0 = vsubw_s16(flt_0, vget_low_s16(d_s16));
108 int32x4_t sub_1 = vsubw_s16(flt_1, vget_high_s16(d_s16));
109
110 int32x4_t offset =
111 vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1));
112 int32x4_t v0 = vmlaq_lane_s32(offset, sub_0, xq_v, 0);
113 int32x4_t v1 = vmlaq_lane_s32(offset, sub_1, xq_v, 0);
114
115 int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
116 int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
117
118 int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s));
119 int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff);
120 int16x4_t e_lo = vget_low_s16(e);
121 int16x4_t e_hi = vget_high_s16(e);
122
123 sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
124 sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
125
126 j += 8;
127 } while (j <= width - 8);
128
129 for (int k = j; k < width; ++k) {
130 int32_t u = dat[k] << SGRPROJ_RST_BITS;
131 int32_t v = xq_active * (flt[k] - u);
132 int32_t e = ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) +
133 dat[k] - src[k];
134 sse += e * e;
135 }
136
137 sse_s64 = vpadalq_s32(sse_s64, sse_s32);
138
139 dat += dat_stride;
140 src += src_stride;
141 flt += flt_stride;
142 } while (--height != 0);
143 } else {
144 uint32x4_t sse_s32 = vdupq_n_u32(0);
145
146 do {
147 int j = 0;
148
149 do {
150 const uint8x16_t d = vld1q_u8(&dat[j]);
151 const uint8x16_t s = vld1q_u8(&src[j]);
152
153 uint8x16_t diff = vabdq_u8(d, s);
154 uint8x8_t diff_lo = vget_low_u8(diff);
155 uint8x8_t diff_hi = vget_high_u8(diff);
156
157 sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_lo, diff_lo));
158 sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_hi, diff_hi));
159
160 j += 16;
161 } while (j <= width - 16);
162
163 for (int k = j; k < width; ++k) {
164 int32_t e = dat[k] - src[k];
165 sse += e * e;
166 }
167
168 dat += dat_stride;
169 src += src_stride;
170 } while (--height != 0);
171
172 sse_s64 = vreinterpretq_s64_u64(vpaddlq_u32(sse_s32));
173 }
174
175 sse += horizontal_add_s64x2(sse_s64);
176 return sse;
177 }
178
179 // We can accumulate up to 65536 8-bit multiplication results in 32-bit. We are
180 // processing 2 pixels at a time, so the accumulator max can be as high as 32768
181 // for the compute stats.
182 #define STAT_ACCUMULATOR_MAX 32768
183
tbl2(uint8x16_t a,uint8x16_t b,uint8x8_t idx)184 static INLINE uint8x8_t tbl2(uint8x16_t a, uint8x16_t b, uint8x8_t idx) {
185 #if AOM_ARCH_AARCH64
186 uint8x16x2_t table = { { a, b } };
187 return vqtbl2_u8(table, idx);
188 #else
189 uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b),
190 vget_high_u8(b) } };
191 return vtbl4_u8(table, idx);
192 #endif
193 }
194
tbl2q(uint8x16_t a,uint8x16_t b,uint8x16_t idx)195 static INLINE uint8x16_t tbl2q(uint8x16_t a, uint8x16_t b, uint8x16_t idx) {
196 #if AOM_ARCH_AARCH64
197 uint8x16x2_t table = { { a, b } };
198 return vqtbl2q_u8(table, idx);
199 #else
200 uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b),
201 vget_high_u8(b) } };
202 return vcombine_u8(vtbl4_u8(table, vget_low_u8(idx)),
203 vtbl4_u8(table, vget_high_u8(idx)));
204 #endif
205 }
206
207 // The M matrix is accumulated in STAT_ACCUMULATOR_MAX steps to speed-up the
208 // computation. This function computes the final M from the accumulated
209 // (src_s64) and the residual parts (src_s32). It also transposes the result as
210 // the output needs to be column-major.
acc_transpose_M(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int scale)211 static INLINE void acc_transpose_M(int64_t *dst, const int64_t *src_s64,
212 const int32_t *src_s32, const int wiener_win,
213 int scale) {
214 for (int i = 0; i < wiener_win; ++i) {
215 for (int j = 0; j < wiener_win; ++j) {
216 int tr_idx = j * wiener_win + i;
217 *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale;
218 }
219 }
220 }
221
222 // The resulting H is a column-major matrix accumulated from the transposed
223 // (column-major) samples of the filter kernel (5x5 or 7x7) viewed as a single
224 // vector. For the 7x7 filter case: H(49x49) = [49 x 1] x [1 x 49]. This
225 // function transforms back to the originally expected format (double
226 // transpose). The H matrix is accumulated in STAT_ACCUMULATOR_MAX steps to
227 // speed-up the computation. This function computes the final H from the
228 // accumulated (src_s64) and the residual parts (src_s32). The computed H is
229 // only an upper triangle matrix, this function also fills the lower triangle of
230 // the resulting matrix.
update_H(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int stride,int scale)231 static void update_H(int64_t *dst, const int64_t *src_s64,
232 const int32_t *src_s32, const int wiener_win, int stride,
233 int scale) {
234 // For a simplified theoretical 3x3 case where `wiener_win` is 3 and
235 // `wiener_win2` is 9, the M matrix is 3x3:
236 // 0, 3, 6
237 // 1, 4, 7
238 // 2, 5, 8
239 //
240 // This is viewed as a vector to compute H (9x9) by vector outer product:
241 // 0, 3, 6, 1, 4, 7, 2, 5, 8
242 //
243 // Double transpose and upper triangle remapping for 3x3 -> 9x9 case:
244 // 0, 3, 6, 1, 4, 7, 2, 5, 8,
245 // 3, 30, 33, 12, 31, 34, 21, 32, 35,
246 // 6, 33, 60, 15, 42, 61, 24, 51, 62,
247 // 1, 12, 15, 10, 13, 16, 11, 14, 17,
248 // 4, 31, 42, 13, 40, 43, 22, 41, 44,
249 // 7, 34, 61, 16, 43, 70, 25, 52, 71,
250 // 2, 21, 24, 11, 22, 25, 20, 23, 26,
251 // 5, 32, 51, 14, 41, 52, 23, 50, 53,
252 // 8, 35, 62, 17, 44, 71, 26, 53, 80,
253 const int wiener_win2 = wiener_win * wiener_win;
254
255 // Loop through the indices according to the remapping above, along the
256 // columns:
257 // 0, wiener_win, 2 * wiener_win, ..., 1, 1 + 2 * wiener_win, ...,
258 // wiener_win - 1, wiener_win - 1 + wiener_win, ...
259 // For the 3x3 case `j` will be: 0, 3, 6, 1, 4, 7, 2, 5, 8.
260 for (int i = 0; i < wiener_win; ++i) {
261 for (int j = i; j < wiener_win2; j += wiener_win) {
262 // These two inner loops are the same as the two outer loops, but running
263 // along rows instead of columns. For the 3x3 case `l` will be:
264 // 0, 3, 6, 1, 4, 7, 2, 5, 8.
265 for (int k = 0; k < wiener_win; ++k) {
266 for (int l = k; l < wiener_win2; l += wiener_win) {
267 // The nominal double transpose indexing would be:
268 // int idx = stride * j + l;
269 // However we need the upper-triangle indices, it is easy with some
270 // min/max operations.
271 int tr_idx = stride * AOMMIN(j, l) + AOMMAX(j, l);
272
273 // Resulting matrix is filled by combining the 64-bit and the residual
274 // 32-bit matrices together with scaling.
275 *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale;
276 }
277 }
278 }
279 }
280 }
281
282 // Load 7x7 matrix into 3 and a half 128-bit vectors from consecutive rows, the
283 // last load address is offset to prevent out-of-bounds access.
load_and_pack_u8_8x7(uint8x16_t dst[4],const uint8_t * src,ptrdiff_t stride)284 static INLINE void load_and_pack_u8_8x7(uint8x16_t dst[4], const uint8_t *src,
285 ptrdiff_t stride) {
286 dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
287 src += 2 * stride;
288 dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
289 src += 2 * stride;
290 dst[2] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
291 src += 2 * stride;
292 dst[3] = vcombine_u8(vld1_u8(src - 1), vdup_n_u8(0));
293 }
294
compute_stats_win7_neon(const uint8_t * dgd,const uint8_t * src,int width,int height,int dgd_stride,int src_stride,int avg,int64_t * M,int64_t * H,int downsample_factor)295 static INLINE void compute_stats_win7_neon(const uint8_t *dgd,
296 const uint8_t *src, int width,
297 int height, int dgd_stride,
298 int src_stride, int avg, int64_t *M,
299 int64_t *H, int downsample_factor) {
300 // Matrix names are capitalized to help readability.
301 DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_ALIGN3]);
302 DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_ALIGN3]);
303 DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_ALIGN3]);
304 DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_ALIGN3]);
305 DECLARE_ALIGNED(64, int32_t, H_s32[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
306 DECLARE_ALIGNED(64, int64_t, H_s64[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
307
308 memset(M_s32, 0, sizeof(M_s32));
309 memset(M_s64, 0, sizeof(M_s64));
310 memset(H_s32, 0, sizeof(H_s32));
311 memset(H_s64, 0, sizeof(H_s64));
312
313 // Look-up tables to create 8x6 matrix with consecutive elements from two 7x7
314 // matrices.
315 // clang-format off
316 DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats7[96]) = {
317 0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 16, 17,
318 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19,
319 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 17, 18, 19, 20, 21, 22,
320 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18,
321 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20,
322 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23,
323 };
324 // clang-format on
325
326 const uint8x16_t lut0 = vld1q_u8(shuffle_stats7 + 0);
327 const uint8x16_t lut1 = vld1q_u8(shuffle_stats7 + 16);
328 const uint8x16_t lut2 = vld1q_u8(shuffle_stats7 + 32);
329 const uint8x16_t lut3 = vld1q_u8(shuffle_stats7 + 48);
330 const uint8x16_t lut4 = vld1q_u8(shuffle_stats7 + 64);
331 const uint8x16_t lut5 = vld1q_u8(shuffle_stats7 + 80);
332
333 int acc_cnt = STAT_ACCUMULATOR_MAX;
334 const int src_next = downsample_factor * src_stride - width;
335 const int dgd_next = downsample_factor * dgd_stride - width;
336 const uint8x8_t avg_u8 = vdup_n_u8(avg);
337
338 do {
339 int j = width;
340 while (j >= 2) {
341 // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
342 // middle 6x7 elements being shared.
343 uint8x16_t dgd_rows[4];
344 load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride);
345
346 const uint8_t *dgd_ptr = dgd + dgd_stride * 6;
347 dgd += 2;
348
349 // Re-arrange (and widen) the combined 8x7 matrix to have the 2 whole 7x7
350 // matrices (1 for each of the 2 pixels) separated into distinct
351 // int16x8_t[6] arrays. These arrays contain 48 elements of the 49 (7x7).
352 // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 49
353 // consecutive elements.
354 int16x8_t dgd_avg0[6];
355 int16x8_t dgd_avg1[6];
356 uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
357 uint8x16_t dgd_shuf3 = tbl2q(dgd_rows[0], dgd_rows[1], lut3);
358
359 dgd_avg0[0] =
360 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
361 dgd_avg0[1] =
362 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
363 dgd_avg1[0] =
364 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf3), avg_u8));
365 dgd_avg1[1] =
366 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf3), avg_u8));
367
368 vst1q_s16(DGD_AVG0, dgd_avg0[0]);
369 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
370 vst1q_s16(DGD_AVG1, dgd_avg1[0]);
371 vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
372
373 uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
374 uint8x16_t dgd_shuf4 = tbl2q(dgd_rows[1], dgd_rows[2], lut4);
375
376 dgd_avg0[2] =
377 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
378 dgd_avg0[3] =
379 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
380 dgd_avg1[2] =
381 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf4), avg_u8));
382 dgd_avg1[3] =
383 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf4), avg_u8));
384
385 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
386 vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
387 vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
388 vst1q_s16(DGD_AVG1 + 24, dgd_avg1[3]);
389
390 uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
391 uint8x16_t dgd_shuf5 = tbl2q(dgd_rows[2], dgd_rows[3], lut5);
392
393 dgd_avg0[4] =
394 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
395 dgd_avg0[5] =
396 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
397 dgd_avg1[4] =
398 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf5), avg_u8));
399 dgd_avg1[5] =
400 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf5), avg_u8));
401
402 vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
403 vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
404 vst1q_s16(DGD_AVG1 + 32, dgd_avg1[4]);
405 vst1q_s16(DGD_AVG1 + 40, dgd_avg1[5]);
406
407 // The remaining last (49th) elements of `dgd - avg`.
408 DGD_AVG0[48] = dgd_ptr[6] - avg;
409 DGD_AVG1[48] = dgd_ptr[7] - avg;
410
411 // Accumulate into row-major variant of matrix M (cross-correlation) for 2
412 // output pixels at a time. M is of size 7 * 7. It needs to be filled such
413 // that multiplying one element from src with each element of a row of the
414 // wiener window will fill one column of M. However this is not very
415 // convenient in terms of memory access, as it means we do contiguous
416 // loads of dgd but strided stores to M. As a result, we use an
417 // intermediate matrix M_s32 which is instead filled such that one row of
418 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
419 // then transposed to return M.
420 int src_avg0 = *src++ - avg;
421 int src_avg1 = *src++ - avg;
422 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
423 int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
424 update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
425 dgd_avg1[0]);
426 update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
427 dgd_avg1[1]);
428 update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
429 dgd_avg1[2]);
430 update_M_2pixels(M_s32 + 24, src_avg0_s16, src_avg1_s16, dgd_avg0[3],
431 dgd_avg1[3]);
432 update_M_2pixels(M_s32 + 32, src_avg0_s16, src_avg1_s16, dgd_avg0[4],
433 dgd_avg1[4]);
434 update_M_2pixels(M_s32 + 40, src_avg0_s16, src_avg1_s16, dgd_avg0[5],
435 dgd_avg1[5]);
436
437 // Last (49th) element of M_s32 can be computed as scalar more efficiently
438 // for 2 output pixels.
439 M_s32[48] += DGD_AVG0[48] * src_avg0 + DGD_AVG1[48] * src_avg1;
440
441 // Start accumulating into row-major version of matrix H
442 // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
443 // row-major. H is of size 49 * 49. It is filled by multiplying every pair
444 // of elements of the wiener window together (vector outer product). Since
445 // it is a symmetric matrix, we only compute the upper-right triangle, and
446 // then copy it down to the lower-left later. The upper triangle is
447 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
448 // column-major and the resulting H matrix is also expected to be
449 // column-major. It is not efficient to work with column-major matrices,
450 // so we accumulate into a row-major matrix H_s32. At the end of the
451 // algorithm a double transpose transformation will convert H_s32 back to
452 // the expected output layout.
453 update_H_7x7_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
454
455 // The last element of the triangle of H_s32 matrix can be computed as a
456 // scalar more efficiently.
457 H_s32[48 * WIENER_WIN2_ALIGN2 + 48] +=
458 DGD_AVG0[48] * DGD_AVG0[48] + DGD_AVG1[48] * DGD_AVG1[48];
459
460 // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent
461 // overflow.
462 if (--acc_cnt == 0) {
463 acc_cnt = STAT_ACCUMULATOR_MAX;
464
465 accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_ALIGN2);
466
467 // The widening accumulation is only needed for the upper triangle part
468 // of the matrix.
469 int64_t *lh = H_s64;
470 int32_t *lh32 = H_s32;
471 for (int k = 0; k < WIENER_WIN2; ++k) {
472 // The widening accumulation is only run for the relevant parts
473 // (upper-right triangle) in a row 4-element aligned.
474 int k4 = k / 4 * 4;
475 accumulate_and_clear(lh + k4, lh32 + k4, 48 - k4);
476
477 // Last element of the row is computed separately.
478 lh[48] += lh32[48];
479 lh32[48] = 0;
480
481 lh += WIENER_WIN2_ALIGN2;
482 lh32 += WIENER_WIN2_ALIGN2;
483 }
484 }
485
486 j -= 2;
487 }
488
489 // Computations for odd pixel in the row.
490 if (width & 1) {
491 // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
492 // middle 6x7 elements being shared.
493 uint8x16_t dgd_rows[4];
494 load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride);
495
496 const uint8_t *dgd_ptr = dgd + dgd_stride * 6;
497 ++dgd;
498
499 // Re-arrange (and widen) the combined 8x7 matrix to have a whole 7x7
500 // matrix tightly packed into a int16x8_t[6] array. This array contains
501 // 48 elements of the 49 (7x7). Compute `dgd - avg` for the whole buffer.
502 // The DGD_AVG buffer contains 49 consecutive elements.
503 int16x8_t dgd_avg0[6];
504 uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
505 dgd_avg0[0] =
506 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
507 dgd_avg0[1] =
508 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
509 vst1q_s16(DGD_AVG0, dgd_avg0[0]);
510 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
511
512 uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
513 dgd_avg0[2] =
514 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
515 dgd_avg0[3] =
516 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
517 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
518 vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
519
520 uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
521 dgd_avg0[4] =
522 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
523 dgd_avg0[5] =
524 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
525 vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
526 vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
527
528 // The remaining last (49th) element of `dgd - avg`.
529 DGD_AVG0[48] = dgd_ptr[6] - avg;
530
531 // Accumulate into row-major order variant of matrix M (cross-correlation)
532 // for 1 output pixel at a time. M is of size 7 * 7. It needs to be filled
533 // such that multiplying one element from src with each element of a row
534 // of the wiener window will fill one column of M. However this is not
535 // very convenient in terms of memory access, as it means we do
536 // contiguous loads of dgd but strided stores to M. As a result, we use an
537 // intermediate matrix M_s32 which is instead filled such that one row of
538 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
539 // then transposed to return M.
540 int src_avg0 = *src++ - avg;
541 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
542 update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
543 update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
544 update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
545 update_M_1pixel(M_s32 + 24, src_avg0_s16, dgd_avg0[3]);
546 update_M_1pixel(M_s32 + 32, src_avg0_s16, dgd_avg0[4]);
547 update_M_1pixel(M_s32 + 40, src_avg0_s16, dgd_avg0[5]);
548
549 // Last (49th) element of M_s32 can be computed as scalar more efficiently
550 // for 1 output pixel.
551 M_s32[48] += DGD_AVG0[48] * src_avg0;
552
553 // Start accumulating into row-major order version of matrix H
554 // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
555 // H is of size 49 * 49. It is filled by multiplying every pair of
556 // elements of the wiener window together (vector outer product). Since it
557 // is a symmetric matrix, we only compute the upper-right triangle, and
558 // then copy it down to the lower-left later. The upper triangle is
559 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
560 // column-major and the resulting H matrix is also expected to be
561 // column-major. It is not efficient to work column-major matrices, so we
562 // accumulate into a row-major matrix H_s32. At the end of the algorithm a
563 // double transpose transformation will convert H_s32 back to the expected
564 // output layout.
565 update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_ALIGN2, 48);
566
567 // The last element of the triangle of H_s32 matrix can be computed as
568 // scalar more efficiently.
569 H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += DGD_AVG0[48] * DGD_AVG0[48];
570 }
571
572 src += src_next;
573 dgd += dgd_next;
574 } while (--height != 0);
575
576 acc_transpose_M(M, M_s64, M_s32, WIENER_WIN, downsample_factor);
577
578 update_H(H, H_s64, H_s32, WIENER_WIN, WIENER_WIN2_ALIGN2, downsample_factor);
579 }
580
581 // Load 5x5 matrix into 2 and a half 128-bit vectors from consecutive rows, the
582 // last load address is offset to prevent out-of-bounds access.
load_and_pack_u8_6x5(uint8x16_t dst[3],const uint8_t * src,ptrdiff_t stride)583 static INLINE void load_and_pack_u8_6x5(uint8x16_t dst[3], const uint8_t *src,
584 ptrdiff_t stride) {
585 dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
586 src += 2 * stride;
587 dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
588 src += 2 * stride;
589 dst[2] = vcombine_u8(vld1_u8(src - 3), vdup_n_u8(0));
590 }
591
compute_stats_win5_neon(const uint8_t * dgd,const uint8_t * src,int width,int height,int dgd_stride,int src_stride,int avg,int64_t * M,int64_t * H,int downsample_factor)592 static INLINE void compute_stats_win5_neon(const uint8_t *dgd,
593 const uint8_t *src, int width,
594 int height, int dgd_stride,
595 int src_stride, int avg, int64_t *M,
596 int64_t *H, int downsample_factor) {
597 // Matrix names are capitalized to help readability.
598 DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_REDUCED_ALIGN3]);
599 DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_REDUCED_ALIGN3]);
600 DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_REDUCED_ALIGN3]);
601 DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_REDUCED_ALIGN3]);
602 DECLARE_ALIGNED(64, int32_t,
603 H_s32[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
604 DECLARE_ALIGNED(64, int64_t,
605 H_s64[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
606
607 memset(M_s32, 0, sizeof(M_s32));
608 memset(M_s64, 0, sizeof(M_s64));
609 memset(H_s32, 0, sizeof(H_s32));
610 memset(H_s64, 0, sizeof(H_s64));
611
612 // Look-up tables to create 8x3 matrix with consecutive elements from two 5x5
613 // matrices.
614 // clang-format off
615 DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats5[48]) = {
616 0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 16, 17, 18, 19, 20, 24,
617 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21, 25,
618 9, 10, 11, 12, 19, 20, 21, 22, 10, 11, 12, 13, 20, 21, 22, 23,
619 };
620 // clang-format on
621
622 const uint8x16_t lut0 = vld1q_u8(shuffle_stats5 + 0);
623 const uint8x16_t lut1 = vld1q_u8(shuffle_stats5 + 16);
624 const uint8x16_t lut2 = vld1q_u8(shuffle_stats5 + 32);
625
626 int acc_cnt = STAT_ACCUMULATOR_MAX;
627 const int src_next = downsample_factor * src_stride - width;
628 const int dgd_next = downsample_factor * dgd_stride - width;
629 const uint8x8_t avg_u8 = vdup_n_u8(avg);
630
631 do {
632 int j = width;
633 while (j >= 2) {
634 // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
635 // middle 4x5 elements being shared.
636 uint8x16_t dgd_rows[3];
637 load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride);
638
639 const uint8_t *dgd_ptr = dgd + dgd_stride * 4;
640 dgd += 2;
641
642 // Re-arrange (and widen) the combined 6x5 matrix to have the 2 whole 5x5
643 // matrices (1 for each of the 2 pixels) separated into distinct
644 // int16x8_t[3] arrays. These arrays contain 24 elements of the 25 (5x5).
645 // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 25
646 // consecutive elements.
647 int16x8_t dgd_avg0[3];
648 int16x8_t dgd_avg1[3];
649 uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
650 uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[0], dgd_rows[1], lut1);
651 uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[1], dgd_rows[2], lut2);
652
653 dgd_avg0[0] =
654 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
655 dgd_avg0[1] =
656 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
657 dgd_avg0[2] =
658 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
659 dgd_avg1[0] =
660 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
661 dgd_avg1[1] =
662 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
663 dgd_avg1[2] =
664 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
665
666 vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]);
667 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
668 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
669 vst1q_s16(DGD_AVG1 + 0, dgd_avg1[0]);
670 vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
671 vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
672
673 // The remaining last (25th) elements of `dgd - avg`.
674 DGD_AVG0[24] = dgd_ptr[4] - avg;
675 DGD_AVG1[24] = dgd_ptr[5] - avg;
676
677 // Accumulate into row-major variant of matrix M (cross-correlation) for 2
678 // output pixels at a time. M is of size 5 * 5. It needs to be filled such
679 // that multiplying one element from src with each element of a row of the
680 // wiener window will fill one column of M. However this is not very
681 // convenient in terms of memory access, as it means we do contiguous
682 // loads of dgd but strided stores to M. As a result, we use an
683 // intermediate matrix M_s32 which is instead filled such that one row of
684 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
685 // then transposed to return M.
686 int src_avg0 = *src++ - avg;
687 int src_avg1 = *src++ - avg;
688 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
689 int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
690 update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
691 dgd_avg1[0]);
692 update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
693 dgd_avg1[1]);
694 update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
695 dgd_avg1[2]);
696
697 // Last (25th) element of M_s32 can be computed as scalar more efficiently
698 // for 2 output pixels.
699 M_s32[24] += DGD_AVG0[24] * src_avg0 + DGD_AVG1[24] * src_avg1;
700
701 // Start accumulating into row-major version of matrix H
702 // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
703 // row-major. H is of size 25 * 25. It is filled by multiplying every pair
704 // of elements of the wiener window together (vector outer product). Since
705 // it is a symmetric matrix, we only compute the upper-right triangle, and
706 // then copy it down to the lower-left later. The upper triangle is
707 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
708 // column-major and the resulting H matrix is also expected to be
709 // column-major. It is not efficient to work with column-major matrices,
710 // so we accumulate into a row-major matrix H_s32. At the end of the
711 // algorithm a double transpose transformation will convert H_s32 back to
712 // the expected output layout.
713 update_H_5x5_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
714
715 // The last element of the triangle of H_s32 matrix can be computed as a
716 // scalar more efficiently.
717 H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
718 DGD_AVG0[24] * DGD_AVG0[24] + DGD_AVG1[24] * DGD_AVG1[24];
719
720 // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent
721 // overflow.
722 if (--acc_cnt == 0) {
723 acc_cnt = STAT_ACCUMULATOR_MAX;
724
725 accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_REDUCED_ALIGN2);
726
727 // The widening accumulation is only needed for the upper triangle part
728 // of the matrix.
729 int64_t *lh = H_s64;
730 int32_t *lh32 = H_s32;
731 for (int k = 0; k < WIENER_WIN2_REDUCED; ++k) {
732 // The widening accumulation is only run for the relevant parts
733 // (upper-right triangle) in a row 4-element aligned.
734 int k4 = k / 4 * 4;
735 accumulate_and_clear(lh + k4, lh32 + k4, 24 - k4);
736
737 // Last element of the row is computed separately.
738 lh[24] += lh32[24];
739 lh32[24] = 0;
740
741 lh += WIENER_WIN2_REDUCED_ALIGN2;
742 lh32 += WIENER_WIN2_REDUCED_ALIGN2;
743 }
744 }
745
746 j -= 2;
747 }
748
749 // Computations for odd pixel in the row.
750 if (width & 1) {
751 // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
752 // middle 4x5 elements being shared.
753 uint8x16_t dgd_rows[3];
754 load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride);
755
756 const uint8_t *dgd_ptr = dgd + dgd_stride * 4;
757 ++dgd;
758
759 // Re-arrange (and widen) the combined 6x5 matrix to have a whole 5x5
760 // matrix tightly packed into a int16x8_t[3] array. This array contains
761 // 24 elements of the 25 (5x5). Compute `dgd - avg` for the whole buffer.
762 // The DGD_AVG buffer contains 25 consecutive elements.
763 int16x8_t dgd_avg0[3];
764 uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
765 uint8x8_t dgd_shuf1 = tbl2(dgd_rows[1], dgd_rows[2], vget_low_u8(lut2));
766
767 dgd_avg0[0] =
768 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
769 dgd_avg0[1] =
770 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
771 dgd_avg0[2] = vreinterpretq_s16_u16(vsubl_u8(dgd_shuf1, avg_u8));
772
773 vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]);
774 vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
775 vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
776
777 // The remaining last (25th) element of `dgd - avg`.
778 DGD_AVG0[24] = dgd_ptr[4] - avg;
779
780 // Accumulate into row-major order variant of matrix M (cross-correlation)
781 // for 1 output pixel at a time. M is of size 5 * 5. It needs to be filled
782 // such that multiplying one element from src with each element of a row
783 // of the wiener window will fill one column of M. However this is not
784 // very convenient in terms of memory access, as it means we do
785 // contiguous loads of dgd but strided stores to M. As a result, we use an
786 // intermediate matrix M_s32 which is instead filled such that one row of
787 // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
788 // then transposed to return M.
789 int src_avg0 = *src++ - avg;
790 int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
791 update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
792 update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
793 update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
794
795 // Last (25th) element of M_s32 can be computed as scalar more efficiently
796 // for 1 output pixel.
797 M_s32[24] += DGD_AVG0[24] * src_avg0;
798
799 // Start accumulating into row-major order version of matrix H
800 // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
801 // H is of size 25 * 25. It is filled by multiplying every pair of
802 // elements of the wiener window together (vector outer product). Since it
803 // is a symmetric matrix, we only compute the upper-right triangle, and
804 // then copy it down to the lower-left later. The upper triangle is
805 // covered by 4x4 tiles. The original algorithm assumes the M matrix is
806 // column-major and the resulting H matrix is also expected to be
807 // column-major. It is not efficient to work column-major matrices, so we
808 // accumulate into a row-major matrix H_s32. At the end of the algorithm a
809 // double transpose transformation will convert H_s32 back to the expected
810 // output layout.
811 update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_REDUCED_ALIGN2, 24);
812
813 // The last element of the triangle of H_s32 matrix can be computed as a
814 // scalar more efficiently.
815 H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
816 DGD_AVG0[24] * DGD_AVG0[24];
817 }
818
819 src += src_next;
820 dgd += dgd_next;
821 } while (--height != 0);
822
823 acc_transpose_M(M, M_s64, M_s32, WIENER_WIN_REDUCED, downsample_factor);
824
825 update_H(H, H_s64, H_s32, WIENER_WIN_REDUCED, WIENER_WIN2_REDUCED_ALIGN2,
826 downsample_factor);
827 }
828
find_average_neon(const uint8_t * src,int src_stride,int width,int height)829 static INLINE uint8_t find_average_neon(const uint8_t *src, int src_stride,
830 int width, int height) {
831 uint64_t sum = 0;
832
833 if (width >= 16) {
834 int h = 0;
835 // We can accumulate up to 257 8-bit values in a 16-bit value, given
836 // that each 16-bit vector has 8 elements, that means we can process up to
837 // int(257*8/width) rows before we need to widen to 32-bit vector
838 // elements.
839 int h_overflow = 257 * 8 / width;
840 int h_limit = height > h_overflow ? h_overflow : height;
841 uint32x4_t avg_u32 = vdupq_n_u32(0);
842 do {
843 uint16x8_t avg_u16 = vdupq_n_u16(0);
844 do {
845 int j = width;
846 const uint8_t *src_ptr = src;
847 do {
848 uint8x16_t s = vld1q_u8(src_ptr);
849 avg_u16 = vpadalq_u8(avg_u16, s);
850 j -= 16;
851 src_ptr += 16;
852 } while (j >= 16);
853 if (j >= 8) {
854 uint8x8_t s = vld1_u8(src_ptr);
855 avg_u16 = vaddw_u8(avg_u16, s);
856 j -= 8;
857 src_ptr += 8;
858 }
859 // Scalar tail case.
860 while (j > 0) {
861 sum += src[width - j];
862 j--;
863 }
864 src += src_stride;
865 } while (++h < h_limit);
866 avg_u32 = vpadalq_u16(avg_u32, avg_u16);
867
868 h_limit += h_overflow;
869 h_limit = height > h_overflow ? h_overflow : height;
870 } while (h < height);
871 return (uint8_t)((horizontal_long_add_u32x4(avg_u32) + sum) /
872 (width * height));
873 }
874 if (width >= 8) {
875 int h = 0;
876 // We can accumulate up to 257 8-bit values in a 16-bit value, given
877 // that each 16-bit vector has 4 elements, that means we can process up to
878 // int(257*4/width) rows before we need to widen to 32-bit vector
879 // elements.
880 int h_overflow = 257 * 4 / width;
881 int h_limit = height > h_overflow ? h_overflow : height;
882 uint32x2_t avg_u32 = vdup_n_u32(0);
883 do {
884 uint16x4_t avg_u16 = vdup_n_u16(0);
885 do {
886 int j = width;
887 const uint8_t *src_ptr = src;
888 uint8x8_t s = vld1_u8(src_ptr);
889 avg_u16 = vpadal_u8(avg_u16, s);
890 j -= 8;
891 src_ptr += 8;
892 // Scalar tail case.
893 while (j > 0) {
894 sum += src[width - j];
895 j--;
896 }
897 src += src_stride;
898 } while (++h < h_limit);
899 avg_u32 = vpadal_u16(avg_u32, avg_u16);
900
901 h_limit += h_overflow;
902 h_limit = height > h_overflow ? h_overflow : height;
903 } while (h < height);
904 return (uint8_t)((horizontal_long_add_u32x2(avg_u32) + sum) /
905 (width * height));
906 }
907 int i = height;
908 do {
909 int j = 0;
910 do {
911 sum += src[j];
912 } while (++j < width);
913 src += src_stride;
914 } while (--i != 0);
915 return (uint8_t)(sum / (width * height));
916 }
917
av1_compute_stats_neon(int wiener_win,const uint8_t * dgd,const uint8_t * src,int16_t * dgd_avg,int16_t * src_avg,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,int use_downsampled_wiener_stats)918 void av1_compute_stats_neon(int wiener_win, const uint8_t *dgd,
919 const uint8_t *src, int16_t *dgd_avg,
920 int16_t *src_avg, int h_start, int h_end,
921 int v_start, int v_end, int dgd_stride,
922 int src_stride, int64_t *M, int64_t *H,
923 int use_downsampled_wiener_stats) {
924 assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA);
925 assert(WIENER_STATS_DOWNSAMPLE_FACTOR == 4);
926 (void)dgd_avg;
927 (void)src_avg;
928
929 const int wiener_win2 = wiener_win * wiener_win;
930 const int wiener_halfwin = wiener_win >> 1;
931 const int width = h_end - h_start;
932 const int height = v_end - v_start;
933
934 const uint8_t *dgd_start = dgd + h_start + v_start * dgd_stride;
935 const uint8_t *src_start = src + h_start + v_start * src_stride;
936
937 // The wiener window will slide along the dgd frame, centered on each pixel.
938 // For the top left pixel and all the pixels on the side of the frame this
939 // means half of the window will be outside of the frame. As such the actual
940 // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
941 // wider and 2 * wiener_halfwin higher than the original dgd buffer.
942 const int vert_offset = v_start - wiener_halfwin;
943 const int horiz_offset = h_start - wiener_halfwin;
944 const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
945
946 uint8_t avg = find_average_neon(dgd_start, dgd_stride, width, height);
947
948 // Since the height is not necessarily a multiple of the downsample factor,
949 // the last line of src will be scaled according to how many rows remain.
950 int downsample_factor =
951 use_downsampled_wiener_stats ? WIENER_STATS_DOWNSAMPLE_FACTOR : 1;
952
953 int downsampled_height = height / downsample_factor;
954 int downsample_remainder = height % downsample_factor;
955
956 memset(M, 0, wiener_win2 * sizeof(*M));
957 memset(H, 0, wiener_win2 * wiener_win2 * sizeof(*H));
958
959 // Calculate the M and H matrices for the normal and downsampled cases.
960 if (downsampled_height > 0) {
961 if (wiener_win == WIENER_WIN) {
962 compute_stats_win7_neon(dgd_win, src_start, width, downsampled_height,
963 dgd_stride, src_stride, avg, M, H,
964 downsample_factor);
965 } else {
966 compute_stats_win5_neon(dgd_win, src_start, width, downsampled_height,
967 dgd_stride, src_stride, avg, M, H,
968 downsample_factor);
969 }
970 }
971
972 // Accumulate the remaining last rows in the downsampled case.
973 if (downsample_remainder > 0) {
974 int remainder_offset = height - downsample_remainder;
975 if (wiener_win == WIENER_WIN) {
976 compute_stats_win7_neon(dgd_win + remainder_offset * dgd_stride,
977 src_start + remainder_offset * src_stride, width,
978 1, dgd_stride, src_stride, avg, M, H,
979 downsample_remainder);
980 } else {
981 compute_stats_win5_neon(dgd_win + remainder_offset * dgd_stride,
982 src_start + remainder_offset * src_stride, width,
983 1, dgd_stride, src_stride, avg, M, H,
984 downsample_remainder);
985 }
986 }
987 }
988
calc_proj_params_r0_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])989 static INLINE void calc_proj_params_r0_r1_neon(
990 const uint8_t *src8, int width, int height, int src_stride,
991 const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
992 int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
993 assert(width % 8 == 0);
994 const int size = width * height;
995
996 int64x2_t h00_lo = vdupq_n_s64(0);
997 int64x2_t h00_hi = vdupq_n_s64(0);
998 int64x2_t h11_lo = vdupq_n_s64(0);
999 int64x2_t h11_hi = vdupq_n_s64(0);
1000 int64x2_t h01_lo = vdupq_n_s64(0);
1001 int64x2_t h01_hi = vdupq_n_s64(0);
1002 int64x2_t c0_lo = vdupq_n_s64(0);
1003 int64x2_t c0_hi = vdupq_n_s64(0);
1004 int64x2_t c1_lo = vdupq_n_s64(0);
1005 int64x2_t c1_hi = vdupq_n_s64(0);
1006
1007 do {
1008 const uint8_t *src_ptr = src8;
1009 const uint8_t *dat_ptr = dat8;
1010 int32_t *flt0_ptr = flt0;
1011 int32_t *flt1_ptr = flt1;
1012 int w = width;
1013
1014 do {
1015 uint8x8_t s = vld1_u8(src_ptr);
1016 uint8x8_t d = vld1_u8(dat_ptr);
1017 int32x4_t f0_lo = vld1q_s32(flt0_ptr);
1018 int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
1019 int32x4_t f1_lo = vld1q_s32(flt1_ptr);
1020 int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
1021
1022 int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
1023 int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
1024
1025 int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
1026 int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
1027 f0_lo = vsubw_s16(f0_lo, vget_low_s16(u));
1028 f0_hi = vsubw_s16(f0_hi, vget_high_s16(u));
1029 f1_lo = vsubw_s16(f1_lo, vget_low_s16(u));
1030 f1_hi = vsubw_s16(f1_hi, vget_high_s16(u));
1031
1032 h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
1033 h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
1034 h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
1035 h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
1036
1037 h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
1038 h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
1039 h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
1040 h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
1041
1042 h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo));
1043 h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo));
1044 h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi));
1045 h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi));
1046
1047 c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
1048 c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
1049 c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
1050 c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
1051
1052 c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
1053 c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
1054 c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
1055 c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
1056
1057 src_ptr += 8;
1058 dat_ptr += 8;
1059 flt0_ptr += 8;
1060 flt1_ptr += 8;
1061 w -= 8;
1062 } while (w != 0);
1063
1064 src8 += src_stride;
1065 dat8 += dat_stride;
1066 flt0 += flt0_stride;
1067 flt1 += flt1_stride;
1068 } while (--height != 0);
1069
1070 H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
1071 H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size;
1072 H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
1073 H[1][0] = H[0][1];
1074 C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
1075 C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
1076 }
1077
calc_proj_params_r0_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])1078 static INLINE void calc_proj_params_r0_neon(const uint8_t *src8, int width,
1079 int height, int src_stride,
1080 const uint8_t *dat8, int dat_stride,
1081 int32_t *flt0, int flt0_stride,
1082 int64_t H[2][2], int64_t C[2]) {
1083 assert(width % 8 == 0);
1084 const int size = width * height;
1085
1086 int64x2_t h00_lo = vdupq_n_s64(0);
1087 int64x2_t h00_hi = vdupq_n_s64(0);
1088 int64x2_t c0_lo = vdupq_n_s64(0);
1089 int64x2_t c0_hi = vdupq_n_s64(0);
1090
1091 do {
1092 const uint8_t *src_ptr = src8;
1093 const uint8_t *dat_ptr = dat8;
1094 int32_t *flt0_ptr = flt0;
1095 int w = width;
1096
1097 do {
1098 uint8x8_t s = vld1_u8(src_ptr);
1099 uint8x8_t d = vld1_u8(dat_ptr);
1100 int32x4_t f0_lo = vld1q_s32(flt0_ptr);
1101 int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
1102
1103 int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
1104 int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
1105
1106 int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
1107 int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
1108 f0_lo = vsubw_s16(f0_lo, vget_low_s16(u));
1109 f0_hi = vsubw_s16(f0_hi, vget_high_s16(u));
1110
1111 h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
1112 h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
1113 h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
1114 h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
1115
1116 c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
1117 c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
1118 c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
1119 c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
1120
1121 src_ptr += 8;
1122 dat_ptr += 8;
1123 flt0_ptr += 8;
1124 w -= 8;
1125 } while (w != 0);
1126
1127 src8 += src_stride;
1128 dat8 += dat_stride;
1129 flt0 += flt0_stride;
1130 } while (--height != 0);
1131
1132 H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
1133 C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
1134 }
1135
calc_proj_params_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])1136 static INLINE void calc_proj_params_r1_neon(const uint8_t *src8, int width,
1137 int height, int src_stride,
1138 const uint8_t *dat8, int dat_stride,
1139 int32_t *flt1, int flt1_stride,
1140 int64_t H[2][2], int64_t C[2]) {
1141 assert(width % 8 == 0);
1142 const int size = width * height;
1143
1144 int64x2_t h11_lo = vdupq_n_s64(0);
1145 int64x2_t h11_hi = vdupq_n_s64(0);
1146 int64x2_t c1_lo = vdupq_n_s64(0);
1147 int64x2_t c1_hi = vdupq_n_s64(0);
1148
1149 do {
1150 const uint8_t *src_ptr = src8;
1151 const uint8_t *dat_ptr = dat8;
1152 int32_t *flt1_ptr = flt1;
1153 int w = width;
1154
1155 do {
1156 uint8x8_t s = vld1_u8(src_ptr);
1157 uint8x8_t d = vld1_u8(dat_ptr);
1158 int32x4_t f1_lo = vld1q_s32(flt1_ptr);
1159 int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
1160
1161 int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
1162 int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
1163
1164 int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
1165 int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
1166 f1_lo = vsubw_s16(f1_lo, vget_low_s16(u));
1167 f1_hi = vsubw_s16(f1_hi, vget_high_s16(u));
1168
1169 h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
1170 h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
1171 h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
1172 h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
1173
1174 c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
1175 c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
1176 c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
1177 c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
1178
1179 src_ptr += 8;
1180 dat_ptr += 8;
1181 flt1_ptr += 8;
1182 w -= 8;
1183 } while (w != 0);
1184
1185 src8 += src_stride;
1186 dat8 += dat_stride;
1187 flt1 += flt1_stride;
1188 } while (--height != 0);
1189
1190 H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
1191 C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
1192 }
1193
1194 // The function calls 3 subfunctions for the following cases :
1195 // 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements
1196 // of C and H need to be computed.
1197 // 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are
1198 // non-zero and need to be computed.
1199 // 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are
1200 // non-zero and need to be computed.
av1_calc_proj_params_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)1201 void av1_calc_proj_params_neon(const uint8_t *src8, int width, int height,
1202 int src_stride, const uint8_t *dat8,
1203 int dat_stride, int32_t *flt0, int flt0_stride,
1204 int32_t *flt1, int flt1_stride, int64_t H[2][2],
1205 int64_t C[2], const sgr_params_type *params) {
1206 if ((params->r[0] > 0) && (params->r[1] > 0)) {
1207 calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8,
1208 dat_stride, flt0, flt0_stride, flt1,
1209 flt1_stride, H, C);
1210 } else if (params->r[0] > 0) {
1211 calc_proj_params_r0_neon(src8, width, height, src_stride, dat8, dat_stride,
1212 flt0, flt0_stride, H, C);
1213 } else if (params->r[1] > 0) {
1214 calc_proj_params_r1_neon(src8, width, height, src_stride, dat8, dat_stride,
1215 flt1, flt1_stride, H, C);
1216 }
1217 }
1218