• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 
17 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/ctc/ctc_beam_search.h"
20 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace custom {
28 namespace ctc_beam_search_decoder {
29 
30 constexpr int kInputsTensor = 0;
31 constexpr int kSequenceLengthTensor = 1;
32 
33 typedef struct {
34   int beam_width;
35   int top_paths;
36   bool merge_repeated;
37 } CTCBeamSearchDecoderParams;
38 
Init(TfLiteContext * context,const char * buffer,size_t length)39 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
40   TFLITE_CHECK(buffer != nullptr);
41   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
42   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
43 
44   CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
45   option->beam_width = m["beam_width"].AsInt32();
46   option->top_paths = m["top_paths"].AsInt32();
47   option->merge_repeated = m["merge_repeated"].AsBool();
48 
49   return option;
50 }
51 
Free(TfLiteContext * context,void * buffer)52 void Free(TfLiteContext* context, void* buffer) {
53   delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
54 }
55 
Prepare(TfLiteContext * context,TfLiteNode * node)56 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
57   const CTCBeamSearchDecoderParams* option =
58       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
59   const int top_paths = option->top_paths;
60   TF_LITE_ENSURE(context, option->beam_width >= top_paths);
61   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
62   // The outputs should be top_paths * 3 + 1.
63   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
64 
65   const TfLiteTensor* inputs;
66   TF_LITE_ENSURE_OK(context,
67                     GetInputSafe(context, node, kInputsTensor, &inputs));
68   TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
69   // TensorFlow only supports float.
70   TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
71   const int batch_size = SizeOfDimension(inputs, 1);
72 
73   const TfLiteTensor* sequence_length;
74   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
75                                           &sequence_length));
76   TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
77   TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
78   // TensorFlow only supports int32.
79   TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
80 
81   // Resize decoded outputs.
82   // Do not resize indices & values cause we don't know the values yet.
83   for (int i = 0; i < top_paths; ++i) {
84     TfLiteTensor* indices;
85     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &indices));
86     SetTensorToDynamic(indices);
87     TfLiteTensor* values;
88     TF_LITE_ENSURE_OK(context,
89                       GetOutputSafe(context, node, i + top_paths, &values));
90     SetTensorToDynamic(values);
91     TfLiteTensor* output_shape;
92     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i + 2 * top_paths,
93                                              &output_shape));
94     SetTensorToDynamic(output_shape);
95   }
96 
97   // Resize log probability outputs.
98   TfLiteTensor* log_probability_output;
99   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, top_paths * 3,
100                                            &log_probability_output));
101   TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
102   log_probability_output_shape_array->data[0] = batch_size;
103   log_probability_output_shape_array->data[1] = top_paths;
104   return context->ResizeTensor(context, log_probability_output,
105                                log_probability_output_shape_array);
106 }
107 
Resize(TfLiteContext * context,std::initializer_list<int32_t> output_shape,TfLiteTensor * output)108 TfLiteStatus Resize(TfLiteContext* context,
109                     std::initializer_list<int32_t> output_shape,
110                     TfLiteTensor* output) {
111   const int dimensions = output_shape.size();
112   TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
113   int i = 0;
114   for (const int v : output_shape) {
115     output_shape_array->data[i++] = v;
116   }
117   return context->ResizeTensor(context, output, output_shape_array);
118 }
119 
StoreAllDecodedSequences(TfLiteContext * context,const std::vector<std::vector<std::vector<int>>> & sequences,TfLiteNode * node,int top_paths)120 TfLiteStatus StoreAllDecodedSequences(
121     TfLiteContext* context,
122     const std::vector<std::vector<std::vector<int>>>& sequences,
123     TfLiteNode* node, int top_paths) {
124   const int32_t batch_size = sequences.size();
125   std::vector<int32_t> num_entries(top_paths, 0);
126 
127   // Calculate num_entries per path
128   for (const auto& batch_s : sequences) {
129     TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
130     for (int p = 0; p < top_paths; ++p) {
131       num_entries[p] += batch_s[p].size();
132     }
133   }
134 
135   for (int p = 0; p < top_paths; ++p) {
136     const int32_t p_num = num_entries[p];
137 
138     // Resize the decoded outputs.
139     TfLiteTensor* indices;
140     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p, &indices));
141     TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
142 
143     TfLiteTensor* values;
144     TF_LITE_ENSURE_OK(context,
145                       GetOutputSafe(context, node, p + top_paths, &values));
146     TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
147 
148     TfLiteTensor* decoded_shape;
149     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p + 2 * top_paths,
150                                              &decoded_shape));
151     TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
152 
153     int32_t max_decoded = 0;
154     int32_t offset = 0;
155 
156     int32_t* indices_data = GetTensorData<int32_t>(indices);
157     int32_t* values_data = GetTensorData<int32_t>(values);
158     int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
159     for (int b = 0; b < batch_size; ++b) {
160       auto& p_batch = sequences[b][p];
161       int32_t num_decoded = p_batch.size();
162       max_decoded = std::max(max_decoded, num_decoded);
163 
164       std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
165       for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
166         indices_data[offset * 2] = b;
167         indices_data[offset * 2 + 1] = t;
168       }
169     }
170 
171     decoded_shape_data[0] = batch_size;
172     decoded_shape_data[1] = max_decoded;
173   }
174   return kTfLiteOk;
175 }
176 
Eval(TfLiteContext * context,TfLiteNode * node)177 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
178   const TfLiteTensor* inputs;
179   TF_LITE_ENSURE_OK(context,
180                     GetInputSafe(context, node, kInputsTensor, &inputs));
181   const TfLiteTensor* sequence_length;
182   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
183                                           &sequence_length));
184   const CTCBeamSearchDecoderParams* option =
185       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
186 
187   const int max_time = SizeOfDimension(inputs, 0);
188   const int batch_size = SizeOfDimension(inputs, 1);
189   const int num_classes = SizeOfDimension(inputs, 2);
190 
191   const int beam_width = option->beam_width;
192   const int top_paths = option->top_paths;
193   const bool merge_repeated = option->merge_repeated;
194 
195   // Validate sequence length is less or equal than max time.
196   for (int i = 0; i < batch_size; ++i) {
197     TF_LITE_ENSURE(context,
198                    max_time >= GetTensorData<int32_t>(sequence_length)[i]);
199   }
200 
201   // The following logic is implemented like
202   // tensorflow/core/kernels/ctc_decoder_ops.cc
203   std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
204 
205   for (std::size_t t = 0; t < max_time; ++t) {
206     input_list_t.emplace_back(
207         GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
208         num_classes);
209   }
210 
211   ::tflite::custom::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer;
212   ::tflite::custom::ctc::CTCBeamSearchDecoder<> beam_search(
213       num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
214       merge_repeated);
215 
216   // Allocate temporary memory for holding chip operation data.
217   float* input_chip_t_data =
218       static_cast<float*>(malloc(num_classes * sizeof(float)));
219   Eigen::array<Eigen::DenseIndex, 1> dims;
220   dims[0] = num_classes;
221   optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
222 
223   std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
224   std::vector<float> log_probs;
225 
226   TfLiteTensor* log_probabilities;
227   TF_LITE_ENSURE_OK(
228       context, GetOutputSafe(context, node, 3 * top_paths, &log_probabilities));
229   float* log_probabilities_output = GetTensorData<float>(log_probabilities);
230 
231   // Assumption: the blank index is num_classes - 1
232   for (int b = 0; b < batch_size; ++b) {
233     auto& best_paths_b = best_paths[b];
234     best_paths_b.resize(top_paths);
235     for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
236       input_chip_t = input_list_t[t].chip(b, 0);
237       auto input_bi =
238           Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
239       beam_search.Step(input_bi);
240     }
241     TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
242                                                  &log_probs, merge_repeated));
243     beam_search.Reset();
244 
245     // Fill in log_probabilities output.
246     for (int bp = 0; bp < top_paths; ++bp) {
247       log_probabilities_output[b * top_paths + bp] = log_probs[bp];
248     }
249   }
250 
251   free(input_chip_t_data);
252   return StoreAllDecodedSequences(context, best_paths, node, top_paths);
253 }
254 
255 }  // namespace ctc_beam_search_decoder
256 
Register_CTC_BEAM_SEARCH_DECODER()257 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
258   static TfLiteRegistration r = {
259       ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
260       ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
261   return &r;
262 }
263 
264 }  // namespace custom
265 }  // namespace ops
266 }  // namespace tflite
267