• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/suppression_gain.h"
12 
13 #include "modules/audio_processing/aec3/aec_state.h"
14 #include "modules/audio_processing/aec3/render_delay_buffer.h"
15 #include "modules/audio_processing/aec3/subtractor.h"
16 #include "modules/audio_processing/aec3/subtractor_output.h"
17 #include "modules/audio_processing/logging/apm_data_dumper.h"
18 #include "rtc_base/checks.h"
19 #include "system_wrappers/include/cpu_features_wrapper.h"
20 #include "test/gtest.h"
21 
22 namespace webrtc {
23 namespace aec3 {
24 
25 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
26 
27 // Verifies that the check for non-null output gains works.
TEST(SuppressionGainDeathTest,NullOutputGains)28 TEST(SuppressionGainDeathTest, NullOutputGains) {
29   std::vector<std::array<float, kFftLengthBy2Plus1>> E2(1, {0.f});
30   std::vector<std::array<float, kFftLengthBy2Plus1>> R2(1, {0.f});
31   std::vector<std::array<float, kFftLengthBy2Plus1>> S2(1);
32   std::vector<std::array<float, kFftLengthBy2Plus1>> N2(1, {0.f});
33   for (auto& S2_k : S2) {
34     S2_k.fill(.1f);
35   }
36   FftData E;
37   FftData Y;
38   E.re.fill(0.f);
39   E.im.fill(0.f);
40   Y.re.fill(0.f);
41   Y.im.fill(0.f);
42 
43   float high_bands_gain;
44   AecState aec_state(EchoCanceller3Config{}, 1);
45   EXPECT_DEATH(
46       SuppressionGain(EchoCanceller3Config{}, DetectOptimization(), 16000, 1)
47           .GetGain(E2, S2, R2, N2,
48                    RenderSignalAnalyzer((EchoCanceller3Config{})), aec_state,
49                    std::vector<std::vector<std::vector<float>>>(
50                        3, std::vector<std::vector<float>>(
51                               1, std::vector<float>(kBlockSize, 0.f))),
52                    &high_bands_gain, nullptr),
53       "");
54 }
55 
56 #endif
57 
58 // Does a sanity check that the gains are correctly computed.
TEST(SuppressionGain,BasicGainComputation)59 TEST(SuppressionGain, BasicGainComputation) {
60   constexpr size_t kNumRenderChannels = 1;
61   constexpr size_t kNumCaptureChannels = 2;
62   constexpr int kSampleRateHz = 16000;
63   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
64   SuppressionGain suppression_gain(EchoCanceller3Config(), DetectOptimization(),
65                                    kSampleRateHz, kNumCaptureChannels);
66   RenderSignalAnalyzer analyzer(EchoCanceller3Config{});
67   float high_bands_gain;
68   std::vector<std::array<float, kFftLengthBy2Plus1>> E2(kNumCaptureChannels);
69   std::vector<std::array<float, kFftLengthBy2Plus1>> S2(kNumCaptureChannels,
70                                                         {0.f});
71   std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
72   std::vector<std::array<float, kFftLengthBy2Plus1>> R2(kNumCaptureChannels);
73   std::vector<std::array<float, kFftLengthBy2Plus1>> N2(kNumCaptureChannels);
74   std::array<float, kFftLengthBy2Plus1> g;
75   std::vector<SubtractorOutput> output(kNumCaptureChannels);
76   std::vector<std::vector<std::vector<float>>> x(
77       kNumBands, std::vector<std::vector<float>>(
78                      kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
79   EchoCanceller3Config config;
80   AecState aec_state(config, kNumCaptureChannels);
81   ApmDataDumper data_dumper(42);
82   Subtractor subtractor(config, kNumRenderChannels, kNumCaptureChannels,
83                         &data_dumper, DetectOptimization());
84   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
85       RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
86   absl::optional<DelayEstimate> delay_estimate;
87 
88   // Ensure that a strong noise is detected to mask any echoes.
89   for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) {
90     E2[ch].fill(10.f);
91     Y2[ch].fill(10.f);
92     R2[ch].fill(.1f);
93     N2[ch].fill(100.f);
94   }
95   for (auto& subtractor_output : output) {
96     subtractor_output.Reset();
97   }
98 
99   // Ensure that the gain is no longer forced to zero.
100   for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) {
101     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
102                      subtractor.FilterImpulseResponses(),
103                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
104   }
105 
106   for (int k = 0; k < 100; ++k) {
107     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
108                      subtractor.FilterImpulseResponses(),
109                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
110     suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x,
111                              &high_bands_gain, &g);
112   }
113   std::for_each(g.begin(), g.end(),
114                 [](float a) { EXPECT_NEAR(1.f, a, 0.001); });
115 
116   // Ensure that a strong nearend is detected to mask any echoes.
117   for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) {
118     E2[ch].fill(100.f);
119     Y2[ch].fill(100.f);
120     R2[ch].fill(0.1f);
121     S2[ch].fill(0.1f);
122     N2[ch].fill(0.f);
123   }
124 
125   for (int k = 0; k < 100; ++k) {
126     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
127                      subtractor.FilterImpulseResponses(),
128                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
129     suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x,
130                              &high_bands_gain, &g);
131   }
132   std::for_each(g.begin(), g.end(),
133                 [](float a) { EXPECT_NEAR(1.f, a, 0.001); });
134 
135   // Add a strong echo to one of the channels and ensure that it is suppressed.
136   E2[1].fill(1000000000.f);
137   R2[1].fill(10000000000000.f);
138 
139   for (int k = 0; k < 10; ++k) {
140     suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x,
141                              &high_bands_gain, &g);
142   }
143   std::for_each(g.begin(), g.end(),
144                 [](float a) { EXPECT_NEAR(0.f, a, 0.001); });
145 }
146 
147 }  // namespace aec3
148 }  // namespace webrtc
149