• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/ctc_ops.cc.
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 
22 #include <utility>
23 
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
32 #include "tensorflow/core/util/sparse/sparse_tensor.h"
33 
34 #if GOOGLE_CUDA
35 #include "third_party/gpus/cudnn/cudnn.h"
36 #endif  // GOOGLE_CUDA
37 
38 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39 #include "tensorflow/core/kernels/conv_ops_gpu.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41 #include "tensorflow/core/util/tensor_format.h"
42 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
43 
44 namespace tensorflow {
45 
46 typedef Eigen::ThreadPoolDevice CPUDevice;
47 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 using GPUDevice = Eigen::GpuDevice;
49 
50 namespace {
51 using se::Stream;
52 using se::StreamExecutor;
53 using se::dnn::RnnStateTensorDescriptor;
54 using se::dnn::ToDataType;
55 
56 template <typename T>
DoHistogram(OpKernelContext * ctx,const Tensor * labels_indices,int num_indices,int batch_size,std::vector<int> * labels_lengths)57 void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
58                  int num_indices, int batch_size,
59                  std::vector<int>* labels_lengths) {
60   const T* h_in = labels_indices->flat<T>().data();
61   for (int i = 0; i < num_indices; i++) {
62     const T& key = h_in[i * 2];
63     (*labels_lengths)[key]++;
64   }
65 }
66 
67 }  // end namespace
68 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
69 
70 template <typename T>
71 class CTCLossOp : public OpKernel {
72   typedef Eigen::Map<
73       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
74       InputMap;
75   typedef Eigen::Map<
76       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
77       OutputMap;
78 
79  public:
CTCLossOp(OpKernelConstruction * ctx)80   explicit CTCLossOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
81     OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
82                                      &preprocess_collapse_repeated_));
83     OP_REQUIRES_OK(ctx,
84                    ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
85     OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
86                                      &ignore_longer_outputs_than_inputs_));
87   }
88 
Compute(OpKernelContext * ctx)89   void Compute(OpKernelContext* ctx) override {
90     const Tensor* inputs;
91     const Tensor* labels_indices;
92     const Tensor* labels_values;
93     const Tensor* seq_len;
94     OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
95     OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
96     OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
97     OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
98 
99     OP_REQUIRES(ctx, inputs->shape().dims() == 3,
100                 errors::InvalidArgument("inputs is not a 3-Tensor"));
101     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
102                 errors::InvalidArgument("sequence_length is not a vector"));
103     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
104                 errors::InvalidArgument("labels_indices is not a matrix"));
105     OP_REQUIRES(ctx, labels_indices->dim_size(1) > 1,
106                 errors::InvalidArgument(
107                     "labels_indices second dimension must be >= 1. Received ",
108                     labels_indices->dim_size(1)));
109     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
110                 errors::InvalidArgument("labels_values is not a vector"));
111 
112     const TensorShape& inputs_shape = inputs->shape();
113     const int64_t max_time = inputs_shape.dim_size(0);
114     OP_REQUIRES(ctx, max_time != 0,
115                 errors::InvalidArgument(
116                     "Max time or first dimension of input cannot be 0."));
117     const int64_t batch_size = inputs_shape.dim_size(1);
118     const int64_t num_classes_raw = inputs_shape.dim_size(2);
119     OP_REQUIRES(
120         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
121         errors::InvalidArgument("num_classes cannot exceed max int"));
122     const int num_classes = static_cast<const int>(num_classes_raw);
123 
124     OP_REQUIRES(
125         ctx, batch_size == seq_len->dim_size(0),
126         errors::InvalidArgument("len(sequence_length) != batch_size.  ",
127                                 "len(sequence_length):  ", seq_len->dim_size(0),
128                                 " batch_size: ", batch_size));
129     auto seq_len_t = seq_len->vec<int32>();
130 
131     OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
132                 errors::InvalidArgument(
133                     "labels_indices and labels_values must contain the "
134                     "same number of rows, but saw shapes: ",
135                     labels_indices->shape().DebugString(), " vs. ",
136                     labels_values->shape().DebugString()));
137 
138     OP_REQUIRES(ctx, batch_size != 0,
139                 errors::InvalidArgument("batch_size must not be 0"));
140 
141     // Figure out the maximum label length to use as sparse tensor dimension.
142     auto labels_indices_t = labels_indices->matrix<int64_t>();
143     int64_t max_label_len = 0;
144     for (int i = 0; i < labels_indices->dim_size(0); i++) {
145       max_label_len = std::max(max_label_len, labels_indices_t(i, 1) + 1);
146     }
147 
148     TensorShape labels_shape({batch_size, max_label_len});
149     std::vector<int64_t> order{0, 1};
150     sparse::SparseTensor labels_sp;
151     OP_REQUIRES_OK(
152         ctx, sparse::SparseTensor::Create(*labels_indices, *labels_values,
153                                           labels_shape, order, &labels_sp));
154 
155     Status labels_sp_valid = labels_sp.IndicesValid();
156     OP_REQUIRES(ctx, labels_sp_valid.ok(),
157                 errors::InvalidArgument("label SparseTensor is not valid: ",
158                                         labels_sp_valid.error_message()));
159 
160     typename ctc::CTCLossCalculator<T>::LabelSequences labels_t(batch_size);
161     for (const auto& g : labels_sp.group({0})) {  // iterate by batch
162       const int64_t batch_indices = g.group()[0];
163       OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size),
164                   errors::InvalidArgument("labels batch index must be between ",
165                                           0, " and ", batch_size,
166                                           " but saw: ", batch_indices));
167 
168       auto values = g.values<int32>();
169       std::vector<int>* b_values = &labels_t[batch_indices];
170       b_values->resize(values.size());
171       for (int i = 0; i < values.size(); ++i) (*b_values)[i] = values(i);
172     }
173 
174     OP_REQUIRES(ctx, static_cast<size_t>(batch_size) == labels_t.size(),
175                 errors::InvalidArgument("len(labels) != batch_size.  ",
176                                         "len(labels):  ", labels_t.size(),
177                                         " batch_size: ", batch_size));
178 
179     for (int64_t b = 0; b < batch_size; ++b) {
180       OP_REQUIRES(
181           ctx, seq_len_t(b) <= max_time,
182           errors::InvalidArgument("sequence_length(", b, ") <= ", max_time));
183     }
184 
185     Tensor* loss = nullptr;
186     OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
187     auto loss_t = loss->vec<T>();
188 
189     Tensor* gradient;
190     OP_REQUIRES_OK(ctx,
191                    ctx->allocate_output("gradient", inputs_shape, &gradient));
192     auto gradient_t = gradient->tensor<T, 3>();
193     auto inputs_t = inputs->tensor<T, 3>();
194     std::vector<OutputMap> gradient_list_t;
195     std::vector<InputMap> input_list_t;
196 
197     for (std::size_t t = 0; t < max_time; ++t) {
198       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
199                                 batch_size, num_classes);
200       gradient_list_t.emplace_back(
201           gradient_t.data() + t * batch_size * num_classes, batch_size,
202           num_classes);
203     }
204 
205     gradient_t.setZero();
206 
207     // Assumption: the blank index is num_classes - 1
208     ctc::CTCLossCalculator<T> ctc_loss_calculator(num_classes - 1, 0);
209     DeviceBase::CpuWorkerThreads workers =
210         *ctx->device()->tensorflow_cpu_worker_threads();
211     OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
212                             seq_len_t, labels_t, input_list_t,
213                             preprocess_collapse_repeated_, ctc_merge_repeated_,
214                             ignore_longer_outputs_than_inputs_, &loss_t,
215                             &gradient_list_t, &workers));
216   }
217 
218  private:
219   bool preprocess_collapse_repeated_;
220   bool ctc_merge_repeated_;
221   bool ignore_longer_outputs_than_inputs_;
222 
223   TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp<T>);
224 };
225 
226 #define REGISTER_CPU(T)                                          \
227   REGISTER_KERNEL_BUILDER(                                       \
228       Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
229       CTCLossOp<T>);
230 
231 REGISTER_CPU(float);
232 REGISTER_CPU(double);
233 
234 #undef REGISTER_CPU
235 
236 #if ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
237 class CTCLossOpGPU : public OpKernel {
238  public:
CTCLossOpGPU(OpKernelConstruction * ctx)239   explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
240     bool preprocess_collapse_repeated;
241     bool ctc_merge_repeated;
242     bool ignore_longer_outputs_than_inputs;
243     OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
244                                      &preprocess_collapse_repeated));
245     OP_REQUIRES_OK(ctx,
246                    ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated));
247     OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
248                                      &ignore_longer_outputs_than_inputs));
249 
250     OP_REQUIRES(ctx, !preprocess_collapse_repeated,
251                 errors::InvalidArgument("GPU CTCLossOp requires "
252                                         "preprocess_collapse_repeated to be "
253                                         "false"));
254     OP_REQUIRES(ctx, ctc_merge_repeated,
255                 errors::InvalidArgument("GPU CTCLossOp requires "
256                                         "ctc_merge_repeated to be "
257                                         "true"));
258     OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs,
259                 errors::InvalidArgument("GPU CTCLossOp requires "
260                                         "ignore_longer_outputs_than_inputs to"
261                                         "be false"));
262   }
263 
Compute(OpKernelContext * ctx)264   void Compute(OpKernelContext* ctx) override {
265     const Tensor* inputs;
266     const Tensor* labels_indices;
267     const Tensor* labels_values;
268     const Tensor* seq_len;
269     OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
270     OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
271     OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
272     OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
273 
274     OP_REQUIRES(ctx, inputs->shape().dims() == 3,
275                 errors::InvalidArgument("inputs is not a 3-Tensor"));
276     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
277                 errors::InvalidArgument("sequence_length is not a vector"));
278     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
279                 errors::InvalidArgument("labels_indices is not a matrix"));
280     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
281                 errors::InvalidArgument("labels_values is not a vector"));
282 
283     const TensorShape& inputs_shape = inputs->shape();
284     const int64_t max_time_raw = inputs_shape.dim_size(0);
285     const int64_t batch_size_raw = inputs_shape.dim_size(1);
286     const int64_t num_classes_raw = inputs_shape.dim_size(2);
287     OP_REQUIRES(ctx,
288                 FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()),
289                 errors::InvalidArgument("max_time_ cannot exceed max int"));
290     OP_REQUIRES(
291         ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()),
292         errors::InvalidArgument("batch_size cannot exceed max int"));
293     OP_REQUIRES(
294         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
295         errors::InvalidArgument("num_classes cannot exceed max int"));
296     const int max_time = static_cast<const int>(max_time_raw);
297     const int batch_size = static_cast<const int>(batch_size_raw);
298     const int num_classes = static_cast<const int>(num_classes_raw);
299 
300     OP_REQUIRES(
301         ctx, batch_size == seq_len->dim_size(0),
302         errors::InvalidArgument("len(sequence_length) != batch_size.  ",
303                                 "len(sequence_length):  ", seq_len->dim_size(0),
304                                 " batch_size: ", batch_size));
305 
306     OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
307                 errors::InvalidArgument(
308                     "labels_indices and labels_values must contain the "
309                     "same number of rows, but saw shapes: ",
310                     labels_indices->shape().DebugString(), " vs. ",
311                     labels_values->shape().DebugString()));
312     auto num_indices = labels_indices->dim_size(0);
313 
314     OP_REQUIRES(ctx, batch_size != 0,
315                 errors::InvalidArgument("batch_size must not be 0"));
316 
317     Tensor* loss = nullptr;
318     OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
319 
320     Tensor* gradient = nullptr;
321     OP_REQUIRES_OK(ctx,
322                    ctx->allocate_output("gradient", inputs_shape, &gradient));
323 
324     // Convert the labels_indices to labels_lengths.
325     std::vector<int> labels_lengths(batch_size, 0);
326     DoHistogram<int64_t>(ctx, labels_indices, num_indices, batch_size,
327                          &labels_lengths);
328 
329     StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
330     se::dnn::DataType data_type = ToDataType<float>::value;
331 
332     auto probs_desc_s = executor->createRnnStateTensorDescriptor(
333         max_time, batch_size, num_classes, data_type);
334     OP_REQUIRES_OK(ctx, probs_desc_s.status());
335     std::unique_ptr<RnnStateTensorDescriptor> probs_desc =
336         std::move(probs_desc_s).value();
337 
338     auto grads_desc_s = executor->createRnnStateTensorDescriptor(
339         max_time, batch_size, num_classes, data_type);
340     OP_REQUIRES_OK(ctx, grads_desc_s.status());
341     std::unique_ptr<RnnStateTensorDescriptor> grads_desc =
342         std::move(grads_desc_s).value();
343 
344     absl::Span<const int32> labels_data(labels_values->flat<int32>().data(),
345                                         num_indices);
346     absl::Span<const int32> labels_lengths_data(labels_lengths.data(),
347                                                 batch_size);
348     absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(),
349                                                batch_size);
350 
351     auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
352     auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
353     auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);
354 
355     // Set the memory limitation to 4GB for workspace memory.
356     DnnScratchAllocator workspace_allocator(1LL << 32, ctx);
357 
358     Stream* stream = ctx->op_device_context()->stream();
359     bool cudnn_launch_status =
360         stream
361             ->ThenCtcLoss(*probs_desc, probs_data, labels_data,
362                           labels_lengths_data, input_lengths_data, &costs_data,
363                           *grads_desc, &grads_data, &workspace_allocator)
364             .ok();
365 
366     if (!cudnn_launch_status) {
367       ctx->SetStatus(errors::Internal("cuDNN CTCLoss launch failure"));
368     }
369   }
370 
371  private:
372   TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
373 };
374 
375 REGISTER_KERNEL_BUILDER(Name("CTCLossV2")
376                             .Device(DEVICE_GPU)
377                             .HostMemory("labels_indices")
378                             .HostMemory("labels_values")
379                             .HostMemory("sequence_length"),
380                         CTCLossOpGPU);
381 #endif  // ((GOOGLE_CUDA && CUDNN_VERSION >= 7603)  || TENSORFLOW_USE_ROCM)
382 }  // end namespace tensorflow
383