1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Compares two TensorFlow graphs to see if their meaning is the same. This is a
17 // semantic comparison that's intended to show whether the graphs should produce
18 // the same results, and so ignores details like version numbers or node
19 // ordering that don't affect the output. To use it, run something like this:
20 //
21 // bazel build tensorflow/tools/graph_transforms:compare_graphs
22 // bazel-bin/tensorflow/tools/graph_transforms/compare_graphs a.pb b.pb
23 //
24 // The return value is 0 if the graphs are equal, 1 if they're different, and -1
25 // if there was a problem.
26
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/init_main.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/command_line_flags.h"
31 #include "tensorflow/core/util/equal_graph_def.h"
32 #include "tensorflow/tools/graph_transforms/file_utils.h"
33 #include "tensorflow/tools/graph_transforms/transform_utils.h"
34
35 namespace tensorflow {
36 namespace graph_transforms {
37 namespace {
38
ParseFlagsAndCompareGraphs(int argc,char * argv[])39 int ParseFlagsAndCompareGraphs(int argc, char* argv[]) {
40 // We need to call this to set up global state for TensorFlow.
41 port::InitMain(argv[0], &argc, &argv);
42
43 if (argc != 3) {
44 LOG(ERROR) << "compare_graphs expects two file names as arguments";
45 return -1;
46 }
47
48 GraphDef a;
49 Status a_load_status = LoadTextOrBinaryGraphFile(argv[1], &a);
50 if (!a_load_status.ok()) {
51 LOG(ERROR) << "Loading graph '" << argv[1] << "' failed with "
52 << a_load_status.error_message();
53 return -1;
54 }
55
56 GraphDef b;
57 Status b_load_status = LoadTextOrBinaryGraphFile(argv[2], &b);
58 if (!b_load_status.ok()) {
59 LOG(ERROR) << "Loading graph '" << argv[2] << "' failed with "
60 << b_load_status.error_message();
61 return -1;
62 }
63
64 string diff;
65 if (EqualGraphDef(a, b, &diff)) {
66 std::cout << "Graphs are equal." << std::endl;
67 return 0;
68 } else {
69 std::cout << diff << std::endl;
70 return 1;
71 }
72 }
73
74 } // namespace
75 } // namespace graph_transforms
76 } // namespace tensorflow
77
main(int argc,char * argv[])78 int main(int argc, char* argv[]) {
79 return tensorflow::graph_transforms::ParseFlagsAndCompareGraphs(argc, argv);
80 }
81