1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/subtractor.h"
12
13 #include <algorithm>
14 #include <memory>
15 #include <numeric>
16 #include <string>
17
18 #include "modules/audio_processing/aec3/aec_state.h"
19 #include "modules/audio_processing/aec3/render_delay_buffer.h"
20 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
21 #include "modules/audio_processing/utility/cascaded_biquad_filter.h"
22 #include "rtc_base/random.h"
23 #include "rtc_base/strings/string_builder.h"
24 #include "test/gtest.h"
25
26 namespace webrtc {
27 namespace {
28
RunSubtractorTest(size_t num_render_channels,size_t num_capture_channels,int num_blocks_to_process,int delay_samples,int refined_filter_length_blocks,int coarse_filter_length_blocks,bool uncorrelated_inputs,const std::vector<int> & blocks_with_echo_path_changes)29 std::vector<float> RunSubtractorTest(
30 size_t num_render_channels,
31 size_t num_capture_channels,
32 int num_blocks_to_process,
33 int delay_samples,
34 int refined_filter_length_blocks,
35 int coarse_filter_length_blocks,
36 bool uncorrelated_inputs,
37 const std::vector<int>& blocks_with_echo_path_changes) {
38 ApmDataDumper data_dumper(42);
39 constexpr int kSampleRateHz = 48000;
40 constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
41 EchoCanceller3Config config;
42 config.filter.refined.length_blocks = refined_filter_length_blocks;
43 config.filter.coarse.length_blocks = coarse_filter_length_blocks;
44
45 Subtractor subtractor(config, num_render_channels, num_capture_channels,
46 &data_dumper, DetectOptimization());
47 absl::optional<DelayEstimate> delay_estimate;
48 std::vector<std::vector<std::vector<float>>> x(
49 kNumBands, std::vector<std::vector<float>>(
50 num_render_channels, std::vector<float>(kBlockSize, 0.f)));
51 std::vector<std::vector<float>> y(num_capture_channels,
52 std::vector<float>(kBlockSize, 0.f));
53 std::array<float, kBlockSize> x_old;
54 std::vector<SubtractorOutput> output(num_capture_channels);
55 config.delay.default_delay = 1;
56 std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
57 RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
58 RenderSignalAnalyzer render_signal_analyzer(config);
59 Random random_generator(42U);
60 Aec3Fft fft;
61 std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
62 std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined(
63 num_capture_channels);
64 std::array<float, kFftLengthBy2Plus1> E2_coarse;
65 AecState aec_state(config, num_capture_channels);
66 x_old.fill(0.f);
67 for (auto& Y2_ch : Y2) {
68 Y2_ch.fill(0.f);
69 }
70 for (auto& E2_refined_ch : E2_refined) {
71 E2_refined_ch.fill(0.f);
72 }
73 E2_coarse.fill(0.f);
74
75 std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer(
76 num_capture_channels);
77 for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
78 delay_buffer[capture_ch].resize(num_render_channels);
79 for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
80 delay_buffer[capture_ch][render_ch] =
81 std::make_unique<DelayBuffer<float>>(delay_samples);
82 }
83 }
84
85 // [B,A] = butter(2,100/8000,'high')
86 constexpr CascadedBiQuadFilter::BiQuadCoefficients
87 kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
88 {-1.94448f, 0.94598f}};
89 std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter(
90 num_render_channels);
91 for (size_t ch = 0; ch < num_render_channels; ++ch) {
92 x_hp_filter[ch] =
93 std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
94 }
95 std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter(
96 num_capture_channels);
97 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
98 y_hp_filter[ch] =
99 std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
100 }
101
102 for (int k = 0; k < num_blocks_to_process; ++k) {
103 for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
104 RandomizeSampleVector(&random_generator, x[0][render_ch]);
105 }
106 if (uncorrelated_inputs) {
107 for (size_t capture_ch = 0; capture_ch < num_capture_channels;
108 ++capture_ch) {
109 RandomizeSampleVector(&random_generator, y[capture_ch]);
110 }
111 } else {
112 for (size_t capture_ch = 0; capture_ch < num_capture_channels;
113 ++capture_ch) {
114 for (size_t render_ch = 0; render_ch < num_render_channels;
115 ++render_ch) {
116 std::array<float, kBlockSize> y_channel;
117 delay_buffer[capture_ch][render_ch]->Delay(x[0][render_ch],
118 y_channel);
119 for (size_t k = 0; k < y.size(); ++k) {
120 y[capture_ch][k] += y_channel[k] / num_render_channels;
121 }
122 }
123 }
124 }
125 for (size_t ch = 0; ch < num_render_channels; ++ch) {
126 x_hp_filter[ch]->Process(x[0][ch]);
127 }
128 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
129 y_hp_filter[ch]->Process(y[ch]);
130 }
131
132 render_delay_buffer->Insert(x);
133 if (k == 0) {
134 render_delay_buffer->Reset();
135 }
136 render_delay_buffer->PrepareCaptureProcessing();
137 render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(),
138 aec_state.MinDirectPathFilterDelay());
139
140 // Handle echo path changes.
141 if (std::find(blocks_with_echo_path_changes.begin(),
142 blocks_with_echo_path_changes.end(),
143 k) != blocks_with_echo_path_changes.end()) {
144 subtractor.HandleEchoPathChange(EchoPathVariability(
145 true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay,
146 false));
147 }
148 subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
149 render_signal_analyzer, aec_state, output);
150
151 aec_state.HandleEchoPathChange(EchoPathVariability(
152 false, EchoPathVariability::DelayAdjustment::kNone, false));
153 aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
154 subtractor.FilterImpulseResponses(),
155 *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2,
156 output);
157 }
158
159 std::vector<float> results(num_capture_channels);
160 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
161 const float output_power = std::inner_product(
162 output[ch].e_refined.begin(), output[ch].e_refined.end(),
163 output[ch].e_refined.begin(), 0.f);
164 const float y_power =
165 std::inner_product(y[ch].begin(), y[ch].end(), y[ch].begin(), 0.f);
166 if (y_power == 0.f) {
167 ADD_FAILURE();
168 results[ch] = -1.f;
169 }
170 results[ch] = output_power / y_power;
171 }
172 return results;
173 }
174
ProduceDebugText(size_t num_render_channels,size_t num_capture_channels,size_t delay,int filter_length_blocks)175 std::string ProduceDebugText(size_t num_render_channels,
176 size_t num_capture_channels,
177 size_t delay,
178 int filter_length_blocks) {
179 rtc::StringBuilder ss;
180 ss << "delay: " << delay << ", ";
181 ss << "filter_length_blocks:" << filter_length_blocks << ", ";
182 ss << "num_render_channels:" << num_render_channels << ", ";
183 ss << "num_capture_channels:" << num_capture_channels;
184 return ss.Release();
185 }
186
187 } // namespace
188
189 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
190
191 // Verifies that the check for non data dumper works.
TEST(SubtractorDeathTest,NullDataDumper)192 TEST(SubtractorDeathTest, NullDataDumper) {
193 EXPECT_DEATH(
194 Subtractor(EchoCanceller3Config(), 1, 1, nullptr, DetectOptimization()),
195 "");
196 }
197
198 // Verifies the check for the capture signal size.
TEST(Subtractor,WrongCaptureSize)199 TEST(Subtractor, WrongCaptureSize) {
200 ApmDataDumper data_dumper(42);
201 EchoCanceller3Config config;
202 Subtractor subtractor(config, 1, 1, &data_dumper, DetectOptimization());
203 std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
204 RenderDelayBuffer::Create(config, 48000, 1));
205 RenderSignalAnalyzer render_signal_analyzer(config);
206 std::vector<std::vector<float>> y(1, std::vector<float>(kBlockSize - 1, 0.f));
207 std::array<SubtractorOutput, 1> output;
208
209 EXPECT_DEATH(
210 subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
211 render_signal_analyzer, AecState(config, 1), output),
212 "");
213 }
214
215 #endif
216
217 // Verifies that the subtractor is able to converge on correlated data.
TEST(Subtractor,Convergence)218 TEST(Subtractor, Convergence) {
219 std::vector<int> blocks_with_echo_path_changes;
220 for (size_t filter_length_blocks : {12, 20, 30}) {
221 for (size_t delay_samples : {0, 64, 150, 200, 301}) {
222 SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
223 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
224 1, 1, 2500, delay_samples, filter_length_blocks, filter_length_blocks,
225 false, blocks_with_echo_path_changes);
226
227 for (float echo_to_nearend_power : echo_to_nearend_powers) {
228 EXPECT_GT(0.1f, echo_to_nearend_power);
229 }
230 }
231 }
232 }
233
234 // Verifies that the subtractor is able to handle the case when the refined
235 // filter is longer than the coarse filter.
TEST(Subtractor,RefinedFilterLongerThanCoarseFilter)236 TEST(Subtractor, RefinedFilterLongerThanCoarseFilter) {
237 std::vector<int> blocks_with_echo_path_changes;
238 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
239 1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes);
240 for (float echo_to_nearend_power : echo_to_nearend_powers) {
241 EXPECT_GT(0.5f, echo_to_nearend_power);
242 }
243 }
244
245 // Verifies that the subtractor is able to handle the case when the coarse
246 // filter is longer than the refined filter.
TEST(Subtractor,CoarseFilterLongerThanRefinedFilter)247 TEST(Subtractor, CoarseFilterLongerThanRefinedFilter) {
248 std::vector<int> blocks_with_echo_path_changes;
249 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
250 1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes);
251 for (float echo_to_nearend_power : echo_to_nearend_powers) {
252 EXPECT_GT(0.5f, echo_to_nearend_power);
253 }
254 }
255
256 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST(Subtractor,NonConvergenceOnUncorrelatedSignals)257 TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) {
258 std::vector<int> blocks_with_echo_path_changes;
259 for (size_t filter_length_blocks : {12, 20, 30}) {
260 for (size_t delay_samples : {0, 64, 150, 200, 301}) {
261 SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
262
263 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
264 1, 1, 3000, delay_samples, filter_length_blocks, filter_length_blocks,
265 true, blocks_with_echo_path_changes);
266 for (float echo_to_nearend_power : echo_to_nearend_powers) {
267 EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1);
268 }
269 }
270 }
271 }
272
273 class SubtractorMultiChannelUpToEightRender
274 : public ::testing::Test,
275 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
276
277 #if defined(NDEBUG)
278 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
279 SubtractorMultiChannelUpToEightRender,
280 ::testing::Combine(::testing::Values(1, 2, 8),
281 ::testing::Values(1, 2, 4)));
282 #else
283 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
284 SubtractorMultiChannelUpToEightRender,
285 ::testing::Combine(::testing::Values(1, 2),
286 ::testing::Values(1, 2)));
287 #endif
288
289 // Verifies that the subtractor is able to converge on correlated data.
TEST_P(SubtractorMultiChannelUpToEightRender,Convergence)290 TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) {
291 const size_t num_render_channels = std::get<0>(GetParam());
292 const size_t num_capture_channels = std::get<1>(GetParam());
293
294 std::vector<int> blocks_with_echo_path_changes;
295 size_t num_blocks_to_process = 2500 * num_render_channels;
296 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
297 num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
298 20, false, blocks_with_echo_path_changes);
299
300 for (float echo_to_nearend_power : echo_to_nearend_powers) {
301 EXPECT_GT(0.1f, echo_to_nearend_power);
302 }
303 }
304
305 class SubtractorMultiChannelUpToFourRender
306 : public ::testing::Test,
307 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
308
309 #if defined(NDEBUG)
310 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel,
311 SubtractorMultiChannelUpToFourRender,
312 ::testing::Combine(::testing::Values(1, 2, 4),
313 ::testing::Values(1, 2, 4)));
314 #else
315 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel,
316 SubtractorMultiChannelUpToFourRender,
317 ::testing::Combine(::testing::Values(1, 2),
318 ::testing::Values(1, 2)));
319 #endif
320
321 // Verifies that the subtractor does not converge on uncorrelated signals.
TEST_P(SubtractorMultiChannelUpToFourRender,NonConvergenceOnUncorrelatedSignals)322 TEST_P(SubtractorMultiChannelUpToFourRender,
323 NonConvergenceOnUncorrelatedSignals) {
324 const size_t num_render_channels = std::get<0>(GetParam());
325 const size_t num_capture_channels = std::get<1>(GetParam());
326
327 std::vector<int> blocks_with_echo_path_changes;
328 size_t num_blocks_to_process = 5000 * num_render_channels;
329 std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
330 num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20,
331 20, true, blocks_with_echo_path_changes);
332 for (float echo_to_nearend_power : echo_to_nearend_powers) {
333 EXPECT_LT(.8f, echo_to_nearend_power);
334 EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f);
335 }
336 }
337 } // namespace webrtc
338