• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 #include "flatbuffers/flexbuffers.h"  // TF:flatbuffers
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/experimental/kernels/ctc_beam_search.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace experimental {
27 namespace ctc_beam_search_decoder {
28 
29 constexpr int kInputsTensor = 0;
30 constexpr int kSequenceLengthTensor = 1;
31 
32 typedef struct {
33   int beam_width;
34   int top_paths;
35   bool merge_repeated;
36 } CTCBeamSearchDecoderParams;
37 
Init(TfLiteContext * context,const char * buffer,size_t length)38 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
39   TFLITE_CHECK(buffer != nullptr);
40   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
41   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
42 
43   CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
44   option->beam_width = m["beam_width"].AsInt32();
45   option->top_paths = m["top_paths"].AsInt32();
46   option->merge_repeated = m["merge_repeated"].AsBool();
47 
48   return option;
49 }
50 
Free(TfLiteContext * context,void * buffer)51 void Free(TfLiteContext* context, void* buffer) {
52   delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
53 }
54 
Prepare(TfLiteContext * context,TfLiteNode * node)55 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
56   const CTCBeamSearchDecoderParams* option =
57       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
58   const int top_paths = option->top_paths;
59   TF_LITE_ENSURE(context, option->beam_width >= top_paths);
60   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
61   // The outputs should be top_paths * 3 + 1.
62   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
63 
64   const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
65   TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
66   // TensorFlow only supports float.
67   TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
68   const int batch_size = SizeOfDimension(inputs, 1);
69 
70   const TfLiteTensor* sequence_length =
71       GetInput(context, node, kSequenceLengthTensor);
72   TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
73   TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
74   // TensorFlow only supports int32.
75   TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
76 
77   // Resize decoded outputs.
78   // Do not resize indices & values cause we don't know the values yet.
79   for (int i = 0; i < top_paths; ++i) {
80     TfLiteTensor* indices = GetOutput(context, node, i);
81     SetTensorToDynamic(indices);
82     TfLiteTensor* values = GetOutput(context, node, i + top_paths);
83     SetTensorToDynamic(values);
84     TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths);
85     SetTensorToDynamic(output_shape);
86   }
87 
88   // Resize log probability outputs.
89   TfLiteTensor* log_probability_output =
90       GetOutput(context, node, top_paths * 3);
91   TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
92   log_probability_output_shape_array->data[0] = batch_size;
93   log_probability_output_shape_array->data[1] = top_paths;
94   return context->ResizeTensor(context, log_probability_output,
95                                log_probability_output_shape_array);
96 }
97 
Resize(TfLiteContext * context,std::initializer_list<int32_t> output_shape,TfLiteTensor * output)98 TfLiteStatus Resize(TfLiteContext* context,
99                     std::initializer_list<int32_t> output_shape,
100                     TfLiteTensor* output) {
101   const int dimensions = output_shape.size();
102   TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
103   int i = 0;
104   for (const int v : output_shape) {
105     output_shape_array->data[i++] = v;
106   }
107   return context->ResizeTensor(context, output, output_shape_array);
108 }
109 
StoreAllDecodedSequences(TfLiteContext * context,const std::vector<std::vector<std::vector<int>>> & sequences,TfLiteNode * node,int top_paths)110 TfLiteStatus StoreAllDecodedSequences(
111     TfLiteContext* context,
112     const std::vector<std::vector<std::vector<int>>>& sequences,
113     TfLiteNode* node, int top_paths) {
114   const int32_t batch_size = sequences.size();
115   std::vector<int32_t> num_entries(top_paths, 0);
116 
117   // Calculate num_entries per path
118   for (const auto& batch_s : sequences) {
119     TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
120     for (int p = 0; p < top_paths; ++p) {
121       num_entries[p] += batch_s[p].size();
122     }
123   }
124 
125   for (int p = 0; p < top_paths; ++p) {
126     const int32_t p_num = num_entries[p];
127 
128     // Resize the decoded outputs.
129     TfLiteTensor* indices = GetOutput(context, node, p);
130     TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
131 
132     TfLiteTensor* values = GetOutput(context, node, p + top_paths);
133     TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
134 
135     TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths);
136     TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
137 
138     int32_t max_decoded = 0;
139     int32_t offset = 0;
140 
141     int32_t* indices_data = GetTensorData<int32_t>(indices);
142     int32_t* values_data = GetTensorData<int32_t>(values);
143     int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
144     for (int b = 0; b < batch_size; ++b) {
145       auto& p_batch = sequences[b][p];
146       int32_t num_decoded = p_batch.size();
147       max_decoded = std::max(max_decoded, num_decoded);
148 
149       std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
150       for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
151         indices_data[offset * 2] = b;
152         indices_data[offset * 2 + 1] = t;
153       }
154     }
155 
156     decoded_shape_data[0] = batch_size;
157     decoded_shape_data[1] = max_decoded;
158   }
159   return kTfLiteOk;
160 }
161 
Eval(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
163   const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
164   const TfLiteTensor* sequence_length =
165       GetInput(context, node, kSequenceLengthTensor);
166   const CTCBeamSearchDecoderParams* option =
167       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
168 
169   const int max_time = SizeOfDimension(inputs, 0);
170   const int batch_size = SizeOfDimension(inputs, 1);
171   const int num_classes = SizeOfDimension(inputs, 2);
172 
173   const int beam_width = option->beam_width;
174   const int top_paths = option->top_paths;
175   const bool merge_repeated = option->merge_repeated;
176 
177   // Validate sequence length is less or equal than max time.
178   for (int i = 0; i < batch_size; ++i) {
179     TF_LITE_ENSURE(context,
180                    max_time >= GetTensorData<int32_t>(sequence_length)[i]);
181   }
182 
183   // The following logic is implemented like
184   // tensorflow/core/kernels/ctc_decoder_ops.cc
185   std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
186 
187   for (std::size_t t = 0; t < max_time; ++t) {
188     input_list_t.emplace_back(
189         GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
190         num_classes);
191   }
192 
193   ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer
194       beam_scorer;
195   ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search(
196       num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
197       merge_repeated);
198 
199   // Allocate temporary memory for holding chip operation data.
200   float* input_chip_t_data =
201       static_cast<float*>(malloc(num_classes * sizeof(float)));
202   Eigen::array<Eigen::DenseIndex, 1> dims;
203   dims[0] = num_classes;
204   optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
205 
206   std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
207   std::vector<float> log_probs;
208 
209   TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths);
210   float* log_probabilities_output = GetTensorData<float>(log_probabilities);
211 
212   // Assumption: the blank index is num_classes - 1
213   for (int b = 0; b < batch_size; ++b) {
214     auto& best_paths_b = best_paths[b];
215     best_paths_b.resize(top_paths);
216     for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
217       input_chip_t = input_list_t[t].chip(b, 0);
218       auto input_bi =
219           Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
220       beam_search.Step(input_bi);
221     }
222     TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
223                                                  &log_probs, merge_repeated));
224     beam_search.Reset();
225 
226     // Fill in log_probabilities output.
227     for (int bp = 0; bp < top_paths; ++bp) {
228       log_probabilities_output[b * top_paths + bp] = log_probs[bp];
229     }
230   }
231 
232   free(input_chip_t_data);
233   return StoreAllDecodedSequences(context, best_paths, node, top_paths);
234 }
235 
236 }  // namespace ctc_beam_search_decoder
237 
Register_CTC_BEAM_SEARCH_DECODER()238 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
239   static TfLiteRegistration r = {
240       ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
241       ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
242   return &r;
243 }
244 
245 }  // namespace experimental
246 }  // namespace ops
247 }  // namespace tflite
248