• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
12 
13 #include <cmath>
14 #include <vector>
15 
16 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
17 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
18 // #include "test/fpe_observer.h"
19 #include "test/gtest.h"
20 
21 namespace webrtc {
22 namespace rnn_vad {
23 namespace test {
24 namespace {
25 
ceil(size_t n,size_t m)26 constexpr size_t ceil(size_t n, size_t m) {
27   return (n + m - 1) / m;
28 }
29 
30 // Number of 10 ms frames required to fill a pitch buffer having size
31 // |kBufSize24kHz|.
32 constexpr size_t kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
33 // Number of samples for the test data.
34 constexpr size_t kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
35 
36 // Verifies that the pitch in Hz is in the detectable range.
PitchIsValid(float pitch_hz)37 bool PitchIsValid(float pitch_hz) {
38   const size_t pitch_period =
39       static_cast<size_t>(static_cast<float>(kSampleRate24kHz) / pitch_hz);
40   return kInitialMinPitch24kHz <= pitch_period &&
41          pitch_period <= kMaxPitch24kHz;
42 }
43 
CreatePureTone(float amplitude,float freq_hz,rtc::ArrayView<float> dst)44 void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
45   for (size_t i = 0; i < dst.size(); ++i) {
46     dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
47   }
48 }
49 
50 // Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
51 // For every frame, the output is written into |feature_vector|. Returns true
52 // if silence is detected in the last frame.
FeedTestData(FeaturesExtractor * features_extractor,rtc::ArrayView<const float> samples,rtc::ArrayView<float,kFeatureVectorSize> feature_vector)53 bool FeedTestData(FeaturesExtractor* features_extractor,
54                   rtc::ArrayView<const float> samples,
55                   rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
56   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
57   // FloatingPointExceptionObserver fpe_observer;
58   bool is_silence = true;
59   const size_t num_frames = samples.size() / kFrameSize10ms24kHz;
60   for (size_t i = 0; i < num_frames; ++i) {
61     is_silence = features_extractor->CheckSilenceComputeFeatures(
62         {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
63         feature_vector);
64   }
65   return is_silence;
66 }
67 
68 }  // namespace
69 
70 // Extracts the features for two pure tones and verifies that the pitch field
71 // values reflect the known tone frequencies.
TEST(RnnVadTest,FeatureExtractionLowHighPitch)72 TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
73   constexpr float amplitude = 1000.f;
74   constexpr float low_pitch_hz = 150.f;
75   constexpr float high_pitch_hz = 250.f;
76   ASSERT_TRUE(PitchIsValid(low_pitch_hz));
77   ASSERT_TRUE(PitchIsValid(high_pitch_hz));
78 
79   FeaturesExtractor features_extractor;
80   std::vector<float> samples(kNumTestDataSize);
81   std::vector<float> feature_vector(kFeatureVectorSize);
82   ASSERT_EQ(kFeatureVectorSize, feature_vector.size());
83   rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
84       feature_vector.data(), kFeatureVectorSize);
85 
86   // Extract the normalized scalar feature that is proportional to the estimated
87   // pitch period.
88   constexpr size_t pitch_feature_index = kFeatureVectorSize - 2;
89   // Low frequency tone - i.e., high period.
90   CreatePureTone(amplitude, low_pitch_hz, samples);
91   ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
92   float high_pitch_period = feature_vector_view[pitch_feature_index];
93   // High frequency tone - i.e., low period.
94   features_extractor.Reset();
95   CreatePureTone(amplitude, high_pitch_hz, samples);
96   ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
97   float low_pitch_period = feature_vector_view[pitch_feature_index];
98   // Check.
99   EXPECT_LT(low_pitch_period, high_pitch_period);
100 }
101 
102 }  // namespace test
103 }  // namespace rnn_vad
104 }  // namespace webrtc
105