• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // test.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Function to test equality of two Fsts.
18 
19 #ifndef FST_LIB_EQUAL_H__
20 #define FST_LIB_EQUAL_H__
21 
22 #include "fst/lib/fst.h"
23 
24 namespace fst {
25 
26 // Tests if two Fsts have the same states and arcs in the same order.
27 template<class Arc>
Equal(const Fst<Arc> & fst1,const Fst<Arc> & fst2)28 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2) {
29   typedef typename Arc::StateId StateId;
30   typedef typename Arc::Weight Weight;
31 
32   if (fst1.Start() != fst2.Start()) {
33     VLOG(1) << "Equal: mismatched start states";
34     return false;
35   }
36 
37   StateIterator< Fst<Arc> > siter1(fst1);
38   StateIterator< Fst<Arc> > siter2(fst2);
39 
40   while (!siter1.Done() || !siter2.Done()) {
41     if (siter1.Done() || siter2.Done()) {
42       VLOG(1) << "Equal: mismatched # of states";
43       return false;
44     }
45     StateId s1 = siter1.Value();
46     StateId s2 = siter2.Value();
47     if (s1 != s2) {
48       VLOG(1) << "Equal: mismatched states:"
49               << ", state1 = " << s1
50               << ", state2 = " << s2;
51       return false;
52     }
53     Weight final1 = fst1.Final(s1);
54     Weight final2 = fst2.Final(s2);
55     if (!ApproxEqual(final1, final2)) {
56       VLOG(1) << "Equal: mismatched final weights:"
57               << " state = " << s1
58               << ", final1 = " << final1
59               << ", final2 = " << final2;
60       return false;
61      }
62     ArcIterator< Fst<Arc> > aiter1(fst1, s1);
63     ArcIterator< Fst<Arc> > aiter2(fst2, s2);
64     for (size_t a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
65       if (aiter1.Done() || aiter2.Done()) {
66         VLOG(1) << "Equal: mismatched # of arcs"
67                 << " state = " << s1;
68         return false;
69       }
70       Arc arc1 = aiter1.Value();
71       Arc arc2 = aiter2.Value();
72       if (arc1.ilabel != arc2.ilabel) {
73         VLOG(1) << "Equal: mismatched arc input labels:"
74                 << " state = " << s1
75                 << ", arc = " << a
76                 << ", ilabel1 = " << arc1.ilabel
77                 << ", ilabel2 = " << arc2.ilabel;
78         return false;
79       } else  if (arc1.olabel != arc2.olabel) {
80         VLOG(1) << "Equal: mismatched arc output labels:"
81                 << " state = " << s1
82                 << ", arc = " << a
83                 << ", olabel1 = " << arc1.olabel
84                 << ", olabel2 = " << arc2.olabel;
85         return false;
86       } else  if (!ApproxEqual(arc1.weight, arc2.weight)) {
87         VLOG(1) << "Equal: mismatched arc weights:"
88                 << " state = " << s1
89                 << ", arc = " << a
90                 << ", weight1 = " << arc1.weight
91                 << ", weight2 = " << arc2.weight;
92         return false;
93       } else  if (arc1.nextstate != arc2.nextstate) {
94         VLOG(1) << "Equal: mismatched input label:"
95                 << " state = " << s1
96                 << ", arc = " << a
97                 << ", nextstate1 = " << arc1.nextstate
98                 << ", nextstate2 = " << arc2.nextstate;
99         return false;
100       }
101       aiter1.Next();
102       aiter2.Next();
103 
104     }
105     // Sanity checks
106     CHECK_EQ(fst1.NumArcs(s1), fst2.NumArcs(s2));
107     CHECK_EQ(fst1.NumInputEpsilons(s1), fst2.NumInputEpsilons(s2));
108     CHECK_EQ(fst1.NumOutputEpsilons(s1), fst2.NumOutputEpsilons(s2));
109 
110     siter1.Next();
111     siter2.Next();
112   }
113   return true;
114 }
115 
116 }  // namespace fst
117 
118 
119 #endif  // FST_LIB_EQUAL_H__
120