1 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 // Copyright 2005-2010 Google, Inc. 15 // Author: jpr@google.com (Jake Ratkiewicz) 16 17 #ifndef FST_SCRIPT_FST_CLASS_H_ 18 #define FST_SCRIPT_FST_CLASS_H_ 19 20 #include <string> 21 22 #include <fst/fst.h> 23 #include <fst/mutable-fst.h> 24 #include <fst/vector-fst.h> 25 #include <iostream> 26 #include <fstream> 27 28 // Classes to support "boxing" all existing types of FST arcs in a single 29 // FstClass which hides the arc types. This allows clients to load 30 // and work with FSTs without knowing the arc type. 31 32 // These classes are only recommended for use in high-level scripting 33 // applications. Most users should use the lower-level templated versions 34 // corresponding to these classes. 35 36 namespace fst { 37 namespace script { 38 39 // 40 // Abstract base class defining the set of functionalities implemented 41 // in all impls, and passed through by all bases Below FstClassBase 42 // the class hierarchy bifurcates; FstClassImplBase serves as the base 43 // class for all implementations (of which FstClassImpl is currently 44 // the only one) and FstClass serves as the base class for all 45 // interfaces. 46 // 47 class FstClassBase { 48 public: 49 virtual const string &ArcType() const = 0; 50 virtual const string &FstType() const = 0; 51 virtual const string &WeightType() const = 0; 52 virtual const SymbolTable *InputSymbols() const = 0; 53 virtual const SymbolTable *OutputSymbols() const = 0; 54 virtual void Write(const string& fname) const = 0; 55 virtual uint64 Properties(uint64 mask, bool test) const = 0; ~FstClassBase()56 virtual ~FstClassBase() { } 57 }; 58 59 class FstClassImplBase : public FstClassBase { 60 public: 61 virtual FstClassImplBase *Copy() = 0; 62 virtual void SetInputSymbols(SymbolTable *is) = 0; 63 virtual void SetOutputSymbols(SymbolTable *is) = 0; ~FstClassImplBase()64 virtual ~FstClassImplBase() { } 65 }; 66 67 68 // 69 // CONTAINER CLASS 70 // Wraps an Fst<Arc>, hiding its arc type. Whether this Fst<Arc> 71 // pointer refers to a special kind of FST (e.g. a MutableFst) is 72 // known by the type of interface class that owns the pointer to this 73 // container. 74 // 75 76 template<class Arc> 77 class FstClassImpl : public FstClassImplBase { 78 public: 79 explicit FstClassImpl(Fst<Arc> *impl, 80 bool should_own = false) : 81 impl_(should_own ? impl : impl->Copy()) { } 82 ArcType()83 virtual const string &ArcType() const { 84 return Arc::Type(); 85 } 86 FstType()87 virtual const string &FstType() const { 88 return impl_->Type(); 89 } 90 WeightType()91 virtual const string &WeightType() const { 92 return Arc::Weight::Type(); 93 } 94 InputSymbols()95 virtual const SymbolTable *InputSymbols() const { 96 return impl_->InputSymbols(); 97 } 98 OutputSymbols()99 virtual const SymbolTable *OutputSymbols() const { 100 return impl_->OutputSymbols(); 101 } 102 103 // Warning: calling this method casts the FST to a mutable FST. SetInputSymbols(SymbolTable * is)104 virtual void SetInputSymbols(SymbolTable *is) { 105 static_cast<MutableFst<Arc> *>(impl_)->SetInputSymbols(is); 106 } 107 108 // Warning: calling this method casts the FST to a mutable FST. SetOutputSymbols(SymbolTable * os)109 virtual void SetOutputSymbols(SymbolTable *os) { 110 static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os); 111 } 112 Write(const string & fname)113 virtual void Write(const string &fname) const { 114 impl_->Write(fname); 115 } 116 Properties(uint64 mask,bool test)117 virtual uint64 Properties(uint64 mask, bool test) const { 118 return impl_->Properties(mask, test); 119 } 120 ~FstClassImpl()121 virtual ~FstClassImpl() { delete impl_; } 122 GetImpl()123 Fst<Arc> *GetImpl() { return impl_; } 124 Copy()125 virtual FstClassImpl *Copy() { 126 return new FstClassImpl<Arc>(impl_); 127 } 128 129 private: 130 Fst<Arc> *impl_; 131 }; 132 133 // 134 // BASE CLASS DEFINITIONS 135 // 136 137 class MutableFstClass; 138 139 class FstClass : public FstClassBase { 140 public: 141 template<class Arc> Read(istream & stream,const FstReadOptions & opts)142 static FstClass *Read(istream &stream, 143 const FstReadOptions &opts) { 144 if (!opts.header) { 145 FSTERROR() << "FstClass::Read: options header not specified"; 146 return 0; 147 } 148 const FstHeader &hdr = *opts.header; 149 150 if (hdr.Properties() & kMutable) { 151 return ReadTypedFst<MutableFstClass, MutableFst<Arc> >(stream, opts); 152 } else { 153 return ReadTypedFst<FstClass, Fst<Arc> >(stream, opts); 154 } 155 } 156 157 template<class Arc> FstClass(Fst<Arc> * fst)158 explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) { } 159 FstClass(const FstClass & other)160 explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } 161 162 static FstClass *Read(const string &fname); 163 ArcType()164 virtual const string &ArcType() const { 165 return impl_->ArcType(); 166 } 167 FstType()168 virtual const string& FstType() const { 169 return impl_->FstType(); 170 } 171 InputSymbols()172 virtual const SymbolTable *InputSymbols() const { 173 return impl_->InputSymbols(); 174 } 175 OutputSymbols()176 virtual const SymbolTable *OutputSymbols() const { 177 return impl_->OutputSymbols(); 178 } 179 WeightType()180 virtual const string& WeightType() const { 181 return impl_->WeightType(); 182 } 183 Write(const string & fname)184 virtual void Write(const string &fname) const { 185 impl_->Write(fname); 186 } 187 Properties(uint64 mask,bool test)188 virtual uint64 Properties(uint64 mask, bool test) const { 189 return impl_->Properties(mask, test); 190 } 191 192 template<class Arc> GetFst()193 const Fst<Arc> *GetFst() const { 194 if (Arc::Type() != ArcType()) { 195 return NULL; 196 } else { 197 FstClassImpl<Arc> *typed_impl = static_cast<FstClassImpl<Arc> *>(impl_); 198 return typed_impl->GetImpl(); 199 } 200 } 201 ~FstClass()202 virtual ~FstClass() { delete impl_; } 203 204 // These methods are required by IO registration 205 template<class Arc> Convert(const FstClass & other)206 static FstClassImplBase *Convert(const FstClass &other) { 207 LOG(ERROR) << "Doesn't make sense to convert any class to type FstClass."; 208 return 0; 209 } 210 211 template<class Arc> Create()212 static FstClassImplBase *Create() { 213 LOG(ERROR) << "Doesn't make sense to create an FstClass with a " 214 << "particular arc type."; 215 return 0; 216 } 217 protected: FstClass(FstClassImplBase * impl)218 explicit FstClass(FstClassImplBase *impl) : impl_(impl) { } 219 220 // Generic template method for reading an arc-templated FST of type 221 // UnderlyingT, and returning it wrapped as FstClassT, with appropriate 222 // error checking. Called from arc-templated Read() static methods. 223 template<class FstClassT, class UnderlyingT> ReadTypedFst(istream & stream,const FstReadOptions & opts)224 static FstClassT* ReadTypedFst(istream &stream, 225 const FstReadOptions &opts) { 226 UnderlyingT *u = UnderlyingT::Read(stream, opts); 227 if (!u) { 228 return 0; 229 } else { 230 FstClassT *r = new FstClassT(u); 231 delete u; 232 return r; 233 } 234 } 235 GetImpl()236 FstClassImplBase *GetImpl() { return impl_; } 237 private: 238 FstClassImplBase *impl_; 239 }; 240 241 // 242 // Specific types of FstClass with special properties 243 // 244 245 class MutableFstClass : public FstClass { 246 public: 247 template<class Arc> MutableFstClass(MutableFst<Arc> * fst)248 explicit MutableFstClass(MutableFst<Arc> *fst) : 249 FstClass(fst) { } 250 251 template<class Arc> GetMutableFst()252 MutableFst<Arc> *GetMutableFst() { 253 Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>()); 254 MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst); 255 256 return mfst; 257 } 258 259 template<class Arc> Read(istream & stream,const FstReadOptions & opts)260 static MutableFstClass *Read(istream &stream, 261 const FstReadOptions &opts) { 262 MutableFst<Arc> *mfst = MutableFst<Arc>::Read(stream, opts); 263 if (!mfst) { 264 return 0; 265 } else { 266 MutableFstClass *retval = new MutableFstClass(mfst); 267 delete mfst; 268 return retval; 269 } 270 } 271 272 static MutableFstClass *Read(const string &fname, bool convert = false); 273 SetInputSymbols(SymbolTable * is)274 virtual void SetInputSymbols(SymbolTable *is) { 275 GetImpl()->SetInputSymbols(is); 276 } 277 SetOutputSymbols(SymbolTable * os)278 virtual void SetOutputSymbols(SymbolTable *os) { 279 GetImpl()->SetOutputSymbols(os); 280 } 281 282 // These methods are required by IO registration 283 template<class Arc> Convert(const FstClass & other)284 static FstClassImplBase *Convert(const FstClass &other) { 285 LOG(ERROR) << "Doesn't make sense to convert any class to type " 286 << "MutableFstClass."; 287 return 0; 288 } 289 290 template<class Arc> Create()291 static FstClassImplBase *Create() { 292 LOG(ERROR) << "Doesn't make sense to create a MutableFstClass with a " 293 << "particular arc type."; 294 return 0; 295 } 296 297 protected: MutableFstClass(FstClassImplBase * impl)298 explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) { } 299 }; 300 301 302 class VectorFstClass : public MutableFstClass { 303 public: 304 explicit VectorFstClass(const FstClass &other); 305 explicit VectorFstClass(const string &arc_type); 306 307 template<class Arc> VectorFstClass(VectorFst<Arc> * fst)308 explicit VectorFstClass(VectorFst<Arc> *fst) : 309 MutableFstClass(fst) { } 310 311 template<class Arc> Read(istream & stream,const FstReadOptions & opts)312 static VectorFstClass *Read(istream &stream, 313 const FstReadOptions &opts) { 314 VectorFst<Arc> *vfst = VectorFst<Arc>::Read(stream, opts); 315 if (!vfst) { 316 return 0; 317 } else { 318 VectorFstClass *retval = new VectorFstClass(vfst); 319 delete vfst; 320 return retval; 321 } 322 } 323 324 static VectorFstClass *Read(const string &fname); 325 326 // Converter / creator for known arc types 327 template<class Arc> Convert(const FstClass & other)328 static FstClassImplBase *Convert(const FstClass &other) { 329 return new FstClassImpl<Arc>(new VectorFst<Arc>( 330 *other.GetFst<Arc>()), true); 331 } 332 333 template<class Arc> Create()334 static FstClassImplBase *Create() { 335 return new FstClassImpl<Arc>(new VectorFst<Arc>(), true); 336 } 337 }; 338 339 } // namespace script 340 } // namespace fst 341 342 343 #endif // FST_SCRIPT_FST_CLASS_H_ 344