• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/graph_constructor.h"
17 
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/core/common_runtime/shape_refiner.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/versions.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/tensor_id.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/flatmap.h"
43 #include "tensorflow/core/lib/gtl/flatset.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/scanner.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/public/version.h"
51 
52 namespace tensorflow {
53 
54 namespace {
55 
56 // We remove duplicate control inputs before adding edges to the Graph, so we
57 // can skip expensive duplicates check in 'AddControlEdge'.
58 static constexpr const bool kDoNotCheckDuplicates = true;
59 
IsMerge(const NodeDef & node_def)60 inline bool IsMerge(const NodeDef& node_def) {
61   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
62          node_def.op() == "_XlaMerge";
63 }
64 
IsNextIteration(const NodeDef & node_def)65 inline bool IsNextIteration(const NodeDef& node_def) {
66   return node_def.op() == "NextIteration" ||
67          node_def.op() == "RefNextIteration";
68 }
69 
IsValidNodeName(StringPiece s,bool allow_internal_ops)70 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
71   using ::tensorflow::strings::Scanner;
72   Scanner scanner(s);
73   scanner
74       .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
75                               : Scanner::LETTER_DIGIT_DOT)
76       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
77 
78   while (true) {
79     if (!scanner.GetResult())  // Some error in previous iteration.
80       return false;
81     if (scanner.empty())  // No error, but nothing left, good.
82       return true;
83 
84     // Absorb another piece, starting with a '>'
85     scanner.One(Scanner::RANGLE)
86         .One(Scanner::LETTER_DIGIT_DOT)
87         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
88   }
89 }
90 
91 class GraphConstructor {
92  public:
93   struct Options {
Optionstensorflow::__anond78e089a0111::GraphConstructor::Options94     Options(const GraphConstructorOptions& in)  // NOLINT(runtime/explicit)
95         : allow_internal_ops(in.allow_internal_ops),
96           expect_device_spec(in.expect_device_spec),
97           importing(false),
98           validate_nodes(in.validate_nodes),
99           validate_colocation_constraints(false),
100           add_default_attributes(in.add_default_attributes) {}
Optionstensorflow::__anond78e089a0111::GraphConstructor::Options101     Options(const ImportGraphDefOptions& in)  // NOLINT(runtime/explicit)
102         : allow_internal_ops(false),
103           expect_device_spec(false),
104           prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
105                      ? in.prefix
106                      : in.prefix + "/"),
107           uniquify_names(in.uniquify_names),
108           uniquify_prefix(in.uniquify_prefix),
109           input_map(in.input_map.begin(), in.input_map.end()),
110           skip_mapped_nodes(in.skip_mapped_nodes),
111           control_dependencies(in.control_dependencies),
112           return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
113           return_nodes(in.return_nodes),
114           importing(true),
115           validate_nodes(true),
116           validate_colocation_constraints(in.validate_colocation_constraints),
117           validate_shape(in.validate_shape),
118           default_device(in.default_device) {}
119 
120     bool allow_internal_ops;
121     bool expect_device_spec;
122 
123     string prefix;
124     bool uniquify_names;
125     bool uniquify_prefix;
126     std::map<TensorId, TensorId> input_map;
127     bool skip_mapped_nodes;
128     std::vector<string> control_dependencies;
129     std::vector<TensorId> return_tensors;
130     std::vector<string> return_nodes;
131 
132     // TODO(ashankar): This bool exists to separate out functionality required
133     // to make ImportGraphDef a close equivalent of Python's import_graph_def
134     // without affecting the behavior of ConvertGraphDefToGraph at the time
135     // ImportGraphDef was added.
136     //
137     // That said, the functionality here (shape and op validation) seems
138     // applicable to ConvertGraphDefToGraph as well, so make an attempt to
139     // remove this.
140     bool importing;
141     // If true, validates that nodes being converted have all expected attrs
142     // set and no unknown attrs set by calling ValidateNodeDef().
143     // `validate_nodes` is always true when `importing` is set.
144     bool validate_nodes;
145     bool validate_colocation_constraints;
146     bool validate_shape = true;
147 
148     // If true, GraphConstructor will add attributes with their default
149     // value to the Node when they are missing from the NodeDef.
150     bool add_default_attributes = true;
151 
152     string default_device;
153   };
154 
155   typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
156 
157   // versions and library may be nullptr
158   static Status Construct(
159       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
160       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
161       std::vector<std::pair<Node*, int>>* return_tensors,
162       std::vector<Node*>* return_nodes,
163       std::vector<SafeTensorId>* missing_unused_input_map_keys);
164 
165   static Status Construct(
166       const Options& opts, GraphDef&& graph_def, Graph* g,
167       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
168       std::vector<Node*>* return_nodes,
169       std::vector<SafeTensorId>* missing_unused_input_map_keys);
170 
171  protected:
GraphConstructor(const Options & opts,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)172   GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
173                    std::vector<std::pair<Node*, int>>* return_tensors,
174                    std::vector<Node*>* return_nodes,
175                    std::vector<SafeTensorId>* missing_unused_input_map_keys)
176       : opts_(opts),
177         g_(g),
178         original_versions_(g->versions()),
179         prefix_(opts.prefix),
180         refiner_(refiner),
181         return_tensors_(return_tensors),
182         return_nodes_(return_nodes),
183         missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
184 
~GraphConstructor()185   virtual ~GraphConstructor() {}
186 
TryImport()187   Status TryImport() {
188     TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
189     TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
190     TF_RETURN_IF_ERROR(BuildNodeIndex());
191     TF_RETURN_IF_ERROR(InitFromEdges());
192 
193     // NOTE: Convert() invokes `consume_node_def()` on each node in the input
194     // graph, so `get_node_def()` is no longer usable once it is called.
195     TF_RETURN_IF_ERROR(Convert());
196 
197     TF_RETURN_IF_ERROR(AddBackEdges());
198     TF_RETURN_IF_ERROR(UpdateVersionDef());
199     TF_RETURN_IF_ERROR(PopulateReturnTensors());
200     TF_RETURN_IF_ERROR(PopulateReturnNodes());
201     TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
202     UpdateUniquifiedColocationNames();
203     FixupSourceAndSinkEdges(g_);
204     return Status::OK();
205   }
206 
207  private:
208   Status EnsureNoNameCollisions();
209   Status ValidateInputMapAndControlDependencies();
210   Status BuildNodeIndex();
211   Status InitFromEdges();
212   Status Convert();
213   Status AddBackEdges();
214   Status UpdateVersionDef();
215   Status PopulateReturnTensors();
216   Status PopulateReturnNodes();
217   Status PopulateMissingUnusedInputMapKeys();
218 
219   void Undo();
220 
221   // Prints cycles in the graph.
222   void PrintCycles();
223   // Performs DFS starting at `cur_node` and prints any cycles found.
224   void DFS(int cur_node, std::vector<int>* cur_branch,
225            std::vector<bool>* is_on_cur_branch,
226            absl::flat_hash_set<int>* unvisited);
227   Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
228   Status ValidateColocationConstraints(const NodeDef& node_def);
229   Status MakeNode(NodeDef&& node_def, Node** node);
230   Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
231   Status ValidateShape(Node* node);
232   Status ModifyNodeDefForImport(NodeDef* node_def);
233   // Modifies node_def's inputs according to opts_.input_map.
234   // input_already_exists is a pre-initialized vector of length
235   // node_def->input_size(). This function will mark inputs that are remapped to
236   // true.
237   void RemapNodeDefInputs(NodeDef* node_def,
238                           std::vector<bool>* input_already_exists);
239   // input_already_exists is a pre-initialized vector of length
240   // node_def->input_size(). This function will add and mark control inputs as
241   // true.
242   void AddControlDependencies(NodeDef* node_def,
243                               std::vector<bool>* input_already_exists);
244   void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
245                           NodeDef* node_def);
246 
247   // Modifies `node_def` if its name isn't unique, or if any of its inputs'
248   // names have been uniquified. This must be called in topological order on all
249   // nodes.
250   void UniquifyNames(const std::vector<bool>& input_already_exists,
251                      NodeDef* node_def);
252 
253   // Updates any constructed nodes' colocation group names if the name has been
254   // updated by UniquifyNames. This is called after all the nodes have been
255   // constructed so all the names have been uniquified if necessary.
256   void UpdateUniquifiedColocationNames();
257 
258   // Returns true if `name` already exists in `g_` (either as a node name or
259   // prefix).
260   bool NameExistsInGraph(StringPiece name);
261 
262   // Returns true if `name` already exists in the GraphDef being imported
263   // (either as a node name or prefix).
264   bool NameExistsInGraphDef(StringPiece name);
265 
266   // Returns a unique version of `original_name`, or `original_name` if it's
267   // already unique in the graph.
268   string FindUniqueName(StringPiece original_name);
269 
270   // Decrement pending count for users of `processed` and add the ones that now
271   // have all of their pending inputs satisfied to `ready_`.
272   void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
273 
274   // Subclasses override the following virtual methods to provide efficient
275   // access to the original protocol buffer-based graph.
276 
277   // Returns the number of nodes in the graph.
278   virtual size_t node_def_count() const = 0;
279   // Returns the i^th node in the graph. Must not be called after
280   // consume_node_def(i).
281   virtual const NodeDef& get_node_def(int i) const = 0;
282   // Destructively reads the i^th node in the graph, avoiding a copy if
283   // possible. After calling this method, the result of get_node_def(i) is
284   // undefined.
285   virtual NodeDef consume_node_def(int i) = 0;
286   // Returns the version information for the graph, or nullptr if none is
287   // available.
288   virtual const VersionDef* versions() const = 0;
289   // Returns the function information for the graph, or nullptr if none is
290   // available.
291   virtual const FunctionDefLibrary* library() const = 0;
292 
293   // From constructor
294   const Options opts_;
295   Graph* g_;
296   const VersionDef original_versions_;
297 
298   // A copy of opts_.prefix, possibly uniquified.
299   string prefix_;
300 
301   ShapeRefiner* refiner_;
302 
303   // May be null. Not owned.
304   std::vector<std::pair<Node*, int>>* return_tensors_;
305 
306   // May be null. Not owned.
307   std::vector<Node*>* return_nodes_;
308 
309   // May be null. Not owned.
310   std::vector<SafeTensorId>* missing_unused_input_map_keys_;
311 
312   // Intermediate datastructure used to populate
313   // `missing_unused_input_map_keys_`.
314   std::set<TensorId> used_input_map_keys_;
315 
316   // Intermediate datastructure used to track the destinations of back edges.
317   absl::flat_hash_set<int> merge_node_indices_;
318 
319   // Mapping from node name to the index within node_defs_.
320   struct NodeInfo {
NodeInfotensorflow::__anond78e089a0111::GraphConstructor::NodeInfo321     explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
322     // Containers require that we have a default constructor.
NodeInfotensorflow::__anond78e089a0111::GraphConstructor::NodeInfo323     NodeInfo() : NodeInfo(-1) {}
324     int gdef_index;
325     Node* node;  // nullptr until the NodeDef is converted to a Node.
326   };
327   gtl::FlatMap<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_;
328 
329   // Storage for StringPiece keys in gdef_nodes_. Typically, the StringPiece key
330   // will refer to the string stored in `NodeDef::name()`. This intern table is
331   // only used when the original NodeDef's name is changed.
332   std::vector<string> string_intern_table_;
333 
334   // Prefixes already used in the GraphDef being imported.
335   gtl::FlatSet<StringPiece, StringPieceHasher> gdef_prefixes_;
336 
337   // Mapping from node name to the existing node in g_.
338   gtl::FlatMap<StringPiece, Node*, StringPieceHasher> existing_nodes_;
339 
340   // Prefixes already used in the graph.
341   gtl::FlatSet<StringPiece, StringPieceHasher> existing_prefixes_;
342 
343   // Imported node names that have been uniquified. The key is the original
344   // name, the value is the new unique name.
345   gtl::FlatMap<string, string> uniquified_names_;
346 
347   // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
348   // (sorted) set so nodes are created in the order defined in the GraphDef.
349   std::set<int> ready_;
350 
351   // Mapping between index within node_defs_ and the number of inputs that
352   // still need to be converted.
353   std::vector<int> pending_count_;
354 
355   // Mapping between index within node_defs_ and the index within node_defs_ of
356   // all nodes it outputs to.
357   std::vector<gtl::InlinedVector<int, 4>> outputs_;
358 
359   // Used in the conversion from node_defs_ to g_ to represent the ith input
360   // of a node.
361   struct InputInfo {
InputInfotensorflow::__anond78e089a0111::GraphConstructor::InputInfo362     explicit InputInfo(const string& node_name, Node* n, int i)
363         : name(node_name), node(n), index(i) {}
364     // Use string instead of StringPiece so we don't have to manage lifetime
365     string name;
366     Node* node;
367     int index;
368 
IsControlInputtensorflow::__anond78e089a0111::GraphConstructor::InputInfo369     static bool IsControlInput(const InputInfo& input) {
370       return input.index == Graph::kControlSlot;
371     }
CompareNametensorflow::__anond78e089a0111::GraphConstructor::InputInfo372     static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
373       return lhs.name < rhs.name;
374     }
IsSameNametensorflow::__anond78e089a0111::GraphConstructor::InputInfo375     static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
376       return lhs.name == rhs.name;
377     }
378   };
379 
380   // Used in the conversion from node_defs_ to g_ to represent an edge from
381   // the node named 'name' to node 'n'.
382   struct EdgeInfo {
EdgeInfotensorflow::__anond78e089a0111::GraphConstructor::EdgeInfo383     explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
384         : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
385     // Use string instead of StringPiece so we don't have to manage lifetime
386     string src_name;
387     int src_index;
388     Node* dst_node;
389     int dst_index;
390   };
391   std::vector<EdgeInfo> back_edges_;
392 
393   TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
394 };
395 
396 // Implementation of GraphConstructor that does not take ownership of the
397 // input NodeDef messages and thus copies the nodes into the constructed Graph*.
398 //
399 // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
400 // avoids copying each NodeDef into the constructed Graph*.
401 class NodeDefCopyingGraphConstructor : public GraphConstructor {
402  public:
NodeDefCopyingGraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)403   NodeDefCopyingGraphConstructor(
404       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
405       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
406       std::vector<std::pair<Node*, int>>* return_tensors,
407       std::vector<Node*>* return_nodes,
408       std::vector<SafeTensorId>* missing_unused_input_map_keys)
409       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
410                          missing_unused_input_map_keys),
411         node_defs_(node_defs),
412         versions_(versions),
413         library_(library) {}
414 
415  private:
node_def_count() const416   size_t node_def_count() const override { return node_defs_.size(); }
get_node_def(int i) const417   const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
consume_node_def(int i)418   NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
versions() const419   const VersionDef* versions() const override { return versions_; }
library() const420   const FunctionDefLibrary* library() const override { return library_; }
421 
422   const NodeDefSlice node_defs_;
423   const VersionDef* const versions_;
424   const FunctionDefLibrary* const library_;
425 };
426 
427 // Implementation of GraphConstructor that takes ownership of the input
428 // GraphDef, and can perform destructive reads.
429 class NodeDefMovingGraphConstructor : public GraphConstructor {
430  public:
NodeDefMovingGraphConstructor(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)431   NodeDefMovingGraphConstructor(
432       const Options& opts, GraphDef&& graph_def, Graph* g,
433       ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
434       std::vector<Node*>* return_nodes,
435       std::vector<SafeTensorId>* missing_unused_input_map_keys)
436       : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
437                          missing_unused_input_map_keys),
438         graph_def_(std::move(graph_def)),
439         is_consumed_(graph_def_.node_size(), false) {}
440 
441  private:
node_def_count() const442   size_t node_def_count() const override { return graph_def_.node().size(); }
get_node_def(int i) const443   const NodeDef& get_node_def(int i) const override {
444     CHECK(!is_consumed_[i])
445         << "NodeDef " << i << " accessed after it was consumed.";
446     return graph_def_.node(i);
447   }
consume_node_def(int i)448   NodeDef consume_node_def(int i) override {
449     CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
450     is_consumed_[i] = true;
451     return std::move(*graph_def_.mutable_node(i));
452   }
versions() const453   const VersionDef* versions() const override { return &graph_def_.versions(); }
library() const454   const FunctionDefLibrary* library() const override {
455     return &graph_def_.library();
456   }
457 
458   GraphDef graph_def_;
459   std::vector<bool> is_consumed_;
460 };
461 
ForwardCompatibilityWindowPassed(const VersionDef & versions)462 bool ForwardCompatibilityWindowPassed(const VersionDef& versions) {
463   // TF_GRAPH_DEF_VERSION is incremented daily.
464   // TF has a 3 week forward compatibility guarantee.
465   return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21;
466 }
467 
MaybeAppendVersionWarning(const VersionDef * versions,const Status & import_status)468 Status MaybeAppendVersionWarning(const VersionDef* versions,
469                                  const Status& import_status) {
470   if (versions && ForwardCompatibilityWindowPassed(*versions)) {
471     return Status(
472         import_status.code(),
473         absl::StrCat(
474             "Converting GraphDef to Graph has failed with an error: '",
475             import_status.error_message(),
476             "' The binary trying to import the GraphDef was built when "
477             "GraphDef version was ",
478             TF_GRAPH_DEF_VERSION,
479             ". The GraphDef was produced by a binary built when GraphDef "
480             "version was ",
481             versions->producer(),
482             ". The difference between these versions is larger than "
483             "TensorFlow's forward compatibility guarantee, and might be the "
484             "root cause for failing to import the GraphDef."));
485   }
486   return import_status;
487 }
488 
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)489 /* static */ Status GraphConstructor::Construct(
490     const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
491     const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
492     std::vector<std::pair<Node*, int>>* return_tensors,
493     std::vector<Node*>* return_nodes,
494     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
495   if (versions) {
496     TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
497                                      TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
498                                      "GraphDef", "graph"));
499   }
500   NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
501                                    refiner, return_tensors, return_nodes,
502                                    missing_unused_input_map_keys);
503   Status s = c.TryImport();
504   if (!s.ok()) {
505     c.Undo();
506     s = MaybeAppendVersionWarning(versions, s);
507   }
508   return s;
509 }
510 
Construct(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)511 /* static */ Status GraphConstructor::Construct(
512     const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
513     std::vector<std::pair<Node*, int>>* return_tensors,
514     std::vector<Node*>* return_nodes,
515     std::vector<SafeTensorId>* missing_unused_input_map_keys) {
516   TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
517                                    TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
518                                    "GraphDef", "graph"));
519   VersionDef version_def = graph_def.versions();
520   NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
521                                   return_tensors, return_nodes,
522                                   missing_unused_input_map_keys);
523   Status s = c.TryImport();
524   if (!s.ok()) {
525     c.Undo();
526     s = MaybeAppendVersionWarning(&version_def, s);
527   }
528   return s;
529 }
530 
UpdatePendingCountAndReady(int processed,bool is_next_iteration)531 void GraphConstructor::UpdatePendingCountAndReady(int processed,
532                                                   bool is_next_iteration) {
533   for (size_t i = 0; i < outputs_[processed].size(); ++i) {
534     const int output = outputs_[processed][i];
535     // We didn't consider NextIteration->Merge edges when computing
536     // pending_counts_ so we should not have to consider it here either.
537     bool is_next_iteration_to_merge_edge =
538         is_next_iteration && merge_node_indices_.count(output) == 1;
539     if (!is_next_iteration_to_merge_edge) {
540       int* current_pending_count = &pending_count_[output];
541       CHECK_GT(*current_pending_count, 0);
542       (*current_pending_count)--;
543       if (*current_pending_count == 0) {
544         ready_.insert(output);
545       }
546     }
547   }
548 }
549 
550 // This could be expensive but we don't expect to call it often, if at all (only
551 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)552 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
553                       const StringPiece& node_name) {
554   for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
555     if (iter->second.first == node_name) return true;
556   }
557   return false;
558 }
559 
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)560 bool NodeNameInValues(const std::vector<string>& control_dependencies,
561                       const StringPiece& node_name) {
562   return std::find(control_dependencies.begin(), control_dependencies.end(),
563                    node_name) != control_dependencies.end();
564 }
565 
566 // Adds any prefixes of `node_name` (not including the full name itself) to
567 // `prefixes`.
AddPrefixes(StringPiece node_name,gtl::FlatSet<StringPiece,StringPieceHasher> * prefixes)568 void AddPrefixes(StringPiece node_name,
569                  gtl::FlatSet<StringPiece, StringPieceHasher>* prefixes) {
570   size_t idx = -1;
571   while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
572     prefixes->insert(node_name.substr(0, idx));
573   }
574 }
575 
EnsureNoNameCollisions()576 Status GraphConstructor::EnsureNoNameCollisions() {
577   existing_nodes_.reserve(g_->num_nodes());
578   // Populate existing_nodes_ and existing_prefixes_.
579   for (Node* n : g_->nodes()) {
580     bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
581     if (already_exists) {
582       if (NodeNameInValues(opts_.input_map, n->name())) {
583         return errors::InvalidArgument(
584             "cannot resolve input_map because multiple nodes exist with name '",
585             n->name(), "'");
586       }
587       if (NodeNameInValues(opts_.control_dependencies, n->name())) {
588         return errors::InvalidArgument(
589             "cannot resolve control_dependencies because multiple nodes exist "
590             "with name '",
591             n->name(), "'");
592       }
593     }
594     AddPrefixes(n->name(), &existing_prefixes_);
595   }
596   if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
597     for (size_t i = 0; i < node_def_count(); ++i) {
598       const string& name = get_node_def(i).name();
599       if (NameExistsInGraph(name)) {
600         return errors::InvalidArgument("Node name '", name,
601                                        "' already exists in the Graph");
602       }
603     }
604   } else if (!prefix_.empty()) {
605     StringPiece prefix_no_slash(prefix_);
606     prefix_no_slash.remove_suffix(1);
607     if (!IsValidNodeName(prefix_no_slash, false)) {
608       return errors::InvalidArgument("Imported node name prefix '", prefix_,
609                                      "' would lead to invalid node names");
610     }
611     if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
612       prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
613     }
614   }
615   return Status::OK();
616 }
617 
ValidateInputMapAndControlDependencies()618 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
619   for (const auto& mapping : opts_.input_map) {
620     TensorId src = mapping.first;
621     TensorId dst = mapping.second;
622     if (existing_nodes_.count(dst.first) == 0) {
623       return errors::InvalidArgument(
624           "node '", dst.first, "' in input_map does not exist in graph ",
625           "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
626     }
627     if ((src.second == Graph::kControlSlot) !=
628         (dst.second == Graph::kControlSlot)) {
629       return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
630                                      dst.ToString(), " between ",
631                                      "control edge and non-control edge");
632     }
633   }
634   for (const string& node : opts_.control_dependencies) {
635     if (existing_nodes_.count(node) == 0) {
636       return errors::InvalidArgument(
637           "node '", node,
638           "' in control_dependencies does not exist in "
639           "graph");
640     }
641   }
642   return Status::OK();
643 }
644 
BuildNodeIndex()645 Status GraphConstructor::BuildNodeIndex() {
646   // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
647   for (int n = 0; n < node_def_count(); ++n) {
648     const NodeDef& node_def = get_node_def(n);
649     if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
650       return errors::InvalidArgument(
651           "Node '", node_def.name(),
652           "': Node name contains invalid characters");
653     }
654     if (!gdef_nodes_
655              .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
656              .second) {
657       return errors::InvalidArgument("Node '", node_def.name(),
658                                      "' is not unique");
659     }
660     // Validate the operation's type.
661     if (node_def.op().empty()) {
662       return errors::InvalidArgument("Node '", node_def.name(),
663                                      "' does not specify an operation");
664     }
665     if (opts_.expect_device_spec && node_def.device().empty()) {
666       return errors::InvalidArgument("Node '", node_def.name(),
667                                      "' is missing a device specification");
668     }
669     if (IsMerge(node_def)) {
670       merge_node_indices_.insert(n);
671     }
672     // Validate control edges at end
673     bool in_control_dependence = false;
674     for (int i = 0; i < node_def.input_size(); ++i) {
675       StringPiece input_name = node_def.input(i);
676       if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
677         in_control_dependence = true;
678       } else if (in_control_dependence) {
679         return errors::InvalidArgument(
680             "Node '", node_def.name(),
681             "': Control dependencies must come after regular dependencies");
682       }
683     }
684     // Update gdef_prefixes_.
685     AddPrefixes(node_def.name(), &gdef_prefixes_);
686   }
687   return Status::OK();
688 }
689 
InitFromEdges()690 Status GraphConstructor::InitFromEdges() {
691   const int num_nodes = node_def_count();
692   pending_count_.reserve(num_nodes);
693   outputs_.resize(num_nodes);
694   gtl::FlatSet<string> next_iteration_nodes;
695   for (int n = 0; n < node_def_count(); ++n) {
696     const NodeDef& node_def = get_node_def(n);
697     if (IsNextIteration(node_def)) {
698       next_iteration_nodes.insert(node_def.name());
699     }
700   }
701 
702   // Parse the inputs for each node.
703   for (int n = 0; n < num_nodes; ++n) {
704     const NodeDef& node_def = get_node_def(n);
705     int pending_count = node_def.input_size();
706     if (IsMerge(node_def)) {
707       // Cycles in the graph are only allowed for while loops. A while loop is
708       // identified by an edge from a NextIteration node to a Merge node. For
709       // such Merge nodes, only wait for one non-control input before
710       // considering the node ready to process in Convert().
711       int32 num_control_edges = 0;
712       bool has_loop_back_edge = false;
713       for (int i = 0; i < node_def.input_size(); ++i) {
714         StringPiece input_name(node_def.input(i));
715         if (absl::StartsWith(input_name, "^")) {
716           num_control_edges++;
717         } else {
718           TensorId id(ParseTensorName(input_name));
719           if (next_iteration_nodes.find(string(id.first)) !=
720               next_iteration_nodes.end()) {
721             has_loop_back_edge = true;
722           }
723         }
724       }
725       if (has_loop_back_edge) {
726         pending_count = num_control_edges + 1;
727       }
728     }
729     for (int i = 0; i < node_def.input_size(); ++i) {
730       StringPiece input_name = node_def.input(i);
731       TensorId id(ParseTensorName(input_name));
732       if (opts_.input_map.count(id) == 0) {
733         // If an input is not mapped, then the input should appear in the graph
734         // being imported.
735         auto iter = gdef_nodes_.find(id.first);
736         if (iter == gdef_nodes_.end()) {
737           return errors::InvalidArgument("Node '", node_def.name(),
738                                          "': Unknown input node '",
739                                          node_def.input(i), "'");
740         }
741         outputs_[iter->second.gdef_index].push_back(n);
742       } else {
743         // This input is mapped to an existing edge. Therefore this input is
744         // as good as being already processed.
745         --pending_count;
746         DCHECK_GE(pending_count, 0);
747       }
748     }
749     if (pending_count == 0) {
750       ready_.insert(n);
751     }
752     pending_count_.push_back(pending_count);
753   }
754   return Status::OK();
755 }
756 
ValidateColocationConstraints(const NodeDef & node_def)757 Status GraphConstructor::ValidateColocationConstraints(
758     const NodeDef& node_def) {
759   if (!opts_.validate_colocation_constraints || !opts_.importing)
760     return Status::OK();
761   const auto iter = node_def.attr().find(kColocationAttrName);
762   if (iter == node_def.attr().end()) return Status::OK();
763   for (const string& c : iter->second.list().s()) {
764     StringPiece s(c);
765     if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
766         gdef_nodes_.find(s) == gdef_nodes_.end()) {
767       return errors::InvalidArgument(
768           "Node '", node_def.name(),
769           "' expects to be colocated with unknown node '", s, "'");
770     }
771   }
772   return Status::OK();
773 }
774 
MakeNode(NodeDef && node_def,Node ** node)775 Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
776   // Add the node to the graph.
777   Status status;
778   *node = g_->AddNode(std::move(node_def), &status);
779   if (!status.ok()) return status;
780   if (opts_.expect_device_spec) {
781     (*node)->set_assigned_device_name((*node)->def().device());
782   }
783   return Status::OK();
784 }
785 
ValidateShape(Node * node)786 Status GraphConstructor::ValidateShape(Node* node) {
787   if (!opts_.importing || !opts_.validate_shape) return Status::OK();
788   TF_RETURN_IF_ERROR(refiner_->AddNode(node));
789   // For nodes with the _output_shapes attribute, override the shape.
790   std::vector<const TensorShapeProto*> shape_attrs;
791   const char* kAttrName = "_output_shapes";
792   if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
793     // No _output_shapes attribute, the AddNode call above was sufficient.
794     return Status::OK();
795   }
796   auto* ic = refiner_->GetContext(node);
797   DCHECK(ic != nullptr)
798       << "ShapeRefiner::AddNode() should have created the InferenceContext";
799   if (shape_attrs.size() < node->num_outputs()) {
800     return errors::InvalidArgument(
801         "Node '", node->name(), "' has ", node->num_outputs(),
802         " outputs but the ", kAttrName, " attribute specifies shapes for ",
803         shape_attrs.size(), " outputs");
804   }
805   // NOTE(skyewm): we don't raise an error here because some users depend on
806   // this behavior, even though it's unsafe.
807   // TODO(b/74619486): raise an error.
808   if (shape_attrs.size() > node->num_outputs()) {
809     LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
810                  << " outputs but the " << kAttrName
811                  << " attribute specifies shapes for " << shape_attrs.size()
812                  << " outputs. Output shapes may be inaccurate.";
813   }
814   for (int i = 0; i < node->num_outputs(); ++i) {
815     const TensorShapeProto& p = *shape_attrs[i];
816     shape_inference::ShapeHandle h;
817     Status s = ic->MakeShapeFromShapeProto(p, &h);
818     if (!s.ok()) {
819       return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
820                                      kAttrName, " attribute (shape #", i,
821                                      " error:'", s.error_message(), "'");
822     }
823     s = refiner_->SetShape(node, i, h);
824     if (!s.ok()) {
825       return errors::InvalidArgument(
826           "Node '", node->name(), "' has an ", kAttrName,
827           " attribute inconsistent with the GraphDef for output #", i, ": ",
828           s.error_message());
829     }
830   }
831   node->ClearAttr(kAttrName);
832   return Status::OK();
833 }
834 
ModifyNodeDefForImport(NodeDef * node_def)835 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
836   const OpDef* op_def;
837   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
838   AddDefaultsToNodeDef(*op_def, node_def);
839   TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
840   if (versions()) {
841     TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
842   }
843   return Status::OK();
844 }
845 
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)846 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
847                   std::vector<bool>* input_already_exists) {
848   // Remove 'inputs_to_remove' from 'node_def'
849   NodeDef copy;
850   copy.mutable_input()->Reserve(node_def->input_size() -
851                                 inputs_to_remove.size());
852   for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
853     if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
854       ++j;
855     } else {
856       copy.add_input()->swap(*node_def->mutable_input(i));
857     }
858   }
859   node_def->mutable_input()->Swap(copy.mutable_input());
860   // Remove 'inputs_to_remove' from 'input_already_exists'
861   for (int idx : inputs_to_remove) {
862     input_already_exists->erase(input_already_exists->begin() + idx);
863   }
864   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
865 }
866 
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)867 void GraphConstructor::RemapNodeDefInputs(
868     NodeDef* node_def, std::vector<bool>* input_already_exists) {
869   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
870   std::set<TensorId> control_inputs;
871   std::vector<int> inputs_to_remove;
872 
873   for (int i = 0; i < node_def->input_size(); ++i) {
874     auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
875     if (iter == opts_.input_map.end()) continue;
876     used_input_map_keys_.insert(iter->first);
877 
878     TensorId new_input = iter->second;
879     if (new_input.second == Graph::kControlSlot) {
880       // Check if we've already remapped a different input to new_input, and if
881       // so remove this input.
882       if (control_inputs.count(new_input) > 0) {
883         inputs_to_remove.push_back(i);
884         continue;
885       }
886       control_inputs.insert(new_input);
887     }
888     node_def->set_input(i, new_input.ToString());
889     (*input_already_exists)[i] = true;
890   }
891   if (!inputs_to_remove.empty()) {
892     RemoveInputs(inputs_to_remove, node_def, input_already_exists);
893   }
894 }
895 
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)896 void GraphConstructor::AddControlDependencies(
897     NodeDef* node_def, std::vector<bool>* input_already_exists) {
898   // To avoid adding redundant control dependencies to every imported node, skip
899   // nodes that will inherit the dependencies from another imported node.
900   bool inherits_deps = false;
901   for (int i = 0; i < node_def->input_size(); ++i) {
902     // Assume we won't inherit dependencies from remapped inputs that already
903     // exist in the graph. Even if we're wrong, we'll only add redundant
904     // dependencies.
905     if ((*input_already_exists)[i]) continue;
906 
907     // If this input is a backedge, assume we won't inherit the dependencies.
908     // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
909     // worth optimizing these.
910     TensorId id(ParseTensorName(node_def->input(i)));
911     auto iter = gdef_nodes_.find(id.first);
912     DCHECK(iter != gdef_nodes_.end()) << id.first;
913     if (iter->second.node == nullptr) {
914       // Input hasn't been created yet, indicating it's a backedge.
915       continue;
916     }
917     inherits_deps = true;
918   }
919   if (inherits_deps) return;
920 
921   // node_def either has no inputs or all remapped inputs, add the control
922   // dependencies
923   for (const string& control_dep : opts_.control_dependencies) {
924     string input = TensorId(control_dep, Graph::kControlSlot).ToString();
925     bool found = false;
926     for (int i = node_def->input_size() - 1; i >= 0; --i) {
927       const string& node_input = node_def->input(i);
928       if (node_input[0] != '^') {
929         // Control inputs are at the end. Break when we reach the non-control
930         // inputs.
931         break;
932       }
933       if (node_input == input) {
934         // Control dependency already exists
935         found = true;
936         break;
937       }
938     }
939     if (found) {
940       continue;
941     }
942     node_def->add_input(input);
943     input_already_exists->push_back(true);
944   }
945 }
946 
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)947 void GraphConstructor::AddPrefixToNodeDef(
948     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
949   if (prefix_.empty()) return;
950   node_def->set_name(strings::StrCat(prefix_, node_def->name()));
951   // Update names of input nodes
952   for (int i = 0; i < node_def->input_size(); ++i) {
953     // Skip remapped inputs (which already exist in g_ and are not being
954     // imported).
955     if (input_already_exists[i]) continue;
956     StringPiece input(node_def->input(i));
957     if (absl::ConsumePrefix(&input, "^")) {
958       node_def->set_input(i, strings::StrCat("^", prefix_, input));
959     } else {
960       node_def->set_input(i, strings::StrCat(prefix_, input));
961     }
962   }
963   // Update names of colocation groups
964   if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
965     auto* list =
966         node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
967     for (int i = 0; i < list->s_size(); ++i) {
968       StringPiece v(list->s(i));
969       if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
970         list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
971       }
972     }
973   }
974 }
975 
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)976 void GraphConstructor::UniquifyNames(
977     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
978   if (NameExistsInGraph(node_def->name())) {
979     string old_name = node_def->name();
980     node_def->set_name(FindUniqueName(node_def->name()));
981     uniquified_names_[old_name] = node_def->name();
982     // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
983     // `name` because we guarantee the original NodeDef names are unique,
984     // meaning we won't generate this name again.
985   }
986   for (int i = 0; i < node_def->input_size(); ++i) {
987     // Skip remapped inputs (which already exist in g_ and are not being
988     // imported).
989     if (input_already_exists[i]) continue;
990     TensorId id = ParseTensorName(node_def->input(i));
991     // We require that UniquifyNames() is called on all NodeDefs in topological
992     // order. This guarantees that node_def's inputs will already be uniquified
993     // if necessary.
994     auto iter = uniquified_names_.find(string(id.first));
995     if (iter == uniquified_names_.end()) continue;
996     id.first = iter->second;
997     node_def->set_input(i, id.ToString());
998   }
999 }
1000 
UpdateUniquifiedColocationNames()1001 void GraphConstructor::UpdateUniquifiedColocationNames() {
1002   for (const auto& pair : gdef_nodes_) {
1003     Node* node = pair.second.node;
1004     if (node == nullptr) continue;
1005     std::vector<string> coloc_values;
1006     if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1007       continue;
1008     bool updated = false;
1009     for (size_t i = 0; i < coloc_values.size(); ++i) {
1010       StringPiece val(coloc_values[i]);
1011       if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1012         auto name_pair = uniquified_names_.find(string(val));
1013         if (name_pair == uniquified_names_.end()) continue;
1014         updated = true;
1015         coloc_values[i] =
1016             strings::StrCat(kColocationGroupPrefix, name_pair->second);
1017       }
1018     }
1019     if (updated) {
1020       node->AddAttr(kColocationAttrName, std::move(coloc_values));
1021     }
1022   }
1023 }
1024 
NameExistsInGraph(StringPiece name)1025 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1026   if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1027   if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1028   return false;
1029 }
1030 
NameExistsInGraphDef(StringPiece name)1031 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1032   if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1033   if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1034   return false;
1035 }
1036 
FindUniqueName(StringPiece original_name)1037 string GraphConstructor::FindUniqueName(StringPiece original_name) {
1038   string name(original_name);
1039   int count = 0;
1040   // Check that any generated names don't collide with imported NodeDefs (as
1041   // well as nodes in g_).
1042   while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1043     name = strings::StrCat(original_name, "_", ++count);
1044   }
1045   return name;
1046 }
1047 
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)1048 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1049                                            bool* is_node_mapped) {
1050   const OpDef* op_def;
1051   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1052   for (int i = 0; i < op_def->output_arg_size(); ++i) {
1053     if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1054       *is_node_mapped = false;
1055       return Status::OK();
1056     }
1057   }
1058   *is_node_mapped = true;
1059   return Status::OK();
1060 }
1061 
DFS(int cur_node,std::vector<int> * cur_branch,std::vector<bool> * is_on_cur_branch,absl::flat_hash_set<int> * unvisited)1062 void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1063                            std::vector<bool>* is_on_cur_branch,
1064                            absl::flat_hash_set<int>* unvisited) {
1065   cur_branch->push_back(cur_node);
1066   is_on_cur_branch->at(cur_node) = true;
1067   for (auto next_node : outputs_[cur_node]) {
1068     if (unvisited->find(next_node) != unvisited->end()) {
1069       if (is_on_cur_branch->at(next_node)) {
1070         auto iter =
1071             std::find(cur_branch->begin(), cur_branch->end(), next_node);
1072         LOG(WARNING) << "Cycle detected:";
1073         while (iter != cur_branch->end()) {
1074           LOG(WARNING) << SummarizeNodeDef(get_node_def(*iter));
1075           ++iter;
1076         }
1077         LOG(WARNING) << "End of cycle";
1078       } else {
1079         DFS(next_node, cur_branch, is_on_cur_branch, unvisited);
1080       }
1081     }
1082   }
1083   cur_branch->pop_back();
1084   is_on_cur_branch->at(cur_node) = false;
1085   unvisited->erase(cur_node);
1086 }
1087 
PrintCycles()1088 void GraphConstructor::PrintCycles() {
1089   int num_nodes = outputs_.size();
1090   absl::flat_hash_set<int> unvisited;
1091   for (int i = 0; i < num_nodes; i++) {
1092     unvisited.insert(i);
1093   }
1094   while (!unvisited.empty()) {
1095     int cur_node = *unvisited.begin();
1096     // Nodes on the current branch of DFS in traversal order. This is used for
1097     // printing the nodes in the cycle.
1098     std::vector<int> cur_branch;
1099     // This is just to make lookups O(1).
1100     // is_on_cur_branch[i] ==
1101     //   (std::find(cur_branch.start(),
1102     //              cur_branch.end(), i) != cur_branch.end())
1103     std::vector<bool> is_on_cur_branch(num_nodes, false);
1104     DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited);
1105   }
1106 }
1107 
Convert()1108 Status GraphConstructor::Convert() {
1109   // Import functions before adding nodes, since imported nodes may refer to
1110   // functions
1111   if (library()) {
1112     // TODO(b/135705010): Add rvalue overloads into the function library, to
1113     // avoid unnecessarily copying `*library()` here.
1114     TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1115   }
1116 
1117   std::vector<InputInfo> inputs;
1118   int processed = 0;
1119 
1120   std::vector<bool> input_already_exists;
1121 
1122   // Process the NodeDefs in topological order.
1123   // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1124   // inputs, pending_counts_ with the number of inputs for each node and
1125   // outputs_ with the outputs of each node).
1126   while (!ready_.empty()) {
1127     int o = *ready_.begin();
1128     ready_.erase(ready_.begin());
1129     ++processed;
1130     inputs.clear();
1131     bool has_data_back_edge = false;
1132 
1133     NodeDef node_def = consume_node_def(o);
1134 
1135     // input_already_exists[i] is true iff the i-th input of the node we're
1136     // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1137     // to importing node_defs_).  Conversely, input_already_exists[i] is false
1138     // iff the input refers to a node in node_defs_.
1139     input_already_exists.clear();
1140     input_already_exists.resize(node_def.input_size(), false);
1141 
1142     ssize_t string_intern_table_index = -1;
1143 
1144     if (opts_.importing) {
1145       // Intern the original node name, so that we can use a StringPiece of the
1146       // name to index gdef_nodes_.
1147       string_intern_table_index = string_intern_table_.size();
1148       string_intern_table_.push_back(node_def.name());
1149 
1150       if (opts_.skip_mapped_nodes) {
1151         bool is_node_mapped = false;
1152         TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1153         if (is_node_mapped) {
1154           // Skip this node after updating pending_count_ for outputs
1155           UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1156           continue;
1157         }
1158       }
1159 
1160       if (!opts_.input_map.empty()) {
1161         // Note that input_already_exists can shrink here
1162         RemapNodeDefInputs(&node_def, &input_already_exists);
1163       }
1164       if (!opts_.control_dependencies.empty()) {
1165         // Note that input_already_exists can grow here
1166         AddControlDependencies(&node_def, &input_already_exists);
1167       }
1168       if (!opts_.default_device.empty() && node_def.device().empty()) {
1169         node_def.set_device(opts_.default_device);
1170       }
1171     }
1172 
1173     DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1174     TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1175     for (int i = 0; i < node_def.input_size(); ++i) {
1176       TensorId tensor_id = ParseTensorName(node_def.input(i));
1177       Node* src_node;
1178       int src_index;
1179 
1180       if (!input_already_exists[i]) {
1181         // Locate input in newly-imported nodes
1182         auto iter = gdef_nodes_.find(tensor_id.node());
1183         DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1184         src_node = iter->second.node;
1185         src_index = tensor_id.index();
1186         if (src_node == nullptr) has_data_back_edge = true;
1187       } else {
1188         // Input refers to preexistng node in graph
1189         auto iter = existing_nodes_.find(tensor_id.node());
1190         DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1191         src_node = iter->second;
1192         src_index = tensor_id.index();
1193       }
1194 
1195       if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1196         std::ostringstream out;
1197         out << "Node '" << node_def.name() << "': Connecting to invalid output "
1198             << tensor_id.index() << " of source node " << tensor_id.node()
1199             << " which has " << src_node->num_outputs() << " outputs.";
1200 
1201         if (src_node->type_string() == "If" ||
1202             src_node->type_string() == "StatelessIf" ||
1203             src_node->type_string() == "While" ||
1204             src_node->type_string() == "StatelessWhile") {
1205           out << " Try using "
1206               << "tf.compat.v1.experimental.output_all_intermediates(True).";
1207         }
1208         return errors::InvalidArgument(out.str());
1209       }
1210 
1211       inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1212     }
1213 
1214     if (has_data_back_edge && !IsMerge(node_def)) {
1215       return errors::InvalidArgument(
1216           "Node '", node_def.name(),
1217           "' had a back edge, but only Merge nodes can have back edges.");
1218     }
1219 
1220     Node* node;
1221     if (opts_.importing) {
1222       if (!prefix_.empty()) {
1223         AddPrefixToNodeDef(input_already_exists, &node_def);
1224       }
1225       // Note: no need to uniquify names if the prefix already guarantees
1226       // uniqueness
1227       if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1228         UniquifyNames(input_already_exists, &node_def);
1229       }
1230     }
1231 
1232     if (opts_.importing) {
1233       TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1234     } else {
1235       const OpDef* op_def;
1236       TF_RETURN_IF_ERROR(
1237           g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1238       if (opts_.add_default_attributes) {
1239         AddDefaultsToNodeDef(*op_def, &node_def);
1240       }
1241       if (opts_.validate_nodes) {
1242         TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1243       }
1244     }
1245 
1246     TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1247 
1248     if (opts_.importing) {
1249       // Use interned original node name so StringPiece remains valid.
1250       DCHECK_GE(string_intern_table_index, 0);
1251       gdef_nodes_[string_intern_table_[string_intern_table_index]].node = node;
1252     } else {
1253       DCHECK_EQ(string_intern_table_index, -1);
1254       gdef_nodes_[node->name()].node = node;
1255     }
1256 
1257     // Remove duplicate control inputs before adding edges to the graph. It
1258     // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1259     auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1260     auto first_control_copy = first_control;
1261     std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1262     inputs.erase(
1263         std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1264         inputs.end());
1265 
1266     // Add edges from inputs to *node to the graph.
1267     for (size_t i = 0; i < inputs.size(); ++i) {
1268       if (inputs[i].node == nullptr) {
1269         // Record this back edge, which will be added after all nodes
1270         // are created.
1271         back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1272       } else if (inputs[i].index == Graph::kControlSlot) {
1273         g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1274       } else {
1275         TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1276       }
1277     }
1278 
1279     TF_RETURN_IF_ERROR(ValidateShape(node));
1280 
1281     // Update pending_count_ for outputs.
1282     UpdatePendingCountAndReady(o, node->IsNextIteration());
1283   }
1284 
1285   if (processed < node_def_count()) {
1286     LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1287                  << " NODES IN A CYCLE";
1288     for (int64 i = 0; i < node_def_count(); i++) {
1289       if (pending_count_[i] != 0) {
1290         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1291                      << " WITH PENDING COUNT = " << pending_count_[i];
1292       }
1293     }
1294     PrintCycles();
1295     return errors::InvalidArgument(node_def_count() - processed,
1296                                    " nodes in a cycle");
1297   }
1298 
1299   return Status::OK();
1300 }
1301 
AddBackEdges()1302 Status GraphConstructor::AddBackEdges() {
1303   // Add the back edges after all nodes are created.
1304   for (const auto& e : back_edges_) {
1305     Node* src_node = gdef_nodes_[e.src_name].node;
1306     if (e.src_index == Graph::kControlSlot) {
1307       g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1308     } else {
1309       TF_RETURN_IF_ERROR(
1310           MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1311     }
1312 
1313     VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1314             << e.dst_node->name();
1315   }
1316   return Status::OK();
1317 }
1318 
UpdateVersionDef()1319 Status GraphConstructor::UpdateVersionDef() {
1320   if (versions() == nullptr) return Status::OK();
1321 
1322   if (!opts_.importing) {
1323     g_->set_versions(*versions());
1324     return Status::OK();
1325   }
1326   VersionDef g_versions = g_->versions();
1327   g_versions.set_producer(
1328       std::min(g_versions.producer(), versions()->producer()));
1329   g_versions.set_min_consumer(
1330       std::max(g_versions.min_consumer(), versions()->min_consumer()));
1331   if (versions()->bad_consumers_size() > 0) {
1332     std::set<int> bad(g_versions.bad_consumers().begin(),
1333                       g_versions.bad_consumers().end());
1334     bad.insert(versions()->bad_consumers().begin(),
1335                versions()->bad_consumers().end());
1336     g_versions.clear_bad_consumers();
1337     for (int v : bad) {
1338       g_versions.add_bad_consumers(v);
1339     }
1340   }
1341   g_->set_versions(g_versions);
1342   return Status::OK();
1343 }
1344 
PopulateReturnTensors()1345 Status GraphConstructor::PopulateReturnTensors() {
1346   if (opts_.return_tensors.empty()) return Status::OK();
1347   for (const TensorId& id : opts_.return_tensors) {
1348     auto iter = opts_.input_map.find(id);
1349     if (iter == opts_.input_map.end()) {
1350       // Locate id in imported nodes
1351       auto iter = gdef_nodes_.find(id.first);
1352       if (iter == gdef_nodes_.end()) {
1353         return errors::InvalidArgument("Requested return tensor '",
1354                                        id.ToString(),
1355                                        "' not found in graph def");
1356       }
1357       int num_outputs = iter->second.node->num_outputs();
1358       if ((id.second < 0 || id.second >= num_outputs) &&
1359           id.second != Graph::kControlSlot) {
1360         return errors::InvalidArgument("Invalid return output ", id.second,
1361                                        " of node '", id.first, "', which has ",
1362                                        num_outputs, " output(s)");
1363       }
1364       return_tensors_->push_back({iter->second.node, id.second});
1365     } else {
1366       // id was remapped to existing node
1367       TensorId remapped_id = iter->second;
1368       DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1369       Node* node = existing_nodes_[remapped_id.first];
1370       return_tensors_->push_back({node, remapped_id.second});
1371     }
1372   }
1373   return Status::OK();
1374 }
1375 
PopulateReturnNodes()1376 Status GraphConstructor::PopulateReturnNodes() {
1377   if (opts_.return_nodes.empty()) return Status::OK();
1378   for (StringPiece name : opts_.return_nodes) {
1379     auto iter = gdef_nodes_.find(name);
1380     if (iter == gdef_nodes_.end()) {
1381       return errors::InvalidArgument("Requested return node '", name,
1382                                      "' not found in graph def");
1383     }
1384     return_nodes_->push_back(iter->second.node);
1385   }
1386   return Status::OK();
1387 }
1388 
PopulateMissingUnusedInputMapKeys()1389 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1390   if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1391   for (const auto& input_map_pair : opts_.input_map) {
1392     TensorId key = input_map_pair.first;
1393     if (used_input_map_keys_.count(key) > 0) continue;
1394 
1395     auto pair = gdef_nodes_.find(key.first);
1396     if (pair == gdef_nodes_.end()) {
1397       // key's node doesn't exist in GraphDef
1398       missing_unused_input_map_keys_->push_back(key);
1399       continue;
1400     }
1401 
1402     // Check that key's index is in bounds. Get the number of outputs from the
1403     // NodeDef, rather than the imported Node, since the Node may not exist if
1404     // opts_.skip_mapped_nodes is true.
1405     const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1406     const OpDef* op_def;
1407     TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1408     int num_outputs;
1409     TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1410     if (key.second >= num_outputs) {
1411       // key's index out of bounds
1412       missing_unused_input_map_keys_->push_back(key);
1413     }
1414   }
1415   return Status::OK();
1416 }
1417 
Undo()1418 void GraphConstructor::Undo() {
1419   for (const auto& iter : gdef_nodes_) {
1420     if (iter.second.node != nullptr) {
1421       g_->RemoveNode(iter.second.node);
1422     }
1423   }
1424   g_->set_versions(original_versions_);
1425 }
1426 
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1427 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1428                                   int input_index) {
1429   if (output_index >= src->num_outputs()) {
1430     return errors::InvalidArgument(
1431         "Output ", output_index, " of node ", src->name(),
1432         " does not exist. Node only has ", src->num_outputs(), " outputs.");
1433   }
1434   if (input_index >= dst->num_inputs()) {
1435     return errors::InvalidArgument(
1436         "Input ", input_index, " of node ", dst->name(),
1437         " does not exist. Node only has ", dst->num_inputs(), " inputs.");
1438   }
1439 
1440   DataType src_out = src->output_type(output_index);
1441   DataType dst_in = dst->input_type(input_index);
1442   if (!TypesCompatible(dst_in, src_out)) {
1443     return errors::InvalidArgument(
1444         "Input ", input_index, " of node ", dst->name(), " was passed ",
1445         DataTypeString(src_out), " from ", src->name(), ":", output_index,
1446         " incompatible with expected ", DataTypeString(dst_in), ".");
1447   }
1448   g_->AddEdge(src, output_index, dst, input_index);
1449   return Status::OK();
1450 }
1451 
1452 }  // namespace
1453 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1454 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1455                               const GraphDef& gdef, Graph* g) {
1456   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1457   return GraphConstructor::Construct(
1458       opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1459       /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1460       /*missing_unused_input_map_keys=*/nullptr);
1461 }
1462 
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,GraphDef && gdef,Graph * g)1463 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1464                               GraphDef&& gdef, Graph* g) {
1465   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1466   return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1467                                      /*return_tensors=*/nullptr,
1468                                      /*return_nodes=*/nullptr,
1469                                      /*missing_unused_input_map_keys=*/nullptr);
1470 }
1471 
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1472 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1473                               gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1474   ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1475   // TODO(irving): Copy will go away once NodeInfo exists
1476   std::vector<const NodeDef*> node_defs;
1477   node_defs.reserve(nodes.size());
1478   for (const auto& n : nodes) {
1479     node_defs.push_back(&n);
1480   }
1481   return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1482                                      &refiner, /*return_tensors=*/nullptr,
1483                                      /*return_nodes=*/nullptr,
1484                                      /*missing_unused_input_map_keys=*/nullptr);
1485 }
1486 
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1487 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1488                       Graph* g, ShapeRefiner* refiner,
1489                       ImportGraphDefResults* results) {
1490   if (!opts.return_tensors.empty()) {
1491     if (results == nullptr) {
1492       return errors::InvalidArgument(
1493           "results argument to ImportGraphDef() must be non-null if "
1494           "opts.return_tensors is non-empty");
1495     }
1496   }
1497 
1498   if (!opts.return_nodes.empty()) {
1499     if (opts.skip_mapped_nodes) {
1500       return errors::InvalidArgument(
1501           "Requesting return_nodes with skip_mapped_nodes set is not currently "
1502           "supported");
1503     }
1504     if (results == nullptr) {
1505       return errors::InvalidArgument(
1506           "results argument to ImportGraphDef() must be non-null if "
1507           "opts.return_nodes is non-empty");
1508     }
1509   }
1510 
1511   if (results != nullptr) {
1512     if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1513         !results->missing_unused_input_map_keys.empty()) {
1514       return errors::InvalidArgument(
1515           "All fields in results argument to ImportGraphDef() must be empty.");
1516     }
1517   }
1518 
1519   ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1520   if (refiner == nullptr) {
1521     refiner = &default_refiner;
1522   } else {
1523     // Log a warning if we are importing a GraphDef at an older
1524     // producer version after already having added non-source/sink
1525     // nodes to the graph in the past.
1526     if (gdef.versions().producer() > 0 &&
1527         gdef.versions().producer() < refiner->graph_def_version() &&
1528         g->num_nodes() > 2) {
1529       LOG(WARNING) << "Importing a graph with a lower producer version "
1530                    << gdef.versions().producer()
1531                    << " into an existing graph with producer version "
1532                    << refiner->graph_def_version() << ". Shape inference will "
1533                    << "have run different parts of the graph with different "
1534                    << "producer versions.";
1535     }
1536   }
1537 
1538   // Set the graph def version of the refiner as the min of the
1539   // current value and the version from the graph we are about to
1540   // import.
1541   //
1542   // Note: to match Run() semantics, we should re-run shape inference
1543   // on the entire graph if the producer version has changed.  For now
1544   // we log the warning above.
1545   refiner->set_graph_def_version(
1546       std::min(refiner->graph_def_version(), gdef.versions().producer()));
1547 
1548   if (results == nullptr) {
1549     return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1550                                        &gdef.library(), g, refiner, nullptr,
1551                                        nullptr, nullptr);
1552   } else {
1553     return GraphConstructor::Construct(
1554         opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1555         &results->return_tensors, &results->return_nodes,
1556         &results->missing_unused_input_map_keys);
1557   }
1558 }
1559 
CopyGraph(const Graph & src,Graph * dest)1560 void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
1561 
1562 }  // namespace tensorflow
1563