• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/subtractor.h"
12 
13 #include <algorithm>
14 #include <memory>
15 #include <numeric>
16 #include <string>
17 
18 #include "modules/audio_processing/aec3/aec_state.h"
19 #include "modules/audio_processing/aec3/render_delay_buffer.h"
20 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
21 #include "modules/audio_processing/utility/cascaded_biquad_filter.h"
22 #include "rtc_base/random.h"
23 #include "rtc_base/strings/string_builder.h"
24 #include "test/gtest.h"
25 
26 namespace webrtc {
27 namespace {
28 
RunSubtractorTest(size_t num_render_channels,size_t num_capture_channels,int num_blocks_to_process,int delay_samples,int refined_filter_length_blocks,int coarse_filter_length_blocks,bool uncorrelated_inputs,const std::vector<int> & blocks_with_echo_path_changes)29 std::vector<float> RunSubtractorTest(
30     size_t num_render_channels,
31     size_t num_capture_channels,
32     int num_blocks_to_process,
33     int delay_samples,
34     int refined_filter_length_blocks,
35     int coarse_filter_length_blocks,
36     bool uncorrelated_inputs,
37     const std::vector<int>& blocks_with_echo_path_changes) {
38   ApmDataDumper data_dumper(42);
39   constexpr int kSampleRateHz = 48000;
40   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
41   EchoCanceller3Config config;
42   config.filter.refined.length_blocks = refined_filter_length_blocks;
43   config.filter.coarse.length_blocks = coarse_filter_length_blocks;
44 
45   Subtractor subtractor(config, num_render_channels, num_capture_channels,
46                         &data_dumper, DetectOptimization());
47   absl::optional<DelayEstimate> delay_estimate;
48   std::vector<std::vector<std::vector<float>>> x(
49       kNumBands, std::vector<std::vector<float>>(
50                      num_render_channels, std::vector<float>(kBlockSize, 0.f)));
51   std::vector<std::vector<float>> y(num_capture_channels,
52                                     std::vector<float>(kBlockSize, 0.f));
53   std::array<float, kBlockSize> x_old;
54   std::vector<SubtractorOutput> output(num_capture_channels);
55   config.delay.default_delay = 1;
56   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
57       RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
58   RenderSignalAnalyzer render_signal_analyzer(config);
59   Random random_generator(42U);
60   Aec3Fft fft;
61   std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
62   std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined(
63       num_capture_channels);
64   std::array<float, kFftLengthBy2Plus1> E2_coarse;
65   AecState aec_state(config, num_capture_channels);
66   x_old.fill(0.f);
67   for (auto& Y2_ch : Y2) {
68     Y2_ch.fill(0.f);
69   }
70   for (auto& E2_refined_ch : E2_refined) {
71     E2_refined_ch.fill(0.f);
72   }
73   E2_coarse.fill(0.f);
74 
75   std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer(
76       num_capture_channels);
77   for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
78     delay_buffer[capture_ch].resize(num_render_channels);
79     for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
80       delay_buffer[capture_ch][render_ch] =
81           std::make_unique<DelayBuffer<float>>(delay_samples);
82     }
83   }
84 
85   // [B,A] = butter(2,100/8000,'high')
86   constexpr CascadedBiQuadFilter::BiQuadCoefficients
87       kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
88                                      {-1.94448f, 0.94598f}};
89   std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter(
90       num_render_channels);
91   for (size_t ch = 0; ch < num_render_channels; ++ch) {
92     x_hp_filter[ch] =
93         std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
94   }
95   std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter(
96       num_capture_channels);
97   for (size_t ch = 0; ch < num_capture_channels; ++ch) {
98     y_hp_filter[ch] =
99         std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
100   }
101 
102   for (int k = 0; k < num_blocks_to_process; ++k) {
103     for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
104       RandomizeSampleVector(&random_generator, x[0][render_ch]);
105     }
106     if (uncorrelated_inputs) {
107       for (size_t capture_ch = 0; capture_ch < num_capture_channels;
108            ++capture_ch) {
109         RandomizeSampleVector(&random_generator, y[capture_ch]);
110       }
111     } else {
112       for (size_t capture_ch = 0; capture_ch < num_capture_channels;
113            ++capture_ch) {
114         for (size_t render_ch = 0; render_ch < num_render_channels;
115              ++render_ch) {
116           std::array<float, kBlockSize> y_channel;
117           delay_buffer[capture_ch][render_ch]->Delay(x[0][render_ch],
118                                                      y_channel);
119           for (size_t k = 0; k < y.size(); ++k) {
120             y[capture_ch][k] += y_channel[k] / num_render_channels;
121           }
122         }
123       }
124     }
125     for (size_t ch = 0; ch < num_render_channels; ++ch) {
126       x_hp_filter[ch]->Process(x[0][ch]);
127     }
128     for (size_t ch = 0; ch < num_capture_channels; ++ch) {
129       y_hp_filter[ch]->Process(y[ch]);
130     }
131 
132     render_delay_buffer->Insert(x);
133     if (k == 0) {
134       render_delay_buffer->Reset();
135     }
136     render_delay_buffer->PrepareCaptureProcessing();
137     render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(),
138                                   aec_state.MinDirectPathFilterDelay());
139 
140     // Handle echo path changes.
141     if (std::find(blocks_with_echo_path_changes.begin(),
142                   blocks_with_echo_path_changes.end(),
143                   k) != blocks_with_echo_path_changes.end()) {
144       subtractor.HandleEchoPathChange(EchoPathVariability(
145           true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay,
146           false));
147     }
148     subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
149                        render_signal_analyzer, aec_state, output);
150 
151     aec_state.HandleEchoPathChange(EchoPathVariability(
152         false, EchoPathVariability::DelayAdjustment::kNone, false));
153     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
154                      subtractor.FilterImpulseResponses(),
155                      *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2,
156                      output);
157   }
158 
159   std::vector<float> results(num_capture_channels);
160   for (size_t ch = 0; ch < num_capture_channels; ++ch) {
161     const float output_power = std::inner_product(
162         output[ch].e_refined.begin(), output[ch].e_refined.end(),
163         output[ch].e_refined.begin(), 0.f);
164     const float y_power =
165         std::inner_product(y[ch].begin(), y[ch].end(), y[ch].begin(), 0.f);
166     if (y_power == 0.f) {
167       ADD_FAILURE();
168       results[ch] = -1.f;
169     }
170     results[ch] = output_power / y_power;
171   }
172   return results;
173 }
174 
ProduceDebugText(size_t num_render_channels,size_t num_capture_channels,size_t delay,int filter_length_blocks)175 std::string ProduceDebugText(size_t num_render_channels,
176                              size_t num_capture_channels,
177                              size_t delay,
178                              int filter_length_blocks) {
179   rtc::StringBuilder ss;
180   ss << "delay: " << delay << ", ";
181   ss << "filter_length_blocks:" << filter_length_blocks << ", ";
182   ss << "num_render_channels:" << num_render_channels << ", ";
183   ss << "num_capture_channels:" << num_capture_channels;
184   return ss.Release();
185 }
186 
187 }  // namespace
188 
189 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
190 
191 // Verifies that the check for non data dumper works.
TEST(SubtractorDeathTest,NullDataDumper)192 TEST(SubtractorDeathTest, NullDataDumper) {
193   EXPECT_DEATH(
194       Subtractor(EchoCanceller3Config(), 1, 1, nullptr, DetectOptimization()),
195       "");
196 }
197 
198 // Verifies the check for the capture signal size.
TEST(Subtractor,WrongCaptureSize)199 TEST(Subtractor, WrongCaptureSize) {
200   ApmDataDumper data_dumper(42);
201   EchoCanceller3Config config;
202   Subtractor subtractor(config, 1, 1, &data_dumper, DetectOptimization());
203   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
204       RenderDelayBuffer::Create(config, 48000, 1));
205   RenderSignalAnalyzer render_signal_analyzer(config);
206   std::vector<std::vector<float>> y(1, std::vector<float>(kBlockSize - 1, 0.f));
207   std::array<SubtractorOutput, 1> output;
208 
209   EXPECT_DEATH(
210       subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
211                          render_signal_analyzer, AecState(config, 1), output),
212       "");
213 }
214 
215 #endif
216 
217 // Verifies that the subtractor is able to converge on correlated data.
TEST(Subtractor,Convergence)218 TEST(Subtractor, Convergence) {
219   std::vector<int> blocks_with_echo_path_changes;
220   for (size_t filter_length_blocks : {12, 20, 30}) {
221     for (size_t delay_samples : {0, 64, 150, 200, 301}) {
222       SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
223       std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
224           1, 1, 2500, delay_samples, filter_length_blocks, filter_length_blocks,
225           false, blocks_with_echo_path_changes);
226 
227       for (float echo_to_nearend_power : echo_to_nearend_powers) {
228         EXPECT_GT(0.1f, echo_to_nearend_power);
229       }
230     }
231   }
232 }
233 
234 // Verifies that the subtractor is able to handle the case when the refined
235 // filter is longer than the coarse filter.
TEST(Subtractor,RefinedFilterLongerThanCoarseFilter)236 TEST(Subtractor, RefinedFilterLongerThanCoarseFilter) {
237   std::vector<int> blocks_with_echo_path_changes;
238   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
239       1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes);
240   for (float echo_to_nearend_power : echo_to_nearend_powers) {
241     EXPECT_GT(0.5f, echo_to_nearend_power);
242   }
243 }
244 
245 // Verifies that the subtractor is able to handle the case when the coarse
246 // filter is longer than the refined filter.
TEST(Subtractor,CoarseFilterLongerThanRefinedFilter)247 TEST(Subtractor, CoarseFilterLongerThanRefinedFilter) {
248   std::vector<int> blocks_with_echo_path_changes;
249   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
250       1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes);
251   for (float echo_to_nearend_power : echo_to_nearend_powers) {
252     EXPECT_GT(0.5f, echo_to_nearend_power);
253   }
254 }
255 
256 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST(Subtractor,NonConvergenceOnUncorrelatedSignals)257 TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) {
258   std::vector<int> blocks_with_echo_path_changes;
259   for (size_t filter_length_blocks : {12, 20, 30}) {
260     for (size_t delay_samples : {0, 64, 150, 200, 301}) {
261       SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
262 
263       std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
264           1, 1, 3000, delay_samples, filter_length_blocks, filter_length_blocks,
265           true, blocks_with_echo_path_changes);
266       for (float echo_to_nearend_power : echo_to_nearend_powers) {
267         EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1);
268       }
269     }
270   }
271 }
272 
273 class SubtractorMultiChannelUpToEightRender
274     : public ::testing::Test,
275       public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
276 
277 #if defined(NDEBUG)
278 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
279                          SubtractorMultiChannelUpToEightRender,
280                          ::testing::Combine(::testing::Values(1, 2, 8),
281                                             ::testing::Values(1, 2, 4)));
282 #else
283 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
284                          SubtractorMultiChannelUpToEightRender,
285                          ::testing::Combine(::testing::Values(1, 2),
286                                             ::testing::Values(1, 2)));
287 #endif
288 
289 // Verifies that the subtractor is able to converge on correlated data.
TEST_P(SubtractorMultiChannelUpToEightRender,Convergence)290 TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) {
291   const size_t num_render_channels = std::get<0>(GetParam());
292   const size_t num_capture_channels = std::get<1>(GetParam());
293 
294   std::vector<int> blocks_with_echo_path_changes;
295   size_t num_blocks_to_process = 2500 * num_render_channels;
296   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
297       num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
298       20, false, blocks_with_echo_path_changes);
299 
300   for (float echo_to_nearend_power : echo_to_nearend_powers) {
301     EXPECT_GT(0.1f, echo_to_nearend_power);
302   }
303 }
304 
305 class SubtractorMultiChannelUpToFourRender
306     : public ::testing::Test,
307       public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
308 
309 #if defined(NDEBUG)
310 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
311                          SubtractorMultiChannelUpToFourRender,
312                          ::testing::Combine(::testing::Values(1, 2, 4),
313                                             ::testing::Values(1, 2, 4)));
314 #else
315 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
316                          SubtractorMultiChannelUpToFourRender,
317                          ::testing::Combine(::testing::Values(1, 2),
318                                             ::testing::Values(1, 2)));
319 #endif
320 
321 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST_P(SubtractorMultiChannelUpToFourRender,NonConvergenceOnUncorrelatedSignals)322 TEST_P(SubtractorMultiChannelUpToFourRender,
323        NonConvergenceOnUncorrelatedSignals) {
324   const size_t num_render_channels = std::get<0>(GetParam());
325   const size_t num_capture_channels = std::get<1>(GetParam());
326 
327   std::vector<int> blocks_with_echo_path_changes;
328   size_t num_blocks_to_process = 5000 * num_render_channels;
329   std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
330       num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
331       20, true, blocks_with_echo_path_changes);
332   for (float echo_to_nearend_power : echo_to_nearend_powers) {
333     EXPECT_LT(.8f, echo_to_nearend_power);
334     EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f);
335   }
336 }
337 }  // namespace webrtc
338