• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/ctc_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include <limits>
21 
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/util/ctc/ctc_beam_search.h"
29 #include "tensorflow/core/util/sparse/sparse_tensor.h"
30 #include "tensorflow/core/util/work_sharder.h"
31 
32 namespace tensorflow {
33 
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 
RowMax(const TTypes<float>::UnalignedConstMatrix & m,int r,int * c)36 inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
37                     int* c) {
38   *c = 0;
39   CHECK_LT(0, m.dimension(1));
40   float p = m(r, 0);
41   for (int i = 1; i < m.dimension(1); ++i) {
42     if (m(r, i) > p) {
43       p = m(r, i);
44       *c = i;
45     }
46   }
47   return p;
48 }
49 
50 class CTCDecodeHelper {
51  public:
CTCDecodeHelper()52   CTCDecodeHelper() : top_paths_(1) {}
53 
GetTopPaths() const54   inline int GetTopPaths() const { return top_paths_; }
SetTopPaths(int tp)55   void SetTopPaths(int tp) { top_paths_ = tp; }
56 
ValidateInputsGenerateOutputs(OpKernelContext * ctx,const Tensor ** inputs,const Tensor ** seq_len,Tensor ** log_prob,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const57   Status ValidateInputsGenerateOutputs(
58       OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
59       Tensor** log_prob, OpOutputList* decoded_indices,
60       OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
61     Status status = ctx->input("inputs", inputs);
62     if (!status.ok()) return status;
63     status = ctx->input("sequence_length", seq_len);
64     if (!status.ok()) return status;
65 
66     const TensorShape& inputs_shape = (*inputs)->shape();
67 
68     if (inputs_shape.dims() != 3) {
69       return errors::InvalidArgument("inputs is not a 3-Tensor");
70     }
71 
72     const int64 max_time = inputs_shape.dim_size(0);
73     const int64 batch_size = inputs_shape.dim_size(1);
74 
75     if (max_time == 0) {
76       return errors::InvalidArgument("max_time is 0");
77     }
78     if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
79       return errors::InvalidArgument("sequence_length is not a vector");
80     }
81 
82     if (!(batch_size == (*seq_len)->dim_size(0))) {
83       return errors::FailedPrecondition(
84           "len(sequence_length) != batch_size.  ",
85           "len(sequence_length):  ", (*seq_len)->dim_size(0),
86           " batch_size: ", batch_size);
87     }
88 
89     auto seq_len_t = (*seq_len)->vec<int32>();
90 
91     for (int b = 0; b < batch_size; ++b) {
92       if (!(seq_len_t(b) <= max_time)) {
93         return errors::FailedPrecondition("sequence_length(", b,
94                                           ") <= ", max_time);
95       }
96     }
97 
98     Status s = ctx->allocate_output(
99         "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
100     if (!s.ok()) return s;
101 
102     s = ctx->output_list("decoded_indices", decoded_indices);
103     if (!s.ok()) return s;
104     s = ctx->output_list("decoded_values", decoded_values);
105     if (!s.ok()) return s;
106     s = ctx->output_list("decoded_shape", decoded_shape);
107     if (!s.ok()) return s;
108 
109     return Status::OK();
110   }
111 
112   // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
StoreAllDecodedSequences(const std::vector<std::vector<std::vector<int>>> & sequences,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const113   Status StoreAllDecodedSequences(
114       const std::vector<std::vector<std::vector<int> > >& sequences,
115       OpOutputList* decoded_indices, OpOutputList* decoded_values,
116       OpOutputList* decoded_shape) const {
117     // Calculate the total number of entries for each path
118     const int64 batch_size = sequences.size();
119     std::vector<int64> num_entries(top_paths_, 0);
120 
121     // Calculate num_entries per path
122     for (const auto& batch_s : sequences) {
123       CHECK_EQ(batch_s.size(), top_paths_);
124       for (int p = 0; p < top_paths_; ++p) {
125         num_entries[p] += batch_s[p].size();
126       }
127     }
128 
129     for (int p = 0; p < top_paths_; ++p) {
130       Tensor* p_indices = nullptr;
131       Tensor* p_values = nullptr;
132       Tensor* p_shape = nullptr;
133 
134       const int64 p_num = num_entries[p];
135 
136       Status s =
137           decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
138       if (!s.ok()) return s;
139       s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
140       if (!s.ok()) return s;
141       s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
142       if (!s.ok()) return s;
143 
144       auto indices_t = p_indices->matrix<int64>();
145       auto values_t = p_values->vec<int64>();
146       auto shape_t = p_shape->vec<int64>();
147 
148       int64 max_decoded = 0;
149       int64 offset = 0;
150 
151       for (int64 b = 0; b < batch_size; ++b) {
152         auto& p_batch = sequences[b][p];
153         int64 num_decoded = p_batch.size();
154         max_decoded = std::max(max_decoded, num_decoded);
155         std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
156         for (int64 t = 0; t < num_decoded; ++t, ++offset) {
157           indices_t(offset, 0) = b;
158           indices_t(offset, 1) = t;
159         }
160       }
161 
162       shape_t(0) = batch_size;
163       shape_t(1) = max_decoded;
164     }
165     return Status::OK();
166   }
167 
168  private:
169   int top_paths_;
170   TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
171 };
172 
173 class CTCGreedyDecoderOp : public OpKernel {
174  public:
CTCGreedyDecoderOp(OpKernelConstruction * ctx)175   explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
176     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
177   }
178 
Compute(OpKernelContext * ctx)179   void Compute(OpKernelContext* ctx) override {
180     const Tensor* inputs;
181     const Tensor* seq_len;
182     Tensor* log_prob = nullptr;
183     OpOutputList decoded_indices;
184     OpOutputList decoded_values;
185     OpOutputList decoded_shape;
186     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
187                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
188                             &decoded_values, &decoded_shape));
189 
190     const TensorShape& inputs_shape = inputs->shape();
191 
192     std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
193     const int64 max_time = inputs_shape.dim_size(0);
194     const int64 batch_size = inputs_shape.dim_size(1);
195     const int64 num_classes_raw = inputs_shape.dim_size(2);
196     OP_REQUIRES(
197         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
198         errors::InvalidArgument("num_classes cannot exceed max int"));
199     const int num_classes = static_cast<const int>(num_classes_raw);
200 
201     auto inputs_t = inputs->tensor<float, 3>();
202 
203     for (std::size_t t = 0; t < max_time; ++t) {
204       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
205                                 batch_size, num_classes);
206     }
207     auto seq_len_t = seq_len->vec<int32>();
208     auto log_prob_t = log_prob->matrix<float>();
209 
210     log_prob_t.setZero();
211 
212     // Assumption: the blank index is num_classes - 1
213     int blank_index = num_classes - 1;
214 
215     // Perform best path decoding
216     std::vector<std::vector<std::vector<int> > > sequences(batch_size);
217     auto decode = [&](const int64 begin, const int64 end) {
218       for (int b = begin; b < end; ++b) {
219         sequences[b].resize(1);
220         auto &sequence = sequences[b][0];
221         int prev_indices = -1;
222         for (int t = 0; t < seq_len_t(b); ++t) {
223           int max_class_indices;
224           log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
225           if (max_class_indices != blank_index &&
226               !(merge_repeated_ && max_class_indices == prev_indices)) {
227             sequence.push_back(max_class_indices);
228           }
229           prev_indices = max_class_indices;
230         }
231       }
232     };
233 
234     const int64 kCostPerUnit = 50 * max_time * num_classes;
235     const int64 total = batch_size;
236     const DeviceBase::CpuWorkerThreads& worker_threads =
237         *ctx->device()->tensorflow_cpu_worker_threads();
238     Shard(worker_threads.num_threads, worker_threads.workers, total,
239           kCostPerUnit, decode);
240 
241     OP_REQUIRES_OK(
242         ctx, decode_helper_.StoreAllDecodedSequences(
243                  sequences, &decoded_indices, &decoded_values, &decoded_shape));
244   }
245 
246  private:
247   CTCDecodeHelper decode_helper_;
248   bool merge_repeated_;
249 
250   TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
251 };
252 
253 REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
254                         CTCGreedyDecoderOp);
255 
256 // CTC beam search
257 class CTCBeamSearchDecoderOp : public OpKernel {
258  public:
CTCBeamSearchDecoderOp(OpKernelConstruction * ctx)259   explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
260     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
261     OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
262     int top_paths;
263     OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
264     decode_helper_.SetTopPaths(top_paths);
265   }
266 
Compute(OpKernelContext * ctx)267   void Compute(OpKernelContext* ctx) override {
268     const Tensor* inputs;
269     const Tensor* seq_len;
270     Tensor* log_prob = nullptr;
271     OpOutputList decoded_indices;
272     OpOutputList decoded_values;
273     OpOutputList decoded_shape;
274     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
275                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
276                             &decoded_values, &decoded_shape));
277 
278     auto inputs_t = inputs->tensor<float, 3>();
279     auto seq_len_t = seq_len->vec<int32>();
280     auto log_prob_t = log_prob->matrix<float>();
281 
282     const TensorShape& inputs_shape = inputs->shape();
283 
284     const int64 max_time = inputs_shape.dim_size(0);
285     const int64 batch_size = inputs_shape.dim_size(1);
286     const int64 num_classes_raw = inputs_shape.dim_size(2);
287     OP_REQUIRES(
288         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
289         errors::InvalidArgument("num_classes cannot exceed max int"));
290     const int num_classes = static_cast<const int>(num_classes_raw);
291 
292     log_prob_t.setZero();
293 
294     std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
295 
296     for (std::size_t t = 0; t < max_time; ++t) {
297       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
298                                 batch_size, num_classes);
299     }
300 
301     ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
302                                             &beam_scorer_, 1 /* batch_size */,
303                                             merge_repeated_);
304     Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
305     auto input_chip_t = input_chip.flat<float>();
306 
307     std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
308     std::vector<float> log_probs;
309 
310     // Assumption: the blank index is num_classes - 1
311     for (int b = 0; b < batch_size; ++b) {
312       auto& best_paths_b = best_paths[b];
313       best_paths_b.resize(decode_helper_.GetTopPaths());
314       for (int t = 0; t < seq_len_t(b); ++t) {
315         input_chip_t = input_list_t[t].chip(b, 0);
316         auto input_bi =
317             Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
318         beam_search.Step(input_bi);
319       }
320       OP_REQUIRES_OK(
321           ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
322                                     &log_probs, merge_repeated_));
323 
324       beam_search.Reset();
325 
326       for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
327         log_prob_t(b, bp) = log_probs[bp];
328       }
329     }
330 
331     OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
332                             best_paths, &decoded_indices, &decoded_values,
333                             &decoded_shape));
334   }
335 
336  private:
337   CTCDecodeHelper decode_helper_;
338   ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
339   bool merge_repeated_;
340   int beam_width_;
341   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
342 };
343 
344 REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
345                         CTCBeamSearchDecoderOp);
346 
347 }  // end namespace tensorflow
348