1 // synchronize.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Synchronize an FST with bounded delay.
19
20 #ifndef FST_LIB_SYNCHRONIZE_H__
21 #define FST_LIB_SYNCHRONIZE_H__
22
23 #include <algorithm>
24
25 #include <unordered_map>
26 #include <unordered_set>
27
28 #include "fst/lib/cache.h"
29 #include "fst/lib/test-properties.h"
30
31 namespace fst {
32
33 typedef CacheOptions SynchronizeFstOptions;
34
35
36 // Implementation class for SynchronizeFst
37 template <class A>
38 class SynchronizeFstImpl
39 : public CacheImpl<A> {
40 public:
41 using FstImpl<A>::SetType;
42 using FstImpl<A>::SetProperties;
43 using FstImpl<A>::Properties;
44 using FstImpl<A>::SetInputSymbols;
45 using FstImpl<A>::SetOutputSymbols;
46
47 using CacheBaseImpl< CacheState<A> >::HasStart;
48 using CacheBaseImpl< CacheState<A> >::HasFinal;
49 using CacheBaseImpl< CacheState<A> >::HasArcs;
50
51 typedef A Arc;
52 typedef typename A::Label Label;
53 typedef typename A::Weight Weight;
54 typedef typename A::StateId StateId;
55
56 typedef basic_string<Label> String;
57
58 struct Element {
ElementElement59 Element() {}
60
ElementElement61 Element(StateId s, const String *i, const String *o)
62 : state(s), istring(i), ostring(o) {}
63
64 StateId state; // Input state Id
65 const String *istring; // Residual input labels
66 const String *ostring; // Residual output labels
67 // Residual strings are represented by const pointers to
68 // basic_string<Label> and are stored in a hash_set. The pointed
69 // memory is owned by the hash_set string_set_.
70 };
71
SynchronizeFstImpl(const Fst<A> & fst,const SynchronizeFstOptions & opts)72 SynchronizeFstImpl(const Fst<A> &fst, const SynchronizeFstOptions &opts)
73 : CacheImpl<A>(opts), fst_(fst.Copy()) {
74 SetType("synchronize");
75 uint64 props = fst.Properties(kFstProperties, false);
76 SetProperties(SynchronizeProperties(props), kCopyProperties);
77
78 SetInputSymbols(fst.InputSymbols());
79 SetOutputSymbols(fst.OutputSymbols());
80 }
81
~SynchronizeFstImpl()82 ~SynchronizeFstImpl() {
83 delete fst_;
84 // Extract pointers from the hash set
85 vector<const String*> strings;
86 typename StringSet::iterator it = string_set_.begin();
87 for (; it != string_set_.end(); ++it)
88 strings.push_back(*it);
89 // Free the extracted pointers
90 for (size_t i = 0; i < strings.size(); ++i)
91 delete strings[i];
92 }
93
Start()94 StateId Start() {
95 if (!HasStart()) {
96 StateId s = fst_->Start();
97 if (s == kNoStateId)
98 return kNoStateId;
99 const String *empty = FindString(new String());
100 StateId start = FindState(Element(fst_->Start(), empty, empty));
101 SetStart(start);
102 }
103 return CacheImpl<A>::Start();
104 }
105
Final(StateId s)106 Weight Final(StateId s) {
107 if (!HasFinal(s)) {
108 const Element &e = elements_[s];
109 Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
110 if ((w != Weight::Zero()) && (e.istring)->empty() && (e.ostring)->empty())
111 SetFinal(s, w);
112 else
113 SetFinal(s, Weight::Zero());
114 }
115 return CacheImpl<A>::Final(s);
116 }
117
NumArcs(StateId s)118 size_t NumArcs(StateId s) {
119 if (!HasArcs(s))
120 Expand(s);
121 return CacheImpl<A>::NumArcs(s);
122 }
123
NumInputEpsilons(StateId s)124 size_t NumInputEpsilons(StateId s) {
125 if (!HasArcs(s))
126 Expand(s);
127 return CacheImpl<A>::NumInputEpsilons(s);
128 }
129
NumOutputEpsilons(StateId s)130 size_t NumOutputEpsilons(StateId s) {
131 if (!HasArcs(s))
132 Expand(s);
133 return CacheImpl<A>::NumOutputEpsilons(s);
134 }
135
InitArcIterator(StateId s,ArcIteratorData<A> * data)136 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
137 if (!HasArcs(s))
138 Expand(s);
139 CacheImpl<A>::InitArcIterator(s, data);
140 }
141
142 // Returns the first character of the string obtained by
143 // concatenating s and l.
144 Label Car(const String *s, Label l = 0) const {
145 if (!s->empty())
146 return (*s)[0];
147 else
148 return l;
149 }
150
151 // Computes the residual string obtained by removing the first
152 // character in the concatenation of s and l.
153 const String *Cdr(const String *s, Label l = 0) {
154 String *r = new String();
155 for (int i = 1; i < s->size(); ++i)
156 r->push_back((*s)[i]);
157 if (l && !(s->empty())) r->push_back(l);
158 return FindString(r);
159 }
160
161 // Computes the concatenation of s and l.
162 const String *Concat(const String *s, Label l = 0) {
163 String *r = new String();
164 for (int i = 0; i < s->size(); ++i)
165 r->push_back((*s)[i]);
166 if (l) r->push_back(l);
167 return FindString(r);
168 }
169
170 // Tests if the concatenation of s and l is empty
171 bool Empty(const String *s, Label l = 0) const {
172 if (s->empty())
173 return l == 0;
174 else
175 return false;
176 }
177
178 // Finds the string pointed by s in the hash set. Transfers the
179 // pointer ownership to the hash set.
FindString(const String * s)180 const String *FindString(const String *s) {
181 typename StringSet::iterator it = string_set_.find(s);
182 if (it != string_set_.end()) {
183 delete s;
184 return (*it);
185 } else {
186 string_set_.insert(s);
187 return s;
188 }
189 }
190
191 // Finds state corresponding to an element. Creates new state
192 // if element not found.
FindState(const Element & e)193 StateId FindState(const Element &e) {
194 typename ElementMap::iterator eit = element_map_.find(e);
195 if (eit != element_map_.end()) {
196 return (*eit).second;
197 } else {
198 StateId s = elements_.size();
199 elements_.push_back(e);
200 element_map_.insert(pair<const Element, StateId>(e, s));
201 return s;
202 }
203 }
204
205
206 // Computes the outgoing transitions from a state, creating new destination
207 // states as needed.
Expand(StateId s)208 void Expand(StateId s) {
209 Element e = elements_[s];
210
211 if (e.state != kNoStateId)
212 for (ArcIterator< Fst<A> > ait(*fst_, e.state);
213 !ait.Done();
214 ait.Next()) {
215 const A &arc = ait.Value();
216 if (!Empty(e.istring, arc.ilabel) && !Empty(e.ostring, arc.olabel)) {
217 const String *istring = Cdr(e.istring, arc.ilabel);
218 const String *ostring = Cdr(e.ostring, arc.olabel);
219 StateId d = FindState(Element(arc.nextstate, istring, ostring));
220 AddArc(s, Arc(Car(e.istring, arc.ilabel),
221 Car(e.ostring, arc.olabel), arc.weight, d));
222 } else {
223 const String *istring = Concat(e.istring, arc.ilabel);
224 const String *ostring = Concat(e.ostring, arc.olabel);
225 StateId d = FindState(Element(arc.nextstate, istring, ostring));
226 AddArc(s, Arc(0 , 0, arc.weight, d));
227 }
228 }
229
230 Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
231 if ((w != Weight::Zero()) &&
232 ((e.istring)->size() + (e.ostring)->size() > 0)) {
233 const String *istring = Cdr(e.istring);
234 const String *ostring = Cdr(e.ostring);
235 StateId d = FindState(Element(kNoStateId, istring, ostring));
236 AddArc(s, Arc(Car(e.istring), Car(e.ostring), w, d));
237 }
238 SetArcs(s);
239 }
240
241 private:
242 // Equality function for Elements, assume strings have been hashed.
243 class ElementEqual {
244 public:
operator()245 bool operator()(const Element &x, const Element &y) const {
246 return x.state == y.state &&
247 x.istring == y.istring &&
248 x.ostring == y.ostring;
249 }
250 };
251
252 // Hash function for Elements to Fst states.
253 class ElementKey {
254 public:
operator()255 size_t operator()(const Element &x) const {
256 size_t key = x.state;
257 key = (key << 1) ^ (x.istring)->size();
258 for (size_t i = 0; i < (x.istring)->size(); ++i)
259 key = (key << 1) ^ (*x.istring)[i];
260 key = (key << 1) ^ (x.ostring)->size();
261 for (size_t i = 0; i < (x.ostring)->size(); ++i)
262 key = (key << 1) ^ (*x.ostring)[i];
263 return key;
264 }
265 };
266
267 // Equality function for strings
268 class StringEqual {
269 public:
operator()270 bool operator()(const String * const &x, const String * const &y) const {
271 if (x->size() != y->size()) return false;
272 for (size_t i = 0; i < x->size(); ++i)
273 if ((*x)[i] != (*y)[i]) return false;
274 return true;
275 }
276 };
277
278 // Hash function for set of strings
279 class StringKey{
280 public:
operator()281 size_t operator()(const String * const & x) const {
282 size_t key = x->size();
283 for (size_t i = 0; i < x->size(); ++i)
284 key = (key << 1) ^ (*x)[i];
285 return key;
286 }
287 };
288
289
290 typedef std::unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
291 typedef std::unordered_set<const String*, StringKey, StringEqual> StringSet;
292
293 const Fst<A> *fst_;
294 vector<Element> elements_; // mapping Fst state to Elements
295 ElementMap element_map_; // mapping Elements to Fst state
296 StringSet string_set_;
297
298 DISALLOW_EVIL_CONSTRUCTORS(SynchronizeFstImpl);
299 };
300
301
302 // Synchronizes a transducer. This version is a delayed Fst. The
303 // result will be an equivalent FST that has the property that during
304 // the traversal of a path, the delay is either zero or strictly
305 // increasing, where the delay is the difference between the number of
306 // non-epsilon output labels and input labels along the path.
307 //
308 // For the algorithm to terminate, the input transducer must have
309 // bounded delay, i.e., the delay of every cycle must be zero.
310 //
311 // Complexity:
312 // - A has bounded delay: exponential
313 // - A does not have bounded delay: does not terminate
314 //
315 // References:
316 // - Mehryar Mohri. Edit-Distance of Weighted Automata: General
317 // Definitions and Algorithms, International Journal of Computer
318 // Science, 14(6): 957-982 (2003).
319 template <class A>
320 class SynchronizeFst : public Fst<A> {
321 public:
322 friend class ArcIterator< SynchronizeFst<A> >;
323 friend class CacheStateIterator< SynchronizeFst<A> >;
324 friend class CacheArcIterator< SynchronizeFst<A> >;
325
326 typedef A Arc;
327 typedef typename A::Weight Weight;
328 typedef typename A::StateId StateId;
329 typedef CacheState<A> State;
330
SynchronizeFst(const Fst<A> & fst)331 SynchronizeFst(const Fst<A> &fst)
332 : impl_(new SynchronizeFstImpl<A>(fst, SynchronizeFstOptions())) {}
333
SynchronizeFst(const Fst<A> & fst,const SynchronizeFstOptions & opts)334 SynchronizeFst(const Fst<A> &fst, const SynchronizeFstOptions &opts)
335 : impl_(new SynchronizeFstImpl<A>(fst, opts)) {}
336
SynchronizeFst(const SynchronizeFst<A> & fst)337 SynchronizeFst(const SynchronizeFst<A> &fst) : impl_(fst.impl_) {
338 impl_->IncrRefCount();
339 }
340
~SynchronizeFst()341 virtual ~SynchronizeFst() { if (!impl_->DecrRefCount()) delete impl_; }
342
Start()343 virtual StateId Start() const { return impl_->Start(); }
344
Final(StateId s)345 virtual Weight Final(StateId s) const { return impl_->Final(s); }
346
NumArcs(StateId s)347 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
348
NumInputEpsilons(StateId s)349 virtual size_t NumInputEpsilons(StateId s) const {
350 return impl_->NumInputEpsilons(s);
351 }
352
NumOutputEpsilons(StateId s)353 virtual size_t NumOutputEpsilons(StateId s) const {
354 return impl_->NumOutputEpsilons(s);
355 }
356
Properties(uint64 mask,bool test)357 virtual uint64 Properties(uint64 mask, bool test) const {
358 if (test) {
359 uint64 known, test = TestProperties(*this, mask, &known);
360 impl_->SetProperties(test, known);
361 return test & mask;
362 } else {
363 return impl_->Properties(mask);
364 }
365 }
366
Type()367 virtual const string& Type() const { return impl_->Type(); }
368
Copy()369 virtual SynchronizeFst<A> *Copy() const {
370 return new SynchronizeFst<A>(*this);
371 }
372
InputSymbols()373 virtual const SymbolTable* InputSymbols() const {
374 return impl_->InputSymbols();
375 }
376
OutputSymbols()377 virtual const SymbolTable* OutputSymbols() const {
378 return impl_->OutputSymbols();
379 }
380
381 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
382
InitArcIterator(StateId s,ArcIteratorData<A> * data)383 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
384 impl_->InitArcIterator(s, data);
385 }
386
387 private:
Impl()388 SynchronizeFstImpl<A> *Impl() { return impl_; }
389
390 SynchronizeFstImpl<A> *impl_;
391
392 void operator=(const SynchronizeFst<A> &fst); // Disallow
393 };
394
395
396 // Specialization for SynchronizeFst.
397 template<class A>
398 class StateIterator< SynchronizeFst<A> >
399 : public CacheStateIterator< SynchronizeFst<A> > {
400 public:
StateIterator(const SynchronizeFst<A> & fst)401 explicit StateIterator(const SynchronizeFst<A> &fst)
402 : CacheStateIterator< SynchronizeFst<A> >(fst) {}
403 };
404
405
406 // Specialization for SynchronizeFst.
407 template <class A>
408 class ArcIterator< SynchronizeFst<A> >
409 : public CacheArcIterator< SynchronizeFst<A> > {
410 public:
411 typedef typename A::StateId StateId;
412
ArcIterator(const SynchronizeFst<A> & fst,StateId s)413 ArcIterator(const SynchronizeFst<A> &fst, StateId s)
414 : CacheArcIterator< SynchronizeFst<A> >(fst, s) {
415 if (!fst.impl_->HasArcs(s))
416 fst.impl_->Expand(s);
417 }
418
419 private:
420 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
421 };
422
423
424 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)425 void SynchronizeFst<A>::InitStateIterator(StateIteratorData<A> *data) const
426 {
427 data->base = new StateIterator< SynchronizeFst<A> >(*this);
428 }
429
430
431
432 // Synchronizes a transducer. This version writes the synchronized
433 // result to a MutableFst. The result will be an equivalent FST that
434 // has the property that during the traversal of a path, the delay is
435 // either zero or strictly increasing, where the delay is the
436 // difference between the number of non-epsilon output labels and
437 // input labels along the path.
438 //
439 // For the algorithm to terminate, the input transducer must have
440 // bounded delay, i.e., the delay of every cycle must be zero.
441 //
442 // Complexity:
443 // - A has bounded delay: exponential
444 // - A does not have bounded delay: does not terminate
445 //
446 // References:
447 // - Mehryar Mohri. Edit-Distance of Weighted Automata: General
448 // Definitions and Algorithms, International Journal of Computer
449 // Science, 14(6): 957-982 (2003).
450 template<class Arc>
Synchronize(const Fst<Arc> & ifst,MutableFst<Arc> * ofst)451 void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
452 SynchronizeFstOptions opts;
453 opts.gc_limit = 0; // Cache only the last state for fastest copy.
454 *ofst = SynchronizeFst<Arc>(ifst, opts);
455 }
456
457 } // namespace fst
458
459 #endif // FST_LIB_SYNCHRONIZE_H__
460