• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "aom_dsp/flow_estimation/disflow.h"
13 
14 #include <arm_neon.h>
15 #include <math.h>
16 
17 #include "aom_dsp/arm/mem_neon.h"
18 #include "aom_dsp/arm/sum_neon.h"
19 #include "config/aom_config.h"
20 #include "config/aom_dsp_rtcd.h"
21 
get_cubic_kernel_dbl(double x,double kernel[4])22 static INLINE void get_cubic_kernel_dbl(double x, double kernel[4]) {
23   // Check that the fractional position is in range.
24   //
25   // Note: x is calculated from, e.g., `u_frac = u - floor(u)`.
26   // Mathematically, this implies that 0 <= x < 1. However, in practice it is
27   // possible to have x == 1 due to floating point rounding. This is fine,
28   // and we still interpolate correctly if we allow x = 1.
29   assert(0 <= x && x <= 1);
30 
31   double x2 = x * x;
32   double x3 = x2 * x;
33   kernel[0] = -0.5 * x + x2 - 0.5 * x3;
34   kernel[1] = 1.0 - 2.5 * x2 + 1.5 * x3;
35   kernel[2] = 0.5 * x + 2.0 * x2 - 1.5 * x3;
36   kernel[3] = -0.5 * x2 + 0.5 * x3;
37 }
38 
get_cubic_kernel_int(double x,int kernel[4])39 static INLINE void get_cubic_kernel_int(double x, int kernel[4]) {
40   double kernel_dbl[4];
41   get_cubic_kernel_dbl(x, kernel_dbl);
42 
43   kernel[0] = (int)rint(kernel_dbl[0] * (1 << DISFLOW_INTERP_BITS));
44   kernel[1] = (int)rint(kernel_dbl[1] * (1 << DISFLOW_INTERP_BITS));
45   kernel[2] = (int)rint(kernel_dbl[2] * (1 << DISFLOW_INTERP_BITS));
46   kernel[3] = (int)rint(kernel_dbl[3] * (1 << DISFLOW_INTERP_BITS));
47 }
48 
49 // Compare two regions of width x height pixels, one rooted at position
50 // (x, y) in src and the other at (x + u, y + v) in ref.
51 // This function returns the sum of squared pixel differences between
52 // the two regions.
compute_flow_error(const uint8_t * src,const uint8_t * ref,int width,int height,int stride,int x,int y,double u,double v,int16_t * dt)53 static INLINE void compute_flow_error(const uint8_t *src, const uint8_t *ref,
54                                       int width, int height, int stride, int x,
55                                       int y, double u, double v, int16_t *dt) {
56   // Split offset into integer and fractional parts, and compute cubic
57   // interpolation kernels
58   const int u_int = (int)floor(u);
59   const int v_int = (int)floor(v);
60   const double u_frac = u - floor(u);
61   const double v_frac = v - floor(v);
62 
63   int h_kernel[4];
64   int v_kernel[4];
65   get_cubic_kernel_int(u_frac, h_kernel);
66   get_cubic_kernel_int(v_frac, v_kernel);
67 
68   int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 3)];
69 
70   // Clamp coordinates so that all pixels we fetch will remain within the
71   // allocated border region, but allow them to go far enough out that
72   // the border pixels' values do not change.
73   // Since we are calculating an 8x8 block, the bottom-right pixel
74   // in the block has coordinates (x0 + 7, y0 + 7). Then, the cubic
75   // interpolation has 4 taps, meaning that the output of pixel
76   // (x_w, y_w) depends on the pixels in the range
77   // ([x_w - 1, x_w + 2], [y_w - 1, y_w + 2]).
78   //
79   // Thus the most extreme coordinates which will be fetched are
80   // (x0 - 1, y0 - 1) and (x0 + 9, y0 + 9).
81   const int x0 = clamp(x + u_int, -9, width);
82   const int y0 = clamp(y + v_int, -9, height);
83 
84   // Horizontal convolution.
85   const uint8_t *ref_start = ref + (y0 - 1) * stride + (x0 - 1);
86   int16x4_t h_filter = vmovn_s32(vld1q_s32(h_kernel));
87 
88   for (int i = 0; i < DISFLOW_PATCH_SIZE + 3; ++i) {
89     uint8x16_t r = vld1q_u8(ref_start + i * stride);
90     uint16x8_t r0 = vmovl_u8(vget_low_u8(r));
91     uint16x8_t r1 = vmovl_u8(vget_high_u8(r));
92 
93     int16x8_t s0 = vreinterpretq_s16_u16(r0);
94     int16x8_t s1 = vreinterpretq_s16_u16(vextq_u16(r0, r1, 1));
95     int16x8_t s2 = vreinterpretq_s16_u16(vextq_u16(r0, r1, 2));
96     int16x8_t s3 = vreinterpretq_s16_u16(vextq_u16(r0, r1, 3));
97 
98     int32x4_t sum_lo = vmull_lane_s16(vget_low_s16(s0), h_filter, 0);
99     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s1), h_filter, 1);
100     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s2), h_filter, 2);
101     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), h_filter, 3);
102 
103     int32x4_t sum_hi = vmull_lane_s16(vget_high_s16(s0), h_filter, 0);
104     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s1), h_filter, 1);
105     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s2), h_filter, 2);
106     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), h_filter, 3);
107 
108     // 6 is the maximum allowable number of extra bits which will avoid
109     // the intermediate values overflowing an int16_t. The most extreme
110     // intermediate value occurs when:
111     // * The input pixels are [0, 255, 255, 0]
112     // * u_frac = 0.5
113     // In this case, the un-scaled output is 255 * 1.125 = 286.875.
114     // As an integer with 6 fractional bits, that is 18360, which fits
115     // in an int16_t. But with 7 fractional bits it would be 36720,
116     // which is too large.
117 
118     int16x8_t sum = vcombine_s16(vrshrn_n_s32(sum_lo, DISFLOW_INTERP_BITS - 6),
119                                  vrshrn_n_s32(sum_hi, DISFLOW_INTERP_BITS - 6));
120     vst1q_s16(tmp_ + i * DISFLOW_PATCH_SIZE, sum);
121   }
122 
123   // Vertical convolution.
124   int16x4_t v_filter = vmovn_s32(vld1q_s32(v_kernel));
125   int16_t *tmp_start = tmp_ + DISFLOW_PATCH_SIZE;
126 
127   for (int i = 0; i < DISFLOW_PATCH_SIZE; ++i) {
128     int16x8_t t0 = vld1q_s16(tmp_start + (i - 1) * DISFLOW_PATCH_SIZE);
129     int16x8_t t1 = vld1q_s16(tmp_start + i * DISFLOW_PATCH_SIZE);
130     int16x8_t t2 = vld1q_s16(tmp_start + (i + 1) * DISFLOW_PATCH_SIZE);
131     int16x8_t t3 = vld1q_s16(tmp_start + (i + 2) * DISFLOW_PATCH_SIZE);
132 
133     int32x4_t sum_lo = vmull_lane_s16(vget_low_s16(t0), v_filter, 0);
134     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t1), v_filter, 1);
135     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t2), v_filter, 2);
136     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t3), v_filter, 3);
137 
138     int32x4_t sum_hi = vmull_lane_s16(vget_high_s16(t0), v_filter, 0);
139     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t1), v_filter, 1);
140     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t2), v_filter, 2);
141     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t3), v_filter, 3);
142 
143     uint8x8_t s = vld1_u8(src + (i + y) * stride + x);
144     int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, 3));
145 
146     // This time, we have to round off the 6 extra bits which were kept
147     // earlier, but we also want to keep DISFLOW_DERIV_SCALE_LOG2 extra bits
148     // of precision to match the scale of the dx and dy arrays.
149     sum_lo = vrshrq_n_s32(sum_lo,
150                           DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
151     sum_hi = vrshrq_n_s32(sum_hi,
152                           DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
153     int32x4_t err_lo = vsubw_s16(sum_lo, vget_low_s16(s_s16));
154     int32x4_t err_hi = vsubw_s16(sum_hi, vget_high_s16(s_s16));
155     vst1q_s16(dt + i * DISFLOW_PATCH_SIZE,
156               vcombine_s16(vmovn_s32(err_lo), vmovn_s32(err_hi)));
157   }
158 }
159 
sobel_filter_x(const uint8_t * src,int src_stride,int16_t * dst,int dst_stride)160 static INLINE void sobel_filter_x(const uint8_t *src, int src_stride,
161                                   int16_t *dst, int dst_stride) {
162   int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
163 
164   // Horizontal filter, using kernel {1, 0, -1}.
165   const uint8_t *src_start = src - 1 * src_stride - 1;
166 
167   for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) {
168     uint8x16_t s = vld1q_u8(src_start + i * src_stride);
169     uint8x8_t s0 = vget_low_u8(s);
170     uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2));
171 
172     // Given that the kernel is {1, 0, -1} the convolution is a simple
173     // subtraction.
174     int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s0, s2));
175 
176     vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, diff);
177   }
178 
179   // Vertical filter, using kernel {1, 2, 1}.
180   // This kernel can be split into two 2-taps kernels of value {1, 1}.
181   // That way we need only 3 add operations to perform the convolution, one of
182   // which can be reused for the next line.
183   int16x8_t s0 = vld1q_s16(tmp);
184   int16x8_t s1 = vld1q_s16(tmp + DISFLOW_PATCH_SIZE);
185   int16x8_t sum01 = vaddq_s16(s0, s1);
186   for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
187     int16x8_t s2 = vld1q_s16(tmp + (i + 2) * DISFLOW_PATCH_SIZE);
188 
189     int16x8_t sum12 = vaddq_s16(s1, s2);
190     int16x8_t sum = vaddq_s16(sum01, sum12);
191 
192     vst1q_s16(dst + i * dst_stride, sum);
193 
194     sum01 = sum12;
195     s1 = s2;
196   }
197 }
198 
sobel_filter_y(const uint8_t * src,int src_stride,int16_t * dst,int dst_stride)199 static INLINE void sobel_filter_y(const uint8_t *src, int src_stride,
200                                   int16_t *dst, int dst_stride) {
201   int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
202 
203   // Horizontal filter, using kernel {1, 2, 1}.
204   // This kernel can be split into two 2-taps kernels of value {1, 1}.
205   // That way we need only 3 add operations to perform the convolution.
206   const uint8_t *src_start = src - 1 * src_stride - 1;
207 
208   for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) {
209     uint8x16_t s = vld1q_u8(src_start + i * src_stride);
210     uint8x8_t s0 = vget_low_u8(s);
211     uint8x8_t s1 = vget_low_u8(vextq_u8(s, s, 1));
212     uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2));
213 
214     uint16x8_t sum01 = vaddl_u8(s0, s1);
215     uint16x8_t sum12 = vaddl_u8(s1, s2);
216     uint16x8_t sum = vaddq_u16(sum01, sum12);
217 
218     vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, vreinterpretq_s16_u16(sum));
219   }
220 
221   // Vertical filter, using kernel {1, 0, -1}.
222   // Load the whole block at once to avoid redundant loads during convolution.
223   int16x8_t t[10];
224   load_s16_8x10(tmp, DISFLOW_PATCH_SIZE, &t[0], &t[1], &t[2], &t[3], &t[4],
225                 &t[5], &t[6], &t[7], &t[8], &t[9]);
226 
227   for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
228     // Given that the kernel is {1, 0, -1} the convolution is a simple
229     // subtraction.
230     int16x8_t diff = vsubq_s16(t[i], t[i + 2]);
231 
232     vst1q_s16(dst + i * dst_stride, diff);
233   }
234 }
235 
236 // Computes the components of the system of equations used to solve for
237 // a flow vector.
238 //
239 // The flow equations are a least-squares system, derived as follows:
240 //
241 // For each pixel in the patch, we calculate the current error `dt`,
242 // and the x and y gradients `dx` and `dy` of the source patch.
243 // This means that, to first order, the squared error for this pixel is
244 //
245 //    (dt + u * dx + v * dy)^2
246 //
247 // where (u, v) are the incremental changes to the flow vector.
248 //
249 // We then want to find the values of u and v which minimize the sum
250 // of the squared error across all pixels. Conveniently, this fits exactly
251 // into the form of a least squares problem, with one equation
252 //
253 //   u * dx + v * dy = -dt
254 //
255 // for each pixel.
256 //
257 // Summing across all pixels in a square window of size DISFLOW_PATCH_SIZE,
258 // and absorbing the - sign elsewhere, this results in the least squares system
259 //
260 //   M = |sum(dx * dx)  sum(dx * dy)|
261 //       |sum(dx * dy)  sum(dy * dy)|
262 //
263 //   b = |sum(dx * dt)|
264 //       |sum(dy * dt)|
compute_flow_matrix(const int16_t * dx,int dx_stride,const int16_t * dy,int dy_stride,double * M_inv)265 static INLINE void compute_flow_matrix(const int16_t *dx, int dx_stride,
266                                        const int16_t *dy, int dy_stride,
267                                        double *M_inv) {
268   int32x4_t sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
269                        vdupq_n_s32(0) };
270 
271   for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
272     int16x8_t x = vld1q_s16(dx + i * dx_stride);
273     int16x8_t y = vld1q_s16(dy + i * dy_stride);
274     sum[0] = vmlal_s16(sum[0], vget_low_s16(x), vget_low_s16(x));
275     sum[0] = vmlal_s16(sum[0], vget_high_s16(x), vget_high_s16(x));
276 
277     sum[1] = vmlal_s16(sum[1], vget_low_s16(x), vget_low_s16(y));
278     sum[1] = vmlal_s16(sum[1], vget_high_s16(x), vget_high_s16(y));
279 
280     sum[3] = vmlal_s16(sum[3], vget_low_s16(y), vget_low_s16(y));
281     sum[3] = vmlal_s16(sum[3], vget_high_s16(y), vget_high_s16(y));
282   }
283   sum[2] = sum[1];
284 
285   int32x4_t res = horizontal_add_4d_s32x4(sum);
286 
287   // Apply regularization
288   // We follow the standard regularization method of adding `k * I` before
289   // inverting. This ensures that the matrix will be invertible.
290   //
291   // Setting the regularization strength k to 1 seems to work well here, as
292   // typical values coming from the other equations are very large (1e5 to
293   // 1e6, with an upper limit of around 6e7, at the time of writing).
294   // It also preserves the property that all matrix values are whole numbers,
295   // which is convenient for integerized SIMD implementation.
296 
297   double M0 = (double)vgetq_lane_s32(res, 0) + 1;
298   double M1 = (double)vgetq_lane_s32(res, 1);
299   double M2 = (double)vgetq_lane_s32(res, 2);
300   double M3 = (double)vgetq_lane_s32(res, 3) + 1;
301 
302   // Invert matrix M.
303   double det = (M0 * M3) - (M1 * M2);
304   assert(det >= 1);
305   const double det_inv = 1 / det;
306 
307   M_inv[0] = M3 * det_inv;
308   M_inv[1] = -M1 * det_inv;
309   M_inv[2] = -M2 * det_inv;
310   M_inv[3] = M0 * det_inv;
311 }
312 
compute_flow_vector(const int16_t * dx,int dx_stride,const int16_t * dy,int dy_stride,const int16_t * dt,int dt_stride,int * b)313 static INLINE void compute_flow_vector(const int16_t *dx, int dx_stride,
314                                        const int16_t *dy, int dy_stride,
315                                        const int16_t *dt, int dt_stride,
316                                        int *b) {
317   int32x4_t b_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
318 
319   for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
320     int16x8_t dx16 = vld1q_s16(dx + i * dx_stride);
321     int16x8_t dy16 = vld1q_s16(dy + i * dy_stride);
322     int16x8_t dt16 = vld1q_s16(dt + i * dt_stride);
323 
324     b_s32[0] = vmlal_s16(b_s32[0], vget_low_s16(dx16), vget_low_s16(dt16));
325     b_s32[0] = vmlal_s16(b_s32[0], vget_high_s16(dx16), vget_high_s16(dt16));
326 
327     b_s32[1] = vmlal_s16(b_s32[1], vget_low_s16(dy16), vget_low_s16(dt16));
328     b_s32[1] = vmlal_s16(b_s32[1], vget_high_s16(dy16), vget_high_s16(dt16));
329   }
330 
331   int32x4_t b_red = horizontal_add_2d_s32(b_s32[0], b_s32[1]);
332   vst1_s32(b, add_pairwise_s32x4(b_red));
333 }
334 
aom_compute_flow_at_point_neon(const uint8_t * src,const uint8_t * ref,int x,int y,int width,int height,int stride,double * u,double * v)335 void aom_compute_flow_at_point_neon(const uint8_t *src, const uint8_t *ref,
336                                     int x, int y, int width, int height,
337                                     int stride, double *u, double *v) {
338   double M_inv[4];
339   int b[2];
340   int16_t dt[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
341   int16_t dx[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
342   int16_t dy[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
343 
344   // Compute gradients within this patch
345   const uint8_t *src_patch = &src[y * stride + x];
346   sobel_filter_x(src_patch, stride, dx, DISFLOW_PATCH_SIZE);
347   sobel_filter_y(src_patch, stride, dy, DISFLOW_PATCH_SIZE);
348 
349   compute_flow_matrix(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, M_inv);
350 
351   for (int itr = 0; itr < DISFLOW_MAX_ITR; itr++) {
352     compute_flow_error(src, ref, width, height, stride, x, y, *u, *v, dt);
353     compute_flow_vector(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, dt,
354                         DISFLOW_PATCH_SIZE, b);
355 
356     // Solve flow equations to find a better estimate for the flow vector
357     // at this point
358     const double step_u = M_inv[0] * b[0] + M_inv[1] * b[1];
359     const double step_v = M_inv[2] * b[0] + M_inv[3] * b[1];
360     *u += fclamp(step_u * DISFLOW_STEP_SIZE, -2, 2);
361     *v += fclamp(step_v * DISFLOW_STEP_SIZE, -2, 2);
362 
363     if (fabs(step_u) + fabs(step_v) < DISFLOW_STEP_SIZE_THRESOLD) {
364       // Stop iteration when we're close to convergence
365       break;
366     }
367   }
368 }
369