1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16
17 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/experimental/kernels/ctc_beam_search.h"
20 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace experimental {
28 namespace ctc_beam_search_decoder {
29
30 constexpr int kInputsTensor = 0;
31 constexpr int kSequenceLengthTensor = 1;
32
33 typedef struct {
34 int beam_width;
35 int top_paths;
36 bool merge_repeated;
37 } CTCBeamSearchDecoderParams;
38
Init(TfLiteContext * context,const char * buffer,size_t length)39 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
40 TFLITE_CHECK(buffer != nullptr);
41 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
42 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
43
44 CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
45 option->beam_width = m["beam_width"].AsInt32();
46 option->top_paths = m["top_paths"].AsInt32();
47 option->merge_repeated = m["merge_repeated"].AsBool();
48
49 return option;
50 }
51
Free(TfLiteContext * context,void * buffer)52 void Free(TfLiteContext* context, void* buffer) {
53 delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
54 }
55
Prepare(TfLiteContext * context,TfLiteNode * node)56 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
57 const CTCBeamSearchDecoderParams* option =
58 reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
59 const int top_paths = option->top_paths;
60 TF_LITE_ENSURE(context, option->beam_width >= top_paths);
61 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
62 // The outputs should be top_paths * 3 + 1.
63 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
64
65 const TfLiteTensor* inputs;
66 TF_LITE_ENSURE_OK(context,
67 GetInputSafe(context, node, kInputsTensor, &inputs));
68 TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
69 // TensorFlow only supports float.
70 TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
71 const int batch_size = SizeOfDimension(inputs, 1);
72
73 const TfLiteTensor* sequence_length;
74 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
75 &sequence_length));
76 TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
77 TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
78 // TensorFlow only supports int32.
79 TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
80
81 // Resize decoded outputs.
82 // Do not resize indices & values cause we don't know the values yet.
83 for (int i = 0; i < top_paths; ++i) {
84 TfLiteTensor* indices;
85 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &indices));
86 SetTensorToDynamic(indices);
87 TfLiteTensor* values;
88 TF_LITE_ENSURE_OK(context,
89 GetOutputSafe(context, node, i + top_paths, &values));
90 SetTensorToDynamic(values);
91 TfLiteTensor* output_shape;
92 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i + 2 * top_paths,
93 &output_shape));
94 SetTensorToDynamic(output_shape);
95 }
96
97 // Resize log probability outputs.
98 TfLiteTensor* log_probability_output;
99 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, top_paths * 3,
100 &log_probability_output));
101 TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
102 log_probability_output_shape_array->data[0] = batch_size;
103 log_probability_output_shape_array->data[1] = top_paths;
104 return context->ResizeTensor(context, log_probability_output,
105 log_probability_output_shape_array);
106 }
107
Resize(TfLiteContext * context,std::initializer_list<int32_t> output_shape,TfLiteTensor * output)108 TfLiteStatus Resize(TfLiteContext* context,
109 std::initializer_list<int32_t> output_shape,
110 TfLiteTensor* output) {
111 const int dimensions = output_shape.size();
112 TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
113 int i = 0;
114 for (const int v : output_shape) {
115 output_shape_array->data[i++] = v;
116 }
117 return context->ResizeTensor(context, output, output_shape_array);
118 }
119
StoreAllDecodedSequences(TfLiteContext * context,const std::vector<std::vector<std::vector<int>>> & sequences,TfLiteNode * node,int top_paths)120 TfLiteStatus StoreAllDecodedSequences(
121 TfLiteContext* context,
122 const std::vector<std::vector<std::vector<int>>>& sequences,
123 TfLiteNode* node, int top_paths) {
124 const int32_t batch_size = sequences.size();
125 std::vector<int32_t> num_entries(top_paths, 0);
126
127 // Calculate num_entries per path
128 for (const auto& batch_s : sequences) {
129 TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
130 for (int p = 0; p < top_paths; ++p) {
131 num_entries[p] += batch_s[p].size();
132 }
133 }
134
135 for (int p = 0; p < top_paths; ++p) {
136 const int32_t p_num = num_entries[p];
137
138 // Resize the decoded outputs.
139 TfLiteTensor* indices;
140 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p, &indices));
141 TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
142
143 TfLiteTensor* values;
144 TF_LITE_ENSURE_OK(context,
145 GetOutputSafe(context, node, p + top_paths, &values));
146 TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
147
148 TfLiteTensor* decoded_shape;
149 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p + 2 * top_paths,
150 &decoded_shape));
151 TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
152
153 int32_t max_decoded = 0;
154 int32_t offset = 0;
155
156 int32_t* indices_data = GetTensorData<int32_t>(indices);
157 int32_t* values_data = GetTensorData<int32_t>(values);
158 int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
159 for (int b = 0; b < batch_size; ++b) {
160 auto& p_batch = sequences[b][p];
161 int32_t num_decoded = p_batch.size();
162 max_decoded = std::max(max_decoded, num_decoded);
163
164 std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
165 for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
166 indices_data[offset * 2] = b;
167 indices_data[offset * 2 + 1] = t;
168 }
169 }
170
171 decoded_shape_data[0] = batch_size;
172 decoded_shape_data[1] = max_decoded;
173 }
174 return kTfLiteOk;
175 }
176
Eval(TfLiteContext * context,TfLiteNode * node)177 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
178 const TfLiteTensor* inputs;
179 TF_LITE_ENSURE_OK(context,
180 GetInputSafe(context, node, kInputsTensor, &inputs));
181 const TfLiteTensor* sequence_length;
182 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
183 &sequence_length));
184 const CTCBeamSearchDecoderParams* option =
185 reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
186
187 const int max_time = SizeOfDimension(inputs, 0);
188 const int batch_size = SizeOfDimension(inputs, 1);
189 const int num_classes = SizeOfDimension(inputs, 2);
190
191 const int beam_width = option->beam_width;
192 const int top_paths = option->top_paths;
193 const bool merge_repeated = option->merge_repeated;
194
195 // Validate sequence length is less or equal than max time.
196 for (int i = 0; i < batch_size; ++i) {
197 TF_LITE_ENSURE(context,
198 max_time >= GetTensorData<int32_t>(sequence_length)[i]);
199 }
200
201 // The following logic is implemented like
202 // tensorflow/core/kernels/ctc_decoder_ops.cc
203 std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
204
205 for (std::size_t t = 0; t < max_time; ++t) {
206 input_list_t.emplace_back(
207 GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
208 num_classes);
209 }
210
211 ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer
212 beam_scorer;
213 ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search(
214 num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
215 merge_repeated);
216
217 // Allocate temporary memory for holding chip operation data.
218 float* input_chip_t_data =
219 static_cast<float*>(malloc(num_classes * sizeof(float)));
220 Eigen::array<Eigen::DenseIndex, 1> dims;
221 dims[0] = num_classes;
222 optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
223
224 std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
225 std::vector<float> log_probs;
226
227 TfLiteTensor* log_probabilities;
228 TF_LITE_ENSURE_OK(
229 context, GetOutputSafe(context, node, 3 * top_paths, &log_probabilities));
230 float* log_probabilities_output = GetTensorData<float>(log_probabilities);
231
232 // Assumption: the blank index is num_classes - 1
233 for (int b = 0; b < batch_size; ++b) {
234 auto& best_paths_b = best_paths[b];
235 best_paths_b.resize(top_paths);
236 for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
237 input_chip_t = input_list_t[t].chip(b, 0);
238 auto input_bi =
239 Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
240 beam_search.Step(input_bi);
241 }
242 TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
243 &log_probs, merge_repeated));
244 beam_search.Reset();
245
246 // Fill in log_probabilities output.
247 for (int bp = 0; bp < top_paths; ++bp) {
248 log_probabilities_output[b * top_paths + bp] = log_probs[bp];
249 }
250 }
251
252 free(input_chip_t_data);
253 return StoreAllDecodedSequences(context, best_paths, node, top_paths);
254 }
255
256 } // namespace ctc_beam_search_decoder
257
Register_CTC_BEAM_SEARCH_DECODER()258 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
259 static TfLiteRegistration r = {
260 ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
261 ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
262 return &r;
263 }
264
265 } // namespace experimental
266 } // namespace ops
267 } // namespace tflite
268