1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 #include "flatbuffers/flexbuffers.h" // TF:flatbuffers
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/experimental/kernels/ctc_beam_search.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace experimental {
27 namespace ctc_beam_search_decoder {
28
29 constexpr int kInputsTensor = 0;
30 constexpr int kSequenceLengthTensor = 1;
31
32 typedef struct {
33 int beam_width;
34 int top_paths;
35 bool merge_repeated;
36 } CTCBeamSearchDecoderParams;
37
Init(TfLiteContext * context,const char * buffer,size_t length)38 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
39 TFLITE_CHECK(buffer != nullptr);
40 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
41 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
42
43 CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
44 option->beam_width = m["beam_width"].AsInt32();
45 option->top_paths = m["top_paths"].AsInt32();
46 option->merge_repeated = m["merge_repeated"].AsBool();
47
48 return option;
49 }
50
Free(TfLiteContext * context,void * buffer)51 void Free(TfLiteContext* context, void* buffer) {
52 delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
53 }
54
Prepare(TfLiteContext * context,TfLiteNode * node)55 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
56 const CTCBeamSearchDecoderParams* option =
57 reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
58 const int top_paths = option->top_paths;
59 TF_LITE_ENSURE(context, option->beam_width >= top_paths);
60 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
61 // The outputs should be top_paths * 3 + 1.
62 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
63
64 const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
65 TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
66 // TensorFlow only supports float.
67 TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
68 const int batch_size = SizeOfDimension(inputs, 1);
69
70 const TfLiteTensor* sequence_length =
71 GetInput(context, node, kSequenceLengthTensor);
72 TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
73 TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
74 // TensorFlow only supports int32.
75 TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
76
77 // Resize decoded outputs.
78 // Do not resize indices & values cause we don't know the values yet.
79 for (int i = 0; i < top_paths; ++i) {
80 TfLiteTensor* indices = GetOutput(context, node, i);
81 SetTensorToDynamic(indices);
82 TfLiteTensor* values = GetOutput(context, node, i + top_paths);
83 SetTensorToDynamic(values);
84 TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths);
85 SetTensorToDynamic(output_shape);
86 }
87
88 // Resize log probability outputs.
89 TfLiteTensor* log_probability_output =
90 GetOutput(context, node, top_paths * 3);
91 TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
92 log_probability_output_shape_array->data[0] = batch_size;
93 log_probability_output_shape_array->data[1] = top_paths;
94 return context->ResizeTensor(context, log_probability_output,
95 log_probability_output_shape_array);
96 }
97
Resize(TfLiteContext * context,std::initializer_list<int32_t> output_shape,TfLiteTensor * output)98 TfLiteStatus Resize(TfLiteContext* context,
99 std::initializer_list<int32_t> output_shape,
100 TfLiteTensor* output) {
101 const int dimensions = output_shape.size();
102 TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
103 int i = 0;
104 for (const int v : output_shape) {
105 output_shape_array->data[i++] = v;
106 }
107 return context->ResizeTensor(context, output, output_shape_array);
108 }
109
StoreAllDecodedSequences(TfLiteContext * context,const std::vector<std::vector<std::vector<int>>> & sequences,TfLiteNode * node,int top_paths)110 TfLiteStatus StoreAllDecodedSequences(
111 TfLiteContext* context,
112 const std::vector<std::vector<std::vector<int>>>& sequences,
113 TfLiteNode* node, int top_paths) {
114 const int32_t batch_size = sequences.size();
115 std::vector<int32_t> num_entries(top_paths, 0);
116
117 // Calculate num_entries per path
118 for (const auto& batch_s : sequences) {
119 TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
120 for (int p = 0; p < top_paths; ++p) {
121 num_entries[p] += batch_s[p].size();
122 }
123 }
124
125 for (int p = 0; p < top_paths; ++p) {
126 const int32_t p_num = num_entries[p];
127
128 // Resize the decoded outputs.
129 TfLiteTensor* indices = GetOutput(context, node, p);
130 TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
131
132 TfLiteTensor* values = GetOutput(context, node, p + top_paths);
133 TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
134
135 TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths);
136 TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
137
138 int32_t max_decoded = 0;
139 int32_t offset = 0;
140
141 int32_t* indices_data = GetTensorData<int32_t>(indices);
142 int32_t* values_data = GetTensorData<int32_t>(values);
143 int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
144 for (int b = 0; b < batch_size; ++b) {
145 auto& p_batch = sequences[b][p];
146 int32_t num_decoded = p_batch.size();
147 max_decoded = std::max(max_decoded, num_decoded);
148
149 std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
150 for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
151 indices_data[offset * 2] = b;
152 indices_data[offset * 2 + 1] = t;
153 }
154 }
155
156 decoded_shape_data[0] = batch_size;
157 decoded_shape_data[1] = max_decoded;
158 }
159 return kTfLiteOk;
160 }
161
Eval(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
163 const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
164 const TfLiteTensor* sequence_length =
165 GetInput(context, node, kSequenceLengthTensor);
166 const CTCBeamSearchDecoderParams* option =
167 reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
168
169 const int max_time = SizeOfDimension(inputs, 0);
170 const int batch_size = SizeOfDimension(inputs, 1);
171 const int num_classes = SizeOfDimension(inputs, 2);
172
173 const int beam_width = option->beam_width;
174 const int top_paths = option->top_paths;
175 const bool merge_repeated = option->merge_repeated;
176
177 // Validate sequence length is less or equal than max time.
178 for (int i = 0; i < batch_size; ++i) {
179 TF_LITE_ENSURE(context,
180 max_time >= GetTensorData<int32_t>(sequence_length)[i]);
181 }
182
183 // The following logic is implemented like
184 // tensorflow/core/kernels/ctc_decoder_ops.cc
185 std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
186
187 for (std::size_t t = 0; t < max_time; ++t) {
188 input_list_t.emplace_back(
189 GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
190 num_classes);
191 }
192
193 ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer
194 beam_scorer;
195 ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search(
196 num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
197 merge_repeated);
198
199 // Allocate temporary memory for holding chip operation data.
200 float* input_chip_t_data =
201 static_cast<float*>(malloc(num_classes * sizeof(float)));
202 Eigen::array<Eigen::DenseIndex, 1> dims;
203 dims[0] = num_classes;
204 optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
205
206 std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
207 std::vector<float> log_probs;
208
209 TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths);
210 float* log_probabilities_output = GetTensorData<float>(log_probabilities);
211
212 // Assumption: the blank index is num_classes - 1
213 for (int b = 0; b < batch_size; ++b) {
214 auto& best_paths_b = best_paths[b];
215 best_paths_b.resize(top_paths);
216 for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
217 input_chip_t = input_list_t[t].chip(b, 0);
218 auto input_bi =
219 Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
220 beam_search.Step(input_bi);
221 }
222 TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
223 &log_probs, merge_repeated));
224 beam_search.Reset();
225
226 // Fill in log_probabilities output.
227 for (int bp = 0; bp < top_paths; ++bp) {
228 log_probabilities_output[b * top_paths + bp] = log_probs[bp];
229 }
230 }
231
232 free(input_chip_t_data);
233 return StoreAllDecodedSequences(context, best_paths, node, top_paths);
234 }
235
236 } // namespace ctc_beam_search_decoder
237
Register_CTC_BEAM_SEARCH_DECODER()238 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
239 static TfLiteRegistration r = {
240 ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
241 ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
242 return &r;
243 }
244
245 } // namespace experimental
246 } // namespace ops
247 } // namespace tflite
248