• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/graph_constructor.h"
17 
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/core/common_runtime/shape_refiner.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/versions.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/tensor_id.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/flatmap.h"
43 #include "tensorflow/core/lib/gtl/flatset.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/scanner.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/public/version.h"
51 
52 namespace tensorflow {
53 
54 namespace {
55 
56 // We remove duplicate control inputs before adding edges to the Graph, so we
57 // can skip expensive duplicates check in 'AddControlEdge'.
58 static constexpr const bool kDoNotCheckDuplicates = true;
59 
IsMerge(const NodeDef & node_def)60 inline bool IsMerge(const NodeDef& node_def) {
61   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
62          node_def.op() == "_XlaMerge";
63 }
64 
IsNextIteration(const NodeDef & node_def)65 inline bool IsNextIteration(const NodeDef& node_def) {
66   return node_def.op() == "NextIteration" ||
67          node_def.op() == "RefNextIteration";
68 }
69 
IsValidNodeName(StringPiece s,bool allow_internal_ops)70 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
71   using ::tensorflow::strings::Scanner;
72   Scanner scanner(s);
73   scanner
74       .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
75                               : Scanner::LETTER_DIGIT_DOT)
76       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
77 
78   while (true) {
79     if (!scanner.GetResult())  // Some error in previous iteration.
80       return false;
81     if (scanner.empty())  // No error, but nothing left, good.
82       return true;
83 
84     // Absorb another piece, starting with a '>'
85     scanner.One(Scanner::RANGLE)
86         .One(Scanner::LETTER_DIGIT_DOT)
87         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
88   }
89 }
90 
91 class GraphConstructor {
92  public:
93   struct Options {
Optionstensorflow::__anon7b82f73f0111::GraphConstructor::Options94     Options(const GraphConstructorOptions& in)  // NOLINT(runtime/explicit)
95         : allow_internal_ops(in.allow_internal_ops),
96           expect_device_spec(in.expect_device_spec),
97           importing(false),
98           validate_nodes(in.validate_nodes),
99           validate_colocation_constraints(false),
100           add_default_attributes(in.add_default_attributes) {}
Optionstensorflow::__anon7b82f73f0111::GraphConstructor::Options101     Options(const ImportGraphDefOptions& in)  // NOLINT(runtime/explicit)
102         : allow_internal_ops(false),
103           expect_device_spec(false),
104           prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
105                      ? in.prefix
106                      : in.prefix + "/"),
107           uniquify_names(in.uniquify_names),
108           uniquify_prefix(in.uniquify_prefix),
109           input_map(in.input_map.begin(), in.input_map.end()),
110           skip_mapped_nodes(in.skip_mapped_nodes),
111           control_dependencies(in.control_dependencies),
112           return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
113           return_nodes(in.return_nodes),
114           importing(true),
115           validate_nodes(true),
116           validate_colocation_constraints(in.validate_colocation_constraints),
117           validate_shape(in.validate_shape),
118           default_device(in.default_device) {}
119 
120     bool allow_internal_ops;
121     bool expect_device_spec;
122 
123     string prefix;
124     bool uniquify_names;
125     bool uniquify_prefix;
126     std::map<TensorId, TensorId> input_map;
127     bool skip_mapped_nodes;
128     std::vector<string> control_dependencies;
129     std::vector<TensorId> return_tensors;
130     std::vector<string> return_nodes;
131 
132     // TODO(ashankar): This bool exists to separate out functionality required
133     // to make ImportGraphDef a close equivalent of Python's import_graph_def
134     // without affecting the behavior of ConvertGraphDefToGraph at the time
135     // ImportGraphDef was added.
136     //
137     // That said, the functionality here (shape and op validation) seems
138     // applicable to ConvertGraphDefToGraph as well, so make an attempt to
139     // remove this.
140     bool importing;
141     // If true, validates that nodes being converted have all expected attrs
142     // set and no unknown attrs set by calling ValidateNodeDef().
143     // `validate_nodes` is always true when `importing` is set.
144     bool validate_nodes;
145     bool validate_colocation_constraints;
146     bool validate_shape = true;
147 
148     // If true, GraphConstructor will add attributes with their default
149     // value to the Node when they are missing from the NodeDef.
150     bool add_default_attributes = true;
151 
152     string default_device;
153   };
154 
155   typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
156 
157   // versions and library may be nullptr
158   static Status Construct(
159       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
160       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
161       std::vector<std::pair<Node*, int>>* return_tensors,
162       std::vector<Node*>* return_nodes,
163       std::vector<SafeTensorId>* missing_unused_input_map_keys);
164 
165   static Status Construct(
166       const Options& opts, GraphDef&& graph_def, Graph* g,
167       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
168       std::vector<Node*>* return_nodes,
169       std::vector<SafeTensorId>* missing_unused_input_map_keys);
170 
171  protected:
GraphConstructor(const Options & opts,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)172   GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
173                    std::vector<std::pair<Node*, int>>* return_tensors,
174                    std::vector<Node*>* return_nodes,
175                    std::vector<SafeTensorId>* missing_unused_input_map_keys)
176       : opts_(opts),
177         g_(g),
178         original_versions_(g->versions()),
179         prefix_(opts.prefix),
180         refiner_(refiner),
181         return_tensors_(return_tensors),
182         return_nodes_(return_nodes),
183         missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
184 
~GraphConstructor()185   virtual ~GraphConstructor() {}
186 
TryImport()187   Status TryImport() {
188     TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
189     TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
190     TF_RETURN_IF_ERROR(BuildNodeIndex());
191     TF_RETURN_IF_ERROR(InitFromEdges());
192 
193     // NOTE: Convert() invokes `consume_node_def()` on each node in the input
194     // graph, so `get_node_def()` is no longer usable once it is called.
195     TF_RETURN_IF_ERROR(Convert());
196 
197     TF_RETURN_IF_ERROR(AddBackEdges());
198     TF_RETURN_IF_ERROR(UpdateVersionDef());
199     TF_RETURN_IF_ERROR(PopulateReturnTensors());
200     TF_RETURN_IF_ERROR(PopulateReturnNodes());
201     TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
202     UpdateUniquifiedColocationNames();
203     FixupSourceAndSinkEdges(g_);
204     return Status::OK();
205   }
206 
207  private:
208   Status EnsureNoNameCollisions();
209   Status ValidateInputMapAndControlDependencies();
210   Status BuildNodeIndex();
211   Status InitFromEdges();
212   Status Convert();
213   Status AddBackEdges();
214   Status UpdateVersionDef();
215   Status PopulateReturnTensors();
216   Status PopulateReturnNodes();
217   Status PopulateMissingUnusedInputMapKeys();
218 
219   void Undo();
220 
221   // Prints cycles in the graph.
222   void PrintCycles();
223   // Performs DFS starting at `cur_node` and prints any cycles found.
224   void DFS(int cur_node, std::vector<int>* cur_branch,
225            std::vector<bool>* is_on_cur_branch,
226            absl::flat_hash_set<int>* unvisited);
227   Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
228   Status ValidateColocationConstraints(const NodeDef& node_def);
229   Status MakeNode(NodeDef&& node_def, Node** node);
230   Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
231   Status ValidateShape(Node* node);
232   Status ModifyNodeDefForImport(NodeDef* node_def);
233   // Modifies node_def's inputs according to opts_.input_map.
234   // input_already_exists is a pre-initialized vector of length
235   // node_def->input_size(). This function will mark inputs that are remapped to
236   // true.
237   void RemapNodeDefInputs(NodeDef* node_def,
238                           std::vector<bool>* input_already_exists);
239   // input_already_exists is a pre-initialized vector of length
240   // node_def->input_size(). This function will add and mark control inputs as
241   // true.
242   void AddControlDependencies(NodeDef* node_def,
243                               std::vector<bool>* input_already_exists);
244   void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
245                           NodeDef* node_def);
246 
247   // Modifies `node_def` if its name isn't unique, or if any of its inputs'
248   // names have been uniquified. This must be called in topological order on all
249   // nodes.
250   void UniquifyNames(const std::vector<bool>& input_already_exists,
251                      NodeDef* node_def);
252 
253   // Updates any constructed nodes' colocation group names if the name has been
254   // updated by UniquifyNames. This is called after all the nodes have been
255   // constructed so all the names have been uniquified if necessary.
256   void UpdateUniquifiedColocationNames();
257 
258   // Returns true if `name` already exists in `g_` (either as a node name or
259   // prefix).
260   bool NameExistsInGraph(StringPiece name);
261 
262   // Returns true if `name` already exists in the GraphDef being imported
263   // (either as a node name or prefix).
264   bool NameExistsInGraphDef(StringPiece name);
265 
266   // Returns a unique version of `original_name`, or `original_name` if it's
267   // already unique in the graph.
268   string FindUniqueName(StringPiece original_name);
269 
270   // Decrement pending count for users of `processed` and add the ones that now
271   // have all of their pending inputs satisfied to `ready_`.
272   void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
273 
274   // Subclasses override the following virtual methods to provide efficient
275   // access to the original protocol buffer-based graph.
276 
277   // Returns the number of nodes in the graph.
278   virtual size_t node_def_count() const = 0;
279   // Returns the i^th node in the graph. Must not be called after
280   // consume_node_def(i).
281   virtual const NodeDef& get_node_def(int i) const = 0;
282   // Destructively reads the i^th node in the graph, avoiding a copy if
283   // possible. After calling this method, the result of get_node_def(i) is
284   // undefined.
285   virtual NodeDef consume_node_def(int i) = 0;
286   // Returns the version information for the graph, or nullptr if none is
287   // available.
288   virtual const VersionDef* versions() const = 0;
289   // Returns the function information for the graph, or nullptr if none is
290   // available.
291   virtual const FunctionDefLibrary* library() const = 0;
292 
293   // From constructor
294   const Options opts_;
295   Graph* g_;
296   const VersionDef original_versions_;
297 
298   // A copy of opts_.prefix, possibly uniquified.
299   string prefix_;
300 
301   ShapeRefiner* refiner_;
302 
303   // May be null. Not owned.
304   std::vector<std::pair<Node*, int>>* return_tensors_;
305 
306   // May be null. Not owned.
307   std::vector<Node*>* return_nodes_;
308 
309   // May be null. Not owned.
310   std::vector<SafeTensorId>* missing_unused_input_map_keys_;
311 
312   // Intermediate datastructure used to populate
313   // `missing_unused_input_map_keys_`.
314   std::set<TensorId> used_input_map_keys_;
315 
316   // Intermediate datastructure used to track the destinations of back edges.
317   absl::flat_hash_set<int> merge_node_indices_;
318 
319   // Mapping from node name to the index within node_defs_.
320   struct NodeInfo {
NodeInfotensorflow::__anon7b82f73f0111::GraphConstructor::NodeInfo321     explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
322     // Containers require that we have a default constructor.
NodeInfotensorflow::__anon7b82f73f0111::GraphConstructor::NodeInfo323     NodeInfo() : NodeInfo(-1) {}
324     int gdef_index;
325     Node* node;  // nullptr until the NodeDef is converted to a Node.
326   };
327   absl::flat_hash_map<std::string, NodeInfo> gdef_nodes_;
328 
329   // Prefixes already used in the GraphDef being imported.
330   absl::flat_hash_set<StringPiece> gdef_prefixes_;
331 
332   // Mapping from node name to the existing node in g_.
333   absl::flat_hash_map<StringPiece, Node*> existing_nodes_;
334 
335   // Prefixes already used in the graph.
336   absl::flat_hash_set<StringPiece> existing_prefixes_;
337 
338   // Imported node names that have been uniquified. The key is the original
339   // name, the value is the new unique name.
340   gtl::FlatMap<string, string> uniquified_names_;
341 
342   // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
343   // (sorted) set so nodes are created in the order defined in the GraphDef.
344   std::set<int> ready_;
345 
346   // Mapping between index within node_defs_ and the number of inputs that
347   // still need to be converted.
348   std::vector<int> pending_count_;
349 
350   // Mapping between index within node_defs_ and the index within node_defs_ of
351   // all nodes it outputs to.
352   std::vector<gtl::InlinedVector<int, 4>> outputs_;
353 
354   // Used in the conversion from node_defs_ to g_ to represent the ith input
355   // of a node.
356   struct InputInfo {
InputInfotensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo357     explicit InputInfo(const string& node_name, Node* n, int i)
358         : name(node_name), node(n), index(i) {}
359     // Use string instead of StringPiece so we don't have to manage lifetime
360     string name;
361     Node* node;
362     int index;
363 
IsControlInputtensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo364     static bool IsControlInput(const InputInfo& input) {
365       return input.index == Graph::kControlSlot;
366     }
CompareNametensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo367     static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
368       return lhs.name < rhs.name;
369     }
IsSameNametensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo370     static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
371       return lhs.name == rhs.name;
372     }
373   };
374 
375   // Used in the conversion from node_defs_ to g_ to represent an edge from
376   // the node named 'name' to node 'n'.
377   struct EdgeInfo {
EdgeInfotensorflow::__anon7b82f73f0111::GraphConstructor::EdgeInfo378     explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
379         : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
380     // Use string instead of StringPiece so we don't have to manage lifetime
381     string src_name;
382     int src_index;
383     Node* dst_node;
384     int dst_index;
385   };
386   std::vector<EdgeInfo> back_edges_;
387 
388   TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
389 };
390 
391 // Implementation of GraphConstructor that does not take ownership of the
392 // input NodeDef messages and thus copies the nodes into the constructed Graph*.
393 //
394 // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
395 // avoids copying each NodeDef into the constructed Graph*.
396 class NodeDefCopyingGraphConstructor : public GraphConstructor {
397  public:
NodeDefCopyingGraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)398   NodeDefCopyingGraphConstructor(
399       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
400       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
401       std::vector<std::pair<Node*, int>>* return_tensors,
402       std::vector<Node*>* return_nodes,
403       std::vector<SafeTensorId>* missing_unused_input_map_keys)
404       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
405                          missing_unused_input_map_keys),
406         node_defs_(node_defs),
407         versions_(versions),
408         library_(library) {}
409 
410  private:
node_def_count() const411   size_t node_def_count() const override { return node_defs_.size(); }
get_node_def(int i) const412   const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
consume_node_def(int i)413   NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
versions() const414   const VersionDef* versions() const override { return versions_; }
library() const415   const FunctionDefLibrary* library() const override { return library_; }
416 
417   const NodeDefSlice node_defs_;
418   const VersionDef* const versions_;
419   const FunctionDefLibrary* const library_;
420 };
421 
422 // Implementation of GraphConstructor that takes ownership of the input
423 // GraphDef, and can perform destructive reads.
424 class NodeDefMovingGraphConstructor : public GraphConstructor {
425  public:
NodeDefMovingGraphConstructor(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)426   NodeDefMovingGraphConstructor(
427       const Options& opts, GraphDef&& graph_def, Graph* g,
428       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
429       std::vector<Node*>* return_nodes,
430       std::vector<SafeTensorId>* missing_unused_input_map_keys)
431       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
432                          missing_unused_input_map_keys),
433         graph_def_(std::move(graph_def)),
434         is_consumed_(graph_def_.node_size(), false) {}
435 
436  private:
node_def_count() const437   size_t node_def_count() const override { return graph_def_.node().size(); }
get_node_def(int i) const438   const NodeDef& get_node_def(int i) const override {
439     CHECK(!is_consumed_[i])
440         << "NodeDef " << i << " accessed after it was consumed.";
441     return graph_def_.node(i);
442   }
consume_node_def(int i)443   NodeDef consume_node_def(int i) override {
444     CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
445     is_consumed_[i] = true;
446     return std::move(*graph_def_.mutable_node(i));
447   }
versions() const448   const VersionDef* versions() const override { return &graph_def_.versions(); }
library() const449   const FunctionDefLibrary* library() const override {
450     return &graph_def_.library();
451   }
452 
453   GraphDef graph_def_;
454   std::vector<bool> is_consumed_;
455 };
456 
ForwardCompatibilityWindowPassed(const VersionDef & versions)457 bool ForwardCompatibilityWindowPassed(const VersionDef& versions) {
458   // TF_GRAPH_DEF_VERSION is incremented daily.
459   // TF has a 3 week forward compatibility guarantee.
460   return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21;
461 }
462 
MaybeAppendVersionWarning(const VersionDef * versions,const Status & import_status)463 Status MaybeAppendVersionWarning(const VersionDef* versions,
464                                  const Status& import_status) {
465   if (versions && ForwardCompatibilityWindowPassed(*versions)) {
466     return Status(
467         import_status.code(),
468         absl::StrCat(
469             "Converting GraphDef to Graph has failed with an error: '",
470             import_status.error_message(),
471             "' The binary trying to import the GraphDef was built when "
472             "GraphDef version was ",
473             TF_GRAPH_DEF_VERSION,
474             ". The GraphDef was produced by a binary built when GraphDef "
475             "version was ",
476             versions->producer(),
477             ". The difference between these versions is larger than "
478             "TensorFlow's forward compatibility guarantee, and might be the "
479             "root cause for failing to import the GraphDef."));
480   }
481   return import_status;
482 }
483 
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)484 /* static */ Status GraphConstructor::Construct(
485     const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
486     const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
487     std::vector<std::pair<Node*, int>>* return_tensors,
488     std::vector<Node*>* return_nodes,
489     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
490   if (versions) {
491     TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
492                                      TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
493                                      "GraphDef", "graph"));
494   }
495   NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
496                                    refiner, return_tensors, return_nodes,
497                                    missing_unused_input_map_keys);
498   Status s = c.TryImport();
499   if (!s.ok()) {
500     c.Undo();
501     s = MaybeAppendVersionWarning(versions, s);
502   }
503   return s;
504 }
505 
Construct(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)506 /* static */ Status GraphConstructor::Construct(
507     const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
508     std::vector<std::pair<Node*, int>>* return_tensors,
509     std::vector<Node*>* return_nodes,
510     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
511   TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
512                                    TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
513                                    "GraphDef", "graph"));
514   VersionDef version_def = graph_def.versions();
515   NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
516                                   return_tensors, return_nodes,
517                                   missing_unused_input_map_keys);
518   Status s = c.TryImport();
519   if (!s.ok()) {
520     c.Undo();
521     s = MaybeAppendVersionWarning(&version_def, s);
522   }
523   return s;
524 }
525 
UpdatePendingCountAndReady(int processed,bool is_next_iteration)526 void GraphConstructor::UpdatePendingCountAndReady(int processed,
527                                                   bool is_next_iteration) {
528   for (size_t i = 0; i < outputs_[processed].size(); ++i) {
529     const int output = outputs_[processed][i];
530     // We didn't consider NextIteration->Merge edges when computing
531     // pending_counts_ so we should not have to consider it here either.
532     bool is_next_iteration_to_merge_edge =
533         is_next_iteration && merge_node_indices_.count(output) == 1;
534     if (!is_next_iteration_to_merge_edge) {
535       int* current_pending_count = &pending_count_[output];
536       CHECK_GT(*current_pending_count, 0);
537       (*current_pending_count)--;
538       if (*current_pending_count == 0) {
539         ready_.insert(output);
540       }
541     }
542   }
543 }
544 
545 // This could be expensive but we don't expect to call it often, if at all (only
546 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)547 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
548                       const StringPiece& node_name) {
549   for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
550     if (iter->second.first == node_name) return true;
551   }
552   return false;
553 }
554 
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)555 bool NodeNameInValues(const std::vector<string>& control_dependencies,
556                       const StringPiece& node_name) {
557   return std::find(control_dependencies.begin(), control_dependencies.end(),
558                    node_name) != control_dependencies.end();
559 }
560 
561 // Adds any prefixes of `node_name` (not including the full name itself) to
562 // `prefixes`.
AddPrefixes(StringPiece node_name,absl::flat_hash_set<StringPiece> * prefixes)563 void AddPrefixes(StringPiece node_name,
564                  absl::flat_hash_set<StringPiece>* prefixes) {
565   size_t idx = -1;
566   while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
567     prefixes->insert(node_name.substr(0, idx));
568   }
569 }
570 
EnsureNoNameCollisions()571 Status GraphConstructor::EnsureNoNameCollisions() {
572   existing_nodes_.reserve(g_->num_nodes());
573   // Populate existing_nodes_ and existing_prefixes_.
574   for (Node* n : g_->nodes()) {
575     bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
576     if (already_exists) {
577       if (NodeNameInValues(opts_.input_map, n->name())) {
578         return errors::InvalidArgument(
579             "cannot resolve input_map because multiple nodes exist with name '",
580             n->name(), "'");
581       }
582       if (NodeNameInValues(opts_.control_dependencies, n->name())) {
583         return errors::InvalidArgument(
584             "cannot resolve control_dependencies because multiple nodes exist "
585             "with name '",
586             n->name(), "'");
587       }
588     }
589     AddPrefixes(n->name(), &existing_prefixes_);
590   }
591   if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
592     for (size_t i = 0; i < node_def_count(); ++i) {
593       const string& name = get_node_def(i).name();
594       if (NameExistsInGraph(name)) {
595         return errors::InvalidArgument("Node name '", name,
596                                        "' already exists in the Graph");
597       }
598     }
599   } else if (!prefix_.empty()) {
600     StringPiece prefix_no_slash(prefix_);
601     prefix_no_slash.remove_suffix(1);
602     if (!IsValidNodeName(prefix_no_slash, false)) {
603       return errors::InvalidArgument("Imported node name prefix '", prefix_,
604                                      "' would lead to invalid node names");
605     }
606     if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
607       prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
608     }
609   }
610   return Status::OK();
611 }
612 
ValidateInputMapAndControlDependencies()613 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
614   for (const auto& mapping : opts_.input_map) {
615     TensorId src = mapping.first;
616     TensorId dst = mapping.second;
617     if (existing_nodes_.count(dst.first) == 0) {
618       return errors::InvalidArgument(
619           "node '", dst.first, "' in input_map does not exist in graph ",
620           "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
621     }
622     if ((src.second == Graph::kControlSlot) !=
623         (dst.second == Graph::kControlSlot)) {
624       return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
625                                      dst.ToString(), " between ",
626                                      "control edge and non-control edge");
627     }
628   }
629   for (const string& node : opts_.control_dependencies) {
630     if (existing_nodes_.count(node) == 0) {
631       return errors::InvalidArgument(
632           "node '", node,
633           "' in control_dependencies does not exist in "
634           "graph");
635     }
636   }
637   return Status::OK();
638 }
639 
BuildNodeIndex()640 Status GraphConstructor::BuildNodeIndex() {
641   // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
642   for (int n = 0; n < node_def_count(); ++n) {
643     const NodeDef& node_def = get_node_def(n);
644     if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
645       return errors::InvalidArgument(
646           "Node '", node_def.name(),
647           "': Node name contains invalid characters");
648     }
649     if (!gdef_nodes_.insert(std::make_pair(node_def.name(), NodeInfo(n)))
650              .second) {
651       return errors::InvalidArgument("Node '", node_def.name(),
652                                      "' is not unique");
653     }
654     // Validate the operation's type.
655     if (node_def.op().empty()) {
656       return errors::InvalidArgument("Node '", node_def.name(),
657                                      "' does not specify an operation");
658     }
659     if (opts_.expect_device_spec && node_def.device().empty()) {
660       return errors::InvalidArgument("Node '", node_def.name(),
661                                      "' is missing a device specification");
662     }
663     if (IsMerge(node_def)) {
664       merge_node_indices_.insert(n);
665     }
666     // Validate control edges at end
667     bool in_control_dependence = false;
668     for (int i = 0; i < node_def.input_size(); ++i) {
669       StringPiece input_name = node_def.input(i);
670       if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
671         in_control_dependence = true;
672       } else if (in_control_dependence) {
673         return errors::InvalidArgument(
674             "Node '", node_def.name(),
675             "': Control dependencies must come after regular dependencies");
676       }
677     }
678     // Update gdef_prefixes_.
679     AddPrefixes(node_def.name(), &gdef_prefixes_);
680   }
681   return Status::OK();
682 }
683 
InitFromEdges()684 Status GraphConstructor::InitFromEdges() {
685   const int num_nodes = node_def_count();
686   pending_count_.reserve(num_nodes);
687   outputs_.resize(num_nodes);
688   gtl::FlatSet<string> next_iteration_nodes;
689   for (int n = 0; n < node_def_count(); ++n) {
690     const NodeDef& node_def = get_node_def(n);
691     if (IsNextIteration(node_def)) {
692       next_iteration_nodes.insert(node_def.name());
693     }
694   }
695 
696   // Parse the inputs for each node.
697   for (int n = 0; n < num_nodes; ++n) {
698     const NodeDef& node_def = get_node_def(n);
699     int pending_count = node_def.input_size();
700     if (IsMerge(node_def)) {
701       // Cycles in the graph are only allowed for while loops. A while loop is
702       // identified by an edge from a NextIteration node to a Merge node. For
703       // such Merge nodes, only wait for one non-control input before
704       // considering the node ready to process in Convert().
705       int32_t num_control_edges = 0;
706       bool has_loop_back_edge = false;
707       for (int i = 0; i < node_def.input_size(); ++i) {
708         StringPiece input_name(node_def.input(i));
709         if (absl::StartsWith(input_name, "^")) {
710           num_control_edges++;
711         } else {
712           TensorId id(ParseTensorName(input_name));
713           if (next_iteration_nodes.find(string(id.first)) !=
714               next_iteration_nodes.end()) {
715             has_loop_back_edge = true;
716           }
717         }
718       }
719       if (has_loop_back_edge) {
720         pending_count = num_control_edges + 1;
721       }
722     }
723     for (int i = 0; i < node_def.input_size(); ++i) {
724       StringPiece input_name = node_def.input(i);
725       TensorId id(ParseTensorName(input_name));
726       if (opts_.input_map.count(id) == 0) {
727         // If an input is not mapped, then the input should appear in the graph
728         // being imported.
729         auto iter = gdef_nodes_.find(id.first);
730         if (iter == gdef_nodes_.end()) {
731           return errors::InvalidArgument("Node '", node_def.name(),
732                                          "': Unknown input node '",
733                                          node_def.input(i), "'");
734         }
735         outputs_[iter->second.gdef_index].push_back(n);
736       } else {
737         // This input is mapped to an existing edge. Therefore this input is
738         // as good as being already processed.
739         --pending_count;
740         DCHECK_GE(pending_count, 0);
741       }
742     }
743     if (pending_count == 0) {
744       ready_.insert(n);
745     }
746     pending_count_.push_back(pending_count);
747   }
748   return Status::OK();
749 }
750 
ValidateColocationConstraints(const NodeDef & node_def)751 Status GraphConstructor::ValidateColocationConstraints(
752     const NodeDef& node_def) {
753   if (!opts_.validate_colocation_constraints || !opts_.importing)
754     return Status::OK();
755   const auto iter = node_def.attr().find(kColocationAttrName);
756   if (iter == node_def.attr().end()) return Status::OK();
757   for (const string& c : iter->second.list().s()) {
758     StringPiece s(c);
759     if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
760         gdef_nodes_.find(s) == gdef_nodes_.end()) {
761       return errors::InvalidArgument(
762           "Node '", node_def.name(),
763           "' expects to be colocated with unknown node '", s, "'");
764     }
765   }
766   return Status::OK();
767 }
768 
MakeNode(NodeDef && node_def,Node ** node)769 Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
770   // Add the node to the graph.
771   Status status;
772   *node = g_->AddNode(std::move(node_def), &status);
773   if (!status.ok()) return status;
774   if (opts_.expect_device_spec) {
775     (*node)->set_assigned_device_name((*node)->def().device());
776   }
777   return Status::OK();
778 }
779 
ValidateShape(Node * node)780 Status GraphConstructor::ValidateShape(Node* node) {
781   if (!opts_.importing || !opts_.validate_shape) return Status::OK();
782   TF_RETURN_IF_ERROR(refiner_->AddNode(node));
783   // For nodes with the _output_shapes attribute, override the shape.
784   std::vector<const TensorShapeProto*> shape_attrs;
785   const char* kAttrName = "_output_shapes";
786   if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
787     // No _output_shapes attribute, the AddNode call above was sufficient.
788     return Status::OK();
789   }
790   auto* ic = refiner_->GetContext(node);
791   DCHECK(ic != nullptr)
792       << "ShapeRefiner::AddNode() should have created the InferenceContext";
793   if (shape_attrs.size() < node->num_outputs()) {
794     return errors::InvalidArgument(
795         "Node '", node->name(), "' has ", node->num_outputs(),
796         " outputs but the ", kAttrName, " attribute specifies shapes for ",
797         shape_attrs.size(), " outputs");
798   }
799   // NOTE(skyewm): we don't raise an error here because some users depend on
800   // this behavior, even though it's unsafe.
801   // TODO(b/74619486): raise an error.
802   if (shape_attrs.size() > node->num_outputs()) {
803     LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
804                  << " outputs but the " << kAttrName
805                  << " attribute specifies shapes for " << shape_attrs.size()
806                  << " outputs. Output shapes may be inaccurate.";
807   }
808   for (int i = 0; i < node->num_outputs(); ++i) {
809     const TensorShapeProto& p = *shape_attrs[i];
810     shape_inference::ShapeHandle h;
811     Status s = ic->MakeShapeFromShapeProto(p, &h);
812     if (!s.ok()) {
813       return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
814                                      kAttrName, " attribute (shape #", i,
815                                      " error:'", s.error_message(), "'");
816     }
817     s = refiner_->SetShape(node, i, h);
818     if (!s.ok()) {
819       return errors::InvalidArgument(
820           "Node '", node->name(), "' has an ", kAttrName,
821           " attribute inconsistent with the GraphDef for output #", i, ": ",
822           s.error_message());
823     }
824   }
825   node->ClearAttr(kAttrName);
826   return Status::OK();
827 }
828 
ModifyNodeDefForImport(NodeDef * node_def)829 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
830   const OpDef* op_def;
831   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
832   AddDefaultsToNodeDef(*op_def, node_def);
833   TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
834   if (versions()) {
835     TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
836   }
837   return Status::OK();
838 }
839 
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)840 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
841                   std::vector<bool>* input_already_exists) {
842   // Remove 'inputs_to_remove' from 'node_def'
843   NodeDef copy;
844   copy.mutable_input()->Reserve(node_def->input_size() -
845                                 inputs_to_remove.size());
846   for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
847     if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
848       ++j;
849     } else {
850       copy.add_input()->swap(*node_def->mutable_input(i));
851     }
852   }
853   node_def->mutable_input()->Swap(copy.mutable_input());
854   // Remove 'inputs_to_remove' from 'input_already_exists'
855   for (int idx : inputs_to_remove) {
856     input_already_exists->erase(input_already_exists->begin() + idx);
857   }
858   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
859 }
860 
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)861 void GraphConstructor::RemapNodeDefInputs(
862     NodeDef* node_def, std::vector<bool>* input_already_exists) {
863   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
864   std::set<TensorId> control_inputs;
865   std::vector<int> inputs_to_remove;
866 
867   for (int i = 0; i < node_def->input_size(); ++i) {
868     auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
869     if (iter == opts_.input_map.end()) continue;
870     used_input_map_keys_.insert(iter->first);
871 
872     TensorId new_input = iter->second;
873     if (new_input.second == Graph::kControlSlot) {
874       // Check if we've already remapped a different input to new_input, and if
875       // so remove this input.
876       if (control_inputs.count(new_input) > 0) {
877         inputs_to_remove.push_back(i);
878         continue;
879       }
880       control_inputs.insert(new_input);
881     }
882     node_def->set_input(i, new_input.ToString());
883     (*input_already_exists)[i] = true;
884   }
885   if (!inputs_to_remove.empty()) {
886     RemoveInputs(inputs_to_remove, node_def, input_already_exists);
887   }
888 }
889 
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)890 void GraphConstructor::AddControlDependencies(
891     NodeDef* node_def, std::vector<bool>* input_already_exists) {
892   // To avoid adding redundant control dependencies to every imported node, skip
893   // nodes that will inherit the dependencies from another imported node.
894   bool inherits_deps = false;
895   for (int i = 0; i < node_def->input_size(); ++i) {
896     // Assume we won't inherit dependencies from remapped inputs that already
897     // exist in the graph. Even if we're wrong, we'll only add redundant
898     // dependencies.
899     if ((*input_already_exists)[i]) continue;
900 
901     // If this input is a backedge, assume we won't inherit the dependencies.
902     // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
903     // worth optimizing these.
904     TensorId id(ParseTensorName(node_def->input(i)));
905     auto iter = gdef_nodes_.find(id.first);
906     DCHECK(iter != gdef_nodes_.end()) << id.first;
907     if (iter->second.node == nullptr) {
908       // Input hasn't been created yet, indicating it's a backedge.
909       continue;
910     }
911     inherits_deps = true;
912   }
913   if (inherits_deps) return;
914 
915   // node_def either has no inputs or all remapped inputs, add the control
916   // dependencies
917   for (const string& control_dep : opts_.control_dependencies) {
918     string input = TensorId(control_dep, Graph::kControlSlot).ToString();
919     bool found = false;
920     for (int i = node_def->input_size() - 1; i >= 0; --i) {
921       const string& node_input = node_def->input(i);
922       if (node_input[0] != '^') {
923         // Control inputs are at the end. Break when we reach the non-control
924         // inputs.
925         break;
926       }
927       if (node_input == input) {
928         // Control dependency already exists
929         found = true;
930         break;
931       }
932     }
933     if (found) {
934       continue;
935     }
936     node_def->add_input(input);
937     input_already_exists->push_back(true);
938   }
939 }
940 
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)941 void GraphConstructor::AddPrefixToNodeDef(
942     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
943   if (prefix_.empty()) return;
944   node_def->set_name(strings::StrCat(prefix_, node_def->name()));
945   // Update names of input nodes
946   for (int i = 0; i < node_def->input_size(); ++i) {
947     // Skip remapped inputs (which already exist in g_ and are not being
948     // imported).
949     if (input_already_exists[i]) continue;
950     StringPiece input(node_def->input(i));
951     if (absl::ConsumePrefix(&input, "^")) {
952       node_def->set_input(i, strings::StrCat("^", prefix_, input));
953     } else {
954       node_def->set_input(i, strings::StrCat(prefix_, input));
955     }
956   }
957   // Update names of colocation groups
958   if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
959     auto* list =
960         node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
961     for (int i = 0; i < list->s_size(); ++i) {
962       StringPiece v(list->s(i));
963       if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
964         list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
965       }
966     }
967   }
968 }
969 
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)970 void GraphConstructor::UniquifyNames(
971     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
972   if (NameExistsInGraph(node_def->name())) {
973     string old_name = node_def->name();
974     node_def->set_name(FindUniqueName(node_def->name()));
975     uniquified_names_[old_name] = node_def->name();
976     // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
977     // `name` because we guarantee the original NodeDef names are unique,
978     // meaning we won't generate this name again.
979   }
980   for (int i = 0; i < node_def->input_size(); ++i) {
981     // Skip remapped inputs (which already exist in g_ and are not being
982     // imported).
983     if (input_already_exists[i]) continue;
984     TensorId id = ParseTensorName(node_def->input(i));
985     // We require that UniquifyNames() is called on all NodeDefs in topological
986     // order. This guarantees that node_def's inputs will already be uniquified
987     // if necessary.
988     auto iter = uniquified_names_.find(string(id.first));
989     if (iter == uniquified_names_.end()) continue;
990     id.first = iter->second;
991     node_def->set_input(i, id.ToString());
992   }
993 }
994 
UpdateUniquifiedColocationNames()995 void GraphConstructor::UpdateUniquifiedColocationNames() {
996   for (const auto& pair : gdef_nodes_) {
997     Node* node = pair.second.node;
998     if (node == nullptr) continue;
999     std::vector<string> coloc_values;
1000     if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1001       continue;
1002     bool updated = false;
1003     for (size_t i = 0; i < coloc_values.size(); ++i) {
1004       StringPiece val(coloc_values[i]);
1005       if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1006         auto name_pair = uniquified_names_.find(string(val));
1007         if (name_pair == uniquified_names_.end()) continue;
1008         updated = true;
1009         coloc_values[i] =
1010             strings::StrCat(kColocationGroupPrefix, name_pair->second);
1011       }
1012     }
1013     if (updated) {
1014       node->AddAttr(kColocationAttrName, std::move(coloc_values));
1015     }
1016   }
1017 }
1018 
NameExistsInGraph(StringPiece name)1019 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1020   if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1021   if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1022   return false;
1023 }
1024 
NameExistsInGraphDef(StringPiece name)1025 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1026   if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1027   if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1028   return false;
1029 }
1030 
FindUniqueName(StringPiece original_name)1031 string GraphConstructor::FindUniqueName(StringPiece original_name) {
1032   string name(original_name);
1033   int count = 0;
1034   // Check that any generated names don't collide with imported NodeDefs (as
1035   // well as nodes in g_).
1036   while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1037     name = strings::StrCat(original_name, "_", ++count);
1038   }
1039   return name;
1040 }
1041 
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)1042 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1043                                            bool* is_node_mapped) {
1044   const OpDef* op_def;
1045   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1046   for (int i = 0; i < op_def->output_arg_size(); ++i) {
1047     if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1048       *is_node_mapped = false;
1049       return Status::OK();
1050     }
1051   }
1052   *is_node_mapped = true;
1053   return Status::OK();
1054 }
1055 
DFS(int cur_node,std::vector<int> * cur_branch,std::vector<bool> * is_on_cur_branch,absl::flat_hash_set<int> * unvisited)1056 void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1057                            std::vector<bool>* is_on_cur_branch,
1058                            absl::flat_hash_set<int>* unvisited) {
1059   cur_branch->push_back(cur_node);
1060   is_on_cur_branch->at(cur_node) = true;
1061   for (auto next_node : outputs_[cur_node]) {
1062     if (unvisited->find(next_node) != unvisited->end()) {
1063       if (is_on_cur_branch->at(next_node)) {
1064         auto iter =
1065             std::find(cur_branch->begin(), cur_branch->end(), next_node);
1066         LOG(WARNING) << "Cycle detected:";
1067         while (iter != cur_branch->end()) {
1068           LOG(WARNING) << SummarizeNodeDef(get_node_def(*iter));
1069           ++iter;
1070         }
1071         LOG(WARNING) << "End of cycle";
1072       } else {
1073         DFS(next_node, cur_branch, is_on_cur_branch, unvisited);
1074       }
1075     }
1076   }
1077   cur_branch->pop_back();
1078   is_on_cur_branch->at(cur_node) = false;
1079   unvisited->erase(cur_node);
1080 }
1081 
PrintCycles()1082 void GraphConstructor::PrintCycles() {
1083   int num_nodes = outputs_.size();
1084   absl::flat_hash_set<int> unvisited;
1085   for (int i = 0; i < num_nodes; i++) {
1086     unvisited.insert(i);
1087   }
1088   while (!unvisited.empty()) {
1089     int cur_node = *unvisited.begin();
1090     // Nodes on the current branch of DFS in traversal order. This is used for
1091     // printing the nodes in the cycle.
1092     std::vector<int> cur_branch;
1093     // This is just to make lookups O(1).
1094     // is_on_cur_branch[i] ==
1095     //   (std::find(cur_branch.start(),
1096     //              cur_branch.end(), i) != cur_branch.end())
1097     std::vector<bool> is_on_cur_branch(num_nodes, false);
1098     DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited);
1099   }
1100 }
1101 
Convert()1102 Status GraphConstructor::Convert() {
1103   // Import functions before adding nodes, since imported nodes may refer to
1104   // functions
1105   if (library()) {
1106     // TODO(b/135705010): Add rvalue overloads into the function library, to
1107     // avoid unnecessarily copying `*library()` here.
1108     TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1109   }
1110 
1111   std::vector<InputInfo> inputs;
1112   int processed = 0;
1113 
1114   std::vector<bool> input_already_exists;
1115 
1116   // Process the NodeDefs in topological order.
1117   // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1118   // inputs, pending_counts_ with the number of inputs for each node and
1119   // outputs_ with the outputs of each node).
1120   while (!ready_.empty()) {
1121     int o = *ready_.begin();
1122     ready_.erase(ready_.begin());
1123     ++processed;
1124     inputs.clear();
1125     bool has_data_back_edge = false;
1126 
1127     NodeDef node_def = consume_node_def(o);
1128 
1129     // input_already_exists[i] is true iff the i-th input of the node we're
1130     // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1131     // to importing node_defs_).  Conversely, input_already_exists[i] is false
1132     // iff the input refers to a node in node_defs_.
1133     input_already_exists.clear();
1134     input_already_exists.resize(node_def.input_size(), false);
1135 
1136     std::string node_name = node_def.name();
1137 
1138     if (opts_.importing) {
1139       if (opts_.skip_mapped_nodes) {
1140         bool is_node_mapped = false;
1141         TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1142         if (is_node_mapped) {
1143           // Skip this node after updating pending_count_ for outputs
1144           UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1145           continue;
1146         }
1147       }
1148 
1149       if (!opts_.input_map.empty()) {
1150         // Note that input_already_exists can shrink here
1151         RemapNodeDefInputs(&node_def, &input_already_exists);
1152       }
1153       if (!opts_.control_dependencies.empty()) {
1154         // Note that input_already_exists can grow here
1155         AddControlDependencies(&node_def, &input_already_exists);
1156       }
1157       if (!opts_.default_device.empty() && node_def.device().empty()) {
1158         node_def.set_device(opts_.default_device);
1159       }
1160     }
1161 
1162     DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1163     TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1164     for (int i = 0; i < node_def.input_size(); ++i) {
1165       TensorId tensor_id = ParseTensorName(node_def.input(i));
1166       Node* src_node;
1167       int src_index;
1168 
1169       if (!input_already_exists[i]) {
1170         // Locate input in newly-imported nodes
1171         auto iter = gdef_nodes_.find(tensor_id.node());
1172         DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1173         src_node = iter->second.node;
1174         src_index = tensor_id.index();
1175         if (src_node == nullptr) has_data_back_edge = true;
1176       } else {
1177         // Input refers to preexistng node in graph
1178         auto iter = existing_nodes_.find(tensor_id.node());
1179         DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1180         src_node = iter->second;
1181         src_index = tensor_id.index();
1182       }
1183 
1184       if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1185         std::ostringstream out;
1186         out << "Node '" << node_def.name() << "': Connecting to invalid output "
1187             << tensor_id.index() << " of source node " << tensor_id.node()
1188             << " which has " << src_node->num_outputs() << " outputs.";
1189 
1190         if (src_node->type_string() == "If" ||
1191             src_node->type_string() == "StatelessIf" ||
1192             src_node->type_string() == "While" ||
1193             src_node->type_string() == "StatelessWhile") {
1194           out << " Try using "
1195               << "tf.compat.v1.experimental.output_all_intermediates(True).";
1196         }
1197         return errors::InvalidArgument(out.str());
1198       }
1199 
1200       inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1201     }
1202 
1203     if (has_data_back_edge && !IsMerge(node_def)) {
1204       return errors::InvalidArgument(
1205           "Node '", node_def.name(),
1206           "' had a back edge, but only Merge nodes can have back edges.");
1207     }
1208 
1209     Node* node;
1210     if (opts_.importing) {
1211       if (!prefix_.empty()) {
1212         AddPrefixToNodeDef(input_already_exists, &node_def);
1213       }
1214       // Note: no need to uniquify names if the prefix already guarantees
1215       // uniqueness
1216       if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1217         UniquifyNames(input_already_exists, &node_def);
1218       }
1219     }
1220 
1221     if (opts_.importing) {
1222       TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1223     } else {
1224       const OpDef* op_def;
1225       TF_RETURN_IF_ERROR(
1226           g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1227       if (opts_.add_default_attributes) {
1228         AddDefaultsToNodeDef(*op_def, &node_def);
1229       }
1230       if (opts_.validate_nodes) {
1231         TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1232       }
1233     }
1234 
1235     TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1236 
1237     gdef_nodes_[node_name].node = node;
1238 
1239     // Remove duplicate control inputs before adding edges to the graph. It
1240     // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1241     auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1242     auto first_control_copy = first_control;
1243     std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1244     inputs.erase(
1245         std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1246         inputs.end());
1247 
1248     // Add edges from inputs to *node to the graph.
1249     for (size_t i = 0; i < inputs.size(); ++i) {
1250       if (inputs[i].node == nullptr) {
1251         // Record this back edge, which will be added after all nodes
1252         // are created.
1253         back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1254       } else if (inputs[i].index == Graph::kControlSlot) {
1255         g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1256       } else {
1257         TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1258       }
1259     }
1260 
1261     TF_RETURN_IF_ERROR(ValidateShape(node));
1262 
1263     // Update pending_count_ for outputs.
1264     UpdatePendingCountAndReady(o, node->IsNextIteration());
1265   }
1266 
1267   if (processed < node_def_count()) {
1268     LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1269                  << " NODES IN A CYCLE";
1270     for (int64_t i = 0; i < node_def_count(); i++) {
1271       if (pending_count_[i] != 0) {
1272         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1273                      << " WITH PENDING COUNT = " << pending_count_[i];
1274       }
1275     }
1276     PrintCycles();
1277     return errors::InvalidArgument(node_def_count() - processed,
1278                                    " nodes in a cycle");
1279   }
1280 
1281   return Status::OK();
1282 }
1283 
AddBackEdges()1284 Status GraphConstructor::AddBackEdges() {
1285   // Add the back edges after all nodes are created.
1286   for (const auto& e : back_edges_) {
1287     Node* src_node = gdef_nodes_[e.src_name].node;
1288     if (e.src_index == Graph::kControlSlot) {
1289       g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1290     } else {
1291       TF_RETURN_IF_ERROR(
1292           MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1293     }
1294 
1295     VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1296             << e.dst_node->name();
1297   }
1298   return Status::OK();
1299 }
1300 
UpdateVersionDef()1301 Status GraphConstructor::UpdateVersionDef() {
1302   if (versions() == nullptr) return Status::OK();
1303 
1304   if (!opts_.importing) {
1305     g_->set_versions(*versions());
1306     return Status::OK();
1307   }
1308   VersionDef g_versions = g_->versions();
1309   g_versions.set_producer(
1310       std::min(g_versions.producer(), versions()->producer()));
1311   g_versions.set_min_consumer(
1312       std::max(g_versions.min_consumer(), versions()->min_consumer()));
1313   if (versions()->bad_consumers_size() > 0) {
1314     std::set<int> bad(g_versions.bad_consumers().begin(),
1315                       g_versions.bad_consumers().end());
1316     bad.insert(versions()->bad_consumers().begin(),
1317                versions()->bad_consumers().end());
1318     g_versions.clear_bad_consumers();
1319     for (int v : bad) {
1320       g_versions.add_bad_consumers(v);
1321     }
1322   }
1323   g_->set_versions(g_versions);
1324   return Status::OK();
1325 }
1326 
PopulateReturnTensors()1327 Status GraphConstructor::PopulateReturnTensors() {
1328   if (opts_.return_tensors.empty()) return Status::OK();
1329   for (const TensorId& id : opts_.return_tensors) {
1330     auto iter = opts_.input_map.find(id);
1331     if (iter == opts_.input_map.end()) {
1332       // Locate id in imported nodes
1333       auto iter = gdef_nodes_.find(id.first);
1334       if (iter == gdef_nodes_.end()) {
1335         return errors::InvalidArgument("Requested return tensor '",
1336                                        id.ToString(),
1337                                        "' not found in graph def");
1338       }
1339       int num_outputs = iter->second.node->num_outputs();
1340       if ((id.second < 0 || id.second >= num_outputs) &&
1341           id.second != Graph::kControlSlot) {
1342         return errors::InvalidArgument("Invalid return output ", id.second,
1343                                        " of node '", id.first, "', which has ",
1344                                        num_outputs, " output(s)");
1345       }
1346       return_tensors_->push_back({iter->second.node, id.second});
1347     } else {
1348       // id was remapped to existing node
1349       TensorId remapped_id = iter->second;
1350       DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1351       Node* node = existing_nodes_[remapped_id.first];
1352       return_tensors_->push_back({node, remapped_id.second});
1353     }
1354   }
1355   return Status::OK();
1356 }
1357 
PopulateReturnNodes()1358 Status GraphConstructor::PopulateReturnNodes() {
1359   if (opts_.return_nodes.empty()) return Status::OK();
1360   for (StringPiece name : opts_.return_nodes) {
1361     auto iter = gdef_nodes_.find(name);
1362     if (iter == gdef_nodes_.end()) {
1363       return errors::InvalidArgument("Requested return node '", name,
1364                                      "' not found in graph def");
1365     }
1366     return_nodes_->push_back(iter->second.node);
1367   }
1368   return Status::OK();
1369 }
1370 
PopulateMissingUnusedInputMapKeys()1371 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1372   if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1373   for (const auto& input_map_pair : opts_.input_map) {
1374     TensorId key = input_map_pair.first;
1375     if (used_input_map_keys_.count(key) > 0) continue;
1376 
1377     auto pair = gdef_nodes_.find(key.first);
1378     if (pair == gdef_nodes_.end()) {
1379       // key's node doesn't exist in GraphDef
1380       missing_unused_input_map_keys_->push_back(key);
1381       continue;
1382     }
1383 
1384     // Check that key's index is in bounds. Get the number of outputs from the
1385     // NodeDef, rather than the imported Node, since the Node may not exist if
1386     // opts_.skip_mapped_nodes is true.
1387     const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1388     const OpDef* op_def;
1389     TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1390     int num_outputs;
1391     TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1392     if (key.second >= num_outputs) {
1393       // key's index out of bounds
1394       missing_unused_input_map_keys_->push_back(key);
1395     }
1396   }
1397   return Status::OK();
1398 }
1399 
Undo()1400 void GraphConstructor::Undo() {
1401   for (const auto& iter : gdef_nodes_) {
1402     if (iter.second.node != nullptr) {
1403       g_->RemoveNode(iter.second.node);
1404     }
1405   }
1406   g_->set_versions(original_versions_);
1407 }
1408 
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1409 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1410                                   int input_index) {
1411   if (output_index >= src->num_outputs()) {
1412     return errors::InvalidArgument(
1413         "Output ", output_index, " of node ", src->name(),
1414         " does not exist. Node only has ", src->num_outputs(), " outputs.");
1415   }
1416   if (input_index >= dst->num_inputs()) {
1417     return errors::InvalidArgument(
1418         "Input ", input_index, " of node ", dst->name(),
1419         " does not exist. Node only has ", dst->num_inputs(), " inputs.");
1420   }
1421 
1422   DataType src_out = src->output_type(output_index);
1423   DataType dst_in = dst->input_type(input_index);
1424   if (!TypesCompatible(dst_in, src_out)) {
1425     return errors::InvalidArgument(
1426         "Input ", input_index, " of node ", dst->name(), " was passed ",
1427         DataTypeString(src_out), " from ", src->name(), ":", output_index,
1428         " incompatible with expected ", DataTypeString(dst_in), ".");
1429   }
1430   g_->AddEdge(src, output_index, dst, input_index);
1431   return Status::OK();
1432 }
1433 
1434 }  // namespace
1435 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1436 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1437                               const GraphDef& gdef, Graph* g) {
1438   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1439   return GraphConstructor::Construct(
1440       opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1441       /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1442       /*missing_unused_input_map_keys=*/nullptr);
1443 }
1444 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,GraphDef && gdef,Graph * g)1445 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1446                               GraphDef&& gdef, Graph* g) {
1447   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1448   return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1449                                      /*return_tensors=*/nullptr,
1450                                      /*return_nodes=*/nullptr,
1451                                      /*missing_unused_input_map_keys=*/nullptr);
1452 }
1453 
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1454 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1455                               gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1456   ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1457   // TODO(irving): Copy will go away once NodeInfo exists
1458   std::vector<const NodeDef*> node_defs;
1459   node_defs.reserve(nodes.size());
1460   for (const auto& n : nodes) {
1461     node_defs.push_back(&n);
1462   }
1463   return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1464                                      &refiner, /*return_tensors=*/nullptr,
1465                                      /*return_nodes=*/nullptr,
1466                                      /*missing_unused_input_map_keys=*/nullptr);
1467 }
1468 
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1469 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1470                       Graph* g, ShapeRefiner* refiner,
1471                       ImportGraphDefResults* results) {
1472   if (!opts.return_tensors.empty()) {
1473     if (results == nullptr) {
1474       return errors::InvalidArgument(
1475           "results argument to ImportGraphDef() must be non-null if "
1476           "opts.return_tensors is non-empty");
1477     }
1478   }
1479 
1480   if (!opts.return_nodes.empty()) {
1481     if (opts.skip_mapped_nodes) {
1482       return errors::InvalidArgument(
1483           "Requesting return_nodes with skip_mapped_nodes set is not currently "
1484           "supported");
1485     }
1486     if (results == nullptr) {
1487       return errors::InvalidArgument(
1488           "results argument to ImportGraphDef() must be non-null if "
1489           "opts.return_nodes is non-empty");
1490     }
1491   }
1492 
1493   if (results != nullptr) {
1494     if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1495         !results->missing_unused_input_map_keys.empty()) {
1496       return errors::InvalidArgument(
1497           "All fields in results argument to ImportGraphDef() must be empty.");
1498     }
1499   }
1500 
1501   ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1502   if (refiner == nullptr) {
1503     refiner = &default_refiner;
1504   } else {
1505     // Log a warning if we are importing a GraphDef at an older
1506     // producer version after already having added non-source/sink
1507     // nodes to the graph in the past.
1508     if (gdef.versions().producer() > 0 &&
1509         gdef.versions().producer() < refiner->graph_def_version() &&
1510         g->num_nodes() > 2) {
1511       LOG(WARNING) << "Importing a graph with a lower producer version "
1512                    << gdef.versions().producer()
1513                    << " into an existing graph with producer version "
1514                    << refiner->graph_def_version() << ". Shape inference will "
1515                    << "have run different parts of the graph with different "
1516                    << "producer versions.";
1517     }
1518   }
1519 
1520   // Set the graph def version of the refiner as the min of the
1521   // current value and the version from the graph we are about to
1522   // import.
1523   //
1524   // Note: to match Run() semantics, we should re-run shape inference
1525   // on the entire graph if the producer version has changed.  For now
1526   // we log the warning above.
1527   refiner->set_graph_def_version(
1528       std::min(refiner->graph_def_version(), gdef.versions().producer()));
1529 
1530   if (results == nullptr) {
1531     return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1532                                        &gdef.library(), g, refiner, nullptr,
1533                                        nullptr, nullptr);
1534   } else {
1535     return GraphConstructor::Construct(
1536         opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1537         &results->return_tensors, &results->return_nodes,
1538         &results->missing_unused_input_map_keys);
1539   }
1540 }
1541 
CopyGraph(const Graph & src,Graph * dest)1542 void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
1543 
1544 }  // namespace tensorflow
1545