1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
12
13 #include <memory>
14
15 #include "rtc_base/checks.h"
16 #include "rtc_base/system/arch.h"
17 #include "system_wrappers/include/cpu_features_wrapper.h"
18 #include "test/gtest.h"
19 #include "test/testsupport/file_utils.h"
20
21 namespace webrtc {
22 namespace rnn_vad {
23 namespace test {
24 namespace {
25
26 using ReaderPairType =
27 std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>;
28
29 } // namespace
30
31 using webrtc::test::ResourcePath;
32
ExpectEqualFloatArray(rtc::ArrayView<const float> expected,rtc::ArrayView<const float> computed)33 void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
34 rtc::ArrayView<const float> computed) {
35 ASSERT_EQ(expected.size(), computed.size());
36 for (size_t i = 0; i < expected.size(); ++i) {
37 SCOPED_TRACE(i);
38 EXPECT_FLOAT_EQ(expected[i], computed[i]);
39 }
40 }
41
ExpectNearAbsolute(rtc::ArrayView<const float> expected,rtc::ArrayView<const float> computed,float tolerance)42 void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
43 rtc::ArrayView<const float> computed,
44 float tolerance) {
45 ASSERT_EQ(expected.size(), computed.size());
46 for (size_t i = 0; i < expected.size(); ++i) {
47 SCOPED_TRACE(i);
48 EXPECT_NEAR(expected[i], computed[i], tolerance);
49 }
50 }
51
52 std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length)53 CreatePcmSamplesReader(const size_t frame_length) {
54 auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
55 test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
56 frame_length);
57 // The last incomplete frame is ignored.
58 return {std::move(ptr), ptr->data_length() / frame_length};
59 }
60
CreatePitchBuffer24kHzReader()61 ReaderPairType CreatePitchBuffer24kHzReader() {
62 constexpr size_t cols = 864;
63 auto ptr = std::make_unique<BinaryFileReader<float>>(
64 ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
65 return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
66 }
67
CreateLpResidualAndPitchPeriodGainReader()68 ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
69 constexpr size_t num_lp_residual_coeffs = 864;
70 auto ptr = std::make_unique<BinaryFileReader<float>>(
71 ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
72 num_lp_residual_coeffs);
73 return {std::move(ptr),
74 rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
75 }
76
CreateVadProbsReader()77 ReaderPairType CreateVadProbsReader() {
78 auto ptr = std::make_unique<BinaryFileReader<float>>(
79 test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
80 return {std::move(ptr), ptr->data_length()};
81 }
82
PitchTestData()83 PitchTestData::PitchTestData() {
84 BinaryFileReader<float> test_data_reader(
85 ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
86 static_cast<size_t>(1396));
87 test_data_reader.ReadChunk(test_data_);
88 }
89
90 PitchTestData::~PitchTestData() = default;
91
GetPitchBufView() const92 rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
93 const {
94 return {test_data_.data(), kBufSize24kHz};
95 }
96
97 rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
GetPitchBufSquareEnergiesView() const98 PitchTestData::GetPitchBufSquareEnergiesView() const {
99 return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
100 }
101
102 rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
GetPitchBufAutoCorrCoeffsView() const103 PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
104 return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
105 kNumPitchBufAutoCorrCoeffs};
106 }
107
IsOptimizationAvailable(Optimization optimization)108 bool IsOptimizationAvailable(Optimization optimization) {
109 switch (optimization) {
110 case Optimization::kSse2:
111 #if defined(WEBRTC_ARCH_X86_FAMILY)
112 return WebRtc_GetCPUInfo(kSSE2) != 0;
113 #else
114 return false;
115 #endif
116 case Optimization::kNeon:
117 #if defined(WEBRTC_HAS_NEON)
118 return true;
119 #else
120 return false;
121 #endif
122 case Optimization::kNone:
123 return true;
124 }
125 }
126
127 } // namespace test
128 } // namespace rnn_vad
129 } // namespace webrtc
130