• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/signal_dependent_erle_estimator.h"
12 
13 #include <algorithm>
14 #include <iostream>
15 #include <string>
16 
17 #include "api/audio/echo_canceller3_config.h"
18 #include "modules/audio_processing/aec3/render_buffer.h"
19 #include "modules/audio_processing/aec3/render_delay_buffer.h"
20 #include "rtc_base/strings/string_builder.h"
21 #include "test/gtest.h"
22 
23 namespace webrtc {
24 
25 namespace {
26 
GetActiveFrame(std::vector<std::vector<std::vector<float>>> * x)27 void GetActiveFrame(std::vector<std::vector<std::vector<float>>>* x) {
28   const std::array<float, kBlockSize> frame = {
29       7459.88, 17209.6, 17383,   20768.9, 16816.7, 18386.3, 4492.83, 9675.85,
30       6665.52, 14808.6, 9342.3,  7483.28, 19261.7, 4145.98, 1622.18, 13475.2,
31       7166.32, 6856.61, 21937,   7263.14, 9569.07, 14919,   8413.32, 7551.89,
32       7848.65, 6011.27, 13080.6, 15865.2, 12656,   17459.6, 4263.93, 4503.03,
33       9311.79, 21095.8, 12657.9, 13906.6, 19267.2, 11338.1, 16828.9, 11501.6,
34       11405,   15031.4, 14541.6, 19765.5, 18346.3, 19350.2, 3157.47, 18095.8,
35       1743.68, 21328.2, 19727.5, 7295.16, 10332.4, 11055.5, 20107.4, 14708.4,
36       12416.2, 16434,   2454.69, 9840.8,  6867.23, 1615.75, 6059.9,  8394.19};
37   for (size_t band = 0; band < x->size(); ++band) {
38     for (size_t channel = 0; channel < (*x)[band].size(); ++channel) {
39       RTC_DCHECK_GE((*x)[band][channel].size(), frame.size());
40       std::copy(frame.begin(), frame.end(), (*x)[band][channel].begin());
41     }
42   }
43 }
44 
45 class TestInputs {
46  public:
47   TestInputs(const EchoCanceller3Config& cfg,
48              size_t num_render_channels,
49              size_t num_capture_channels);
50   ~TestInputs();
GetRenderBuffer()51   const RenderBuffer& GetRenderBuffer() { return *render_buffer_; }
GetX2()52   rtc::ArrayView<const float, kFftLengthBy2Plus1> GetX2() { return X2_; }
GetY2() const53   rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetY2() const {
54     return Y2_;
55   }
GetE2() const56   rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetE2() const {
57     return E2_;
58   }
59   rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
GetH2() const60   GetH2() const {
61     return H2_;
62   }
GetConvergedFilters() const63   const std::vector<bool>& GetConvergedFilters() const {
64     return converged_filters_;
65   }
66   void Update();
67 
68  private:
69   void UpdateCurrentPowerSpectra();
70   int n_ = 0;
71   std::unique_ptr<RenderDelayBuffer> render_delay_buffer_;
72   RenderBuffer* render_buffer_;
73   std::array<float, kFftLengthBy2Plus1> X2_;
74   std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_;
75   std::vector<std::array<float, kFftLengthBy2Plus1>> E2_;
76   std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_;
77   std::vector<std::vector<std::vector<float>>> x_;
78   std::vector<bool> converged_filters_;
79 };
80 
TestInputs(const EchoCanceller3Config & cfg,size_t num_render_channels,size_t num_capture_channels)81 TestInputs::TestInputs(const EchoCanceller3Config& cfg,
82                        size_t num_render_channels,
83                        size_t num_capture_channels)
84     : render_delay_buffer_(
85           RenderDelayBuffer::Create(cfg, 16000, num_render_channels)),
86       Y2_(num_capture_channels),
87       E2_(num_capture_channels),
88       H2_(num_capture_channels,
89           std::vector<std::array<float, kFftLengthBy2Plus1>>(
90               cfg.filter.refined.length_blocks)),
91       x_(1,
92          std::vector<std::vector<float>>(num_render_channels,
93                                          std::vector<float>(kBlockSize, 0.f))),
94       converged_filters_(num_capture_channels, true) {
95   render_delay_buffer_->AlignFromDelay(4);
96   render_buffer_ = render_delay_buffer_->GetRenderBuffer();
97   for (auto& H2_ch : H2_) {
98     for (auto& H2_p : H2_ch) {
99       H2_p.fill(0.f);
100     }
101   }
102   for (auto& H2_p : H2_[0]) {
103     H2_p.fill(1.f);
104   }
105 }
106 
107 TestInputs::~TestInputs() = default;
108 
Update()109 void TestInputs::Update() {
110   if (n_ % 2 == 0) {
111     std::fill(x_[0][0].begin(), x_[0][0].end(), 0.f);
112   } else {
113     GetActiveFrame(&x_);
114   }
115 
116   render_delay_buffer_->Insert(x_);
117   render_delay_buffer_->PrepareCaptureProcessing();
118   UpdateCurrentPowerSpectra();
119   ++n_;
120 }
121 
UpdateCurrentPowerSpectra()122 void TestInputs::UpdateCurrentPowerSpectra() {
123   const SpectrumBuffer& spectrum_render_buffer =
124       render_buffer_->GetSpectrumBuffer();
125   size_t idx = render_buffer_->Position();
126   size_t prev_idx = spectrum_render_buffer.OffsetIndex(idx, 1);
127   auto& X2 = spectrum_render_buffer.buffer[idx][/*channel=*/0];
128   auto& X2_prev = spectrum_render_buffer.buffer[prev_idx][/*channel=*/0];
129   std::copy(X2.begin(), X2.end(), X2_.begin());
130   for (size_t ch = 0; ch < Y2_.size(); ++ch) {
131     RTC_DCHECK_EQ(X2.size(), Y2_[ch].size());
132     for (size_t k = 0; k < X2.size(); ++k) {
133       E2_[ch][k] = 0.01f * X2_prev[k];
134       Y2_[ch][k] = X2[k] + E2_[ch][k];
135     }
136   }
137 }
138 
139 }  // namespace
140 
141 class SignalDependentErleEstimatorMultiChannel
142     : public ::testing::Test,
143       public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
144 
145 INSTANTIATE_TEST_SUITE_P(MultiChannel,
146                          SignalDependentErleEstimatorMultiChannel,
147                          ::testing::Combine(::testing::Values(1, 2, 4),
148                                             ::testing::Values(1, 2, 4)));
149 
TEST_P(SignalDependentErleEstimatorMultiChannel,SweepSettings)150 TEST_P(SignalDependentErleEstimatorMultiChannel, SweepSettings) {
151   const size_t num_render_channels = std::get<0>(GetParam());
152   const size_t num_capture_channels = std::get<1>(GetParam());
153   EchoCanceller3Config cfg;
154   size_t max_length_blocks = 50;
155   for (size_t blocks = 1; blocks < max_length_blocks; blocks = blocks + 10) {
156     for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) {
157       for (size_t num_sections = 2; num_sections < max_length_blocks;
158            ++num_sections) {
159         cfg.filter.refined.length_blocks = blocks;
160         cfg.filter.refined_initial.length_blocks =
161             std::min(cfg.filter.refined_initial.length_blocks, blocks);
162         cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize;
163         cfg.erle.num_sections = num_sections;
164         if (EchoCanceller3Config::Validate(&cfg)) {
165           SignalDependentErleEstimator s(cfg, num_capture_channels);
166           std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle(
167               num_capture_channels);
168           for (auto& e : average_erle) {
169             e.fill(cfg.erle.max_l);
170           }
171           TestInputs inputs(cfg, num_render_channels, num_capture_channels);
172           for (size_t n = 0; n < 10; ++n) {
173             inputs.Update();
174             s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
175                      inputs.GetY2(), inputs.GetE2(), average_erle,
176                      inputs.GetConvergedFilters());
177           }
178         }
179       }
180     }
181   }
182 }
183 
TEST_P(SignalDependentErleEstimatorMultiChannel,LongerRun)184 TEST_P(SignalDependentErleEstimatorMultiChannel, LongerRun) {
185   const size_t num_render_channels = std::get<0>(GetParam());
186   const size_t num_capture_channels = std::get<1>(GetParam());
187   EchoCanceller3Config cfg;
188   cfg.filter.refined.length_blocks = 2;
189   cfg.filter.refined_initial.length_blocks = 1;
190   cfg.delay.delay_headroom_samples = 0;
191   cfg.delay.hysteresis_limit_blocks = 0;
192   cfg.erle.num_sections = 2;
193   EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true);
194   std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle(
195       num_capture_channels);
196   for (auto& e : average_erle) {
197     e.fill(cfg.erle.max_l);
198   }
199   SignalDependentErleEstimator s(cfg, num_capture_channels);
200   TestInputs inputs(cfg, num_render_channels, num_capture_channels);
201   for (size_t n = 0; n < 200; ++n) {
202     inputs.Update();
203     s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
204              inputs.GetY2(), inputs.GetE2(), average_erle,
205              inputs.GetConvergedFilters());
206   }
207 }
208 
209 }  // namespace webrtc
210