1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/equal_graph_def.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/protobuf.h"
29
30 namespace tensorflow {
31
EqualGraphDef(const GraphDef & actual,const GraphDef & expected,string * diff,const EqualGraphDefOptions & options)32 bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
33 string* diff, const EqualGraphDefOptions& options) {
34 // Intentionally do not check that versions match so that this routine can
35 // be used for less brittle golden file tests.
36 return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
37 }
38
GraphDefHash(const GraphDef & gdef,const EqualGraphDefOptions & options)39 uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) {
40 return RepeatedNodeDefHash(gdef.node(), options);
41 }
42
EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef> & actual,const protobuf::RepeatedPtrField<NodeDef> & expected,string * diff,const EqualGraphDefOptions & options)43 bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
44 const protobuf::RepeatedPtrField<NodeDef>& expected,
45 string* diff, const EqualGraphDefOptions& options) {
46 std::unordered_map<string, const NodeDef*> actual_index;
47 for (const NodeDef& node : actual) {
48 actual_index[node.name()] = &node;
49 }
50
51 for (const NodeDef& expected_node : expected) {
52 auto actual_iter = actual_index.find(expected_node.name());
53 if (actual_iter == actual_index.end()) {
54 if (diff != nullptr) {
55 *diff = strings::StrCat("Did not find expected node '",
56 SummarizeNodeDef(expected_node), "'");
57 }
58 return false;
59 }
60
61 if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
62 return false;
63 }
64
65 actual_index.erase(actual_iter);
66 }
67
68 if (!actual_index.empty()) {
69 if (diff != nullptr) {
70 *diff =
71 strings::StrCat("Found unexpected node '",
72 SummarizeNodeDef(*actual_index.begin()->second), "'");
73 }
74 return false;
75 }
76
77 return true;
78 }
79
RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef> & ndefs,const EqualGraphDefOptions & options)80 uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs,
81 const EqualGraphDefOptions& options) {
82 uint64 h = 0xDECAFCAFFE;
83 // Insert NodeDefs into map to deterministically sort by name
84 std::map<string, const NodeDef*> nodes;
85 for (const NodeDef& node : ndefs) {
86 nodes[node.name()] = &node;
87 }
88 for (const auto& pair : nodes) {
89 h = Hash64(pair.first.data(), pair.first.size(), h);
90 h = Hash64Combine(NodeDefHash(*pair.second, options), h);
91 }
92 return h;
93 }
94
95 namespace {
96
JoinStringField(const protobuf::RepeatedPtrField<string> & f)97 string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
98 string ret;
99 for (int i = 0; i < f.size(); ++i) {
100 if (i > 0) strings::StrAppend(&ret, ", ");
101 strings::StrAppend(&ret, f.Get(i));
102 }
103 return ret;
104 }
105
106 } // namespace
107
EqualNodeDef(const NodeDef & actual,const NodeDef & expected,string * diff,const EqualGraphDefOptions & options)108 bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
109 const EqualGraphDefOptions& options) {
110 if (actual.name() != expected.name()) {
111 if (diff != nullptr) {
112 *diff = strings::StrCat("Actual node name '", actual.name(),
113 "' is not expected '", expected.name(), "'");
114 }
115 return false;
116 }
117
118 if (actual.op() != expected.op()) {
119 if (diff != nullptr) {
120 *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
121 actual.op(), "' that is not expected '",
122 expected.op(), "'");
123 }
124 return false;
125 }
126
127 if (actual.device() != expected.device()) {
128 if (diff != nullptr) {
129 *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
130 actual.device(), "' that is not expected '",
131 expected.device(), "'");
132 }
133 return false;
134 }
135
136 if (actual.input_size() != expected.input_size()) {
137 if (diff != nullptr) {
138 *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
139 JoinStringField(actual.input()),
140 "' that don't match expected '",
141 JoinStringField(expected.input()), "'");
142 }
143 return false;
144 }
145
146 int first_control_input = actual.input_size();
147 for (int i = 0; i < actual.input_size(); ++i) {
148 if (absl::StartsWith(actual.input(i), "^")) {
149 first_control_input = i;
150 break;
151 }
152 // Special case for inputs: "tensor" is equivalent to "tensor:0"
153 if (actual.input(i) != expected.input(i) &&
154 actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
155 strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
156 if (diff != nullptr) {
157 *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
158 i, " '", actual.input(i),
159 "' that doesn't match expected '",
160 expected.input(i), "'");
161 }
162 return false;
163 }
164 }
165
166 std::unordered_set<string> actual_control;
167 std::unordered_set<string> expected_control;
168 for (int i = first_control_input; i < actual.input_size(); ++i) {
169 actual_control.insert(actual.input(i));
170 expected_control.insert(expected.input(i));
171 }
172 for (const auto& e : expected_control) {
173 if (actual_control.erase(e) == 0) {
174 if (diff != nullptr) {
175 *diff = strings::StrCat("Node named '", actual.name(),
176 "' missing expected control input '", e, "'");
177 }
178 return false;
179 }
180 }
181 if (!actual_control.empty()) {
182 if (diff != nullptr) {
183 *diff = strings::StrCat("Node named '", actual.name(),
184 "' has unexpected control input '",
185 *actual_control.begin(), "'");
186 }
187 return false;
188 }
189
190 std::unordered_set<string> actual_attr;
191 for (const auto& a : actual.attr()) {
192 if (options.ignore_internal_attrs && !a.first.empty() &&
193 a.first[0] == '_') {
194 continue;
195 }
196 actual_attr.insert(a.first);
197 }
198 for (const auto& e : expected.attr()) {
199 if (options.ignore_internal_attrs && !e.first.empty() &&
200 e.first[0] == '_') {
201 continue;
202 }
203
204 if (actual_attr.erase(e.first) == 0) {
205 if (diff != nullptr) {
206 *diff = strings::StrCat("Node named '", actual.name(),
207 "' missing expected attr '", e.first,
208 "' with value: ", SummarizeAttrValue(e.second));
209 }
210 return false;
211 }
212 auto iter = actual.attr().find(e.first);
213 if (!AreAttrValuesEqual(e.second, iter->second)) {
214 if (diff != nullptr) {
215 *diff = strings::StrCat(
216 "Node named '", actual.name(), "' has attr '", e.first,
217 "' with value: ", SummarizeAttrValue(iter->second),
218 " that does not match expected: ", SummarizeAttrValue(e.second));
219 }
220 return false;
221 }
222 }
223 if (!actual_attr.empty()) {
224 if (diff != nullptr) {
225 *diff = strings::StrCat(
226 "Node named '", actual.name(), "' has unexpected attr '",
227 *actual_attr.begin(), "' with value: ",
228 SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
229 }
230 return false;
231 }
232
233 return true;
234 }
235
NodeDefHash(const NodeDef & ndef,const EqualGraphDefOptions & options)236 uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) {
237 uint64 h = Hash64(ndef.name());
238 h = Hash64(ndef.op().data(), ndef.op().size(), h);
239 h = Hash64(ndef.device().data(), ndef.device().size(), h);
240
241 // Normal inputs. Order important.
242 int first_control_input = ndef.input_size();
243 for (int i = 0; i < ndef.input_size(); ++i) {
244 if (absl::StartsWith(ndef.input(i), "^")) {
245 first_control_input = i;
246 break;
247 }
248 h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h);
249 }
250
251 // Control inputs. Order irrelevant.
252 std::set<string> ndef_control;
253 for (int i = first_control_input; i < ndef.input_size(); ++i) {
254 ndef_control.insert(ndef.input(i));
255 }
256 for (const string& s : ndef_control) {
257 h = Hash64(s.data(), s.size(), h);
258 }
259
260 // Attributes
261 std::map<string, AttrValue> ndef_attr;
262 for (const auto& a : ndef.attr()) {
263 if (options.ignore_internal_attrs && !a.first.empty() &&
264 a.first[0] == '_') {
265 continue;
266 }
267 ndef_attr[a.first] = a.second;
268 }
269 for (const auto& a : ndef_attr) {
270 h = Hash64(a.first.data(), a.first.size(), h);
271 h = Hash64Combine(AttrValueHash(a.second), h);
272 }
273
274 return h;
275 }
276
277 } // namespace tensorflow
278