1 // rational.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // An Fst implementation and base interface for delayed unions,
18 // concatenations and closures.
19
20 #ifndef FST_LIB_RATIONAL_H__
21 #define FST_LIB_RATIONAL_H__
22
23 #include "fst/lib/map.h"
24 #include "fst/lib/mutable-fst.h"
25 #include "fst/lib/replace.h"
26 #include "fst/lib/test-properties.h"
27
28 namespace fst {
29
30 typedef CacheOptions RationalFstOptions;
31
32 // This specifies whether to add the empty string.
33 enum ClosureType { CLOSURE_STAR = 0, // T* -> add the empty string
34 CLOSURE_PLUS = 1 }; // T+ -> don't add the empty string
35
36 template <class A> class RationalFst;
37 template <class A> void Union(RationalFst<A> *fst1, const Fst<A> &fst2);
38 template <class A> void Concat(RationalFst<A> *fst1, const Fst<A> &fst2);
39 template <class A> void Closure(RationalFst<A> *fst, ClosureType closure_type);
40
41
42 // Implementation class for delayed unions, concatenations and closures.
43 template<class A>
44 class RationalFstImpl : public ReplaceFstImpl<A> {
45 public:
46 using FstImpl<A>::SetType;
47 using FstImpl<A>::SetProperties;
48 using FstImpl<A>::Properties;
49 using FstImpl<A>::SetInputSymbols;
50 using FstImpl<A>::SetOutputSymbols;
51 using ReplaceFstImpl<A>::SetRoot;
52
53 typedef typename A::Weight Weight;
54 typedef typename A::Label Label;
55
RationalFstImpl(const RationalFstOptions & opts)56 explicit RationalFstImpl(const RationalFstOptions &opts)
57 : ReplaceFstImpl<A>(ReplaceFstOptions(opts, kNoLabel)),
58 nonterminals_(0) {
59 SetType("rational");
60 }
61
62 // Implementation of UnionFst(fst1,fst2)
InitUnion(const Fst<A> & fst1,const Fst<A> & fst2)63 void InitUnion(const Fst<A> &fst1, const Fst<A> &fst2) {
64 uint64 props1 = fst1.Properties(kFstProperties, false);
65 uint64 props2 = fst2.Properties(kFstProperties, false);
66 SetInputSymbols(fst1.InputSymbols());
67 SetOutputSymbols(fst1.OutputSymbols());
68 rfst_.AddState();
69 rfst_.AddState();
70 rfst_.SetStart(0);
71 rfst_.SetFinal(1, Weight::One());
72 rfst_.SetInputSymbols(fst1.InputSymbols());
73 rfst_.SetOutputSymbols(fst1.OutputSymbols());
74 nonterminals_ = 2;
75 rfst_.AddArc(0, A(0, -1, Weight::One(), 1));
76 rfst_.AddArc(0, A(0, -2, Weight::One(), 1));
77 AddFst(0, &rfst_);
78 AddFst(-1, &fst1);
79 AddFst(-2, &fst2);
80 SetRoot(0);
81 SetProperties(UnionProperties(props1, props2, true), kCopyProperties);
82 }
83
84 // Implementation of ConcatFst(fst1,fst2)
InitConcat(const Fst<A> & fst1,const Fst<A> & fst2)85 void InitConcat(const Fst<A> &fst1, const Fst<A> &fst2) {
86 uint64 props1 = fst1.Properties(kFstProperties, false);
87 uint64 props2 = fst2.Properties(kFstProperties, false);
88 SetInputSymbols(fst1.InputSymbols());
89 SetOutputSymbols(fst1.OutputSymbols());
90 rfst_.AddState();
91 rfst_.AddState();
92 rfst_.AddState();
93 rfst_.SetStart(0);
94 rfst_.SetFinal(2, Weight::One());
95 rfst_.SetInputSymbols(fst1.InputSymbols());
96 rfst_.SetOutputSymbols(fst1.OutputSymbols());
97 nonterminals_ = 2;
98 rfst_.AddArc(0, A(0, -1, Weight::One(), 1));
99 rfst_.AddArc(1, A(0, -2, Weight::One(), 2));
100 AddFst(0, &rfst_);
101 AddFst(-1, &fst1);
102 AddFst(-2, &fst2);
103 SetRoot(0);
104 SetProperties(ConcatProperties(props1, props2, true), kCopyProperties);
105 }
106
107 // Implementation of ClosureFst(fst, closure_type)
InitClosure(const Fst<A> & fst,ClosureType closure_type)108 void InitClosure(const Fst<A> &fst, ClosureType closure_type) {
109 uint64 props = fst.Properties(kFstProperties, false);
110 SetInputSymbols(fst.InputSymbols());
111 SetOutputSymbols(fst.OutputSymbols());
112 if (closure_type == CLOSURE_STAR) {
113 rfst_.AddState();
114 rfst_.SetStart(0);
115 rfst_.SetFinal(0, Weight::One());
116 rfst_.AddArc(0, A(0, -1, Weight::One(), 0));
117 } else {
118 rfst_.AddState();
119 rfst_.AddState();
120 rfst_.SetStart(0);
121 rfst_.SetFinal(1, Weight::One());
122 rfst_.AddArc(0, A(0, -1, Weight::One(), 1));
123 rfst_.AddArc(1, A(0, 0, Weight::One(), 0));
124 }
125 rfst_.SetInputSymbols(fst.InputSymbols());
126 rfst_.SetOutputSymbols(fst.OutputSymbols());
127 AddFst(0, &rfst_);
128 AddFst(-1, &fst);
129 SetRoot(0);
130 nonterminals_ = 1;
131 SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true),
132 kCopyProperties);
133 }
134
135 // Implementation of Union(Fst &, RationalFst *)
AddUnion(const Fst<A> & fst)136 void AddUnion(const Fst<A> &fst) {
137 uint64 props1 = Properties();
138 uint64 props2 = fst.Properties(kFstProperties, false);
139 VectorFst<A> afst;
140 afst.AddState();
141 afst.AddState();
142 afst.SetStart(0);
143 afst.SetFinal(1, Weight::One());
144 afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1));
145 Union(&rfst_, afst);
146 SetFst(0, &rfst_);
147 ++nonterminals_;
148 SetProperties(UnionProperties(props1, props2, true), kCopyProperties);
149 }
150
151 // Implementation of Concat(Fst &, RationalFst *)
AddConcat(const Fst<A> & fst)152 void AddConcat(const Fst<A> &fst) {
153 uint64 props1 = Properties();
154 uint64 props2 = fst.Properties(kFstProperties, false);
155 VectorFst<A> afst;
156 afst.AddState();
157 afst.AddState();
158 afst.SetStart(0);
159 afst.SetFinal(1, Weight::One());
160 afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1));
161 Concat(&rfst_, afst);
162 SetFst(0, &rfst_);
163 ++nonterminals_;
164 SetProperties(ConcatProperties(props1, props2, true), kCopyProperties);
165 }
166
167 // Implementation of Closure(RationalFst *, closure_type)
AddClosure(ClosureType closure_type)168 void AddClosure(ClosureType closure_type) {
169 uint64 props = Properties();
170 Closure(&rfst_, closure_type);
171 SetFst(0, &rfst_);
172 SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true),
173 kCopyProperties);
174 }
175
176 private:
177 VectorFst<A> rfst_; // rational topology machine; uses neg. nonterminals
178 Label nonterminals_; // # of nonterminals used
179
180 DISALLOW_EVIL_CONSTRUCTORS(RationalFstImpl);
181 };
182
183 // Parent class for the delayed rational operations - delayed union,
184 // concatenation, and closure. This class attaches interface to
185 // implementation and handles reference counting.
186 template <class A>
187 class RationalFst : public Fst<A> {
188 public:
189 friend class CacheStateIterator< RationalFst<A> >;
190 friend class ArcIterator< RationalFst<A> >;
191 friend class CacheArcIterator< RationalFst<A> >;
192 friend void Union<>(RationalFst<A> *fst1, const Fst<A> &fst2);
193 friend void Concat<>(RationalFst<A> *fst1, const Fst<A> &fst2);
194 friend void Closure<>(RationalFst<A> *fst, ClosureType closure_type);
195
196 typedef A Arc;
197 typedef typename A::Weight Weight;
198 typedef typename A::StateId StateId;
199 typedef CacheState<A> State;
200
Start()201 virtual StateId Start() const { return impl_->Start(); }
Final(StateId s)202 virtual Weight Final(StateId s) const { return impl_->Final(s); }
NumArcs(StateId s)203 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
NumInputEpsilons(StateId s)204 virtual size_t NumInputEpsilons(StateId s) const {
205 return impl_->NumInputEpsilons(s);
206 }
NumOutputEpsilons(StateId s)207 virtual size_t NumOutputEpsilons(StateId s) const {
208 return impl_->NumOutputEpsilons(s);
209 }
Properties(uint64 mask,bool test)210 virtual uint64 Properties(uint64 mask, bool test) const {
211 if (test) {
212 uint64 known, test = TestProperties(*this, mask, &known);
213 impl_->SetProperties(test, known);
214 return test & mask;
215 } else {
216 return impl_->Properties(mask);
217 }
218 }
Type()219 virtual const string& Type() const { return impl_->Type(); }
InputSymbols()220 virtual const SymbolTable* InputSymbols() const {
221 return impl_->InputSymbols();
222 }
OutputSymbols()223 virtual const SymbolTable* OutputSymbols() const {
224 return impl_->OutputSymbols();
225 }
226
227 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
228
InitArcIterator(StateId s,ArcIteratorData<A> * data)229 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
230 impl_->InitArcIterator(s, data);
231 }
232
233 protected:
RationalFst()234 RationalFst() : impl_(new RationalFstImpl<A>(RationalFstOptions())) {}
RationalFst(const RationalFstOptions & opts)235 explicit RationalFst(const RationalFstOptions &opts)
236 : impl_(new RationalFstImpl<A>(opts)) {}
237
238
RationalFst(const RationalFst<A> & fst)239 RationalFst(const RationalFst<A> &fst) : impl_(fst.impl_) {
240 impl_->IncrRefCount();
241 }
242
~RationalFst()243 virtual ~RationalFst() { if (!impl_->DecrRefCount()) delete impl_; }
244
Impl()245 RationalFstImpl<A> *Impl() { return impl_; }
246
247 private:
248 RationalFstImpl<A> *impl_;
249
250 void operator=(const RationalFst<A> &fst); // disallow
251 };
252
253 // Specialization for RationalFst.
254 template <class A>
255 class StateIterator< RationalFst<A> >
256 : public CacheStateIterator< RationalFst<A> > {
257 public:
StateIterator(const RationalFst<A> & fst)258 explicit StateIterator(const RationalFst<A> &fst)
259 : CacheStateIterator< RationalFst<A> >(fst) {}
260 };
261
262 // Specialization for RationalFst.
263 template <class A>
264 class ArcIterator< RationalFst<A> >
265 : public CacheArcIterator< RationalFst<A> > {
266 public:
267 typedef typename A::StateId StateId;
268
ArcIterator(const RationalFst<A> & fst,StateId s)269 ArcIterator(const RationalFst<A> &fst, StateId s)
270 : CacheArcIterator< RationalFst<A> >(fst, s) {
271 if (!fst.impl_->HasArcs(s))
272 fst.impl_->Expand(s);
273 }
274
275 private:
276 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
277 };
278
279 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)280 void RationalFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
281 data->base = new StateIterator< RationalFst<A> >(*this);
282 }
283
284 } // namespace fst
285
286 #endif // FST_LIB_RATIONAL_H__
287