• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <array>
12 #include <string>
13 #include <vector>
14 
15 #include "common_audio/resampler/push_sinc_resampler.h"
16 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
17 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
18 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
19 #include "modules/audio_processing/test/performance_timer.h"
20 #include "rtc_base/checks.h"
21 #include "rtc_base/logging.h"
22 #include "test/gtest.h"
23 #include "third_party/rnnoise/src/rnn_activations.h"
24 #include "third_party/rnnoise/src/rnn_vad_weights.h"
25 
26 namespace webrtc {
27 namespace rnn_vad {
28 namespace test {
29 namespace {
30 
31 constexpr size_t kFrameSize10ms48kHz = 480;
32 
DumpPerfStats(size_t num_samples,size_t sample_rate,double average_us,double standard_deviation)33 void DumpPerfStats(size_t num_samples,
34                    size_t sample_rate,
35                    double average_us,
36                    double standard_deviation) {
37   float audio_track_length_ms =
38       1e3f * static_cast<float>(num_samples) / static_cast<float>(sample_rate);
39   float average_ms = static_cast<float>(average_us) / 1e3f;
40   float speed = audio_track_length_ms / average_ms;
41   RTC_LOG(LS_INFO) << "track duration (ms): " << audio_track_length_ms;
42   RTC_LOG(LS_INFO) << "average processing time (ms): " << average_ms << " +/- "
43                    << (standard_deviation / 1e3);
44   RTC_LOG(LS_INFO) << "speed: " << speed << "x";
45 }
46 
47 // When the RNN VAD model is updated and the expected output changes, set the
48 // constant below to true in order to write new expected output binary files.
49 constexpr bool kWriteComputedOutputToFile = false;
50 
51 }  // namespace
52 
53 // Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
54 // when the expected output files are re-exported.
TEST(RnnVadTest,CheckWriteComputedOutputIsFalse)55 TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
56   ASSERT_FALSE(kWriteComputedOutputToFile)
57       << "Cannot land if kWriteComputedOutput is true.";
58 }
59 
60 // Checks that the computed VAD probability for a test input sequence sampled at
61 // 48 kHz is within tolerance.
TEST(RnnVadTest,RnnVadProbabilityWithinTolerance)62 TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
63   // Init resampler, feature extractor and RNN.
64   PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
65   FeaturesExtractor features_extractor;
66   RnnBasedVad rnn_vad;
67 
68   // Init input samples and expected output readers.
69   auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
70   auto expected_vad_prob_reader = CreateVadProbsReader();
71 
72   // Input length.
73   const size_t num_frames = samples_reader.second;
74   ASSERT_GE(expected_vad_prob_reader.second, num_frames);
75 
76   // Init buffers.
77   std::vector<float> samples_48k(kFrameSize10ms48kHz);
78   std::vector<float> samples_24k(kFrameSize10ms24kHz);
79   std::vector<float> feature_vector(kFeatureVectorSize);
80   std::vector<float> computed_vad_prob(num_frames);
81   std::vector<float> expected_vad_prob(num_frames);
82 
83   // Read expected output.
84   ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
85 
86   // Compute VAD probabilities on the downsampled input.
87   float cumulative_error = 0.f;
88   for (size_t i = 0; i < num_frames; ++i) {
89     samples_reader.first->ReadChunk(samples_48k);
90     decimator.Resample(samples_48k.data(), samples_48k.size(),
91                        samples_24k.data(), samples_24k.size());
92     bool is_silence = features_extractor.CheckSilenceComputeFeatures(
93         {samples_24k.data(), kFrameSize10ms24kHz},
94         {feature_vector.data(), kFeatureVectorSize});
95     computed_vad_prob[i] = rnn_vad.ComputeVadProbability(
96         {feature_vector.data(), kFeatureVectorSize}, is_silence);
97     EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f);
98     cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]);
99   }
100   // Check average error.
101   EXPECT_LT(cumulative_error / num_frames, 1e-4f);
102 
103   if (kWriteComputedOutputToFile) {
104     BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
105     vad_prob_writer.WriteChunk(computed_vad_prob);
106   }
107 }
108 
109 // Performance test for the RNN VAD (pre-fetching and downsampling are
110 // excluded). Keep disabled and only enable locally to measure performance as
111 // follows:
112 // - on desktop: run the this unit test adding "--logs";
113 // - on android: run the this unit test adding "--logcat-output-file".
TEST(RnnVadTest,DISABLED_RnnVadPerformance)114 TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
115   // PCM samples reader and buffers.
116   auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
117   const size_t num_frames = samples_reader.second;
118   std::array<float, kFrameSize10ms48kHz> samples;
119   // Pre-fetch and decimate samples.
120   PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
121   std::vector<float> prefetched_decimated_samples;
122   prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
123   for (size_t i = 0; i < num_frames; ++i) {
124     samples_reader.first->ReadChunk(samples);
125     decimator.Resample(samples.data(), samples.size(),
126                        &prefetched_decimated_samples[i * kFrameSize10ms24kHz],
127                        kFrameSize10ms24kHz);
128   }
129   // Initialize.
130   FeaturesExtractor features_extractor;
131   std::array<float, kFeatureVectorSize> feature_vector;
132   RnnBasedVad rnn_vad;
133   constexpr size_t number_of_tests = 100;
134   ::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
135   for (size_t k = 0; k < number_of_tests; ++k) {
136     features_extractor.Reset();
137     rnn_vad.Reset();
138     // Process frames.
139     perf_timer.StartTimer();
140     for (size_t i = 0; i < num_frames; ++i) {
141       bool is_silence = features_extractor.CheckSilenceComputeFeatures(
142           {&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
143            kFrameSize10ms24kHz},
144           feature_vector);
145       rnn_vad.ComputeVadProbability(feature_vector, is_silence);
146     }
147     perf_timer.StopTimer();
148     samples_reader.first->SeekBeginning();
149   }
150   DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz,
151                 perf_timer.GetDurationAverage(),
152                 perf_timer.GetDurationStandardDeviation());
153 }
154 
155 }  // namespace test
156 }  // namespace rnn_vad
157 }  // namespace webrtc
158