1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Copied from tensorflow/core/util/ctc/ctc_beam_search.h
17 // TODO(b/111524997): Remove this file.
18 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
19 #define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
20
21 #include <algorithm>
22 #include <cmath>
23 #include <limits>
24 #include <memory>
25 #include <vector>
26
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/lite/experimental/kernels/ctc_beam_entry.h"
29 #include "tensorflow/lite/experimental/kernels/ctc_beam_scorer.h"
30 #include "tensorflow/lite/experimental/kernels/ctc_decoder.h"
31 #include "tensorflow/lite/experimental/kernels/ctc_loss_util.h"
32 #include "tensorflow/lite/experimental/kernels/top_n.h"
33 #include "tensorflow/lite/kernels/internal/compatibility.h"
34
35 namespace tflite {
36 namespace experimental {
37 namespace ctc {
38
39 template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
40 typename CTCBeamComparer =
41 ctc_beam_search::BeamComparer<CTCBeamState>>
42 class CTCBeamSearchDecoder : public CTCDecoder {
43 // Beam Search
44 //
45 // Example (GravesTh Fig. 7.5):
46 // a -
47 // P = [ 0.3 0.7 ] t = 0
48 // [ 0.4 0.6 ] t = 1
49 //
50 // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
51 // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
52 //
53 // In this case, Best Path decoding is suboptimal.
54 //
55 // For Beam Search, we use the following main recurrence relations:
56 //
57 // Relation 1:
58 // ---------------------------------------------------------- Eq. 1
59 // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7)
60 // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
61 // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
62 // updated recursively in the beam entry.
63 //
64 // Relation 2:
65 // ---------------------------------------------------------- Eq. 2
66 // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
67 // for ? in a, b, d, ..., (not including c or the blank index),
68 // and the recurrence starts from the beam entry for P(l=abc @ t=2).
69 //
70 // For this case, the length of the new sequence equals t+1 (t
71 // starts at 0). This special case can be calculated as:
72 // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
73 // but we calculate it recursively for speed purposes.
74 typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
75 typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
76 typedef ctc_beam_search::BeamProbability BeamProbability;
77
78 public:
79 typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
80
81 // The beam search decoder is constructed specifying the beam_width (number of
82 // candidates to keep at each decoding timestep) and a beam scorer (used for
83 // custom scoring, for example enabling the use of a language model).
84 // The ownership of the scorer remains with the caller. The default
85 // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
86 // standard beam search.
87 CTCBeamSearchDecoder(int num_classes, int beam_width,
88 BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
89 bool merge_repeated = false)
CTCDecoder(num_classes,batch_size,merge_repeated)90 : CTCDecoder(num_classes, batch_size, merge_repeated),
91 beam_width_(beam_width),
92 leaves_(beam_width),
93 beam_scorer_(scorer) {
94 Reset();
95 }
96
~CTCBeamSearchDecoder()97 ~CTCBeamSearchDecoder() override {}
98
99 // Run the hibernating beam search algorithm on the given input.
100 bool Decode(const CTCDecoder::SequenceLength& seq_len,
101 const std::vector<CTCDecoder::Input>& input,
102 std::vector<CTCDecoder::Output>* output,
103 CTCDecoder::ScoreOutput* scores) override;
104
105 // Calculate the next step of the beam search and update the internal state.
106 template <typename Vector>
107 void Step(const Vector& log_input_t);
108
109 template <typename Vector>
110 float GetTopK(const int K, const Vector& input,
111 std::vector<float>* top_k_logits,
112 std::vector<int>* top_k_indices);
113
114 // Retrieve the beam scorer instance used during decoding.
GetBeamScorer()115 BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
116
117 // Set label selection parameters for faster decoding.
118 // See comments for label_selection_size_ and label_selection_margin_.
SetLabelSelectionParameters(int label_selection_size,float label_selection_margin)119 void SetLabelSelectionParameters(int label_selection_size,
120 float label_selection_margin) {
121 label_selection_size_ = label_selection_size;
122 label_selection_margin_ = label_selection_margin;
123 }
124
125 // Reset the beam search
126 void Reset();
127
128 // Extract the top n paths at current time step
129 bool TopPaths(int n, std::vector<std::vector<int>>* paths,
130 std::vector<float>* log_probs, bool merge_repeated) const;
131
132 private:
133 int beam_width_;
134
135 // Label selection is designed to avoid possibly very expensive scorer calls,
136 // by pruning the hypotheses based on the input alone.
137 // Label selection size controls how many items in each beam are passed
138 // through to the beam scorer. Only items with top N input scores are
139 // considered.
140 // Label selection margin controls the difference between minimal input score
141 // (versus the best scoring label) for an item to be passed to the beam
142 // scorer. This margin is expressed in terms of log-probability.
143 // Default is to do no label selection.
144 // For more detail: https://research.google.com/pubs/pub44823.html
145 int label_selection_size_ = 0; // zero means unlimited
146 float label_selection_margin_ = -1; // -1 means unlimited.
147
148 gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
149 std::unique_ptr<BeamRoot> beam_root_;
150 BaseBeamScorer<CTCBeamState>* beam_scorer_;
151
152 CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete;
153 void operator=(const CTCBeamSearchDecoder&) = delete;
154 };
155
156 template <typename CTCBeamState, typename CTCBeamComparer>
Decode(const CTCDecoder::SequenceLength & seq_len,const std::vector<CTCDecoder::Input> & input,std::vector<CTCDecoder::Output> * output,ScoreOutput * scores)157 bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
158 const CTCDecoder::SequenceLength& seq_len,
159 const std::vector<CTCDecoder::Input>& input,
160 std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
161 // Storage for top paths.
162 std::vector<std::vector<int>> beams;
163 std::vector<float> beam_log_probabilities;
164 int top_n = output->size();
165 if (std::any_of(output->begin(), output->end(),
166 [this](const CTCDecoder::Output& output) -> bool {
167 return output.size() < this->batch_size_;
168 })) {
169 return false;
170 }
171 if (scores->rows() < batch_size_ || scores->cols() < top_n) {
172 return false;
173 }
174
175 for (int b = 0; b < batch_size_; ++b) {
176 int seq_len_b = seq_len[b];
177 Reset();
178
179 for (int t = 0; t < seq_len_b; ++t) {
180 // Pass log-probabilities for this example + time.
181 Step(input[t].row(b));
182 } // for (int t...
183
184 // O(n * log(n))
185 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
186 leaves_.Reset();
187 for (int i = 0; i < branches->size(); ++i) {
188 BeamEntry* entry = (*branches)[i];
189 beam_scorer_->ExpandStateEnd(&entry->state);
190 entry->newp.total +=
191 beam_scorer_->GetStateEndExpansionScore(entry->state);
192 leaves_.push(entry);
193 }
194
195 bool status =
196 TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
197 if (!status) {
198 return status;
199 }
200
201 TFLITE_DCHECK_EQ(top_n, beam_log_probabilities.size());
202 TFLITE_DCHECK_EQ(beams.size(), beam_log_probabilities.size());
203
204 for (int i = 0; i < top_n; ++i) {
205 // Copy output to the correct beam + batch
206 (*output)[i][b].swap(beams[i]);
207 (*scores)(b, i) = -beam_log_probabilities[i];
208 }
209 } // for (int b...
210 return true;
211 }
212
213 template <typename CTCBeamState, typename CTCBeamComparer>
214 template <typename Vector>
GetTopK(const int K,const Vector & input,std::vector<float> * top_k_logits,std::vector<int> * top_k_indices)215 float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
216 const int K, const Vector& input, std::vector<float>* top_k_logits,
217 std::vector<int>* top_k_indices) {
218 // Find Top K choices, complexity nk in worst case. The array input is read
219 // just once.
220 TFLITE_DCHECK_EQ(num_classes_, input.size());
221 top_k_logits->clear();
222 top_k_indices->clear();
223 top_k_logits->resize(K, -INFINITY);
224 top_k_indices->resize(K, -1);
225 for (int j = 0; j < num_classes_ - 1; ++j) {
226 const float logit = input(j);
227 if (logit > (*top_k_logits)[K - 1]) {
228 int k = K - 1;
229 while (k > 0 && logit > (*top_k_logits)[k - 1]) {
230 (*top_k_logits)[k] = (*top_k_logits)[k - 1];
231 (*top_k_indices)[k] = (*top_k_indices)[k - 1];
232 k--;
233 }
234 (*top_k_logits)[k] = logit;
235 (*top_k_indices)[k] = j;
236 }
237 }
238 // Return max value which is in 0th index or blank character logit
239 return std::max((*top_k_logits)[0], input(num_classes_ - 1));
240 }
241
242 template <typename CTCBeamState, typename CTCBeamComparer>
243 template <typename Vector>
Step(const Vector & raw_input)244 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
245 const Vector& raw_input) {
246 std::vector<float> top_k_logits;
247 std::vector<int> top_k_indices;
248 const bool top_k =
249 (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
250 // Number of character classes to consider in each step.
251 const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
252 // Get max coefficient and remove it from raw_input later.
253 float max_coeff;
254 if (top_k) {
255 max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
256 &top_k_indices);
257 } else {
258 max_coeff = raw_input.maxCoeff();
259 }
260
261 // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
262 float logsumexp = 0.0;
263 for (int j = 0; j < raw_input.size(); ++j) {
264 logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
265 }
266 logsumexp = Eigen::numext::log(logsumexp);
267 // Final normalization offset to get correct log probabilities.
268 float norm_offset = max_coeff + logsumexp;
269
270 const float label_selection_input_min =
271 (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
272 : -std::numeric_limits<float>::infinity();
273
274 // Extract the beams sorted in decreasing new probability
275 TFLITE_DCHECK_EQ(num_classes_, raw_input.size());
276
277 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
278 leaves_.Reset();
279
280 for (BeamEntry* b : *branches) {
281 // P(.. @ t) becomes the new P(.. @ t-1)
282 b->oldp = b->newp;
283 }
284
285 for (BeamEntry* b : *branches) {
286 if (b->parent != nullptr) { // if not the root
287 if (b->parent->Active()) {
288 // If last two sequence characters are identical:
289 // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
290 // + Pblank(l=ac @ t=5))
291 // else:
292 // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
293 // + P(l=ab @ t=5))
294 float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
295 : b->parent->oldp.total;
296 b->newp.label =
297 LogSumExp(b->newp.label,
298 beam_scorer_->GetStateExpansionScore(b->state, previous));
299 }
300 // Plabel(l=abc @ t=6) *= P(c @ 6)
301 b->newp.label += raw_input(b->label) - norm_offset;
302 }
303 // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
304 b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
305 // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
306 b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
307
308 // Push the entry back to the top paths list.
309 // Note, this will always fill leaves back up in sorted order.
310 leaves_.push(b);
311 }
312
313 // we need to resort branches in descending oldp order.
314
315 // branches is in descending oldp order because it was
316 // originally in descending newp order and we copied newp to oldp.
317
318 // Grow new leaves
319 for (BeamEntry* b : *branches) {
320 // A new leaf (represented by its BeamProbability) is a candidate
321 // iff its total probability is nonzero and either the beam list
322 // isn't full, or the lowest probability entry in the beam has a
323 // lower probability than the leaf.
324 auto is_candidate = [this](const BeamProbability& prob) {
325 return (prob.total > kLogZero &&
326 (leaves_.size() < beam_width_ ||
327 prob.total > leaves_.peek_bottom()->newp.total));
328 };
329
330 if (!is_candidate(b->oldp)) {
331 continue;
332 }
333
334 for (int ind = 0; ind < max_classes; ind++) {
335 const int label = top_k ? top_k_indices[ind] : ind;
336 const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
337 // Perform label selection: if input for this label looks very
338 // unpromising, never evaluate it with a scorer.
339 // We may compare logits instead of log probabilities,
340 // since the difference is the same in both cases.
341 if (logit < label_selection_input_min) {
342 continue;
343 }
344 BeamEntry& c = b->GetChild(label);
345 if (!c.Active()) {
346 // Pblank(l=abcd @ t=6) = 0
347 c.newp.blank = kLogZero;
348 // If new child label is identical to beam label:
349 // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
350 // Otherwise:
351 // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
352 beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
353 float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
354 c.newp.label = logit - norm_offset +
355 beam_scorer_->GetStateExpansionScore(c.state, previous);
356 // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
357 c.newp.total = c.newp.label;
358
359 if (is_candidate(c.newp)) {
360 // Before adding the new node to the beam, check if the beam
361 // is already at maximum width.
362 if (leaves_.size() == beam_width_) {
363 // Bottom is no longer in the beam search. Reset
364 // its probability; signal it's no longer in the beam search.
365 BeamEntry* bottom = leaves_.peek_bottom();
366 bottom->newp.Reset();
367 }
368 leaves_.push(&c);
369 } else {
370 // Deactivate child.
371 c.oldp.Reset();
372 c.newp.Reset();
373 }
374 }
375 }
376 } // for (BeamEntry* b...
377 }
378
379 template <typename CTCBeamState, typename CTCBeamComparer>
Reset()380 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
381 leaves_.Reset();
382
383 // This beam root, and all of its children, will be in memory until
384 // the next reset.
385 beam_root_.reset(new BeamRoot(nullptr, -1));
386 beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
387 beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
388
389 // Add the root as the initial leaf.
390 leaves_.push(beam_root_->RootEntry());
391
392 // Call initialize state on the root object.
393 beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
394 }
395
396 template <typename CTCBeamState, typename CTCBeamComparer>
TopPaths(int n,std::vector<std::vector<int>> * paths,std::vector<float> * log_probs,bool merge_repeated)397 bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
398 int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
399 bool merge_repeated) const {
400 TFLITE_DCHECK(paths);
401 TFLITE_DCHECK(log_probs);
402 paths->clear();
403 log_probs->clear();
404 if (n > beam_width_) {
405 return false;
406 }
407 if (n > leaves_.size()) {
408 return false;
409 }
410
411 gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
412
413 // O(beam_width_ * log(n)), space complexity is O(n)
414 for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
415 top_branches.push(*it);
416 }
417 // O(n * log(n))
418 std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
419
420 for (int i = 0; i < n; ++i) {
421 BeamEntry* e((*branches)[i]);
422 paths->push_back(e->LabelSeq(merge_repeated));
423 log_probs->push_back(e->newp.total);
424 }
425 return true;
426 }
427
428 } // namespace ctc
429 } // namespace experimental
430 } // namespace tflite
431
432 #endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
433