1 // reweight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Function to reweight an FST.
19
20 #ifndef FST_LIB_REWEIGHT_H__
21 #define FST_LIB_REWEIGHT_H__
22
23 #include "fst/lib/mutable-fst.h"
24
25 namespace fst {
26
27 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
28
29 // Reweight FST according to the potentials defined by the POTENTIAL
30 // vector in the direction defined by TYPE. Weight needs to be left
31 // distributive when reweighting towards the initial state and right
32 // distributive when reweighting towards the final states.
33 //
34 // An arc of weight w, with an origin state of potential p and
35 // destination state of potential q, is reweighted by p\wq when
36 // reweighting towards the initial state and by pw/q when reweighting
37 // towards the final states.
38 template <class Arc>
Reweight(MutableFst<Arc> * fst,vector<typename Arc::Weight> potential,ReweightType type)39 void Reweight(MutableFst<Arc> *fst, vector<typename Arc::Weight> potential,
40 ReweightType type) {
41 typedef typename Arc::Weight Weight;
42
43 if (!fst->NumStates())
44 return;
45 while ( (int64)potential.size() < (int64)fst->NumStates())
46 potential.push_back(Weight::Zero());
47
48 if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring))
49 LOG(FATAL) << "Reweight: Reweighting to the final states requires "
50 << "Weight to be right distributive: "
51 << Weight::Type();
52
53 if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring))
54 LOG(FATAL) << "Reweight: Reweighting to the initial state requires "
55 << "Weight to be left distributive: "
56 << Weight::Type();
57
58 for (StateIterator< MutableFst<Arc> > sit(*fst);
59 !sit.Done();
60 sit.Next()) {
61 typename Arc::StateId state = sit.Value();
62 for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
63 !ait.Done();
64 ait.Next()) {
65 Arc arc = ait.Value();
66 if ((potential[state] == Weight::Zero()) ||
67 (potential[arc.nextstate] == Weight::Zero()))
68 continue; //temp fix: needs to find best solution for zeros
69 if ((type == REWEIGHT_TO_INITIAL)
70 && (potential[state] != Weight::Zero()))
71 arc.weight = Divide(Times(arc.weight, potential[arc.nextstate]),
72 potential[state], DIVIDE_LEFT);
73 else if ((type == REWEIGHT_TO_FINAL)
74 && (potential[arc.nextstate] != Weight::Zero()))
75 arc.weight = Divide(Times(potential[state], arc.weight),
76 potential[arc.nextstate], DIVIDE_RIGHT);
77 ait.SetValue(arc);
78 }
79 if ((type == REWEIGHT_TO_INITIAL)
80 && (potential[state] != Weight::Zero()))
81 fst->SetFinal(state,
82 Divide(fst->Final(state), potential[state], DIVIDE_LEFT));
83 else if (type == REWEIGHT_TO_FINAL)
84 fst->SetFinal(state, Times(potential[state], fst->Final(state)));
85 }
86
87 if ((potential[fst->Start()] != Weight::One()) &&
88 (potential[fst->Start()] != Weight::Zero())) {
89 if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
90 typename Arc::StateId state = fst->Start();
91 for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
92 !ait.Done();
93 ait.Next()) {
94 Arc arc = ait.Value();
95 if (type == REWEIGHT_TO_INITIAL)
96 arc.weight = Times(potential[state], arc.weight);
97 else
98 arc.weight = Times(
99 Divide(Weight::One(), potential[state], DIVIDE_RIGHT),
100 arc.weight);
101 ait.SetValue(arc);
102 }
103 if (type == REWEIGHT_TO_INITIAL)
104 fst->SetFinal(state, Times(potential[state], fst->Final(state)));
105 else
106 fst->SetFinal(state, Times(Divide(Weight::One(), potential[state],
107 DIVIDE_RIGHT),
108 fst->Final(state)));
109 }
110 else {
111 typename Arc::StateId state = fst->AddState();
112 Weight w = type == REWEIGHT_TO_INITIAL ?
113 potential[fst->Start()] :
114 Divide(Weight::One(), potential[fst->Start()], DIVIDE_RIGHT);
115 Arc arc (0, 0, w, fst->Start());
116 fst->AddArc(state, arc);
117 fst->SetStart(state);
118 }
119 }
120
121 fst->SetProperties(ReweightProperties(
122 fst->Properties(kFstProperties, false)),
123 kFstProperties);
124 }
125
126 } // namespace fst
127
128 #endif /* FST_LIB_REWEIGHT_H_ */
129