• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // LINT.IfChange
16 
17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <vector>
23 
24 #include "third_party/eigen3/Eigen/Core"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
30 
31 namespace tensorflow {
32 namespace ctc {
33 
34 // The ctc_beam_search namespace holds several classes meant to be accessed only
35 // in case of extending the CTCBeamSearch decoder to allow custom scoring
36 // functions.
37 //
38 // BeamEntry is exposed through template arguments BeamScorer and BeamComparer
39 // of CTCBeamSearch (ctc_beam_search.h).
40 namespace ctc_beam_search {
41 
42 struct EmptyBeamState {};
43 
44 struct BeamProbability {
BeamProbabilityBeamProbability45   BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
ResetBeamProbability46   void Reset() {
47     total = kLogZero;
48     blank = kLogZero;
49     label = kLogZero;
50   }
51   float total;
52   float blank;
53   float label;
54 };
55 
56 template <class CTCBeamState>
57 class BeamRoot;
58 
59 template <class CTCBeamState = EmptyBeamState>
60 struct BeamEntry {
61   // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
62   friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
63       BeamEntry<CTCBeamState>* p, int l);
ActiveBeamEntry64   inline bool Active() const { return newp.total != kLogZero; }
65   // Return the child at the given index, or construct a new one in-place if
66   // none was found.
GetChildBeamEntry67   BeamEntry& GetChild(int ind) {
68     auto entry = children.emplace(ind, nullptr);
69     auto& child_entry = entry.first->second;
70     // If this is a new child, populate the BeamEntry<CTCBeamState>*.
71     if (entry.second) {
72       child_entry = beam_root->AddEntry(this, ind);
73     }
74     return *child_entry;
75   }
LabelSeqBeamEntry76   std::vector<int> LabelSeq(bool merge_repeated) const {
77     std::vector<int> labels;
78     int prev_label = -1;
79     const BeamEntry* c = this;
80     while (c->parent != nullptr) {  // Checking c->parent to skip root leaf.
81       if (!merge_repeated || c->label != prev_label) {
82         labels.push_back(c->label);
83       }
84       prev_label = c->label;
85       c = c->parent;
86     }
87     std::reverse(labels.begin(), labels.end());
88     return labels;
89   }
90 
91   BeamEntry<CTCBeamState>* parent;
92   int label;
93   // All instances of child BeamEntry are owned by *beam_root.
94   gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
95   BeamProbability oldp;
96   BeamProbability newp;
97   CTCBeamState state;
98 
99  private:
100   // Constructor giving parent, label, and the beam_root.
101   // The object pointed to by p cannot be copied and should not be moved,
102   // otherwise parent will become invalid.
103   // This private constructor is only called through the factory method
104   // BeamRoot<CTCBeamState>::AddEntry().
BeamEntryBeamEntry105   BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
106       : parent(p), label(l), beam_root(beam_root) {}
107   BeamRoot<CTCBeamState>* beam_root;
108   TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
109 };
110 
111 // This class owns all instances of BeamEntry.  This is used to avoid recursive
112 // destructor call during destruction.
113 template <class CTCBeamState = EmptyBeamState>
114 class BeamRoot {
115  public:
BeamRoot(BeamEntry<CTCBeamState> * p,int l)116   BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
117   BeamRoot(const BeamRoot&) = delete;
118   BeamRoot& operator=(const BeamRoot&) = delete;
119 
AddEntry(BeamEntry<CTCBeamState> * p,int l)120   BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
121     auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
122     beam_entries_.emplace_back(new_entry);
123     return new_entry;
124   }
RootEntry()125   BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
126 
127  private:
128   BeamEntry<CTCBeamState>* root_entry_ = nullptr;
129   std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
130 };
131 
132 // BeamComparer is the default beam comparer provided in CTCBeamSearch.
133 template <class CTCBeamState = EmptyBeamState>
134 class BeamComparer {
135  public:
~BeamComparer()136   virtual ~BeamComparer() {}
operator()137   virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
138                                  const BeamEntry<CTCBeamState>* b) const {
139     return a->newp.total > b->newp.total;
140   }
141 };
142 
143 }  // namespace ctc_beam_search
144 
145 }  // namespace ctc
146 }  // namespace tensorflow
147 
148 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
149 // LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_entry.h)
150