• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/adaptive_fir_filter.h"
12 
13 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
14 #include "rtc_base/system/arch.h"
15 
16 #if defined(WEBRTC_HAS_NEON)
17 #include <arm_neon.h>
18 #endif
19 #if defined(WEBRTC_ARCH_X86_FAMILY)
20 #include <emmintrin.h>
21 #endif
22 #include <math.h>
23 
24 #include <algorithm>
25 #include <functional>
26 
27 #include "modules/audio_processing/aec3/fft_data.h"
28 #include "rtc_base/checks.h"
29 
30 namespace webrtc {
31 
32 namespace aec3 {
33 
34 // Computes and stores the frequency response of the filter.
ComputeFrequencyResponse(size_t num_partitions,const std::vector<std::vector<FftData>> & H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)35 void ComputeFrequencyResponse(
36     size_t num_partitions,
37     const std::vector<std::vector<FftData>>& H,
38     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
39   for (auto& H2_ch : *H2) {
40     H2_ch.fill(0.f);
41   }
42 
43   const size_t num_render_channels = H[0].size();
44   RTC_DCHECK_EQ(H.size(), H2->capacity());
45   for (size_t p = 0; p < num_partitions; ++p) {
46     RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
47     for (size_t ch = 0; ch < num_render_channels; ++ch) {
48       for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) {
49         float tmp =
50             H[p][ch].re[j] * H[p][ch].re[j] + H[p][ch].im[j] * H[p][ch].im[j];
51         (*H2)[p][j] = std::max((*H2)[p][j], tmp);
52       }
53     }
54   }
55 }
56 
57 #if defined(WEBRTC_HAS_NEON)
58 // Computes and stores the frequency response of the filter.
ComputeFrequencyResponse_Neon(size_t num_partitions,const std::vector<std::vector<FftData>> & H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)59 void ComputeFrequencyResponse_Neon(
60     size_t num_partitions,
61     const std::vector<std::vector<FftData>>& H,
62     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
63   for (auto& H2_ch : *H2) {
64     H2_ch.fill(0.f);
65   }
66 
67   const size_t num_render_channels = H[0].size();
68   RTC_DCHECK_EQ(H.size(), H2->capacity());
69   for (size_t p = 0; p < num_partitions; ++p) {
70     RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
71     for (size_t ch = 0; ch < num_render_channels; ++ch) {
72       for (size_t j = 0; j < kFftLengthBy2; j += 4) {
73         const float32x4_t re = vld1q_f32(&H[p][ch].re[j]);
74         const float32x4_t im = vld1q_f32(&H[p][ch].im[j]);
75         float32x4_t H2_new = vmulq_f32(re, re);
76         H2_new = vmlaq_f32(H2_new, im, im);
77         float32x4_t H2_p_j = vld1q_f32(&(*H2)[p][j]);
78         H2_p_j = vmaxq_f32(H2_p_j, H2_new);
79         vst1q_f32(&(*H2)[p][j], H2_p_j);
80       }
81       float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
82                      H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
83       (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
84     }
85   }
86 }
87 #endif
88 
89 #if defined(WEBRTC_ARCH_X86_FAMILY)
90 // Computes and stores the frequency response of the filter.
ComputeFrequencyResponse_Sse2(size_t num_partitions,const std::vector<std::vector<FftData>> & H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)91 void ComputeFrequencyResponse_Sse2(
92     size_t num_partitions,
93     const std::vector<std::vector<FftData>>& H,
94     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
95   for (auto& H2_ch : *H2) {
96     H2_ch.fill(0.f);
97   }
98 
99   const size_t num_render_channels = H[0].size();
100   RTC_DCHECK_EQ(H.size(), H2->capacity());
101   // constexpr __mmmask8 kMaxMask = static_cast<__mmmask8>(256u);
102   for (size_t p = 0; p < num_partitions; ++p) {
103     RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
104     for (size_t ch = 0; ch < num_render_channels; ++ch) {
105       for (size_t j = 0; j < kFftLengthBy2; j += 4) {
106         const __m128 re = _mm_loadu_ps(&H[p][ch].re[j]);
107         const __m128 re2 = _mm_mul_ps(re, re);
108         const __m128 im = _mm_loadu_ps(&H[p][ch].im[j]);
109         const __m128 im2 = _mm_mul_ps(im, im);
110         const __m128 H2_new = _mm_add_ps(re2, im2);
111         __m128 H2_k_j = _mm_loadu_ps(&(*H2)[p][j]);
112         H2_k_j = _mm_max_ps(H2_k_j, H2_new);
113         _mm_storeu_ps(&(*H2)[p][j], H2_k_j);
114       }
115       float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
116                      H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
117       (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
118     }
119   }
120 }
121 #endif
122 
123 // Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)).
AdaptPartitions(const RenderBuffer & render_buffer,const FftData & G,size_t num_partitions,std::vector<std::vector<FftData>> * H)124 void AdaptPartitions(const RenderBuffer& render_buffer,
125                      const FftData& G,
126                      size_t num_partitions,
127                      std::vector<std::vector<FftData>>* H) {
128   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
129       render_buffer.GetFftBuffer();
130   size_t index = render_buffer.Position();
131   const size_t num_render_channels = render_buffer_data[index].size();
132   for (size_t p = 0; p < num_partitions; ++p) {
133     for (size_t ch = 0; ch < num_render_channels; ++ch) {
134       const FftData& X_p_ch = render_buffer_data[index][ch];
135       FftData& H_p_ch = (*H)[p][ch];
136       for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
137         H_p_ch.re[k] += X_p_ch.re[k] * G.re[k] + X_p_ch.im[k] * G.im[k];
138         H_p_ch.im[k] += X_p_ch.re[k] * G.im[k] - X_p_ch.im[k] * G.re[k];
139       }
140     }
141     index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
142   }
143 }
144 
145 #if defined(WEBRTC_HAS_NEON)
146 // Adapts the filter partitions. (Neon variant)
AdaptPartitions_Neon(const RenderBuffer & render_buffer,const FftData & G,size_t num_partitions,std::vector<std::vector<FftData>> * H)147 void AdaptPartitions_Neon(const RenderBuffer& render_buffer,
148                           const FftData& G,
149                           size_t num_partitions,
150                           std::vector<std::vector<FftData>>* H) {
151   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
152       render_buffer.GetFftBuffer();
153   const size_t num_render_channels = render_buffer_data[0].size();
154   const size_t lim1 = std::min(
155       render_buffer_data.size() - render_buffer.Position(), num_partitions);
156   const size_t lim2 = num_partitions;
157   constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
158 
159   size_t X_partition = render_buffer.Position();
160   size_t limit = lim1;
161   size_t p = 0;
162   do {
163     for (; p < limit; ++p, ++X_partition) {
164       for (size_t ch = 0; ch < num_render_channels; ++ch) {
165         FftData& H_p_ch = (*H)[p][ch];
166         const FftData& X = render_buffer_data[X_partition][ch];
167         for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
168           const float32x4_t G_re = vld1q_f32(&G.re[k]);
169           const float32x4_t G_im = vld1q_f32(&G.im[k]);
170           const float32x4_t X_re = vld1q_f32(&X.re[k]);
171           const float32x4_t X_im = vld1q_f32(&X.im[k]);
172           const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]);
173           const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]);
174           const float32x4_t a = vmulq_f32(X_re, G_re);
175           const float32x4_t e = vmlaq_f32(a, X_im, G_im);
176           const float32x4_t c = vmulq_f32(X_re, G_im);
177           const float32x4_t f = vmlsq_f32(c, X_im, G_re);
178           const float32x4_t g = vaddq_f32(H_re, e);
179           const float32x4_t h = vaddq_f32(H_im, f);
180           vst1q_f32(&H_p_ch.re[k], g);
181           vst1q_f32(&H_p_ch.im[k], h);
182         }
183       }
184     }
185 
186     X_partition = 0;
187     limit = lim2;
188   } while (p < lim2);
189 
190   X_partition = render_buffer.Position();
191   limit = lim1;
192   p = 0;
193   do {
194     for (; p < limit; ++p, ++X_partition) {
195       for (size_t ch = 0; ch < num_render_channels; ++ch) {
196         FftData& H_p_ch = (*H)[p][ch];
197         const FftData& X = render_buffer_data[X_partition][ch];
198 
199         H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
200                                     X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
201         H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
202                                     X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
203       }
204     }
205     X_partition = 0;
206     limit = lim2;
207   } while (p < lim2);
208 }
209 #endif
210 
211 #if defined(WEBRTC_ARCH_X86_FAMILY)
212 // Adapts the filter partitions. (SSE2 variant)
AdaptPartitions_Sse2(const RenderBuffer & render_buffer,const FftData & G,size_t num_partitions,std::vector<std::vector<FftData>> * H)213 void AdaptPartitions_Sse2(const RenderBuffer& render_buffer,
214                           const FftData& G,
215                           size_t num_partitions,
216                           std::vector<std::vector<FftData>>* H) {
217   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
218       render_buffer.GetFftBuffer();
219   const size_t num_render_channels = render_buffer_data[0].size();
220   const size_t lim1 = std::min(
221       render_buffer_data.size() - render_buffer.Position(), num_partitions);
222   const size_t lim2 = num_partitions;
223   constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
224 
225   size_t X_partition = render_buffer.Position();
226   size_t limit = lim1;
227   size_t p = 0;
228   do {
229     for (; p < limit; ++p, ++X_partition) {
230       for (size_t ch = 0; ch < num_render_channels; ++ch) {
231         FftData& H_p_ch = (*H)[p][ch];
232         const FftData& X = render_buffer_data[X_partition][ch];
233 
234         for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
235           const __m128 G_re = _mm_loadu_ps(&G.re[k]);
236           const __m128 G_im = _mm_loadu_ps(&G.im[k]);
237           const __m128 X_re = _mm_loadu_ps(&X.re[k]);
238           const __m128 X_im = _mm_loadu_ps(&X.im[k]);
239           const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]);
240           const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]);
241           const __m128 a = _mm_mul_ps(X_re, G_re);
242           const __m128 b = _mm_mul_ps(X_im, G_im);
243           const __m128 c = _mm_mul_ps(X_re, G_im);
244           const __m128 d = _mm_mul_ps(X_im, G_re);
245           const __m128 e = _mm_add_ps(a, b);
246           const __m128 f = _mm_sub_ps(c, d);
247           const __m128 g = _mm_add_ps(H_re, e);
248           const __m128 h = _mm_add_ps(H_im, f);
249           _mm_storeu_ps(&H_p_ch.re[k], g);
250           _mm_storeu_ps(&H_p_ch.im[k], h);
251         }
252       }
253     }
254     X_partition = 0;
255     limit = lim2;
256   } while (p < lim2);
257 
258   X_partition = render_buffer.Position();
259   limit = lim1;
260   p = 0;
261   do {
262     for (; p < limit; ++p, ++X_partition) {
263       for (size_t ch = 0; ch < num_render_channels; ++ch) {
264         FftData& H_p_ch = (*H)[p][ch];
265         const FftData& X = render_buffer_data[X_partition][ch];
266 
267         H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
268                                     X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
269         H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
270                                     X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
271       }
272     }
273 
274     X_partition = 0;
275     limit = lim2;
276   } while (p < lim2);
277 }
278 #endif
279 
280 // Produces the filter output.
ApplyFilter(const RenderBuffer & render_buffer,size_t num_partitions,const std::vector<std::vector<FftData>> & H,FftData * S)281 void ApplyFilter(const RenderBuffer& render_buffer,
282                  size_t num_partitions,
283                  const std::vector<std::vector<FftData>>& H,
284                  FftData* S) {
285   S->re.fill(0.f);
286   S->im.fill(0.f);
287 
288   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
289       render_buffer.GetFftBuffer();
290   size_t index = render_buffer.Position();
291   const size_t num_render_channels = render_buffer_data[index].size();
292   for (size_t p = 0; p < num_partitions; ++p) {
293     RTC_DCHECK_EQ(num_render_channels, H[p].size());
294     for (size_t ch = 0; ch < num_render_channels; ++ch) {
295       const FftData& X_p_ch = render_buffer_data[index][ch];
296       const FftData& H_p_ch = H[p][ch];
297       for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
298         S->re[k] += X_p_ch.re[k] * H_p_ch.re[k] - X_p_ch.im[k] * H_p_ch.im[k];
299         S->im[k] += X_p_ch.re[k] * H_p_ch.im[k] + X_p_ch.im[k] * H_p_ch.re[k];
300       }
301     }
302     index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
303   }
304 }
305 
306 #if defined(WEBRTC_HAS_NEON)
307 // Produces the filter output (Neon variant).
ApplyFilter_Neon(const RenderBuffer & render_buffer,size_t num_partitions,const std::vector<std::vector<FftData>> & H,FftData * S)308 void ApplyFilter_Neon(const RenderBuffer& render_buffer,
309                       size_t num_partitions,
310                       const std::vector<std::vector<FftData>>& H,
311                       FftData* S) {
312   // const RenderBuffer& render_buffer,
313   //                     rtc::ArrayView<const FftData> H,
314   //                     FftData* S) {
315   RTC_DCHECK_GE(H.size(), H.size() - 1);
316   S->Clear();
317 
318   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
319       render_buffer.GetFftBuffer();
320   const size_t num_render_channels = render_buffer_data[0].size();
321   const size_t lim1 = std::min(
322       render_buffer_data.size() - render_buffer.Position(), num_partitions);
323   const size_t lim2 = num_partitions;
324   constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
325 
326   size_t X_partition = render_buffer.Position();
327   size_t p = 0;
328   size_t limit = lim1;
329   do {
330     for (; p < limit; ++p, ++X_partition) {
331       for (size_t ch = 0; ch < num_render_channels; ++ch) {
332         const FftData& H_p_ch = H[p][ch];
333         const FftData& X = render_buffer_data[X_partition][ch];
334         for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
335           const float32x4_t X_re = vld1q_f32(&X.re[k]);
336           const float32x4_t X_im = vld1q_f32(&X.im[k]);
337           const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]);
338           const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]);
339           const float32x4_t S_re = vld1q_f32(&S->re[k]);
340           const float32x4_t S_im = vld1q_f32(&S->im[k]);
341           const float32x4_t a = vmulq_f32(X_re, H_re);
342           const float32x4_t e = vmlsq_f32(a, X_im, H_im);
343           const float32x4_t c = vmulq_f32(X_re, H_im);
344           const float32x4_t f = vmlaq_f32(c, X_im, H_re);
345           const float32x4_t g = vaddq_f32(S_re, e);
346           const float32x4_t h = vaddq_f32(S_im, f);
347           vst1q_f32(&S->re[k], g);
348           vst1q_f32(&S->im[k], h);
349         }
350       }
351     }
352     limit = lim2;
353     X_partition = 0;
354   } while (p < lim2);
355 
356   X_partition = render_buffer.Position();
357   p = 0;
358   limit = lim1;
359   do {
360     for (; p < limit; ++p, ++X_partition) {
361       for (size_t ch = 0; ch < num_render_channels; ++ch) {
362         const FftData& H_p_ch = H[p][ch];
363         const FftData& X = render_buffer_data[X_partition][ch];
364         S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
365                                 X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
366         S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
367                                 X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
368       }
369     }
370     limit = lim2;
371     X_partition = 0;
372   } while (p < lim2);
373 }
374 #endif
375 
376 #if defined(WEBRTC_ARCH_X86_FAMILY)
377 // Produces the filter output (SSE2 variant).
ApplyFilter_Sse2(const RenderBuffer & render_buffer,size_t num_partitions,const std::vector<std::vector<FftData>> & H,FftData * S)378 void ApplyFilter_Sse2(const RenderBuffer& render_buffer,
379                       size_t num_partitions,
380                       const std::vector<std::vector<FftData>>& H,
381                       FftData* S) {
382   // const RenderBuffer& render_buffer,
383   //                     rtc::ArrayView<const FftData> H,
384   //                     FftData* S) {
385   RTC_DCHECK_GE(H.size(), H.size() - 1);
386   S->re.fill(0.f);
387   S->im.fill(0.f);
388 
389   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
390       render_buffer.GetFftBuffer();
391   const size_t num_render_channels = render_buffer_data[0].size();
392   const size_t lim1 = std::min(
393       render_buffer_data.size() - render_buffer.Position(), num_partitions);
394   const size_t lim2 = num_partitions;
395   constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
396 
397   size_t X_partition = render_buffer.Position();
398   size_t p = 0;
399   size_t limit = lim1;
400   do {
401     for (; p < limit; ++p, ++X_partition) {
402       for (size_t ch = 0; ch < num_render_channels; ++ch) {
403         const FftData& H_p_ch = H[p][ch];
404         const FftData& X = render_buffer_data[X_partition][ch];
405         for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
406           const __m128 X_re = _mm_loadu_ps(&X.re[k]);
407           const __m128 X_im = _mm_loadu_ps(&X.im[k]);
408           const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]);
409           const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]);
410           const __m128 S_re = _mm_loadu_ps(&S->re[k]);
411           const __m128 S_im = _mm_loadu_ps(&S->im[k]);
412           const __m128 a = _mm_mul_ps(X_re, H_re);
413           const __m128 b = _mm_mul_ps(X_im, H_im);
414           const __m128 c = _mm_mul_ps(X_re, H_im);
415           const __m128 d = _mm_mul_ps(X_im, H_re);
416           const __m128 e = _mm_sub_ps(a, b);
417           const __m128 f = _mm_add_ps(c, d);
418           const __m128 g = _mm_add_ps(S_re, e);
419           const __m128 h = _mm_add_ps(S_im, f);
420           _mm_storeu_ps(&S->re[k], g);
421           _mm_storeu_ps(&S->im[k], h);
422         }
423       }
424     }
425     limit = lim2;
426     X_partition = 0;
427   } while (p < lim2);
428 
429   X_partition = render_buffer.Position();
430   p = 0;
431   limit = lim1;
432   do {
433     for (; p < limit; ++p, ++X_partition) {
434       for (size_t ch = 0; ch < num_render_channels; ++ch) {
435         const FftData& H_p_ch = H[p][ch];
436         const FftData& X = render_buffer_data[X_partition][ch];
437         S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
438                                 X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
439         S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
440                                 X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
441       }
442     }
443     limit = lim2;
444     X_partition = 0;
445   } while (p < lim2);
446 }
447 #endif
448 
449 }  // namespace aec3
450 
451 namespace {
452 
453 // Ensures that the newly added filter partitions after a size increase are set
454 // to zero.
ZeroFilter(size_t old_size,size_t new_size,std::vector<std::vector<FftData>> * H)455 void ZeroFilter(size_t old_size,
456                 size_t new_size,
457                 std::vector<std::vector<FftData>>* H) {
458   RTC_DCHECK_GE(H->size(), old_size);
459   RTC_DCHECK_GE(H->size(), new_size);
460 
461   for (size_t p = old_size; p < new_size; ++p) {
462     RTC_DCHECK_EQ((*H)[p].size(), (*H)[0].size());
463     for (size_t ch = 0; ch < (*H)[0].size(); ++ch) {
464       (*H)[p][ch].Clear();
465     }
466   }
467 }
468 
469 }  // namespace
470 
AdaptiveFirFilter(size_t max_size_partitions,size_t initial_size_partitions,size_t size_change_duration_blocks,size_t num_render_channels,Aec3Optimization optimization,ApmDataDumper * data_dumper)471 AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions,
472                                      size_t initial_size_partitions,
473                                      size_t size_change_duration_blocks,
474                                      size_t num_render_channels,
475                                      Aec3Optimization optimization,
476                                      ApmDataDumper* data_dumper)
477     : data_dumper_(data_dumper),
478       fft_(),
479       optimization_(optimization),
480       num_render_channels_(num_render_channels),
481       max_size_partitions_(max_size_partitions),
482       size_change_duration_blocks_(
483           static_cast<int>(size_change_duration_blocks)),
484       current_size_partitions_(initial_size_partitions),
485       target_size_partitions_(initial_size_partitions),
486       old_target_size_partitions_(initial_size_partitions),
487       H_(max_size_partitions_, std::vector<FftData>(num_render_channels_)) {
488   RTC_DCHECK(data_dumper_);
489   RTC_DCHECK_GE(max_size_partitions, initial_size_partitions);
490 
491   RTC_DCHECK_LT(0, size_change_duration_blocks_);
492   one_by_size_change_duration_blocks_ = 1.f / size_change_duration_blocks_;
493 
494   ZeroFilter(0, max_size_partitions_, &H_);
495 
496   SetSizePartitions(current_size_partitions_, true);
497 }
498 
499 AdaptiveFirFilter::~AdaptiveFirFilter() = default;
500 
HandleEchoPathChange()501 void AdaptiveFirFilter::HandleEchoPathChange() {
502   // TODO(peah): Check the value and purpose of the code below.
503   ZeroFilter(current_size_partitions_, max_size_partitions_, &H_);
504 }
505 
SetSizePartitions(size_t size,bool immediate_effect)506 void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) {
507   RTC_DCHECK_EQ(max_size_partitions_, H_.capacity());
508   RTC_DCHECK_LE(size, max_size_partitions_);
509 
510   target_size_partitions_ = std::min(max_size_partitions_, size);
511   if (immediate_effect) {
512     size_t old_size_partitions_ = current_size_partitions_;
513     current_size_partitions_ = old_target_size_partitions_ =
514         target_size_partitions_;
515     ZeroFilter(old_size_partitions_, current_size_partitions_, &H_);
516 
517     partition_to_constrain_ =
518         std::min(partition_to_constrain_, current_size_partitions_ - 1);
519     size_change_counter_ = 0;
520   } else {
521     size_change_counter_ = size_change_duration_blocks_;
522   }
523 }
524 
UpdateSize()525 void AdaptiveFirFilter::UpdateSize() {
526   RTC_DCHECK_GE(size_change_duration_blocks_, size_change_counter_);
527   size_t old_size_partitions_ = current_size_partitions_;
528   if (size_change_counter_ > 0) {
529     --size_change_counter_;
530 
531     auto average = [](float from, float to, float from_weight) {
532       return from * from_weight + to * (1.f - from_weight);
533     };
534 
535     float change_factor =
536         size_change_counter_ * one_by_size_change_duration_blocks_;
537 
538     current_size_partitions_ = average(old_target_size_partitions_,
539                                        target_size_partitions_, change_factor);
540 
541     partition_to_constrain_ =
542         std::min(partition_to_constrain_, current_size_partitions_ - 1);
543   } else {
544     current_size_partitions_ = old_target_size_partitions_ =
545         target_size_partitions_;
546   }
547   ZeroFilter(old_size_partitions_, current_size_partitions_, &H_);
548   RTC_DCHECK_LE(0, size_change_counter_);
549 }
550 
Filter(const RenderBuffer & render_buffer,FftData * S) const551 void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer,
552                                FftData* S) const {
553   RTC_DCHECK(S);
554   switch (optimization_) {
555 #if defined(WEBRTC_ARCH_X86_FAMILY)
556     case Aec3Optimization::kSse2:
557       aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S);
558       break;
559 #endif
560 #if defined(WEBRTC_HAS_NEON)
561     case Aec3Optimization::kNeon:
562       aec3::ApplyFilter_Neon(render_buffer, current_size_partitions_, H_, S);
563       break;
564 #endif
565     default:
566       aec3::ApplyFilter(render_buffer, current_size_partitions_, H_, S);
567   }
568 }
569 
Adapt(const RenderBuffer & render_buffer,const FftData & G)570 void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
571                               const FftData& G) {
572   // Adapt the filter and update the filter size.
573   AdaptAndUpdateSize(render_buffer, G);
574 
575   // Constrain the filter partitions in a cyclic manner.
576   Constrain();
577 }
578 
Adapt(const RenderBuffer & render_buffer,const FftData & G,std::vector<float> * impulse_response)579 void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
580                               const FftData& G,
581                               std::vector<float>* impulse_response) {
582   // Adapt the filter and update the filter size.
583   AdaptAndUpdateSize(render_buffer, G);
584 
585   // Constrain the filter partitions in a cyclic manner.
586   ConstrainAndUpdateImpulseResponse(impulse_response);
587 }
588 
ComputeFrequencyResponse(std::vector<std::array<float,kFftLengthBy2Plus1>> * H2) const589 void AdaptiveFirFilter::ComputeFrequencyResponse(
590     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) const {
591   RTC_DCHECK_GE(max_size_partitions_, H2->capacity());
592 
593   H2->resize(current_size_partitions_);
594 
595   switch (optimization_) {
596 #if defined(WEBRTC_ARCH_X86_FAMILY)
597     case Aec3Optimization::kSse2:
598       aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2);
599       break;
600 #endif
601 #if defined(WEBRTC_HAS_NEON)
602     case Aec3Optimization::kNeon:
603       aec3::ComputeFrequencyResponse_Neon(current_size_partitions_, H_, H2);
604       break;
605 #endif
606     default:
607       aec3::ComputeFrequencyResponse(current_size_partitions_, H_, H2);
608   }
609 }
610 
AdaptAndUpdateSize(const RenderBuffer & render_buffer,const FftData & G)611 void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer,
612                                            const FftData& G) {
613   // Update the filter size if needed.
614   UpdateSize();
615 
616   // Adapt the filter.
617   switch (optimization_) {
618 #if defined(WEBRTC_ARCH_X86_FAMILY)
619     case Aec3Optimization::kSse2:
620       aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_,
621                                  &H_);
622       break;
623 #endif
624 #if defined(WEBRTC_HAS_NEON)
625     case Aec3Optimization::kNeon:
626       aec3::AdaptPartitions_Neon(render_buffer, G, current_size_partitions_,
627                                  &H_);
628       break;
629 #endif
630     default:
631       aec3::AdaptPartitions(render_buffer, G, current_size_partitions_, &H_);
632   }
633 }
634 
635 // Constrains the partition of the frequency domain filter to be limited in
636 // time via setting the relevant time-domain coefficients to zero and updates
637 // the corresponding values in an externally stored impulse response estimate.
ConstrainAndUpdateImpulseResponse(std::vector<float> * impulse_response)638 void AdaptiveFirFilter::ConstrainAndUpdateImpulseResponse(
639     std::vector<float>* impulse_response) {
640   RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_),
641                 impulse_response->capacity());
642   impulse_response->resize(GetTimeDomainLength(current_size_partitions_));
643   std::array<float, kFftLength> h;
644   impulse_response->resize(GetTimeDomainLength(current_size_partitions_));
645   std::fill(
646       impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2,
647       impulse_response->begin() + (partition_to_constrain_ + 1) * kFftLengthBy2,
648       0.f);
649 
650   for (size_t ch = 0; ch < num_render_channels_; ++ch) {
651     fft_.Ifft(H_[partition_to_constrain_][ch], &h);
652 
653     static constexpr float kScale = 1.0f / kFftLengthBy2;
654     std::for_each(h.begin(), h.begin() + kFftLengthBy2,
655                   [](float& a) { a *= kScale; });
656     std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
657 
658     if (ch == 0) {
659       std::copy(
660           h.begin(), h.begin() + kFftLengthBy2,
661           impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2);
662     } else {
663       for (size_t k = 0, j = partition_to_constrain_ * kFftLengthBy2;
664            k < kFftLengthBy2; ++k, ++j) {
665         if (fabsf((*impulse_response)[j]) < fabsf(h[k])) {
666           (*impulse_response)[j] = h[k];
667         }
668       }
669     }
670 
671     fft_.Fft(&h, &H_[partition_to_constrain_][ch]);
672   }
673 
674   partition_to_constrain_ =
675       partition_to_constrain_ < (current_size_partitions_ - 1)
676           ? partition_to_constrain_ + 1
677           : 0;
678 }
679 
680 // Constrains the a partiton of the frequency domain filter to be limited in
681 // time via setting the relevant time-domain coefficients to zero.
Constrain()682 void AdaptiveFirFilter::Constrain() {
683   std::array<float, kFftLength> h;
684   for (size_t ch = 0; ch < num_render_channels_; ++ch) {
685     fft_.Ifft(H_[partition_to_constrain_][ch], &h);
686 
687     static constexpr float kScale = 1.0f / kFftLengthBy2;
688     std::for_each(h.begin(), h.begin() + kFftLengthBy2,
689                   [](float& a) { a *= kScale; });
690     std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
691 
692     fft_.Fft(&h, &H_[partition_to_constrain_][ch]);
693   }
694 
695   partition_to_constrain_ =
696       partition_to_constrain_ < (current_size_partitions_ - 1)
697           ? partition_to_constrain_ + 1
698           : 0;
699 }
700 
ScaleFilter(float factor)701 void AdaptiveFirFilter::ScaleFilter(float factor) {
702   for (auto& H_p : H_) {
703     for (auto& H_p_ch : H_p) {
704       for (auto& re : H_p_ch.re) {
705         re *= factor;
706       }
707       for (auto& im : H_p_ch.im) {
708         im *= factor;
709       }
710     }
711   }
712 }
713 
714 // Set the filter coefficients.
SetFilter(size_t num_partitions,const std::vector<std::vector<FftData>> & H)715 void AdaptiveFirFilter::SetFilter(size_t num_partitions,
716                                   const std::vector<std::vector<FftData>>& H) {
717   const size_t min_num_partitions =
718       std::min(current_size_partitions_, num_partitions);
719   for (size_t p = 0; p < min_num_partitions; ++p) {
720     RTC_DCHECK_EQ(H_[p].size(), H[p].size());
721     RTC_DCHECK_EQ(num_render_channels_, H_[p].size());
722 
723     for (size_t ch = 0; ch < num_render_channels_; ++ch) {
724       std::copy(H[p][ch].re.begin(), H[p][ch].re.end(), H_[p][ch].re.begin());
725       std::copy(H[p][ch].im.begin(), H[p][ch].im.end(), H_[p][ch].im.begin());
726     }
727   }
728 }
729 
730 }  // namespace webrtc
731