1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
17 #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
18
19 #include <vector>
20
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
28 #include "tensorflow/core/util/work_sharder.h"
29
30 namespace tensorflow {
31 namespace ctc {
32
33 class CTCLossCalculator {
34 // Connectionist Temporal Classification Loss
35 //
36 // Implementation by kanishkarao@, posenhuang@, and ebrevdo@.
37 //
38 // The CTC Loss layer learns a *transition* probability value for each
39 // input time step. The transitions are on the class alphabet
40 // {0, 1, ..., N-2}
41 // where N is the depth of the input layer (the size of the alphabet is N-1).
42 // Note: The token N-1 is reserved for the "no transition" output, so
43 // make sure that your input layer has a depth that's one larger than
44 // the set of classes you're training on. Also make sure that your
45 // training labels do not have a class value of N-1, as training will skip
46 // these examples.
47 //
48 // Reference materials:
49 // GravesTh: Alex Graves, "Supervised Sequence Labeling with Recurrent
50 // Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
51 public:
52 typedef std::vector<std::vector<int>> LabelSequences;
53 typedef Eigen::MatrixXf Matrix;
54 typedef Eigen::ArrayXf Array;
55 typedef Eigen::Map<const Eigen::MatrixXf> InputMap;
56 typedef Eigen::Map<Eigen::MatrixXf> OutputMap;
57
CTCLossCalculator(int blank_index,int output_delay)58 CTCLossCalculator(int blank_index, int output_delay)
59 : blank_index_(blank_index), output_delay_(output_delay) {}
60
61 template <typename VectorIn, typename VectorOut, typename MatrixIn,
62 typename MatrixOut>
63 Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels,
64 const std::vector<MatrixIn>& inputs,
65 bool preprocess_collapse_repeated,
66 bool ctc_merge_repeated,
67 bool ignore_longer_outputs_than_inputs, VectorOut* loss,
68 std::vector<MatrixOut>* gradients,
69 DeviceBase::CpuWorkerThreads* workers = nullptr) const;
70
71 private:
72 void CalculateForwardVariables(const std::vector<int>& l_prime,
73 const Matrix& y, bool ctc_merge_repeated,
74 Matrix* log_alpha) const;
75
76 void CalculateBackwardVariables(const std::vector<int>& l_prime,
77 const Matrix& y, bool ctc_merge_repeated,
78 Matrix* log_beta) const;
79
80 void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
81 const Matrix& log_alpha, const Matrix& log_beta,
82 float log_p_z_x, Matrix* dy) const;
83
84 void GetLPrimeIndices(const std::vector<int>& l,
85 std::vector<int>* l_prime) const;
86
87 // Helper function that calculates the l_prime indices for all
88 // batches at the same time, and identifies errors for any given
89 // batch. Return value:
90 // max_{b in batch_size} l_primes[b].size()
91 template <typename Vector>
92 Status PopulateLPrimes(bool preprocess_collapse_repeated,
93 bool ignore_longer_outputs_than_inputs, int batch_size,
94 int num_classes, const Vector& seq_len,
95 const LabelSequences& labels, size_t* max_u_prime,
96 LabelSequences* l_primes) const;
97
98 // Utility indices for the CTC algorithm.
99 int blank_index_;
100
101 // Delay for target labels in time steps.
102 // The delay in time steps before the output sequence.
103 const int output_delay_;
104 };
105
106 template <typename VectorIn, typename VectorOut, typename MatrixIn,
107 typename MatrixOut>
CalculateLoss(const VectorIn & seq_len,const LabelSequences & labels,const std::vector<MatrixIn> & inputs,bool preprocess_collapse_repeated,bool ctc_merge_repeated,bool ignore_longer_outputs_than_inputs,VectorOut * loss,std::vector<MatrixOut> * gradients,DeviceBase::CpuWorkerThreads * workers)108 Status CTCLossCalculator::CalculateLoss(
109 const VectorIn& seq_len, const LabelSequences& labels,
110 const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
111 bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
112 VectorOut* loss, std::vector<MatrixOut>* gradients,
113 DeviceBase::CpuWorkerThreads* workers) const {
114 auto num_time_steps = inputs.size();
115
116 if (loss == nullptr) {
117 return errors::InvalidArgument("loss == nullptr");
118 }
119
120 bool requires_backprop = (gradients != nullptr);
121
122 auto batch_size = inputs[0].rows();
123 auto num_classes = inputs[0].cols();
124
125 if (loss->size() != batch_size) {
126 return errors::InvalidArgument("loss.size() != batch_size");
127 }
128 loss->setZero();
129
130 for (int t = 1; t < num_time_steps; ++t) {
131 if (inputs[t].rows() != batch_size) {
132 return errors::InvalidArgument("Expected batch size at t: ", t,
133 " to be: ", batch_size,
134 " but got: ", inputs[t].rows());
135 }
136 if (inputs[t].cols() != num_classes) {
137 return errors::InvalidArgument("Expected class count at t: ", t,
138 " to be: ", num_classes,
139 " but got: ", inputs[t].cols());
140 }
141 }
142
143 // Check validity of sequence_length array values.
144 auto max_seq_len = seq_len(0);
145 for (int b = 0; b < batch_size; b++) {
146 if (seq_len(b) < 0) {
147 return errors::InvalidArgument("seq_len(", b, ") < 0");
148 }
149 if (seq_len(b) > num_time_steps) {
150 return errors::InvalidArgument("seq_len(", b, ") > num_time_steps");
151 }
152 max_seq_len = std::max(seq_len(b), max_seq_len);
153 }
154
155 // Calculate the modified label sequence l' for each batch element,
156 // and calculate the maximum necessary allocation size.
157 LabelSequences l_primes(batch_size);
158 size_t max_u_prime = 0;
159 Status l_p_ret = PopulateLPrimes(
160 preprocess_collapse_repeated, ignore_longer_outputs_than_inputs,
161 batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes);
162 if (!l_p_ret.ok()) {
163 return l_p_ret;
164 }
165
166 // Process each item in a batch in parallel, using at most kMaxThreads.
167 auto ComputeLossAndGradients = [this, num_classes, &labels, &l_primes,
168 &seq_len, &inputs, requires_backprop,
169 ctc_merge_repeated,
170 ignore_longer_outputs_than_inputs, &loss,
171 &gradients](int64 start_row,
172 int64 limit_row) {
173 for (int b = start_row; b < limit_row; b++) {
174 // Return zero gradient for empty sequences or sequences with labels
175 // longer than input, which is not supported by CTC.
176 if (seq_len(b) == 0 ||
177 (ignore_longer_outputs_than_inputs &&
178 labels[b].size() > seq_len(b) - this->output_delay_)) {
179 VLOG(1) << "The sequence length is either zero or shorter than the "
180 "target output (CTC works only with shorter target sequence "
181 "than input sequence). You can turn this into a warning by "
182 "using the flag ignore_longer_outputs_than_inputs - "
183 << b << ": " << str_util::Join(labels[b], " ");
184 continue;
185 }
186
187 // For each batch element, log(alpha) and log(beta).
188 // row size is: u_prime == l_prime.size()
189 // col size is: seq_len[b] - output_delay_
190 const std::vector<int>& l_prime = l_primes[b];
191
192 Matrix log_alpha_b(l_prime.size(), seq_len(b) - this->output_delay_);
193 Matrix log_beta_b(l_prime.size(), seq_len(b) - this->output_delay_);
194
195 // Work matrices, pre-allocated to the size required by this batch item.
196 Matrix y(num_classes, seq_len(b));
197 Matrix dy;
198 if (requires_backprop) {
199 dy = Matrix::Zero(y.rows(), y.cols());
200 }
201
202 // For this batch, we'll only work with this shortened sequence_length.
203 Matrix y_b = y.leftCols(seq_len(b));
204
205 // Convert label from DistBelief
206 // y, prob are in num_classes x seq_len(b)
207 // Output activations.
208 Eigen::ArrayXf y_b_col;
209 for (int t = 0; t < seq_len(b); t++) {
210 // Calculate the softmax of y_b. Use double precision
211 // arithmetic for the sum.
212 float max_coeff = inputs[t].row(b).maxCoeff();
213 y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
214 y_b.col(t) = y_b_col / y_b_col.sum();
215 }
216
217 // Compute forward, backward.
218 // Forward variables.
219 CalculateForwardVariables(l_prime, y_b, ctc_merge_repeated, &log_alpha_b);
220 // Backward variables.
221 CalculateBackwardVariables(l_prime, y_b, ctc_merge_repeated, &log_beta_b);
222
223 // The loss is computed as the log(p(z|x)) between the target and
224 // prediction. Do lazy evaluation of log_prob here.
225 float log_p_z_x = kLogZero;
226 for (int u = 0; u < l_prime.size(); ++u) {
227 // (GravesTh) Eq 7.26, sum over all paths for t = 0.
228 log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
229 }
230
231 (*loss)(b) = -log_p_z_x; // Use negative log loss for display.
232
233 // We compute the derivative if needed.
234 if (requires_backprop) {
235 // Gradients with respect to input activations.
236 // Calculate gradient.
237 dy.setZero();
238 CalculateGradient(l_prime, y_b, log_alpha_b, log_beta_b, log_p_z_x,
239 &dy);
240
241 // Convert gradient for current sample to DistBelief.
242 for (int t = 0; t < seq_len(b); t++) {
243 (*gradients)[t].row(b).array() = dy.col(t);
244 }
245 }
246 } // for (int b = ...
247 };
248 if (workers) {
249 // *Rough* estimate of the cost for one item in the batch.
250 // Forward, Backward: O(T * U (= 2L + 1)), Gradients: O(T * (U + L)).
251 //
252 // softmax: T * L * (Cost(Exp) + Cost(Div))softmax +
253 // fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) +
254 // grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)).
255 const int64 cost_exp = Eigen::internal::functor_traits<
256 Eigen::internal::scalar_exp_op<float>>::Cost;
257 const int64 cost_log = Eigen::internal::functor_traits<
258 Eigen::internal::scalar_log_op<float>>::Cost;
259 const int64 cost_log_sum_exp =
260 Eigen::TensorOpCost::AddCost<float>() + cost_exp + cost_log;
261 const int64 cost =
262 max_seq_len * num_classes *
263 (cost_exp + Eigen::TensorOpCost::DivCost<float>()) +
264 max_seq_len * 2 * (2 * num_classes + 1) *
265 (cost_log_sum_exp + cost_log) +
266 max_seq_len *
267 ((2 * num_classes + 1) * cost_log_sum_exp +
268 num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<float>()));
269 Shard(workers->num_threads, workers->workers, batch_size, cost,
270 ComputeLossAndGradients);
271 } else {
272 ComputeLossAndGradients(0, batch_size);
273 }
274 return Status::OK();
275 }
276
277 template <typename Vector>
PopulateLPrimes(bool preprocess_collapse_repeated,bool ignore_longer_outputs_than_inputs,int batch_size,int num_classes,const Vector & seq_len,const LabelSequences & labels,size_t * max_u_prime,LabelSequences * l_primes)278 Status CTCLossCalculator::PopulateLPrimes(
279 bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
280 int batch_size, int num_classes, const Vector& seq_len,
281 const LabelSequences& labels, size_t* max_u_prime,
282 LabelSequences* l_primes) const {
283 // labels is a Label array of size batch_size
284 if (labels.size() != batch_size) {
285 return errors::InvalidArgument(
286 "labels.size() != batch_size: ", labels.size(), " vs. ", batch_size);
287 }
288
289 *max_u_prime = 0; // keep track of longest l' modified label sequence.
290 for (int b = 0; b < batch_size; b++) {
291 // Assume label is in Label proto
292 const std::vector<int>& label = labels[b];
293 if (label.size() == 0) {
294 return errors::InvalidArgument("Labels length is zero in batch ", b);
295 }
296
297 // If debugging: output the labels coming into training.
298 //
299 VLOG(2) << "label for batch: " << b << ": " << str_util::Join(label, " ");
300
301 // Target indices, length = U.
302 std::vector<int> l;
303
304 // Convert label from DistBelief
305 bool finished_sequence = false;
306 for (int i = 0; i < label.size(); ++i) {
307 if (i == 0 || !preprocess_collapse_repeated || label[i] != label[i - 1]) {
308 if (label[i] >= num_classes - 1) {
309 finished_sequence = true;
310 } else {
311 if (finished_sequence) {
312 // Saw an invalid sequence with non-null following null
313 // labels.
314 return errors::InvalidArgument(
315 "Saw a non-null label (index >= num_classes - 1) "
316 "following a ",
317 "null label, batch: ", b, " num_classes: ", num_classes,
318 " labels: ", str_util::Join(l, ","));
319 }
320 l.push_back(label[i]);
321 }
322 }
323 }
324
325 for (int l_i : l) {
326 if (l_i < 0) {
327 return errors::InvalidArgument(
328 "All labels must be nonnegative integers, batch: ", b,
329 " labels: ", str_util::Join(l, ","));
330 } else if (l_i >= num_classes) {
331 return errors::InvalidArgument(
332 "No label may be greater than num_classes. ",
333 "num_classes: ", num_classes, ", batch: ", b,
334 " labels: ", str_util::Join(l, ","));
335 }
336 }
337 if (!ignore_longer_outputs_than_inputs) {
338 // Make sure there is enough time to output the target indices.
339 int time = seq_len(b) - output_delay_;
340 int required_time = label.size();
341 if (required_time > time) {
342 return errors::InvalidArgument(
343 "Not enough time for target transition sequence ("
344 "required: ",
345 required_time, ", available: ", time, ")", b,
346 "You can turn this error into a warning by using the flag "
347 "ignore_longer_outputs_than_inputs");
348 }
349 }
350 // Target indices with blanks before each index and a blank at the end.
351 // Length U' = 2U + 1.
352 // Convert l to l_prime
353 GetLPrimeIndices(l, &l_primes->at(b));
354 *max_u_prime = std::max(*max_u_prime, l_primes->at(b).size());
355 }
356 return Status::OK();
357 }
358
359 } // namespace ctc
360 } // namespace tensorflow
361
362 #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
363