1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/ctc_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include <limits>
21
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/util/ctc/ctc_beam_search.h"
29 #include "tensorflow/core/util/sparse/sparse_tensor.h"
30 #include "tensorflow/core/util/work_sharder.h"
31
32 namespace tensorflow {
33
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35
RowMax(const TTypes<float>::UnalignedConstMatrix & m,int r,int * c)36 inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
37 int* c) {
38 *c = 0;
39 CHECK_LT(0, m.dimension(1));
40 float p = m(r, 0);
41 for (int i = 1; i < m.dimension(1); ++i) {
42 if (m(r, i) > p) {
43 p = m(r, i);
44 *c = i;
45 }
46 }
47 return p;
48 }
49
50 class CTCDecodeHelper {
51 public:
CTCDecodeHelper()52 CTCDecodeHelper() : top_paths_(1) {}
53
GetTopPaths() const54 inline int GetTopPaths() const { return top_paths_; }
SetTopPaths(int tp)55 void SetTopPaths(int tp) { top_paths_ = tp; }
56
ValidateInputsGenerateOutputs(OpKernelContext * ctx,const Tensor ** inputs,const Tensor ** seq_len,Tensor ** log_prob,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const57 Status ValidateInputsGenerateOutputs(
58 OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
59 Tensor** log_prob, OpOutputList* decoded_indices,
60 OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
61 Status status = ctx->input("inputs", inputs);
62 if (!status.ok()) return status;
63 status = ctx->input("sequence_length", seq_len);
64 if (!status.ok()) return status;
65
66 const TensorShape& inputs_shape = (*inputs)->shape();
67
68 if (inputs_shape.dims() != 3) {
69 return errors::InvalidArgument("inputs is not a 3-Tensor");
70 }
71
72 const int64 max_time = inputs_shape.dim_size(0);
73 const int64 batch_size = inputs_shape.dim_size(1);
74
75 if (max_time == 0) {
76 return errors::InvalidArgument("max_time is 0");
77 }
78 if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
79 return errors::InvalidArgument("sequence_length is not a vector");
80 }
81
82 if (!(batch_size == (*seq_len)->dim_size(0))) {
83 return errors::FailedPrecondition(
84 "len(sequence_length) != batch_size. ",
85 "len(sequence_length): ", (*seq_len)->dim_size(0),
86 " batch_size: ", batch_size);
87 }
88
89 auto seq_len_t = (*seq_len)->vec<int32>();
90
91 for (int b = 0; b < batch_size; ++b) {
92 if (!(seq_len_t(b) <= max_time)) {
93 return errors::FailedPrecondition("sequence_length(", b,
94 ") <= ", max_time);
95 }
96 }
97
98 Status s = ctx->allocate_output(
99 "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
100 if (!s.ok()) return s;
101
102 s = ctx->output_list("decoded_indices", decoded_indices);
103 if (!s.ok()) return s;
104 s = ctx->output_list("decoded_values", decoded_values);
105 if (!s.ok()) return s;
106 s = ctx->output_list("decoded_shape", decoded_shape);
107 if (!s.ok()) return s;
108
109 return Status::OK();
110 }
111
112 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
StoreAllDecodedSequences(const std::vector<std::vector<std::vector<int>>> & sequences,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const113 Status StoreAllDecodedSequences(
114 const std::vector<std::vector<std::vector<int> > >& sequences,
115 OpOutputList* decoded_indices, OpOutputList* decoded_values,
116 OpOutputList* decoded_shape) const {
117 // Calculate the total number of entries for each path
118 const int64 batch_size = sequences.size();
119 std::vector<int64> num_entries(top_paths_, 0);
120
121 // Calculate num_entries per path
122 for (const auto& batch_s : sequences) {
123 CHECK_EQ(batch_s.size(), top_paths_);
124 for (int p = 0; p < top_paths_; ++p) {
125 num_entries[p] += batch_s[p].size();
126 }
127 }
128
129 for (int p = 0; p < top_paths_; ++p) {
130 Tensor* p_indices = nullptr;
131 Tensor* p_values = nullptr;
132 Tensor* p_shape = nullptr;
133
134 const int64 p_num = num_entries[p];
135
136 Status s =
137 decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
138 if (!s.ok()) return s;
139 s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
140 if (!s.ok()) return s;
141 s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
142 if (!s.ok()) return s;
143
144 auto indices_t = p_indices->matrix<int64>();
145 auto values_t = p_values->vec<int64>();
146 auto shape_t = p_shape->vec<int64>();
147
148 int64 max_decoded = 0;
149 int64 offset = 0;
150
151 for (int64 b = 0; b < batch_size; ++b) {
152 auto& p_batch = sequences[b][p];
153 int64 num_decoded = p_batch.size();
154 max_decoded = std::max(max_decoded, num_decoded);
155 std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
156 for (int64 t = 0; t < num_decoded; ++t, ++offset) {
157 indices_t(offset, 0) = b;
158 indices_t(offset, 1) = t;
159 }
160 }
161
162 shape_t(0) = batch_size;
163 shape_t(1) = max_decoded;
164 }
165 return Status::OK();
166 }
167
168 private:
169 int top_paths_;
170 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
171 };
172
173 class CTCGreedyDecoderOp : public OpKernel {
174 public:
CTCGreedyDecoderOp(OpKernelConstruction * ctx)175 explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
176 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
177 }
178
Compute(OpKernelContext * ctx)179 void Compute(OpKernelContext* ctx) override {
180 const Tensor* inputs;
181 const Tensor* seq_len;
182 Tensor* log_prob = nullptr;
183 OpOutputList decoded_indices;
184 OpOutputList decoded_values;
185 OpOutputList decoded_shape;
186 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
187 ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
188 &decoded_values, &decoded_shape));
189
190 const TensorShape& inputs_shape = inputs->shape();
191
192 std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
193 const int64 max_time = inputs_shape.dim_size(0);
194 const int64 batch_size = inputs_shape.dim_size(1);
195 const int64 num_classes_raw = inputs_shape.dim_size(2);
196 OP_REQUIRES(
197 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
198 errors::InvalidArgument("num_classes cannot exceed max int"));
199 const int num_classes = static_cast<const int>(num_classes_raw);
200
201 auto inputs_t = inputs->tensor<float, 3>();
202
203 for (std::size_t t = 0; t < max_time; ++t) {
204 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
205 batch_size, num_classes);
206 }
207 auto seq_len_t = seq_len->vec<int32>();
208 auto log_prob_t = log_prob->matrix<float>();
209
210 log_prob_t.setZero();
211
212 // Assumption: the blank index is num_classes - 1
213 int blank_index = num_classes - 1;
214
215 // Perform best path decoding
216 std::vector<std::vector<std::vector<int> > > sequences(batch_size);
217 auto decode = [&](const int64 begin, const int64 end) {
218 for (int b = begin; b < end; ++b) {
219 sequences[b].resize(1);
220 auto &sequence = sequences[b][0];
221 int prev_indices = -1;
222 for (int t = 0; t < seq_len_t(b); ++t) {
223 int max_class_indices;
224 log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
225 if (max_class_indices != blank_index &&
226 !(merge_repeated_ && max_class_indices == prev_indices)) {
227 sequence.push_back(max_class_indices);
228 }
229 prev_indices = max_class_indices;
230 }
231 }
232 };
233
234 const int64 kCostPerUnit = 50 * max_time * num_classes;
235 const int64 total = batch_size;
236 const DeviceBase::CpuWorkerThreads& worker_threads =
237 *ctx->device()->tensorflow_cpu_worker_threads();
238 Shard(worker_threads.num_threads, worker_threads.workers, total,
239 kCostPerUnit, decode);
240
241 OP_REQUIRES_OK(
242 ctx, decode_helper_.StoreAllDecodedSequences(
243 sequences, &decoded_indices, &decoded_values, &decoded_shape));
244 }
245
246 private:
247 CTCDecodeHelper decode_helper_;
248 bool merge_repeated_;
249
250 TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
251 };
252
253 REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
254 CTCGreedyDecoderOp);
255
256 // CTC beam search
257 class CTCBeamSearchDecoderOp : public OpKernel {
258 public:
CTCBeamSearchDecoderOp(OpKernelConstruction * ctx)259 explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
260 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
261 OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
262 int top_paths;
263 OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
264 decode_helper_.SetTopPaths(top_paths);
265 }
266
Compute(OpKernelContext * ctx)267 void Compute(OpKernelContext* ctx) override {
268 const Tensor* inputs;
269 const Tensor* seq_len;
270 Tensor* log_prob = nullptr;
271 OpOutputList decoded_indices;
272 OpOutputList decoded_values;
273 OpOutputList decoded_shape;
274 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
275 ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
276 &decoded_values, &decoded_shape));
277
278 auto inputs_t = inputs->tensor<float, 3>();
279 auto seq_len_t = seq_len->vec<int32>();
280 auto log_prob_t = log_prob->matrix<float>();
281
282 const TensorShape& inputs_shape = inputs->shape();
283
284 const int64 max_time = inputs_shape.dim_size(0);
285 const int64 batch_size = inputs_shape.dim_size(1);
286 const int64 num_classes_raw = inputs_shape.dim_size(2);
287 OP_REQUIRES(
288 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
289 errors::InvalidArgument("num_classes cannot exceed max int"));
290 const int num_classes = static_cast<const int>(num_classes_raw);
291
292 log_prob_t.setZero();
293
294 std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
295
296 for (std::size_t t = 0; t < max_time; ++t) {
297 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
298 batch_size, num_classes);
299 }
300
301 ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
302 &beam_scorer_, 1 /* batch_size */,
303 merge_repeated_);
304 Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
305 auto input_chip_t = input_chip.flat<float>();
306
307 std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
308 std::vector<float> log_probs;
309
310 // Assumption: the blank index is num_classes - 1
311 for (int b = 0; b < batch_size; ++b) {
312 auto& best_paths_b = best_paths[b];
313 best_paths_b.resize(decode_helper_.GetTopPaths());
314 for (int t = 0; t < seq_len_t(b); ++t) {
315 input_chip_t = input_list_t[t].chip(b, 0);
316 auto input_bi =
317 Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
318 beam_search.Step(input_bi);
319 }
320 OP_REQUIRES_OK(
321 ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
322 &log_probs, merge_repeated_));
323
324 beam_search.Reset();
325
326 for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
327 log_prob_t(b, bp) = log_probs[bp];
328 }
329 }
330
331 OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
332 best_paths, &decoded_indices, &decoded_values,
333 &decoded_shape));
334 }
335
336 private:
337 CTCDecodeHelper decode_helper_;
338 ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
339 bool merge_repeated_;
340 int beam_width_;
341 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
342 };
343
344 REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
345 CTCBeamSearchDecoderOp);
346
347 } // end namespace tensorflow
348