• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 
17 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/experimental/kernels/ctc_beam_search.h"
20 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace experimental {
28 namespace ctc_beam_search_decoder {
29 
30 constexpr int kInputsTensor = 0;
31 constexpr int kSequenceLengthTensor = 1;
32 
33 typedef struct {
34   int beam_width;
35   int top_paths;
36   bool merge_repeated;
37 } CTCBeamSearchDecoderParams;
38 
Init(TfLiteContext * context,const char * buffer,size_t length)39 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
40   TFLITE_CHECK(buffer != nullptr);
41   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
42   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
43 
44   CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
45   option->beam_width = m["beam_width"].AsInt32();
46   option->top_paths = m["top_paths"].AsInt32();
47   option->merge_repeated = m["merge_repeated"].AsBool();
48 
49   return option;
50 }
51 
Free(TfLiteContext * context,void * buffer)52 void Free(TfLiteContext* context, void* buffer) {
53   delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
54 }
55 
Prepare(TfLiteContext * context,TfLiteNode * node)56 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
57   const CTCBeamSearchDecoderParams* option =
58       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
59   const int top_paths = option->top_paths;
60   TF_LITE_ENSURE(context, option->beam_width >= top_paths);
61   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
62   // The outputs should be top_paths * 3 + 1.
63   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
64 
65   const TfLiteTensor* inputs;
66   TF_LITE_ENSURE_OK(context,
67                     GetInputSafe(context, node, kInputsTensor, &inputs));
68   TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
69   // TensorFlow only supports float.
70   TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
71   const int batch_size = SizeOfDimension(inputs, 1);
72 
73   const TfLiteTensor* sequence_length;
74   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
75                                           &sequence_length));
76   TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
77   TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
78   // TensorFlow only supports int32.
79   TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
80 
81   // Resize decoded outputs.
82   // Do not resize indices & values cause we don't know the values yet.
83   for (int i = 0; i < top_paths; ++i) {
84     TfLiteTensor* indices;
85     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &indices));
86     SetTensorToDynamic(indices);
87     TfLiteTensor* values;
88     TF_LITE_ENSURE_OK(context,
89                       GetOutputSafe(context, node, i + top_paths, &values));
90     SetTensorToDynamic(values);
91     TfLiteTensor* output_shape;
92     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i + 2 * top_paths,
93                                              &output_shape));
94     SetTensorToDynamic(output_shape);
95   }
96 
97   // Resize log probability outputs.
98   TfLiteTensor* log_probability_output;
99   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, top_paths * 3,
100                                            &log_probability_output));
101   TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
102   log_probability_output_shape_array->data[0] = batch_size;
103   log_probability_output_shape_array->data[1] = top_paths;
104   return context->ResizeTensor(context, log_probability_output,
105                                log_probability_output_shape_array);
106 }
107 
Resize(TfLiteContext * context,std::initializer_list<int32_t> output_shape,TfLiteTensor * output)108 TfLiteStatus Resize(TfLiteContext* context,
109                     std::initializer_list<int32_t> output_shape,
110                     TfLiteTensor* output) {
111   const int dimensions = output_shape.size();
112   TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
113   int i = 0;
114   for (const int v : output_shape) {
115     output_shape_array->data[i++] = v;
116   }
117   return context->ResizeTensor(context, output, output_shape_array);
118 }
119 
StoreAllDecodedSequences(TfLiteContext * context,const std::vector<std::vector<std::vector<int>>> & sequences,TfLiteNode * node,int top_paths)120 TfLiteStatus StoreAllDecodedSequences(
121     TfLiteContext* context,
122     const std::vector<std::vector<std::vector<int>>>& sequences,
123     TfLiteNode* node, int top_paths) {
124   const int32_t batch_size = sequences.size();
125   std::vector<int32_t> num_entries(top_paths, 0);
126 
127   // Calculate num_entries per path
128   for (const auto& batch_s : sequences) {
129     TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
130     for (int p = 0; p < top_paths; ++p) {
131       num_entries[p] += batch_s[p].size();
132     }
133   }
134 
135   for (int p = 0; p < top_paths; ++p) {
136     const int32_t p_num = num_entries[p];
137 
138     // Resize the decoded outputs.
139     TfLiteTensor* indices;
140     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p, &indices));
141     TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
142 
143     TfLiteTensor* values;
144     TF_LITE_ENSURE_OK(context,
145                       GetOutputSafe(context, node, p + top_paths, &values));
146     TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
147 
148     TfLiteTensor* decoded_shape;
149     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p + 2 * top_paths,
150                                              &decoded_shape));
151     TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
152 
153     int32_t max_decoded = 0;
154     int32_t offset = 0;
155 
156     int32_t* indices_data = GetTensorData<int32_t>(indices);
157     int32_t* values_data = GetTensorData<int32_t>(values);
158     int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
159     for (int b = 0; b < batch_size; ++b) {
160       auto& p_batch = sequences[b][p];
161       int32_t num_decoded = p_batch.size();
162       max_decoded = std::max(max_decoded, num_decoded);
163 
164       std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
165       for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
166         indices_data[offset * 2] = b;
167         indices_data[offset * 2 + 1] = t;
168       }
169     }
170 
171     decoded_shape_data[0] = batch_size;
172     decoded_shape_data[1] = max_decoded;
173   }
174   return kTfLiteOk;
175 }
176 
Eval(TfLiteContext * context,TfLiteNode * node)177 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
178   const TfLiteTensor* inputs;
179   TF_LITE_ENSURE_OK(context,
180                     GetInputSafe(context, node, kInputsTensor, &inputs));
181   const TfLiteTensor* sequence_length;
182   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
183                                           &sequence_length));
184   const CTCBeamSearchDecoderParams* option =
185       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
186 
187   const int max_time = SizeOfDimension(inputs, 0);
188   const int batch_size = SizeOfDimension(inputs, 1);
189   const int num_classes = SizeOfDimension(inputs, 2);
190 
191   const int beam_width = option->beam_width;
192   const int top_paths = option->top_paths;
193   const bool merge_repeated = option->merge_repeated;
194 
195   // Validate sequence length is less or equal than max time.
196   for (int i = 0; i < batch_size; ++i) {
197     TF_LITE_ENSURE(context,
198                    max_time >= GetTensorData<int32_t>(sequence_length)[i]);
199   }
200 
201   // The following logic is implemented like
202   // tensorflow/core/kernels/ctc_decoder_ops.cc
203   std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
204 
205   for (std::size_t t = 0; t < max_time; ++t) {
206     input_list_t.emplace_back(
207         GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
208         num_classes);
209   }
210 
211   ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer
212       beam_scorer;
213   ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search(
214       num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
215       merge_repeated);
216 
217   // Allocate temporary memory for holding chip operation data.
218   float* input_chip_t_data =
219       static_cast<float*>(malloc(num_classes * sizeof(float)));
220   Eigen::array<Eigen::DenseIndex, 1> dims;
221   dims[0] = num_classes;
222   optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
223 
224   std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
225   std::vector<float> log_probs;
226 
227   TfLiteTensor* log_probabilities;
228   TF_LITE_ENSURE_OK(
229       context, GetOutputSafe(context, node, 3 * top_paths, &log_probabilities));
230   float* log_probabilities_output = GetTensorData<float>(log_probabilities);
231 
232   // Assumption: the blank index is num_classes - 1
233   for (int b = 0; b < batch_size; ++b) {
234     auto& best_paths_b = best_paths[b];
235     best_paths_b.resize(top_paths);
236     for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
237       input_chip_t = input_list_t[t].chip(b, 0);
238       auto input_bi =
239           Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
240       beam_search.Step(input_bi);
241     }
242     TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
243                                                  &log_probs, merge_repeated));
244     beam_search.Reset();
245 
246     // Fill in log_probabilities output.
247     for (int bp = 0; bp < top_paths; ++bp) {
248       log_probabilities_output[b * top_paths + bp] = log_probs[bp];
249     }
250   }
251 
252   free(input_chip_t_data);
253   return StoreAllDecodedSequences(context, best_paths, node, top_paths);
254 }
255 
256 }  // namespace ctc_beam_search_decoder
257 
Register_CTC_BEAM_SEARCH_DECODER()258 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
259   static TfLiteRegistration r = {
260       ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
261       ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
262   return &r;
263 }
264 
265 }  // namespace experimental
266 }  // namespace ops
267 }  // namespace tflite
268