1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ArmnnNetworkExecutor.hpp" 9 #include "Decoder.hpp" 10 #include "MFCC.hpp" 11 #include "Wav2LetterPreprocessor.hpp" 12 13 namespace asr 14 { 15 /** 16 * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference 17 * result post-processing. 18 * 19 */ 20 class ASRPipeline 21 { 22 public: 23 24 /** 25 * Creates speech recognition pipeline with given network executor and decoder. 26 * @param executor - unique pointer to inference runner 27 * @param decoder - unique pointer to inference results decoder 28 */ 29 ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor, 30 std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor); 31 32 /** 33 * @brief Standard audio pre-processing implementation. 34 * 35 * Preprocesses and prepares the data for inference by 36 * extracting the MFCC features. 37 38 * @param[in] audio - the raw audio data 39 * @param[out] preprocessor - the preprocessor object, which handles the data preparation 40 */ 41 std::vector<int8_t> PreProcessing(std::vector<float>& audio); 42 43 int getInputSamplesSize(); 44 int getSlidingWindowOffset(); 45 46 // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself 47 // Will need to be refactored so that hard coded values are not defined outside of model settings 48 int SLIDING_WINDOW_OFFSET; 49 50 /** 51 * @brief Executes inference 52 * 53 * Calls inference runner provided during instance construction. 54 * 55 * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor. 56 * @param[out] result - raw inference results. 57 */ 58 template<typename T> Inference(const std::vector<T> & preprocessedData,common::InferenceResults<int8_t> & result)59 void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result) 60 { 61 size_t data_bytes = sizeof(T) * preprocessedData.size(); 62 m_executor->Run(preprocessedData.data(), data_bytes, result); 63 } 64 65 /** 66 * @brief Standard inference results post-processing implementation. 67 * 68 * Decodes inference results using decoder provided during construction. 69 * 70 * @param[in] inferenceResult - inference results to be decoded. 71 * @param[in] isFirstWindow - for checking if this is the first window of the sliding window. 72 * @param[in] isLastWindow - for checking if this is the last window of the sliding window. 73 * @param[in] currentRContext - the right context of the output text. To be output if it is the last window. 74 */ 75 template<typename T> PostProcessing(common::InferenceResults<int8_t> & inferenceResult,bool & isFirstWindow,bool isLastWindow,std::string currentRContext)76 void PostProcessing(common::InferenceResults<int8_t>& inferenceResult, 77 bool& isFirstWindow, 78 bool isLastWindow, 79 std::string currentRContext) 80 { 81 int rowLength = 29; 82 int middleContextStart = 49; 83 int middleContextEnd = 99; 84 int leftContextStart = 0; 85 int rightContextStart = 100; 86 int rightContextEnd = 148; 87 88 std::vector<T> contextToProcess; 89 90 // If isFirstWindow we keep the left context of the output 91 if (isFirstWindow) 92 { 93 std::vector<T> chunk(&inferenceResult[0][leftContextStart], 94 &inferenceResult[0][middleContextEnd * rowLength]); 95 contextToProcess = chunk; 96 } 97 else 98 { 99 // Else we only keep the middle context of the output 100 std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength], 101 &inferenceResult[0][middleContextEnd * rowLength]); 102 contextToProcess = chunk; 103 } 104 std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess); 105 isFirstWindow = false; 106 std::cout << output << std::flush; 107 108 // If this is the last window, we print the right context of the output 109 if (isLastWindow) 110 { 111 std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength], 112 &inferenceResult[0][rightContextEnd * rowLength]); 113 currentRContext = this->m_decoder->DecodeOutput(rContext); 114 std::cout << currentRContext << std::endl; 115 } 116 } 117 118 protected: 119 std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor; 120 std::unique_ptr<Decoder> m_decoder; 121 std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor; 122 }; 123 124 using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>; 125 126 /** 127 * Constructs speech recognition pipeline based on configuration provided. 128 * 129 * @param[in] config - speech recognition pipeline configuration. 130 * @param[in] labels - asr labels 131 * 132 * @return unique pointer to asr pipeline. 133 */ 134 IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels); 135 136 } // namespace asr