1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
12
13 #include <cmath>
14 #include <vector>
15
16 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
17 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
18 // #include "test/fpe_observer.h"
19 #include "test/gtest.h"
20
21 namespace webrtc {
22 namespace rnn_vad {
23 namespace test {
24 namespace {
25
ceil(size_t n,size_t m)26 constexpr size_t ceil(size_t n, size_t m) {
27 return (n + m - 1) / m;
28 }
29
30 // Number of 10 ms frames required to fill a pitch buffer having size
31 // |kBufSize24kHz|.
32 constexpr size_t kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
33 // Number of samples for the test data.
34 constexpr size_t kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
35
36 // Verifies that the pitch in Hz is in the detectable range.
PitchIsValid(float pitch_hz)37 bool PitchIsValid(float pitch_hz) {
38 const size_t pitch_period =
39 static_cast<size_t>(static_cast<float>(kSampleRate24kHz) / pitch_hz);
40 return kInitialMinPitch24kHz <= pitch_period &&
41 pitch_period <= kMaxPitch24kHz;
42 }
43
CreatePureTone(float amplitude,float freq_hz,rtc::ArrayView<float> dst)44 void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
45 for (size_t i = 0; i < dst.size(); ++i) {
46 dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
47 }
48 }
49
50 // Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
51 // For every frame, the output is written into |feature_vector|. Returns true
52 // if silence is detected in the last frame.
FeedTestData(FeaturesExtractor * features_extractor,rtc::ArrayView<const float> samples,rtc::ArrayView<float,kFeatureVectorSize> feature_vector)53 bool FeedTestData(FeaturesExtractor* features_extractor,
54 rtc::ArrayView<const float> samples,
55 rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
56 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
57 // FloatingPointExceptionObserver fpe_observer;
58 bool is_silence = true;
59 const size_t num_frames = samples.size() / kFrameSize10ms24kHz;
60 for (size_t i = 0; i < num_frames; ++i) {
61 is_silence = features_extractor->CheckSilenceComputeFeatures(
62 {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
63 feature_vector);
64 }
65 return is_silence;
66 }
67
68 } // namespace
69
70 // Extracts the features for two pure tones and verifies that the pitch field
71 // values reflect the known tone frequencies.
TEST(RnnVadTest,FeatureExtractionLowHighPitch)72 TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
73 constexpr float amplitude = 1000.f;
74 constexpr float low_pitch_hz = 150.f;
75 constexpr float high_pitch_hz = 250.f;
76 ASSERT_TRUE(PitchIsValid(low_pitch_hz));
77 ASSERT_TRUE(PitchIsValid(high_pitch_hz));
78
79 FeaturesExtractor features_extractor;
80 std::vector<float> samples(kNumTestDataSize);
81 std::vector<float> feature_vector(kFeatureVectorSize);
82 ASSERT_EQ(kFeatureVectorSize, feature_vector.size());
83 rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
84 feature_vector.data(), kFeatureVectorSize);
85
86 // Extract the normalized scalar feature that is proportional to the estimated
87 // pitch period.
88 constexpr size_t pitch_feature_index = kFeatureVectorSize - 2;
89 // Low frequency tone - i.e., high period.
90 CreatePureTone(amplitude, low_pitch_hz, samples);
91 ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
92 float high_pitch_period = feature_vector_view[pitch_feature_index];
93 // High frequency tone - i.e., low period.
94 features_extractor.Reset();
95 CreatePureTone(amplitude, high_pitch_hz, samples);
96 ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
97 float low_pitch_period = feature_vector_view[pitch_feature_index];
98 // Check.
99 EXPECT_LT(low_pitch_period, high_pitch_period);
100 }
101
102 } // namespace test
103 } // namespace rnn_vad
104 } // namespace webrtc
105