1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
17
18 namespace tensorflow {
19 namespace ctc {
20
21 // Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
22 // Starting with t = 0 instead of t = 1 used in the text.
23 // Based on Kanishka's CTC.
CalculateForwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_alpha) const24 void CTCLossCalculator::CalculateForwardVariables(
25 const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
26 Matrix* log_alpha) const {
27 // Number of cols is the number of time steps = number of cols in target
28 // after the output delay.
29 log_alpha->setConstant(kLogZero);
30
31 int U = l_prime.size();
32 int T = log_alpha->cols();
33
34 CHECK_EQ(U, log_alpha->rows());
35
36 // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
37 log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
38 // Below, l_prime[1] == labels[0]
39 auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
40 log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
41
42 for (int t = 1; t < T; ++t) {
43 // If there is not enough time to output the remaining labels or
44 // some labels have been skipped, then let log_alpha(u, t) continue to
45 // be kLogZero.
46 for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
47 ++u) {
48 // Begin (GravesTh) Eq 7.9
49 // Add in the u, t - 1 term.
50 float sum_log_alpha = kLogZero;
51 if (ctc_merge_repeated || l_prime[u] == blank_index_) {
52 sum_log_alpha = log_alpha->coeff(u, t - 1);
53 }
54
55 // Add in the u - 1, t - 1 term.
56 if (u > 0) {
57 sum_log_alpha =
58 LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
59 }
60
61 // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
62 if (u > 1) {
63 const bool matching_labels_merge =
64 ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
65 if (l_prime[u] != blank_index_ && !matching_labels_merge) {
66 sum_log_alpha =
67 LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
68 }
69 }
70 // Multiply the summed alphas with the activation log probability.
71 log_alpha->coeffRef(u, t) =
72 log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
73 } // End (GravesTh) Eq 7.9.
74 }
75 }
76
77 // Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
CalculateBackwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_beta) const78 void CTCLossCalculator::CalculateBackwardVariables(
79 const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
80 Matrix* log_beta) const {
81 // Number of cols is the number of time steps = number of cols in target.
82 // Matrix log_beta =
83 // Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
84 // kLogZero);
85 log_beta->setConstant(kLogZero);
86 int T = log_beta->cols();
87 int U = l_prime.size();
88 CHECK_EQ(U, log_beta->rows());
89
90 // Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
91 for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
92
93 for (int t = T - 1 - 1; t >= 0; --t) {
94 // If there is not enough time to output the remaining labels or
95 // some labels have been skipped, then let log_beta(u, t) continue to
96 // be kLogZero.
97 for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
98 ++u) {
99 // Begin (GravesTh) Eq 7.15
100 // Add in the u, t + 1 term.
101 if (ctc_merge_repeated || l_prime[u] == blank_index_) {
102 log_beta->coeffRef(u, t) =
103 LogSumExp(log_beta->coeff(u, t),
104 log_beta->coeff(u, t + 1) +
105 log(y(l_prime[u], output_delay_ + t + 1)));
106 }
107
108 // Add in the u + 1, t + 1 term.
109 if (u + 1 < U) {
110 log_beta->coeffRef(u, t) =
111 LogSumExp(log_beta->coeff(u, t),
112 log_beta->coeff(u + 1, t + 1) +
113 log(y(l_prime[u + 1], output_delay_ + t + 1)));
114 }
115
116 // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
117 if (u + 2 < U) {
118 const bool matching_labels_merge =
119 ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
120 if (l_prime[u] != blank_index_ && !matching_labels_merge) {
121 // Add in u + 2 term.
122 log_beta->coeffRef(u, t) =
123 LogSumExp(log_beta->coeff(u, t),
124 log_beta->coeff(u + 2, t + 1) +
125 log(y(l_prime[u + 2], output_delay_ + t + 1)));
126 }
127 } // End (GravesTh) Eq. 7.15
128 }
129 }
130 }
131
132 // Using (GravesTh) Eq 7.26 & 7.34.
CalculateGradient(const std::vector<int> & l_prime,const Matrix & y,const Matrix & log_alpha,const Matrix & log_beta,float log_p_z_x,Matrix * dy) const133 void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime,
134 const Matrix& y,
135 const Matrix& log_alpha,
136 const Matrix& log_beta,
137 float log_p_z_x, Matrix* dy) const {
138 // Only working with the leftmost part of dy for this batch element.
139 auto dy_b = dy->leftCols(y.cols());
140
141 // It is possible that no valid path is found if the activations for the
142 // targets are zero.
143 if (log_p_z_x == kLogZero) {
144 LOG(WARNING) << "No valid path found.";
145 dy_b = y;
146 return;
147 }
148
149 int L = y.rows();
150 int T = y.cols();
151 int U = l_prime.size();
152
153 for (int t = 0; t < T - output_delay_; ++t) {
154 Array prob_sum(L);
155 prob_sum.setConstant(kLogZero);
156
157 for (int u = 0; u < U; ++u) {
158 int l = l_prime[u];
159 prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
160 }
161
162 for (int l = 0; l < L; ++l) {
163 // Negative term in (GravesTh) Eq 7.28.
164 float negative_term = expf(prob_sum[l] - log_p_z_x);
165
166 dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
167 }
168 }
169 }
170
GetLPrimeIndices(const std::vector<int> & l,std::vector<int> * l_prime) const171 void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l,
172 std::vector<int>* l_prime) const {
173 // Assumption is that l_prime is empty.
174 l_prime->reserve(2 * l.size() + 1);
175
176 for (auto label : l) {
177 l_prime->push_back(blank_index_);
178 l_prime->push_back(label);
179 }
180 // Add final blank to l'.
181 l_prime->push_back(blank_index_);
182 }
183
184 } // namespace ctc
185 } // namespace tensorflow
186