1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/hash_utils.h"
16
17 #include <queue>
18
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_builder.h"
29 #include "tensorflow/core/framework/op_def_util.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/kernels/data/dataset_utils.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/lib/strings/proto_serialization.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/regexp.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/util/work_sharder.h"
42
43 namespace tensorflow {
44 namespace data {
45 namespace {
46
47 // clang-format off
48 constexpr std::array<const char*, 3> kOpsWithSeed = {
49 "AnonymousRandomSeedGenerator",
50 "ShuffleDataset",
51 "ShuffleAndRepeatDataset"
52 };
53 // clang-format on
54 constexpr char kSeedInputName[] = "seed";
55 constexpr char kSeed2InputName[] = "seed2";
56 constexpr char kSeedGeneratorInputName[] = "seed_generator";
57
58 template <std::size_t SIZE>
IsNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & op_types)59 bool IsNodeOfType(const NodeDef& node,
60 const std::array<const char*, SIZE>& op_types) {
61 for (const auto& type : op_types) {
62 if (MatchesAnyVersion(type, node.op())) {
63 return true;
64 }
65 }
66 return false;
67 }
68
FindNode(const GraphDef & graph,const string & name,const NodeDef ** result)69 Status FindNode(const GraphDef& graph, const string& name,
70 const NodeDef** result) {
71 for (const auto& node : graph.node()) {
72 if (node.name() == name) {
73 *result = &node;
74 return Status::OK();
75 }
76 }
77 return errors::NotFound("Could not find node ", name, ".");
78 }
79
GetSink(const GraphDef & graph_def,const NodeDef ** sink)80 Status GetSink(const GraphDef& graph_def, const NodeDef** sink) {
81 for (auto& node : graph_def.node()) {
82 if (node.op() == "_Retval") {
83 *sink = &node;
84 break;
85 }
86 }
87
88 if (sink == nullptr) {
89 return errors::Internal("Cannot find sink node for dataset graph.");
90 }
91 return Status::OK();
92 }
93
ShouldIgnoreInput(const NodeDef & node,int i,bool * result)94 Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) {
95 *result = false;
96 if (IsNodeOfType(node, kOpsWithSeed)) {
97 const OpRegistrationData* reg;
98 auto status = OpRegistry::Global()->LookUp(node.op(), ®);
99
100 if (status.ok()) {
101 if (reg->op_def.input_arg_size() > i) {
102 const std::string input_arg_name = reg->op_def.input_arg(i).name();
103 if (input_arg_name == kSeedInputName ||
104 input_arg_name == kSeed2InputName ||
105 input_arg_name == kSeedGeneratorInputName) {
106 VLOG(2) << "Ignoring arg: " << input_arg_name
107 << " from node: " << node.name();
108 *result = true;
109 return Status::OK();
110 }
111 }
112 } else if (errors::IsNotFound(status)) {
113 LOG(WARNING) << "Cannot find " << node.op()
114 << " in global op registry, so cannot determine which "
115 "inputs are seeds.";
116 } else {
117 return status;
118 }
119 }
120 return Status::OK();
121 }
122
ParseInputNodeName(const std::string & input_name,std::string * node_name,std::string * suffix,bool * is_control_input)123 Status ParseInputNodeName(const std::string& input_name, std::string* node_name,
124 std::string* suffix, bool* is_control_input) {
125 if (input_name[0] == '^') {
126 *node_name = input_name.substr(1);
127 *is_control_input = true;
128 return Status::OK();
129 }
130 std::pair<std::string, std::string> node_spec =
131 absl::StrSplit(input_name, absl::MaxSplits(':', 1));
132 *node_name = node_spec.first;
133 *suffix = node_spec.second;
134 *is_control_input = false;
135 return Status::OK();
136 }
137
138 // Given a graph_def and a root_node, this class computes a fingerprint that
139 // tries to capture the structure of the graph rooted at the provided node.
140 // It does not at any point rely on the names of the nodes in the graph and
141 // just relies on the connections between different nodes. In the presence of
142 // multiple cycles in the graph, there is a non-zero possibility that two
143 // graphs with different structure might end up with the same fingerprint
144 // as in order to break cycles we prune away some edges (in a deterministic
145 // fashion though). Idea for this algorithm was borrowed from:
146 // https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali
147 class GraphHasher {
148 public:
149 // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or
150 // `flib_def`.
GraphHasher(const GraphDef * graph,const NodeDef * root,const FunctionLibraryDefinition * flib)151 explicit GraphHasher(const GraphDef* graph, const NodeDef* root,
152 const FunctionLibraryDefinition* flib)
153 : graph_(graph), root_(root), flib_(flib) {}
154
Init()155 Status Init() {
156 // Pre-process the graph to do a BFS and prune away cycles that might cause
157 // problems.
158 absl::flat_hash_set<std::string> visited;
159 std::queue<const NodeDef*> bfs_queue;
160 bfs_queue.push(root_);
161 while (!bfs_queue.empty()) {
162 const NodeDef* node = bfs_queue.front();
163 bfs_queue.pop();
164 if (visited.contains(node->name())) {
165 continue;
166 }
167 visited.insert(node->name());
168 NodeRep node_rep;
169 for (int i = 0; i < node->input_size(); ++i) {
170 DCHECK_GT(node->input(i).length(), 0);
171
172 // We skip trying to take the hash of the seeds of any ops, as they
173 // are irrelevant to the hash of the graph and may vary from run to run.
174 bool should_ignore_input = false;
175 TF_RETURN_IF_ERROR(ShouldIgnoreInput(*node, i, &should_ignore_input));
176 if (should_ignore_input) continue;
177
178 std::string node_name, suffix;
179 bool is_control_input;
180 TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name,
181 &suffix, &is_control_input));
182 const NodeDef* input_node;
183 TF_RETURN_IF_ERROR(FindNode(*graph_, node_name, &input_node));
184
185 // If we've already seen this node before, skip it and don't add it to
186 // the queue.
187 if (visited.find(node_name) != visited.end()) {
188 EdgeRep cycle_edge(node, input_node);
189 cycle_forming_edges_.insert(cycle_edge.GetHash());
190 continue;
191 }
192 if (is_control_input) {
193 node_rep.node_control_inputs.push_back(input_node);
194 } else {
195 node_rep.node_inputs.push_back(std::make_pair(input_node, suffix));
196 bfs_queue.push(input_node);
197 }
198 }
199 nodes_[node] = node_rep;
200 }
201 return Status::OK();
202 }
203
HashRoot(uint64 * hash)204 Status HashRoot(uint64* hash) { return HashNode(root_, hash); }
205
CheckEqual(GraphHasher * that)206 Status CheckEqual(GraphHasher* that) {
207 return CheckNodesEqual(root_, that, that->root_);
208 }
209
210 private:
HashNode(const NodeDef * node,uint64 * hash)211 Status HashNode(const NodeDef* node, uint64* hash) {
212 auto it = cache_.find(node);
213 if (it != cache_.end()) {
214 *hash = it->second;
215 return Status::OK();
216 }
217
218 NodeRep* node_rep = gtl::FindOrNull(nodes_, node);
219 if (node_rep == nullptr) {
220 return errors::InvalidArgument("Could not find node: ", node->name());
221 }
222
223 uint64 non_input_hash;
224 TF_RETURN_IF_ERROR(
225 HashNodeNonInput(node, /*hash_functions=*/true, &non_input_hash));
226
227 uint64 control_inputs_hash;
228 TF_RETURN_IF_ERROR(
229 HashControlInputs(node_rep->node_control_inputs, &control_inputs_hash));
230
231 // Hash regular inputs. We combine them in an ordered fashion.
232 uint64 inputs_hash = 0;
233 for (const auto& input : node_rep->node_inputs) {
234 uint64 node_hash = 0;
235 EdgeRep edge(node, input.first);
236 // If the edge was pruned we get the non input node hash to avoid cycles.
237 if (cycle_forming_edges_.find(edge.GetHash()) !=
238 cycle_forming_edges_.end()) {
239 TF_RETURN_IF_ERROR(
240 HashNodeNonInput(input.first, /*hash_functions=*/true, &node_hash));
241 } else {
242 TF_RETURN_IF_ERROR(HashNode(input.first, &node_hash));
243 }
244 inputs_hash = Hash64Combine(
245 inputs_hash, Hash64Combine(node_hash, Hash64(input.second)));
246 }
247
248 *hash = Hash64Combine(non_input_hash,
249 Hash64Combine(control_inputs_hash, inputs_hash));
250 cache_[node] = *hash;
251 return Status::OK();
252 }
253
CheckNodesEqual(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node)254 Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that,
255 const NodeDef* that_node) {
256 Status s = CheckNodesEqualHelper(this_node, that, that_node);
257 if (!s.ok()) {
258 return errors::FailedPrecondition("Nodes ", this_node->name(), " and ",
259 that_node->name(),
260 " are not the same:\n", s);
261 }
262 return s;
263 }
264
CheckNodesEqualHelper(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node)265 Status CheckNodesEqualHelper(const NodeDef* this_node, GraphHasher* that,
266 const NodeDef* that_node) {
267 TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_node, that, that_node,
268 /*compare_functions=*/true));
269
270 TF_RETURN_IF_ERROR(
271 CheckControlInputsEqual(nodes_[this_node].node_control_inputs, that,
272 that->nodes_[that_node].node_control_inputs));
273
274 auto& this_node_inputs = nodes_[this_node].node_inputs;
275 auto& that_node_inputs = that->nodes_[that_node].node_inputs;
276 if (this_node_inputs.size() != that_node_inputs.size()) {
277 return errors::FailedPrecondition(
278 "Nodes have different numbers of node inputs: ",
279 this_node_inputs.size(), " vs ", that_node_inputs.size());
280 }
281 for (int i = 0; i < this_node_inputs.size(); ++i) {
282 const NodeDef* this_input = this_node_inputs[i].first;
283 const NodeDef* that_input = that_node_inputs[i].first;
284 if (is_cycle_forming_edge(this_node, this_input)) {
285 TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_input, that, that_input,
286 /*compare_functions=*/true));
287 } else {
288 TF_RETURN_IF_ERROR(CheckNodesEqual(this_input, that, that_input));
289 }
290 std::string this_input_suffix = this_node_inputs[i].second;
291 std::string that_input_suffix = that_node_inputs[i].second;
292 if (this_input_suffix != that_input_suffix) {
293 return errors::FailedPrecondition(
294 "Node inputs ", this_input->name(), " and ", that_input->name(),
295 " have different suffixes: ", this_input_suffix, " vs ",
296 that_input_suffix);
297 }
298 }
299 return Status::OK();
300 }
301
HashNodeNonInput(const NodeDef * node,bool hash_functions,uint64 * hash)302 Status HashNodeNonInput(const NodeDef* node, bool hash_functions,
303 uint64* hash) {
304 // Hash Attrs. We get the list of attrs from the op registry and then look
305 // up their values in the NodeDef attr map. This avoids looping over
306 // a map which is non-deterministic.
307 uint64 attrs_hash = 0;
308 const OpRegistrationData* reg;
309 TF_RETURN_IF_ERROR(flib_->LookUp(node->op(), ®));
310 uint64 op_hash = 0;
311 if (reg->is_function_op) {
312 if (hash_functions) {
313 TF_RETURN_IF_ERROR(HashFunction(node->op(), node->attr(), &op_hash));
314 }
315 } else {
316 op_hash = Hash64(node->op());
317 }
318
319 for (const auto& attr : reg->op_def.attr()) {
320 const auto& attr_key = attr.name();
321 if (!node->attr().contains(attr_key)) continue;
322 auto attr_value = node->attr().at(attr_key);
323 if (attr_key == kColocationAttrName ||
324 attr_key == kColocationGroupPrefix) {
325 continue;
326 }
327 uint64 attr_hash = 0;
328 TF_RETURN_IF_ERROR(
329 HashAttr(attr_key, attr_value, hash_functions, &attr_hash));
330 attrs_hash = Hash64Combine(attrs_hash, attr_hash);
331 }
332
333 // Hash Device.
334 uint64 device_hash = Hash64(node->device());
335
336 *hash = Hash64Combine(op_hash, Hash64Combine(attrs_hash, device_hash));
337 return Status::OK();
338 }
339
CheckNodesEqualNonInput(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node,bool compare_functions)340 Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that,
341 const NodeDef* that_node,
342 bool compare_functions) {
343 // We get the list of attrs from the op registry and then look
344 // up their values in the NodeDef attr map. This avoids looping over
345 // a map which is non-deterministic.
346 const OpRegistrationData* reg;
347 TF_RETURN_IF_ERROR(flib_->LookUp(this_node->op(), ®));
348 if (reg->is_function_op) {
349 if (compare_functions) {
350 TF_RETURN_IF_ERROR(
351 CheckFunctionsEqual(this_node->op(), this_node->attr(), that,
352 that_node->op(), that_node->attr()));
353 }
354 } else {
355 if (this_node->op() != that_node->op()) {
356 return errors::FailedPrecondition(
357 "ops for nodes ", this_node->name(), " and ", that_node->name(),
358 " are different: ", this_node->op(), " != ", that_node->op());
359 }
360 }
361
362 for (const auto& attr : reg->op_def.attr()) {
363 const auto& attr_key = attr.name();
364 if (this_node->attr().contains(attr_key) !=
365 that_node->attr().contains(attr_key)) {
366 return errors::FailedPrecondition(
367 "attr with key ", attr_key, " is different for nodes ",
368 this_node->name(), " and ", that_node->name(),
369 ". Present in former: ", this_node->attr().contains(attr_key),
370 ". Present in latter: ", that_node->attr().contains(attr_key));
371 }
372 if (!this_node->attr().contains(attr_key)) continue;
373 if (attr_key == kColocationAttrName ||
374 attr_key == kColocationGroupPrefix) {
375 continue;
376 }
377 auto this_attr = this_node->attr().at(attr_key);
378 auto that_attr = that_node->attr().at(attr_key);
379 TF_RETURN_IF_ERROR(CheckAttrsEqual(attr_key, this_attr, that, that_attr,
380 compare_functions));
381 }
382
383 if (this_node->device() != that_node->device()) {
384 return errors::FailedPrecondition(
385 "Devices are different for nodes ", this_node->name(), " and ",
386 that_node->name(), ": ", this_node->device(), " vs ",
387 that_node->device());
388 }
389 return Status::OK();
390 }
391
HashAttr(const std::string & attr_name,const AttrValue & attr_value,bool hash_functions,uint64 * hash)392 Status HashAttr(const std::string& attr_name, const AttrValue& attr_value,
393 bool hash_functions, uint64* hash) {
394 uint64 value_hash = 0;
395 if (attr_value.has_func()) {
396 if (hash_functions) {
397 TF_RETURN_IF_ERROR(HashFunction(attr_value.func(), &value_hash));
398 }
399 } else if (attr_value.has_list() && attr_value.list().func_size() > 0) {
400 if (hash_functions) {
401 for (auto& func : attr_value.list().func()) {
402 uint64 func_hash;
403 TF_RETURN_IF_ERROR(HashFunction(func, &func_hash));
404 value_hash = Hash64Combine(value_hash, func_hash);
405 }
406 }
407 } else {
408 value_hash = DeterministicProtoHash64(attr_value);
409 }
410 *hash = Hash64(absl::StrCat(attr_name, "=", value_hash));
411 return Status::OK();
412 }
413
CheckAttrsEqual(const std::string & attr_name,const AttrValue & this_attr,GraphHasher * that,const AttrValue & that_attr,bool compare_functions)414 Status CheckAttrsEqual(const std::string& attr_name,
415 const AttrValue& this_attr, GraphHasher* that,
416 const AttrValue& that_attr, bool compare_functions) {
417 if (this_attr.has_func() != that_attr.has_func()) {
418 return errors::FailedPrecondition(
419 "AttrValues are of different types: ", this_attr.DebugString(),
420 " vs ", that_attr.DebugString());
421 }
422 if (this_attr.has_func()) {
423 if (compare_functions) {
424 TF_RETURN_IF_ERROR(
425 CheckFunctionsEqual(this_attr.func(), that, that_attr.func()));
426 }
427 return Status::OK();
428 }
429 if (this_attr.has_list() != that_attr.has_list()) {
430 return errors::FailedPrecondition(
431 "AttrValues are of different types: ", this_attr.DebugString(),
432 " vs ", that_attr.DebugString());
433 }
434 if (this_attr.has_list()) {
435 if (this_attr.list().func_size() != that_attr.list().func_size()) {
436 return errors::FailedPrecondition(
437 "AttrValues have func lists of different sizes: ",
438 this_attr.DebugString(), " vs ", that_attr.DebugString());
439 }
440 if (compare_functions) {
441 for (int i = 0; i < this_attr.list().func_size(); ++i) {
442 TF_RETURN_IF_ERROR(CheckFunctionsEqual(this_attr.list().func(i), that,
443 that_attr.list().func(i)));
444 }
445 }
446 return Status::OK();
447 }
448 uint64 this_hash, that_hash;
449 TF_RETURN_IF_ERROR(
450 HashAttr(attr_name, this_attr, /*hash_functions=*/true, &this_hash));
451 TF_RETURN_IF_ERROR(that->HashAttr(attr_name, that_attr,
452 /*hash_functions=*/true, &that_hash));
453 if (this_hash != that_hash) {
454 return errors::FailedPrecondition(
455 "AttrValues are different: ", this_attr.DebugString(), " vs ",
456 that_attr.DebugString());
457 }
458 return Status::OK();
459 }
460
HashFunction(const NameAttrList & func,uint64 * hash)461 Status HashFunction(const NameAttrList& func, uint64* hash) {
462 return HashFunction(func.name(), func.attr(), hash);
463 }
464
HashFunction(const std::string & name,const AttrValueMap & attrs,uint64 * hash)465 Status HashFunction(const std::string& name, const AttrValueMap& attrs,
466 uint64* hash) {
467 const FunctionDef* fdef = flib_->Find(name);
468
469 // Convert to a GraphDef.
470 std::unique_ptr<FunctionBody> fbody;
471 TF_RETURN_IF_ERROR(
472 FunctionDefToBodyHelper(*fdef, AttrSlice(&attrs), flib_, &fbody));
473 GraphDef graph_def = fbody->graph->ToGraphDefDebug();
474
475 // For each return node, we create a new GraphHasher to compute a hash.
476 // We then combine these hashes to produce the hash ordered.
477 uint64 ret_nodes_hash = 0;
478 for (const auto& ret_node : fbody->ret_nodes) {
479 uint64 ret_node_hash = 0;
480 GraphHasher hasher(&graph_def, &ret_node->def(), flib_);
481 TF_RETURN_IF_ERROR(hasher.Init());
482 TF_RETURN_IF_ERROR(hasher.HashRoot(&ret_node_hash));
483 ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash);
484 }
485
486 std::vector<const NodeDef*> control_rets;
487 for (const auto& control_ret_node : fbody->control_ret_nodes) {
488 control_rets.push_back(&control_ret_node->def());
489 }
490 uint64 control_ret_nodes_hash = 0;
491 TF_RETURN_IF_ERROR(
492 HashControlInputs(control_rets, &control_ret_nodes_hash));
493
494 *hash = Hash64Combine(ret_nodes_hash, control_ret_nodes_hash);
495 return Status::OK();
496 }
497
CheckFunctionsEqual(const NameAttrList & this_func,GraphHasher * that,const NameAttrList & that_func)498 Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that,
499 const NameAttrList& that_func) {
500 return CheckFunctionsEqual(this_func.name(), this_func.attr(), that,
501 that_func.name(), that_func.attr());
502 }
CheckFunctionsEqual(const std::string & this_name,const AttrValueMap & this_attrs,GraphHasher * that,const std::string & that_name,const AttrValueMap & that_attrs)503 Status CheckFunctionsEqual(const std::string& this_name,
504 const AttrValueMap& this_attrs, GraphHasher* that,
505 const std::string& that_name,
506 const AttrValueMap& that_attrs) {
507 Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, that_name,
508 that_attrs);
509 if (!s.ok()) {
510 return errors::FailedPrecondition("Functions ", this_name, " and ",
511 that_name, " are not the same:\n", s);
512 }
513 return s;
514 }
515
CheckFunctionsEqualHelper(const std::string & this_name,const AttrValueMap & this_attrs,GraphHasher * that,const std::string & that_name,const AttrValueMap & that_attrs)516 Status CheckFunctionsEqualHelper(const std::string& this_name,
517 const AttrValueMap& this_attrs,
518 GraphHasher* that,
519 const std::string& that_name,
520 const AttrValueMap& that_attrs) {
521 const FunctionDef* this_fdef = flib_->Find(this_name);
522 const FunctionDef* that_fdef = that->flib_->Find(that_name);
523
524 // Convert to GraphDefs.
525 std::unique_ptr<FunctionBody> this_fbody;
526 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
527 *this_fdef, AttrSlice(&this_attrs), flib_, &this_fbody));
528 GraphDef this_graph_def = this_fbody->graph->ToGraphDefDebug();
529 std::unique_ptr<FunctionBody> that_fbody;
530 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
531 *that_fdef, AttrSlice(&that_attrs), that->flib_, &that_fbody));
532 GraphDef that_graph_def = that_fbody->graph->ToGraphDefDebug();
533
534 if (this_fbody->ret_nodes.size() != that_fbody->ret_nodes.size()) {
535 return errors::FailedPrecondition(
536 "Different numbers of ret nodes for functions ", this_name, " and ",
537 that_name, ": ", this_fbody->ret_nodes.size(), " vs ",
538 that_fbody->ret_nodes.size());
539 }
540 for (int i = 0; i < this_fbody->ret_nodes.size(); ++i) {
541 const NodeDef* this_root = &this_fbody->ret_nodes[i]->def();
542 const NodeDef* that_root = &that_fbody->ret_nodes[i]->def();
543 GraphHasher this_hasher(&this_graph_def, this_root, flib_);
544 TF_RETURN_IF_ERROR(this_hasher.Init());
545 GraphHasher that_hasher(&that_graph_def, that_root, that->flib_);
546 TF_RETURN_IF_ERROR(that_hasher.Init());
547 TF_RETURN_IF_ERROR(this_hasher.CheckEqual(&that_hasher));
548 }
549
550 std::vector<const NodeDef*> this_control_rets;
551 for (const auto& control_ret_node : this_fbody->control_ret_nodes) {
552 this_control_rets.push_back(&control_ret_node->def());
553 }
554 std::vector<const NodeDef*> that_control_rets;
555 for (const auto& control_ret_node : that_fbody->control_ret_nodes) {
556 that_control_rets.push_back(&control_ret_node->def());
557 }
558 TF_RETURN_IF_ERROR(
559 CheckControlInputsEqual(this_control_rets, that, that_control_rets));
560 return Status::OK();
561 }
562
HashControlInputs(const std::vector<const NodeDef * > & inputs,uint64 * hash)563 Status HashControlInputs(const std::vector<const NodeDef*>& inputs,
564 uint64* hash) {
565 *hash = 0;
566 for (const NodeDef* input : inputs) {
567 uint64 node_hash = 0;
568 TF_RETURN_IF_ERROR(
569 HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
570 *hash = Hash64CombineUnordered(*hash, node_hash);
571 }
572 return Status::OK();
573 }
574
CheckControlInputsEqual(const std::vector<const NodeDef * > & this_inputs,GraphHasher * that,const std::vector<const NodeDef * > & that_inputs)575 Status CheckControlInputsEqual(
576 const std::vector<const NodeDef*>& this_inputs, GraphHasher* that,
577 const std::vector<const NodeDef*>& that_inputs) {
578 absl::flat_hash_map<uint64, const NodeDef*> this_hashes;
579 for (const NodeDef* input : this_inputs) {
580 uint64 node_hash = 0;
581 TF_RETURN_IF_ERROR(
582 HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
583 this_hashes[node_hash] = input;
584 }
585 absl::flat_hash_map<uint64, const NodeDef*> that_hashes;
586 for (const NodeDef* input : that_inputs) {
587 uint64 node_hash = 0;
588 TF_RETURN_IF_ERROR(
589 HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
590 if (this_hashes.contains(node_hash)) {
591 this_hashes.erase(node_hash);
592 } else {
593 that_hashes[node_hash] = input;
594 }
595 }
596 if (!this_hashes.empty()) {
597 std::vector<std::string> this_unmatched;
598 for (const auto& it : this_hashes) {
599 this_unmatched.push_back(it.second->name());
600 }
601 std::vector<std::string> that_unmatched;
602 for (const auto& it : that_hashes) {
603 that_unmatched.push_back(it.second->name());
604 }
605 return errors::FailedPrecondition(
606 "Control dependencies are different. One node has dependencies [",
607 absl::StrJoin(this_unmatched, ", "),
608 "], which don't match any of the other node's dependencies [",
609 absl::StrJoin(that_unmatched, ", "), "]");
610 }
611 return Status::OK();
612 }
613
614 private:
is_cycle_forming_edge(const NodeDef * start,const NodeDef * end)615 bool is_cycle_forming_edge(const NodeDef* start, const NodeDef* end) {
616 EdgeRep edge(start, end);
617 return cycle_forming_edges_.contains(edge.GetHash());
618 }
619
620 struct NodeRep {
621 std::vector<const NodeDef*> node_control_inputs;
622 std::vector<std::pair<const NodeDef*, std::string>> node_inputs;
623 };
624
625 struct EdgeRep {
626 const NodeDef* start_node;
627 const NodeDef* end_node;
628
EdgeReptensorflow::data::__anon391c42a80111::GraphHasher::EdgeRep629 EdgeRep(const NodeDef* start, const NodeDef* end)
630 : start_node(start), end_node(end) {}
631
GetHashtensorflow::data::__anon391c42a80111::GraphHasher::EdgeRep632 uint64 GetHash() {
633 return Hash64Combine(absl::Hash<const NodeDef*>()(start_node),
634 absl::Hash<const NodeDef*>()(end_node));
635 }
636 };
637 const GraphDef* const graph_; // Not owned.
638 const NodeDef* const root_; // Not owned.
639 const FunctionLibraryDefinition* const flib_; // Not owned.
640 // Edges that need to be pruned as their presence will cause cycles.
641 absl::flat_hash_set<uint64> cycle_forming_edges_;
642 absl::flat_hash_map<const NodeDef*, NodeRep> nodes_;
643 absl::flat_hash_map<const NodeDef*, uint64> cache_;
644 };
645
646 } // anonymous namespace
647
HashTensor(const Tensor & tensor,uint64 * hash)648 Status HashTensor(const Tensor& tensor, uint64* hash) {
649 const tstring* s = nullptr;
650 // Hash tensor type.
651 *hash = Hash64Combine(0, tensor.dtype());
652 // Hash tensor shape.
653 for (int i = 0; i < tensor.shape().dims(); ++i) {
654 *hash = Hash64Combine(*hash, tensor.shape().dim_size(i));
655 }
656 // Hash tensor data.
657 switch (tensor.dtype()) {
658 case DT_RESOURCE:
659 case DT_VARIANT:
660 return errors::Unimplemented("Hashing ", DataTypeString(tensor.dtype()),
661 " is not supported.");
662 case DT_STRING:
663 s = tensor.flat<tstring>().data();
664 for (int i = 0; i < tensor.NumElements(); ++i, ++s) {
665 *hash = Hash64Combine(*hash, Hash64(s->data(), s->size()));
666 }
667 break;
668 default:
669 *hash = Hash64(tensor.tensor_data().data(), tensor.tensor_data().size());
670 }
671 return Status::OK();
672 }
673
HashNode(const GraphDef & graph,const NodeDef & node,uint64 * hash)674 Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) {
675 const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
676 graph.library());
677 return HashNode(graph, node, flib_def, hash);
678 }
679
HashNode(const GraphDef & graph,const NodeDef & node,const FunctionLibraryDefinition & flib_def,uint64 * hash)680 Status HashNode(const GraphDef& graph, const NodeDef& node,
681 const FunctionLibraryDefinition& flib_def, uint64* hash) {
682 GraphHasher hasher(&graph, &node, &flib_def);
683 TF_RETURN_IF_ERROR(hasher.Init());
684 return hasher.HashRoot(hash);
685 }
686
HashGraph(const GraphDef & graph_def,uint64 * hash)687 Status HashGraph(const GraphDef& graph_def, uint64* hash) {
688 const NodeDef* sink = nullptr;
689 TF_RETURN_IF_ERROR(GetSink(graph_def, &sink));
690 return HashNode(graph_def, *sink, hash);
691 }
692
CheckGraphsEqual(const GraphDef & a,const GraphDef & b)693 Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) {
694 const NodeDef* sink_a;
695 TF_RETURN_IF_ERROR(GetSink(a, &sink_a));
696 const NodeDef* sink_b;
697 TF_RETURN_IF_ERROR(GetSink(b, &sink_b));
698 return CheckSubgraphsEqual(a, sink_a, b, sink_b);
699 }
700
CheckSubgraphsEqual(const GraphDef & a,const NodeDef * node_a,const GraphDef & b,const NodeDef * node_b)701 Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a,
702 const GraphDef& b, const NodeDef* node_b) {
703 const FunctionLibraryDefinition flib_def_a(OpRegistry::Global(), a.library());
704 GraphHasher hasher_a(&a, node_a, &flib_def_a);
705 TF_RETURN_IF_ERROR(hasher_a.Init());
706
707 const FunctionLibraryDefinition flib_def_b(OpRegistry::Global(), b.library());
708 GraphHasher hasher_b(&b, node_b, &flib_def_b);
709 TF_RETURN_IF_ERROR(hasher_b.Init());
710
711 return hasher_a.CheckEqual(&hasher_b);
712 }
713
714 } // namespace data
715 } // namespace tensorflow
716