• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 // string.h
3 
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Copyright 2005-2010 Google, Inc.
17 // Author: allauzen@google.com (Cyril Allauzen)
18 //
19 // \file
20 // Utilities to convert strings into FSTs.
21 //
22 
23 #ifndef FST_LIB_STRING_H_
24 #define FST_LIB_STRING_H_
25 
26 #include <fst/compact-fst.h>
27 #include <fst/mutable-fst.h>
28 
29 DECLARE_string(fst_field_separator);
30 
31 namespace fst {
32 
33 // Functor compiling a string in an FST
34 template <class A>
35 class StringCompiler {
36  public:
37   typedef A Arc;
38   typedef typename A::Label Label;
39   typedef typename A::Weight Weight;
40 
41   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
42 
43   StringCompiler(TokenType type, const SymbolTable *syms = 0,
44                  Label unknown_label = kNoLabel,
45                  bool allow_negative = false)
token_type_(type)46       : token_type_(type), syms_(syms), unknown_label_(unknown_label),
47         allow_negative_(allow_negative) {}
48 
49   // Compile string 's' into FST 'fst'.
50   template <class F>
operator()51   bool operator()(const string &s, F *fst) {
52     vector<Label> labels;
53     if (!ConvertStringToLabels(s, &labels))
54       return false;
55     Compile(labels, fst);
56     return true;
57   }
58 
59  private:
ConvertStringToLabels(const string & str,vector<Label> * labels)60   bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
61     labels->clear();
62     if (token_type_ == BYTE) {
63       for (size_t i = 0; i < str.size(); ++i)
64         labels->push_back(static_cast<unsigned char>(str[i]));
65     } else if (token_type_ == UTF8) {
66       return UTF8StringToLabels(str, labels);
67     } else {
68       char *c_str = new char[str.size() + 1];
69       str.copy(c_str, str.size());
70       c_str[str.size()] = 0;
71       vector<char *> vec;
72       string separator = "\n" + FLAGS_fst_field_separator;
73       SplitToVector(c_str, separator.c_str(), &vec, true);
74       for (size_t i = 0; i < vec.size(); ++i) {
75         Label label;
76         if (!ConvertSymbolToLabel(vec[i], &label))
77           return false;
78         labels->push_back(label);
79       }
80       delete[] c_str;
81     }
82     return true;
83   }
84 
Compile(const vector<Label> & labels,MutableFst<A> * fst)85   void Compile(const vector<Label> &labels, MutableFst<A> *fst) const {
86     fst->DeleteStates();
87     while (fst->NumStates() <= labels.size())
88       fst->AddState();
89     for (size_t i = 0; i < labels.size(); ++i)
90       fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
91     fst->SetStart(0);
92     fst->SetFinal(labels.size(), Weight::One());
93   }
94 
95   template <class Unsigned>
Compile(const vector<Label> & labels,CompactFst<A,StringCompactor<A>,Unsigned> * fst)96   void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>,
97                Unsigned> *fst) const {
98     fst->SetCompactElements(labels.begin(), labels.end());
99   }
100 
ConvertSymbolToLabel(const char * s,Label * output)101   bool ConvertSymbolToLabel(const char *s, Label* output) const {
102     int64 n;
103     if (syms_) {
104       n = syms_->Find(s);
105       if ((n == -1) && (unknown_label_ != kNoLabel))
106         n = unknown_label_;
107       if (n == -1 || (!allow_negative_ && n < 0)) {
108         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
109                 << "\" is not mapped to any integer label, symbol table = "
110                  << syms_->Name();
111         return false;
112       }
113     } else {
114       char *p;
115       n = strtoll(s, &p, 10);
116       if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
117         VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
118                 << "= \"" << s << "\"";
119         return false;
120       }
121     }
122     *output = n;
123     return true;
124   }
125 
126   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
127   const SymbolTable *syms_;  // Symbol table used when token type is symbol
128   Label unknown_label_;      // Label for token missing from symbol table
129   bool allow_negative_;      // Negative labels allowed?
130 
131   DISALLOW_COPY_AND_ASSIGN(StringCompiler);
132 };
133 
134 // Functor to print a string FST as a string.
135 template <class A>
136 class StringPrinter {
137  public:
138   typedef A Arc;
139   typedef typename A::Label Label;
140   typedef typename A::StateId StateId;
141   typedef typename A::Weight Weight;
142 
143   enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
144 
145   StringPrinter(TokenType token_type,
146                 const SymbolTable *syms = 0)
token_type_(token_type)147       : token_type_(token_type), syms_(syms) {}
148 
149   // Convert the FST 'fst' into the string 'output'
operator()150   bool operator()(const Fst<A> &fst, string *output) {
151     bool is_a_string = FstToLabels(fst);
152     if (!is_a_string) {
153       VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
154       return false;
155     }
156 
157     output->clear();
158 
159     if (token_type_ == SYMBOL) {
160       stringstream sstrm;
161       for (size_t i = 0; i < labels_.size(); ++i) {
162         if (i)
163           sstrm << *(FLAGS_fst_field_separator.rbegin());
164         if (!PrintLabel(labels_[i], sstrm))
165           return false;
166       }
167       *output = sstrm.str();
168     } else if (token_type_ == BYTE) {
169       for (size_t i = 0; i < labels_.size(); ++i) {
170         output->push_back(labels_[i]);
171       }
172     } else if (token_type_ == UTF8) {
173       return LabelsToUTF8String(labels_, output);
174     } else {
175       VLOG(1) << "StringPrinter::operator(): Unknown token type: "
176               << token_type_;
177       return false;
178     }
179     return true;
180   }
181 
182  private:
FstToLabels(const Fst<A> & fst)183   bool FstToLabels(const Fst<A> &fst) {
184     labels_.clear();
185 
186     StateId s = fst.Start();
187     if (s == kNoStateId) {
188       VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
189               << "string fst.";
190       return false;
191     }
192 
193     while (fst.Final(s) == Weight::Zero()) {
194       ArcIterator<Fst<A> > aiter(fst, s);
195       if (aiter.Done()) {
196         VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
197                 << "not reach final state.";
198         return false;
199       }
200 
201       const A& arc = aiter.Value();
202       labels_.push_back(arc.olabel);
203 
204       s = arc.nextstate;
205       if (s == kNoStateId) {
206         VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
207                 << "state.";
208         return false;
209       }
210 
211       aiter.Next();
212       if (!aiter.Done()) {
213         VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
214                 << "outgoing arcs found.";
215         return false;
216       }
217     }
218 
219     return true;
220   }
221 
PrintLabel(Label lab,ostream & ostrm)222   bool PrintLabel(Label lab, ostream& ostrm) {
223     if (syms_) {
224       string symbol = syms_->Find(lab);
225       if (symbol == "") {
226         VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
227                 << "mapped to any textual symbol, symbol table = "
228                  << syms_->Name();
229         return false;
230       }
231       ostrm << symbol;
232     } else {
233       ostrm << lab;
234     }
235     return true;
236   }
237 
238   TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
239   const SymbolTable *syms_;  // Symbol table used when token type is symbol
240   vector<Label> labels_;     // Input FST labels.
241 
242   DISALLOW_COPY_AND_ASSIGN(StringPrinter);
243 };
244 
245 }  // namespace fst
246 
247 #endif // FST_LIB_STRING_H_
248