1 // push.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Class to reweight/push an FST.
20
21 #ifndef FST_LIB_PUSH_H__
22 #define FST_LIB_PUSH_H__
23
24 #include <vector>
25 using std::vector;
26
27 #include <fst/factor-weight.h>
28 #include <fst/fst.h>
29 #include <fst/arc-map.h>
30 #include <fst/reweight.h>
31 #include <fst/shortest-distance.h>
32
33
34 namespace fst {
35
36 // Private helper functions for Push
37 namespace internal {
38
39 // Compute the total weight (sum of the weights of all accepting paths) from
40 // the output of ShortestDistance. 'distance' is the shortest distance from the
41 // initial state when 'reverse == false' and to the final states when
42 // 'reverse == true'.
43 template <class Arc>
ComputeTotalWeight(const Fst<Arc> & fst,const vector<typename Arc::Weight> & distance,bool reverse)44 typename Arc::Weight ComputeTotalWeight(
45 const Fst<Arc> &fst,
46 const vector<typename Arc::Weight> &distance,
47 bool reverse) {
48 if (reverse)
49 return fst.Start() < distance.size() ?
50 distance[fst.Start()] : Arc::Weight::Zero();
51
52 typename Arc::Weight sum = Arc::Weight::Zero();
53 for (typename Arc::StateId s = 0; s < distance.size(); ++s)
54 sum = Plus(sum, Times(distance[s], fst.Final(s)));
55 return sum;
56 }
57
58 // Divide the weight of every accepting path by 'w'. The weight 'w' is
59 // divided at the final states if 'at_final == true' and at the
60 // initial state otherwise.
61 template <class Arc>
RemoveWeight(MutableFst<Arc> * fst,typename Arc::Weight w,bool at_final)62 void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) {
63 if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero()))
64 return;
65
66 if (at_final) {
67 // Remove 'w' from the final states
68 for (StateIterator< MutableFst<Arc> > sit(*fst);
69 !sit.Done();
70 sit.Next())
71 fst->SetFinal(sit.Value(),
72 Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT));
73 } else { // at_final == false
74 // Remove 'w' from the initial state
75 typename Arc::StateId start = fst->Start();
76 for (MutableArcIterator<MutableFst<Arc> > ait(fst, start);
77 !ait.Done();
78 ait.Next()) {
79 Arc arc = ait.Value();
80 arc.weight = Divide(arc.weight, w, DIVIDE_LEFT);
81 ait.SetValue(arc);
82 }
83 fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT));
84 }
85 }
86 } // namespace internal
87
88 // Pushes the weights in FST in the direction defined by TYPE. If
89 // pushing towards the initial state, the sum of the weight of the
90 // outgoing transitions and final weight at a non-initial state is
91 // equal to One() in the resulting machine. If pushing towards the
92 // final state, the same property holds on the reverse machine.
93 //
94 // Weight needs to be left distributive when pushing towards the
95 // initial state and right distributive when pushing towards the final
96 // states.
97 template <class Arc>
98 void Push(MutableFst<Arc> *fst,
99 ReweightType type,
100 float delta = kDelta,
101 bool remove_total_weight = false) {
102 vector<typename Arc::Weight> distance;
103 ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta);
104 typename Arc::Weight total_weight = Arc::Weight::One();
105 if (remove_total_weight)
106 total_weight = internal::ComputeTotalWeight(*fst, distance,
107 type == REWEIGHT_TO_INITIAL);
108 Reweight(fst, distance, type);
109 if (remove_total_weight)
110 internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL);
111 }
112
113 const uint32 kPushWeights = 0x0001;
114 const uint32 kPushLabels = 0x0002;
115 const uint32 kPushRemoveTotalWeight = 0x0004;
116 const uint32 kPushRemoveCommonAffix = 0x0008;
117
118 // OFST obtained from IFST by pushing weights and/or labels according
119 // to PTYPE in the direction defined by RTYPE. Weight needs to be
120 // left distributive when pushing weights towards the initial state
121 // and right distributive when pushing weights towards the final
122 // states.
123 template <class Arc, ReweightType rtype>
124 void Push(const Fst<Arc> &ifst,
125 MutableFst<Arc> *ofst,
126 uint32 ptype,
127 float delta = kDelta) {
128
129 if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
130 *ofst = ifst;
131 Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
132 } else if (ptype & kPushLabels) {
133 const StringType stype = rtype == REWEIGHT_TO_INITIAL
134 ? STRING_LEFT
135 : STRING_RIGHT;
136 vector<typename GallicArc<Arc, stype>::Weight> gdistance;
137 VectorFst<GallicArc<Arc, stype> > gfst;
138 ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>());
139 if (ptype & kPushWeights ) {
140 ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
141 } else {
142 ArcMapFst<Arc, Arc, RmWeightMapper<Arc> >
143 uwfst(ifst, RmWeightMapper<Arc>());
144 ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
145 guwfst(uwfst, ToGallicMapper<Arc, stype>());
146 ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
147 }
148 typename GallicArc<Arc, stype>::Weight total_weight =
149 GallicArc<Arc, stype>::Weight::One();
150 if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
151 total_weight = internal::ComputeTotalWeight(
152 gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
153 total_weight = typename GallicArc<Arc, stype>::Weight(
154 ptype & kPushRemoveCommonAffix ? total_weight.Value1()
155 : StringWeight<typename Arc::Label, stype>::One(),
156 ptype & kPushRemoveTotalWeight ? total_weight.Value2()
157 : Arc::Weight::One());
158 }
159 Reweight(&gfst, gdistance, rtype);
160 if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix))
161 internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
162 FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
163 typename Arc::Weight, stype> > fwfst(gfst);
164 ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>());
165 ofst->SetOutputSymbols(ifst.OutputSymbols());
166 } else {
167 LOG(WARNING) << "Push: pushing type is set to 0: "
168 << "pushing neither labels nor weights.";
169 *ofst = ifst;
170 }
171 }
172
173 } // namespace fst
174
175 #endif /* FST_LIB_PUSH_H_ */
176