1 // connect.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes and functions to remove unsuccessful paths from an Fst.
20
21 #ifndef FST_LIB_CONNECT_H__
22 #define FST_LIB_CONNECT_H__
23
24 #include <vector>
25 using std::vector;
26
27 #include <fst/dfs-visit.h>
28 #include <fst/union-find.h>
29 #include <fst/mutable-fst.h>
30
31
32 namespace fst {
33
34 // Finds and returns connected components. Use with Visit().
35 template <class A>
36 class CcVisitor {
37 public:
38 typedef A Arc;
39 typedef typename Arc::Weight Weight;
40 typedef typename A::StateId StateId;
41
42 // cc[i]: connected component number for state i.
CcVisitor(vector<StateId> * cc)43 CcVisitor(vector<StateId> *cc)
44 : comps_(new UnionFind<StateId>(0, kNoStateId)),
45 cc_(cc),
46 nstates_(0) { }
47
48 // comps: connected components equiv classes.
CcVisitor(UnionFind<StateId> * comps)49 CcVisitor(UnionFind<StateId> *comps)
50 : comps_(comps),
51 cc_(0),
52 nstates_(0) { }
53
~CcVisitor()54 ~CcVisitor() {
55 if (cc_) // own comps_?
56 delete comps_;
57 }
58
InitVisit(const Fst<A> & fst)59 void InitVisit(const Fst<A> &fst) { }
60
InitState(StateId s,StateId root)61 bool InitState(StateId s, StateId root) {
62 ++nstates_;
63 if (comps_->FindSet(s) == kNoStateId)
64 comps_->MakeSet(s);
65 return true;
66 }
67
WhiteArc(StateId s,const A & arc)68 bool WhiteArc(StateId s, const A &arc) {
69 comps_->MakeSet(arc.nextstate);
70 comps_->Union(s, arc.nextstate);
71 return true;
72 }
73
GreyArc(StateId s,const A & arc)74 bool GreyArc(StateId s, const A &arc) {
75 comps_->Union(s, arc.nextstate);
76 return true;
77 }
78
BlackArc(StateId s,const A & arc)79 bool BlackArc(StateId s, const A &arc) {
80 comps_->Union(s, arc.nextstate);
81 return true;
82 }
83
FinishState(StateId s)84 void FinishState(StateId s) { }
85
FinishVisit()86 void FinishVisit() {
87 if (cc_)
88 GetCcVector(cc_);
89 }
90
91 // cc[i]: connected component number for state i.
92 // Returns number of components.
GetCcVector(vector<StateId> * cc)93 int GetCcVector(vector<StateId> *cc) {
94 cc->clear();
95 cc->resize(nstates_, kNoStateId);
96 StateId ncomp = 0;
97 for (StateId i = 0; i < nstates_; ++i) {
98 StateId rep = comps_->FindSet(i);
99 StateId &comp = (*cc)[rep];
100 if (comp == kNoStateId) {
101 comp = ncomp;
102 ++ncomp;
103 }
104 (*cc)[i] = comp;
105 }
106 return ncomp;
107 }
108
109 private:
110 UnionFind<StateId> *comps_; // Components
111 vector<StateId> *cc_; // State's cc number
112 StateId nstates_; // State count
113 };
114
115
116 // Finds and returns strongly-connected components, accessible and
117 // coaccessible states and related properties. Uses Tarjan's single
118 // DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer
119 // Algorithms", 189pp). Use with DfsVisit();
120 template <class A>
121 class SccVisitor {
122 public:
123 typedef A Arc;
124 typedef typename A::Weight Weight;
125 typedef typename A::StateId StateId;
126
127 // scc[i]: strongly-connected component number for state i.
128 // SCC numbers will be in topological order for acyclic input.
129 // access[i]: accessibility of state i.
130 // coaccess[i]: coaccessibility of state i.
131 // Any of above can be NULL.
132 // props: related property bits (cyclicity, initial cyclicity,
133 // accessibility, coaccessibility) set/cleared (o.w. unchanged).
SccVisitor(vector<StateId> * scc,vector<bool> * access,vector<bool> * coaccess,uint64 * props)134 SccVisitor(vector<StateId> *scc, vector<bool> *access,
135 vector<bool> *coaccess, uint64 *props)
136 : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {}
SccVisitor(uint64 * props)137 SccVisitor(uint64 *props)
138 : scc_(0), access_(0), coaccess_(0), props_(props) {}
139
140 void InitVisit(const Fst<A> &fst);
141
142 bool InitState(StateId s, StateId root);
143
TreeArc(StateId s,const A & arc)144 bool TreeArc(StateId s, const A &arc) { return true; }
145
BackArc(StateId s,const A & arc)146 bool BackArc(StateId s, const A &arc) {
147 StateId t = arc.nextstate;
148 if ((*dfnumber_)[t] < (*lowlink_)[s])
149 (*lowlink_)[s] = (*dfnumber_)[t];
150 if ((*coaccess_)[t])
151 (*coaccess_)[s] = true;
152 *props_ |= kCyclic;
153 *props_ &= ~kAcyclic;
154 if (arc.nextstate == start_) {
155 *props_ |= kInitialCyclic;
156 *props_ &= ~kInitialAcyclic;
157 }
158 return true;
159 }
160
ForwardOrCrossArc(StateId s,const A & arc)161 bool ForwardOrCrossArc(StateId s, const A &arc) {
162 StateId t = arc.nextstate;
163 if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ &&
164 (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s])
165 (*lowlink_)[s] = (*dfnumber_)[t];
166 if ((*coaccess_)[t])
167 (*coaccess_)[s] = true;
168 return true;
169 }
170
171 void FinishState(StateId s, StateId p, const A *);
172
FinishVisit()173 void FinishVisit() {
174 // Numbers SCC's in topological order when acyclic.
175 if (scc_)
176 for (StateId i = 0; i < scc_->size(); ++i)
177 (*scc_)[i] = nscc_ - 1 - (*scc_)[i];
178 if (coaccess_internal_)
179 delete coaccess_;
180 delete dfnumber_;
181 delete lowlink_;
182 delete onstack_;
183 delete scc_stack_;
184 }
185
186 private:
187 vector<StateId> *scc_; // State's scc number
188 vector<bool> *access_; // State's accessibility
189 vector<bool> *coaccess_; // State's coaccessibility
190 uint64 *props_;
191 const Fst<A> *fst_;
192 StateId start_;
193 StateId nstates_; // State count
194 StateId nscc_; // SCC count
195 bool coaccess_internal_;
196 vector<StateId> *dfnumber_; // state discovery times
197 vector<StateId> *lowlink_; // lowlink[s] == dfnumber[s] => SCC root
198 vector<bool> *onstack_; // is a state on the SCC stack
199 vector<StateId> *scc_stack_; // SCC stack (w/ random access)
200 };
201
202 template <class A> inline
InitVisit(const Fst<A> & fst)203 void SccVisitor<A>::InitVisit(const Fst<A> &fst) {
204 if (scc_)
205 scc_->clear();
206 if (access_)
207 access_->clear();
208 if (coaccess_) {
209 coaccess_->clear();
210 coaccess_internal_ = false;
211 } else {
212 coaccess_ = new vector<bool>;
213 coaccess_internal_ = true;
214 }
215 *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible;
216 *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible);
217 fst_ = &fst;
218 start_ = fst.Start();
219 nstates_ = 0;
220 nscc_ = 0;
221 dfnumber_ = new vector<StateId>;
222 lowlink_ = new vector<StateId>;
223 onstack_ = new vector<bool>;
224 scc_stack_ = new vector<StateId>;
225 }
226
227 template <class A> inline
InitState(StateId s,StateId root)228 bool SccVisitor<A>::InitState(StateId s, StateId root) {
229 scc_stack_->push_back(s);
230 while (dfnumber_->size() <= s) {
231 if (scc_)
232 scc_->push_back(-1);
233 if (access_)
234 access_->push_back(false);
235 coaccess_->push_back(false);
236 dfnumber_->push_back(-1);
237 lowlink_->push_back(-1);
238 onstack_->push_back(false);
239 }
240 (*dfnumber_)[s] = nstates_;
241 (*lowlink_)[s] = nstates_;
242 (*onstack_)[s] = true;
243 if (root == start_) {
244 if (access_)
245 (*access_)[s] = true;
246 } else {
247 if (access_)
248 (*access_)[s] = false;
249 *props_ |= kNotAccessible;
250 *props_ &= ~kAccessible;
251 }
252 ++nstates_;
253 return true;
254 }
255
256 template <class A> inline
FinishState(StateId s,StateId p,const A *)257 void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) {
258 if (fst_->Final(s) != Weight::Zero())
259 (*coaccess_)[s] = true;
260 if ((*dfnumber_)[s] == (*lowlink_)[s]) { // root of new SCC
261 bool scc_coaccess = false;
262 size_t i = scc_stack_->size();
263 StateId t;
264 do {
265 t = (*scc_stack_)[--i];
266 if ((*coaccess_)[t])
267 scc_coaccess = true;
268 } while (s != t);
269 do {
270 t = scc_stack_->back();
271 if (scc_)
272 (*scc_)[t] = nscc_;
273 if (scc_coaccess)
274 (*coaccess_)[t] = true;
275 (*onstack_)[t] = false;
276 scc_stack_->pop_back();
277 } while (s != t);
278 if (!scc_coaccess) {
279 *props_ |= kNotCoAccessible;
280 *props_ &= ~kCoAccessible;
281 }
282 ++nscc_;
283 }
284 if (p != kNoStateId) {
285 if ((*coaccess_)[s])
286 (*coaccess_)[p] = true;
287 if ((*lowlink_)[s] < (*lowlink_)[p])
288 (*lowlink_)[p] = (*lowlink_)[s];
289 }
290 }
291
292
293 // Trims an FST, removing states and arcs that are not on successful
294 // paths. This version modifies its input.
295 //
296 // Complexity:
297 // - Time: O(V + E)
298 // - Space: O(V + E)
299 // where V = # of states and E = # of arcs.
300 template<class Arc>
Connect(MutableFst<Arc> * fst)301 void Connect(MutableFst<Arc> *fst) {
302 typedef typename Arc::StateId StateId;
303
304 vector<bool> access;
305 vector<bool> coaccess;
306 uint64 props = 0;
307 SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props);
308 DfsVisit(*fst, &scc_visitor);
309 vector<StateId> dstates;
310 for (StateId s = 0; s < access.size(); ++s)
311 if (!access[s] || !coaccess[s])
312 dstates.push_back(s);
313 fst->DeleteStates(dstates);
314 fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible);
315 }
316
317 } // namespace fst
318
319 #endif // FST_LIB_CONNECT_H__
320