• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
17 #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
18 
19 #include <vector>
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
28 #include "tensorflow/core/util/work_sharder.h"
29 
30 namespace tensorflow {
31 namespace ctc {
32 
33 class CTCLossCalculator {
34   // Connectionist Temporal Classification Loss
35   //
36   // Implementation by kanishkarao@, posenhuang@, and ebrevdo@.
37   //
38   // The CTC Loss layer learns a *transition* probability value for each
39   // input time step.  The transitions are on the class alphabet
40   //   {0, 1, ..., N-2}
41   // where N is the depth of the input layer (the size of the alphabet is N-1).
42   // Note: The token N-1 is reserved for the "no transition" output, so
43   // make sure that your input layer has a depth that's one larger than
44   // the set of classes you're training on.  Also make sure that your
45   // training labels do not have a class value of N-1, as training will skip
46   // these examples.
47   //
48   // Reference materials:
49   //  GravesTh: Alex Graves, "Supervised Sequence Labeling with Recurrent
50   //    Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
51  public:
52   typedef std::vector<std::vector<int>> LabelSequences;
53   typedef Eigen::MatrixXf Matrix;
54   typedef Eigen::ArrayXf Array;
55   typedef Eigen::Map<const Eigen::MatrixXf> InputMap;
56   typedef Eigen::Map<Eigen::MatrixXf> OutputMap;
57 
CTCLossCalculator(int blank_index,int output_delay)58   CTCLossCalculator(int blank_index, int output_delay)
59       : blank_index_(blank_index), output_delay_(output_delay) {}
60 
61   template <typename VectorIn, typename VectorOut, typename MatrixIn,
62             typename MatrixOut>
63   Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels,
64                        const std::vector<MatrixIn>& inputs,
65                        bool preprocess_collapse_repeated,
66                        bool ctc_merge_repeated,
67                        bool ignore_longer_outputs_than_inputs, VectorOut* loss,
68                        std::vector<MatrixOut>* gradients,
69                        DeviceBase::CpuWorkerThreads* workers = nullptr) const;
70 
71  private:
72   void CalculateForwardVariables(const std::vector<int>& l_prime,
73                                  const Matrix& y, bool ctc_merge_repeated,
74                                  Matrix* log_alpha) const;
75 
76   void CalculateBackwardVariables(const std::vector<int>& l_prime,
77                                   const Matrix& y, bool ctc_merge_repeated,
78                                   Matrix* log_beta) const;
79 
80   void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
81                          const Matrix& log_alpha, const Matrix& log_beta,
82                          float log_p_z_x, Matrix* dy) const;
83 
84   void GetLPrimeIndices(const std::vector<int>& l,
85                         std::vector<int>* l_prime) const;
86 
87   // Helper function that calculates the l_prime indices for all
88   // batches at the same time, and identifies errors for any given
89   // batch.  Return value:
90   //    max_{b in batch_size} l_primes[b].size()
91   template <typename Vector>
92   Status PopulateLPrimes(bool preprocess_collapse_repeated,
93                          bool ignore_longer_outputs_than_inputs, int batch_size,
94                          int num_classes, const Vector& seq_len,
95                          const LabelSequences& labels, size_t* max_u_prime,
96                          LabelSequences* l_primes) const;
97 
98   // Utility indices for the CTC algorithm.
99   int blank_index_;
100 
101   // Delay for target labels in time steps.
102   // The delay in time steps before the output sequence.
103   const int output_delay_;
104 };
105 
106 template <typename VectorIn, typename VectorOut, typename MatrixIn,
107           typename MatrixOut>
CalculateLoss(const VectorIn & seq_len,const LabelSequences & labels,const std::vector<MatrixIn> & inputs,bool preprocess_collapse_repeated,bool ctc_merge_repeated,bool ignore_longer_outputs_than_inputs,VectorOut * loss,std::vector<MatrixOut> * gradients,DeviceBase::CpuWorkerThreads * workers)108 Status CTCLossCalculator::CalculateLoss(
109     const VectorIn& seq_len, const LabelSequences& labels,
110     const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
111     bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
112     VectorOut* loss, std::vector<MatrixOut>* gradients,
113     DeviceBase::CpuWorkerThreads* workers) const {
114   auto num_time_steps = inputs.size();
115 
116   if (loss == nullptr) {
117     return errors::InvalidArgument("loss == nullptr");
118   }
119 
120   bool requires_backprop = (gradients != nullptr);
121 
122   auto batch_size = inputs[0].rows();
123   auto num_classes = inputs[0].cols();
124 
125   if (loss->size() != batch_size) {
126     return errors::InvalidArgument("loss.size() != batch_size");
127   }
128   loss->setZero();
129 
130   for (int t = 1; t < num_time_steps; ++t) {
131     if (inputs[t].rows() != batch_size) {
132       return errors::InvalidArgument("Expected batch size at t: ", t,
133                                      " to be: ", batch_size,
134                                      " but got: ", inputs[t].rows());
135     }
136     if (inputs[t].cols() != num_classes) {
137       return errors::InvalidArgument("Expected class count at t: ", t,
138                                      " to be: ", num_classes,
139                                      " but got: ", inputs[t].cols());
140     }
141   }
142 
143   // Check validity of sequence_length array values.
144   auto max_seq_len = seq_len(0);
145   for (int b = 0; b < batch_size; b++) {
146     if (seq_len(b) < 0) {
147       return errors::InvalidArgument("seq_len(", b, ") < 0");
148     }
149     if (seq_len(b) > num_time_steps) {
150       return errors::InvalidArgument("seq_len(", b, ") > num_time_steps");
151     }
152     max_seq_len = std::max(seq_len(b), max_seq_len);
153   }
154 
155   // Calculate the modified label sequence l' for each batch element,
156   // and calculate the maximum necessary allocation size.
157   LabelSequences l_primes(batch_size);
158   size_t max_u_prime = 0;
159   Status l_p_ret = PopulateLPrimes(
160       preprocess_collapse_repeated, ignore_longer_outputs_than_inputs,
161       batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes);
162   if (!l_p_ret.ok()) {
163     return l_p_ret;
164   }
165 
166   // Process each item in a batch in parallel, using at most kMaxThreads.
167   auto ComputeLossAndGradients = [this, num_classes, &labels, &l_primes,
168                                   &seq_len, &inputs, requires_backprop,
169                                   ctc_merge_repeated,
170                                   ignore_longer_outputs_than_inputs, &loss,
171                                   &gradients](int64 start_row,
172                                               int64 limit_row) {
173     for (int b = start_row; b < limit_row; b++) {
174       // Return zero gradient for empty sequences or sequences with labels
175       // longer than input, which is not supported by CTC.
176       if (seq_len(b) == 0 ||
177           (ignore_longer_outputs_than_inputs &&
178            labels[b].size() > seq_len(b) - this->output_delay_)) {
179         VLOG(1) << "The sequence length is either zero or shorter than the "
180                    "target output (CTC works only with shorter target sequence "
181                    "than input sequence). You can turn this into a warning by "
182                    "using the flag ignore_longer_outputs_than_inputs - "
183                 << b << ": " << str_util::Join(labels[b], " ");
184         continue;
185       }
186 
187       // For each batch element, log(alpha) and log(beta).
188       //   row size is: u_prime == l_prime.size()
189       //   col size is: seq_len[b] - output_delay_
190       const std::vector<int>& l_prime = l_primes[b];
191 
192       Matrix log_alpha_b(l_prime.size(), seq_len(b) - this->output_delay_);
193       Matrix log_beta_b(l_prime.size(), seq_len(b) - this->output_delay_);
194 
195       // Work matrices, pre-allocated to the size required by this batch item.
196       Matrix y(num_classes, seq_len(b));
197       Matrix dy;
198       if (requires_backprop) {
199         dy = Matrix::Zero(y.rows(), y.cols());
200       }
201 
202       // For this batch, we'll only work with this shortened sequence_length.
203       Matrix y_b = y.leftCols(seq_len(b));
204 
205       // Convert label from DistBelief
206       // y, prob are in num_classes x seq_len(b)
207       // Output activations.
208       Eigen::ArrayXf y_b_col;
209       for (int t = 0; t < seq_len(b); t++) {
210         // Calculate the softmax of y_b.  Use double precision
211         // arithmetic for the sum.
212         float max_coeff = inputs[t].row(b).maxCoeff();
213         y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
214         y_b.col(t) = y_b_col / y_b_col.sum();
215       }
216 
217       // Compute forward, backward.
218       // Forward variables.
219       CalculateForwardVariables(l_prime, y_b, ctc_merge_repeated, &log_alpha_b);
220       // Backward variables.
221       CalculateBackwardVariables(l_prime, y_b, ctc_merge_repeated, &log_beta_b);
222 
223       // The loss is computed as the log(p(z|x)) between the target and
224       // prediction. Do lazy evaluation of log_prob here.
225       float log_p_z_x = kLogZero;
226       for (int u = 0; u < l_prime.size(); ++u) {
227         // (GravesTh) Eq 7.26, sum over all paths for t = 0.
228         log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
229       }
230 
231       (*loss)(b) = -log_p_z_x;  // Use negative log loss for display.
232 
233       // We compute the derivative if needed.
234       if (requires_backprop) {
235         // Gradients with respect to input activations.
236         // Calculate gradient.
237         dy.setZero();
238         CalculateGradient(l_prime, y_b, log_alpha_b, log_beta_b, log_p_z_x,
239                           &dy);
240 
241         // Convert gradient for current sample to DistBelief.
242         for (int t = 0; t < seq_len(b); t++) {
243           (*gradients)[t].row(b).array() = dy.col(t);
244         }
245       }
246     }  // for (int b = ...
247   };
248   if (workers) {
249     // *Rough* estimate of the cost for one item in the batch.
250     // Forward, Backward: O(T * U (= 2L + 1)), Gradients: O(T * (U + L)).
251     //
252     // softmax: T * L * (Cost(Exp) + Cost(Div))softmax +
253     // fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) +
254     // grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)).
255     const int64 cost_exp = Eigen::internal::functor_traits<
256         Eigen::internal::scalar_exp_op<float>>::Cost;
257     const int64 cost_log = Eigen::internal::functor_traits<
258         Eigen::internal::scalar_log_op<float>>::Cost;
259     const int64 cost_log_sum_exp =
260         Eigen::TensorOpCost::AddCost<float>() + cost_exp + cost_log;
261     const int64 cost =
262         max_seq_len * num_classes *
263             (cost_exp + Eigen::TensorOpCost::DivCost<float>()) +
264         max_seq_len * 2 * (2 * num_classes + 1) *
265             (cost_log_sum_exp + cost_log) +
266         max_seq_len *
267             ((2 * num_classes + 1) * cost_log_sum_exp +
268              num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<float>()));
269     Shard(workers->num_threads, workers->workers, batch_size, cost,
270           ComputeLossAndGradients);
271   } else {
272     ComputeLossAndGradients(0, batch_size);
273   }
274   return Status::OK();
275 }
276 
277 template <typename Vector>
PopulateLPrimes(bool preprocess_collapse_repeated,bool ignore_longer_outputs_than_inputs,int batch_size,int num_classes,const Vector & seq_len,const LabelSequences & labels,size_t * max_u_prime,LabelSequences * l_primes)278 Status CTCLossCalculator::PopulateLPrimes(
279     bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
280     int batch_size, int num_classes, const Vector& seq_len,
281     const LabelSequences& labels, size_t* max_u_prime,
282     LabelSequences* l_primes) const {
283   // labels is a Label array of size batch_size
284   if (labels.size() != batch_size) {
285     return errors::InvalidArgument(
286         "labels.size() != batch_size: ", labels.size(), " vs. ", batch_size);
287   }
288 
289   *max_u_prime = 0;  // keep track of longest l' modified label sequence.
290   for (int b = 0; b < batch_size; b++) {
291     // Assume label is in Label proto
292     const std::vector<int>& label = labels[b];
293     if (label.size() == 0) {
294       return errors::InvalidArgument("Labels length is zero in batch ", b);
295     }
296 
297     // If debugging: output the labels coming into training.
298     //
299     VLOG(2) << "label for batch: " << b << ": " << str_util::Join(label, " ");
300 
301     // Target indices, length = U.
302     std::vector<int> l;
303 
304     // Convert label from DistBelief
305     bool finished_sequence = false;
306     for (int i = 0; i < label.size(); ++i) {
307       if (i == 0 || !preprocess_collapse_repeated || label[i] != label[i - 1]) {
308         if (label[i] >= num_classes - 1) {
309           finished_sequence = true;
310         } else {
311           if (finished_sequence) {
312             // Saw an invalid sequence with non-null following null
313             // labels.
314             return errors::InvalidArgument(
315                 "Saw a non-null label (index >= num_classes - 1) "
316                 "following a ",
317                 "null label, batch: ", b, " num_classes: ", num_classes,
318                 " labels: ", str_util::Join(l, ","));
319           }
320           l.push_back(label[i]);
321         }
322       }
323     }
324 
325     for (int l_i : l) {
326       if (l_i < 0) {
327         return errors::InvalidArgument(
328             "All labels must be nonnegative integers, batch: ", b,
329             " labels: ", str_util::Join(l, ","));
330       } else if (l_i >= num_classes) {
331         return errors::InvalidArgument(
332             "No label may be greater than num_classes. ",
333             "num_classes: ", num_classes, ", batch: ", b,
334             " labels: ", str_util::Join(l, ","));
335       }
336     }
337     if (!ignore_longer_outputs_than_inputs) {
338       // Make sure there is enough time to output the target indices.
339       int time = seq_len(b) - output_delay_;
340       int required_time = label.size();
341       if (required_time > time) {
342         return errors::InvalidArgument(
343             "Not enough time for target transition sequence ("
344             "required: ",
345             required_time, ", available: ", time, ")", b,
346             "You can turn this error into a warning by using the flag "
347             "ignore_longer_outputs_than_inputs");
348       }
349     }
350     // Target indices with blanks before each index and a blank at the end.
351     // Length U' = 2U + 1.
352     // Convert l to l_prime
353     GetLPrimeIndices(l, &l_primes->at(b));
354     *max_u_prime = std::max(*max_u_prime, l_primes->at(b).size());
355   }
356   return Status::OK();
357 }
358 
359 }  // namespace ctc
360 }  // namespace tensorflow
361 
362 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
363