• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/ctc_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include <limits>
21 
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/ctc/ctc_beam_search.h"
30 #include "tensorflow/core/util/sparse/sparse_tensor.h"
31 #include "tensorflow/core/util/work_sharder.h"
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 
37 template <typename T>
RowMax(const typename TTypes<T>::UnalignedConstMatrix & m,int r,int * c)38 inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,
39                 int* c) {
40   *c = 0;
41   CHECK_LT(0, m.dimension(1));
42   auto p = m(r, 0);
43   for (int i = 1; i < m.dimension(1); ++i) {
44     if (m(r, i) > p) {
45       p = m(r, i);
46       *c = i;
47     }
48   }
49   return p;
50 }
51 
52 class CTCDecodeHelper {
53  public:
CTCDecodeHelper()54   CTCDecodeHelper() : top_paths_(1) {}
55 
GetTopPaths() const56   inline int GetTopPaths() const { return top_paths_; }
SetTopPaths(int tp)57   void SetTopPaths(int tp) { top_paths_ = tp; }
58 
ValidateInputsGenerateOutputs(OpKernelContext * ctx,const Tensor ** inputs,const Tensor ** seq_len,Tensor ** log_prob,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const59   Status ValidateInputsGenerateOutputs(
60       OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
61       Tensor** log_prob, OpOutputList* decoded_indices,
62       OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
63     Status status = ctx->input("inputs", inputs);
64     if (!status.ok()) return status;
65     status = ctx->input("sequence_length", seq_len);
66     if (!status.ok()) return status;
67 
68     const TensorShape& inputs_shape = (*inputs)->shape();
69 
70     if (inputs_shape.dims() != 3) {
71       return errors::InvalidArgument("inputs is not a 3-Tensor");
72     }
73 
74     const int64 max_time = inputs_shape.dim_size(0);
75     const int64 batch_size = inputs_shape.dim_size(1);
76 
77     if (max_time == 0) {
78       return errors::InvalidArgument("max_time is 0");
79     }
80     if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
81       return errors::InvalidArgument("sequence_length is not a vector");
82     }
83 
84     if (!(batch_size == (*seq_len)->dim_size(0))) {
85       return errors::FailedPrecondition(
86           "len(sequence_length) != batch_size.  ",
87           "len(sequence_length):  ", (*seq_len)->dim_size(0),
88           " batch_size: ", batch_size);
89     }
90 
91     auto seq_len_t = (*seq_len)->vec<int32>();
92 
93     for (int b = 0; b < batch_size; ++b) {
94       if (!(seq_len_t(b) <= max_time)) {
95         return errors::FailedPrecondition("sequence_length(", b,
96                                           ") <= ", max_time);
97       }
98     }
99 
100     Status s = ctx->allocate_output(
101         "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
102     if (!s.ok()) return s;
103 
104     s = ctx->output_list("decoded_indices", decoded_indices);
105     if (!s.ok()) return s;
106     s = ctx->output_list("decoded_values", decoded_values);
107     if (!s.ok()) return s;
108     s = ctx->output_list("decoded_shape", decoded_shape);
109     if (!s.ok()) return s;
110 
111     return Status::OK();
112   }
113 
114   // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
StoreAllDecodedSequences(const std::vector<std::vector<std::vector<int>>> & sequences,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const115   Status StoreAllDecodedSequences(
116       const std::vector<std::vector<std::vector<int> > >& sequences,
117       OpOutputList* decoded_indices, OpOutputList* decoded_values,
118       OpOutputList* decoded_shape) const {
119     // Calculate the total number of entries for each path
120     const int64 batch_size = sequences.size();
121     std::vector<int64> num_entries(top_paths_, 0);
122 
123     // Calculate num_entries per path
124     for (const auto& batch_s : sequences) {
125       CHECK_EQ(batch_s.size(), top_paths_);
126       for (int p = 0; p < top_paths_; ++p) {
127         num_entries[p] += batch_s[p].size();
128       }
129     }
130 
131     for (int p = 0; p < top_paths_; ++p) {
132       Tensor* p_indices = nullptr;
133       Tensor* p_values = nullptr;
134       Tensor* p_shape = nullptr;
135 
136       const int64 p_num = num_entries[p];
137 
138       Status s =
139           decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
140       if (!s.ok()) return s;
141       s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
142       if (!s.ok()) return s;
143       s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
144       if (!s.ok()) return s;
145 
146       auto indices_t = p_indices->matrix<int64>();
147       auto values_t = p_values->vec<int64>();
148       auto shape_t = p_shape->vec<int64>();
149 
150       int64 max_decoded = 0;
151       int64 offset = 0;
152 
153       for (int64 b = 0; b < batch_size; ++b) {
154         auto& p_batch = sequences[b][p];
155         int64 num_decoded = p_batch.size();
156         max_decoded = std::max(max_decoded, num_decoded);
157         if (num_decoded > 0) {
158           DCHECK_NE(values_t.data(), nullptr)
159               << "values_t should not be nullptr: p_num=" << p_num
160               << " num_decoded=" << num_decoded;
161           DCHECK_LT(offset, values_t.size())
162               << "offset should be smaller than values_t.size()";
163           std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
164         }
165         for (int64 t = 0; t < num_decoded; ++t, ++offset) {
166           indices_t(offset, 0) = b;
167           indices_t(offset, 1) = t;
168         }
169       }
170 
171       shape_t(0) = batch_size;
172       shape_t(1) = max_decoded;
173     }
174     return Status::OK();
175   }
176 
177  private:
178   int top_paths_;
179   TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
180 };
181 
182 template <typename T>
183 class CTCGreedyDecoderOp : public OpKernel {
184  public:
CTCGreedyDecoderOp(OpKernelConstruction * ctx)185   explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
186     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
187   }
188 
Compute(OpKernelContext * ctx)189   void Compute(OpKernelContext* ctx) override {
190     const Tensor* inputs;
191     const Tensor* seq_len;
192     Tensor* log_prob = nullptr;
193     OpOutputList decoded_indices;
194     OpOutputList decoded_values;
195     OpOutputList decoded_shape;
196     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
197                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
198                             &decoded_values, &decoded_shape));
199 
200     const TensorShape& inputs_shape = inputs->shape();
201 
202     std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
203     const int64 max_time = inputs_shape.dim_size(0);
204     const int64 batch_size = inputs_shape.dim_size(1);
205     const int64 num_classes_raw = inputs_shape.dim_size(2);
206     OP_REQUIRES(
207         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
208         errors::InvalidArgument("num_classes cannot exceed max int"));
209     const int num_classes = static_cast<const int>(num_classes_raw);
210 
211     auto inputs_t = inputs->tensor<T, 3>();
212 
213     input_list_t.reserve(max_time);
214     for (std::size_t t = 0; t < max_time; ++t) {
215       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
216                                 batch_size, num_classes);
217     }
218     auto seq_len_t = seq_len->vec<int32>();
219     auto log_prob_t = log_prob->matrix<T>();
220 
221     log_prob_t.setZero();
222 
223     // Assumption: the blank index is num_classes - 1
224     int blank_index = num_classes - 1;
225 
226     // Perform best path decoding
227     std::vector<std::vector<std::vector<int> > > sequences(batch_size);
228     auto decode = [&](const int64 begin, const int64 end) {
229       for (int b = begin; b < end; ++b) {
230         sequences[b].resize(1);
231         auto &sequence = sequences[b][0];
232         int prev_indices = -1;
233         for (int t = 0; t < seq_len_t(b); ++t) {
234           int max_class_indices;
235           log_prob_t(b, 0) +=
236               -RowMax<T>(input_list_t[t], b, &max_class_indices);
237           if (max_class_indices != blank_index &&
238               !(merge_repeated_ && max_class_indices == prev_indices)) {
239             sequence.push_back(max_class_indices);
240           }
241           prev_indices = max_class_indices;
242         }
243       }
244     };
245 
246     const int64 kCostPerUnit = 50 * max_time * num_classes;
247     const int64 total = batch_size;
248     const DeviceBase::CpuWorkerThreads& worker_threads =
249         *ctx->device()->tensorflow_cpu_worker_threads();
250     Shard(worker_threads.num_threads, worker_threads.workers, total,
251           kCostPerUnit, decode);
252 
253     OP_REQUIRES_OK(
254         ctx, decode_helper_.StoreAllDecodedSequences(
255                  sequences, &decoded_indices, &decoded_values, &decoded_shape));
256   }
257 
258  private:
259   CTCDecodeHelper decode_helper_;
260   bool merge_repeated_;
261 
262   TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
263 };
264 
265 #define REGISTER_CPU(T)                                                   \
266   REGISTER_KERNEL_BUILDER(                                                \
267       Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
268       CTCGreedyDecoderOp<T>);
269 
270 REGISTER_CPU(float);
271 REGISTER_CPU(double);
272 
273 #undef REGISTER_CPU
274 
275 // CTC beam search
276 template <typename T>
277 class CTCBeamSearchDecoderOp : public OpKernel {
278  public:
CTCBeamSearchDecoderOp(OpKernelConstruction * ctx)279   explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
280     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
281     OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
282     int top_paths;
283     OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
284     decode_helper_.SetTopPaths(top_paths);
285   }
286 
Compute(OpKernelContext * ctx)287   void Compute(OpKernelContext* ctx) override {
288     const Tensor* inputs;
289     const Tensor* seq_len;
290     Tensor* log_prob = nullptr;
291     OpOutputList decoded_indices;
292     OpOutputList decoded_values;
293     OpOutputList decoded_shape;
294     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
295                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
296                             &decoded_values, &decoded_shape));
297 
298     auto inputs_t = inputs->tensor<T, 3>();
299     auto seq_len_t = seq_len->vec<int32>();
300     auto log_prob_t = log_prob->matrix<T>();
301 
302     const TensorShape& inputs_shape = inputs->shape();
303 
304     const int64 max_time = inputs_shape.dim_size(0);
305     const int64 batch_size = inputs_shape.dim_size(1);
306     const int64 num_classes_raw = inputs_shape.dim_size(2);
307     OP_REQUIRES(
308         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
309         errors::InvalidArgument("num_classes cannot exceed max int"));
310     const int num_classes = static_cast<const int>(num_classes_raw);
311 
312     log_prob_t.setZero();
313 
314     std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
315 
316     input_list_t.reserve(max_time);
317     for (std::size_t t = 0; t < max_time; ++t) {
318       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
319                                 batch_size, num_classes);
320     }
321 
322     ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
323                                              &beam_scorer_, 1 /* batch_size */,
324                                              merge_repeated_);
325     Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
326     auto input_chip_t = input_chip.flat<T>();
327 
328     std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
329     std::vector<T> log_probs;
330 
331     // Assumption: the blank index is num_classes - 1
332     for (int b = 0; b < batch_size; ++b) {
333       auto& best_paths_b = best_paths[b];
334       best_paths_b.resize(decode_helper_.GetTopPaths());
335       for (int t = 0; t < seq_len_t(b); ++t) {
336         input_chip_t = input_list_t[t].chip(b, 0);
337         auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>(
338             input_chip_t.data(), num_classes);
339         beam_search.Step(input_bi);
340       }
341       OP_REQUIRES_OK(
342           ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
343                                     &log_probs, merge_repeated_));
344 
345       beam_search.Reset();
346 
347       for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
348         log_prob_t(b, bp) = log_probs[bp];
349       }
350     }
351 
352     OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
353                             best_paths, &decoded_indices, &decoded_values,
354                             &decoded_shape));
355   }
356 
357  private:
358   CTCDecodeHelper decode_helper_;
359   typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
360   bool merge_repeated_;
361   int beam_width_;
362   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
363 };
364 
365 #define REGISTER_CPU(T)                                                       \
366   REGISTER_KERNEL_BUILDER(                                                    \
367       Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
368       CTCBeamSearchDecoderOp<T>);
369 
370 REGISTER_CPU(float);
371 REGISTER_CPU(double);
372 
373 #undef REGISTER_CPU
374 
375 }  // end namespace tensorflow
376