1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/ctc_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include <limits>
21
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/ctc/ctc_beam_search.h"
30 #include "tensorflow/core/util/sparse/sparse_tensor.h"
31 #include "tensorflow/core/util/work_sharder.h"
32
33 namespace tensorflow {
34
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36
37 template <typename T>
RowMax(const typename TTypes<T>::UnalignedConstMatrix & m,int r,int * c)38 inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,
39 int* c) {
40 *c = 0;
41 CHECK_LT(0, m.dimension(1));
42 auto p = m(r, 0);
43 for (int i = 1; i < m.dimension(1); ++i) {
44 if (m(r, i) > p) {
45 p = m(r, i);
46 *c = i;
47 }
48 }
49 return p;
50 }
51
52 class CTCDecodeHelper {
53 public:
CTCDecodeHelper()54 CTCDecodeHelper() : top_paths_(1) {}
55
GetTopPaths() const56 inline int GetTopPaths() const { return top_paths_; }
SetTopPaths(int tp)57 void SetTopPaths(int tp) { top_paths_ = tp; }
58
ValidateInputsGenerateOutputs(OpKernelContext * ctx,const Tensor ** inputs,const Tensor ** seq_len,Tensor ** log_prob,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const59 Status ValidateInputsGenerateOutputs(
60 OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
61 Tensor** log_prob, OpOutputList* decoded_indices,
62 OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
63 Status status = ctx->input("inputs", inputs);
64 if (!status.ok()) return status;
65 status = ctx->input("sequence_length", seq_len);
66 if (!status.ok()) return status;
67
68 const TensorShape& inputs_shape = (*inputs)->shape();
69
70 if (inputs_shape.dims() != 3) {
71 return errors::InvalidArgument("inputs is not a 3-Tensor");
72 }
73
74 const int64 max_time = inputs_shape.dim_size(0);
75 const int64 batch_size = inputs_shape.dim_size(1);
76
77 if (max_time == 0) {
78 return errors::InvalidArgument("max_time is 0");
79 }
80 if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
81 return errors::InvalidArgument("sequence_length is not a vector");
82 }
83
84 if (!(batch_size == (*seq_len)->dim_size(0))) {
85 return errors::FailedPrecondition(
86 "len(sequence_length) != batch_size. ",
87 "len(sequence_length): ", (*seq_len)->dim_size(0),
88 " batch_size: ", batch_size);
89 }
90
91 auto seq_len_t = (*seq_len)->vec<int32>();
92
93 for (int b = 0; b < batch_size; ++b) {
94 if (!(seq_len_t(b) <= max_time)) {
95 return errors::FailedPrecondition("sequence_length(", b,
96 ") <= ", max_time);
97 }
98 }
99
100 Status s = ctx->allocate_output(
101 "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
102 if (!s.ok()) return s;
103
104 s = ctx->output_list("decoded_indices", decoded_indices);
105 if (!s.ok()) return s;
106 s = ctx->output_list("decoded_values", decoded_values);
107 if (!s.ok()) return s;
108 s = ctx->output_list("decoded_shape", decoded_shape);
109 if (!s.ok()) return s;
110
111 return Status::OK();
112 }
113
114 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
StoreAllDecodedSequences(const std::vector<std::vector<std::vector<int>>> & sequences,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const115 Status StoreAllDecodedSequences(
116 const std::vector<std::vector<std::vector<int> > >& sequences,
117 OpOutputList* decoded_indices, OpOutputList* decoded_values,
118 OpOutputList* decoded_shape) const {
119 // Calculate the total number of entries for each path
120 const int64 batch_size = sequences.size();
121 std::vector<int64> num_entries(top_paths_, 0);
122
123 // Calculate num_entries per path
124 for (const auto& batch_s : sequences) {
125 CHECK_EQ(batch_s.size(), top_paths_);
126 for (int p = 0; p < top_paths_; ++p) {
127 num_entries[p] += batch_s[p].size();
128 }
129 }
130
131 for (int p = 0; p < top_paths_; ++p) {
132 Tensor* p_indices = nullptr;
133 Tensor* p_values = nullptr;
134 Tensor* p_shape = nullptr;
135
136 const int64 p_num = num_entries[p];
137
138 Status s =
139 decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
140 if (!s.ok()) return s;
141 s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
142 if (!s.ok()) return s;
143 s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
144 if (!s.ok()) return s;
145
146 auto indices_t = p_indices->matrix<int64>();
147 auto values_t = p_values->vec<int64>();
148 auto shape_t = p_shape->vec<int64>();
149
150 int64 max_decoded = 0;
151 int64 offset = 0;
152
153 for (int64 b = 0; b < batch_size; ++b) {
154 auto& p_batch = sequences[b][p];
155 int64 num_decoded = p_batch.size();
156 max_decoded = std::max(max_decoded, num_decoded);
157 if (num_decoded > 0) {
158 DCHECK_NE(values_t.data(), nullptr)
159 << "values_t should not be nullptr: p_num=" << p_num
160 << " num_decoded=" << num_decoded;
161 DCHECK_LT(offset, values_t.size())
162 << "offset should be smaller than values_t.size()";
163 std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
164 }
165 for (int64 t = 0; t < num_decoded; ++t, ++offset) {
166 indices_t(offset, 0) = b;
167 indices_t(offset, 1) = t;
168 }
169 }
170
171 shape_t(0) = batch_size;
172 shape_t(1) = max_decoded;
173 }
174 return Status::OK();
175 }
176
177 private:
178 int top_paths_;
179 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
180 };
181
182 template <typename T>
183 class CTCGreedyDecoderOp : public OpKernel {
184 public:
CTCGreedyDecoderOp(OpKernelConstruction * ctx)185 explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
186 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
187 }
188
Compute(OpKernelContext * ctx)189 void Compute(OpKernelContext* ctx) override {
190 const Tensor* inputs;
191 const Tensor* seq_len;
192 Tensor* log_prob = nullptr;
193 OpOutputList decoded_indices;
194 OpOutputList decoded_values;
195 OpOutputList decoded_shape;
196 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
197 ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
198 &decoded_values, &decoded_shape));
199
200 const TensorShape& inputs_shape = inputs->shape();
201
202 std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
203 const int64 max_time = inputs_shape.dim_size(0);
204 const int64 batch_size = inputs_shape.dim_size(1);
205 const int64 num_classes_raw = inputs_shape.dim_size(2);
206 OP_REQUIRES(
207 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
208 errors::InvalidArgument("num_classes cannot exceed max int"));
209 const int num_classes = static_cast<const int>(num_classes_raw);
210
211 auto inputs_t = inputs->tensor<T, 3>();
212
213 input_list_t.reserve(max_time);
214 for (std::size_t t = 0; t < max_time; ++t) {
215 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
216 batch_size, num_classes);
217 }
218 auto seq_len_t = seq_len->vec<int32>();
219 auto log_prob_t = log_prob->matrix<T>();
220
221 log_prob_t.setZero();
222
223 // Assumption: the blank index is num_classes - 1
224 int blank_index = num_classes - 1;
225
226 // Perform best path decoding
227 std::vector<std::vector<std::vector<int> > > sequences(batch_size);
228 auto decode = [&](const int64 begin, const int64 end) {
229 for (int b = begin; b < end; ++b) {
230 sequences[b].resize(1);
231 auto &sequence = sequences[b][0];
232 int prev_indices = -1;
233 for (int t = 0; t < seq_len_t(b); ++t) {
234 int max_class_indices;
235 log_prob_t(b, 0) +=
236 -RowMax<T>(input_list_t[t], b, &max_class_indices);
237 if (max_class_indices != blank_index &&
238 !(merge_repeated_ && max_class_indices == prev_indices)) {
239 sequence.push_back(max_class_indices);
240 }
241 prev_indices = max_class_indices;
242 }
243 }
244 };
245
246 const int64 kCostPerUnit = 50 * max_time * num_classes;
247 const int64 total = batch_size;
248 const DeviceBase::CpuWorkerThreads& worker_threads =
249 *ctx->device()->tensorflow_cpu_worker_threads();
250 Shard(worker_threads.num_threads, worker_threads.workers, total,
251 kCostPerUnit, decode);
252
253 OP_REQUIRES_OK(
254 ctx, decode_helper_.StoreAllDecodedSequences(
255 sequences, &decoded_indices, &decoded_values, &decoded_shape));
256 }
257
258 private:
259 CTCDecodeHelper decode_helper_;
260 bool merge_repeated_;
261
262 TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
263 };
264
265 #define REGISTER_CPU(T) \
266 REGISTER_KERNEL_BUILDER( \
267 Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
268 CTCGreedyDecoderOp<T>);
269
270 REGISTER_CPU(float);
271 REGISTER_CPU(double);
272
273 #undef REGISTER_CPU
274
275 // CTC beam search
276 template <typename T>
277 class CTCBeamSearchDecoderOp : public OpKernel {
278 public:
CTCBeamSearchDecoderOp(OpKernelConstruction * ctx)279 explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
280 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
281 OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
282 int top_paths;
283 OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
284 decode_helper_.SetTopPaths(top_paths);
285 }
286
Compute(OpKernelContext * ctx)287 void Compute(OpKernelContext* ctx) override {
288 const Tensor* inputs;
289 const Tensor* seq_len;
290 Tensor* log_prob = nullptr;
291 OpOutputList decoded_indices;
292 OpOutputList decoded_values;
293 OpOutputList decoded_shape;
294 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
295 ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
296 &decoded_values, &decoded_shape));
297
298 auto inputs_t = inputs->tensor<T, 3>();
299 auto seq_len_t = seq_len->vec<int32>();
300 auto log_prob_t = log_prob->matrix<T>();
301
302 const TensorShape& inputs_shape = inputs->shape();
303
304 const int64 max_time = inputs_shape.dim_size(0);
305 const int64 batch_size = inputs_shape.dim_size(1);
306 const int64 num_classes_raw = inputs_shape.dim_size(2);
307 OP_REQUIRES(
308 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
309 errors::InvalidArgument("num_classes cannot exceed max int"));
310 const int num_classes = static_cast<const int>(num_classes_raw);
311
312 log_prob_t.setZero();
313
314 std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
315
316 input_list_t.reserve(max_time);
317 for (std::size_t t = 0; t < max_time; ++t) {
318 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
319 batch_size, num_classes);
320 }
321
322 ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
323 &beam_scorer_, 1 /* batch_size */,
324 merge_repeated_);
325 Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
326 auto input_chip_t = input_chip.flat<T>();
327
328 std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
329 std::vector<T> log_probs;
330
331 // Assumption: the blank index is num_classes - 1
332 for (int b = 0; b < batch_size; ++b) {
333 auto& best_paths_b = best_paths[b];
334 best_paths_b.resize(decode_helper_.GetTopPaths());
335 for (int t = 0; t < seq_len_t(b); ++t) {
336 input_chip_t = input_list_t[t].chip(b, 0);
337 auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>(
338 input_chip_t.data(), num_classes);
339 beam_search.Step(input_bi);
340 }
341 OP_REQUIRES_OK(
342 ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
343 &log_probs, merge_repeated_));
344
345 beam_search.Reset();
346
347 for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
348 log_prob_t(b, bp) = log_probs[bp];
349 }
350 }
351
352 OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
353 best_paths, &decoded_indices, &decoded_values,
354 &decoded_shape));
355 }
356
357 private:
358 CTCDecodeHelper decode_helper_;
359 typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
360 bool merge_repeated_;
361 int beam_width_;
362 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
363 };
364
365 #define REGISTER_CPU(T) \
366 REGISTER_KERNEL_BUILDER( \
367 Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
368 CTCBeamSearchDecoderOp<T>);
369
370 REGISTER_CPU(float);
371 REGISTER_CPU(double);
372
373 #undef REGISTER_CPU
374
375 } // end namespace tensorflow
376