1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // LINT.IfChange 16 17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ 18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <vector> 23 24 #include "third_party/eigen3/Eigen/Core" 25 #include "tensorflow/core/lib/gtl/flatmap.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/core/util/ctc/ctc_loss_util.h" 30 31 namespace tensorflow { 32 namespace ctc { 33 34 // The ctc_beam_search namespace holds several classes meant to be accessed only 35 // in case of extending the CTCBeamSearch decoder to allow custom scoring 36 // functions. 37 // 38 // BeamEntry is exposed through template arguments BeamScorer and BeamComparer 39 // of CTCBeamSearch (ctc_beam_search.h). 40 namespace ctc_beam_search { 41 42 struct EmptyBeamState {}; 43 44 struct BeamProbability { BeamProbabilityBeamProbability45 BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {} ResetBeamProbability46 void Reset() { 47 total = kLogZero; 48 blank = kLogZero; 49 label = kLogZero; 50 } 51 float total; 52 float blank; 53 float label; 54 }; 55 56 template <class CTCBeamState> 57 class BeamRoot; 58 59 template <class CTCBeamState = EmptyBeamState> 60 struct BeamEntry { 61 // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method. 62 friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry( 63 BeamEntry<CTCBeamState>* p, int l); ActiveBeamEntry64 inline bool Active() const { return newp.total != kLogZero; } 65 // Return the child at the given index, or construct a new one in-place if 66 // none was found. GetChildBeamEntry67 BeamEntry& GetChild(int ind) { 68 auto entry = children.emplace(ind, nullptr); 69 auto& child_entry = entry.first->second; 70 // If this is a new child, populate the BeamEntry<CTCBeamState>*. 71 if (entry.second) { 72 child_entry = beam_root->AddEntry(this, ind); 73 } 74 return *child_entry; 75 } LabelSeqBeamEntry76 std::vector<int> LabelSeq(bool merge_repeated) const { 77 std::vector<int> labels; 78 int prev_label = -1; 79 const BeamEntry* c = this; 80 while (c->parent != nullptr) { // Checking c->parent to skip root leaf. 81 if (!merge_repeated || c->label != prev_label) { 82 labels.push_back(c->label); 83 } 84 prev_label = c->label; 85 c = c->parent; 86 } 87 std::reverse(labels.begin(), labels.end()); 88 return labels; 89 } 90 91 BeamEntry<CTCBeamState>* parent; 92 int label; 93 // All instances of child BeamEntry are owned by *beam_root. 94 gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children; 95 BeamProbability oldp; 96 BeamProbability newp; 97 CTCBeamState state; 98 99 private: 100 // Constructor giving parent, label, and the beam_root. 101 // The object pointed to by p cannot be copied and should not be moved, 102 // otherwise parent will become invalid. 103 // This private constructor is only called through the factory method 104 // BeamRoot<CTCBeamState>::AddEntry(). BeamEntryBeamEntry105 BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root) 106 : parent(p), label(l), beam_root(beam_root) {} 107 BeamRoot<CTCBeamState>* beam_root; 108 TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry); 109 }; 110 111 // This class owns all instances of BeamEntry. This is used to avoid recursive 112 // destructor call during destruction. 113 template <class CTCBeamState = EmptyBeamState> 114 class BeamRoot { 115 public: BeamRoot(BeamEntry<CTCBeamState> * p,int l)116 BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); } 117 BeamRoot(const BeamRoot&) = delete; 118 BeamRoot& operator=(const BeamRoot&) = delete; 119 AddEntry(BeamEntry<CTCBeamState> * p,int l)120 BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) { 121 auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this); 122 beam_entries_.emplace_back(new_entry); 123 return new_entry; 124 } RootEntry()125 BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; } 126 127 private: 128 BeamEntry<CTCBeamState>* root_entry_ = nullptr; 129 std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_; 130 }; 131 132 // BeamComparer is the default beam comparer provided in CTCBeamSearch. 133 template <class CTCBeamState = EmptyBeamState> 134 class BeamComparer { 135 public: ~BeamComparer()136 virtual ~BeamComparer() {} operator()137 virtual bool inline operator()(const BeamEntry<CTCBeamState>* a, 138 const BeamEntry<CTCBeamState>* b) const { 139 return a->newp.total > b->newp.total; 140 } 141 }; 142 143 } // namespace ctc_beam_search 144 145 } // namespace ctc 146 } // namespace tensorflow 147 148 #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ 149 // LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_entry.h) 150