1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <string> 7 #include <map> 8 #include <vector> 9 #include <algorithm> 10 #include <cmath> 11 12 # pragma once 13 14 namespace asr 15 { 16 /** 17 * @brief Class used to Decode the output of the ASR inference 18 * 19 */ 20 class Decoder 21 { 22 public: 23 std::map<int, std::string> m_labels; 24 /** 25 * @brief Default constructor 26 * @param[in] labels - map of labels to be used for decoding to text. 27 */ 28 Decoder(std::map<int, std::string>& labels); 29 30 /** 31 * @brief Function to decode the output into a text string 32 * @param[in] output - the output vector to decode. 33 */ 34 template<typename T> DecodeOutput(std::vector<T> & contextToProcess)35 std::string DecodeOutput(std::vector<T>& contextToProcess) 36 { 37 int rowLength = 29; 38 39 std::vector<char> unfilteredText; 40 41 for(int row = 0; row < contextToProcess.size()/rowLength; ++row) 42 { 43 std::vector<int16_t> rowVector; 44 for(int j = 0; j < rowLength; ++j) 45 { 46 rowVector.emplace_back(static_cast<int16_t>(contextToProcess[row * rowLength + j])); 47 } 48 49 int maxIndex = std::distance(rowVector.begin(), std::max_element(rowVector.begin(), rowVector.end())); 50 unfilteredText.emplace_back(this->m_labels.at(maxIndex)[0]); 51 } 52 53 std::string filteredText = FilterCharacters(unfilteredText); 54 return filteredText; 55 } 56 57 /** 58 * @brief Function to filter out unwanted characters 59 * @param[in] unfiltered - the unfiltered output to be processed. 60 */ 61 std::string FilterCharacters(std::vector<char>& unfiltered); 62 }; 63 } // namespace asr 64