• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_constructor.h"
17 
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/core/common_runtime/shape_refiner.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function.pb.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/versions.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/tensor_id.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/gtl/flatmap.h"
42 #include "tensorflow/core/lib/gtl/flatset.h"
43 #include "tensorflow/core/lib/gtl/inlined_vector.h"
44 #include "tensorflow/core/lib/strings/scanner.h"
45 #include "tensorflow/core/lib/strings/str_util.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/macros.h"
48 #include "tensorflow/core/public/version.h"
49 
50 namespace tensorflow {
51 
52 namespace {
53 
54 // We remove duplicate control inputs before adding edges to the Graph, so we
55 // can skip expensive duplicates check in 'AddControlEdge'.
56 static constexpr const bool kDoNotCheckDuplicates = true;
57 
IsMerge(const NodeDef & node_def)58 inline bool IsMerge(const NodeDef& node_def) {
59   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
60          node_def.op() == "_XlaMerge";
61 }
62 
IsNextIteration(const NodeDef & node_def)63 inline bool IsNextIteration(const NodeDef& node_def) {
64   return node_def.op() == "NextIteration" ||
65          node_def.op() == "RefNextIteration";
66 }
67 
IsValidNodeName(StringPiece s,bool allow_internal_ops)68 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
69   using ::tensorflow::strings::Scanner;
70   Scanner scanner(s);
71   scanner
72       .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
73                               : Scanner::LETTER_DIGIT_DOT)
74       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
75 
76   while (true) {
77     if (!scanner.GetResult())  // Some error in previous iteration.
78       return false;
79     if (scanner.empty())  // No error, but nothing left, good.
80       return true;
81 
82     // Absorb another piece, starting with a '>'
83     scanner.One(Scanner::RANGLE)
84         .One(Scanner::LETTER_DIGIT_DOT)
85         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
86   }
87 }
88 
89 class GraphConstructor {
90  public:
91   struct Options {
Optionstensorflow::__anon29e8b4a40111::GraphConstructor::Options92     Options(const GraphConstructorOptions& in)  // NOLINT(runtime/explicit)
93         : allow_internal_ops(in.allow_internal_ops),
94           expect_device_spec(in.expect_device_spec),
95           importing(false),
96           validate_nodes(in.validate_nodes),
97           validate_colocation_constraints(false),
98           add_default_attributes(in.add_default_attributes) {}
Optionstensorflow::__anon29e8b4a40111::GraphConstructor::Options99     Options(const ImportGraphDefOptions& in)  // NOLINT(runtime/explicit)
100         : allow_internal_ops(false),
101           expect_device_spec(false),
102           prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
103                      ? in.prefix
104                      : in.prefix + "/"),
105           uniquify_names(in.uniquify_names),
106           uniquify_prefix(in.uniquify_prefix),
107           input_map(in.input_map.begin(), in.input_map.end()),
108           skip_mapped_nodes(in.skip_mapped_nodes),
109           control_dependencies(in.control_dependencies),
110           return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
111           return_nodes(in.return_nodes),
112           importing(true),
113           validate_nodes(true),
114           validate_colocation_constraints(in.validate_colocation_constraints),
115           validate_shape(in.validate_shape),
116           default_device(in.default_device) {}
117 
118     bool allow_internal_ops;
119     bool expect_device_spec;
120 
121     string prefix;
122     bool uniquify_names;
123     bool uniquify_prefix;
124     std::map<TensorId, TensorId> input_map;
125     bool skip_mapped_nodes;
126     std::vector<string> control_dependencies;
127     std::vector<TensorId> return_tensors;
128     std::vector<string> return_nodes;
129 
130     // TODO(ashankar): This bool exists to separate out functionality required
131     // to make ImportGraphDef a close equivalent of Python's import_graph_def
132     // without affecting the behavior of ConvertGraphDefToGraph at the time
133     // ImportGraphDef was added.
134     //
135     // That said, the functionality here (shape and op validation) seems
136     // applicable to ConvertGraphDefToGraph as well, so make an attempt to
137     // remove this.
138     bool importing;
139     // If true, validates that nodes being converted have all expected attrs
140     // set and no unknonw attrs set by calling ValidateNodeDef().
141     // `validate_nodes` is always true when `importing` is set.
142     bool validate_nodes;
143     bool validate_colocation_constraints;
144     bool validate_shape = true;
145 
146     // If true, GraphConstructor will add attributes with their default
147     // value to the Node when they are missing from the NodeDef.
148     bool add_default_attributes = true;
149 
150     string default_device;
151   };
152 
153   typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
154 
155   // versions and library may be nullptr
156   static Status Construct(
157       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
158       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
159       std::vector<std::pair<Node*, int>>* return_tensors,
160       std::vector<Node*>* return_nodes,
161       std::vector<SafeTensorId>* missing_unused_input_map_keys);
162 
163   static Status Construct(
164       const Options& opts, GraphDef&& graph_def, Graph* g,
165       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
166       std::vector<Node*>* return_nodes,
167       std::vector<SafeTensorId>* missing_unused_input_map_keys);
168 
169  protected:
GraphConstructor(const Options & opts,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)170   GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
171                    std::vector<std::pair<Node*, int>>* return_tensors,
172                    std::vector<Node*>* return_nodes,
173                    std::vector<SafeTensorId>* missing_unused_input_map_keys)
174       : opts_(opts),
175         g_(g),
176         original_versions_(g->versions()),
177         prefix_(opts.prefix),
178         refiner_(refiner),
179         return_tensors_(return_tensors),
180         return_nodes_(return_nodes),
181         missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
182 
~GraphConstructor()183   virtual ~GraphConstructor() {}
184 
TryImport()185   Status TryImport() {
186     TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
187     TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
188     TF_RETURN_IF_ERROR(BuildNodeIndex());
189     TF_RETURN_IF_ERROR(InitFromEdges());
190 
191     // NOTE: Convert() invokes `consume_node_def()` on each node in the input
192     // graph, so `get_node_def()` is no longer usable once it is called.
193     TF_RETURN_IF_ERROR(Convert());
194 
195     TF_RETURN_IF_ERROR(AddBackEdges());
196     TF_RETURN_IF_ERROR(UpdateVersionDef());
197     TF_RETURN_IF_ERROR(PopulateReturnTensors());
198     TF_RETURN_IF_ERROR(PopulateReturnNodes());
199     TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
200     UpdateUniquifiedColocationNames();
201     FixupSourceAndSinkEdges(g_);
202     return Status::OK();
203   }
204 
205  private:
206   Status EnsureNoNameCollisions();
207   Status ValidateInputMapAndControlDependencies();
208   Status BuildNodeIndex();
209   Status InitFromEdges();
210   Status Convert();
211   Status AddBackEdges();
212   Status UpdateVersionDef();
213   Status PopulateReturnTensors();
214   Status PopulateReturnNodes();
215   Status PopulateMissingUnusedInputMapKeys();
216 
217   void Undo();
218 
219   // Prints cycles in the graph.
220   void PrintCycles();
221   // Performs DFS starting at `cur_node` and prints any cycles found.
222   void DFS(int cur_node, std::vector<int>* cur_branch,
223            std::vector<bool>* is_on_cur_branch,
224            absl::flat_hash_set<int>* unvisited);
225   Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
226   Status ValidateColocationConstraints(const NodeDef& node_def);
227   Status MakeNode(NodeDef&& node_def, Node** node);
228   Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
229   Status ValidateShape(Node* node);
230   Status ModifyNodeDefForImport(NodeDef* node_def);
231   // Modifies node_def's inputs according to opts_.input_map.
232   // input_already_exists is a pre-initialized vector of length
233   // node_def->input_size(). This function will mark inputs that are remapped to
234   // true.
235   void RemapNodeDefInputs(NodeDef* node_def,
236                           std::vector<bool>* input_already_exists);
237   // input_already_exists is a pre-initialized vector of length
238   // node_def->input_size(). This function will add and mark control inputs as
239   // true.
240   void AddControlDependencies(NodeDef* node_def,
241                               std::vector<bool>* input_already_exists);
242   void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
243                           NodeDef* node_def);
244 
245   // Modifies `node_def` if its name isn't unique, or if any of its inputs'
246   // names have been uniquified. This must be called in topological order on all
247   // nodes.
248   void UniquifyNames(const std::vector<bool>& input_already_exists,
249                      NodeDef* node_def);
250 
251   // Updates any constructed nodes' colocation group names if the name has been
252   // updated by UniquifyNames. This is called after all the nodes have been
253   // constructed so all the names have been uniquified if necessary.
254   void UpdateUniquifiedColocationNames();
255 
256   // Returns true if `name` already exists in `g_` (either as a node name or
257   // prefix).
258   bool NameExistsInGraph(StringPiece name);
259 
260   // Returns true if `name` already exists in the GraphDef being imported
261   // (either as a node name or prefix).
262   bool NameExistsInGraphDef(StringPiece name);
263 
264   // Returns a unique version of `original_name`, or `original_name` if it's
265   // already unique in the graph.
266   string FindUniqueName(StringPiece original_name);
267 
268   // Decrement pending count for users of `processed` and add the ones that now
269   // have all of their pending inputs satisfied to `ready_`.
270   void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
271 
272   // Subclasses override the following virtual methods to provide efficient
273   // access to the original protocol buffer-based graph.
274 
275   // Returns the number of nodes in the graph.
276   virtual size_t node_def_count() const = 0;
277   // Returns the i^th node in the graph. Must not be called after
278   // consume_node_def(i).
279   virtual const NodeDef& get_node_def(int i) const = 0;
280   // Destructively reads the i^th node in the graph, avoiding a copy if
281   // possible. After calling this method, the result of get_node_def(i) is
282   // undefined.
283   virtual NodeDef consume_node_def(int i) = 0;
284   // Returns the version information for the graph, or nullptr if none is
285   // available.
286   virtual const VersionDef* versions() const = 0;
287   // Returns the function information for the graph, or nullptr if none is
288   // available.
289   virtual const FunctionDefLibrary* library() const = 0;
290 
291   // From constructor
292   const Options opts_;
293   Graph* g_;
294   const VersionDef original_versions_;
295 
296   // A copy of opts_.prefix, possibly uniquified.
297   string prefix_;
298 
299   ShapeRefiner* refiner_;
300 
301   // May be null. Not owned.
302   std::vector<std::pair<Node*, int>>* return_tensors_;
303 
304   // May be null. Not owned.
305   std::vector<Node*>* return_nodes_;
306 
307   // May be null. Not owned.
308   std::vector<SafeTensorId>* missing_unused_input_map_keys_;
309 
310   // Intermediate datastructure used to populate
311   // `missing_unused_input_map_keys_`.
312   std::set<TensorId> used_input_map_keys_;
313 
314   // Intermediate datastructure used to track the destinations of back edges.
315   absl::flat_hash_set<int> merge_node_indices_;
316 
317   // Mapping from node name to the index within node_defs_.
318   struct NodeInfo {
NodeInfotensorflow::__anon29e8b4a40111::GraphConstructor::NodeInfo319     explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
320     // Containers require that we have a default constructor.
NodeInfotensorflow::__anon29e8b4a40111::GraphConstructor::NodeInfo321     NodeInfo() : NodeInfo(-1) {}
322     int gdef_index;
323     Node* node;  // nullptr until the NodeDef is converted to a Node.
324   };
325   gtl::FlatMap<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_;
326 
327   // Storage for StringPiece keys in gdef_nodes_. Typically, the StringPiece key
328   // will refer to the string stored in `NodeDef::name()`. This intern table is
329   // only used when the original NodeDef's name is changed.
330   std::vector<string> string_intern_table_;
331 
332   // Prefixes already used in the GraphDef being imported.
333   gtl::FlatSet<StringPiece, StringPieceHasher> gdef_prefixes_;
334 
335   // Mapping from node name to the existing node in g_.
336   gtl::FlatMap<StringPiece, Node*, StringPieceHasher> existing_nodes_;
337 
338   // Prefixes already used in the graph.
339   gtl::FlatSet<StringPiece, StringPieceHasher> existing_prefixes_;
340 
341   // Imported node names that have been uniquified. The key is the original
342   // name, the value is the new unique name.
343   gtl::FlatMap<string, string> uniquified_names_;
344 
345   // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
346   // (sorted) set so nodes are created in the order defined in the GraphDef.
347   std::set<int> ready_;
348 
349   // Mapping between index within node_defs_ and the number of inputs that
350   // still need to be converted.
351   std::vector<int> pending_count_;
352 
353   // Mapping between index within node_defs_ and the index within node_defs_ of
354   // all nodes it outputs to.
355   std::vector<gtl::InlinedVector<int, 4>> outputs_;
356 
357   // Used in the conversion from node_defs_ to g_ to represent the ith input
358   // of a node.
359   struct InputInfo {
InputInfotensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo360     explicit InputInfo(const string& node_name, Node* n, int i)
361         : name(node_name), node(n), index(i) {}
362     // Use string instead of StringPiece so we don't have to manage lifetime
363     string name;
364     Node* node;
365     int index;
366 
IsControlInputtensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo367     static bool IsControlInput(const InputInfo& input) {
368       return input.index == Graph::kControlSlot;
369     }
CompareNametensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo370     static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
371       return lhs.name < rhs.name;
372     }
IsSameNametensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo373     static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
374       return lhs.name == rhs.name;
375     }
376   };
377 
378   // Used in the conversion from node_defs_ to g_ to represent an edge from
379   // the node named 'name' to node 'n'.
380   struct EdgeInfo {
EdgeInfotensorflow::__anon29e8b4a40111::GraphConstructor::EdgeInfo381     explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
382         : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
383     // Use string instead of StringPiece so we don't have to manage lifetime
384     string src_name;
385     int src_index;
386     Node* dst_node;
387     int dst_index;
388   };
389   std::vector<EdgeInfo> back_edges_;
390 
391   TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
392 };
393 
394 // Implementation of GraphConstructor that does not take ownership of the
395 // input NodeDef messages and thus copies the nodes into the constructed Graph*.
396 //
397 // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
398 // avoids copying each NodeDef into the constructed Graph*.
399 class NodeDefCopyingGraphConstructor : public GraphConstructor {
400  public:
NodeDefCopyingGraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)401   NodeDefCopyingGraphConstructor(
402       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
403       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
404       std::vector<std::pair<Node*, int>>* return_tensors,
405       std::vector<Node*>* return_nodes,
406       std::vector<SafeTensorId>* missing_unused_input_map_keys)
407       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
408                          missing_unused_input_map_keys),
409         node_defs_(node_defs),
410         versions_(versions),
411         library_(library) {}
412 
413  private:
node_def_count() const414   size_t node_def_count() const override { return node_defs_.size(); }
get_node_def(int i) const415   const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
consume_node_def(int i)416   NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
versions() const417   const VersionDef* versions() const override { return versions_; }
library() const418   const FunctionDefLibrary* library() const override { return library_; }
419 
420   const NodeDefSlice node_defs_;
421   const VersionDef* const versions_;
422   const FunctionDefLibrary* const library_;
423 };
424 
425 // Implementation of GraphConstructor that takes ownership of the input
426 // GraphDef, and can perform destructive reads.
427 class NodeDefMovingGraphConstructor : public GraphConstructor {
428  public:
NodeDefMovingGraphConstructor(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)429   NodeDefMovingGraphConstructor(
430       const Options& opts, GraphDef&& graph_def, Graph* g,
431       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
432       std::vector<Node*>* return_nodes,
433       std::vector<SafeTensorId>* missing_unused_input_map_keys)
434       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
435                          missing_unused_input_map_keys),
436         graph_def_(std::move(graph_def)),
437         is_consumed_(graph_def_.node_size(), false) {}
438 
439  private:
node_def_count() const440   size_t node_def_count() const override { return graph_def_.node().size(); }
get_node_def(int i) const441   const NodeDef& get_node_def(int i) const override {
442     CHECK(!is_consumed_[i])
443         << "NodeDef " << i << " accessed after it was consumed.";
444     return graph_def_.node(i);
445   }
consume_node_def(int i)446   NodeDef consume_node_def(int i) override {
447     CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
448     is_consumed_[i] = true;
449     return std::move(*graph_def_.mutable_node(i));
450   }
versions() const451   const VersionDef* versions() const override { return &graph_def_.versions(); }
library() const452   const FunctionDefLibrary* library() const override {
453     return &graph_def_.library();
454   }
455 
456   GraphDef graph_def_;
457   std::vector<bool> is_consumed_;
458 };
459 
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)460 /* static */ Status GraphConstructor::Construct(
461     const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
462     const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
463     std::vector<std::pair<Node*, int>>* return_tensors,
464     std::vector<Node*>* return_nodes,
465     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
466   if (versions) {
467     TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
468                                      TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
469                                      "GraphDef", "graph"));
470   }
471   NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
472                                    refiner, return_tensors, return_nodes,
473                                    missing_unused_input_map_keys);
474   const Status s = c.TryImport();
475   if (!s.ok()) c.Undo();
476   return s;
477 }
478 
Construct(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)479 /* static */ Status GraphConstructor::Construct(
480     const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
481     std::vector<std::pair<Node*, int>>* return_tensors,
482     std::vector<Node*>* return_nodes,
483     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
484   TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
485                                    TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
486                                    "GraphDef", "graph"));
487   NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
488                                   return_tensors, return_nodes,
489                                   missing_unused_input_map_keys);
490   const Status s = c.TryImport();
491   if (!s.ok()) c.Undo();
492   return s;
493 }
494 
UpdatePendingCountAndReady(int processed,bool is_next_iteration)495 void GraphConstructor::UpdatePendingCountAndReady(int processed,
496                                                   bool is_next_iteration) {
497   for (size_t i = 0; i < outputs_[processed].size(); ++i) {
498     const int output = outputs_[processed][i];
499     // We didn't consider NextIteration->Merge edges when computing
500     // pending_counts_ so we should not have to consider it here either.
501     bool is_next_iteration_to_merge_edge =
502         is_next_iteration && merge_node_indices_.count(output) == 1;
503     if (!is_next_iteration_to_merge_edge) {
504       int* current_pending_count = &pending_count_[output];
505       CHECK_GT(*current_pending_count, 0);
506       (*current_pending_count)--;
507       if (*current_pending_count == 0) {
508         ready_.insert(output);
509       }
510     }
511   }
512 }
513 
514 // This could be expensive but we don't expect to call it often, if at all (only
515 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)516 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
517                       const StringPiece& node_name) {
518   for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
519     if (iter->second.first == node_name) return true;
520   }
521   return false;
522 }
523 
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)524 bool NodeNameInValues(const std::vector<string>& control_dependencies,
525                       const StringPiece& node_name) {
526   return std::find(control_dependencies.begin(), control_dependencies.end(),
527                    node_name) != control_dependencies.end();
528 }
529 
530 // Adds any prefixes of `node_name` (not including the full name itself) to
531 // `prefixes`.
AddPrefixes(StringPiece node_name,gtl::FlatSet<StringPiece,StringPieceHasher> * prefixes)532 void AddPrefixes(StringPiece node_name,
533                  gtl::FlatSet<StringPiece, StringPieceHasher>* prefixes) {
534   size_t idx = -1;
535   while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
536     prefixes->insert(node_name.substr(0, idx));
537   }
538 }
539 
EnsureNoNameCollisions()540 Status GraphConstructor::EnsureNoNameCollisions() {
541   existing_nodes_.reserve(g_->num_nodes());
542   // Populate existing_nodes_ and existing_prefixes_.
543   for (Node* n : g_->nodes()) {
544     bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
545     if (already_exists) {
546       if (NodeNameInValues(opts_.input_map, n->name())) {
547         return errors::InvalidArgument(
548             "cannot resolve input_map because multiple nodes exist with name '",
549             n->name(), "'");
550       }
551       if (NodeNameInValues(opts_.control_dependencies, n->name())) {
552         return errors::InvalidArgument(
553             "cannot resolve control_dependencies because multiple nodes exist "
554             "with name '",
555             n->name(), "'");
556       }
557     }
558     AddPrefixes(n->name(), &existing_prefixes_);
559   }
560   if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
561     for (size_t i = 0; i < node_def_count(); ++i) {
562       const string& name = get_node_def(i).name();
563       if (NameExistsInGraph(name)) {
564         return errors::InvalidArgument("Node name '", name,
565                                        "' already exists in the Graph");
566       }
567     }
568   } else if (!prefix_.empty()) {
569     StringPiece prefix_no_slash(prefix_);
570     prefix_no_slash.remove_suffix(1);
571     if (!IsValidNodeName(prefix_no_slash, false)) {
572       return errors::InvalidArgument("Imported node name prefix '", prefix_,
573                                      "' would lead to invalid node names");
574     }
575     if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
576       prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
577     }
578   }
579   return Status::OK();
580 }
581 
ValidateInputMapAndControlDependencies()582 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
583   for (const auto& mapping : opts_.input_map) {
584     TensorId src = mapping.first;
585     TensorId dst = mapping.second;
586     if (existing_nodes_.count(dst.first) == 0) {
587       return errors::InvalidArgument(
588           "node '", dst.first, "' in input_map does not exist in graph ",
589           "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
590     }
591     if ((src.second == Graph::kControlSlot) !=
592         (dst.second == Graph::kControlSlot)) {
593       return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
594                                      dst.ToString(), " between ",
595                                      "control edge and non-control edge");
596     }
597   }
598   for (const string& node : opts_.control_dependencies) {
599     if (existing_nodes_.count(node) == 0) {
600       return errors::InvalidArgument(
601           "node '", node,
602           "' in control_dependencies does not exist in "
603           "graph");
604     }
605   }
606   return Status::OK();
607 }
608 
BuildNodeIndex()609 Status GraphConstructor::BuildNodeIndex() {
610   // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
611   for (int n = 0; n < node_def_count(); ++n) {
612     const NodeDef& node_def = get_node_def(n);
613     if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
614       return errors::InvalidArgument(
615           "Node '", node_def.name(),
616           "': Node name contains invalid characters");
617     }
618     if (!gdef_nodes_
619              .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
620              .second) {
621       return errors::InvalidArgument("Node '", node_def.name(),
622                                      "' is not unique");
623     }
624     // Validate the operation's type.
625     if (node_def.op().empty()) {
626       return errors::InvalidArgument("Node '", node_def.name(),
627                                      "' does not specify an operation");
628     }
629     if (opts_.expect_device_spec && node_def.device().empty()) {
630       return errors::InvalidArgument("Node '", node_def.name(),
631                                      "' is missing a device specification");
632     }
633     if (IsMerge(node_def)) {
634       merge_node_indices_.insert(n);
635     }
636     // Validate control edges at end
637     bool in_control_dependence = false;
638     for (int i = 0; i < node_def.input_size(); ++i) {
639       StringPiece input_name = node_def.input(i);
640       if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
641         in_control_dependence = true;
642       } else if (in_control_dependence) {
643         return errors::InvalidArgument(
644             "Node '", node_def.name(),
645             "': Control dependencies must come after regular dependencies");
646       }
647     }
648     // Update gdef_prefixes_.
649     AddPrefixes(node_def.name(), &gdef_prefixes_);
650   }
651   return Status::OK();
652 }
653 
InitFromEdges()654 Status GraphConstructor::InitFromEdges() {
655   const int num_nodes = node_def_count();
656   pending_count_.reserve(num_nodes);
657   outputs_.resize(num_nodes);
658   gtl::FlatSet<string> next_iteration_nodes;
659   for (int n = 0; n < node_def_count(); ++n) {
660     const NodeDef& node_def = get_node_def(n);
661     if (IsNextIteration(node_def)) {
662       next_iteration_nodes.insert(node_def.name());
663     }
664   }
665 
666   // Parse the inputs for each node.
667   for (int n = 0; n < num_nodes; ++n) {
668     const NodeDef& node_def = get_node_def(n);
669     int pending_count = node_def.input_size();
670     if (IsMerge(node_def)) {
671       // Cycles in the graph are only allowed for while loops. A while loop is
672       // identified by an edge from a NextIteration node to a Merge node. For
673       // such Merge nodes, only wait for one non-control input before
674       // considering the node ready to process in Convert().
675       int32 num_control_edges = 0;
676       bool has_loop_back_edge = false;
677       for (int i = 0; i < node_def.input_size(); ++i) {
678         StringPiece input_name(node_def.input(i));
679         if (absl::StartsWith(input_name, "^")) {
680           num_control_edges++;
681         } else {
682           TensorId id(ParseTensorName(input_name));
683           if (next_iteration_nodes.find(string(id.first)) !=
684               next_iteration_nodes.end()) {
685             has_loop_back_edge = true;
686           }
687         }
688       }
689       if (has_loop_back_edge) {
690         pending_count = num_control_edges + 1;
691       }
692     }
693     for (int i = 0; i < node_def.input_size(); ++i) {
694       StringPiece input_name = node_def.input(i);
695       TensorId id(ParseTensorName(input_name));
696       if (opts_.input_map.count(id) == 0) {
697         // If an input is not mapped, then the input should appear in the graph
698         // being imported.
699         auto iter = gdef_nodes_.find(id.first);
700         if (iter == gdef_nodes_.end()) {
701           return errors::InvalidArgument("Node '", node_def.name(),
702                                          "': Unknown input node '",
703                                          node_def.input(i), "'");
704         }
705         outputs_[iter->second.gdef_index].push_back(n);
706       } else {
707         // This input is mapped to an existing edge. Therefore this input is
708         // as good as being already processed.
709         --pending_count;
710         DCHECK_GE(pending_count, 0);
711       }
712     }
713     if (pending_count == 0) {
714       ready_.insert(n);
715     }
716     pending_count_.push_back(pending_count);
717   }
718   return Status::OK();
719 }
720 
ValidateColocationConstraints(const NodeDef & node_def)721 Status GraphConstructor::ValidateColocationConstraints(
722     const NodeDef& node_def) {
723   if (!opts_.validate_colocation_constraints || !opts_.importing)
724     return Status::OK();
725   const auto iter = node_def.attr().find(kColocationAttrName);
726   if (iter == node_def.attr().end()) return Status::OK();
727   for (const string& c : iter->second.list().s()) {
728     StringPiece s(c);
729     if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
730         gdef_nodes_.find(s) == gdef_nodes_.end()) {
731       return errors::InvalidArgument(
732           "Node '", node_def.name(),
733           "' expects to be colocated with unknown node '", s, "'");
734     }
735   }
736   return Status::OK();
737 }
738 
MakeNode(NodeDef && node_def,Node ** node)739 Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
740   // Add the node to the graph.
741   Status status;
742   *node = g_->AddNode(std::move(node_def), &status);
743   if (!status.ok()) return status;
744   if (opts_.expect_device_spec) {
745     (*node)->set_assigned_device_name((*node)->def().device());
746   }
747   return Status::OK();
748 }
749 
ValidateShape(Node * node)750 Status GraphConstructor::ValidateShape(Node* node) {
751   if (!opts_.importing || !opts_.validate_shape) return Status::OK();
752   TF_RETURN_IF_ERROR(refiner_->AddNode(node));
753   // For nodes with the _output_shapes attribute, override the shape.
754   std::vector<const TensorShapeProto*> shape_attrs;
755   const char* kAttrName = "_output_shapes";
756   if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
757     // No _output_shapes attribute, the AddNode call above was sufficient.
758     return Status::OK();
759   }
760   auto* ic = refiner_->GetContext(node);
761   DCHECK(ic != nullptr)
762       << "ShapeRefiner::AddNode() should have created the InferenceContext";
763   if (shape_attrs.size() < node->num_outputs()) {
764     return errors::InvalidArgument(
765         "Node '", node->name(), "' has ", node->num_outputs(),
766         " outputs but the ", kAttrName, " attribute specifies shapes for ",
767         shape_attrs.size(), " outputs");
768   }
769   // NOTE(skyewm): we don't raise an error here because some users depend on
770   // this behavior, even though it's unsafe.
771   // TODO(b/74619486): raise an error.
772   if (shape_attrs.size() > node->num_outputs()) {
773     LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
774                  << " outputs but the " << kAttrName
775                  << " attribute specifies shapes for " << shape_attrs.size()
776                  << " outputs. Output shapes may be inaccurate.";
777   }
778   for (int i = 0; i < node->num_outputs(); ++i) {
779     const TensorShapeProto& p = *shape_attrs[i];
780     shape_inference::ShapeHandle h;
781     Status s = ic->MakeShapeFromShapeProto(p, &h);
782     if (!s.ok()) {
783       return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
784                                      kAttrName, " attribute (shape #", i,
785                                      " error:'", s.error_message(), "'");
786     }
787     s = refiner_->SetShape(node, i, h);
788     if (!s.ok()) {
789       // If the output shape is incompatible with what is inferred
790       // by the graph for a very specific whitelist of ops, then we
791       // ignore this output shape.  This can happen if there is a
792       // bug in the shape function for some operation, and the
793       // serialized graph def has the incorrect shape set when
794       // running on a newer binary with the fixed shape function.
795       // This is an escape hatch that allows us to correct shape
796       // functions that are not critical to correct execution but
797       // would cause graphs to fail if imported after correcting.
798       const string& op = node->type_string();
799       const std::vector<string> whitelist = {
800           // To be removed after 2017/03/08.
801           "RandomShuffleQueue",
802           "PaddingFIFOQueue",
803           "FIFOQueue",
804           "PriorityQueue",
805           "QueueSize",
806           "Stack",
807           "Barrier",
808           "BarrierReadySize",
809           "BarrierIncompleteSize",
810           "HashTable",
811           "MutableHashTable",
812           "MutableHashTableOfTensors",
813           "Mutex",
814           "CuckooTable",
815           "IndexTable",
816           "WholeFileReader",
817           "TextLineReader",
818           "FixedLengthRecordReader",
819           "TFRecordReader",
820           "IdentityReader",
821           "RefSwitch",
822           "RefEnter",
823           "RefNextIteration",
824           "RefMerge",
825           "RefIdentity",
826           "LMDBReader",
827           // To be removed after 2017/04/24.
828           "ConditionalAccumulator",
829           "SparseConditionalAccumulator",
830           "Table",
831       };
832       if (std::find(whitelist.begin(), whitelist.end(), op) ==
833           whitelist.end()) {
834         return errors::InvalidArgument(
835             "Node '", node->name(), "' has an ", kAttrName,
836             " attribute inconsistent with the GraphDef for output #", i, ": ",
837             s.error_message());
838       }
839     }
840   }
841   node->ClearAttr(kAttrName);
842   return Status::OK();
843 }
844 
ModifyNodeDefForImport(NodeDef * node_def)845 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
846   const OpDef* op_def;
847   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
848   AddDefaultsToNodeDef(*op_def, node_def);
849   TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
850   if (versions()) {
851     TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
852   }
853   return Status::OK();
854 }
855 
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)856 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
857                   std::vector<bool>* input_already_exists) {
858   // Remove 'inputs_to_remove' from 'node_def'
859   NodeDef copy;
860   copy.mutable_input()->Reserve(node_def->input_size() -
861                                 inputs_to_remove.size());
862   for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
863     if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
864       ++j;
865     } else {
866       copy.add_input()->swap(*node_def->mutable_input(i));
867     }
868   }
869   node_def->mutable_input()->Swap(copy.mutable_input());
870   // Remove 'inputs_to_remove' from 'input_already_exists'
871   for (int idx : inputs_to_remove) {
872     input_already_exists->erase(input_already_exists->begin() + idx);
873   }
874   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
875 }
876 
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)877 void GraphConstructor::RemapNodeDefInputs(
878     NodeDef* node_def, std::vector<bool>* input_already_exists) {
879   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
880   std::set<TensorId> control_inputs;
881   std::vector<int> inputs_to_remove;
882 
883   for (int i = 0; i < node_def->input_size(); ++i) {
884     auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
885     if (iter == opts_.input_map.end()) continue;
886     used_input_map_keys_.insert(iter->first);
887 
888     TensorId new_input = iter->second;
889     if (new_input.second == Graph::kControlSlot) {
890       // Check if we've already remapped a different input to new_input, and if
891       // so remove this input.
892       if (control_inputs.count(new_input) > 0) {
893         inputs_to_remove.push_back(i);
894         continue;
895       }
896       control_inputs.insert(new_input);
897     }
898     node_def->set_input(i, new_input.ToString());
899     (*input_already_exists)[i] = true;
900   }
901   if (!inputs_to_remove.empty()) {
902     RemoveInputs(inputs_to_remove, node_def, input_already_exists);
903   }
904 }
905 
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)906 void GraphConstructor::AddControlDependencies(
907     NodeDef* node_def, std::vector<bool>* input_already_exists) {
908   // To avoid adding redundant control dependencies to every imported node, skip
909   // nodes that will inherit the dependencies from another imported node.
910   bool inherits_deps = false;
911   for (int i = 0; i < node_def->input_size(); ++i) {
912     // Assume we won't inherit dependencies from remapped inputs that already
913     // exist in the graph. Even if we're wrong, we'll only add redundant
914     // dependencies.
915     if ((*input_already_exists)[i]) continue;
916 
917     // If this input is a backedge, assume we won't inherit the dependencies.
918     // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
919     // worth optimizing these.
920     TensorId id(ParseTensorName(node_def->input(i)));
921     auto iter = gdef_nodes_.find(id.first);
922     DCHECK(iter != gdef_nodes_.end()) << id.first;
923     if (iter->second.node == nullptr) {
924       // Input hasn't been created yet, indicating it's a backedge.
925       continue;
926     }
927     inherits_deps = true;
928   }
929   if (inherits_deps) return;
930 
931   // node_def either has no inputs or all remapped inputs, add the control
932   // dependencies
933   for (const string& control_dep : opts_.control_dependencies) {
934     string input = TensorId(control_dep, Graph::kControlSlot).ToString();
935     bool found = false;
936     for (int i = node_def->input_size() - 1; i >= 0; --i) {
937       const string& node_input = node_def->input(i);
938       if (node_input[0] != '^') {
939         // Control inputs are at the end. Break when we reach the non-control
940         // inputs.
941         break;
942       }
943       if (node_input == input) {
944         // Control dependency already exists
945         found = true;
946         break;
947       }
948     }
949     if (found) {
950       continue;
951     }
952     node_def->add_input(input);
953     input_already_exists->push_back(true);
954   }
955 }
956 
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)957 void GraphConstructor::AddPrefixToNodeDef(
958     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
959   if (prefix_.empty()) return;
960   node_def->set_name(strings::StrCat(prefix_, node_def->name()));
961   // Update names of input nodes
962   for (int i = 0; i < node_def->input_size(); ++i) {
963     // Skip remapped inputs (which already exist in g_ and are not being
964     // imported).
965     if (input_already_exists[i]) continue;
966     StringPiece input(node_def->input(i));
967     if (absl::ConsumePrefix(&input, "^")) {
968       node_def->set_input(i, strings::StrCat("^", prefix_, input));
969     } else {
970       node_def->set_input(i, strings::StrCat(prefix_, input));
971     }
972   }
973   // Update names of colocation groups
974   if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
975     auto* list =
976         node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
977     for (int i = 0; i < list->s_size(); ++i) {
978       StringPiece v(list->s(i));
979       if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
980         list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
981       }
982     }
983   }
984 }
985 
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)986 void GraphConstructor::UniquifyNames(
987     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
988   if (NameExistsInGraph(node_def->name())) {
989     string old_name = node_def->name();
990     node_def->set_name(FindUniqueName(node_def->name()));
991     uniquified_names_[old_name] = node_def->name();
992     // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
993     // `name` because we guarantee the original NodeDef names are unique,
994     // meaning we won't generate this name again.
995   }
996   for (int i = 0; i < node_def->input_size(); ++i) {
997     // Skip remapped inputs (which already exist in g_ and are not being
998     // imported).
999     if (input_already_exists[i]) continue;
1000     TensorId id = ParseTensorName(node_def->input(i));
1001     // We require that UniquifyNames() is called on all NodeDefs in topological
1002     // order. This guarantees that node_def's inputs will already be uniquified
1003     // if necessary.
1004     auto iter = uniquified_names_.find(string(id.first));
1005     if (iter == uniquified_names_.end()) continue;
1006     id.first = iter->second;
1007     node_def->set_input(i, id.ToString());
1008   }
1009 }
1010 
UpdateUniquifiedColocationNames()1011 void GraphConstructor::UpdateUniquifiedColocationNames() {
1012   for (const auto& pair : gdef_nodes_) {
1013     Node* node = pair.second.node;
1014     if (node == nullptr) continue;
1015     std::vector<string> coloc_values;
1016     if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1017       continue;
1018     bool updated = false;
1019     for (size_t i = 0; i < coloc_values.size(); ++i) {
1020       StringPiece val(coloc_values[i]);
1021       if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1022         auto name_pair = uniquified_names_.find(string(val));
1023         if (name_pair == uniquified_names_.end()) continue;
1024         updated = true;
1025         coloc_values[i] =
1026             strings::StrCat(kColocationGroupPrefix, name_pair->second);
1027       }
1028     }
1029     if (updated) {
1030       node->AddAttr(kColocationAttrName, std::move(coloc_values));
1031     }
1032   }
1033 }
1034 
NameExistsInGraph(StringPiece name)1035 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1036   if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1037   if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1038   return false;
1039 }
1040 
NameExistsInGraphDef(StringPiece name)1041 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1042   if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1043   if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1044   return false;
1045 }
1046 
FindUniqueName(StringPiece original_name)1047 string GraphConstructor::FindUniqueName(StringPiece original_name) {
1048   string name(original_name);
1049   int count = 0;
1050   // Check that any generated names don't collide with imported NodeDefs (as
1051   // well as nodes in g_).
1052   while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1053     name = strings::StrCat(original_name, "_", ++count);
1054   }
1055   return name;
1056 }
1057 
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)1058 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1059                                            bool* is_node_mapped) {
1060   const OpDef* op_def;
1061   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1062   for (int i = 0; i < op_def->output_arg_size(); ++i) {
1063     if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1064       *is_node_mapped = false;
1065       return Status::OK();
1066     }
1067   }
1068   *is_node_mapped = true;
1069   return Status::OK();
1070 }
1071 
DFS(int cur_node,std::vector<int> * cur_branch,std::vector<bool> * is_on_cur_branch,absl::flat_hash_set<int> * unvisited)1072 void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1073                            std::vector<bool>* is_on_cur_branch,
1074                            absl::flat_hash_set<int>* unvisited) {
1075   cur_branch->push_back(cur_node);
1076   is_on_cur_branch->at(cur_node) = true;
1077   for (auto next_node : outputs_[cur_node]) {
1078     if (unvisited->find(next_node) != unvisited->end()) {
1079       if (is_on_cur_branch->at(next_node)) {
1080         auto iter =
1081             std::find(cur_branch->begin(), cur_branch->end(), next_node);
1082         LOG(WARNING) << "Cycle detected:";
1083         while (iter != cur_branch->end()) {
1084           LOG(WARNING) << SummarizeNodeDef(get_node_def(*iter));
1085           ++iter;
1086         }
1087         LOG(WARNING) << "End of cycle";
1088       } else {
1089         DFS(next_node, cur_branch, is_on_cur_branch, unvisited);
1090       }
1091     }
1092   }
1093   cur_branch->pop_back();
1094   is_on_cur_branch->at(cur_node) = false;
1095   unvisited->erase(cur_node);
1096 }
1097 
PrintCycles()1098 void GraphConstructor::PrintCycles() {
1099   int num_nodes = outputs_.size();
1100   absl::flat_hash_set<int> unvisited;
1101   for (int i = 0; i < num_nodes; i++) {
1102     unvisited.insert(i);
1103   }
1104   while (!unvisited.empty()) {
1105     int cur_node = *unvisited.begin();
1106     // Nodes on the current branch of DFS in traversal order. This is used for
1107     // printing the nodes in the cycle.
1108     std::vector<int> cur_branch;
1109     // This is just to make lookups O(1).
1110     // is_on_cur_branch[i] ==
1111     //   (std::find(cur_branch.start(),
1112     //              cur_branch.end(), i) != cur_branch.end())
1113     std::vector<bool> is_on_cur_branch(num_nodes, false);
1114     DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited);
1115   }
1116 }
1117 
Convert()1118 Status GraphConstructor::Convert() {
1119   // Import functions before adding nodes, since imported nodes may refer to
1120   // functions
1121   if (library()) {
1122     // TODO(b/135705010): Add rvalue overloads into the function library, to
1123     // avoid unnecessarily copying `*library()` here.
1124     TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1125   }
1126 
1127   std::vector<InputInfo> inputs;
1128   int processed = 0;
1129 
1130   std::vector<bool> input_already_exists;
1131 
1132   // Process the NodeDefs in topological order.
1133   // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1134   // inputs, pending_counts_ with the number of inputs for each node and
1135   // outputs_ with the outputs of each node).
1136   while (!ready_.empty()) {
1137     int o = *ready_.begin();
1138     ready_.erase(ready_.begin());
1139     ++processed;
1140     inputs.clear();
1141     bool has_data_back_edge = false;
1142 
1143     NodeDef node_def = consume_node_def(o);
1144 
1145     // input_already_exists[i] is true iff the i-th input of the node we're
1146     // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1147     // to importing node_defs_).  Conversely, input_already_exists[i] is false
1148     // iff the input refers to a node in node_defs_.
1149     input_already_exists.clear();
1150     input_already_exists.resize(node_def.input_size(), false);
1151 
1152     ssize_t string_intern_table_index = -1;
1153 
1154     if (opts_.importing) {
1155       // Intern the original node name, so that we can use a StringPiece of the
1156       // name to index gdef_nodes_.
1157       string_intern_table_index = string_intern_table_.size();
1158       string_intern_table_.push_back(node_def.name());
1159 
1160       if (opts_.skip_mapped_nodes) {
1161         bool is_node_mapped = false;
1162         TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1163         if (is_node_mapped) {
1164           // Skip this node after updating pending_count_ for outputs
1165           UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1166           continue;
1167         }
1168       }
1169 
1170       if (!opts_.input_map.empty()) {
1171         // Note that input_already_exists can shrink here
1172         RemapNodeDefInputs(&node_def, &input_already_exists);
1173       }
1174       if (!opts_.control_dependencies.empty()) {
1175         // Note that input_already_exists can grow here
1176         AddControlDependencies(&node_def, &input_already_exists);
1177       }
1178       if (!opts_.default_device.empty() && node_def.device().empty()) {
1179         node_def.set_device(opts_.default_device);
1180       }
1181     }
1182 
1183     DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1184     TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1185     for (int i = 0; i < node_def.input_size(); ++i) {
1186       TensorId tensor_id = ParseTensorName(node_def.input(i));
1187       Node* src_node;
1188       int src_index;
1189 
1190       if (!input_already_exists[i]) {
1191         // Locate input in newly-imported nodes
1192         auto iter = gdef_nodes_.find(tensor_id.node());
1193         DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1194         src_node = iter->second.node;
1195         src_index = tensor_id.index();
1196         if (src_node == nullptr) has_data_back_edge = true;
1197       } else {
1198         // Input refers to preexistng node in graph
1199         auto iter = existing_nodes_.find(tensor_id.node());
1200         DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1201         src_node = iter->second;
1202         src_index = tensor_id.index();
1203       }
1204 
1205       if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1206         std::ostringstream out;
1207         out << "Node '" << node_def.name() << "': Connecting to invalid output "
1208             << tensor_id.index() << " of source node " << tensor_id.node()
1209             << " which has " << src_node->num_outputs() << " outputs.";
1210 
1211         if (src_node->type_string() == "If" ||
1212             src_node->type_string() == "StatelessIf" ||
1213             src_node->type_string() == "While" ||
1214             src_node->type_string() == "StatelessWhile") {
1215           out << " Try using "
1216               << "tf.compat.v1.experimental.output_all_intermediates(True).";
1217         }
1218         return errors::InvalidArgument(out.str());
1219       }
1220 
1221       inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1222     }
1223 
1224     if (has_data_back_edge && !IsMerge(node_def)) {
1225       return errors::InvalidArgument(
1226           "Node '", node_def.name(),
1227           "' had a back edge, but only Merge nodes can have back edges.");
1228     }
1229 
1230     Node* node;
1231     if (opts_.importing) {
1232       if (!prefix_.empty()) {
1233         AddPrefixToNodeDef(input_already_exists, &node_def);
1234       }
1235       // Note: no need to uniquify names if the prefix already guarantees
1236       // uniqueness
1237       if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1238         UniquifyNames(input_already_exists, &node_def);
1239       }
1240     }
1241 
1242     if (opts_.importing) {
1243       TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1244     } else {
1245       const OpDef* op_def;
1246       TF_RETURN_IF_ERROR(
1247           g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1248       if (opts_.add_default_attributes) {
1249         AddDefaultsToNodeDef(*op_def, &node_def);
1250       }
1251       if (opts_.validate_nodes) {
1252         TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1253       }
1254     }
1255 
1256     TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1257 
1258     if (opts_.importing) {
1259       // Use interned original node name so StringPiece remains valid.
1260       DCHECK_GE(string_intern_table_index, 0);
1261       gdef_nodes_[string_intern_table_[string_intern_table_index]].node = node;
1262     } else {
1263       DCHECK_EQ(string_intern_table_index, -1);
1264       gdef_nodes_[node->name()].node = node;
1265     }
1266 
1267     // Remove duplicate control inputs before adding edges to the graph. It
1268     // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1269     auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1270     auto first_control_copy = first_control;
1271     std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1272     inputs.erase(
1273         std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1274         inputs.end());
1275 
1276     // Add edges from inputs to *node to the graph.
1277     for (size_t i = 0; i < inputs.size(); ++i) {
1278       if (inputs[i].node == nullptr) {
1279         // Record this back edge, which will be added after all nodes
1280         // are created.
1281         back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1282       } else if (inputs[i].index == Graph::kControlSlot) {
1283         g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1284       } else {
1285         TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1286       }
1287     }
1288 
1289     TF_RETURN_IF_ERROR(ValidateShape(node));
1290 
1291     // Update pending_count_ for outputs.
1292     UpdatePendingCountAndReady(o, node->IsNextIteration());
1293   }
1294 
1295   if (processed < node_def_count()) {
1296     LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1297                  << " NODES IN A CYCLE";
1298     for (int64 i = 0; i < node_def_count(); i++) {
1299       if (pending_count_[i] != 0) {
1300         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1301                      << " WITH PENDING COUNT = " << pending_count_[i];
1302       }
1303     }
1304     PrintCycles();
1305     return errors::InvalidArgument(node_def_count() - processed,
1306                                    " nodes in a cycle");
1307   }
1308 
1309   return Status::OK();
1310 }
1311 
AddBackEdges()1312 Status GraphConstructor::AddBackEdges() {
1313   // Add the back edges after all nodes are created.
1314   for (auto e : back_edges_) {
1315     Node* src_node = gdef_nodes_[e.src_name].node;
1316     if (e.src_index == Graph::kControlSlot) {
1317       g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1318     } else {
1319       TF_RETURN_IF_ERROR(
1320           MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1321     }
1322 
1323     VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1324             << e.dst_node->name();
1325   }
1326   return Status::OK();
1327 }
1328 
UpdateVersionDef()1329 Status GraphConstructor::UpdateVersionDef() {
1330   if (versions() == nullptr) return Status::OK();
1331 
1332   if (!opts_.importing) {
1333     g_->set_versions(*versions());
1334     return Status::OK();
1335   }
1336   VersionDef g_versions = g_->versions();
1337   g_versions.set_producer(
1338       std::min(g_versions.producer(), versions()->producer()));
1339   g_versions.set_min_consumer(
1340       std::max(g_versions.min_consumer(), versions()->min_consumer()));
1341   if (versions()->bad_consumers_size() > 0) {
1342     std::set<int> bad(g_versions.bad_consumers().begin(),
1343                       g_versions.bad_consumers().end());
1344     bad.insert(versions()->bad_consumers().begin(),
1345                versions()->bad_consumers().end());
1346     g_versions.clear_bad_consumers();
1347     for (int v : bad) {
1348       g_versions.add_bad_consumers(v);
1349     }
1350   }
1351   g_->set_versions(g_versions);
1352   return Status::OK();
1353 }
1354 
PopulateReturnTensors()1355 Status GraphConstructor::PopulateReturnTensors() {
1356   if (opts_.return_tensors.empty()) return Status::OK();
1357   for (const TensorId& id : opts_.return_tensors) {
1358     auto iter = opts_.input_map.find(id);
1359     if (iter == opts_.input_map.end()) {
1360       // Locate id in imported nodes
1361       auto iter = gdef_nodes_.find(id.first);
1362       if (iter == gdef_nodes_.end()) {
1363         return errors::InvalidArgument("Requested return tensor '",
1364                                        id.ToString(),
1365                                        "' not found in graph def");
1366       }
1367       int num_outputs = iter->second.node->num_outputs();
1368       if ((id.second < 0 || id.second >= num_outputs) &&
1369           id.second != Graph::kControlSlot) {
1370         return errors::InvalidArgument("Invalid return output ", id.second,
1371                                        " of node '", id.first, "', which has ",
1372                                        num_outputs, " output(s)");
1373       }
1374       return_tensors_->push_back({iter->second.node, id.second});
1375     } else {
1376       // id was remapped to existing node
1377       TensorId remapped_id = iter->second;
1378       DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1379       Node* node = existing_nodes_[remapped_id.first];
1380       return_tensors_->push_back({node, remapped_id.second});
1381     }
1382   }
1383   return Status::OK();
1384 }
1385 
PopulateReturnNodes()1386 Status GraphConstructor::PopulateReturnNodes() {
1387   if (opts_.return_nodes.empty()) return Status::OK();
1388   for (StringPiece name : opts_.return_nodes) {
1389     auto iter = gdef_nodes_.find(name);
1390     if (iter == gdef_nodes_.end()) {
1391       return errors::InvalidArgument("Requested return node '", name,
1392                                      "' not found in graph def");
1393     }
1394     return_nodes_->push_back(iter->second.node);
1395   }
1396   return Status::OK();
1397 }
1398 
PopulateMissingUnusedInputMapKeys()1399 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1400   if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1401   for (const auto& input_map_pair : opts_.input_map) {
1402     TensorId key = input_map_pair.first;
1403     if (used_input_map_keys_.count(key) > 0) continue;
1404 
1405     auto pair = gdef_nodes_.find(key.first);
1406     if (pair == gdef_nodes_.end()) {
1407       // key's node doesn't exist in GraphDef
1408       missing_unused_input_map_keys_->push_back(key);
1409       continue;
1410     }
1411 
1412     // Check that key's index is in bounds. Get the number of outputs from the
1413     // NodeDef, rather than the imported Node, since the Node may not exist if
1414     // opts_.skip_mapped_nodes is true.
1415     const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1416     const OpDef* op_def;
1417     TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1418     int num_outputs;
1419     TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1420     if (key.second >= num_outputs) {
1421       // key's index out of bounds
1422       missing_unused_input_map_keys_->push_back(key);
1423     }
1424   }
1425   return Status::OK();
1426 }
1427 
Undo()1428 void GraphConstructor::Undo() {
1429   for (const auto& iter : gdef_nodes_) {
1430     if (iter.second.node != nullptr) {
1431       g_->RemoveNode(iter.second.node);
1432     }
1433   }
1434   g_->set_versions(original_versions_);
1435 }
1436 
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1437 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1438                                   int input_index) {
1439   DataType src_out = src->output_type(output_index);
1440   DataType dst_in = dst->input_type(input_index);
1441   if (!TypesCompatible(dst_in, src_out)) {
1442     return errors::InvalidArgument(
1443         "Input ", input_index, " of node ", dst->name(), " was passed ",
1444         DataTypeString(src_out), " from ", src->name(), ":", output_index,
1445         " incompatible with expected ", DataTypeString(dst_in), ".");
1446   }
1447   g_->AddEdge(src, output_index, dst, input_index);
1448   return Status::OK();
1449 }
1450 
1451 }  // namespace
1452 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1453 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1454                               const GraphDef& gdef, Graph* g) {
1455   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1456   return GraphConstructor::Construct(
1457       opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1458       /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1459       /*missing_unused_input_map_keys=*/nullptr);
1460 }
1461 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,GraphDef && gdef,Graph * g)1462 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1463                               GraphDef&& gdef, Graph* g) {
1464   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1465   return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1466                                      /*return_tensors=*/nullptr,
1467                                      /*return_nodes=*/nullptr,
1468                                      /*missing_unused_input_map_keys=*/nullptr);
1469 }
1470 
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1471 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1472                               gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1473   ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1474   // TODO(irving): Copy will go away once NodeInfo exists
1475   std::vector<const NodeDef*> node_defs;
1476   node_defs.reserve(nodes.size());
1477   for (const auto& n : nodes) {
1478     node_defs.push_back(&n);
1479   }
1480   return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1481                                      &refiner, /*return_tensors=*/nullptr,
1482                                      /*return_nodes=*/nullptr,
1483                                      /*missing_unused_input_map_keys=*/nullptr);
1484 }
1485 
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1486 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1487                       Graph* g, ShapeRefiner* refiner,
1488                       ImportGraphDefResults* results) {
1489   if (!opts.return_tensors.empty()) {
1490     if (results == nullptr) {
1491       return errors::InvalidArgument(
1492           "results argument to ImportGraphDef() must be non-null if "
1493           "opts.return_tensors is non-empty");
1494     }
1495   }
1496 
1497   if (!opts.return_nodes.empty()) {
1498     if (opts.skip_mapped_nodes) {
1499       return errors::InvalidArgument(
1500           "Requesting return_nodes with skip_mapped_nodes set is not currently "
1501           "supported");
1502     }
1503     if (results == nullptr) {
1504       return errors::InvalidArgument(
1505           "results argument to ImportGraphDef() must be non-null if "
1506           "opts.return_nodes is non-empty");
1507     }
1508   }
1509 
1510   if (results != nullptr) {
1511     if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1512         !results->missing_unused_input_map_keys.empty()) {
1513       return errors::InvalidArgument(
1514           "All fields in results argument to ImportGraphDef() must be empty.");
1515     }
1516   }
1517 
1518   ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1519   if (refiner == nullptr) {
1520     refiner = &default_refiner;
1521   } else {
1522     // Log a warning if we are importing a GraphDef at an older
1523     // producer version after already having added non-source/sink
1524     // nodes to the graph in the past.
1525     if (gdef.versions().producer() > 0 &&
1526         gdef.versions().producer() < refiner->graph_def_version() &&
1527         g->num_nodes() > 2) {
1528       LOG(WARNING) << "Importing a graph with a lower producer version "
1529                    << gdef.versions().producer()
1530                    << " into an existing graph with producer version "
1531                    << refiner->graph_def_version() << ". Shape inference will "
1532                    << "have run different parts of the graph with different "
1533                    << "producer versions.";
1534     }
1535   }
1536 
1537   // Set the graph def version of the refiner as the min of the
1538   // current value and the version from the graph we are about to
1539   // import.
1540   //
1541   // Note: to match Run() semantics, we should re-run shape inference
1542   // on the entire graph if the producer version has changed.  For now
1543   // we log the warning above.
1544   refiner->set_graph_def_version(
1545       std::min(refiner->graph_def_version(), gdef.versions().producer()));
1546 
1547   if (results == nullptr) {
1548     return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1549                                        &gdef.library(), g, refiner, nullptr,
1550                                        nullptr, nullptr);
1551   } else {
1552     return GraphConstructor::Construct(
1553         opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1554         &results->return_tensors, &results->return_nodes,
1555         &results->missing_unused_input_map_keys);
1556   }
1557 }
1558 
CopyGraph(const Graph & src,Graph * dest)1559 void CopyGraph(const Graph& src, Graph* dest) {
1560   for (Node* n : dest->nodes()) {
1561     CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
1562   }
1563 
1564   // Copy GraphDef versions
1565   dest->set_versions(src.versions());
1566 
1567   // Copy the nodes.
1568   // "Node in src" -> "Node in *dest"
1569   gtl::FlatMap<const Node*, Node*> node_map;
1570   node_map[src.source_node()] = dest->source_node();
1571   node_map[src.sink_node()] = dest->sink_node();
1572   for (Node* n : src.op_nodes()) {
1573     node_map[n] = dest->CopyNode(n);
1574   }
1575 
1576   // Copy the edges
1577   for (const Edge* e : src.edges()) {
1578     Node* src_copy = node_map[e->src()];
1579     Node* dst_copy = node_map[e->dst()];
1580     dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1581   }
1582 }
1583 
1584 }  // namespace tensorflow
1585