• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ArmnnNetworkExecutor.hpp"
9 #include "Decoder.hpp"
10 #include "MFCC.hpp"
11 #include "Wav2LetterPreprocessor.hpp"
12 
13 namespace asr
14 {
15 /**
16  * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference
17  * result post-processing.
18  *
19  */
20 class ASRPipeline
21 {
22 public:
23 
24     /**
25      * Creates speech recognition pipeline with given network executor and decoder.
26      * @param executor - unique pointer to inference runner
27      * @param decoder - unique pointer to inference results decoder
28      */
29     ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
30                 std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor);
31 
32     /**
33      * @brief Standard audio pre-processing implementation.
34      *
35      * Preprocesses and prepares the data for inference by
36      * extracting the MFCC features.
37 
38      * @param[in] audio - the raw audio data
39      * @param[out] preprocessor - the preprocessor object, which handles the data preparation
40      */
41     std::vector<int8_t> PreProcessing(std::vector<float>& audio);
42 
43     int getInputSamplesSize();
44     int getSlidingWindowOffset();
45 
46     // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself
47     // Will need to be refactored so that hard coded values are not defined outside of model settings
48     int SLIDING_WINDOW_OFFSET;
49 
50     /**
51      * @brief Executes inference
52      *
53      * Calls inference runner provided during instance construction.
54      *
55      * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor.
56      * @param[out] result - raw inference results.
57      */
58     template<typename T>
Inference(const std::vector<T> & preprocessedData,common::InferenceResults<int8_t> & result)59     void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result)
60     {
61         size_t data_bytes = sizeof(T) * preprocessedData.size();
62         m_executor->Run(preprocessedData.data(), data_bytes, result);
63     }
64 
65     /**
66      * @brief Standard inference results post-processing implementation.
67      *
68      * Decodes inference results using decoder provided during construction.
69      *
70      * @param[in] inferenceResult - inference results to be decoded.
71      * @param[in] isFirstWindow - for checking if this is the first window of the sliding window.
72      * @param[in] isLastWindow - for checking if this is the last window of the sliding window.
73      * @param[in] currentRContext - the right context of the output text. To be output if it is the last window.
74      */
75     template<typename T>
PostProcessing(common::InferenceResults<int8_t> & inferenceResult,bool & isFirstWindow,bool isLastWindow,std::string currentRContext)76     void PostProcessing(common::InferenceResults<int8_t>& inferenceResult,
77                         bool& isFirstWindow,
78                         bool isLastWindow,
79                         std::string currentRContext)
80     {
81         int rowLength = 29;
82         int middleContextStart = 49;
83         int middleContextEnd = 99;
84         int leftContextStart = 0;
85         int rightContextStart = 100;
86         int rightContextEnd = 148;
87 
88         std::vector<T> contextToProcess;
89 
90         // If isFirstWindow we keep the left context of the output
91         if (isFirstWindow)
92         {
93             std::vector<T> chunk(&inferenceResult[0][leftContextStart],
94                                  &inferenceResult[0][middleContextEnd * rowLength]);
95             contextToProcess = chunk;
96         }
97         else
98         {
99             // Else we only keep the middle context of the output
100             std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength],
101                                  &inferenceResult[0][middleContextEnd * rowLength]);
102             contextToProcess = chunk;
103         }
104         std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess);
105         isFirstWindow = false;
106         std::cout << output << std::flush;
107 
108         // If this is the last window, we print the right context of the output
109         if (isLastWindow)
110         {
111             std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength],
112                                     &inferenceResult[0][rightContextEnd * rowLength]);
113             currentRContext = this->m_decoder->DecodeOutput(rContext);
114             std::cout << currentRContext << std::endl;
115         }
116     }
117 
118 protected:
119     std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
120     std::unique_ptr<Decoder> m_decoder;
121     std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor;
122 };
123 
124 using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>;
125 
126 /**
127  * Constructs speech recognition pipeline based on configuration provided.
128  *
129  * @param[in] config - speech recognition pipeline configuration.
130  * @param[in] labels - asr labels
131  *
132  * @return unique pointer to asr pipeline.
133  */
134 IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels);
135 
136 } // namespace asr