1 // randgen.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Function to generate random paths through an FST.
18
19 #ifndef FST_LIB_RANDGEN_H__
20 #define FST_LIB_RANDGEN_H__
21
22 #include <cmath>
23 #include <cstdlib>
24 #include <ctime>
25
26 #include "fst/lib/mutable-fst.h"
27
28 namespace fst {
29
30 //
31 // ARC SELECTORS - these function objects are used to select a random
32 // transition to take from an FST's state. They should return a number
33 // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
34 // transition is selected. If N == NumArcs(), then the final weight at
35 // that state is selected (i.e., the 'super-final' transition is selected).
36 // It can be assumed these will not be called unless either there
37 // are transitions leaving the state and/or the state is final.
38 //
39
40 // Randomly selects a transition using the uniform distribution.
41 template <class A>
42 struct UniformArcSelector {
43 typedef typename A::StateId StateId;
44 typedef typename A::Weight Weight;
45
46 UniformArcSelector(int seed = time(0)) { srand(seed); }
47
operatorUniformArcSelector48 size_t operator()(const Fst<A> &fst, StateId s) const {
49 double r = rand()/(RAND_MAX + 1.0);
50 size_t n = fst.NumArcs(s);
51 if (fst.Final(s) != Weight::Zero())
52 ++n;
53 return static_cast<size_t>(r * n);
54 }
55 };
56
57 // Randomly selects a transition w.r.t. the weights treated as negative
58 // log probabilities after normalizing for the total weight leaving
59 // the state). Weight::zero transitions are disregarded.
60 // Assumes Weight::Value() accesses the floating point
61 // representation of the weight.
62 template <class A>
63 struct LogProbArcSelector {
64 typedef typename A::StateId StateId;
65 typedef typename A::Weight Weight;
66
67 LogProbArcSelector(int seed = time(0)) { srand(seed); }
68
operatorLogProbArcSelector69 size_t operator()(const Fst<A> &fst, StateId s) const {
70 // Find total weight leaving state
71 double sum = 0.0;
72 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
73 aiter.Next()) {
74 const A &arc = aiter.Value();
75 sum += exp(-arc.weight.Value());
76 }
77 sum += exp(-fst.Final(s).Value());
78
79 double r = rand()/(RAND_MAX + 1.0);
80 double p = 0.0;
81 int n = 0;
82 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
83 aiter.Next(), ++n) {
84 const A &arc = aiter.Value();
85 p += exp(-arc.weight.Value());
86 if (p > r * sum) return n;
87 }
88 return n;
89 }
90 };
91
92 // Convenience definitions
93 typedef LogProbArcSelector<StdArc> StdArcSelector;
94 typedef LogProbArcSelector<LogArc> LogArcSelector;
95
96
97 // Options for random path generation.
98 template <class S>
99 struct RandGenOptions {
100 const S &arc_selector; // How an arc is selected at a state
101 int max_length; // Maximum path length
102 size_t npath; // # of paths to generate
103
104 // These are used internally by RandGen
105 int64 source; // 'ifst' state to expand
106 int64 dest; // 'ofst' state to append
107
108 RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1)
arc_selectorRandGenOptions109 : arc_selector(sel), max_length(len), npath(n),
110 source(kNoStateId), dest(kNoStateId) {}
111 };
112
113
114 // Randomly generate paths through an FST; details controlled by
115 // RandGenOptions.
116 template<class Arc, class ArcSelector>
RandGen(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,const RandGenOptions<ArcSelector> & opts)117 void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
118 const RandGenOptions<ArcSelector> &opts) {
119 typedef typename Arc::Weight Weight;
120
121 if (opts.npath == 0 || opts.max_length == 0 || ifst.Start() == kNoStateId)
122 return;
123
124 if (opts.source == kNoStateId) { // first call
125 ofst->DeleteStates();
126 ofst->SetInputSymbols(ifst.InputSymbols());
127 ofst->SetOutputSymbols(ifst.OutputSymbols());
128 ofst->SetStart(ofst->AddState());
129 RandGenOptions<ArcSelector> nopts(opts);
130 nopts.source = ifst.Start();
131 nopts.dest = ofst->Start();
132 for (; nopts.npath > 0; --nopts.npath)
133 RandGen(ifst, ofst, nopts);
134 } else {
135 if (ifst.NumArcs(opts.source) == 0 &&
136 ifst.Final(opts.source) == Weight::Zero()) // Non-coaccessible
137 return;
138 // Pick a random transition from the source state
139 size_t n = opts.arc_selector(ifst, opts.source);
140 if (n == ifst.NumArcs(opts.source)) { // Take 'super-final' transition
141 ofst->SetFinal(opts.dest, Weight::One());
142 } else {
143 ArcIterator< Fst<Arc> > aiter(ifst, opts.source);
144 aiter.Seek(n);
145 const Arc &iarc = aiter.Value();
146 Arc oarc(iarc.ilabel, iarc.olabel, Weight::One(), ofst->AddState());
147 ofst->AddArc(opts.dest, oarc);
148
149 RandGenOptions<ArcSelector> nopts(opts);
150 nopts.source = iarc.nextstate;
151 nopts.dest = oarc.nextstate;
152 --nopts.max_length;
153 RandGen(ifst, ofst, nopts);
154 }
155 }
156 }
157
158 // Randomly generate a path through an FST with the uniform distribution
159 // over the transitions.
160 template<class Arc>
RandGen(const Fst<Arc> & ifst,MutableFst<Arc> * ofst)161 void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
162 UniformArcSelector<Arc> uniform_selector;
163 RandGenOptions< UniformArcSelector<Arc> > opts(uniform_selector);
164 RandGen(ifst, ofst, opts);
165 }
166
167 } // namespace fst
168
169 #endif // FST_LIB_RANDGEN_H__
170