• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/hash_utils.h"
16 
17 #include <queue>
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_builder.h"
29 #include "tensorflow/core/framework/op_def_util.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/kernels/data/dataset_utils.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/lib/strings/proto_serialization.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/regexp.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/util/work_sharder.h"
42 
43 namespace tensorflow {
44 namespace data {
45 namespace {
46 
47 // clang-format off
48 constexpr std::array<const char*, 3> kOpsWithSeed = {
49     "AnonymousRandomSeedGenerator",
50     "ShuffleDataset",
51     "ShuffleAndRepeatDataset"
52 };
53 // clang-format on
54 constexpr char kSeedInputName[] = "seed";
55 constexpr char kSeed2InputName[] = "seed2";
56 constexpr char kSeedGeneratorInputName[] = "seed_generator";
57 
58 template <std::size_t SIZE>
IsNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & op_types)59 bool IsNodeOfType(const NodeDef& node,
60                   const std::array<const char*, SIZE>& op_types) {
61   for (const auto& type : op_types) {
62     if (MatchesAnyVersion(type, node.op())) {
63       return true;
64     }
65   }
66   return false;
67 }
68 
FindNode(const GraphDef & graph,const string & name,const NodeDef ** result)69 Status FindNode(const GraphDef& graph, const string& name,
70                 const NodeDef** result) {
71   for (const auto& node : graph.node()) {
72     if (node.name() == name) {
73       *result = &node;
74       return Status::OK();
75     }
76   }
77   return errors::NotFound("Could not find node ", name, ".");
78 }
79 
GetSink(const GraphDef & graph_def,const NodeDef ** sink)80 Status GetSink(const GraphDef& graph_def, const NodeDef** sink) {
81   for (auto& node : graph_def.node()) {
82     if (node.op() == "_Retval") {
83       *sink = &node;
84       break;
85     }
86   }
87 
88   if (sink == nullptr) {
89     return errors::Internal("Cannot find sink node for dataset graph.");
90   }
91   return Status::OK();
92 }
93 
ShouldIgnoreInput(const NodeDef & node,int i,bool * result)94 Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) {
95   *result = false;
96   if (IsNodeOfType(node, kOpsWithSeed)) {
97     const OpRegistrationData* reg;
98     auto status = OpRegistry::Global()->LookUp(node.op(), &reg);
99 
100     if (status.ok()) {
101       if (reg->op_def.input_arg_size() > i) {
102         const std::string input_arg_name = reg->op_def.input_arg(i).name();
103         if (input_arg_name == kSeedInputName ||
104             input_arg_name == kSeed2InputName ||
105             input_arg_name == kSeedGeneratorInputName) {
106           VLOG(2) << "Ignoring arg: " << input_arg_name
107                   << " from node: " << node.name();
108           *result = true;
109           return Status::OK();
110         }
111       }
112     } else if (errors::IsNotFound(status)) {
113       LOG(WARNING) << "Cannot find " << node.op()
114                    << " in global op registry, so cannot determine which "
115                       "inputs are seeds.";
116     } else {
117       return status;
118     }
119   }
120   return Status::OK();
121 }
122 
ParseInputNodeName(const std::string & input_name,std::string * node_name,std::string * suffix,bool * is_control_input)123 Status ParseInputNodeName(const std::string& input_name, std::string* node_name,
124                           std::string* suffix, bool* is_control_input) {
125   if (input_name[0] == '^') {
126     *node_name = input_name.substr(1);
127     *is_control_input = true;
128     return Status::OK();
129   }
130   std::pair<std::string, std::string> node_spec =
131       absl::StrSplit(input_name, absl::MaxSplits(':', 1));
132   *node_name = node_spec.first;
133   *suffix = node_spec.second;
134   *is_control_input = false;
135   return Status::OK();
136 }
137 
138 // Given a graph_def and a root_node, this class computes a fingerprint that
139 // tries to capture the structure of the graph rooted at the provided node.
140 // It does not at any point rely on the names of the nodes in the graph and
141 // just relies on the connections between different nodes. In the presence of
142 // multiple cycles in the graph, there is a non-zero possibility that two
143 // graphs with different structure might end up with the same fingerprint
144 // as in order to break cycles we prune away some edges (in a deterministic
145 // fashion though). Idea for this algorithm was borrowed from:
146 // https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali
147 class GraphHasher {
148  public:
149   // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or
150   // `flib_def`.
GraphHasher(const GraphDef * graph,const NodeDef * root,const FunctionLibraryDefinition * flib)151   explicit GraphHasher(const GraphDef* graph, const NodeDef* root,
152                        const FunctionLibraryDefinition* flib)
153       : graph_(graph), root_(root), flib_(flib) {}
154 
Init()155   Status Init() {
156     // Pre-process the graph to do a BFS and prune away cycles that might cause
157     // problems.
158     absl::flat_hash_set<std::string> visited;
159     std::queue<const NodeDef*> bfs_queue;
160     bfs_queue.push(root_);
161     while (!bfs_queue.empty()) {
162       const NodeDef* node = bfs_queue.front();
163       bfs_queue.pop();
164       if (visited.contains(node->name())) {
165         continue;
166       }
167       visited.insert(node->name());
168       NodeRep node_rep;
169       for (int i = 0; i < node->input_size(); ++i) {
170         DCHECK_GT(node->input(i).length(), 0);
171 
172         // We skip trying to take the hash of the seeds of any ops, as they
173         // are irrelevant to the hash of the graph and may vary from run to run.
174         bool should_ignore_input = false;
175         TF_RETURN_IF_ERROR(ShouldIgnoreInput(*node, i, &should_ignore_input));
176         if (should_ignore_input) continue;
177 
178         std::string node_name, suffix;
179         bool is_control_input;
180         TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name,
181                                               &suffix, &is_control_input));
182         const NodeDef* input_node;
183         TF_RETURN_IF_ERROR(FindNode(*graph_, node_name, &input_node));
184 
185         // If we've already seen this node before, skip it and don't add it to
186         // the queue.
187         if (visited.find(node_name) != visited.end()) {
188           EdgeRep cycle_edge(node, input_node);
189           cycle_forming_edges_.insert(cycle_edge.GetHash());
190           continue;
191         }
192         if (is_control_input) {
193           node_rep.node_control_inputs.push_back(input_node);
194         } else {
195           node_rep.node_inputs.push_back(std::make_pair(input_node, suffix));
196           bfs_queue.push(input_node);
197         }
198       }
199       nodes_[node] = node_rep;
200     }
201     return Status::OK();
202   }
203 
HashRoot(uint64 * hash)204   Status HashRoot(uint64* hash) { return HashNode(root_, hash); }
205 
CheckEqual(GraphHasher * that)206   Status CheckEqual(GraphHasher* that) {
207     return CheckNodesEqual(root_, that, that->root_);
208   }
209 
210  private:
HashNode(const NodeDef * node,uint64 * hash)211   Status HashNode(const NodeDef* node, uint64* hash) {
212     auto it = cache_.find(node);
213     if (it != cache_.end()) {
214       *hash = it->second;
215       return Status::OK();
216     }
217 
218     NodeRep* node_rep = gtl::FindOrNull(nodes_, node);
219     if (node_rep == nullptr) {
220       return errors::InvalidArgument("Could not find node: ", node->name());
221     }
222 
223     uint64 non_input_hash;
224     TF_RETURN_IF_ERROR(
225         HashNodeNonInput(node, /*hash_functions=*/true, &non_input_hash));
226 
227     uint64 control_inputs_hash;
228     TF_RETURN_IF_ERROR(
229         HashControlInputs(node_rep->node_control_inputs, &control_inputs_hash));
230 
231     // Hash regular inputs. We combine them in an ordered fashion.
232     uint64 inputs_hash = 0;
233     for (const auto& input : node_rep->node_inputs) {
234       uint64 node_hash = 0;
235       EdgeRep edge(node, input.first);
236       // If the edge was pruned we get the non input node hash to avoid cycles.
237       if (cycle_forming_edges_.find(edge.GetHash()) !=
238           cycle_forming_edges_.end()) {
239         TF_RETURN_IF_ERROR(
240             HashNodeNonInput(input.first, /*hash_functions=*/true, &node_hash));
241       } else {
242         TF_RETURN_IF_ERROR(HashNode(input.first, &node_hash));
243       }
244       inputs_hash = Hash64Combine(
245           inputs_hash, Hash64Combine(node_hash, Hash64(input.second)));
246     }
247 
248     *hash = Hash64Combine(non_input_hash,
249                           Hash64Combine(control_inputs_hash, inputs_hash));
250     cache_[node] = *hash;
251     return Status::OK();
252   }
253 
CheckNodesEqual(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node)254   Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that,
255                          const NodeDef* that_node) {
256     Status s = CheckNodesEqualHelper(this_node, that, that_node);
257     if (!s.ok()) {
258       return errors::FailedPrecondition("Nodes ", this_node->name(), " and ",
259                                         that_node->name(),
260                                         " are not the same:\n", s);
261     }
262     return s;
263   }
264 
CheckNodesEqualHelper(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node)265   Status CheckNodesEqualHelper(const NodeDef* this_node, GraphHasher* that,
266                                const NodeDef* that_node) {
267     TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_node, that, that_node,
268                                                /*compare_functions=*/true));
269 
270     TF_RETURN_IF_ERROR(
271         CheckControlInputsEqual(nodes_[this_node].node_control_inputs, that,
272                                 that->nodes_[that_node].node_control_inputs));
273 
274     auto& this_node_inputs = nodes_[this_node].node_inputs;
275     auto& that_node_inputs = that->nodes_[that_node].node_inputs;
276     if (this_node_inputs.size() != that_node_inputs.size()) {
277       return errors::FailedPrecondition(
278           "Nodes have different numbers of node inputs: ",
279           this_node_inputs.size(), " vs ", that_node_inputs.size());
280     }
281     for (int i = 0; i < this_node_inputs.size(); ++i) {
282       const NodeDef* this_input = this_node_inputs[i].first;
283       const NodeDef* that_input = that_node_inputs[i].first;
284       if (is_cycle_forming_edge(this_node, this_input)) {
285         TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_input, that, that_input,
286                                                    /*compare_functions=*/true));
287       } else {
288         TF_RETURN_IF_ERROR(CheckNodesEqual(this_input, that, that_input));
289       }
290       std::string this_input_suffix = this_node_inputs[i].second;
291       std::string that_input_suffix = that_node_inputs[i].second;
292       if (this_input_suffix != that_input_suffix) {
293         return errors::FailedPrecondition(
294             "Node inputs ", this_input->name(), " and ", that_input->name(),
295             " have different suffixes: ", this_input_suffix, " vs ",
296             that_input_suffix);
297       }
298     }
299     return Status::OK();
300   }
301 
HashNodeNonInput(const NodeDef * node,bool hash_functions,uint64 * hash)302   Status HashNodeNonInput(const NodeDef* node, bool hash_functions,
303                           uint64* hash) {
304     // Hash Attrs. We get the list of attrs from the op registry and then look
305     // up their values in the NodeDef attr map. This avoids looping over
306     // a map which is non-deterministic.
307     uint64 attrs_hash = 0;
308     const OpRegistrationData* reg;
309     TF_RETURN_IF_ERROR(flib_->LookUp(node->op(), &reg));
310     uint64 op_hash = 0;
311     if (reg->is_function_op) {
312       if (hash_functions) {
313         TF_RETURN_IF_ERROR(HashFunction(node->op(), node->attr(), &op_hash));
314       }
315     } else {
316       op_hash = Hash64(node->op());
317     }
318 
319     for (const auto& attr : reg->op_def.attr()) {
320       const auto& attr_key = attr.name();
321       if (!node->attr().contains(attr_key)) continue;
322       auto attr_value = node->attr().at(attr_key);
323       if (attr_key == kColocationAttrName ||
324           attr_key == kColocationGroupPrefix) {
325         continue;
326       }
327       uint64 attr_hash = 0;
328       TF_RETURN_IF_ERROR(
329           HashAttr(attr_key, attr_value, hash_functions, &attr_hash));
330       attrs_hash = Hash64Combine(attrs_hash, attr_hash);
331     }
332 
333     // Hash Device.
334     uint64 device_hash = Hash64(node->device());
335 
336     *hash = Hash64Combine(op_hash, Hash64Combine(attrs_hash, device_hash));
337     return Status::OK();
338   }
339 
CheckNodesEqualNonInput(const NodeDef * this_node,GraphHasher * that,const NodeDef * that_node,bool compare_functions)340   Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that,
341                                  const NodeDef* that_node,
342                                  bool compare_functions) {
343     // We get the list of attrs from the op registry and then look
344     // up their values in the NodeDef attr map. This avoids looping over
345     // a map which is non-deterministic.
346     const OpRegistrationData* reg;
347     TF_RETURN_IF_ERROR(flib_->LookUp(this_node->op(), &reg));
348     if (reg->is_function_op) {
349       if (compare_functions) {
350         TF_RETURN_IF_ERROR(
351             CheckFunctionsEqual(this_node->op(), this_node->attr(), that,
352                                 that_node->op(), that_node->attr()));
353       }
354     } else {
355       if (this_node->op() != that_node->op()) {
356         return errors::FailedPrecondition(
357             "ops for nodes ", this_node->name(), " and ", that_node->name(),
358             " are different: ", this_node->op(), " != ", that_node->op());
359       }
360     }
361 
362     for (const auto& attr : reg->op_def.attr()) {
363       const auto& attr_key = attr.name();
364       if (this_node->attr().contains(attr_key) !=
365           that_node->attr().contains(attr_key)) {
366         return errors::FailedPrecondition(
367             "attr with key ", attr_key, " is different for nodes ",
368             this_node->name(), " and ", that_node->name(),
369             ". Present in former: ", this_node->attr().contains(attr_key),
370             ". Present in latter: ", that_node->attr().contains(attr_key));
371       }
372       if (!this_node->attr().contains(attr_key)) continue;
373       if (attr_key == kColocationAttrName ||
374           attr_key == kColocationGroupPrefix) {
375         continue;
376       }
377       auto this_attr = this_node->attr().at(attr_key);
378       auto that_attr = that_node->attr().at(attr_key);
379       TF_RETURN_IF_ERROR(CheckAttrsEqual(attr_key, this_attr, that, that_attr,
380                                          compare_functions));
381     }
382 
383     if (this_node->device() != that_node->device()) {
384       return errors::FailedPrecondition(
385           "Devices are different for nodes ", this_node->name(), " and ",
386           that_node->name(), ": ", this_node->device(), " vs ",
387           that_node->device());
388     }
389     return Status::OK();
390   }
391 
HashAttr(const std::string & attr_name,const AttrValue & attr_value,bool hash_functions,uint64 * hash)392   Status HashAttr(const std::string& attr_name, const AttrValue& attr_value,
393                   bool hash_functions, uint64* hash) {
394     uint64 value_hash = 0;
395     if (attr_value.has_func()) {
396       if (hash_functions) {
397         TF_RETURN_IF_ERROR(HashFunction(attr_value.func(), &value_hash));
398       }
399     } else if (attr_value.has_list() && attr_value.list().func_size() > 0) {
400       if (hash_functions) {
401         for (auto& func : attr_value.list().func()) {
402           uint64 func_hash;
403           TF_RETURN_IF_ERROR(HashFunction(func, &func_hash));
404           value_hash = Hash64Combine(value_hash, func_hash);
405         }
406       }
407     } else {
408       value_hash = DeterministicProtoHash64(attr_value);
409     }
410     *hash = Hash64(absl::StrCat(attr_name, "=", value_hash));
411     return Status::OK();
412   }
413 
CheckAttrsEqual(const std::string & attr_name,const AttrValue & this_attr,GraphHasher * that,const AttrValue & that_attr,bool compare_functions)414   Status CheckAttrsEqual(const std::string& attr_name,
415                          const AttrValue& this_attr, GraphHasher* that,
416                          const AttrValue& that_attr, bool compare_functions) {
417     if (this_attr.has_func() != that_attr.has_func()) {
418       return errors::FailedPrecondition(
419           "AttrValues are of different types: ", this_attr.DebugString(),
420           " vs ", that_attr.DebugString());
421     }
422     if (this_attr.has_func()) {
423       if (compare_functions) {
424         TF_RETURN_IF_ERROR(
425             CheckFunctionsEqual(this_attr.func(), that, that_attr.func()));
426       }
427       return Status::OK();
428     }
429     if (this_attr.has_list() != that_attr.has_list()) {
430       return errors::FailedPrecondition(
431           "AttrValues are of different types: ", this_attr.DebugString(),
432           " vs ", that_attr.DebugString());
433     }
434     if (this_attr.has_list()) {
435       if (this_attr.list().func_size() != that_attr.list().func_size()) {
436         return errors::FailedPrecondition(
437             "AttrValues have func lists of different sizes: ",
438             this_attr.DebugString(), " vs ", that_attr.DebugString());
439       }
440       if (compare_functions) {
441         for (int i = 0; i < this_attr.list().func_size(); ++i) {
442           TF_RETURN_IF_ERROR(CheckFunctionsEqual(this_attr.list().func(i), that,
443                                                  that_attr.list().func(i)));
444         }
445       }
446       return Status::OK();
447     }
448     uint64 this_hash, that_hash;
449     TF_RETURN_IF_ERROR(
450         HashAttr(attr_name, this_attr, /*hash_functions=*/true, &this_hash));
451     TF_RETURN_IF_ERROR(that->HashAttr(attr_name, that_attr,
452                                       /*hash_functions=*/true, &that_hash));
453     if (this_hash != that_hash) {
454       return errors::FailedPrecondition(
455           "AttrValues are different: ", this_attr.DebugString(), " vs ",
456           that_attr.DebugString());
457     }
458     return Status::OK();
459   }
460 
HashFunction(const NameAttrList & func,uint64 * hash)461   Status HashFunction(const NameAttrList& func, uint64* hash) {
462     return HashFunction(func.name(), func.attr(), hash);
463   }
464 
HashFunction(const std::string & name,const AttrValueMap & attrs,uint64 * hash)465   Status HashFunction(const std::string& name, const AttrValueMap& attrs,
466                       uint64* hash) {
467     const FunctionDef* fdef = flib_->Find(name);
468 
469     // Convert to a GraphDef.
470     std::unique_ptr<FunctionBody> fbody;
471     TF_RETURN_IF_ERROR(
472         FunctionDefToBodyHelper(*fdef, AttrSlice(&attrs), flib_, &fbody));
473     GraphDef graph_def = fbody->graph->ToGraphDefDebug();
474 
475     // For each return node, we create a new GraphHasher to compute a hash.
476     // We then combine these hashes to produce the hash ordered.
477     uint64 ret_nodes_hash = 0;
478     for (const auto& ret_node : fbody->ret_nodes) {
479       uint64 ret_node_hash = 0;
480       GraphHasher hasher(&graph_def, &ret_node->def(), flib_);
481       TF_RETURN_IF_ERROR(hasher.Init());
482       TF_RETURN_IF_ERROR(hasher.HashRoot(&ret_node_hash));
483       ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash);
484     }
485 
486     std::vector<const NodeDef*> control_rets;
487     for (const auto& control_ret_node : fbody->control_ret_nodes) {
488       control_rets.push_back(&control_ret_node->def());
489     }
490     uint64 control_ret_nodes_hash = 0;
491     TF_RETURN_IF_ERROR(
492         HashControlInputs(control_rets, &control_ret_nodes_hash));
493 
494     *hash = Hash64Combine(ret_nodes_hash, control_ret_nodes_hash);
495     return Status::OK();
496   }
497 
CheckFunctionsEqual(const NameAttrList & this_func,GraphHasher * that,const NameAttrList & that_func)498   Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that,
499                              const NameAttrList& that_func) {
500     return CheckFunctionsEqual(this_func.name(), this_func.attr(), that,
501                                that_func.name(), that_func.attr());
502   }
CheckFunctionsEqual(const std::string & this_name,const AttrValueMap & this_attrs,GraphHasher * that,const std::string & that_name,const AttrValueMap & that_attrs)503   Status CheckFunctionsEqual(const std::string& this_name,
504                              const AttrValueMap& this_attrs, GraphHasher* that,
505                              const std::string& that_name,
506                              const AttrValueMap& that_attrs) {
507     Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, that_name,
508                                          that_attrs);
509     if (!s.ok()) {
510       return errors::FailedPrecondition("Functions ", this_name, " and ",
511                                         that_name, " are not the same:\n", s);
512     }
513     return s;
514   }
515 
CheckFunctionsEqualHelper(const std::string & this_name,const AttrValueMap & this_attrs,GraphHasher * that,const std::string & that_name,const AttrValueMap & that_attrs)516   Status CheckFunctionsEqualHelper(const std::string& this_name,
517                                    const AttrValueMap& this_attrs,
518                                    GraphHasher* that,
519                                    const std::string& that_name,
520                                    const AttrValueMap& that_attrs) {
521     const FunctionDef* this_fdef = flib_->Find(this_name);
522     const FunctionDef* that_fdef = that->flib_->Find(that_name);
523 
524     // Convert to GraphDefs.
525     std::unique_ptr<FunctionBody> this_fbody;
526     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
527         *this_fdef, AttrSlice(&this_attrs), flib_, &this_fbody));
528     GraphDef this_graph_def = this_fbody->graph->ToGraphDefDebug();
529     std::unique_ptr<FunctionBody> that_fbody;
530     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
531         *that_fdef, AttrSlice(&that_attrs), that->flib_, &that_fbody));
532     GraphDef that_graph_def = that_fbody->graph->ToGraphDefDebug();
533 
534     if (this_fbody->ret_nodes.size() != that_fbody->ret_nodes.size()) {
535       return errors::FailedPrecondition(
536           "Different numbers of ret nodes for functions ", this_name, " and ",
537           that_name, ": ", this_fbody->ret_nodes.size(), " vs ",
538           that_fbody->ret_nodes.size());
539     }
540     for (int i = 0; i < this_fbody->ret_nodes.size(); ++i) {
541       const NodeDef* this_root = &this_fbody->ret_nodes[i]->def();
542       const NodeDef* that_root = &that_fbody->ret_nodes[i]->def();
543       GraphHasher this_hasher(&this_graph_def, this_root, flib_);
544       TF_RETURN_IF_ERROR(this_hasher.Init());
545       GraphHasher that_hasher(&that_graph_def, that_root, that->flib_);
546       TF_RETURN_IF_ERROR(that_hasher.Init());
547       TF_RETURN_IF_ERROR(this_hasher.CheckEqual(&that_hasher));
548     }
549 
550     std::vector<const NodeDef*> this_control_rets;
551     for (const auto& control_ret_node : this_fbody->control_ret_nodes) {
552       this_control_rets.push_back(&control_ret_node->def());
553     }
554     std::vector<const NodeDef*> that_control_rets;
555     for (const auto& control_ret_node : that_fbody->control_ret_nodes) {
556       that_control_rets.push_back(&control_ret_node->def());
557     }
558     TF_RETURN_IF_ERROR(
559         CheckControlInputsEqual(this_control_rets, that, that_control_rets));
560     return Status::OK();
561   }
562 
HashControlInputs(const std::vector<const NodeDef * > & inputs,uint64 * hash)563   Status HashControlInputs(const std::vector<const NodeDef*>& inputs,
564                            uint64* hash) {
565     *hash = 0;
566     for (const NodeDef* input : inputs) {
567       uint64 node_hash = 0;
568       TF_RETURN_IF_ERROR(
569           HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
570       *hash = Hash64CombineUnordered(*hash, node_hash);
571     }
572     return Status::OK();
573   }
574 
CheckControlInputsEqual(const std::vector<const NodeDef * > & this_inputs,GraphHasher * that,const std::vector<const NodeDef * > & that_inputs)575   Status CheckControlInputsEqual(
576       const std::vector<const NodeDef*>& this_inputs, GraphHasher* that,
577       const std::vector<const NodeDef*>& that_inputs) {
578     absl::flat_hash_map<uint64, const NodeDef*> this_hashes;
579     for (const NodeDef* input : this_inputs) {
580       uint64 node_hash = 0;
581       TF_RETURN_IF_ERROR(
582           HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
583       this_hashes[node_hash] = input;
584     }
585     absl::flat_hash_map<uint64, const NodeDef*> that_hashes;
586     for (const NodeDef* input : that_inputs) {
587       uint64 node_hash = 0;
588       TF_RETURN_IF_ERROR(
589           HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
590       if (this_hashes.contains(node_hash)) {
591         this_hashes.erase(node_hash);
592       } else {
593         that_hashes[node_hash] = input;
594       }
595     }
596     if (!this_hashes.empty()) {
597       std::vector<std::string> this_unmatched;
598       for (const auto& it : this_hashes) {
599         this_unmatched.push_back(it.second->name());
600       }
601       std::vector<std::string> that_unmatched;
602       for (const auto& it : that_hashes) {
603         that_unmatched.push_back(it.second->name());
604       }
605       return errors::FailedPrecondition(
606           "Control dependencies are different. One node has dependencies [",
607           absl::StrJoin(this_unmatched, ", "),
608           "], which don't match any of the other node's dependencies [",
609           absl::StrJoin(that_unmatched, ", "), "]");
610     }
611     return Status::OK();
612   }
613 
614  private:
is_cycle_forming_edge(const NodeDef * start,const NodeDef * end)615   bool is_cycle_forming_edge(const NodeDef* start, const NodeDef* end) {
616     EdgeRep edge(start, end);
617     return cycle_forming_edges_.contains(edge.GetHash());
618   }
619 
620   struct NodeRep {
621     std::vector<const NodeDef*> node_control_inputs;
622     std::vector<std::pair<const NodeDef*, std::string>> node_inputs;
623   };
624 
625   struct EdgeRep {
626     const NodeDef* start_node;
627     const NodeDef* end_node;
628 
EdgeReptensorflow::data::__anon391c42a80111::GraphHasher::EdgeRep629     EdgeRep(const NodeDef* start, const NodeDef* end)
630         : start_node(start), end_node(end) {}
631 
GetHashtensorflow::data::__anon391c42a80111::GraphHasher::EdgeRep632     uint64 GetHash() {
633       return Hash64Combine(absl::Hash<const NodeDef*>()(start_node),
634                            absl::Hash<const NodeDef*>()(end_node));
635     }
636   };
637   const GraphDef* const graph_;                  // Not owned.
638   const NodeDef* const root_;                    // Not owned.
639   const FunctionLibraryDefinition* const flib_;  // Not owned.
640   // Edges that need to be pruned as their presence will cause cycles.
641   absl::flat_hash_set<uint64> cycle_forming_edges_;
642   absl::flat_hash_map<const NodeDef*, NodeRep> nodes_;
643   absl::flat_hash_map<const NodeDef*, uint64> cache_;
644 };
645 
646 }  // anonymous namespace
647 
HashTensor(const Tensor & tensor,uint64 * hash)648 Status HashTensor(const Tensor& tensor, uint64* hash) {
649   const tstring* s = nullptr;
650   // Hash tensor type.
651   *hash = Hash64Combine(0, tensor.dtype());
652   // Hash tensor shape.
653   for (int i = 0; i < tensor.shape().dims(); ++i) {
654     *hash = Hash64Combine(*hash, tensor.shape().dim_size(i));
655   }
656   // Hash tensor data.
657   switch (tensor.dtype()) {
658     case DT_RESOURCE:
659     case DT_VARIANT:
660       return errors::Unimplemented("Hashing ", DataTypeString(tensor.dtype()),
661                                    " is not supported.");
662     case DT_STRING:
663       s = tensor.flat<tstring>().data();
664       for (int i = 0; i < tensor.NumElements(); ++i, ++s) {
665         *hash = Hash64Combine(*hash, Hash64(s->data(), s->size()));
666       }
667       break;
668     default:
669       *hash = Hash64(tensor.tensor_data().data(), tensor.tensor_data().size());
670   }
671   return Status::OK();
672 }
673 
HashNode(const GraphDef & graph,const NodeDef & node,uint64 * hash)674 Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) {
675   const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
676                                            graph.library());
677   return HashNode(graph, node, flib_def, hash);
678 }
679 
HashNode(const GraphDef & graph,const NodeDef & node,const FunctionLibraryDefinition & flib_def,uint64 * hash)680 Status HashNode(const GraphDef& graph, const NodeDef& node,
681                 const FunctionLibraryDefinition& flib_def, uint64* hash) {
682   GraphHasher hasher(&graph, &node, &flib_def);
683   TF_RETURN_IF_ERROR(hasher.Init());
684   return hasher.HashRoot(hash);
685 }
686 
HashGraph(const GraphDef & graph_def,uint64 * hash)687 Status HashGraph(const GraphDef& graph_def, uint64* hash) {
688   const NodeDef* sink = nullptr;
689   TF_RETURN_IF_ERROR(GetSink(graph_def, &sink));
690   return HashNode(graph_def, *sink, hash);
691 }
692 
CheckGraphsEqual(const GraphDef & a,const GraphDef & b)693 Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) {
694   const NodeDef* sink_a;
695   TF_RETURN_IF_ERROR(GetSink(a, &sink_a));
696   const NodeDef* sink_b;
697   TF_RETURN_IF_ERROR(GetSink(b, &sink_b));
698   return CheckSubgraphsEqual(a, sink_a, b, sink_b);
699 }
700 
CheckSubgraphsEqual(const GraphDef & a,const NodeDef * node_a,const GraphDef & b,const NodeDef * node_b)701 Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a,
702                            const GraphDef& b, const NodeDef* node_b) {
703   const FunctionLibraryDefinition flib_def_a(OpRegistry::Global(), a.library());
704   GraphHasher hasher_a(&a, node_a, &flib_def_a);
705   TF_RETURN_IF_ERROR(hasher_a.Init());
706 
707   const FunctionLibraryDefinition flib_def_b(OpRegistry::Global(), b.library());
708   GraphHasher hasher_b(&b, node_b, &flib_def_b);
709   TF_RETURN_IF_ERROR(hasher_b.Init());
710 
711   return hasher_a.CheckEqual(&hasher_b);
712 }
713 
714 }  // namespace data
715 }  // namespace tensorflow
716