• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Copied from tensorflow/core/util/ctc/ctc_beam_search.h
17 // TODO(b/111524997): Remove this file.
18 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
19 #define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
20 
21 #include <algorithm>
22 #include <cmath>
23 #include <limits>
24 #include <memory>
25 #include <vector>
26 
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/lite/experimental/kernels/ctc_beam_entry.h"
29 #include "tensorflow/lite/experimental/kernels/ctc_beam_scorer.h"
30 #include "tensorflow/lite/experimental/kernels/ctc_decoder.h"
31 #include "tensorflow/lite/experimental/kernels/ctc_loss_util.h"
32 #include "tensorflow/lite/experimental/kernels/top_n.h"
33 #include "tensorflow/lite/kernels/internal/compatibility.h"
34 
35 namespace tflite {
36 namespace experimental {
37 namespace ctc {
38 
39 template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
40           typename CTCBeamComparer =
41               ctc_beam_search::BeamComparer<CTCBeamState>>
42 class CTCBeamSearchDecoder : public CTCDecoder {
43   // Beam Search
44   //
45   // Example (GravesTh Fig. 7.5):
46   //         a    -
47   //  P = [ 0.3  0.7 ]  t = 0
48   //      [ 0.4  0.6 ]  t = 1
49   //
50   // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
51   //      P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
52   //
53   // In this case, Best Path decoding is suboptimal.
54   //
55   // For Beam Search, we use the following main recurrence relations:
56   //
57   // Relation 1:
58   // ---------------------------------------------------------- Eq. 1
59   //      P(l=abcd @ t=7) = P(l=abc  @ t=6) * P(d @ 7)
60   //                      + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
61   // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
62   // updated recursively in the beam entry.
63   //
64   // Relation 2:
65   // ---------------------------------------------------------- Eq. 2
66   //      P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
67   // for ? in a, b, d, ..., (not including c or the blank index),
68   // and the recurrence starts from the beam entry for P(l=abc @ t=2).
69   //
70   // For this case, the length of the new sequence equals t+1 (t
71   // starts at 0).  This special case can be calculated as:
72   //   P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
73   // but we calculate it recursively for speed purposes.
74   typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
75   typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
76   typedef ctc_beam_search::BeamProbability BeamProbability;
77 
78  public:
79   typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
80 
81   // The beam search decoder is constructed specifying the beam_width (number of
82   // candidates to keep at each decoding timestep) and a beam scorer (used for
83   // custom scoring, for example enabling the use of a language model).
84   // The ownership of the scorer remains with the caller. The default
85   // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
86   // standard beam search.
87   CTCBeamSearchDecoder(int num_classes, int beam_width,
88                        BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
89                        bool merge_repeated = false)
CTCDecoder(num_classes,batch_size,merge_repeated)90       : CTCDecoder(num_classes, batch_size, merge_repeated),
91         beam_width_(beam_width),
92         leaves_(beam_width),
93         beam_scorer_(scorer) {
94     Reset();
95   }
96 
~CTCBeamSearchDecoder()97   ~CTCBeamSearchDecoder() override {}
98 
99   // Run the hibernating beam search algorithm on the given input.
100   bool Decode(const CTCDecoder::SequenceLength& seq_len,
101               const std::vector<CTCDecoder::Input>& input,
102               std::vector<CTCDecoder::Output>* output,
103               CTCDecoder::ScoreOutput* scores) override;
104 
105   // Calculate the next step of the beam search and update the internal state.
106   template <typename Vector>
107   void Step(const Vector& log_input_t);
108 
109   template <typename Vector>
110   float GetTopK(const int K, const Vector& input,
111                 std::vector<float>* top_k_logits,
112                 std::vector<int>* top_k_indices);
113 
114   // Retrieve the beam scorer instance used during decoding.
GetBeamScorer()115   BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
116 
117   // Set label selection parameters for faster decoding.
118   // See comments for label_selection_size_ and label_selection_margin_.
SetLabelSelectionParameters(int label_selection_size,float label_selection_margin)119   void SetLabelSelectionParameters(int label_selection_size,
120                                    float label_selection_margin) {
121     label_selection_size_ = label_selection_size;
122     label_selection_margin_ = label_selection_margin;
123   }
124 
125   // Reset the beam search
126   void Reset();
127 
128   // Extract the top n paths at current time step
129   bool TopPaths(int n, std::vector<std::vector<int>>* paths,
130                 std::vector<float>* log_probs, bool merge_repeated) const;
131 
132  private:
133   int beam_width_;
134 
135   // Label selection is designed to avoid possibly very expensive scorer calls,
136   // by pruning the hypotheses based on the input alone.
137   // Label selection size controls how many items in each beam are passed
138   // through to the beam scorer. Only items with top N input scores are
139   // considered.
140   // Label selection margin controls the difference between minimal input score
141   // (versus the best scoring label) for an item to be passed to the beam
142   // scorer. This margin is expressed in terms of log-probability.
143   // Default is to do no label selection.
144   // For more detail: https://research.google.com/pubs/pub44823.html
145   int label_selection_size_ = 0;       // zero means unlimited
146   float label_selection_margin_ = -1;  // -1 means unlimited.
147 
148   gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
149   std::unique_ptr<BeamRoot> beam_root_;
150   BaseBeamScorer<CTCBeamState>* beam_scorer_;
151 
152   CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete;
153   void operator=(const CTCBeamSearchDecoder&) = delete;
154 };
155 
156 template <typename CTCBeamState, typename CTCBeamComparer>
Decode(const CTCDecoder::SequenceLength & seq_len,const std::vector<CTCDecoder::Input> & input,std::vector<CTCDecoder::Output> * output,ScoreOutput * scores)157 bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
158     const CTCDecoder::SequenceLength& seq_len,
159     const std::vector<CTCDecoder::Input>& input,
160     std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
161   // Storage for top paths.
162   std::vector<std::vector<int>> beams;
163   std::vector<float> beam_log_probabilities;
164   int top_n = output->size();
165   if (std::any_of(output->begin(), output->end(),
166                   [this](const CTCDecoder::Output& output) -> bool {
167                     return output.size() < this->batch_size_;
168                   })) {
169     return false;
170   }
171   if (scores->rows() < batch_size_ || scores->cols() < top_n) {
172     return false;
173   }
174 
175   for (int b = 0; b < batch_size_; ++b) {
176     int seq_len_b = seq_len[b];
177     Reset();
178 
179     for (int t = 0; t < seq_len_b; ++t) {
180       // Pass log-probabilities for this example + time.
181       Step(input[t].row(b));
182     }  // for (int t...
183 
184     // O(n * log(n))
185     std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
186     leaves_.Reset();
187     for (int i = 0; i < branches->size(); ++i) {
188       BeamEntry* entry = (*branches)[i];
189       beam_scorer_->ExpandStateEnd(&entry->state);
190       entry->newp.total +=
191           beam_scorer_->GetStateEndExpansionScore(entry->state);
192       leaves_.push(entry);
193     }
194 
195     bool status =
196         TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
197     if (!status) {
198       return status;
199     }
200 
201     TFLITE_DCHECK_EQ(top_n, beam_log_probabilities.size());
202     TFLITE_DCHECK_EQ(beams.size(), beam_log_probabilities.size());
203 
204     for (int i = 0; i < top_n; ++i) {
205       // Copy output to the correct beam + batch
206       (*output)[i][b].swap(beams[i]);
207       (*scores)(b, i) = -beam_log_probabilities[i];
208     }
209   }  // for (int b...
210   return true;
211 }
212 
213 template <typename CTCBeamState, typename CTCBeamComparer>
214 template <typename Vector>
GetTopK(const int K,const Vector & input,std::vector<float> * top_k_logits,std::vector<int> * top_k_indices)215 float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
216     const int K, const Vector& input, std::vector<float>* top_k_logits,
217     std::vector<int>* top_k_indices) {
218   // Find Top K choices, complexity nk in worst case. The array input is read
219   // just once.
220   TFLITE_DCHECK_EQ(num_classes_, input.size());
221   top_k_logits->clear();
222   top_k_indices->clear();
223   top_k_logits->resize(K, -INFINITY);
224   top_k_indices->resize(K, -1);
225   for (int j = 0; j < num_classes_ - 1; ++j) {
226     const float logit = input(j);
227     if (logit > (*top_k_logits)[K - 1]) {
228       int k = K - 1;
229       while (k > 0 && logit > (*top_k_logits)[k - 1]) {
230         (*top_k_logits)[k] = (*top_k_logits)[k - 1];
231         (*top_k_indices)[k] = (*top_k_indices)[k - 1];
232         k--;
233       }
234       (*top_k_logits)[k] = logit;
235       (*top_k_indices)[k] = j;
236     }
237   }
238   // Return max value which is in 0th index or blank character logit
239   return std::max((*top_k_logits)[0], input(num_classes_ - 1));
240 }
241 
242 template <typename CTCBeamState, typename CTCBeamComparer>
243 template <typename Vector>
Step(const Vector & raw_input)244 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
245     const Vector& raw_input) {
246   std::vector<float> top_k_logits;
247   std::vector<int> top_k_indices;
248   const bool top_k =
249       (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
250   // Number of character classes to consider in each step.
251   const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
252   // Get max coefficient and remove it from raw_input later.
253   float max_coeff;
254   if (top_k) {
255     max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
256                         &top_k_indices);
257   } else {
258     max_coeff = raw_input.maxCoeff();
259   }
260 
261   // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
262   float logsumexp = 0.0;
263   for (int j = 0; j < raw_input.size(); ++j) {
264     logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
265   }
266   logsumexp = Eigen::numext::log(logsumexp);
267   // Final normalization offset to get correct log probabilities.
268   float norm_offset = max_coeff + logsumexp;
269 
270   const float label_selection_input_min =
271       (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
272                                      : -std::numeric_limits<float>::infinity();
273 
274   // Extract the beams sorted in decreasing new probability
275   TFLITE_DCHECK_EQ(num_classes_, raw_input.size());
276 
277   std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
278   leaves_.Reset();
279 
280   for (BeamEntry* b : *branches) {
281     // P(.. @ t) becomes the new P(.. @ t-1)
282     b->oldp = b->newp;
283   }
284 
285   for (BeamEntry* b : *branches) {
286     if (b->parent != nullptr) {  // if not the root
287       if (b->parent->Active()) {
288         // If last two sequence characters are identical:
289         //   Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
290         //                          + Pblank(l=ac @ t=5))
291         // else:
292         //   Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
293         //                          + P(l=ab @ t=5))
294         float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
295                                                         : b->parent->oldp.total;
296         b->newp.label =
297             LogSumExp(b->newp.label,
298                       beam_scorer_->GetStateExpansionScore(b->state, previous));
299       }
300       // Plabel(l=abc @ t=6) *= P(c @ 6)
301       b->newp.label += raw_input(b->label) - norm_offset;
302     }
303     // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
304     b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
305     // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
306     b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
307 
308     // Push the entry back to the top paths list.
309     // Note, this will always fill leaves back up in sorted order.
310     leaves_.push(b);
311   }
312 
313   // we need to resort branches in descending oldp order.
314 
315   // branches is in descending oldp order because it was
316   // originally in descending newp order and we copied newp to oldp.
317 
318   // Grow new leaves
319   for (BeamEntry* b : *branches) {
320     // A new leaf (represented by its BeamProbability) is a candidate
321     // iff its total probability is nonzero and either the beam list
322     // isn't full, or the lowest probability entry in the beam has a
323     // lower probability than the leaf.
324     auto is_candidate = [this](const BeamProbability& prob) {
325       return (prob.total > kLogZero &&
326               (leaves_.size() < beam_width_ ||
327                prob.total > leaves_.peek_bottom()->newp.total));
328     };
329 
330     if (!is_candidate(b->oldp)) {
331       continue;
332     }
333 
334     for (int ind = 0; ind < max_classes; ind++) {
335       const int label = top_k ? top_k_indices[ind] : ind;
336       const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
337       // Perform label selection: if input for this label looks very
338       // unpromising, never evaluate it with a scorer.
339       // We may compare logits instead of log probabilities,
340       // since the difference is the same in both cases.
341       if (logit < label_selection_input_min) {
342         continue;
343       }
344       BeamEntry& c = b->GetChild(label);
345       if (!c.Active()) {
346         //   Pblank(l=abcd @ t=6) = 0
347         c.newp.blank = kLogZero;
348         // If new child label is identical to beam label:
349         //   Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
350         // Otherwise:
351         //   Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
352         beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
353         float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
354         c.newp.label = logit - norm_offset +
355                        beam_scorer_->GetStateExpansionScore(c.state, previous);
356         // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
357         c.newp.total = c.newp.label;
358 
359         if (is_candidate(c.newp)) {
360           // Before adding the new node to the beam, check if the beam
361           // is already at maximum width.
362           if (leaves_.size() == beam_width_) {
363             // Bottom is no longer in the beam search.  Reset
364             // its probability; signal it's no longer in the beam search.
365             BeamEntry* bottom = leaves_.peek_bottom();
366             bottom->newp.Reset();
367           }
368           leaves_.push(&c);
369         } else {
370           // Deactivate child.
371           c.oldp.Reset();
372           c.newp.Reset();
373         }
374       }
375     }
376   }  // for (BeamEntry* b...
377 }
378 
379 template <typename CTCBeamState, typename CTCBeamComparer>
Reset()380 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
381   leaves_.Reset();
382 
383   // This beam root, and all of its children, will be in memory until
384   // the next reset.
385   beam_root_.reset(new BeamRoot(nullptr, -1));
386   beam_root_->RootEntry()->newp.total = 0.0;  // ln(1)
387   beam_root_->RootEntry()->newp.blank = 0.0;  // ln(1)
388 
389   // Add the root as the initial leaf.
390   leaves_.push(beam_root_->RootEntry());
391 
392   // Call initialize state on the root object.
393   beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
394 }
395 
396 template <typename CTCBeamState, typename CTCBeamComparer>
TopPaths(int n,std::vector<std::vector<int>> * paths,std::vector<float> * log_probs,bool merge_repeated)397 bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
398     int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
399     bool merge_repeated) const {
400   TFLITE_DCHECK(paths);
401   TFLITE_DCHECK(log_probs);
402   paths->clear();
403   log_probs->clear();
404   if (n > beam_width_) {
405     return false;
406   }
407   if (n > leaves_.size()) {
408     return false;
409   }
410 
411   gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
412 
413   // O(beam_width_ * log(n)), space complexity is O(n)
414   for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
415     top_branches.push(*it);
416   }
417   // O(n * log(n))
418   std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
419 
420   for (int i = 0; i < n; ++i) {
421     BeamEntry* e((*branches)[i]);
422     paths->push_back(e->LabelSeq(merge_repeated));
423     log_probs->push_back(e->newp.total);
424   }
425   return true;
426 }
427 
428 }  // namespace ctc
429 }  // namespace experimental
430 }  // namespace tflite
431 
432 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
433