• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // reweight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Function to reweight an FST.
19 
20 #ifndef FST_LIB_REWEIGHT_H__
21 #define FST_LIB_REWEIGHT_H__
22 
23 #include "fst/lib/mutable-fst.h"
24 
25 namespace fst {
26 
27 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
28 
29 // Reweight FST according to the potentials defined by the POTENTIAL
30 // vector in the direction defined by TYPE. Weight needs to be left
31 // distributive when reweighting towards the initial state and right
32 // distributive when reweighting towards the final states.
33 //
34 // An arc of weight w, with an origin state of potential p and
35 // destination state of potential q, is reweighted by p\wq when
36 // reweighting towards the initial state and by pw/q when reweighting
37 // towards the final states.
38 template <class Arc>
Reweight(MutableFst<Arc> * fst,vector<typename Arc::Weight> potential,ReweightType type)39 void Reweight(MutableFst<Arc> *fst, vector<typename Arc::Weight> potential,
40               ReweightType type) {
41   typedef typename Arc::Weight Weight;
42 
43   if (!fst->NumStates())
44     return;
45   while ( (int64)potential.size() < (int64)fst->NumStates())
46     potential.push_back(Weight::Zero());
47 
48   if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring))
49     LOG(FATAL) << "Reweight: Reweighting to the final states requires "
50                << "Weight to be right distributive: "
51                << Weight::Type();
52 
53   if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring))
54     LOG(FATAL) << "Reweight: Reweighting to the initial state requires "
55                << "Weight to be left distributive: "
56                << Weight::Type();
57 
58   for (StateIterator< MutableFst<Arc> > sit(*fst);
59        !sit.Done();
60        sit.Next()) {
61     typename Arc::StateId state = sit.Value();
62     for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
63          !ait.Done();
64          ait.Next()) {
65       Arc arc = ait.Value();
66       if ((potential[state] == Weight::Zero()) ||
67 	  (potential[arc.nextstate] == Weight::Zero()))
68 	continue; //temp fix: needs to find best solution for zeros
69       if ((type == REWEIGHT_TO_INITIAL)
70 	  && (potential[state] != Weight::Zero()))
71         arc.weight = Divide(Times(arc.weight, potential[arc.nextstate]),
72 			    potential[state], DIVIDE_LEFT);
73       else if ((type == REWEIGHT_TO_FINAL)
74 	       && (potential[arc.nextstate] != Weight::Zero()))
75         arc.weight = Divide(Times(potential[state], arc.weight),
76                             potential[arc.nextstate], DIVIDE_RIGHT);
77       ait.SetValue(arc);
78     }
79     if ((type == REWEIGHT_TO_INITIAL)
80 	&& (potential[state] != Weight::Zero()))
81       fst->SetFinal(state,
82                     Divide(fst->Final(state), potential[state], DIVIDE_LEFT));
83     else if (type == REWEIGHT_TO_FINAL)
84       fst->SetFinal(state, Times(potential[state], fst->Final(state)));
85   }
86 
87   if ((potential[fst->Start()] != Weight::One()) &&
88       (potential[fst->Start()] != Weight::Zero())) {
89     if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
90       typename Arc::StateId state = fst->Start();
91       for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
92            !ait.Done();
93            ait.Next()) {
94         Arc arc = ait.Value();
95         if (type == REWEIGHT_TO_INITIAL)
96           arc.weight = Times(potential[state], arc.weight);
97         else
98           arc.weight = Times(
99               Divide(Weight::One(), potential[state], DIVIDE_RIGHT),
100               arc.weight);
101         ait.SetValue(arc);
102       }
103       if (type == REWEIGHT_TO_INITIAL)
104         fst->SetFinal(state, Times(potential[state], fst->Final(state)));
105       else
106         fst->SetFinal(state, Times(Divide(Weight::One(), potential[state],
107                                           DIVIDE_RIGHT),
108                                    fst->Final(state)));
109     }
110     else {
111       typename Arc::StateId state = fst->AddState();
112       Weight w = type == REWEIGHT_TO_INITIAL ?
113                  potential[fst->Start()] :
114                  Divide(Weight::One(), potential[fst->Start()], DIVIDE_RIGHT);
115       Arc arc (0, 0, w, fst->Start());
116       fst->AddArc(state, arc);
117       fst->SetStart(state);
118     }
119   }
120 
121   fst->SetProperties(ReweightProperties(
122                          fst->Properties(kFstProperties, false)),
123                      kFstProperties);
124 }
125 
126 }  // namespace fst
127 
128 #endif /* FST_LIB_REWEIGHT_H_ */
129