1 // arcsort.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Functions and classes to sort arcs in an FST.
18
19 #ifndef FST_LIB_ARCSORT_H__
20 #define FST_LIB_ARCSORT_H__
21
22 #include <algorithm>
23
24 #include "fst/lib/cache.h"
25 #include "fst/lib/test-properties.h"
26
27 namespace fst {
28
29 // Sorts the arcs in an FST according to function object 'comp' of
30 // type Compare. This version modifies its input. Comparison function
31 // objects IlabelCompare and OlabelCompare are provived by the
32 // library. In general, Compare must meet the requirements for an STL
33 // sort comparision function object. It must also have a member
34 // Properties(uint64) that specifies the known properties of the
35 // sorted FST; it takes as argument the input FST's known properties
36 // before the sort.
37 //
38 // Complexity:
39 // - Time: O(V + D log D)
40 // - Space: O(D)
41 // where V = # of states and D = maximum out-degree.
42 template<class Arc, class Compare>
ArcSort(MutableFst<Arc> * fst,Compare comp)43 void ArcSort(MutableFst<Arc> *fst, Compare comp) {
44 typedef typename Arc::StateId StateId;
45
46 uint64 props = fst->Properties(kFstProperties, false);
47
48 vector<Arc> arcs;
49 for (StateIterator< MutableFst<Arc> > siter(*fst);
50 !siter.Done();
51 siter.Next()) {
52 StateId s = siter.Value();
53 arcs.clear();
54 for (ArcIterator< MutableFst<Arc> > aiter(*fst, s);
55 !aiter.Done();
56 aiter.Next())
57 arcs.push_back(aiter.Value());
58 sort(arcs.begin(), arcs.end(), comp);
59 fst->DeleteArcs(s);
60 for (size_t a = 0; a < arcs.size(); ++a)
61 fst->AddArc(s, arcs[a]);
62 }
63
64 fst->SetProperties(comp.Properties(props), kFstProperties);
65 }
66
67 typedef CacheOptions ArcSortFstOptions;
68
69 // Implementation of delayed ArcSortFst.
70 template<class A, class C>
71 class ArcSortFstImpl : public CacheImpl<A> {
72 public:
73 using FstImpl<A>::SetType;
74 using FstImpl<A>::SetProperties;
75 using FstImpl<A>::Properties;
76 using FstImpl<A>::SetInputSymbols;
77 using FstImpl<A>::SetOutputSymbols;
78 using FstImpl<A>::InputSymbols;
79 using FstImpl<A>::OutputSymbols;
80
81 using VectorFstBaseImpl<typename CacheImpl<A>::State>::NumStates;
82
83 using CacheImpl<A>::HasArcs;
84 using CacheImpl<A>::HasFinal;
85 using CacheImpl<A>::HasStart;
86
87 typedef typename A::Weight Weight;
88 typedef typename A::StateId StateId;
89
ArcSortFstImpl(const Fst<A> & fst,const C & comp,const ArcSortFstOptions & opts)90 ArcSortFstImpl(const Fst<A> &fst, const C &comp,
91 const ArcSortFstOptions &opts)
92 : CacheImpl<A>(opts), fst_(fst.Copy()), comp_(comp) {
93 SetType("arcsort");
94 uint64 props = fst_->Properties(kCopyProperties, false);
95 SetProperties(comp_.Properties(props));
96 SetInputSymbols(fst.InputSymbols());
97 SetOutputSymbols(fst.OutputSymbols());
98 }
99
ArcSortFstImpl(const ArcSortFstImpl & impl)100 ArcSortFstImpl(const ArcSortFstImpl& impl)
101 : fst_(impl.fst_->Copy()), comp_(impl.comp_) {
102 SetType("arcsort");
103 SetProperties(impl.Properties(), kCopyProperties);
104 SetInputSymbols(impl.InputSymbols());
105 SetOutputSymbols(impl.OutputSymbols());
106 }
107
~ArcSortFstImpl()108 ~ArcSortFstImpl() { delete fst_; }
109
Start()110 StateId Start() {
111 if (!HasStart())
112 SetStart(fst_->Start());
113 return CacheImpl<A>::Start();
114 }
115
Final(StateId s)116 Weight Final(StateId s) {
117 if (!HasFinal(s))
118 SetFinal(s, fst_->Final(s));
119 return CacheImpl<A>::Final(s);
120 }
121
NumArcs(StateId s)122 size_t NumArcs(StateId s) {
123 if (!HasArcs(s))
124 Expand(s);
125 return CacheImpl<A>::NumArcs(s);
126 }
127
NumInputEpsilons(StateId s)128 size_t NumInputEpsilons(StateId s) {
129 if (!HasArcs(s))
130 Expand(s);
131 return CacheImpl<A>::NumInputEpsilons(s);
132 }
133
NumOutputEpsilons(StateId s)134 size_t NumOutputEpsilons(StateId s) {
135 if (!HasArcs(s))
136 Expand(s);
137 return CacheImpl<A>::NumOutputEpsilons(s);
138 }
139
InitStateIterator(StateIteratorData<A> * data)140 void InitStateIterator(StateIteratorData<A> *data) const {
141 fst_->InitStateIterator(data);
142 }
143
InitArcIterator(StateId s,ArcIteratorData<A> * data)144 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
145 if (!HasArcs(s))
146 Expand(s);
147 CacheImpl<A>::InitArcIterator(s, data);
148 }
149
Expand(StateId s)150 void Expand(StateId s) {
151 for (ArcIterator< Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next())
152 AddArc(s, aiter.Value());
153 SetArcs(s);
154
155 if (s < NumStates()) { // ensure state exists
156 vector<A> &carcs = GetState(s)->arcs;
157 sort(carcs.begin(), carcs.end(), comp_);
158 }
159 }
160
161 private:
162 const Fst<A> *fst_;
163 C comp_;
164
165 void operator=(const ArcSortFstImpl<A, C> &impl); // Disallow
166 };
167
168
169 // Sorts the arcs in an FST according to function object 'comp' of
170 // type Compare. This version is a delayed Fst. Comparsion function
171 // objects IlabelCompare and OlabelCompare are provided by the
172 // library. In general, Compare must meet the requirements for an STL
173 // comparision function object (e.g. as used for STL sort). It must
174 // also have a member Properties(uint64) that specifies the known
175 // properties of the sorted FST; it takes as argument the input FST's
176 // known properties.
177 //
178 // Complexity:
179 // - Time: O(v + d log d)
180 // - Space: O(v + d)
181 // where v = # of states visited, d = maximum out-degree of states
182 // visited. Constant time and space to visit an input state is assumed
183 // and exclusive of caching.
184 template <class A, class C>
185 class ArcSortFst : public Fst<A> {
186 public:
187 friend class CacheArcIterator< ArcSortFst<A, C> >;
188 friend class ArcIterator< ArcSortFst<A, C> >;
189
190 typedef A Arc;
191 typedef C Compare;
192 typedef typename A::Weight Weight;
193 typedef typename A::StateId StateId;
194 typedef CacheState<A> State;
195
ArcSortFst(const Fst<A> & fst,const C & comp)196 ArcSortFst(const Fst<A> &fst, const C &comp)
197 : impl_(new ArcSortFstImpl<A, C>(fst, comp, ArcSortFstOptions())) {}
198
ArcSortFst(const Fst<A> & fst,const C & comp,const ArcSortFstOptions & opts)199 ArcSortFst(const Fst<A> &fst, const C &comp, const ArcSortFstOptions &opts)
200 : impl_(new ArcSortFstImpl<A, C>(fst, comp, opts)) {}
201
ArcSortFst(const ArcSortFst<A,C> & fst)202 ArcSortFst(const ArcSortFst<A, C> &fst) :
203 impl_(new ArcSortFstImpl<A, C>(*(fst.impl_))) {}
204
~ArcSortFst()205 virtual ~ArcSortFst() { if (!impl_->DecrRefCount()) delete impl_; }
206
Start()207 virtual StateId Start() const { return impl_->Start(); }
208
Final(StateId s)209 virtual Weight Final(StateId s) const { return impl_->Final(s); }
210
NumArcs(StateId s)211 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
212
NumInputEpsilons(StateId s)213 virtual size_t NumInputEpsilons(StateId s) const {
214 return impl_->NumInputEpsilons(s);
215 }
216
NumOutputEpsilons(StateId s)217 virtual size_t NumOutputEpsilons(StateId s) const {
218 return impl_->NumOutputEpsilons(s);
219 }
220
Properties(uint64 mask,bool test)221 virtual uint64 Properties(uint64 mask, bool test) const {
222 if (test) {
223 uint64 known, test = TestProperties(*this, mask, &known);
224 impl_->SetProperties(test, known);
225 return test & mask;
226 } else {
227 return impl_->Properties(mask);
228 }
229 }
230
Type()231 virtual const string& Type() const { return impl_->Type(); }
232
Copy()233 virtual ArcSortFst<A, C> *Copy() const {
234 return new ArcSortFst<A, C>(*this);
235 }
236
InputSymbols()237 virtual const SymbolTable* InputSymbols() const {
238 return impl_->InputSymbols();
239 }
240
OutputSymbols()241 virtual const SymbolTable* OutputSymbols() const {
242 return impl_->OutputSymbols();
243 }
244
InitStateIterator(StateIteratorData<A> * data)245 virtual void InitStateIterator(StateIteratorData<A> *data) const {
246 impl_->InitStateIterator(data);
247 }
248
InitArcIterator(StateId s,ArcIteratorData<A> * data)249 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
250 impl_->InitArcIterator(s, data);
251 }
252
253 private:
254 ArcSortFstImpl<A, C> *impl_;
255
256 void operator=(const ArcSortFst<A, C> &fst); // Disallow
257 };
258
259
260 // Specialization for ArcSortFst.
261 template <class A, class C>
262 class ArcIterator< ArcSortFst<A, C> >
263 : public CacheArcIterator< ArcSortFst<A, C> > {
264 public:
265 typedef typename A::StateId StateId;
266
ArcIterator(const ArcSortFst<A,C> & fst,StateId s)267 ArcIterator(const ArcSortFst<A, C> &fst, StateId s)
268 : CacheArcIterator< ArcSortFst<A, C> >(fst, s) {
269 if (!fst.impl_->HasArcs(s))
270 fst.impl_->Expand(s);
271 }
272
273 private:
274 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
275 };
276
277
278 // Compare class for comparing input labels of arcs.
279 template<class A> class ILabelCompare {
280 public:
operator()281 bool operator() (A arc1, A arc2) const {
282 return arc1.ilabel < arc2.ilabel;
283 }
284
Properties(uint64 props)285 uint64 Properties(uint64 props) const {
286 return (props & kArcSortProperties) | kILabelSorted;
287 }
288 };
289
290
291 // Compare class for comparing output labels of arcs.
292 template<class A> class OLabelCompare {
293 public:
operator()294 bool operator() (const A &arc1, const A &arc2) const {
295 return arc1.olabel < arc2.olabel;
296 }
297
Properties(uint64 props)298 uint64 Properties(uint64 props) const {
299 return (props & kArcSortProperties) | kOLabelSorted;
300 }
301 };
302
303
304 // Useful aliases when using StdArc.
305 template<class C> class StdArcSortFst : public ArcSortFst<StdArc, C> {
306 public:
307 typedef StdArc Arc;
308 typedef C Compare;
309 };
310
311 typedef ILabelCompare<StdArc> StdILabelCompare;
312
313 typedef OLabelCompare<StdArc> StdOLabelCompare;
314
315 } // namespace fst
316
317 #endif // FST_LIB_ARCSORT_H__
318