• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/matched_filter.h"
12 
13 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
14 #include "rtc_base/system/arch.h"
15 
16 #if defined(WEBRTC_ARCH_X86_FAMILY)
17 #include <emmintrin.h>
18 #endif
19 #include <algorithm>
20 #include <string>
21 
22 #include "modules/audio_processing/aec3/aec3_common.h"
23 #include "modules/audio_processing/aec3/decimator.h"
24 #include "modules/audio_processing/aec3/render_delay_buffer.h"
25 #include "modules/audio_processing/logging/apm_data_dumper.h"
26 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
27 #include "rtc_base/random.h"
28 #include "rtc_base/strings/string_builder.h"
29 #include "system_wrappers/include/cpu_features_wrapper.h"
30 #include "test/field_trial.h"
31 #include "test/gtest.h"
32 
33 namespace webrtc {
34 namespace aec3 {
35 namespace {
36 
ProduceDebugText(size_t delay,size_t down_sampling_factor)37 std::string ProduceDebugText(size_t delay, size_t down_sampling_factor) {
38   rtc::StringBuilder ss;
39   ss << "Delay: " << delay;
40   ss << ", Down sampling factor: " << down_sampling_factor;
41   return ss.Release();
42 }
43 
44 constexpr size_t kNumMatchedFilters = 10;
45 constexpr size_t kDownSamplingFactors[] = {2, 4, 8};
46 constexpr size_t kWindowSizeSubBlocks = 32;
47 constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4;
48 
49 }  // namespace
50 
51 class MatchedFilterTest : public ::testing::TestWithParam<bool> {};
52 
53 #if defined(WEBRTC_HAS_NEON)
54 // Verifies that the optimized methods for NEON are similar to their reference
55 // counterparts.
TEST_P(MatchedFilterTest,TestNeonOptimizations)56 TEST_P(MatchedFilterTest, TestNeonOptimizations) {
57   Random random_generator(42U);
58   constexpr float kSmoothing = 0.7f;
59   const bool kComputeAccumulatederror = GetParam();
60   for (auto down_sampling_factor : kDownSamplingFactors) {
61     const size_t sub_block_size = kBlockSize / down_sampling_factor;
62 
63     std::vector<float> x(2000);
64     RandomizeSampleVector(&random_generator, x);
65     std::vector<float> y(sub_block_size);
66     std::vector<float> h_NEON(512);
67     std::vector<float> h(512);
68     std::vector<float> accumulated_error(512);
69     std::vector<float> accumulated_error_NEON(512);
70     std::vector<float> scratch_memory(512);
71 
72     int x_index = 0;
73     for (int k = 0; k < 1000; ++k) {
74       RandomizeSampleVector(&random_generator, y);
75 
76       bool filters_updated = false;
77       float error_sum = 0.f;
78       bool filters_updated_NEON = false;
79       float error_sum_NEON = 0.f;
80 
81       MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
82                              y, h_NEON, &filters_updated_NEON, &error_sum_NEON,
83                              kComputeAccumulatederror, accumulated_error_NEON,
84                              scratch_memory);
85 
86       MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h,
87                         &filters_updated, &error_sum, kComputeAccumulatederror,
88                         accumulated_error);
89 
90       EXPECT_EQ(filters_updated, filters_updated_NEON);
91       EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f);
92 
93       for (size_t j = 0; j < h.size(); ++j) {
94         EXPECT_NEAR(h[j], h_NEON[j], 0.00001f);
95       }
96 
97       if (kComputeAccumulatederror) {
98         for (size_t j = 0; j < accumulated_error.size(); ++j) {
99           float difference =
100               std::abs(accumulated_error[j] - accumulated_error_NEON[j]);
101           float relative_difference = accumulated_error[j] > 0
102                                           ? difference / accumulated_error[j]
103                                           : difference;
104           EXPECT_NEAR(relative_difference, 0.0f, 0.02f);
105         }
106       }
107 
108       x_index = (x_index + sub_block_size) % x.size();
109     }
110   }
111 }
112 #endif
113 
114 #if defined(WEBRTC_ARCH_X86_FAMILY)
115 // Verifies that the optimized methods for SSE2 are bitexact to their reference
116 // counterparts.
TEST_P(MatchedFilterTest,TestSse2Optimizations)117 TEST_P(MatchedFilterTest, TestSse2Optimizations) {
118   const bool kComputeAccumulatederror = GetParam();
119   bool use_sse2 = (GetCPUInfo(kSSE2) != 0);
120   if (use_sse2) {
121     Random random_generator(42U);
122     constexpr float kSmoothing = 0.7f;
123     for (auto down_sampling_factor : kDownSamplingFactors) {
124       const size_t sub_block_size = kBlockSize / down_sampling_factor;
125       std::vector<float> x(2000);
126       RandomizeSampleVector(&random_generator, x);
127       std::vector<float> y(sub_block_size);
128       std::vector<float> h_SSE2(512);
129       std::vector<float> h(512);
130       std::vector<float> accumulated_error(512 / 4);
131       std::vector<float> accumulated_error_SSE2(512 / 4);
132       std::vector<float> scratch_memory(512);
133       int x_index = 0;
134       for (int k = 0; k < 1000; ++k) {
135         RandomizeSampleVector(&random_generator, y);
136 
137         bool filters_updated = false;
138         float error_sum = 0.f;
139         bool filters_updated_SSE2 = false;
140         float error_sum_SSE2 = 0.f;
141 
142         MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
143                                y, h_SSE2, &filters_updated_SSE2,
144                                &error_sum_SSE2, kComputeAccumulatederror,
145                                accumulated_error_SSE2, scratch_memory);
146 
147         MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
148                           h, &filters_updated, &error_sum,
149                           kComputeAccumulatederror, accumulated_error);
150 
151         EXPECT_EQ(filters_updated, filters_updated_SSE2);
152         EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
153 
154         for (size_t j = 0; j < h.size(); ++j) {
155           EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
156         }
157 
158         for (size_t j = 0; j < accumulated_error.size(); ++j) {
159           float difference =
160               std::abs(accumulated_error[j] - accumulated_error_SSE2[j]);
161           float relative_difference = accumulated_error[j] > 0
162                                           ? difference / accumulated_error[j]
163                                           : difference;
164           EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
165         }
166 
167         x_index = (x_index + sub_block_size) % x.size();
168       }
169     }
170   }
171 }
172 
TEST_P(MatchedFilterTest,TestAvx2Optimizations)173 TEST_P(MatchedFilterTest, TestAvx2Optimizations) {
174   bool use_avx2 = (GetCPUInfo(kAVX2) != 0);
175   const bool kComputeAccumulatederror = GetParam();
176   if (use_avx2) {
177     Random random_generator(42U);
178     constexpr float kSmoothing = 0.7f;
179     for (auto down_sampling_factor : kDownSamplingFactors) {
180       const size_t sub_block_size = kBlockSize / down_sampling_factor;
181       std::vector<float> x(2000);
182       RandomizeSampleVector(&random_generator, x);
183       std::vector<float> y(sub_block_size);
184       std::vector<float> h_AVX2(512);
185       std::vector<float> h(512);
186       std::vector<float> accumulated_error(512 / 4);
187       std::vector<float> accumulated_error_AVX2(512 / 4);
188       std::vector<float> scratch_memory(512);
189       int x_index = 0;
190       for (int k = 0; k < 1000; ++k) {
191         RandomizeSampleVector(&random_generator, y);
192         bool filters_updated = false;
193         float error_sum = 0.f;
194         bool filters_updated_AVX2 = false;
195         float error_sum_AVX2 = 0.f;
196         MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
197                                y, h_AVX2, &filters_updated_AVX2,
198                                &error_sum_AVX2, kComputeAccumulatederror,
199                                accumulated_error_AVX2, scratch_memory);
200         MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
201                           h, &filters_updated, &error_sum,
202                           kComputeAccumulatederror, accumulated_error);
203         EXPECT_EQ(filters_updated, filters_updated_AVX2);
204         EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f);
205         for (size_t j = 0; j < h.size(); ++j) {
206           EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f);
207         }
208         for (size_t j = 0; j < accumulated_error.size(); j += 4) {
209           float difference =
210               std::abs(accumulated_error[j] - accumulated_error_AVX2[j]);
211           float relative_difference = accumulated_error[j] > 0
212                                           ? difference / accumulated_error[j]
213                                           : difference;
214           EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
215         }
216         x_index = (x_index + sub_block_size) % x.size();
217       }
218     }
219   }
220 }
221 
222 #endif
223 
224 // Verifies that the (optimized) function MaxSquarePeakIndex() produces output
225 // equal to the corresponding std-functions.
TEST(MatchedFilter,MaxSquarePeakIndex)226 TEST(MatchedFilter, MaxSquarePeakIndex) {
227   Random random_generator(42U);
228   constexpr int kMaxLength = 128;
229   constexpr int kNumIterationsPerLength = 256;
230   for (int length = 1; length < kMaxLength; ++length) {
231     std::vector<float> y(length);
232     for (int i = 0; i < kNumIterationsPerLength; ++i) {
233       RandomizeSampleVector(&random_generator, y);
234 
235       size_t lag_from_function = MaxSquarePeakIndex(y);
236       size_t lag_from_std = std::distance(
237           y.begin(),
238           std::max_element(y.begin(), y.end(), [](float a, float b) -> bool {
239             return a * a < b * b;
240           }));
241       EXPECT_EQ(lag_from_function, lag_from_std);
242     }
243   }
244 }
245 
246 // Verifies that the matched filter produces proper lag estimates for
247 // artificially delayed signals.
TEST_P(MatchedFilterTest,LagEstimation)248 TEST_P(MatchedFilterTest, LagEstimation) {
249   const bool kDetectPreEcho = GetParam();
250   Random random_generator(42U);
251   constexpr size_t kNumChannels = 1;
252   constexpr int kSampleRateHz = 48000;
253   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
254 
255   for (auto down_sampling_factor : kDownSamplingFactors) {
256     const size_t sub_block_size = kBlockSize / down_sampling_factor;
257 
258     Block render(kNumBands, kNumChannels);
259     std::vector<std::vector<float>> capture(
260         1, std::vector<float>(kBlockSize, 0.f));
261     ApmDataDumper data_dumper(0);
262     for (size_t delay_samples : {5, 64, 150, 200, 800, 1000}) {
263       SCOPED_TRACE(ProduceDebugText(delay_samples, down_sampling_factor));
264       EchoCanceller3Config config;
265       config.delay.down_sampling_factor = down_sampling_factor;
266       config.delay.num_filters = kNumMatchedFilters;
267       Decimator capture_decimator(down_sampling_factor);
268       DelayBuffer<float> signal_delay_buffer(down_sampling_factor *
269                                              delay_samples);
270       MatchedFilter filter(
271           &data_dumper, DetectOptimization(), sub_block_size,
272           kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks,
273           150, config.delay.delay_estimate_smoothing,
274           config.delay.delay_estimate_smoothing_delay_found,
275           config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
276 
277       std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
278           RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
279 
280       // Analyze the correlation between render and capture.
281       for (size_t k = 0; k < (600 + delay_samples / sub_block_size); ++k) {
282         for (size_t band = 0; band < kNumBands; ++band) {
283           for (size_t channel = 0; channel < kNumChannels; ++channel) {
284             RandomizeSampleVector(&random_generator,
285                                   render.View(band, channel));
286           }
287         }
288         signal_delay_buffer.Delay(render.View(/*band=*/0, /*channel=*/0),
289                                   capture[0]);
290         render_delay_buffer->Insert(render);
291 
292         if (k == 0) {
293           render_delay_buffer->Reset();
294         }
295 
296         render_delay_buffer->PrepareCaptureProcessing();
297         std::array<float, kBlockSize> downsampled_capture_data;
298         rtc::ArrayView<float> downsampled_capture(
299             downsampled_capture_data.data(), sub_block_size);
300         capture_decimator.Decimate(capture[0], downsampled_capture);
301         filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
302                       downsampled_capture, /*use_slow_smoothing=*/false);
303       }
304 
305       // Obtain the lag estimates.
306       auto lag_estimate = filter.GetBestLagEstimate();
307       EXPECT_TRUE(lag_estimate.has_value());
308 
309       // Verify that the expected most accurate lag estimate is correct.
310       if (lag_estimate.has_value()) {
311         EXPECT_EQ(delay_samples, lag_estimate->lag);
312         EXPECT_EQ(delay_samples, lag_estimate->pre_echo_lag);
313       }
314     }
315   }
316 }
317 
318 // Test the pre echo estimation.
TEST_P(MatchedFilterTest,PreEchoEstimation)319 TEST_P(MatchedFilterTest, PreEchoEstimation) {
320   const bool kDetectPreEcho = GetParam();
321   Random random_generator(42U);
322   constexpr size_t kNumChannels = 1;
323   constexpr int kSampleRateHz = 48000;
324   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
325 
326   for (auto down_sampling_factor : kDownSamplingFactors) {
327     const size_t sub_block_size = kBlockSize / down_sampling_factor;
328 
329     Block render(kNumBands, kNumChannels);
330     std::vector<std::vector<float>> capture(
331         1, std::vector<float>(kBlockSize, 0.f));
332     std::vector<float> capture_with_pre_echo(kBlockSize, 0.f);
333     ApmDataDumper data_dumper(0);
334     // data_dumper.SetActivated(true);
335     size_t pre_echo_delay_samples = 20e-3 * 16000 / down_sampling_factor;
336     size_t echo_delay_samples = 50e-3 * 16000 / down_sampling_factor;
337     EchoCanceller3Config config;
338     config.delay.down_sampling_factor = down_sampling_factor;
339     config.delay.num_filters = kNumMatchedFilters;
340     Decimator capture_decimator(down_sampling_factor);
341     DelayBuffer<float> signal_echo_delay_buffer(down_sampling_factor *
342                                                 echo_delay_samples);
343     DelayBuffer<float> signal_pre_echo_delay_buffer(down_sampling_factor *
344                                                     pre_echo_delay_samples);
345     MatchedFilter filter(
346         &data_dumper, DetectOptimization(), sub_block_size,
347         kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
348         config.delay.delay_estimate_smoothing,
349         config.delay.delay_estimate_smoothing_delay_found,
350         config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
351     std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
352         RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
353     // Analyze the correlation between render and capture.
354     for (size_t k = 0; k < (600 + echo_delay_samples / sub_block_size); ++k) {
355       for (size_t band = 0; band < kNumBands; ++band) {
356         for (size_t channel = 0; channel < kNumChannels; ++channel) {
357           RandomizeSampleVector(&random_generator, render.View(band, channel));
358         }
359       }
360       signal_echo_delay_buffer.Delay(render.View(0, 0), capture[0]);
361       signal_pre_echo_delay_buffer.Delay(render.View(0, 0),
362                                          capture_with_pre_echo);
363       for (size_t k = 0; k < capture[0].size(); ++k) {
364         constexpr float gain_pre_echo = 0.8f;
365         capture[0][k] += gain_pre_echo * capture_with_pre_echo[k];
366       }
367       render_delay_buffer->Insert(render);
368       if (k == 0) {
369         render_delay_buffer->Reset();
370       }
371       render_delay_buffer->PrepareCaptureProcessing();
372       std::array<float, kBlockSize> downsampled_capture_data;
373       rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
374                                                 sub_block_size);
375       capture_decimator.Decimate(capture[0], downsampled_capture);
376       filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
377                     downsampled_capture, /*use_slow_smoothing=*/false);
378     }
379     // Obtain the lag estimates.
380     auto lag_estimate = filter.GetBestLagEstimate();
381     EXPECT_TRUE(lag_estimate.has_value());
382     // Verify that the expected most accurate lag estimate is correct.
383     if (lag_estimate.has_value()) {
384       EXPECT_EQ(echo_delay_samples, lag_estimate->lag);
385       if (kDetectPreEcho) {
386         // The pre echo delay is estimated in a subsampled domain and a larger
387         // error is allowed.
388         EXPECT_NEAR(pre_echo_delay_samples, lag_estimate->pre_echo_lag, 4);
389       } else {
390         // The pre echo delay fallback to the highest mached filter peak when
391         // its detection is disabled.
392         EXPECT_EQ(echo_delay_samples, lag_estimate->pre_echo_lag);
393       }
394     }
395   }
396 }
397 
398 // Verifies that the matched filter does not produce reliable and accurate
399 // estimates for uncorrelated render and capture signals.
TEST_P(MatchedFilterTest,LagNotReliableForUncorrelatedRenderAndCapture)400 TEST_P(MatchedFilterTest, LagNotReliableForUncorrelatedRenderAndCapture) {
401   const bool kDetectPreEcho = GetParam();
402   constexpr size_t kNumChannels = 1;
403   constexpr int kSampleRateHz = 48000;
404   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
405   Random random_generator(42U);
406   for (auto down_sampling_factor : kDownSamplingFactors) {
407     EchoCanceller3Config config;
408     config.delay.down_sampling_factor = down_sampling_factor;
409     config.delay.num_filters = kNumMatchedFilters;
410     const size_t sub_block_size = kBlockSize / down_sampling_factor;
411 
412     Block render(kNumBands, kNumChannels);
413     std::array<float, kBlockSize> capture_data;
414     rtc::ArrayView<float> capture(capture_data.data(), sub_block_size);
415     std::fill(capture.begin(), capture.end(), 0.f);
416     ApmDataDumper data_dumper(0);
417     std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
418         RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
419     MatchedFilter filter(
420         &data_dumper, DetectOptimization(), sub_block_size,
421         kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
422         config.delay.delay_estimate_smoothing,
423         config.delay.delay_estimate_smoothing_delay_found,
424         config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
425 
426     // Analyze the correlation between render and capture.
427     for (size_t k = 0; k < 100; ++k) {
428       RandomizeSampleVector(&random_generator,
429                             render.View(/*band=*/0, /*channel=*/0));
430       RandomizeSampleVector(&random_generator, capture);
431       render_delay_buffer->Insert(render);
432       filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture,
433                     false);
434     }
435 
436     // Obtain the best lag estimate and Verify that no lag estimates are
437     // reliable.
438     auto best_lag_estimates = filter.GetBestLagEstimate();
439     EXPECT_FALSE(best_lag_estimates.has_value());
440   }
441 }
442 
443 // Verifies that the matched filter does not produce updated lag estimates for
444 // render signals of low level.
TEST_P(MatchedFilterTest,LagNotUpdatedForLowLevelRender)445 TEST_P(MatchedFilterTest, LagNotUpdatedForLowLevelRender) {
446   const bool kDetectPreEcho = GetParam();
447   Random random_generator(42U);
448   constexpr size_t kNumChannels = 1;
449   constexpr int kSampleRateHz = 48000;
450   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
451 
452   for (auto down_sampling_factor : kDownSamplingFactors) {
453     const size_t sub_block_size = kBlockSize / down_sampling_factor;
454 
455     Block render(kNumBands, kNumChannels);
456     std::vector<std::vector<float>> capture(
457         1, std::vector<float>(kBlockSize, 0.f));
458     ApmDataDumper data_dumper(0);
459     EchoCanceller3Config config;
460     MatchedFilter filter(
461         &data_dumper, DetectOptimization(), sub_block_size,
462         kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
463         config.delay.delay_estimate_smoothing,
464         config.delay.delay_estimate_smoothing_delay_found,
465         config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
466     std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
467         RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
468                                   kNumChannels));
469     Decimator capture_decimator(down_sampling_factor);
470 
471     // Analyze the correlation between render and capture.
472     for (size_t k = 0; k < 100; ++k) {
473       RandomizeSampleVector(&random_generator, render.View(0, 0));
474       for (auto& render_k : render.View(0, 0)) {
475         render_k *= 149.f / 32767.f;
476       }
477       std::copy(render.begin(0, 0), render.end(0, 0), capture[0].begin());
478       std::array<float, kBlockSize> downsampled_capture_data;
479       rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
480                                                 sub_block_size);
481       capture_decimator.Decimate(capture[0], downsampled_capture);
482       filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
483                     downsampled_capture, false);
484     }
485 
486     // Verify that no lag estimate has been produced.
487     auto lag_estimate = filter.GetBestLagEstimate();
488     EXPECT_FALSE(lag_estimate.has_value());
489   }
490 }
491 
492 INSTANTIATE_TEST_SUITE_P(_, MatchedFilterTest, testing::Values(true, false));
493 
494 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
495 
496 class MatchedFilterDeathTest : public ::testing::TestWithParam<bool> {};
497 
498 // Verifies the check for non-zero windows size.
TEST_P(MatchedFilterDeathTest,ZeroWindowSize)499 TEST_P(MatchedFilterDeathTest, ZeroWindowSize) {
500   const bool kDetectPreEcho = GetParam();
501   ApmDataDumper data_dumper(0);
502   EchoCanceller3Config config;
503   EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1,
504                              150, config.delay.delay_estimate_smoothing,
505                              config.delay.delay_estimate_smoothing_delay_found,
506                              config.delay.delay_candidate_detection_threshold,
507                              kDetectPreEcho),
508                "");
509 }
510 
511 // Verifies the check for non-null data dumper.
TEST_P(MatchedFilterDeathTest,NullDataDumper)512 TEST_P(MatchedFilterDeathTest, NullDataDumper) {
513   const bool kDetectPreEcho = GetParam();
514   EchoCanceller3Config config;
515   EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150,
516                              config.delay.delay_estimate_smoothing,
517                              config.delay.delay_estimate_smoothing_delay_found,
518                              config.delay.delay_candidate_detection_threshold,
519                              kDetectPreEcho),
520                "");
521 }
522 
523 // Verifies the check for that the sub block size is a multiple of 4.
524 // TODO(peah): Activate the unittest once the required code has been landed.
TEST_P(MatchedFilterDeathTest,DISABLED_BlockSizeMultipleOf4)525 TEST_P(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) {
526   const bool kDetectPreEcho = GetParam();
527   ApmDataDumper data_dumper(0);
528   EchoCanceller3Config config;
529   EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1,
530                              150, config.delay.delay_estimate_smoothing,
531                              config.delay.delay_estimate_smoothing_delay_found,
532                              config.delay.delay_candidate_detection_threshold,
533                              kDetectPreEcho),
534                "");
535 }
536 
537 // Verifies the check for that there is an integer number of sub blocks that add
538 // up to a block size.
539 // TODO(peah): Activate the unittest once the required code has been landed.
TEST_P(MatchedFilterDeathTest,DISABLED_SubBlockSizeAddsUpToBlockSize)540 TEST_P(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) {
541   const bool kDetectPreEcho = GetParam();
542   ApmDataDumper data_dumper(0);
543   EchoCanceller3Config config;
544   EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1,
545                              150, config.delay.delay_estimate_smoothing,
546                              config.delay.delay_estimate_smoothing_delay_found,
547                              config.delay.delay_candidate_detection_threshold,
548                              kDetectPreEcho),
549                "");
550 }
551 
552 INSTANTIATE_TEST_SUITE_P(_,
553                          MatchedFilterDeathTest,
554                          testing::Values(true, false));
555 
556 #endif
557 
558 }  // namespace aec3
559 
TEST(MatchedFilterFieldTrialTest,PreEchoConfigurationTest)560 TEST(MatchedFilterFieldTrialTest, PreEchoConfigurationTest) {
561   float threshold_in = 0.1f;
562   int mode_in = 2;
563   rtc::StringBuilder field_trial_name;
564   field_trial_name << "WebRTC-Aec3PreEchoConfiguration/threshold:"
565                    << threshold_in << ",mode:" << mode_in << "/";
566   webrtc::test::ScopedFieldTrials field_trials(field_trial_name.str());
567   ApmDataDumper data_dumper(0);
568   EchoCanceller3Config config;
569   MatchedFilter matched_filter(
570       &data_dumper, DetectOptimization(),
571       kBlockSize / config.delay.down_sampling_factor,
572       aec3::kWindowSizeSubBlocks, aec3::kNumMatchedFilters,
573       aec3::kAlignmentShiftSubBlocks,
574       config.render_levels.poor_excitation_render_limit,
575       config.delay.delay_estimate_smoothing,
576       config.delay.delay_estimate_smoothing_delay_found,
577       config.delay.delay_candidate_detection_threshold,
578       config.delay.detect_pre_echo);
579 
580   auto& pre_echo_config = matched_filter.GetPreEchoConfiguration();
581   EXPECT_EQ(pre_echo_config.threshold, threshold_in);
582   EXPECT_EQ(pre_echo_config.mode, mode_in);
583 }
584 
TEST(MatchedFilterFieldTrialTest,WrongPreEchoConfigurationTest)585 TEST(MatchedFilterFieldTrialTest, WrongPreEchoConfigurationTest) {
586   constexpr float kDefaultThreshold = 0.5f;
587   constexpr int kDefaultMode = 0;
588   float threshold_in = -0.1f;
589   int mode_in = 5;
590   rtc::StringBuilder field_trial_name;
591   field_trial_name << "WebRTC-Aec3PreEchoConfiguration/threshold:"
592                    << threshold_in << ",mode:" << mode_in << "/";
593   webrtc::test::ScopedFieldTrials field_trials(field_trial_name.str());
594   ApmDataDumper data_dumper(0);
595   EchoCanceller3Config config;
596   MatchedFilter matched_filter(
597       &data_dumper, DetectOptimization(),
598       kBlockSize / config.delay.down_sampling_factor,
599       aec3::kWindowSizeSubBlocks, aec3::kNumMatchedFilters,
600       aec3::kAlignmentShiftSubBlocks,
601       config.render_levels.poor_excitation_render_limit,
602       config.delay.delay_estimate_smoothing,
603       config.delay.delay_estimate_smoothing_delay_found,
604       config.delay.delay_candidate_detection_threshold,
605       config.delay.detect_pre_echo);
606 
607   auto& pre_echo_config = matched_filter.GetPreEchoConfiguration();
608   EXPECT_EQ(pre_echo_config.threshold, kDefaultThreshold);
609   EXPECT_EQ(pre_echo_config.mode, kDefaultMode);
610 }
611 
612 }  // namespace webrtc
613