• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
12 
13 #include <array>
14 #include <tuple>
15 
16 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
17 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
18 // #include "test/fpe_observer.h"
19 #include "test/gtest.h"
20 
21 namespace webrtc {
22 namespace rnn_vad {
23 namespace test {
24 namespace {
25 
26 constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
27 constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
28 
29 constexpr float kTestPitchGainsLow = 0.35f;
30 constexpr float kTestPitchGainsHigh = 0.75f;
31 
32 }  // namespace
33 
34 class ComputePitchGainThresholdTest
35     : public ::testing::Test,
36       public ::testing::WithParamInterface<std::tuple<
37           /*candidate_pitch_period=*/size_t,
38           /*pitch_period_ratio=*/size_t,
39           /*initial_pitch_period=*/size_t,
40           /*initial_pitch_gain=*/float,
41           /*prev_pitch_period=*/size_t,
42           /*prev_pitch_gain=*/float,
43           /*threshold=*/float>> {};
44 
45 // Checks that the computed pitch gain is within tolerance given test input
46 // data.
TEST_P(ComputePitchGainThresholdTest,WithinTolerance)47 TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
48   const auto params = GetParam();
49   const size_t candidate_pitch_period = std::get<0>(params);
50   const size_t pitch_period_ratio = std::get<1>(params);
51   const size_t initial_pitch_period = std::get<2>(params);
52   const float initial_pitch_gain = std::get<3>(params);
53   const size_t prev_pitch_period = std::get<4>(params);
54   const float prev_pitch_gain = std::get<5>(params);
55   const float threshold = std::get<6>(params);
56   {
57     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
58     // FloatingPointExceptionObserver fpe_observer;
59     EXPECT_NEAR(
60         threshold,
61         ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
62                                   initial_pitch_period, initial_pitch_gain,
63                                   prev_pitch_period, prev_pitch_gain),
64         5e-7f);
65   }
66 }
67 
68 INSTANTIATE_TEST_SUITE_P(
69     RnnVadTest,
70     ComputePitchGainThresholdTest,
71     ::testing::Values(
72         std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f),
73         std::make_tuple(113,
74                         2,
75                         226,
76                         0.20967799f,
77                         219,
78                         0.40392199f,
79                         0.30000001f),
80         std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f),
81         std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f),
82         std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f),
83         std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
84         std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
85 
86 // Checks that the frame-wise sliding square energy function produces output
87 // within tolerance given test input data.
TEST(RnnVadTest,ComputeSlidingFrameSquareEnergiesWithinTolerance)88 TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) {
89   PitchTestData test_data;
90   std::array<float, kNumPitchBufSquareEnergies> computed_output;
91   {
92     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
93     // FloatingPointExceptionObserver fpe_observer;
94     ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(),
95                                       computed_output);
96   }
97   auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
98   ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
99                      computed_output, 3e-2f);
100 }
101 
102 // Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest,FindBestPitchPeriodsBitExactness)103 TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
104   PitchTestData test_data;
105   std::array<float, kBufSize12kHz> pitch_buf_decimated;
106   Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
107   std::array<size_t, 2> pitch_candidates_inv_lags;
108   {
109     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
110     // FloatingPointExceptionObserver fpe_observer;
111     auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
112     pitch_candidates_inv_lags =
113         FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
114                              pitch_buf_decimated, kMaxPitch12kHz);
115   }
116   EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast<size_t>(140));
117   EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast<size_t>(142));
118 }
119 
120 // Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest,RefinePitchPeriod48kHzBitExactness)121 TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
122   PitchTestData test_data;
123   size_t pitch_inv_lag;
124   {
125     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
126     // FloatingPointExceptionObserver fpe_observer;
127     const std::array<size_t, 2> pitch_candidates_inv_lags = {280, 284};
128     pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
129                                            pitch_candidates_inv_lags);
130   }
131   EXPECT_EQ(560u, pitch_inv_lag);
132 }
133 
134 class CheckLowerPitchPeriodsAndComputePitchGainTest
135     : public ::testing::Test,
136       public ::testing::WithParamInterface<std::tuple<
137           /*initial_pitch_period=*/int,
138           /*prev_pitch_period=*/int,
139           /*prev_pitch_gain=*/float,
140           /*expected_pitch_period=*/int,
141           /*expected_pitch_gain=*/float>> {};
142 
143 // Checks that the computed pitch period is bit-exact and that the computed
144 // pitch gain is within tolerance given test input data.
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,PeriodBitExactnessGainWithinTolerance)145 TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,
146        PeriodBitExactnessGainWithinTolerance) {
147   const auto params = GetParam();
148   const int initial_pitch_period = std::get<0>(params);
149   const int prev_pitch_period = std::get<1>(params);
150   const float prev_pitch_gain = std::get<2>(params);
151   const int expected_pitch_period = std::get<3>(params);
152   const float expected_pitch_gain = std::get<4>(params);
153   PitchTestData test_data;
154   {
155     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
156     // FloatingPointExceptionObserver fpe_observer;
157     const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain(
158         test_data.GetPitchBufView(), initial_pitch_period,
159         {prev_pitch_period, prev_pitch_gain});
160     EXPECT_EQ(expected_pitch_period, computed_output.period);
161     EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f);
162   }
163 }
164 
165 INSTANTIATE_TEST_SUITE_P(
166     RnnVadTest,
167     CheckLowerPitchPeriodsAndComputePitchGainTest,
168     ::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
169                                       kTestPitchPeriodsLow,
170                                       kTestPitchGainsLow,
171                                       91,
172                                       -0.0188608f),
173                       std::make_tuple(kTestPitchPeriodsLow,
174                                       kTestPitchPeriodsLow,
175                                       kTestPitchGainsHigh,
176                                       91,
177                                       -0.0188608f),
178                       std::make_tuple(kTestPitchPeriodsLow,
179                                       kTestPitchPeriodsHigh,
180                                       kTestPitchGainsLow,
181                                       91,
182                                       -0.0188608f),
183                       std::make_tuple(kTestPitchPeriodsLow,
184                                       kTestPitchPeriodsHigh,
185                                       kTestPitchGainsHigh,
186                                       91,
187                                       -0.0188608f),
188                       std::make_tuple(kTestPitchPeriodsHigh,
189                                       kTestPitchPeriodsLow,
190                                       kTestPitchGainsLow,
191                                       475,
192                                       -0.0904344f),
193                       std::make_tuple(kTestPitchPeriodsHigh,
194                                       kTestPitchPeriodsLow,
195                                       kTestPitchGainsHigh,
196                                       475,
197                                       -0.0904344f),
198                       std::make_tuple(kTestPitchPeriodsHigh,
199                                       kTestPitchPeriodsHigh,
200                                       kTestPitchGainsLow,
201                                       475,
202                                       -0.0904344f),
203                       std::make_tuple(kTestPitchPeriodsHigh,
204                                       kTestPitchPeriodsHigh,
205                                       kTestPitchGainsHigh,
206                                       475,
207                                       -0.0904344f)));
208 
209 }  // namespace test
210 }  // namespace rnn_vad
211 }  // namespace webrtc
212