• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <string>
7 #include <map>
8 #include <vector>
9 #include <algorithm>
10 #include <cmath>
11 
12 # pragma once
13 
14 namespace asr
15 {
16 /**
17 * @brief Class used to Decode the output of the ASR inference
18 *
19 */
20     class Decoder
21     {
22     public:
23         std::map<int, std::string> m_labels;
24         /**
25         * @brief Default constructor
26         * @param[in] labels - map of labels to be used for decoding to text.
27         */
28         Decoder(std::map<int, std::string>& labels);
29 
30         /**
31         * @brief Function to decode the output into a text string
32         * @param[in] output - the output vector to decode.
33         */
34         template<typename T>
DecodeOutput(std::vector<T> & contextToProcess)35         std::string DecodeOutput(std::vector<T>& contextToProcess)
36         {
37             int rowLength = 29;
38 
39             std::vector<char> unfilteredText;
40 
41             for(int row = 0; row < contextToProcess.size()/rowLength; ++row)
42             {
43                 std::vector<int16_t> rowVector;
44                 for(int j = 0; j < rowLength; ++j)
45                 {
46                     rowVector.emplace_back(static_cast<int16_t>(contextToProcess[row * rowLength + j]));
47                 }
48 
49                 int maxIndex = std::distance(rowVector.begin(), std::max_element(rowVector.begin(), rowVector.end()));
50                 unfilteredText.emplace_back(this->m_labels.at(maxIndex)[0]);
51             }
52 
53             std::string filteredText = FilterCharacters(unfilteredText);
54             return filteredText;
55         }
56 
57         /**
58         * @brief Function to filter out unwanted characters
59         * @param[in] unfiltered - the unfiltered output to be processed.
60         */
61         std::string FilterCharacters(std::vector<char>& unfiltered);
62     };
63 } // namespace asr
64