1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <array>
12 #include <string>
13 #include <vector>
14
15 #include "common_audio/resampler/push_sinc_resampler.h"
16 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
17 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
18 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
19 #include "modules/audio_processing/test/performance_timer.h"
20 #include "rtc_base/checks.h"
21 #include "rtc_base/logging.h"
22 #include "test/gtest.h"
23 #include "third_party/rnnoise/src/rnn_activations.h"
24 #include "third_party/rnnoise/src/rnn_vad_weights.h"
25
26 namespace webrtc {
27 namespace rnn_vad {
28 namespace test {
29 namespace {
30
31 constexpr size_t kFrameSize10ms48kHz = 480;
32
DumpPerfStats(size_t num_samples,size_t sample_rate,double average_us,double standard_deviation)33 void DumpPerfStats(size_t num_samples,
34 size_t sample_rate,
35 double average_us,
36 double standard_deviation) {
37 float audio_track_length_ms =
38 1e3f * static_cast<float>(num_samples) / static_cast<float>(sample_rate);
39 float average_ms = static_cast<float>(average_us) / 1e3f;
40 float speed = audio_track_length_ms / average_ms;
41 RTC_LOG(LS_INFO) << "track duration (ms): " << audio_track_length_ms;
42 RTC_LOG(LS_INFO) << "average processing time (ms): " << average_ms << " +/- "
43 << (standard_deviation / 1e3);
44 RTC_LOG(LS_INFO) << "speed: " << speed << "x";
45 }
46
47 // When the RNN VAD model is updated and the expected output changes, set the
48 // constant below to true in order to write new expected output binary files.
49 constexpr bool kWriteComputedOutputToFile = false;
50
51 } // namespace
52
53 // Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
54 // when the expected output files are re-exported.
TEST(RnnVadTest,CheckWriteComputedOutputIsFalse)55 TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
56 ASSERT_FALSE(kWriteComputedOutputToFile)
57 << "Cannot land if kWriteComputedOutput is true.";
58 }
59
60 // Checks that the computed VAD probability for a test input sequence sampled at
61 // 48 kHz is within tolerance.
TEST(RnnVadTest,RnnVadProbabilityWithinTolerance)62 TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
63 // Init resampler, feature extractor and RNN.
64 PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
65 FeaturesExtractor features_extractor;
66 RnnBasedVad rnn_vad;
67
68 // Init input samples and expected output readers.
69 auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
70 auto expected_vad_prob_reader = CreateVadProbsReader();
71
72 // Input length.
73 const size_t num_frames = samples_reader.second;
74 ASSERT_GE(expected_vad_prob_reader.second, num_frames);
75
76 // Init buffers.
77 std::vector<float> samples_48k(kFrameSize10ms48kHz);
78 std::vector<float> samples_24k(kFrameSize10ms24kHz);
79 std::vector<float> feature_vector(kFeatureVectorSize);
80 std::vector<float> computed_vad_prob(num_frames);
81 std::vector<float> expected_vad_prob(num_frames);
82
83 // Read expected output.
84 ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
85
86 // Compute VAD probabilities on the downsampled input.
87 float cumulative_error = 0.f;
88 for (size_t i = 0; i < num_frames; ++i) {
89 samples_reader.first->ReadChunk(samples_48k);
90 decimator.Resample(samples_48k.data(), samples_48k.size(),
91 samples_24k.data(), samples_24k.size());
92 bool is_silence = features_extractor.CheckSilenceComputeFeatures(
93 {samples_24k.data(), kFrameSize10ms24kHz},
94 {feature_vector.data(), kFeatureVectorSize});
95 computed_vad_prob[i] = rnn_vad.ComputeVadProbability(
96 {feature_vector.data(), kFeatureVectorSize}, is_silence);
97 EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f);
98 cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]);
99 }
100 // Check average error.
101 EXPECT_LT(cumulative_error / num_frames, 1e-4f);
102
103 if (kWriteComputedOutputToFile) {
104 BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
105 vad_prob_writer.WriteChunk(computed_vad_prob);
106 }
107 }
108
109 // Performance test for the RNN VAD (pre-fetching and downsampling are
110 // excluded). Keep disabled and only enable locally to measure performance as
111 // follows:
112 // - on desktop: run the this unit test adding "--logs";
113 // - on android: run the this unit test adding "--logcat-output-file".
TEST(RnnVadTest,DISABLED_RnnVadPerformance)114 TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
115 // PCM samples reader and buffers.
116 auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
117 const size_t num_frames = samples_reader.second;
118 std::array<float, kFrameSize10ms48kHz> samples;
119 // Pre-fetch and decimate samples.
120 PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
121 std::vector<float> prefetched_decimated_samples;
122 prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
123 for (size_t i = 0; i < num_frames; ++i) {
124 samples_reader.first->ReadChunk(samples);
125 decimator.Resample(samples.data(), samples.size(),
126 &prefetched_decimated_samples[i * kFrameSize10ms24kHz],
127 kFrameSize10ms24kHz);
128 }
129 // Initialize.
130 FeaturesExtractor features_extractor;
131 std::array<float, kFeatureVectorSize> feature_vector;
132 RnnBasedVad rnn_vad;
133 constexpr size_t number_of_tests = 100;
134 ::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
135 for (size_t k = 0; k < number_of_tests; ++k) {
136 features_extractor.Reset();
137 rnn_vad.Reset();
138 // Process frames.
139 perf_timer.StartTimer();
140 for (size_t i = 0; i < num_frames; ++i) {
141 bool is_silence = features_extractor.CheckSilenceComputeFeatures(
142 {&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
143 kFrameSize10ms24kHz},
144 feature_vector);
145 rnn_vad.ComputeVadProbability(feature_vector, is_silence);
146 }
147 perf_timer.StopTimer();
148 samples_reader.first->SeekBeginning();
149 }
150 DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz,
151 perf_timer.GetDurationAverage(),
152 perf_timer.GetDurationStandardDeviation());
153 }
154
155 } // namespace test
156 } // namespace rnn_vad
157 } // namespace webrtc
158