1 // Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2 //
3 // Use of this source code is governed by a BSD-style license
4 // that can be found in the LICENSE file in the root of the source
5 // tree. An additional intellectual property rights grant can be found
6 // in the file PATENTS. All contributing project authors may
7 // be found in the AUTHORS file in the root of the source tree.
8
9 #include "common_audio/vad/include/vad.h"
10
11 #include <array>
12 #include <fstream>
13 #include <memory>
14
15 #include "absl/flags/flag.h"
16 #include "absl/flags/parse.h"
17 #include "common_audio/wav_file.h"
18 #include "rtc_base/logging.h"
19
20 ABSL_FLAG(std::string, i, "", "Input wav file");
21 ABSL_FLAG(std::string, o, "", "VAD output file");
22
23 namespace webrtc {
24 namespace test {
25 namespace {
26
27 // The allowed values are 10, 20 or 30 ms.
28 constexpr uint8_t kAudioFrameLengthMilliseconds = 30;
29 constexpr int kMaxSampleRate = 48000;
30 constexpr size_t kMaxFrameLen =
31 kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
32
33 constexpr uint8_t kBitmaskBuffSize = 8;
34
main(int argc,char * argv[])35 int main(int argc, char* argv[]) {
36 absl::ParseCommandLine(argc, argv);
37 const std::string input_file = absl::GetFlag(FLAGS_i);
38 const std::string output_file = absl::GetFlag(FLAGS_o);
39 // Open wav input file and check properties.
40 WavReader wav_reader(input_file);
41 if (wav_reader.num_channels() != 1) {
42 RTC_LOG(LS_ERROR) << "Only mono wav files supported";
43 return 1;
44 }
45 if (wav_reader.sample_rate() > kMaxSampleRate) {
46 RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
47 << ")";
48 return 1;
49 }
50 const size_t audio_frame_length = rtc::CheckedDivExact(
51 kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
52 if (audio_frame_length > kMaxFrameLen) {
53 RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
54 return 1;
55 }
56
57 // Create output file and write header.
58 std::ofstream out_file(output_file, std::ofstream::binary);
59 const char audio_frame_length_ms = kAudioFrameLengthMilliseconds;
60 out_file.write(&audio_frame_length_ms, 1); // Header.
61
62 // Run VAD and write decisions.
63 std::unique_ptr<Vad> vad = CreateVad(Vad::Aggressiveness::kVadNormal);
64 std::array<int16_t, kMaxFrameLen> samples;
65 char buff = 0; // Buffer to write one bit per frame.
66 uint8_t next = 0; // Points to the next bit to write in `buff`.
67 while (true) {
68 // Process frame.
69 const auto read_samples =
70 wav_reader.ReadSamples(audio_frame_length, samples.data());
71 if (read_samples < audio_frame_length)
72 break;
73 const auto is_speech = vad->VoiceActivity(
74 samples.data(), audio_frame_length, wav_reader.sample_rate());
75
76 // Write output.
77 buff = is_speech ? buff | (1 << next) : buff & ~(1 << next);
78 if (++next == kBitmaskBuffSize) {
79 out_file.write(&buff, 1); // Flush.
80 buff = 0; // Reset.
81 next = 0;
82 }
83 }
84
85 // Finalize.
86 char extra_bits = 0;
87 if (next > 0) {
88 extra_bits = kBitmaskBuffSize - next;
89 out_file.write(&buff, 1); // Flush.
90 }
91 out_file.write(&extra_bits, 1);
92 out_file.close();
93
94 return 0;
95 }
96
97 } // namespace
98 } // namespace test
99 } // namespace webrtc
100
main(int argc,char * argv[])101 int main(int argc, char* argv[]) {
102 return webrtc::test::main(argc, argv);
103 }
104