• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/subtractor.h"
12 
13 #include <algorithm>
14 #include <utility>
15 
16 #include "api/array_view.h"
17 #include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
18 #include "modules/audio_processing/aec3/fft_data.h"
19 #include "modules/audio_processing/logging/apm_data_dumper.h"
20 #include "rtc_base/checks.h"
21 #include "rtc_base/numerics/safe_minmax.h"
22 
23 namespace webrtc {
24 
25 namespace {
26 
PredictionError(const Aec3Fft & fft,const FftData & S,rtc::ArrayView<const float> y,std::array<float,kBlockSize> * e,std::array<float,kBlockSize> * s)27 void PredictionError(const Aec3Fft& fft,
28                      const FftData& S,
29                      rtc::ArrayView<const float> y,
30                      std::array<float, kBlockSize>* e,
31                      std::array<float, kBlockSize>* s) {
32   std::array<float, kFftLength> tmp;
33   fft.Ifft(S, &tmp);
34   constexpr float kScale = 1.0f / kFftLengthBy2;
35   std::transform(y.begin(), y.end(), tmp.begin() + kFftLengthBy2, e->begin(),
36                  [&](float a, float b) { return a - b * kScale; });
37 
38   if (s) {
39     for (size_t k = 0; k < s->size(); ++k) {
40       (*s)[k] = kScale * tmp[k + kFftLengthBy2];
41     }
42   }
43 }
44 
ScaleFilterOutput(rtc::ArrayView<const float> y,float factor,rtc::ArrayView<float> e,rtc::ArrayView<float> s)45 void ScaleFilterOutput(rtc::ArrayView<const float> y,
46                        float factor,
47                        rtc::ArrayView<float> e,
48                        rtc::ArrayView<float> s) {
49   RTC_DCHECK_EQ(y.size(), e.size());
50   RTC_DCHECK_EQ(y.size(), s.size());
51   for (size_t k = 0; k < y.size(); ++k) {
52     s[k] *= factor;
53     e[k] = y[k] - s[k];
54   }
55 }
56 
57 }  // namespace
58 
Subtractor(const EchoCanceller3Config & config,size_t num_render_channels,size_t num_capture_channels,ApmDataDumper * data_dumper,Aec3Optimization optimization)59 Subtractor::Subtractor(const EchoCanceller3Config& config,
60                        size_t num_render_channels,
61                        size_t num_capture_channels,
62                        ApmDataDumper* data_dumper,
63                        Aec3Optimization optimization)
64     : fft_(),
65       data_dumper_(data_dumper),
66       optimization_(optimization),
67       config_(config),
68       num_capture_channels_(num_capture_channels),
69       refined_filters_(num_capture_channels_),
70       coarse_filter_(num_capture_channels_),
71       refined_gains_(num_capture_channels_),
72       coarse_gains_(num_capture_channels_),
73       filter_misadjustment_estimators_(num_capture_channels_),
74       poor_coarse_filter_counters_(num_capture_channels_, 0),
75       refined_frequency_responses_(
76           num_capture_channels_,
77           std::vector<std::array<float, kFftLengthBy2Plus1>>(
78               std::max(config_.filter.refined_initial.length_blocks,
79                        config_.filter.refined.length_blocks),
80               std::array<float, kFftLengthBy2Plus1>())),
81       refined_impulse_responses_(
82           num_capture_channels_,
83           std::vector<float>(GetTimeDomainLength(std::max(
84                                  config_.filter.refined_initial.length_blocks,
85                                  config_.filter.refined.length_blocks)),
86                              0.f)) {
87   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
88     refined_filters_[ch] = std::make_unique<AdaptiveFirFilter>(
89         config_.filter.refined.length_blocks,
90         config_.filter.refined_initial.length_blocks,
91         config.filter.config_change_duration_blocks, num_render_channels,
92         optimization, data_dumper_);
93 
94     coarse_filter_[ch] = std::make_unique<AdaptiveFirFilter>(
95         config_.filter.coarse.length_blocks,
96         config_.filter.coarse_initial.length_blocks,
97         config.filter.config_change_duration_blocks, num_render_channels,
98         optimization, data_dumper_);
99     refined_gains_[ch] = std::make_unique<RefinedFilterUpdateGain>(
100         config_.filter.refined_initial,
101         config_.filter.config_change_duration_blocks);
102     coarse_gains_[ch] = std::make_unique<CoarseFilterUpdateGain>(
103         config_.filter.coarse_initial,
104         config.filter.config_change_duration_blocks);
105   }
106 
107   RTC_DCHECK(data_dumper_);
108   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
109     for (auto& H2_k : refined_frequency_responses_[ch]) {
110       H2_k.fill(0.f);
111     }
112   }
113 }
114 
115 Subtractor::~Subtractor() = default;
116 
HandleEchoPathChange(const EchoPathVariability & echo_path_variability)117 void Subtractor::HandleEchoPathChange(
118     const EchoPathVariability& echo_path_variability) {
119   const auto full_reset = [&]() {
120     for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
121       refined_filters_[ch]->HandleEchoPathChange();
122       coarse_filter_[ch]->HandleEchoPathChange();
123       refined_gains_[ch]->HandleEchoPathChange(echo_path_variability);
124       coarse_gains_[ch]->HandleEchoPathChange();
125       refined_gains_[ch]->SetConfig(config_.filter.refined_initial, true);
126       coarse_gains_[ch]->SetConfig(config_.filter.coarse_initial, true);
127       refined_filters_[ch]->SetSizePartitions(
128           config_.filter.refined_initial.length_blocks, true);
129       coarse_filter_[ch]->SetSizePartitions(
130           config_.filter.coarse_initial.length_blocks, true);
131     }
132   };
133 
134   if (echo_path_variability.delay_change !=
135       EchoPathVariability::DelayAdjustment::kNone) {
136     full_reset();
137   }
138 
139   if (echo_path_variability.gain_change) {
140     for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
141       refined_gains_[ch]->HandleEchoPathChange(echo_path_variability);
142     }
143   }
144 }
145 
ExitInitialState()146 void Subtractor::ExitInitialState() {
147   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
148     refined_gains_[ch]->SetConfig(config_.filter.refined, false);
149     coarse_gains_[ch]->SetConfig(config_.filter.coarse, false);
150     refined_filters_[ch]->SetSizePartitions(
151         config_.filter.refined.length_blocks, false);
152     coarse_filter_[ch]->SetSizePartitions(config_.filter.coarse.length_blocks,
153                                           false);
154   }
155 }
156 
Process(const RenderBuffer & render_buffer,const std::vector<std::vector<float>> & capture,const RenderSignalAnalyzer & render_signal_analyzer,const AecState & aec_state,rtc::ArrayView<SubtractorOutput> outputs)157 void Subtractor::Process(const RenderBuffer& render_buffer,
158                          const std::vector<std::vector<float>>& capture,
159                          const RenderSignalAnalyzer& render_signal_analyzer,
160                          const AecState& aec_state,
161                          rtc::ArrayView<SubtractorOutput> outputs) {
162   RTC_DCHECK_EQ(num_capture_channels_, capture.size());
163 
164   // Compute the render powers.
165   const bool same_filter_sizes = refined_filters_[0]->SizePartitions() ==
166                                  coarse_filter_[0]->SizePartitions();
167   std::array<float, kFftLengthBy2Plus1> X2_refined;
168   std::array<float, kFftLengthBy2Plus1> X2_coarse_data;
169   auto& X2_coarse = same_filter_sizes ? X2_refined : X2_coarse_data;
170   if (same_filter_sizes) {
171     render_buffer.SpectralSum(refined_filters_[0]->SizePartitions(),
172                               &X2_refined);
173   } else if (refined_filters_[0]->SizePartitions() >
174              coarse_filter_[0]->SizePartitions()) {
175     render_buffer.SpectralSums(coarse_filter_[0]->SizePartitions(),
176                                refined_filters_[0]->SizePartitions(),
177                                &X2_coarse, &X2_refined);
178   } else {
179     render_buffer.SpectralSums(refined_filters_[0]->SizePartitions(),
180                                coarse_filter_[0]->SizePartitions(), &X2_refined,
181                                &X2_coarse);
182   }
183 
184   // Process all capture channels
185   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
186     RTC_DCHECK_EQ(kBlockSize, capture[ch].size());
187     SubtractorOutput& output = outputs[ch];
188     rtc::ArrayView<const float> y = capture[ch];
189     FftData& E_refined = output.E_refined;
190     FftData E_coarse;
191     std::array<float, kBlockSize>& e_refined = output.e_refined;
192     std::array<float, kBlockSize>& e_coarse = output.e_coarse;
193 
194     FftData S;
195     FftData& G = S;
196 
197     // Form the outputs of the refined and coarse filters.
198     refined_filters_[ch]->Filter(render_buffer, &S);
199     PredictionError(fft_, S, y, &e_refined, &output.s_refined);
200 
201     coarse_filter_[ch]->Filter(render_buffer, &S);
202     PredictionError(fft_, S, y, &e_coarse, &output.s_coarse);
203 
204     // Compute the signal powers in the subtractor output.
205     output.ComputeMetrics(y);
206 
207     // Adjust the filter if needed.
208     bool refined_filters_adjusted = false;
209     filter_misadjustment_estimators_[ch].Update(output);
210     if (filter_misadjustment_estimators_[ch].IsAdjustmentNeeded()) {
211       float scale = filter_misadjustment_estimators_[ch].GetMisadjustment();
212       refined_filters_[ch]->ScaleFilter(scale);
213       for (auto& h_k : refined_impulse_responses_[ch]) {
214         h_k *= scale;
215       }
216       ScaleFilterOutput(y, scale, e_refined, output.s_refined);
217       filter_misadjustment_estimators_[ch].Reset();
218       refined_filters_adjusted = true;
219     }
220 
221     // Compute the FFts of the refined and coarse filter outputs.
222     fft_.ZeroPaddedFft(e_refined, Aec3Fft::Window::kHanning, &E_refined);
223     fft_.ZeroPaddedFft(e_coarse, Aec3Fft::Window::kHanning, &E_coarse);
224 
225     // Compute spectra for future use.
226     E_coarse.Spectrum(optimization_, output.E2_coarse);
227     E_refined.Spectrum(optimization_, output.E2_refined);
228 
229     // Update the refined filter.
230     if (!refined_filters_adjusted) {
231       std::array<float, kFftLengthBy2Plus1> erl;
232       ComputeErl(optimization_, refined_frequency_responses_[ch], erl);
233       refined_gains_[ch]->Compute(X2_refined, render_signal_analyzer, output,
234                                   erl, refined_filters_[ch]->SizePartitions(),
235                                   aec_state.SaturatedCapture(), &G);
236     } else {
237       G.re.fill(0.f);
238       G.im.fill(0.f);
239     }
240     refined_filters_[ch]->Adapt(render_buffer, G,
241                                 &refined_impulse_responses_[ch]);
242     refined_filters_[ch]->ComputeFrequencyResponse(
243         &refined_frequency_responses_[ch]);
244 
245     if (ch == 0) {
246       data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.re);
247       data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.im);
248     }
249 
250     // Update the coarse filter.
251     poor_coarse_filter_counters_[ch] =
252         output.e2_refined < output.e2_coarse
253             ? poor_coarse_filter_counters_[ch] + 1
254             : 0;
255     if (poor_coarse_filter_counters_[ch] < 5) {
256       coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_coarse,
257                                  coarse_filter_[ch]->SizePartitions(),
258                                  aec_state.SaturatedCapture(), &G);
259     } else {
260       poor_coarse_filter_counters_[ch] = 0;
261       coarse_filter_[ch]->SetFilter(refined_filters_[ch]->SizePartitions(),
262                                     refined_filters_[ch]->GetFilter());
263       coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_refined,
264                                  coarse_filter_[ch]->SizePartitions(),
265                                  aec_state.SaturatedCapture(), &G);
266     }
267 
268     coarse_filter_[ch]->Adapt(render_buffer, G);
269     if (ch == 0) {
270       data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.re);
271       data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.im);
272       filter_misadjustment_estimators_[ch].Dump(data_dumper_);
273       DumpFilters();
274     }
275 
276     std::for_each(e_refined.begin(), e_refined.end(),
277                   [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); });
278 
279     if (ch == 0) {
280       data_dumper_->DumpWav("aec3_refined_filters_output", kBlockSize,
281                             &e_refined[0], 16000, 1);
282       data_dumper_->DumpWav("aec3_coarse_filter_output", kBlockSize,
283                             &e_coarse[0], 16000, 1);
284     }
285   }
286 }
287 
Update(const SubtractorOutput & output)288 void Subtractor::FilterMisadjustmentEstimator::Update(
289     const SubtractorOutput& output) {
290   e2_acum_ += output.e2_refined;
291   y2_acum_ += output.y2;
292   if (++n_blocks_acum_ == n_blocks_) {
293     if (y2_acum_ > n_blocks_ * 200.f * 200.f * kBlockSize) {
294       float update = (e2_acum_ / y2_acum_);
295       if (e2_acum_ > n_blocks_ * 7500.f * 7500.f * kBlockSize) {
296         // Duration equal to blockSizeMs * n_blocks_ * 4.
297         overhang_ = 4;
298       } else {
299         overhang_ = std::max(overhang_ - 1, 0);
300       }
301 
302       if ((update < inv_misadjustment_) || (overhang_ > 0)) {
303         inv_misadjustment_ += 0.1f * (update - inv_misadjustment_);
304       }
305     }
306     e2_acum_ = 0.f;
307     y2_acum_ = 0.f;
308     n_blocks_acum_ = 0;
309   }
310 }
311 
Reset()312 void Subtractor::FilterMisadjustmentEstimator::Reset() {
313   e2_acum_ = 0.f;
314   y2_acum_ = 0.f;
315   n_blocks_acum_ = 0;
316   inv_misadjustment_ = 0.f;
317   overhang_ = 0.f;
318 }
319 
Dump(ApmDataDumper * data_dumper) const320 void Subtractor::FilterMisadjustmentEstimator::Dump(
321     ApmDataDumper* data_dumper) const {
322   data_dumper->DumpRaw("aec3_inv_misadjustment_factor", inv_misadjustment_);
323 }
324 
325 }  // namespace webrtc
326