1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 13 14 #include <stddef.h> 15 #include <sys/types.h> 16 17 #include <array> 18 #include <vector> 19 20 #include "api/array_view.h" 21 #include "api/function_view.h" 22 #include "modules/audio_processing/agc2/rnn_vad/common.h" 23 #include "rtc_base/system/arch.h" 24 25 namespace webrtc { 26 namespace rnn_vad { 27 28 // Maximum number of units for a fully-connected layer. This value is used to 29 // over-allocate space for fully-connected layers output vectors (implemented as 30 // std::array). The value should equal the number of units of the largest 31 // fully-connected layer. 32 constexpr size_t kFullyConnectedLayersMaxUnits = 24; 33 34 // Maximum number of units for a recurrent layer. This value is used to 35 // over-allocate space for recurrent layers state vectors (implemented as 36 // std::array). The value should equal the number of units of the largest 37 // recurrent layer. 38 constexpr size_t kRecurrentLayersMaxUnits = 24; 39 40 // Fully-connected layer. 41 class FullyConnectedLayer { 42 public: 43 FullyConnectedLayer(size_t input_size, 44 size_t output_size, 45 rtc::ArrayView<const int8_t> bias, 46 rtc::ArrayView<const int8_t> weights, 47 rtc::FunctionView<float(float)> activation_function, 48 Optimization optimization); 49 FullyConnectedLayer(const FullyConnectedLayer&) = delete; 50 FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; 51 ~FullyConnectedLayer(); input_size()52 size_t input_size() const { return input_size_; } output_size()53 size_t output_size() const { return output_size_; } optimization()54 Optimization optimization() const { return optimization_; } 55 rtc::ArrayView<const float> GetOutput() const; 56 // Computes the fully-connected layer output. 57 void ComputeOutput(rtc::ArrayView<const float> input); 58 59 private: 60 const size_t input_size_; 61 const size_t output_size_; 62 const std::vector<float> bias_; 63 const std::vector<float> weights_; 64 rtc::FunctionView<float(float)> activation_function_; 65 // The output vector of a recurrent layer has length equal to |output_size_|. 66 // However, for efficiency, over-allocation is used. 67 std::array<float, kFullyConnectedLayersMaxUnits> output_; 68 const Optimization optimization_; 69 }; 70 71 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as 72 // activation functions for the update/reset and output gates respectively. 73 class GatedRecurrentLayer { 74 public: 75 GatedRecurrentLayer(size_t input_size, 76 size_t output_size, 77 rtc::ArrayView<const int8_t> bias, 78 rtc::ArrayView<const int8_t> weights, 79 rtc::ArrayView<const int8_t> recurrent_weights, 80 Optimization optimization); 81 GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; 82 GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; 83 ~GatedRecurrentLayer(); input_size()84 size_t input_size() const { return input_size_; } output_size()85 size_t output_size() const { return output_size_; } optimization()86 Optimization optimization() const { return optimization_; } 87 rtc::ArrayView<const float> GetOutput() const; 88 void Reset(); 89 // Computes the recurrent layer output and updates the status. 90 void ComputeOutput(rtc::ArrayView<const float> input); 91 92 private: 93 const size_t input_size_; 94 const size_t output_size_; 95 const std::vector<float> bias_; 96 const std::vector<float> weights_; 97 const std::vector<float> recurrent_weights_; 98 // The state vector of a recurrent layer has length equal to |output_size_|. 99 // However, to avoid dynamic allocation, over-allocation is used. 100 std::array<float, kRecurrentLayersMaxUnits> state_; 101 const Optimization optimization_; 102 }; 103 104 // Recurrent network based VAD. 105 class RnnBasedVad { 106 public: 107 RnnBasedVad(); 108 RnnBasedVad(const RnnBasedVad&) = delete; 109 RnnBasedVad& operator=(const RnnBasedVad&) = delete; 110 ~RnnBasedVad(); 111 void Reset(); 112 // Compute and returns the probability of voice (range: [0.0, 1.0]). 113 float ComputeVadProbability( 114 rtc::ArrayView<const float, kFeatureVectorSize> feature_vector, 115 bool is_silence); 116 117 private: 118 FullyConnectedLayer input_layer_; 119 GatedRecurrentLayer hidden_layer_; 120 FullyConnectedLayer output_layer_; 121 }; 122 123 } // namespace rnn_vad 124 } // namespace webrtc 125 126 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 127