• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
17 
18 namespace tensorflow {
19 namespace ctc {
20 
21 // Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
22 // Starting with t = 0 instead of t = 1 used in the text.
23 // Based on Kanishka's CTC.
CalculateForwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_alpha) const24 void CTCLossCalculator::CalculateForwardVariables(
25     const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
26     Matrix* log_alpha) const {
27   // Number of cols is the number of time steps = number of cols in target
28   // after the output delay.
29   log_alpha->setConstant(kLogZero);
30 
31   int U = l_prime.size();
32   int T = log_alpha->cols();
33 
34   CHECK_EQ(U, log_alpha->rows());
35 
36   // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
37   log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
38   // Below, l_prime[1] == labels[0]
39   auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
40   log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
41 
42   for (int t = 1; t < T; ++t) {
43     // If there is not enough time to output the remaining labels or
44     // some labels have been skipped, then let log_alpha(u, t) continue to
45     // be kLogZero.
46     for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
47          ++u) {
48       // Begin (GravesTh) Eq 7.9
49       // Add in the u, t - 1 term.
50       float sum_log_alpha = kLogZero;
51       if (ctc_merge_repeated || l_prime[u] == blank_index_) {
52         sum_log_alpha = log_alpha->coeff(u, t - 1);
53       }
54 
55       // Add in the u - 1, t - 1 term.
56       if (u > 0) {
57         sum_log_alpha =
58             LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
59       }
60 
61       // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
62       if (u > 1) {
63         const bool matching_labels_merge =
64             ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
65         if (l_prime[u] != blank_index_ && !matching_labels_merge) {
66           sum_log_alpha =
67               LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
68         }
69       }
70       // Multiply the summed alphas with the activation log probability.
71       log_alpha->coeffRef(u, t) =
72           log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
73     }  // End (GravesTh) Eq 7.9.
74   }
75 }
76 
77 // Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
CalculateBackwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_beta) const78 void CTCLossCalculator::CalculateBackwardVariables(
79     const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
80     Matrix* log_beta) const {
81   // Number of cols is the number of time steps =  number of cols in target.
82   // Matrix log_beta =
83   //    Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
84   // kLogZero);
85   log_beta->setConstant(kLogZero);
86   int T = log_beta->cols();
87   int U = l_prime.size();
88   CHECK_EQ(U, log_beta->rows());
89 
90   // Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
91   for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
92 
93   for (int t = T - 1 - 1; t >= 0; --t) {
94     // If there is not enough time to output the remaining labels or
95     // some labels have been skipped, then let log_beta(u, t) continue to
96     // be kLogZero.
97     for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
98          ++u) {
99       // Begin (GravesTh) Eq 7.15
100       // Add in the u, t + 1 term.
101       if (ctc_merge_repeated || l_prime[u] == blank_index_) {
102         log_beta->coeffRef(u, t) =
103             LogSumExp(log_beta->coeff(u, t),
104                       log_beta->coeff(u, t + 1) +
105                           log(y(l_prime[u], output_delay_ + t + 1)));
106       }
107 
108       // Add in the u + 1, t + 1 term.
109       if (u + 1 < U) {
110         log_beta->coeffRef(u, t) =
111             LogSumExp(log_beta->coeff(u, t),
112                       log_beta->coeff(u + 1, t + 1) +
113                           log(y(l_prime[u + 1], output_delay_ + t + 1)));
114       }
115 
116       // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
117       if (u + 2 < U) {
118         const bool matching_labels_merge =
119             ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
120         if (l_prime[u] != blank_index_ && !matching_labels_merge) {
121           // Add in u + 2 term.
122           log_beta->coeffRef(u, t) =
123               LogSumExp(log_beta->coeff(u, t),
124                         log_beta->coeff(u + 2, t + 1) +
125                             log(y(l_prime[u + 2], output_delay_ + t + 1)));
126         }
127       }  // End (GravesTh) Eq. 7.15
128     }
129   }
130 }
131 
132 // Using (GravesTh) Eq 7.26 & 7.34.
CalculateGradient(const std::vector<int> & l_prime,const Matrix & y,const Matrix & log_alpha,const Matrix & log_beta,float log_p_z_x,Matrix * dy) const133 void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime,
134                                           const Matrix& y,
135                                           const Matrix& log_alpha,
136                                           const Matrix& log_beta,
137                                           float log_p_z_x, Matrix* dy) const {
138   // Only working with the leftmost part of dy for this batch element.
139   auto dy_b = dy->leftCols(y.cols());
140 
141   // It is possible that no valid path is found if the activations for the
142   // targets are zero.
143   if (log_p_z_x == kLogZero) {
144     LOG(WARNING) << "No valid path found.";
145     dy_b = y;
146     return;
147   }
148 
149   int L = y.rows();
150   int T = y.cols();
151   int U = l_prime.size();
152 
153   for (int t = 0; t < T - output_delay_; ++t) {
154     Array prob_sum(L);
155     prob_sum.setConstant(kLogZero);
156 
157     for (int u = 0; u < U; ++u) {
158       int l = l_prime[u];
159       prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
160     }
161 
162     for (int l = 0; l < L; ++l) {
163       // Negative term in (GravesTh) Eq 7.28.
164       float negative_term = expf(prob_sum[l] - log_p_z_x);
165 
166       dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
167     }
168   }
169 }
170 
GetLPrimeIndices(const std::vector<int> & l,std::vector<int> * l_prime) const171 void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l,
172                                          std::vector<int>* l_prime) const {
173   // Assumption is that l_prime is empty.
174   l_prime->reserve(2 * l.size() + 1);
175 
176   for (auto label : l) {
177     l_prime->push_back(blank_index_);
178     l_prime->push_back(label);
179   }
180   // Add final blank to l'.
181   l_prime->push_back(blank_index_);
182 }
183 
184 }  // namespace ctc
185 }  // namespace tensorflow
186