1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
12
13 #include <array>
14 #include <tuple>
15
16 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
17 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
18 // #include "test/fpe_observer.h"
19 #include "test/gtest.h"
20
21 namespace webrtc {
22 namespace rnn_vad {
23 namespace test {
24 namespace {
25
26 constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
27 constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
28
29 constexpr float kTestPitchGainsLow = 0.35f;
30 constexpr float kTestPitchGainsHigh = 0.75f;
31
32 } // namespace
33
34 class ComputePitchGainThresholdTest
35 : public ::testing::Test,
36 public ::testing::WithParamInterface<std::tuple<
37 /*candidate_pitch_period=*/size_t,
38 /*pitch_period_ratio=*/size_t,
39 /*initial_pitch_period=*/size_t,
40 /*initial_pitch_gain=*/float,
41 /*prev_pitch_period=*/size_t,
42 /*prev_pitch_gain=*/float,
43 /*threshold=*/float>> {};
44
45 // Checks that the computed pitch gain is within tolerance given test input
46 // data.
TEST_P(ComputePitchGainThresholdTest,WithinTolerance)47 TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
48 const auto params = GetParam();
49 const size_t candidate_pitch_period = std::get<0>(params);
50 const size_t pitch_period_ratio = std::get<1>(params);
51 const size_t initial_pitch_period = std::get<2>(params);
52 const float initial_pitch_gain = std::get<3>(params);
53 const size_t prev_pitch_period = std::get<4>(params);
54 const float prev_pitch_gain = std::get<5>(params);
55 const float threshold = std::get<6>(params);
56 {
57 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
58 // FloatingPointExceptionObserver fpe_observer;
59 EXPECT_NEAR(
60 threshold,
61 ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
62 initial_pitch_period, initial_pitch_gain,
63 prev_pitch_period, prev_pitch_gain),
64 5e-7f);
65 }
66 }
67
68 INSTANTIATE_TEST_SUITE_P(
69 RnnVadTest,
70 ComputePitchGainThresholdTest,
71 ::testing::Values(
72 std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f),
73 std::make_tuple(113,
74 2,
75 226,
76 0.20967799f,
77 219,
78 0.40392199f,
79 0.30000001f),
80 std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f),
81 std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f),
82 std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f),
83 std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
84 std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
85
86 // Checks that the frame-wise sliding square energy function produces output
87 // within tolerance given test input data.
TEST(RnnVadTest,ComputeSlidingFrameSquareEnergiesWithinTolerance)88 TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) {
89 PitchTestData test_data;
90 std::array<float, kNumPitchBufSquareEnergies> computed_output;
91 {
92 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
93 // FloatingPointExceptionObserver fpe_observer;
94 ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(),
95 computed_output);
96 }
97 auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
98 ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
99 computed_output, 3e-2f);
100 }
101
102 // Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest,FindBestPitchPeriodsBitExactness)103 TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
104 PitchTestData test_data;
105 std::array<float, kBufSize12kHz> pitch_buf_decimated;
106 Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
107 std::array<size_t, 2> pitch_candidates_inv_lags;
108 {
109 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
110 // FloatingPointExceptionObserver fpe_observer;
111 auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
112 pitch_candidates_inv_lags =
113 FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
114 pitch_buf_decimated, kMaxPitch12kHz);
115 }
116 EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast<size_t>(140));
117 EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast<size_t>(142));
118 }
119
120 // Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest,RefinePitchPeriod48kHzBitExactness)121 TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
122 PitchTestData test_data;
123 size_t pitch_inv_lag;
124 {
125 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
126 // FloatingPointExceptionObserver fpe_observer;
127 const std::array<size_t, 2> pitch_candidates_inv_lags = {280, 284};
128 pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
129 pitch_candidates_inv_lags);
130 }
131 EXPECT_EQ(560u, pitch_inv_lag);
132 }
133
134 class CheckLowerPitchPeriodsAndComputePitchGainTest
135 : public ::testing::Test,
136 public ::testing::WithParamInterface<std::tuple<
137 /*initial_pitch_period=*/int,
138 /*prev_pitch_period=*/int,
139 /*prev_pitch_gain=*/float,
140 /*expected_pitch_period=*/int,
141 /*expected_pitch_gain=*/float>> {};
142
143 // Checks that the computed pitch period is bit-exact and that the computed
144 // pitch gain is within tolerance given test input data.
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,PeriodBitExactnessGainWithinTolerance)145 TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,
146 PeriodBitExactnessGainWithinTolerance) {
147 const auto params = GetParam();
148 const int initial_pitch_period = std::get<0>(params);
149 const int prev_pitch_period = std::get<1>(params);
150 const float prev_pitch_gain = std::get<2>(params);
151 const int expected_pitch_period = std::get<3>(params);
152 const float expected_pitch_gain = std::get<4>(params);
153 PitchTestData test_data;
154 {
155 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
156 // FloatingPointExceptionObserver fpe_observer;
157 const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain(
158 test_data.GetPitchBufView(), initial_pitch_period,
159 {prev_pitch_period, prev_pitch_gain});
160 EXPECT_EQ(expected_pitch_period, computed_output.period);
161 EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f);
162 }
163 }
164
165 INSTANTIATE_TEST_SUITE_P(
166 RnnVadTest,
167 CheckLowerPitchPeriodsAndComputePitchGainTest,
168 ::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
169 kTestPitchPeriodsLow,
170 kTestPitchGainsLow,
171 91,
172 -0.0188608f),
173 std::make_tuple(kTestPitchPeriodsLow,
174 kTestPitchPeriodsLow,
175 kTestPitchGainsHigh,
176 91,
177 -0.0188608f),
178 std::make_tuple(kTestPitchPeriodsLow,
179 kTestPitchPeriodsHigh,
180 kTestPitchGainsLow,
181 91,
182 -0.0188608f),
183 std::make_tuple(kTestPitchPeriodsLow,
184 kTestPitchPeriodsHigh,
185 kTestPitchGainsHigh,
186 91,
187 -0.0188608f),
188 std::make_tuple(kTestPitchPeriodsHigh,
189 kTestPitchPeriodsLow,
190 kTestPitchGainsLow,
191 475,
192 -0.0904344f),
193 std::make_tuple(kTestPitchPeriodsHigh,
194 kTestPitchPeriodsLow,
195 kTestPitchGainsHigh,
196 475,
197 -0.0904344f),
198 std::make_tuple(kTestPitchPeriodsHigh,
199 kTestPitchPeriodsHigh,
200 kTestPitchGainsLow,
201 475,
202 -0.0904344f),
203 std::make_tuple(kTestPitchPeriodsHigh,
204 kTestPitchPeriodsHigh,
205 kTestPitchGainsHigh,
206 475,
207 -0.0904344f)));
208
209 } // namespace test
210 } // namespace rnn_vad
211 } // namespace webrtc
212