• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // arcsort.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Functions and classes to sort arcs in an FST.
18 
19 #ifndef FST_LIB_ARCSORT_H__
20 #define FST_LIB_ARCSORT_H__
21 
22 #include <algorithm>
23 
24 #include "fst/lib/cache.h"
25 #include "fst/lib/test-properties.h"
26 
27 namespace fst {
28 
29 // Sorts the arcs in an FST according to function object 'comp' of
30 // type Compare. This version modifies its input.  Comparison function
31 // objects IlabelCompare and OlabelCompare are provived by the
32 // library. In general, Compare must meet the requirements for an STL
33 // sort comparision function object. It must also have a member
34 // Properties(uint64) that specifies the known properties of the
35 // sorted FST; it takes as argument the input FST's known properties
36 // before the sort.
37 //
38 // Complexity:
39 // - Time: O(V + D log D)
40 // - Space: O(D)
41 // where V = # of states and D = maximum out-degree.
42 template<class Arc, class Compare>
ArcSort(MutableFst<Arc> * fst,Compare comp)43 void ArcSort(MutableFst<Arc> *fst, Compare comp) {
44   typedef typename Arc::StateId StateId;
45 
46   uint64 props = fst->Properties(kFstProperties, false);
47 
48   vector<Arc> arcs;
49   for (StateIterator< MutableFst<Arc> > siter(*fst);
50        !siter.Done();
51        siter.Next()) {
52     StateId s = siter.Value();
53     arcs.clear();
54     for (ArcIterator< MutableFst<Arc> > aiter(*fst, s);
55          !aiter.Done();
56          aiter.Next())
57       arcs.push_back(aiter.Value());
58     sort(arcs.begin(), arcs.end(), comp);
59     fst->DeleteArcs(s);
60     for (size_t a = 0; a < arcs.size(); ++a)
61       fst->AddArc(s, arcs[a]);
62   }
63 
64   fst->SetProperties(comp.Properties(props), kFstProperties);
65 }
66 
67 typedef CacheOptions ArcSortFstOptions;
68 
69 // Implementation of delayed ArcSortFst.
70 template<class A, class C>
71 class ArcSortFstImpl : public CacheImpl<A> {
72  public:
73   using FstImpl<A>::SetType;
74   using FstImpl<A>::SetProperties;
75   using FstImpl<A>::Properties;
76   using FstImpl<A>::SetInputSymbols;
77   using FstImpl<A>::SetOutputSymbols;
78   using FstImpl<A>::InputSymbols;
79   using FstImpl<A>::OutputSymbols;
80 
81   using VectorFstBaseImpl<typename CacheImpl<A>::State>::NumStates;
82 
83   using CacheImpl<A>::HasArcs;
84   using CacheImpl<A>::HasFinal;
85   using CacheImpl<A>::HasStart;
86 
87   typedef typename A::Weight Weight;
88   typedef typename A::StateId StateId;
89 
ArcSortFstImpl(const Fst<A> & fst,const C & comp,const ArcSortFstOptions & opts)90   ArcSortFstImpl(const Fst<A> &fst, const C &comp,
91                  const ArcSortFstOptions &opts)
92       : CacheImpl<A>(opts), fst_(fst.Copy()), comp_(comp) {
93     SetType("arcsort");
94     uint64 props = fst_->Properties(kCopyProperties, false);
95     SetProperties(comp_.Properties(props));
96     SetInputSymbols(fst.InputSymbols());
97     SetOutputSymbols(fst.OutputSymbols());
98   }
99 
ArcSortFstImpl(const ArcSortFstImpl & impl)100   ArcSortFstImpl(const ArcSortFstImpl& impl)
101       : fst_(impl.fst_->Copy()), comp_(impl.comp_) {
102     SetType("arcsort");
103     SetProperties(impl.Properties(), kCopyProperties);
104     SetInputSymbols(impl.InputSymbols());
105     SetOutputSymbols(impl.OutputSymbols());
106   }
107 
~ArcSortFstImpl()108   ~ArcSortFstImpl() { delete fst_; }
109 
Start()110   StateId Start() {
111     if (!HasStart())
112       SetStart(fst_->Start());
113     return CacheImpl<A>::Start();
114   }
115 
Final(StateId s)116   Weight Final(StateId s) {
117     if (!HasFinal(s))
118       SetFinal(s, fst_->Final(s));
119     return CacheImpl<A>::Final(s);
120   }
121 
NumArcs(StateId s)122   size_t NumArcs(StateId s) {
123     if (!HasArcs(s))
124       Expand(s);
125     return CacheImpl<A>::NumArcs(s);
126   }
127 
NumInputEpsilons(StateId s)128   size_t NumInputEpsilons(StateId s) {
129     if (!HasArcs(s))
130       Expand(s);
131     return CacheImpl<A>::NumInputEpsilons(s);
132   }
133 
NumOutputEpsilons(StateId s)134   size_t NumOutputEpsilons(StateId s) {
135     if (!HasArcs(s))
136       Expand(s);
137     return CacheImpl<A>::NumOutputEpsilons(s);
138   }
139 
InitStateIterator(StateIteratorData<A> * data)140   void InitStateIterator(StateIteratorData<A> *data) const {
141     fst_->InitStateIterator(data);
142   }
143 
InitArcIterator(StateId s,ArcIteratorData<A> * data)144   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
145     if (!HasArcs(s))
146       Expand(s);
147     CacheImpl<A>::InitArcIterator(s, data);
148   }
149 
Expand(StateId s)150   void Expand(StateId s) {
151     for (ArcIterator< Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next())
152       AddArc(s, aiter.Value());
153     SetArcs(s);
154 
155     if (s < NumStates()) {  // ensure state exists
156       vector<A> &carcs = GetState(s)->arcs;
157       sort(carcs.begin(), carcs.end(), comp_);
158     }
159   }
160 
161  private:
162   const Fst<A> *fst_;
163   C comp_;
164 
165   void operator=(const ArcSortFstImpl<A, C> &impl);  // Disallow
166 };
167 
168 
169 // Sorts the arcs in an FST according to function object 'comp' of
170 // type Compare. This version is a delayed Fst.  Comparsion function
171 // objects IlabelCompare and OlabelCompare are provided by the
172 // library. In general, Compare must meet the requirements for an STL
173 // comparision function object (e.g. as used for STL sort). It must
174 // also have a member Properties(uint64) that specifies the known
175 // properties of the sorted FST; it takes as argument the input FST's
176 // known properties.
177 //
178 // Complexity:
179 // - Time: O(v + d log d)
180 // - Space: O(v + d)
181 // where v = # of states visited, d = maximum out-degree of states
182 // visited. Constant time and space to visit an input state is assumed
183 // and exclusive of caching.
184 template <class A, class C>
185 class ArcSortFst : public Fst<A> {
186  public:
187   friend class CacheArcIterator< ArcSortFst<A, C> >;
188   friend class ArcIterator< ArcSortFst<A, C> >;
189 
190   typedef A Arc;
191   typedef C Compare;
192   typedef typename A::Weight Weight;
193   typedef typename A::StateId StateId;
194   typedef CacheState<A> State;
195 
ArcSortFst(const Fst<A> & fst,const C & comp)196   ArcSortFst(const Fst<A> &fst, const C &comp)
197       : impl_(new ArcSortFstImpl<A, C>(fst, comp, ArcSortFstOptions())) {}
198 
ArcSortFst(const Fst<A> & fst,const C & comp,const ArcSortFstOptions & opts)199   ArcSortFst(const Fst<A> &fst, const C &comp, const ArcSortFstOptions &opts)
200       : impl_(new ArcSortFstImpl<A, C>(fst, comp, opts)) {}
201 
ArcSortFst(const ArcSortFst<A,C> & fst)202   ArcSortFst(const ArcSortFst<A, C> &fst) :
203       impl_(new ArcSortFstImpl<A, C>(*(fst.impl_))) {}
204 
~ArcSortFst()205   virtual ~ArcSortFst() { if (!impl_->DecrRefCount()) delete impl_; }
206 
Start()207   virtual StateId Start() const { return impl_->Start(); }
208 
Final(StateId s)209   virtual Weight Final(StateId s) const { return impl_->Final(s); }
210 
NumArcs(StateId s)211   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
212 
NumInputEpsilons(StateId s)213   virtual size_t NumInputEpsilons(StateId s) const {
214     return impl_->NumInputEpsilons(s);
215   }
216 
NumOutputEpsilons(StateId s)217   virtual size_t NumOutputEpsilons(StateId s) const {
218     return impl_->NumOutputEpsilons(s);
219   }
220 
Properties(uint64 mask,bool test)221   virtual uint64 Properties(uint64 mask, bool test) const {
222     if (test) {
223       uint64 known, test = TestProperties(*this, mask, &known);
224       impl_->SetProperties(test, known);
225       return test & mask;
226     } else {
227       return impl_->Properties(mask);
228     }
229   }
230 
Type()231   virtual const string& Type() const { return impl_->Type(); }
232 
Copy()233   virtual ArcSortFst<A, C> *Copy() const {
234     return new ArcSortFst<A, C>(*this);
235   }
236 
InputSymbols()237   virtual const SymbolTable* InputSymbols() const {
238     return impl_->InputSymbols();
239   }
240 
OutputSymbols()241   virtual const SymbolTable* OutputSymbols() const {
242     return impl_->OutputSymbols();
243   }
244 
InitStateIterator(StateIteratorData<A> * data)245   virtual void InitStateIterator(StateIteratorData<A> *data) const {
246     impl_->InitStateIterator(data);
247   }
248 
InitArcIterator(StateId s,ArcIteratorData<A> * data)249   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
250     impl_->InitArcIterator(s, data);
251   }
252 
253  private:
254   ArcSortFstImpl<A, C> *impl_;
255 
256   void operator=(const ArcSortFst<A, C> &fst);  // Disallow
257 };
258 
259 
260 // Specialization for ArcSortFst.
261 template <class A, class C>
262 class ArcIterator< ArcSortFst<A, C> >
263     : public CacheArcIterator< ArcSortFst<A, C> > {
264  public:
265   typedef typename A::StateId StateId;
266 
ArcIterator(const ArcSortFst<A,C> & fst,StateId s)267   ArcIterator(const ArcSortFst<A, C> &fst, StateId s)
268       : CacheArcIterator< ArcSortFst<A, C> >(fst, s) {
269     if (!fst.impl_->HasArcs(s))
270       fst.impl_->Expand(s);
271   }
272 
273  private:
274   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
275 };
276 
277 
278 // Compare class for comparing input labels of arcs.
279 template<class A> class ILabelCompare {
280  public:
operator()281   bool operator() (A arc1, A arc2) const {
282     return arc1.ilabel < arc2.ilabel;
283   }
284 
Properties(uint64 props)285   uint64 Properties(uint64 props) const {
286     return (props & kArcSortProperties) | kILabelSorted;
287   }
288 };
289 
290 
291 // Compare class for comparing output labels of arcs.
292 template<class A> class OLabelCompare {
293  public:
operator()294   bool operator() (const A &arc1, const A &arc2) const {
295     return arc1.olabel < arc2.olabel;
296   }
297 
Properties(uint64 props)298   uint64 Properties(uint64 props) const {
299     return (props & kArcSortProperties) | kOLabelSorted;
300   }
301 };
302 
303 
304 // Useful aliases when using StdArc.
305 template<class C> class StdArcSortFst : public ArcSortFst<StdArc, C> {
306  public:
307   typedef StdArc Arc;
308   typedef C Compare;
309 };
310 
311 typedef ILabelCompare<StdArc> StdILabelCompare;
312 
313 typedef OLabelCompare<StdArc> StdOLabelCompare;
314 
315 }  // namespace fst
316 
317 #endif  // FST_LIB_ARCSORT_H__
318