• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 /*
12  * The core AEC algorithm, neon version of speed-critical functions.
13  *
14  * Based on aec_core_sse2.c.
15  */
16 
17 #include <arm_neon.h>
18 #include <math.h>
19 #include <string.h>  // memset
20 
21 #include "webrtc/common_audio/signal_processing/include/signal_processing_library.h"
22 #include "webrtc/modules/audio_processing/aec/aec_common.h"
23 #include "webrtc/modules/audio_processing/aec/aec_core_internal.h"
24 #include "webrtc/modules/audio_processing/aec/aec_rdft.h"
25 
26 enum { kShiftExponentIntoTopMantissa = 8 };
27 enum { kFloatExponentShift = 23 };
28 
MulRe(float aRe,float aIm,float bRe,float bIm)29 __inline static float MulRe(float aRe, float aIm, float bRe, float bIm) {
30   return aRe * bRe - aIm * bIm;
31 }
32 
MulIm(float aRe,float aIm,float bRe,float bIm)33 __inline static float MulIm(float aRe, float aIm, float bRe, float bIm) {
34   return aRe * bIm + aIm * bRe;
35 }
36 
FilterFarNEON(int num_partitions,int x_fft_buf_block_pos,float x_fft_buf[2][kExtendedNumPartitions * PART_LEN1],float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1],float y_fft[2][PART_LEN1])37 static void FilterFarNEON(
38     int num_partitions,
39     int x_fft_buf_block_pos,
40     float x_fft_buf[2][kExtendedNumPartitions * PART_LEN1],
41     float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1],
42     float y_fft[2][PART_LEN1]) {
43   int i;
44   for (i = 0; i < num_partitions; i++) {
45     int j;
46     int xPos = (i + x_fft_buf_block_pos) * PART_LEN1;
47     int pos = i * PART_LEN1;
48     // Check for wrap
49     if (i + x_fft_buf_block_pos >= num_partitions) {
50       xPos -= num_partitions * PART_LEN1;
51     }
52 
53     // vectorized code (four at once)
54     for (j = 0; j + 3 < PART_LEN1; j += 4) {
55       const float32x4_t x_fft_buf_re = vld1q_f32(&x_fft_buf[0][xPos + j]);
56       const float32x4_t x_fft_buf_im = vld1q_f32(&x_fft_buf[1][xPos + j]);
57       const float32x4_t h_fft_buf_re = vld1q_f32(&h_fft_buf[0][pos + j]);
58       const float32x4_t h_fft_buf_im = vld1q_f32(&h_fft_buf[1][pos + j]);
59       const float32x4_t y_fft_re = vld1q_f32(&y_fft[0][j]);
60       const float32x4_t y_fft_im = vld1q_f32(&y_fft[1][j]);
61       const float32x4_t a = vmulq_f32(x_fft_buf_re, h_fft_buf_re);
62       const float32x4_t e = vmlsq_f32(a, x_fft_buf_im, h_fft_buf_im);
63       const float32x4_t c = vmulq_f32(x_fft_buf_re, h_fft_buf_im);
64       const float32x4_t f = vmlaq_f32(c, x_fft_buf_im, h_fft_buf_re);
65       const float32x4_t g = vaddq_f32(y_fft_re, e);
66       const float32x4_t h = vaddq_f32(y_fft_im, f);
67       vst1q_f32(&y_fft[0][j], g);
68       vst1q_f32(&y_fft[1][j], h);
69     }
70     // scalar code for the remaining items.
71     for (; j < PART_LEN1; j++) {
72       y_fft[0][j] += MulRe(x_fft_buf[0][xPos + j],
73                            x_fft_buf[1][xPos + j],
74                            h_fft_buf[0][pos + j],
75                            h_fft_buf[1][pos + j]);
76       y_fft[1][j] += MulIm(x_fft_buf[0][xPos + j],
77                            x_fft_buf[1][xPos + j],
78                            h_fft_buf[0][pos + j],
79                            h_fft_buf[1][pos + j]);
80     }
81   }
82 }
83 
84 // ARM64's arm_neon.h has already defined vdivq_f32 vsqrtq_f32.
85 #if !defined (WEBRTC_ARCH_ARM64)
vdivq_f32(float32x4_t a,float32x4_t b)86 static float32x4_t vdivq_f32(float32x4_t a, float32x4_t b) {
87   int i;
88   float32x4_t x = vrecpeq_f32(b);
89   // from arm documentation
90   // The Newton-Raphson iteration:
91   //     x[n+1] = x[n] * (2 - d * x[n])
92   // converges to (1/d) if x0 is the result of VRECPE applied to d.
93   //
94   // Note: The precision did not improve after 2 iterations.
95   for (i = 0; i < 2; i++) {
96     x = vmulq_f32(vrecpsq_f32(b, x), x);
97   }
98   // a/b = a*(1/b)
99   return vmulq_f32(a, x);
100 }
101 
vsqrtq_f32(float32x4_t s)102 static float32x4_t vsqrtq_f32(float32x4_t s) {
103   int i;
104   float32x4_t x = vrsqrteq_f32(s);
105 
106   // Code to handle sqrt(0).
107   // If the input to sqrtf() is zero, a zero will be returned.
108   // If the input to vrsqrteq_f32() is zero, positive infinity is returned.
109   const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000);
110   // check for divide by zero
111   const uint32x4_t div_by_zero = vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(x));
112   // zero out the positive infinity results
113   x = vreinterpretq_f32_u32(vandq_u32(vmvnq_u32(div_by_zero),
114                                       vreinterpretq_u32_f32(x)));
115   // from arm documentation
116   // The Newton-Raphson iteration:
117   //     x[n+1] = x[n] * (3 - d * (x[n] * x[n])) / 2)
118   // converges to (1/√d) if x0 is the result of VRSQRTE applied to d.
119   //
120   // Note: The precision did not improve after 2 iterations.
121   for (i = 0; i < 2; i++) {
122     x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, x), s), x);
123   }
124   // sqrt(s) = s * 1/sqrt(s)
125   return vmulq_f32(s, x);;
126 }
127 #endif  // WEBRTC_ARCH_ARM64
128 
ScaleErrorSignalNEON(int extended_filter_enabled,float normal_mu,float normal_error_threshold,float x_pow[PART_LEN1],float ef[2][PART_LEN1])129 static void ScaleErrorSignalNEON(int extended_filter_enabled,
130                                  float normal_mu,
131                                  float normal_error_threshold,
132                                  float x_pow[PART_LEN1],
133                                  float ef[2][PART_LEN1]) {
134   const float mu = extended_filter_enabled ? kExtendedMu : normal_mu;
135   const float error_threshold = extended_filter_enabled ?
136       kExtendedErrorThreshold : normal_error_threshold;
137   const float32x4_t k1e_10f = vdupq_n_f32(1e-10f);
138   const float32x4_t kMu = vmovq_n_f32(mu);
139   const float32x4_t kThresh = vmovq_n_f32(error_threshold);
140   int i;
141   // vectorized code (four at once)
142   for (i = 0; i + 3 < PART_LEN1; i += 4) {
143     const float32x4_t x_pow_local = vld1q_f32(&x_pow[i]);
144     const float32x4_t ef_re_base = vld1q_f32(&ef[0][i]);
145     const float32x4_t ef_im_base = vld1q_f32(&ef[1][i]);
146     const float32x4_t xPowPlus = vaddq_f32(x_pow_local, k1e_10f);
147     float32x4_t ef_re = vdivq_f32(ef_re_base, xPowPlus);
148     float32x4_t ef_im = vdivq_f32(ef_im_base, xPowPlus);
149     const float32x4_t ef_re2 = vmulq_f32(ef_re, ef_re);
150     const float32x4_t ef_sum2 = vmlaq_f32(ef_re2, ef_im, ef_im);
151     const float32x4_t absEf = vsqrtq_f32(ef_sum2);
152     const uint32x4_t bigger = vcgtq_f32(absEf, kThresh);
153     const float32x4_t absEfPlus = vaddq_f32(absEf, k1e_10f);
154     const float32x4_t absEfInv = vdivq_f32(kThresh, absEfPlus);
155     uint32x4_t ef_re_if = vreinterpretq_u32_f32(vmulq_f32(ef_re, absEfInv));
156     uint32x4_t ef_im_if = vreinterpretq_u32_f32(vmulq_f32(ef_im, absEfInv));
157     uint32x4_t ef_re_u32 = vandq_u32(vmvnq_u32(bigger),
158                                      vreinterpretq_u32_f32(ef_re));
159     uint32x4_t ef_im_u32 = vandq_u32(vmvnq_u32(bigger),
160                                      vreinterpretq_u32_f32(ef_im));
161     ef_re_if = vandq_u32(bigger, ef_re_if);
162     ef_im_if = vandq_u32(bigger, ef_im_if);
163     ef_re_u32 = vorrq_u32(ef_re_u32, ef_re_if);
164     ef_im_u32 = vorrq_u32(ef_im_u32, ef_im_if);
165     ef_re = vmulq_f32(vreinterpretq_f32_u32(ef_re_u32), kMu);
166     ef_im = vmulq_f32(vreinterpretq_f32_u32(ef_im_u32), kMu);
167     vst1q_f32(&ef[0][i], ef_re);
168     vst1q_f32(&ef[1][i], ef_im);
169   }
170   // scalar code for the remaining items.
171   for (; i < PART_LEN1; i++) {
172     float abs_ef;
173     ef[0][i] /= (x_pow[i] + 1e-10f);
174     ef[1][i] /= (x_pow[i] + 1e-10f);
175     abs_ef = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]);
176 
177     if (abs_ef > error_threshold) {
178       abs_ef = error_threshold / (abs_ef + 1e-10f);
179       ef[0][i] *= abs_ef;
180       ef[1][i] *= abs_ef;
181     }
182 
183     // Stepsize factor
184     ef[0][i] *= mu;
185     ef[1][i] *= mu;
186   }
187 }
188 
FilterAdaptationNEON(int num_partitions,int x_fft_buf_block_pos,float x_fft_buf[2][kExtendedNumPartitions * PART_LEN1],float e_fft[2][PART_LEN1],float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1])189 static void FilterAdaptationNEON(
190     int num_partitions,
191     int x_fft_buf_block_pos,
192     float x_fft_buf[2][kExtendedNumPartitions * PART_LEN1],
193     float e_fft[2][PART_LEN1],
194     float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1]) {
195   float fft[PART_LEN2];
196   int i;
197   for (i = 0; i < num_partitions; i++) {
198     int xPos = (i + x_fft_buf_block_pos) * PART_LEN1;
199     int pos = i * PART_LEN1;
200     int j;
201     // Check for wrap
202     if (i + x_fft_buf_block_pos >= num_partitions) {
203       xPos -= num_partitions * PART_LEN1;
204     }
205 
206     // Process the whole array...
207     for (j = 0; j < PART_LEN; j += 4) {
208       // Load x_fft_buf and e_fft.
209       const float32x4_t x_fft_buf_re = vld1q_f32(&x_fft_buf[0][xPos + j]);
210       const float32x4_t x_fft_buf_im = vld1q_f32(&x_fft_buf[1][xPos + j]);
211       const float32x4_t e_fft_re = vld1q_f32(&e_fft[0][j]);
212       const float32x4_t e_fft_im = vld1q_f32(&e_fft[1][j]);
213       // Calculate the product of conjugate(x_fft_buf) by e_fft.
214       //   re(conjugate(a) * b) = aRe * bRe + aIm * bIm
215       //   im(conjugate(a) * b)=  aRe * bIm - aIm * bRe
216       const float32x4_t a = vmulq_f32(x_fft_buf_re, e_fft_re);
217       const float32x4_t e = vmlaq_f32(a, x_fft_buf_im, e_fft_im);
218       const float32x4_t c = vmulq_f32(x_fft_buf_re, e_fft_im);
219       const float32x4_t f = vmlsq_f32(c, x_fft_buf_im, e_fft_re);
220       // Interleave real and imaginary parts.
221       const float32x4x2_t g_n_h = vzipq_f32(e, f);
222       // Store
223       vst1q_f32(&fft[2 * j + 0], g_n_h.val[0]);
224       vst1q_f32(&fft[2 * j + 4], g_n_h.val[1]);
225     }
226     // ... and fixup the first imaginary entry.
227     fft[1] = MulRe(x_fft_buf[0][xPos + PART_LEN],
228                    -x_fft_buf[1][xPos + PART_LEN],
229                    e_fft[0][PART_LEN],
230                    e_fft[1][PART_LEN]);
231 
232     aec_rdft_inverse_128(fft);
233     memset(fft + PART_LEN, 0, sizeof(float) * PART_LEN);
234 
235     // fft scaling
236     {
237       const float scale = 2.0f / PART_LEN2;
238       const float32x4_t scale_ps = vmovq_n_f32(scale);
239       for (j = 0; j < PART_LEN; j += 4) {
240         const float32x4_t fft_ps = vld1q_f32(&fft[j]);
241         const float32x4_t fft_scale = vmulq_f32(fft_ps, scale_ps);
242         vst1q_f32(&fft[j], fft_scale);
243       }
244     }
245     aec_rdft_forward_128(fft);
246 
247     {
248       const float wt1 = h_fft_buf[1][pos];
249       h_fft_buf[0][pos + PART_LEN] += fft[1];
250       for (j = 0; j < PART_LEN; j += 4) {
251         float32x4_t wtBuf_re = vld1q_f32(&h_fft_buf[0][pos + j]);
252         float32x4_t wtBuf_im = vld1q_f32(&h_fft_buf[1][pos + j]);
253         const float32x4_t fft0 = vld1q_f32(&fft[2 * j + 0]);
254         const float32x4_t fft4 = vld1q_f32(&fft[2 * j + 4]);
255         const float32x4x2_t fft_re_im = vuzpq_f32(fft0, fft4);
256         wtBuf_re = vaddq_f32(wtBuf_re, fft_re_im.val[0]);
257         wtBuf_im = vaddq_f32(wtBuf_im, fft_re_im.val[1]);
258 
259         vst1q_f32(&h_fft_buf[0][pos + j], wtBuf_re);
260         vst1q_f32(&h_fft_buf[1][pos + j], wtBuf_im);
261       }
262       h_fft_buf[1][pos] = wt1;
263     }
264   }
265 }
266 
vpowq_f32(float32x4_t a,float32x4_t b)267 static float32x4_t vpowq_f32(float32x4_t a, float32x4_t b) {
268   // a^b = exp2(b * log2(a))
269   //   exp2(x) and log2(x) are calculated using polynomial approximations.
270   float32x4_t log2_a, b_log2_a, a_exp_b;
271 
272   // Calculate log2(x), x = a.
273   {
274     // To calculate log2(x), we decompose x like this:
275     //   x = y * 2^n
276     //     n is an integer
277     //     y is in the [1.0, 2.0) range
278     //
279     //   log2(x) = log2(y) + n
280     //     n       can be evaluated by playing with float representation.
281     //     log2(y) in a small range can be approximated, this code uses an order
282     //             five polynomial approximation. The coefficients have been
283     //             estimated with the Remez algorithm and the resulting
284     //             polynomial has a maximum relative error of 0.00086%.
285 
286     // Compute n.
287     //    This is done by masking the exponent, shifting it into the top bit of
288     //    the mantissa, putting eight into the biased exponent (to shift/
289     //    compensate the fact that the exponent has been shifted in the top/
290     //    fractional part and finally getting rid of the implicit leading one
291     //    from the mantissa by substracting it out.
292     const uint32x4_t vec_float_exponent_mask = vdupq_n_u32(0x7F800000);
293     const uint32x4_t vec_eight_biased_exponent = vdupq_n_u32(0x43800000);
294     const uint32x4_t vec_implicit_leading_one = vdupq_n_u32(0x43BF8000);
295     const uint32x4_t two_n = vandq_u32(vreinterpretq_u32_f32(a),
296                                        vec_float_exponent_mask);
297     const uint32x4_t n_1 = vshrq_n_u32(two_n, kShiftExponentIntoTopMantissa);
298     const uint32x4_t n_0 = vorrq_u32(n_1, vec_eight_biased_exponent);
299     const float32x4_t n =
300         vsubq_f32(vreinterpretq_f32_u32(n_0),
301                   vreinterpretq_f32_u32(vec_implicit_leading_one));
302     // Compute y.
303     const uint32x4_t vec_mantissa_mask = vdupq_n_u32(0x007FFFFF);
304     const uint32x4_t vec_zero_biased_exponent_is_one = vdupq_n_u32(0x3F800000);
305     const uint32x4_t mantissa = vandq_u32(vreinterpretq_u32_f32(a),
306                                           vec_mantissa_mask);
307     const float32x4_t y =
308         vreinterpretq_f32_u32(vorrq_u32(mantissa,
309                                         vec_zero_biased_exponent_is_one));
310     // Approximate log2(y) ~= (y - 1) * pol5(y).
311     //    pol5(y) = C5 * y^5 + C4 * y^4 + C3 * y^3 + C2 * y^2 + C1 * y + C0
312     const float32x4_t C5 = vdupq_n_f32(-3.4436006e-2f);
313     const float32x4_t C4 = vdupq_n_f32(3.1821337e-1f);
314     const float32x4_t C3 = vdupq_n_f32(-1.2315303f);
315     const float32x4_t C2 = vdupq_n_f32(2.5988452f);
316     const float32x4_t C1 = vdupq_n_f32(-3.3241990f);
317     const float32x4_t C0 = vdupq_n_f32(3.1157899f);
318     float32x4_t pol5_y = C5;
319     pol5_y = vmlaq_f32(C4, y, pol5_y);
320     pol5_y = vmlaq_f32(C3, y, pol5_y);
321     pol5_y = vmlaq_f32(C2, y, pol5_y);
322     pol5_y = vmlaq_f32(C1, y, pol5_y);
323     pol5_y = vmlaq_f32(C0, y, pol5_y);
324     const float32x4_t y_minus_one =
325         vsubq_f32(y, vreinterpretq_f32_u32(vec_zero_biased_exponent_is_one));
326     const float32x4_t log2_y = vmulq_f32(y_minus_one, pol5_y);
327 
328     // Combine parts.
329     log2_a = vaddq_f32(n, log2_y);
330   }
331 
332   // b * log2(a)
333   b_log2_a = vmulq_f32(b, log2_a);
334 
335   // Calculate exp2(x), x = b * log2(a).
336   {
337     // To calculate 2^x, we decompose x like this:
338     //   x = n + y
339     //     n is an integer, the value of x - 0.5 rounded down, therefore
340     //     y is in the [0.5, 1.5) range
341     //
342     //   2^x = 2^n * 2^y
343     //     2^n can be evaluated by playing with float representation.
344     //     2^y in a small range can be approximated, this code uses an order two
345     //         polynomial approximation. The coefficients have been estimated
346     //         with the Remez algorithm and the resulting polynomial has a
347     //         maximum relative error of 0.17%.
348     // To avoid over/underflow, we reduce the range of input to ]-127, 129].
349     const float32x4_t max_input = vdupq_n_f32(129.f);
350     const float32x4_t min_input = vdupq_n_f32(-126.99999f);
351     const float32x4_t x_min = vminq_f32(b_log2_a, max_input);
352     const float32x4_t x_max = vmaxq_f32(x_min, min_input);
353     // Compute n.
354     const float32x4_t half = vdupq_n_f32(0.5f);
355     const float32x4_t x_minus_half = vsubq_f32(x_max, half);
356     const int32x4_t x_minus_half_floor = vcvtq_s32_f32(x_minus_half);
357 
358     // Compute 2^n.
359     const int32x4_t float_exponent_bias = vdupq_n_s32(127);
360     const int32x4_t two_n_exponent =
361         vaddq_s32(x_minus_half_floor, float_exponent_bias);
362     const float32x4_t two_n =
363         vreinterpretq_f32_s32(vshlq_n_s32(two_n_exponent, kFloatExponentShift));
364     // Compute y.
365     const float32x4_t y = vsubq_f32(x_max, vcvtq_f32_s32(x_minus_half_floor));
366 
367     // Approximate 2^y ~= C2 * y^2 + C1 * y + C0.
368     const float32x4_t C2 = vdupq_n_f32(3.3718944e-1f);
369     const float32x4_t C1 = vdupq_n_f32(6.5763628e-1f);
370     const float32x4_t C0 = vdupq_n_f32(1.0017247f);
371     float32x4_t exp2_y = C2;
372     exp2_y = vmlaq_f32(C1, y, exp2_y);
373     exp2_y = vmlaq_f32(C0, y, exp2_y);
374 
375     // Combine parts.
376     a_exp_b = vmulq_f32(exp2_y, two_n);
377   }
378 
379   return a_exp_b;
380 }
381 
OverdriveAndSuppressNEON(AecCore * aec,float hNl[PART_LEN1],const float hNlFb,float efw[2][PART_LEN1])382 static void OverdriveAndSuppressNEON(AecCore* aec,
383                                      float hNl[PART_LEN1],
384                                      const float hNlFb,
385                                      float efw[2][PART_LEN1]) {
386   int i;
387   const float32x4_t vec_hNlFb = vmovq_n_f32(hNlFb);
388   const float32x4_t vec_one = vdupq_n_f32(1.0f);
389   const float32x4_t vec_minus_one = vdupq_n_f32(-1.0f);
390   const float32x4_t vec_overDriveSm = vmovq_n_f32(aec->overDriveSm);
391 
392   // vectorized code (four at once)
393   for (i = 0; i + 3 < PART_LEN1; i += 4) {
394     // Weight subbands
395     float32x4_t vec_hNl = vld1q_f32(&hNl[i]);
396     const float32x4_t vec_weightCurve = vld1q_f32(&WebRtcAec_weightCurve[i]);
397     const uint32x4_t bigger = vcgtq_f32(vec_hNl, vec_hNlFb);
398     const float32x4_t vec_weightCurve_hNlFb = vmulq_f32(vec_weightCurve,
399                                                         vec_hNlFb);
400     const float32x4_t vec_one_weightCurve = vsubq_f32(vec_one, vec_weightCurve);
401     const float32x4_t vec_one_weightCurve_hNl = vmulq_f32(vec_one_weightCurve,
402                                                           vec_hNl);
403     const uint32x4_t vec_if0 = vandq_u32(vmvnq_u32(bigger),
404                                          vreinterpretq_u32_f32(vec_hNl));
405     const float32x4_t vec_one_weightCurve_add =
406         vaddq_f32(vec_weightCurve_hNlFb, vec_one_weightCurve_hNl);
407     const uint32x4_t vec_if1 =
408         vandq_u32(bigger, vreinterpretq_u32_f32(vec_one_weightCurve_add));
409 
410     vec_hNl = vreinterpretq_f32_u32(vorrq_u32(vec_if0, vec_if1));
411 
412     {
413       const float32x4_t vec_overDriveCurve =
414           vld1q_f32(&WebRtcAec_overDriveCurve[i]);
415       const float32x4_t vec_overDriveSm_overDriveCurve =
416           vmulq_f32(vec_overDriveSm, vec_overDriveCurve);
417       vec_hNl = vpowq_f32(vec_hNl, vec_overDriveSm_overDriveCurve);
418       vst1q_f32(&hNl[i], vec_hNl);
419     }
420 
421     // Suppress error signal
422     {
423       float32x4_t vec_efw_re = vld1q_f32(&efw[0][i]);
424       float32x4_t vec_efw_im = vld1q_f32(&efw[1][i]);
425       vec_efw_re = vmulq_f32(vec_efw_re, vec_hNl);
426       vec_efw_im = vmulq_f32(vec_efw_im, vec_hNl);
427 
428       // Ooura fft returns incorrect sign on imaginary component. It matters
429       // here because we are making an additive change with comfort noise.
430       vec_efw_im = vmulq_f32(vec_efw_im, vec_minus_one);
431       vst1q_f32(&efw[0][i], vec_efw_re);
432       vst1q_f32(&efw[1][i], vec_efw_im);
433     }
434   }
435 
436   // scalar code for the remaining items.
437   for (; i < PART_LEN1; i++) {
438     // Weight subbands
439     if (hNl[i] > hNlFb) {
440       hNl[i] = WebRtcAec_weightCurve[i] * hNlFb +
441                (1 - WebRtcAec_weightCurve[i]) * hNl[i];
442     }
443 
444     hNl[i] = powf(hNl[i], aec->overDriveSm * WebRtcAec_overDriveCurve[i]);
445 
446     // Suppress error signal
447     efw[0][i] *= hNl[i];
448     efw[1][i] *= hNl[i];
449 
450     // Ooura fft returns incorrect sign on imaginary component. It matters
451     // here because we are making an additive change with comfort noise.
452     efw[1][i] *= -1;
453   }
454 }
455 
PartitionDelayNEON(const AecCore * aec)456 static int PartitionDelayNEON(const AecCore* aec) {
457   // Measures the energy in each filter partition and returns the partition with
458   // highest energy.
459   // TODO(bjornv): Spread computational cost by computing one partition per
460   // block?
461   float wfEnMax = 0;
462   int i;
463   int delay = 0;
464 
465   for (i = 0; i < aec->num_partitions; i++) {
466     int j;
467     int pos = i * PART_LEN1;
468     float wfEn = 0;
469     float32x4_t vec_wfEn = vdupq_n_f32(0.0f);
470     // vectorized code (four at once)
471     for (j = 0; j + 3 < PART_LEN1; j += 4) {
472       const float32x4_t vec_wfBuf0 = vld1q_f32(&aec->wfBuf[0][pos + j]);
473       const float32x4_t vec_wfBuf1 = vld1q_f32(&aec->wfBuf[1][pos + j]);
474       vec_wfEn = vmlaq_f32(vec_wfEn, vec_wfBuf0, vec_wfBuf0);
475       vec_wfEn = vmlaq_f32(vec_wfEn, vec_wfBuf1, vec_wfBuf1);
476     }
477     {
478       float32x2_t vec_total;
479       // A B C D
480       vec_total = vpadd_f32(vget_low_f32(vec_wfEn), vget_high_f32(vec_wfEn));
481       // A+B C+D
482       vec_total = vpadd_f32(vec_total, vec_total);
483       // A+B+C+D A+B+C+D
484       wfEn = vget_lane_f32(vec_total, 0);
485     }
486 
487     // scalar code for the remaining items.
488     for (; j < PART_LEN1; j++) {
489       wfEn += aec->wfBuf[0][pos + j] * aec->wfBuf[0][pos + j] +
490               aec->wfBuf[1][pos + j] * aec->wfBuf[1][pos + j];
491     }
492 
493     if (wfEn > wfEnMax) {
494       wfEnMax = wfEn;
495       delay = i;
496     }
497   }
498   return delay;
499 }
500 
501 // Updates the following smoothed  Power Spectral Densities (PSD):
502 //  - sd  : near-end
503 //  - se  : residual echo
504 //  - sx  : far-end
505 //  - sde : cross-PSD of near-end and residual echo
506 //  - sxd : cross-PSD of near-end and far-end
507 //
508 // In addition to updating the PSDs, also the filter diverge state is determined
509 // upon actions are taken.
SmoothedPSD(AecCore * aec,float efw[2][PART_LEN1],float dfw[2][PART_LEN1],float xfw[2][PART_LEN1],int * extreme_filter_divergence)510 static void SmoothedPSD(AecCore* aec,
511                         float efw[2][PART_LEN1],
512                         float dfw[2][PART_LEN1],
513                         float xfw[2][PART_LEN1],
514                         int* extreme_filter_divergence) {
515   // Power estimate smoothing coefficients.
516   const float* ptrGCoh = aec->extended_filter_enabled
517       ? WebRtcAec_kExtendedSmoothingCoefficients[aec->mult - 1]
518       : WebRtcAec_kNormalSmoothingCoefficients[aec->mult - 1];
519   int i;
520   float sdSum = 0, seSum = 0;
521   const float32x4_t vec_15 =  vdupq_n_f32(WebRtcAec_kMinFarendPSD);
522   float32x4_t vec_sdSum = vdupq_n_f32(0.0f);
523   float32x4_t vec_seSum = vdupq_n_f32(0.0f);
524 
525   for (i = 0; i + 3 < PART_LEN1; i += 4) {
526     const float32x4_t vec_dfw0 = vld1q_f32(&dfw[0][i]);
527     const float32x4_t vec_dfw1 = vld1q_f32(&dfw[1][i]);
528     const float32x4_t vec_efw0 = vld1q_f32(&efw[0][i]);
529     const float32x4_t vec_efw1 = vld1q_f32(&efw[1][i]);
530     const float32x4_t vec_xfw0 = vld1q_f32(&xfw[0][i]);
531     const float32x4_t vec_xfw1 = vld1q_f32(&xfw[1][i]);
532     float32x4_t vec_sd = vmulq_n_f32(vld1q_f32(&aec->sd[i]), ptrGCoh[0]);
533     float32x4_t vec_se = vmulq_n_f32(vld1q_f32(&aec->se[i]), ptrGCoh[0]);
534     float32x4_t vec_sx = vmulq_n_f32(vld1q_f32(&aec->sx[i]), ptrGCoh[0]);
535     float32x4_t vec_dfw_sumsq = vmulq_f32(vec_dfw0, vec_dfw0);
536     float32x4_t vec_efw_sumsq = vmulq_f32(vec_efw0, vec_efw0);
537     float32x4_t vec_xfw_sumsq = vmulq_f32(vec_xfw0, vec_xfw0);
538 
539     vec_dfw_sumsq = vmlaq_f32(vec_dfw_sumsq, vec_dfw1, vec_dfw1);
540     vec_efw_sumsq = vmlaq_f32(vec_efw_sumsq, vec_efw1, vec_efw1);
541     vec_xfw_sumsq = vmlaq_f32(vec_xfw_sumsq, vec_xfw1, vec_xfw1);
542     vec_xfw_sumsq = vmaxq_f32(vec_xfw_sumsq, vec_15);
543     vec_sd = vmlaq_n_f32(vec_sd, vec_dfw_sumsq, ptrGCoh[1]);
544     vec_se = vmlaq_n_f32(vec_se, vec_efw_sumsq, ptrGCoh[1]);
545     vec_sx = vmlaq_n_f32(vec_sx, vec_xfw_sumsq, ptrGCoh[1]);
546 
547     vst1q_f32(&aec->sd[i], vec_sd);
548     vst1q_f32(&aec->se[i], vec_se);
549     vst1q_f32(&aec->sx[i], vec_sx);
550 
551     {
552       float32x4x2_t vec_sde = vld2q_f32(&aec->sde[i][0]);
553       float32x4_t vec_dfwefw0011 = vmulq_f32(vec_dfw0, vec_efw0);
554       float32x4_t vec_dfwefw0110 = vmulq_f32(vec_dfw0, vec_efw1);
555       vec_sde.val[0] = vmulq_n_f32(vec_sde.val[0], ptrGCoh[0]);
556       vec_sde.val[1] = vmulq_n_f32(vec_sde.val[1], ptrGCoh[0]);
557       vec_dfwefw0011 = vmlaq_f32(vec_dfwefw0011, vec_dfw1, vec_efw1);
558       vec_dfwefw0110 = vmlsq_f32(vec_dfwefw0110, vec_dfw1, vec_efw0);
559       vec_sde.val[0] = vmlaq_n_f32(vec_sde.val[0], vec_dfwefw0011, ptrGCoh[1]);
560       vec_sde.val[1] = vmlaq_n_f32(vec_sde.val[1], vec_dfwefw0110, ptrGCoh[1]);
561       vst2q_f32(&aec->sde[i][0], vec_sde);
562     }
563 
564     {
565       float32x4x2_t vec_sxd = vld2q_f32(&aec->sxd[i][0]);
566       float32x4_t vec_dfwxfw0011 = vmulq_f32(vec_dfw0, vec_xfw0);
567       float32x4_t vec_dfwxfw0110 = vmulq_f32(vec_dfw0, vec_xfw1);
568       vec_sxd.val[0] = vmulq_n_f32(vec_sxd.val[0], ptrGCoh[0]);
569       vec_sxd.val[1] = vmulq_n_f32(vec_sxd.val[1], ptrGCoh[0]);
570       vec_dfwxfw0011 = vmlaq_f32(vec_dfwxfw0011, vec_dfw1, vec_xfw1);
571       vec_dfwxfw0110 = vmlsq_f32(vec_dfwxfw0110, vec_dfw1, vec_xfw0);
572       vec_sxd.val[0] = vmlaq_n_f32(vec_sxd.val[0], vec_dfwxfw0011, ptrGCoh[1]);
573       vec_sxd.val[1] = vmlaq_n_f32(vec_sxd.val[1], vec_dfwxfw0110, ptrGCoh[1]);
574       vst2q_f32(&aec->sxd[i][0], vec_sxd);
575     }
576 
577     vec_sdSum = vaddq_f32(vec_sdSum, vec_sd);
578     vec_seSum = vaddq_f32(vec_seSum, vec_se);
579   }
580   {
581     float32x2_t vec_sdSum_total;
582     float32x2_t vec_seSum_total;
583     // A B C D
584     vec_sdSum_total = vpadd_f32(vget_low_f32(vec_sdSum),
585                                 vget_high_f32(vec_sdSum));
586     vec_seSum_total = vpadd_f32(vget_low_f32(vec_seSum),
587                                 vget_high_f32(vec_seSum));
588     // A+B C+D
589     vec_sdSum_total = vpadd_f32(vec_sdSum_total, vec_sdSum_total);
590     vec_seSum_total = vpadd_f32(vec_seSum_total, vec_seSum_total);
591     // A+B+C+D A+B+C+D
592     sdSum = vget_lane_f32(vec_sdSum_total, 0);
593     seSum = vget_lane_f32(vec_seSum_total, 0);
594   }
595 
596   // scalar code for the remaining items.
597   for (; i < PART_LEN1; i++) {
598     aec->sd[i] = ptrGCoh[0] * aec->sd[i] +
599                  ptrGCoh[1] * (dfw[0][i] * dfw[0][i] + dfw[1][i] * dfw[1][i]);
600     aec->se[i] = ptrGCoh[0] * aec->se[i] +
601                  ptrGCoh[1] * (efw[0][i] * efw[0][i] + efw[1][i] * efw[1][i]);
602     // We threshold here to protect against the ill-effects of a zero farend.
603     // The threshold is not arbitrarily chosen, but balances protection and
604     // adverse interaction with the algorithm's tuning.
605     // TODO(bjornv): investigate further why this is so sensitive.
606     aec->sx[i] =
607         ptrGCoh[0] * aec->sx[i] +
608         ptrGCoh[1] * WEBRTC_SPL_MAX(
609             xfw[0][i] * xfw[0][i] + xfw[1][i] * xfw[1][i],
610             WebRtcAec_kMinFarendPSD);
611 
612     aec->sde[i][0] =
613         ptrGCoh[0] * aec->sde[i][0] +
614         ptrGCoh[1] * (dfw[0][i] * efw[0][i] + dfw[1][i] * efw[1][i]);
615     aec->sde[i][1] =
616         ptrGCoh[0] * aec->sde[i][1] +
617         ptrGCoh[1] * (dfw[0][i] * efw[1][i] - dfw[1][i] * efw[0][i]);
618 
619     aec->sxd[i][0] =
620         ptrGCoh[0] * aec->sxd[i][0] +
621         ptrGCoh[1] * (dfw[0][i] * xfw[0][i] + dfw[1][i] * xfw[1][i]);
622     aec->sxd[i][1] =
623         ptrGCoh[0] * aec->sxd[i][1] +
624         ptrGCoh[1] * (dfw[0][i] * xfw[1][i] - dfw[1][i] * xfw[0][i]);
625 
626     sdSum += aec->sd[i];
627     seSum += aec->se[i];
628   }
629 
630   // Divergent filter safeguard update.
631   aec->divergeState = (aec->divergeState ? 1.05f : 1.0f) * seSum > sdSum;
632 
633   // Signal extreme filter divergence if the error is significantly larger
634   // than the nearend (13 dB).
635   *extreme_filter_divergence = (seSum > (19.95f * sdSum));
636 }
637 
638 // Window time domain data to be used by the fft.
WindowDataNEON(float * x_windowed,const float * x)639 static void WindowDataNEON(float* x_windowed, const float* x) {
640   int i;
641   for (i = 0; i < PART_LEN; i += 4) {
642     const float32x4_t vec_Buf1 = vld1q_f32(&x[i]);
643     const float32x4_t vec_Buf2 = vld1q_f32(&x[PART_LEN + i]);
644     const float32x4_t vec_sqrtHanning = vld1q_f32(&WebRtcAec_sqrtHanning[i]);
645     // A B C D
646     float32x4_t vec_sqrtHanning_rev =
647         vld1q_f32(&WebRtcAec_sqrtHanning[PART_LEN - i - 3]);
648     // B A D C
649     vec_sqrtHanning_rev = vrev64q_f32(vec_sqrtHanning_rev);
650     // D C B A
651     vec_sqrtHanning_rev = vcombine_f32(vget_high_f32(vec_sqrtHanning_rev),
652                                        vget_low_f32(vec_sqrtHanning_rev));
653     vst1q_f32(&x_windowed[i], vmulq_f32(vec_Buf1, vec_sqrtHanning));
654     vst1q_f32(&x_windowed[PART_LEN + i],
655             vmulq_f32(vec_Buf2, vec_sqrtHanning_rev));
656   }
657 }
658 
659 // Puts fft output data into a complex valued array.
StoreAsComplexNEON(const float * data,float data_complex[2][PART_LEN1])660 static void StoreAsComplexNEON(const float* data,
661                                float data_complex[2][PART_LEN1]) {
662   int i;
663   for (i = 0; i < PART_LEN; i += 4) {
664     const float32x4x2_t vec_data = vld2q_f32(&data[2 * i]);
665     vst1q_f32(&data_complex[0][i], vec_data.val[0]);
666     vst1q_f32(&data_complex[1][i], vec_data.val[1]);
667   }
668   // fix beginning/end values
669   data_complex[1][0] = 0;
670   data_complex[1][PART_LEN] = 0;
671   data_complex[0][0] = data[0];
672   data_complex[0][PART_LEN] = data[1];
673 }
674 
SubbandCoherenceNEON(AecCore * aec,float efw[2][PART_LEN1],float dfw[2][PART_LEN1],float xfw[2][PART_LEN1],float * fft,float * cohde,float * cohxd,int * extreme_filter_divergence)675 static void SubbandCoherenceNEON(AecCore* aec,
676                                  float efw[2][PART_LEN1],
677                                  float dfw[2][PART_LEN1],
678                                  float xfw[2][PART_LEN1],
679                                  float* fft,
680                                  float* cohde,
681                                  float* cohxd,
682                                  int* extreme_filter_divergence) {
683   int i;
684 
685   SmoothedPSD(aec, efw, dfw, xfw, extreme_filter_divergence);
686 
687   {
688     const float32x4_t vec_1eminus10 =  vdupq_n_f32(1e-10f);
689 
690     // Subband coherence
691     for (i = 0; i + 3 < PART_LEN1; i += 4) {
692       const float32x4_t vec_sd = vld1q_f32(&aec->sd[i]);
693       const float32x4_t vec_se = vld1q_f32(&aec->se[i]);
694       const float32x4_t vec_sx = vld1q_f32(&aec->sx[i]);
695       const float32x4_t vec_sdse = vmlaq_f32(vec_1eminus10, vec_sd, vec_se);
696       const float32x4_t vec_sdsx = vmlaq_f32(vec_1eminus10, vec_sd, vec_sx);
697       float32x4x2_t vec_sde = vld2q_f32(&aec->sde[i][0]);
698       float32x4x2_t vec_sxd = vld2q_f32(&aec->sxd[i][0]);
699       float32x4_t vec_cohde = vmulq_f32(vec_sde.val[0], vec_sde.val[0]);
700       float32x4_t vec_cohxd = vmulq_f32(vec_sxd.val[0], vec_sxd.val[0]);
701       vec_cohde = vmlaq_f32(vec_cohde, vec_sde.val[1], vec_sde.val[1]);
702       vec_cohde = vdivq_f32(vec_cohde, vec_sdse);
703       vec_cohxd = vmlaq_f32(vec_cohxd, vec_sxd.val[1], vec_sxd.val[1]);
704       vec_cohxd = vdivq_f32(vec_cohxd, vec_sdsx);
705 
706       vst1q_f32(&cohde[i], vec_cohde);
707       vst1q_f32(&cohxd[i], vec_cohxd);
708     }
709   }
710   // scalar code for the remaining items.
711   for (; i < PART_LEN1; i++) {
712     cohde[i] =
713         (aec->sde[i][0] * aec->sde[i][0] + aec->sde[i][1] * aec->sde[i][1]) /
714         (aec->sd[i] * aec->se[i] + 1e-10f);
715     cohxd[i] =
716         (aec->sxd[i][0] * aec->sxd[i][0] + aec->sxd[i][1] * aec->sxd[i][1]) /
717         (aec->sx[i] * aec->sd[i] + 1e-10f);
718   }
719 }
720 
WebRtcAec_InitAec_neon(void)721 void WebRtcAec_InitAec_neon(void) {
722   WebRtcAec_FilterFar = FilterFarNEON;
723   WebRtcAec_ScaleErrorSignal = ScaleErrorSignalNEON;
724   WebRtcAec_FilterAdaptation = FilterAdaptationNEON;
725   WebRtcAec_OverdriveAndSuppress = OverdriveAndSuppressNEON;
726   WebRtcAec_SubbandCoherence = SubbandCoherenceNEON;
727   WebRtcAec_StoreAsComplex = StoreAsComplexNEON;
728   WebRtcAec_PartitionDelay = PartitionDelayNEON;
729   WebRtcAec_WindowData = WindowDataNEON;
730 }
731