1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/graph_constructor.h"
17
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/core/common_runtime/shape_refiner.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/versions.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/tensor_id.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/flatmap.h"
43 #include "tensorflow/core/lib/gtl/flatset.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/scanner.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/public/version.h"
51
52 namespace tensorflow {
53
54 namespace {
55
56 // We remove duplicate control inputs before adding edges to the Graph, so we
57 // can skip expensive duplicates check in 'AddControlEdge'.
58 static constexpr const bool kDoNotCheckDuplicates = true;
59
IsMerge(const NodeDef & node_def)60 inline bool IsMerge(const NodeDef& node_def) {
61 return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
62 node_def.op() == "_XlaMerge";
63 }
64
IsNextIteration(const NodeDef & node_def)65 inline bool IsNextIteration(const NodeDef& node_def) {
66 return node_def.op() == "NextIteration" ||
67 node_def.op() == "RefNextIteration";
68 }
69
IsValidNodeName(StringPiece s,bool allow_internal_ops)70 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
71 using ::tensorflow::strings::Scanner;
72 Scanner scanner(s);
73 scanner
74 .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
75 : Scanner::LETTER_DIGIT_DOT)
76 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
77
78 while (true) {
79 if (!scanner.GetResult()) // Some error in previous iteration.
80 return false;
81 if (scanner.empty()) // No error, but nothing left, good.
82 return true;
83
84 // Absorb another piece, starting with a '>'
85 scanner.One(Scanner::RANGLE)
86 .One(Scanner::LETTER_DIGIT_DOT)
87 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
88 }
89 }
90
91 class GraphConstructor {
92 public:
93 struct Options {
Optionstensorflow::__anond78e089a0111::GraphConstructor::Options94 Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
95 : allow_internal_ops(in.allow_internal_ops),
96 expect_device_spec(in.expect_device_spec),
97 importing(false),
98 validate_nodes(in.validate_nodes),
99 validate_colocation_constraints(false),
100 add_default_attributes(in.add_default_attributes) {}
Optionstensorflow::__anond78e089a0111::GraphConstructor::Options101 Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
102 : allow_internal_ops(false),
103 expect_device_spec(false),
104 prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
105 ? in.prefix
106 : in.prefix + "/"),
107 uniquify_names(in.uniquify_names),
108 uniquify_prefix(in.uniquify_prefix),
109 input_map(in.input_map.begin(), in.input_map.end()),
110 skip_mapped_nodes(in.skip_mapped_nodes),
111 control_dependencies(in.control_dependencies),
112 return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
113 return_nodes(in.return_nodes),
114 importing(true),
115 validate_nodes(true),
116 validate_colocation_constraints(in.validate_colocation_constraints),
117 validate_shape(in.validate_shape),
118 default_device(in.default_device) {}
119
120 bool allow_internal_ops;
121 bool expect_device_spec;
122
123 string prefix;
124 bool uniquify_names;
125 bool uniquify_prefix;
126 std::map<TensorId, TensorId> input_map;
127 bool skip_mapped_nodes;
128 std::vector<string> control_dependencies;
129 std::vector<TensorId> return_tensors;
130 std::vector<string> return_nodes;
131
132 // TODO(ashankar): This bool exists to separate out functionality required
133 // to make ImportGraphDef a close equivalent of Python's import_graph_def
134 // without affecting the behavior of ConvertGraphDefToGraph at the time
135 // ImportGraphDef was added.
136 //
137 // That said, the functionality here (shape and op validation) seems
138 // applicable to ConvertGraphDefToGraph as well, so make an attempt to
139 // remove this.
140 bool importing;
141 // If true, validates that nodes being converted have all expected attrs
142 // set and no unknown attrs set by calling ValidateNodeDef().
143 // `validate_nodes` is always true when `importing` is set.
144 bool validate_nodes;
145 bool validate_colocation_constraints;
146 bool validate_shape = true;
147
148 // If true, GraphConstructor will add attributes with their default
149 // value to the Node when they are missing from the NodeDef.
150 bool add_default_attributes = true;
151
152 string default_device;
153 };
154
155 typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
156
157 // versions and library may be nullptr
158 static Status Construct(
159 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
160 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
161 std::vector<std::pair<Node*, int>>* return_tensors,
162 std::vector<Node*>* return_nodes,
163 std::vector<SafeTensorId>* missing_unused_input_map_keys);
164
165 static Status Construct(
166 const Options& opts, GraphDef&& graph_def, Graph* g,
167 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
168 std::vector<Node*>* return_nodes,
169 std::vector<SafeTensorId>* missing_unused_input_map_keys);
170
171 protected:
GraphConstructor(const Options & opts,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)172 GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
173 std::vector<std::pair<Node*, int>>* return_tensors,
174 std::vector<Node*>* return_nodes,
175 std::vector<SafeTensorId>* missing_unused_input_map_keys)
176 : opts_(opts),
177 g_(g),
178 original_versions_(g->versions()),
179 prefix_(opts.prefix),
180 refiner_(refiner),
181 return_tensors_(return_tensors),
182 return_nodes_(return_nodes),
183 missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
184
~GraphConstructor()185 virtual ~GraphConstructor() {}
186
TryImport()187 Status TryImport() {
188 TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
189 TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
190 TF_RETURN_IF_ERROR(BuildNodeIndex());
191 TF_RETURN_IF_ERROR(InitFromEdges());
192
193 // NOTE: Convert() invokes `consume_node_def()` on each node in the input
194 // graph, so `get_node_def()` is no longer usable once it is called.
195 TF_RETURN_IF_ERROR(Convert());
196
197 TF_RETURN_IF_ERROR(AddBackEdges());
198 TF_RETURN_IF_ERROR(UpdateVersionDef());
199 TF_RETURN_IF_ERROR(PopulateReturnTensors());
200 TF_RETURN_IF_ERROR(PopulateReturnNodes());
201 TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
202 UpdateUniquifiedColocationNames();
203 FixupSourceAndSinkEdges(g_);
204 return Status::OK();
205 }
206
207 private:
208 Status EnsureNoNameCollisions();
209 Status ValidateInputMapAndControlDependencies();
210 Status BuildNodeIndex();
211 Status InitFromEdges();
212 Status Convert();
213 Status AddBackEdges();
214 Status UpdateVersionDef();
215 Status PopulateReturnTensors();
216 Status PopulateReturnNodes();
217 Status PopulateMissingUnusedInputMapKeys();
218
219 void Undo();
220
221 // Prints cycles in the graph.
222 void PrintCycles();
223 // Performs DFS starting at `cur_node` and prints any cycles found.
224 void DFS(int cur_node, std::vector<int>* cur_branch,
225 std::vector<bool>* is_on_cur_branch,
226 absl::flat_hash_set<int>* unvisited);
227 Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
228 Status ValidateColocationConstraints(const NodeDef& node_def);
229 Status MakeNode(NodeDef&& node_def, Node** node);
230 Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
231 Status ValidateShape(Node* node);
232 Status ModifyNodeDefForImport(NodeDef* node_def);
233 // Modifies node_def's inputs according to opts_.input_map.
234 // input_already_exists is a pre-initialized vector of length
235 // node_def->input_size(). This function will mark inputs that are remapped to
236 // true.
237 void RemapNodeDefInputs(NodeDef* node_def,
238 std::vector<bool>* input_already_exists);
239 // input_already_exists is a pre-initialized vector of length
240 // node_def->input_size(). This function will add and mark control inputs as
241 // true.
242 void AddControlDependencies(NodeDef* node_def,
243 std::vector<bool>* input_already_exists);
244 void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
245 NodeDef* node_def);
246
247 // Modifies `node_def` if its name isn't unique, or if any of its inputs'
248 // names have been uniquified. This must be called in topological order on all
249 // nodes.
250 void UniquifyNames(const std::vector<bool>& input_already_exists,
251 NodeDef* node_def);
252
253 // Updates any constructed nodes' colocation group names if the name has been
254 // updated by UniquifyNames. This is called after all the nodes have been
255 // constructed so all the names have been uniquified if necessary.
256 void UpdateUniquifiedColocationNames();
257
258 // Returns true if `name` already exists in `g_` (either as a node name or
259 // prefix).
260 bool NameExistsInGraph(StringPiece name);
261
262 // Returns true if `name` already exists in the GraphDef being imported
263 // (either as a node name or prefix).
264 bool NameExistsInGraphDef(StringPiece name);
265
266 // Returns a unique version of `original_name`, or `original_name` if it's
267 // already unique in the graph.
268 string FindUniqueName(StringPiece original_name);
269
270 // Decrement pending count for users of `processed` and add the ones that now
271 // have all of their pending inputs satisfied to `ready_`.
272 void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
273
274 // Subclasses override the following virtual methods to provide efficient
275 // access to the original protocol buffer-based graph.
276
277 // Returns the number of nodes in the graph.
278 virtual size_t node_def_count() const = 0;
279 // Returns the i^th node in the graph. Must not be called after
280 // consume_node_def(i).
281 virtual const NodeDef& get_node_def(int i) const = 0;
282 // Destructively reads the i^th node in the graph, avoiding a copy if
283 // possible. After calling this method, the result of get_node_def(i) is
284 // undefined.
285 virtual NodeDef consume_node_def(int i) = 0;
286 // Returns the version information for the graph, or nullptr if none is
287 // available.
288 virtual const VersionDef* versions() const = 0;
289 // Returns the function information for the graph, or nullptr if none is
290 // available.
291 virtual const FunctionDefLibrary* library() const = 0;
292
293 // From constructor
294 const Options opts_;
295 Graph* g_;
296 const VersionDef original_versions_;
297
298 // A copy of opts_.prefix, possibly uniquified.
299 string prefix_;
300
301 ShapeRefiner* refiner_;
302
303 // May be null. Not owned.
304 std::vector<std::pair<Node*, int>>* return_tensors_;
305
306 // May be null. Not owned.
307 std::vector<Node*>* return_nodes_;
308
309 // May be null. Not owned.
310 std::vector<SafeTensorId>* missing_unused_input_map_keys_;
311
312 // Intermediate datastructure used to populate
313 // `missing_unused_input_map_keys_`.
314 std::set<TensorId> used_input_map_keys_;
315
316 // Intermediate datastructure used to track the destinations of back edges.
317 absl::flat_hash_set<int> merge_node_indices_;
318
319 // Mapping from node name to the index within node_defs_.
320 struct NodeInfo {
NodeInfotensorflow::__anond78e089a0111::GraphConstructor::NodeInfo321 explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
322 // Containers require that we have a default constructor.
NodeInfotensorflow::__anond78e089a0111::GraphConstructor::NodeInfo323 NodeInfo() : NodeInfo(-1) {}
324 int gdef_index;
325 Node* node; // nullptr until the NodeDef is converted to a Node.
326 };
327 gtl::FlatMap<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_;
328
329 // Storage for StringPiece keys in gdef_nodes_. Typically, the StringPiece key
330 // will refer to the string stored in `NodeDef::name()`. This intern table is
331 // only used when the original NodeDef's name is changed.
332 std::vector<string> string_intern_table_;
333
334 // Prefixes already used in the GraphDef being imported.
335 gtl::FlatSet<StringPiece, StringPieceHasher> gdef_prefixes_;
336
337 // Mapping from node name to the existing node in g_.
338 gtl::FlatMap<StringPiece, Node*, StringPieceHasher> existing_nodes_;
339
340 // Prefixes already used in the graph.
341 gtl::FlatSet<StringPiece, StringPieceHasher> existing_prefixes_;
342
343 // Imported node names that have been uniquified. The key is the original
344 // name, the value is the new unique name.
345 gtl::FlatMap<string, string> uniquified_names_;
346
347 // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
348 // (sorted) set so nodes are created in the order defined in the GraphDef.
349 std::set<int> ready_;
350
351 // Mapping between index within node_defs_ and the number of inputs that
352 // still need to be converted.
353 std::vector<int> pending_count_;
354
355 // Mapping between index within node_defs_ and the index within node_defs_ of
356 // all nodes it outputs to.
357 std::vector<gtl::InlinedVector<int, 4>> outputs_;
358
359 // Used in the conversion from node_defs_ to g_ to represent the ith input
360 // of a node.
361 struct InputInfo {
InputInfotensorflow::__anond78e089a0111::GraphConstructor::InputInfo362 explicit InputInfo(const string& node_name, Node* n, int i)
363 : name(node_name), node(n), index(i) {}
364 // Use string instead of StringPiece so we don't have to manage lifetime
365 string name;
366 Node* node;
367 int index;
368
IsControlInputtensorflow::__anond78e089a0111::GraphConstructor::InputInfo369 static bool IsControlInput(const InputInfo& input) {
370 return input.index == Graph::kControlSlot;
371 }
CompareNametensorflow::__anond78e089a0111::GraphConstructor::InputInfo372 static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
373 return lhs.name < rhs.name;
374 }
IsSameNametensorflow::__anond78e089a0111::GraphConstructor::InputInfo375 static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
376 return lhs.name == rhs.name;
377 }
378 };
379
380 // Used in the conversion from node_defs_ to g_ to represent an edge from
381 // the node named 'name' to node 'n'.
382 struct EdgeInfo {
EdgeInfotensorflow::__anond78e089a0111::GraphConstructor::EdgeInfo383 explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
384 : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
385 // Use string instead of StringPiece so we don't have to manage lifetime
386 string src_name;
387 int src_index;
388 Node* dst_node;
389 int dst_index;
390 };
391 std::vector<EdgeInfo> back_edges_;
392
393 TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
394 };
395
396 // Implementation of GraphConstructor that does not take ownership of the
397 // input NodeDef messages and thus copies the nodes into the constructed Graph*.
398 //
399 // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
400 // avoids copying each NodeDef into the constructed Graph*.
401 class NodeDefCopyingGraphConstructor : public GraphConstructor {
402 public:
NodeDefCopyingGraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)403 NodeDefCopyingGraphConstructor(
404 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
405 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
406 std::vector<std::pair<Node*, int>>* return_tensors,
407 std::vector<Node*>* return_nodes,
408 std::vector<SafeTensorId>* missing_unused_input_map_keys)
409 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
410 missing_unused_input_map_keys),
411 node_defs_(node_defs),
412 versions_(versions),
413 library_(library) {}
414
415 private:
node_def_count() const416 size_t node_def_count() const override { return node_defs_.size(); }
get_node_def(int i) const417 const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
consume_node_def(int i)418 NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
versions() const419 const VersionDef* versions() const override { return versions_; }
library() const420 const FunctionDefLibrary* library() const override { return library_; }
421
422 const NodeDefSlice node_defs_;
423 const VersionDef* const versions_;
424 const FunctionDefLibrary* const library_;
425 };
426
427 // Implementation of GraphConstructor that takes ownership of the input
428 // GraphDef, and can perform destructive reads.
429 class NodeDefMovingGraphConstructor : public GraphConstructor {
430 public:
NodeDefMovingGraphConstructor(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)431 NodeDefMovingGraphConstructor(
432 const Options& opts, GraphDef&& graph_def, Graph* g,
433 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
434 std::vector<Node*>* return_nodes,
435 std::vector<SafeTensorId>* missing_unused_input_map_keys)
436 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
437 missing_unused_input_map_keys),
438 graph_def_(std::move(graph_def)),
439 is_consumed_(graph_def_.node_size(), false) {}
440
441 private:
node_def_count() const442 size_t node_def_count() const override { return graph_def_.node().size(); }
get_node_def(int i) const443 const NodeDef& get_node_def(int i) const override {
444 CHECK(!is_consumed_[i])
445 << "NodeDef " << i << " accessed after it was consumed.";
446 return graph_def_.node(i);
447 }
consume_node_def(int i)448 NodeDef consume_node_def(int i) override {
449 CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
450 is_consumed_[i] = true;
451 return std::move(*graph_def_.mutable_node(i));
452 }
versions() const453 const VersionDef* versions() const override { return &graph_def_.versions(); }
library() const454 const FunctionDefLibrary* library() const override {
455 return &graph_def_.library();
456 }
457
458 GraphDef graph_def_;
459 std::vector<bool> is_consumed_;
460 };
461
ForwardCompatibilityWindowPassed(const VersionDef & versions)462 bool ForwardCompatibilityWindowPassed(const VersionDef& versions) {
463 // TF_GRAPH_DEF_VERSION is incremented daily.
464 // TF has a 3 week forward compatibility guarantee.
465 return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21;
466 }
467
MaybeAppendVersionWarning(const VersionDef * versions,const Status & import_status)468 Status MaybeAppendVersionWarning(const VersionDef* versions,
469 const Status& import_status) {
470 if (versions && ForwardCompatibilityWindowPassed(*versions)) {
471 return Status(
472 import_status.code(),
473 absl::StrCat(
474 "Converting GraphDef to Graph has failed with an error: '",
475 import_status.error_message(),
476 "' The binary trying to import the GraphDef was built when "
477 "GraphDef version was ",
478 TF_GRAPH_DEF_VERSION,
479 ". The GraphDef was produced by a binary built when GraphDef "
480 "version was ",
481 versions->producer(),
482 ". The difference between these versions is larger than "
483 "TensorFlow's forward compatibility guarantee, and might be the "
484 "root cause for failing to import the GraphDef."));
485 }
486 return import_status;
487 }
488
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)489 /* static */ Status GraphConstructor::Construct(
490 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
491 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
492 std::vector<std::pair<Node*, int>>* return_tensors,
493 std::vector<Node*>* return_nodes,
494 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
495 if (versions) {
496 TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
497 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
498 "GraphDef", "graph"));
499 }
500 NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
501 refiner, return_tensors, return_nodes,
502 missing_unused_input_map_keys);
503 Status s = c.TryImport();
504 if (!s.ok()) {
505 c.Undo();
506 s = MaybeAppendVersionWarning(versions, s);
507 }
508 return s;
509 }
510
Construct(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)511 /* static */ Status GraphConstructor::Construct(
512 const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
513 std::vector<std::pair<Node*, int>>* return_tensors,
514 std::vector<Node*>* return_nodes,
515 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
516 TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
517 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
518 "GraphDef", "graph"));
519 VersionDef version_def = graph_def.versions();
520 NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
521 return_tensors, return_nodes,
522 missing_unused_input_map_keys);
523 Status s = c.TryImport();
524 if (!s.ok()) {
525 c.Undo();
526 s = MaybeAppendVersionWarning(&version_def, s);
527 }
528 return s;
529 }
530
UpdatePendingCountAndReady(int processed,bool is_next_iteration)531 void GraphConstructor::UpdatePendingCountAndReady(int processed,
532 bool is_next_iteration) {
533 for (size_t i = 0; i < outputs_[processed].size(); ++i) {
534 const int output = outputs_[processed][i];
535 // We didn't consider NextIteration->Merge edges when computing
536 // pending_counts_ so we should not have to consider it here either.
537 bool is_next_iteration_to_merge_edge =
538 is_next_iteration && merge_node_indices_.count(output) == 1;
539 if (!is_next_iteration_to_merge_edge) {
540 int* current_pending_count = &pending_count_[output];
541 CHECK_GT(*current_pending_count, 0);
542 (*current_pending_count)--;
543 if (*current_pending_count == 0) {
544 ready_.insert(output);
545 }
546 }
547 }
548 }
549
550 // This could be expensive but we don't expect to call it often, if at all (only
551 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)552 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
553 const StringPiece& node_name) {
554 for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
555 if (iter->second.first == node_name) return true;
556 }
557 return false;
558 }
559
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)560 bool NodeNameInValues(const std::vector<string>& control_dependencies,
561 const StringPiece& node_name) {
562 return std::find(control_dependencies.begin(), control_dependencies.end(),
563 node_name) != control_dependencies.end();
564 }
565
566 // Adds any prefixes of `node_name` (not including the full name itself) to
567 // `prefixes`.
AddPrefixes(StringPiece node_name,gtl::FlatSet<StringPiece,StringPieceHasher> * prefixes)568 void AddPrefixes(StringPiece node_name,
569 gtl::FlatSet<StringPiece, StringPieceHasher>* prefixes) {
570 size_t idx = -1;
571 while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
572 prefixes->insert(node_name.substr(0, idx));
573 }
574 }
575
EnsureNoNameCollisions()576 Status GraphConstructor::EnsureNoNameCollisions() {
577 existing_nodes_.reserve(g_->num_nodes());
578 // Populate existing_nodes_ and existing_prefixes_.
579 for (Node* n : g_->nodes()) {
580 bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
581 if (already_exists) {
582 if (NodeNameInValues(opts_.input_map, n->name())) {
583 return errors::InvalidArgument(
584 "cannot resolve input_map because multiple nodes exist with name '",
585 n->name(), "'");
586 }
587 if (NodeNameInValues(opts_.control_dependencies, n->name())) {
588 return errors::InvalidArgument(
589 "cannot resolve control_dependencies because multiple nodes exist "
590 "with name '",
591 n->name(), "'");
592 }
593 }
594 AddPrefixes(n->name(), &existing_prefixes_);
595 }
596 if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
597 for (size_t i = 0; i < node_def_count(); ++i) {
598 const string& name = get_node_def(i).name();
599 if (NameExistsInGraph(name)) {
600 return errors::InvalidArgument("Node name '", name,
601 "' already exists in the Graph");
602 }
603 }
604 } else if (!prefix_.empty()) {
605 StringPiece prefix_no_slash(prefix_);
606 prefix_no_slash.remove_suffix(1);
607 if (!IsValidNodeName(prefix_no_slash, false)) {
608 return errors::InvalidArgument("Imported node name prefix '", prefix_,
609 "' would lead to invalid node names");
610 }
611 if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
612 prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
613 }
614 }
615 return Status::OK();
616 }
617
ValidateInputMapAndControlDependencies()618 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
619 for (const auto& mapping : opts_.input_map) {
620 TensorId src = mapping.first;
621 TensorId dst = mapping.second;
622 if (existing_nodes_.count(dst.first) == 0) {
623 return errors::InvalidArgument(
624 "node '", dst.first, "' in input_map does not exist in graph ",
625 "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
626 }
627 if ((src.second == Graph::kControlSlot) !=
628 (dst.second == Graph::kControlSlot)) {
629 return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
630 dst.ToString(), " between ",
631 "control edge and non-control edge");
632 }
633 }
634 for (const string& node : opts_.control_dependencies) {
635 if (existing_nodes_.count(node) == 0) {
636 return errors::InvalidArgument(
637 "node '", node,
638 "' in control_dependencies does not exist in "
639 "graph");
640 }
641 }
642 return Status::OK();
643 }
644
BuildNodeIndex()645 Status GraphConstructor::BuildNodeIndex() {
646 // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
647 for (int n = 0; n < node_def_count(); ++n) {
648 const NodeDef& node_def = get_node_def(n);
649 if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
650 return errors::InvalidArgument(
651 "Node '", node_def.name(),
652 "': Node name contains invalid characters");
653 }
654 if (!gdef_nodes_
655 .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
656 .second) {
657 return errors::InvalidArgument("Node '", node_def.name(),
658 "' is not unique");
659 }
660 // Validate the operation's type.
661 if (node_def.op().empty()) {
662 return errors::InvalidArgument("Node '", node_def.name(),
663 "' does not specify an operation");
664 }
665 if (opts_.expect_device_spec && node_def.device().empty()) {
666 return errors::InvalidArgument("Node '", node_def.name(),
667 "' is missing a device specification");
668 }
669 if (IsMerge(node_def)) {
670 merge_node_indices_.insert(n);
671 }
672 // Validate control edges at end
673 bool in_control_dependence = false;
674 for (int i = 0; i < node_def.input_size(); ++i) {
675 StringPiece input_name = node_def.input(i);
676 if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
677 in_control_dependence = true;
678 } else if (in_control_dependence) {
679 return errors::InvalidArgument(
680 "Node '", node_def.name(),
681 "': Control dependencies must come after regular dependencies");
682 }
683 }
684 // Update gdef_prefixes_.
685 AddPrefixes(node_def.name(), &gdef_prefixes_);
686 }
687 return Status::OK();
688 }
689
InitFromEdges()690 Status GraphConstructor::InitFromEdges() {
691 const int num_nodes = node_def_count();
692 pending_count_.reserve(num_nodes);
693 outputs_.resize(num_nodes);
694 gtl::FlatSet<string> next_iteration_nodes;
695 for (int n = 0; n < node_def_count(); ++n) {
696 const NodeDef& node_def = get_node_def(n);
697 if (IsNextIteration(node_def)) {
698 next_iteration_nodes.insert(node_def.name());
699 }
700 }
701
702 // Parse the inputs for each node.
703 for (int n = 0; n < num_nodes; ++n) {
704 const NodeDef& node_def = get_node_def(n);
705 int pending_count = node_def.input_size();
706 if (IsMerge(node_def)) {
707 // Cycles in the graph are only allowed for while loops. A while loop is
708 // identified by an edge from a NextIteration node to a Merge node. For
709 // such Merge nodes, only wait for one non-control input before
710 // considering the node ready to process in Convert().
711 int32 num_control_edges = 0;
712 bool has_loop_back_edge = false;
713 for (int i = 0; i < node_def.input_size(); ++i) {
714 StringPiece input_name(node_def.input(i));
715 if (absl::StartsWith(input_name, "^")) {
716 num_control_edges++;
717 } else {
718 TensorId id(ParseTensorName(input_name));
719 if (next_iteration_nodes.find(string(id.first)) !=
720 next_iteration_nodes.end()) {
721 has_loop_back_edge = true;
722 }
723 }
724 }
725 if (has_loop_back_edge) {
726 pending_count = num_control_edges + 1;
727 }
728 }
729 for (int i = 0; i < node_def.input_size(); ++i) {
730 StringPiece input_name = node_def.input(i);
731 TensorId id(ParseTensorName(input_name));
732 if (opts_.input_map.count(id) == 0) {
733 // If an input is not mapped, then the input should appear in the graph
734 // being imported.
735 auto iter = gdef_nodes_.find(id.first);
736 if (iter == gdef_nodes_.end()) {
737 return errors::InvalidArgument("Node '", node_def.name(),
738 "': Unknown input node '",
739 node_def.input(i), "'");
740 }
741 outputs_[iter->second.gdef_index].push_back(n);
742 } else {
743 // This input is mapped to an existing edge. Therefore this input is
744 // as good as being already processed.
745 --pending_count;
746 DCHECK_GE(pending_count, 0);
747 }
748 }
749 if (pending_count == 0) {
750 ready_.insert(n);
751 }
752 pending_count_.push_back(pending_count);
753 }
754 return Status::OK();
755 }
756
ValidateColocationConstraints(const NodeDef & node_def)757 Status GraphConstructor::ValidateColocationConstraints(
758 const NodeDef& node_def) {
759 if (!opts_.validate_colocation_constraints || !opts_.importing)
760 return Status::OK();
761 const auto iter = node_def.attr().find(kColocationAttrName);
762 if (iter == node_def.attr().end()) return Status::OK();
763 for (const string& c : iter->second.list().s()) {
764 StringPiece s(c);
765 if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
766 gdef_nodes_.find(s) == gdef_nodes_.end()) {
767 return errors::InvalidArgument(
768 "Node '", node_def.name(),
769 "' expects to be colocated with unknown node '", s, "'");
770 }
771 }
772 return Status::OK();
773 }
774
MakeNode(NodeDef && node_def,Node ** node)775 Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
776 // Add the node to the graph.
777 Status status;
778 *node = g_->AddNode(std::move(node_def), &status);
779 if (!status.ok()) return status;
780 if (opts_.expect_device_spec) {
781 (*node)->set_assigned_device_name((*node)->def().device());
782 }
783 return Status::OK();
784 }
785
ValidateShape(Node * node)786 Status GraphConstructor::ValidateShape(Node* node) {
787 if (!opts_.importing || !opts_.validate_shape) return Status::OK();
788 TF_RETURN_IF_ERROR(refiner_->AddNode(node));
789 // For nodes with the _output_shapes attribute, override the shape.
790 std::vector<const TensorShapeProto*> shape_attrs;
791 const char* kAttrName = "_output_shapes";
792 if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
793 // No _output_shapes attribute, the AddNode call above was sufficient.
794 return Status::OK();
795 }
796 auto* ic = refiner_->GetContext(node);
797 DCHECK(ic != nullptr)
798 << "ShapeRefiner::AddNode() should have created the InferenceContext";
799 if (shape_attrs.size() < node->num_outputs()) {
800 return errors::InvalidArgument(
801 "Node '", node->name(), "' has ", node->num_outputs(),
802 " outputs but the ", kAttrName, " attribute specifies shapes for ",
803 shape_attrs.size(), " outputs");
804 }
805 // NOTE(skyewm): we don't raise an error here because some users depend on
806 // this behavior, even though it's unsafe.
807 // TODO(b/74619486): raise an error.
808 if (shape_attrs.size() > node->num_outputs()) {
809 LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
810 << " outputs but the " << kAttrName
811 << " attribute specifies shapes for " << shape_attrs.size()
812 << " outputs. Output shapes may be inaccurate.";
813 }
814 for (int i = 0; i < node->num_outputs(); ++i) {
815 const TensorShapeProto& p = *shape_attrs[i];
816 shape_inference::ShapeHandle h;
817 Status s = ic->MakeShapeFromShapeProto(p, &h);
818 if (!s.ok()) {
819 return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
820 kAttrName, " attribute (shape #", i,
821 " error:'", s.error_message(), "'");
822 }
823 s = refiner_->SetShape(node, i, h);
824 if (!s.ok()) {
825 return errors::InvalidArgument(
826 "Node '", node->name(), "' has an ", kAttrName,
827 " attribute inconsistent with the GraphDef for output #", i, ": ",
828 s.error_message());
829 }
830 }
831 node->ClearAttr(kAttrName);
832 return Status::OK();
833 }
834
ModifyNodeDefForImport(NodeDef * node_def)835 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
836 const OpDef* op_def;
837 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
838 AddDefaultsToNodeDef(*op_def, node_def);
839 TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
840 if (versions()) {
841 TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
842 }
843 return Status::OK();
844 }
845
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)846 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
847 std::vector<bool>* input_already_exists) {
848 // Remove 'inputs_to_remove' from 'node_def'
849 NodeDef copy;
850 copy.mutable_input()->Reserve(node_def->input_size() -
851 inputs_to_remove.size());
852 for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
853 if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
854 ++j;
855 } else {
856 copy.add_input()->swap(*node_def->mutable_input(i));
857 }
858 }
859 node_def->mutable_input()->Swap(copy.mutable_input());
860 // Remove 'inputs_to_remove' from 'input_already_exists'
861 for (int idx : inputs_to_remove) {
862 input_already_exists->erase(input_already_exists->begin() + idx);
863 }
864 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
865 }
866
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)867 void GraphConstructor::RemapNodeDefInputs(
868 NodeDef* node_def, std::vector<bool>* input_already_exists) {
869 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
870 std::set<TensorId> control_inputs;
871 std::vector<int> inputs_to_remove;
872
873 for (int i = 0; i < node_def->input_size(); ++i) {
874 auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
875 if (iter == opts_.input_map.end()) continue;
876 used_input_map_keys_.insert(iter->first);
877
878 TensorId new_input = iter->second;
879 if (new_input.second == Graph::kControlSlot) {
880 // Check if we've already remapped a different input to new_input, and if
881 // so remove this input.
882 if (control_inputs.count(new_input) > 0) {
883 inputs_to_remove.push_back(i);
884 continue;
885 }
886 control_inputs.insert(new_input);
887 }
888 node_def->set_input(i, new_input.ToString());
889 (*input_already_exists)[i] = true;
890 }
891 if (!inputs_to_remove.empty()) {
892 RemoveInputs(inputs_to_remove, node_def, input_already_exists);
893 }
894 }
895
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)896 void GraphConstructor::AddControlDependencies(
897 NodeDef* node_def, std::vector<bool>* input_already_exists) {
898 // To avoid adding redundant control dependencies to every imported node, skip
899 // nodes that will inherit the dependencies from another imported node.
900 bool inherits_deps = false;
901 for (int i = 0; i < node_def->input_size(); ++i) {
902 // Assume we won't inherit dependencies from remapped inputs that already
903 // exist in the graph. Even if we're wrong, we'll only add redundant
904 // dependencies.
905 if ((*input_already_exists)[i]) continue;
906
907 // If this input is a backedge, assume we won't inherit the dependencies.
908 // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
909 // worth optimizing these.
910 TensorId id(ParseTensorName(node_def->input(i)));
911 auto iter = gdef_nodes_.find(id.first);
912 DCHECK(iter != gdef_nodes_.end()) << id.first;
913 if (iter->second.node == nullptr) {
914 // Input hasn't been created yet, indicating it's a backedge.
915 continue;
916 }
917 inherits_deps = true;
918 }
919 if (inherits_deps) return;
920
921 // node_def either has no inputs or all remapped inputs, add the control
922 // dependencies
923 for (const string& control_dep : opts_.control_dependencies) {
924 string input = TensorId(control_dep, Graph::kControlSlot).ToString();
925 bool found = false;
926 for (int i = node_def->input_size() - 1; i >= 0; --i) {
927 const string& node_input = node_def->input(i);
928 if (node_input[0] != '^') {
929 // Control inputs are at the end. Break when we reach the non-control
930 // inputs.
931 break;
932 }
933 if (node_input == input) {
934 // Control dependency already exists
935 found = true;
936 break;
937 }
938 }
939 if (found) {
940 continue;
941 }
942 node_def->add_input(input);
943 input_already_exists->push_back(true);
944 }
945 }
946
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)947 void GraphConstructor::AddPrefixToNodeDef(
948 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
949 if (prefix_.empty()) return;
950 node_def->set_name(strings::StrCat(prefix_, node_def->name()));
951 // Update names of input nodes
952 for (int i = 0; i < node_def->input_size(); ++i) {
953 // Skip remapped inputs (which already exist in g_ and are not being
954 // imported).
955 if (input_already_exists[i]) continue;
956 StringPiece input(node_def->input(i));
957 if (absl::ConsumePrefix(&input, "^")) {
958 node_def->set_input(i, strings::StrCat("^", prefix_, input));
959 } else {
960 node_def->set_input(i, strings::StrCat(prefix_, input));
961 }
962 }
963 // Update names of colocation groups
964 if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
965 auto* list =
966 node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
967 for (int i = 0; i < list->s_size(); ++i) {
968 StringPiece v(list->s(i));
969 if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
970 list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
971 }
972 }
973 }
974 }
975
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)976 void GraphConstructor::UniquifyNames(
977 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
978 if (NameExistsInGraph(node_def->name())) {
979 string old_name = node_def->name();
980 node_def->set_name(FindUniqueName(node_def->name()));
981 uniquified_names_[old_name] = node_def->name();
982 // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
983 // `name` because we guarantee the original NodeDef names are unique,
984 // meaning we won't generate this name again.
985 }
986 for (int i = 0; i < node_def->input_size(); ++i) {
987 // Skip remapped inputs (which already exist in g_ and are not being
988 // imported).
989 if (input_already_exists[i]) continue;
990 TensorId id = ParseTensorName(node_def->input(i));
991 // We require that UniquifyNames() is called on all NodeDefs in topological
992 // order. This guarantees that node_def's inputs will already be uniquified
993 // if necessary.
994 auto iter = uniquified_names_.find(string(id.first));
995 if (iter == uniquified_names_.end()) continue;
996 id.first = iter->second;
997 node_def->set_input(i, id.ToString());
998 }
999 }
1000
UpdateUniquifiedColocationNames()1001 void GraphConstructor::UpdateUniquifiedColocationNames() {
1002 for (const auto& pair : gdef_nodes_) {
1003 Node* node = pair.second.node;
1004 if (node == nullptr) continue;
1005 std::vector<string> coloc_values;
1006 if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1007 continue;
1008 bool updated = false;
1009 for (size_t i = 0; i < coloc_values.size(); ++i) {
1010 StringPiece val(coloc_values[i]);
1011 if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1012 auto name_pair = uniquified_names_.find(string(val));
1013 if (name_pair == uniquified_names_.end()) continue;
1014 updated = true;
1015 coloc_values[i] =
1016 strings::StrCat(kColocationGroupPrefix, name_pair->second);
1017 }
1018 }
1019 if (updated) {
1020 node->AddAttr(kColocationAttrName, std::move(coloc_values));
1021 }
1022 }
1023 }
1024
NameExistsInGraph(StringPiece name)1025 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1026 if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1027 if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1028 return false;
1029 }
1030
NameExistsInGraphDef(StringPiece name)1031 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1032 if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1033 if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1034 return false;
1035 }
1036
FindUniqueName(StringPiece original_name)1037 string GraphConstructor::FindUniqueName(StringPiece original_name) {
1038 string name(original_name);
1039 int count = 0;
1040 // Check that any generated names don't collide with imported NodeDefs (as
1041 // well as nodes in g_).
1042 while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1043 name = strings::StrCat(original_name, "_", ++count);
1044 }
1045 return name;
1046 }
1047
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)1048 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1049 bool* is_node_mapped) {
1050 const OpDef* op_def;
1051 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1052 for (int i = 0; i < op_def->output_arg_size(); ++i) {
1053 if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1054 *is_node_mapped = false;
1055 return Status::OK();
1056 }
1057 }
1058 *is_node_mapped = true;
1059 return Status::OK();
1060 }
1061
DFS(int cur_node,std::vector<int> * cur_branch,std::vector<bool> * is_on_cur_branch,absl::flat_hash_set<int> * unvisited)1062 void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1063 std::vector<bool>* is_on_cur_branch,
1064 absl::flat_hash_set<int>* unvisited) {
1065 cur_branch->push_back(cur_node);
1066 is_on_cur_branch->at(cur_node) = true;
1067 for (auto next_node : outputs_[cur_node]) {
1068 if (unvisited->find(next_node) != unvisited->end()) {
1069 if (is_on_cur_branch->at(next_node)) {
1070 auto iter =
1071 std::find(cur_branch->begin(), cur_branch->end(), next_node);
1072 LOG(WARNING) << "Cycle detected:";
1073 while (iter != cur_branch->end()) {
1074 LOG(WARNING) << SummarizeNodeDef(get_node_def(*iter));
1075 ++iter;
1076 }
1077 LOG(WARNING) << "End of cycle";
1078 } else {
1079 DFS(next_node, cur_branch, is_on_cur_branch, unvisited);
1080 }
1081 }
1082 }
1083 cur_branch->pop_back();
1084 is_on_cur_branch->at(cur_node) = false;
1085 unvisited->erase(cur_node);
1086 }
1087
PrintCycles()1088 void GraphConstructor::PrintCycles() {
1089 int num_nodes = outputs_.size();
1090 absl::flat_hash_set<int> unvisited;
1091 for (int i = 0; i < num_nodes; i++) {
1092 unvisited.insert(i);
1093 }
1094 while (!unvisited.empty()) {
1095 int cur_node = *unvisited.begin();
1096 // Nodes on the current branch of DFS in traversal order. This is used for
1097 // printing the nodes in the cycle.
1098 std::vector<int> cur_branch;
1099 // This is just to make lookups O(1).
1100 // is_on_cur_branch[i] ==
1101 // (std::find(cur_branch.start(),
1102 // cur_branch.end(), i) != cur_branch.end())
1103 std::vector<bool> is_on_cur_branch(num_nodes, false);
1104 DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited);
1105 }
1106 }
1107
Convert()1108 Status GraphConstructor::Convert() {
1109 // Import functions before adding nodes, since imported nodes may refer to
1110 // functions
1111 if (library()) {
1112 // TODO(b/135705010): Add rvalue overloads into the function library, to
1113 // avoid unnecessarily copying `*library()` here.
1114 TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1115 }
1116
1117 std::vector<InputInfo> inputs;
1118 int processed = 0;
1119
1120 std::vector<bool> input_already_exists;
1121
1122 // Process the NodeDefs in topological order.
1123 // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1124 // inputs, pending_counts_ with the number of inputs for each node and
1125 // outputs_ with the outputs of each node).
1126 while (!ready_.empty()) {
1127 int o = *ready_.begin();
1128 ready_.erase(ready_.begin());
1129 ++processed;
1130 inputs.clear();
1131 bool has_data_back_edge = false;
1132
1133 NodeDef node_def = consume_node_def(o);
1134
1135 // input_already_exists[i] is true iff the i-th input of the node we're
1136 // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1137 // to importing node_defs_). Conversely, input_already_exists[i] is false
1138 // iff the input refers to a node in node_defs_.
1139 input_already_exists.clear();
1140 input_already_exists.resize(node_def.input_size(), false);
1141
1142 ssize_t string_intern_table_index = -1;
1143
1144 if (opts_.importing) {
1145 // Intern the original node name, so that we can use a StringPiece of the
1146 // name to index gdef_nodes_.
1147 string_intern_table_index = string_intern_table_.size();
1148 string_intern_table_.push_back(node_def.name());
1149
1150 if (opts_.skip_mapped_nodes) {
1151 bool is_node_mapped = false;
1152 TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1153 if (is_node_mapped) {
1154 // Skip this node after updating pending_count_ for outputs
1155 UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1156 continue;
1157 }
1158 }
1159
1160 if (!opts_.input_map.empty()) {
1161 // Note that input_already_exists can shrink here
1162 RemapNodeDefInputs(&node_def, &input_already_exists);
1163 }
1164 if (!opts_.control_dependencies.empty()) {
1165 // Note that input_already_exists can grow here
1166 AddControlDependencies(&node_def, &input_already_exists);
1167 }
1168 if (!opts_.default_device.empty() && node_def.device().empty()) {
1169 node_def.set_device(opts_.default_device);
1170 }
1171 }
1172
1173 DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1174 TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1175 for (int i = 0; i < node_def.input_size(); ++i) {
1176 TensorId tensor_id = ParseTensorName(node_def.input(i));
1177 Node* src_node;
1178 int src_index;
1179
1180 if (!input_already_exists[i]) {
1181 // Locate input in newly-imported nodes
1182 auto iter = gdef_nodes_.find(tensor_id.node());
1183 DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1184 src_node = iter->second.node;
1185 src_index = tensor_id.index();
1186 if (src_node == nullptr) has_data_back_edge = true;
1187 } else {
1188 // Input refers to preexistng node in graph
1189 auto iter = existing_nodes_.find(tensor_id.node());
1190 DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1191 src_node = iter->second;
1192 src_index = tensor_id.index();
1193 }
1194
1195 if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1196 std::ostringstream out;
1197 out << "Node '" << node_def.name() << "': Connecting to invalid output "
1198 << tensor_id.index() << " of source node " << tensor_id.node()
1199 << " which has " << src_node->num_outputs() << " outputs.";
1200
1201 if (src_node->type_string() == "If" ||
1202 src_node->type_string() == "StatelessIf" ||
1203 src_node->type_string() == "While" ||
1204 src_node->type_string() == "StatelessWhile") {
1205 out << " Try using "
1206 << "tf.compat.v1.experimental.output_all_intermediates(True).";
1207 }
1208 return errors::InvalidArgument(out.str());
1209 }
1210
1211 inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1212 }
1213
1214 if (has_data_back_edge && !IsMerge(node_def)) {
1215 return errors::InvalidArgument(
1216 "Node '", node_def.name(),
1217 "' had a back edge, but only Merge nodes can have back edges.");
1218 }
1219
1220 Node* node;
1221 if (opts_.importing) {
1222 if (!prefix_.empty()) {
1223 AddPrefixToNodeDef(input_already_exists, &node_def);
1224 }
1225 // Note: no need to uniquify names if the prefix already guarantees
1226 // uniqueness
1227 if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1228 UniquifyNames(input_already_exists, &node_def);
1229 }
1230 }
1231
1232 if (opts_.importing) {
1233 TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1234 } else {
1235 const OpDef* op_def;
1236 TF_RETURN_IF_ERROR(
1237 g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1238 if (opts_.add_default_attributes) {
1239 AddDefaultsToNodeDef(*op_def, &node_def);
1240 }
1241 if (opts_.validate_nodes) {
1242 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1243 }
1244 }
1245
1246 TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1247
1248 if (opts_.importing) {
1249 // Use interned original node name so StringPiece remains valid.
1250 DCHECK_GE(string_intern_table_index, 0);
1251 gdef_nodes_[string_intern_table_[string_intern_table_index]].node = node;
1252 } else {
1253 DCHECK_EQ(string_intern_table_index, -1);
1254 gdef_nodes_[node->name()].node = node;
1255 }
1256
1257 // Remove duplicate control inputs before adding edges to the graph. It
1258 // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1259 auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1260 auto first_control_copy = first_control;
1261 std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1262 inputs.erase(
1263 std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1264 inputs.end());
1265
1266 // Add edges from inputs to *node to the graph.
1267 for (size_t i = 0; i < inputs.size(); ++i) {
1268 if (inputs[i].node == nullptr) {
1269 // Record this back edge, which will be added after all nodes
1270 // are created.
1271 back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1272 } else if (inputs[i].index == Graph::kControlSlot) {
1273 g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1274 } else {
1275 TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1276 }
1277 }
1278
1279 TF_RETURN_IF_ERROR(ValidateShape(node));
1280
1281 // Update pending_count_ for outputs.
1282 UpdatePendingCountAndReady(o, node->IsNextIteration());
1283 }
1284
1285 if (processed < node_def_count()) {
1286 LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1287 << " NODES IN A CYCLE";
1288 for (int64 i = 0; i < node_def_count(); i++) {
1289 if (pending_count_[i] != 0) {
1290 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1291 << " WITH PENDING COUNT = " << pending_count_[i];
1292 }
1293 }
1294 PrintCycles();
1295 return errors::InvalidArgument(node_def_count() - processed,
1296 " nodes in a cycle");
1297 }
1298
1299 return Status::OK();
1300 }
1301
AddBackEdges()1302 Status GraphConstructor::AddBackEdges() {
1303 // Add the back edges after all nodes are created.
1304 for (const auto& e : back_edges_) {
1305 Node* src_node = gdef_nodes_[e.src_name].node;
1306 if (e.src_index == Graph::kControlSlot) {
1307 g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1308 } else {
1309 TF_RETURN_IF_ERROR(
1310 MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1311 }
1312
1313 VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1314 << e.dst_node->name();
1315 }
1316 return Status::OK();
1317 }
1318
UpdateVersionDef()1319 Status GraphConstructor::UpdateVersionDef() {
1320 if (versions() == nullptr) return Status::OK();
1321
1322 if (!opts_.importing) {
1323 g_->set_versions(*versions());
1324 return Status::OK();
1325 }
1326 VersionDef g_versions = g_->versions();
1327 g_versions.set_producer(
1328 std::min(g_versions.producer(), versions()->producer()));
1329 g_versions.set_min_consumer(
1330 std::max(g_versions.min_consumer(), versions()->min_consumer()));
1331 if (versions()->bad_consumers_size() > 0) {
1332 std::set<int> bad(g_versions.bad_consumers().begin(),
1333 g_versions.bad_consumers().end());
1334 bad.insert(versions()->bad_consumers().begin(),
1335 versions()->bad_consumers().end());
1336 g_versions.clear_bad_consumers();
1337 for (int v : bad) {
1338 g_versions.add_bad_consumers(v);
1339 }
1340 }
1341 g_->set_versions(g_versions);
1342 return Status::OK();
1343 }
1344
PopulateReturnTensors()1345 Status GraphConstructor::PopulateReturnTensors() {
1346 if (opts_.return_tensors.empty()) return Status::OK();
1347 for (const TensorId& id : opts_.return_tensors) {
1348 auto iter = opts_.input_map.find(id);
1349 if (iter == opts_.input_map.end()) {
1350 // Locate id in imported nodes
1351 auto iter = gdef_nodes_.find(id.first);
1352 if (iter == gdef_nodes_.end()) {
1353 return errors::InvalidArgument("Requested return tensor '",
1354 id.ToString(),
1355 "' not found in graph def");
1356 }
1357 int num_outputs = iter->second.node->num_outputs();
1358 if ((id.second < 0 || id.second >= num_outputs) &&
1359 id.second != Graph::kControlSlot) {
1360 return errors::InvalidArgument("Invalid return output ", id.second,
1361 " of node '", id.first, "', which has ",
1362 num_outputs, " output(s)");
1363 }
1364 return_tensors_->push_back({iter->second.node, id.second});
1365 } else {
1366 // id was remapped to existing node
1367 TensorId remapped_id = iter->second;
1368 DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1369 Node* node = existing_nodes_[remapped_id.first];
1370 return_tensors_->push_back({node, remapped_id.second});
1371 }
1372 }
1373 return Status::OK();
1374 }
1375
PopulateReturnNodes()1376 Status GraphConstructor::PopulateReturnNodes() {
1377 if (opts_.return_nodes.empty()) return Status::OK();
1378 for (StringPiece name : opts_.return_nodes) {
1379 auto iter = gdef_nodes_.find(name);
1380 if (iter == gdef_nodes_.end()) {
1381 return errors::InvalidArgument("Requested return node '", name,
1382 "' not found in graph def");
1383 }
1384 return_nodes_->push_back(iter->second.node);
1385 }
1386 return Status::OK();
1387 }
1388
PopulateMissingUnusedInputMapKeys()1389 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1390 if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1391 for (const auto& input_map_pair : opts_.input_map) {
1392 TensorId key = input_map_pair.first;
1393 if (used_input_map_keys_.count(key) > 0) continue;
1394
1395 auto pair = gdef_nodes_.find(key.first);
1396 if (pair == gdef_nodes_.end()) {
1397 // key's node doesn't exist in GraphDef
1398 missing_unused_input_map_keys_->push_back(key);
1399 continue;
1400 }
1401
1402 // Check that key's index is in bounds. Get the number of outputs from the
1403 // NodeDef, rather than the imported Node, since the Node may not exist if
1404 // opts_.skip_mapped_nodes is true.
1405 const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1406 const OpDef* op_def;
1407 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1408 int num_outputs;
1409 TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1410 if (key.second >= num_outputs) {
1411 // key's index out of bounds
1412 missing_unused_input_map_keys_->push_back(key);
1413 }
1414 }
1415 return Status::OK();
1416 }
1417
Undo()1418 void GraphConstructor::Undo() {
1419 for (const auto& iter : gdef_nodes_) {
1420 if (iter.second.node != nullptr) {
1421 g_->RemoveNode(iter.second.node);
1422 }
1423 }
1424 g_->set_versions(original_versions_);
1425 }
1426
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1427 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1428 int input_index) {
1429 if (output_index >= src->num_outputs()) {
1430 return errors::InvalidArgument(
1431 "Output ", output_index, " of node ", src->name(),
1432 " does not exist. Node only has ", src->num_outputs(), " outputs.");
1433 }
1434 if (input_index >= dst->num_inputs()) {
1435 return errors::InvalidArgument(
1436 "Input ", input_index, " of node ", dst->name(),
1437 " does not exist. Node only has ", dst->num_inputs(), " inputs.");
1438 }
1439
1440 DataType src_out = src->output_type(output_index);
1441 DataType dst_in = dst->input_type(input_index);
1442 if (!TypesCompatible(dst_in, src_out)) {
1443 return errors::InvalidArgument(
1444 "Input ", input_index, " of node ", dst->name(), " was passed ",
1445 DataTypeString(src_out), " from ", src->name(), ":", output_index,
1446 " incompatible with expected ", DataTypeString(dst_in), ".");
1447 }
1448 g_->AddEdge(src, output_index, dst, input_index);
1449 return Status::OK();
1450 }
1451
1452 } // namespace
1453
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1454 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1455 const GraphDef& gdef, Graph* g) {
1456 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1457 return GraphConstructor::Construct(
1458 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1459 /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1460 /*missing_unused_input_map_keys=*/nullptr);
1461 }
1462
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,GraphDef && gdef,Graph * g)1463 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1464 GraphDef&& gdef, Graph* g) {
1465 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1466 return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1467 /*return_tensors=*/nullptr,
1468 /*return_nodes=*/nullptr,
1469 /*missing_unused_input_map_keys=*/nullptr);
1470 }
1471
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1472 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1473 gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1474 ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1475 // TODO(irving): Copy will go away once NodeInfo exists
1476 std::vector<const NodeDef*> node_defs;
1477 node_defs.reserve(nodes.size());
1478 for (const auto& n : nodes) {
1479 node_defs.push_back(&n);
1480 }
1481 return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1482 &refiner, /*return_tensors=*/nullptr,
1483 /*return_nodes=*/nullptr,
1484 /*missing_unused_input_map_keys=*/nullptr);
1485 }
1486
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1487 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1488 Graph* g, ShapeRefiner* refiner,
1489 ImportGraphDefResults* results) {
1490 if (!opts.return_tensors.empty()) {
1491 if (results == nullptr) {
1492 return errors::InvalidArgument(
1493 "results argument to ImportGraphDef() must be non-null if "
1494 "opts.return_tensors is non-empty");
1495 }
1496 }
1497
1498 if (!opts.return_nodes.empty()) {
1499 if (opts.skip_mapped_nodes) {
1500 return errors::InvalidArgument(
1501 "Requesting return_nodes with skip_mapped_nodes set is not currently "
1502 "supported");
1503 }
1504 if (results == nullptr) {
1505 return errors::InvalidArgument(
1506 "results argument to ImportGraphDef() must be non-null if "
1507 "opts.return_nodes is non-empty");
1508 }
1509 }
1510
1511 if (results != nullptr) {
1512 if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1513 !results->missing_unused_input_map_keys.empty()) {
1514 return errors::InvalidArgument(
1515 "All fields in results argument to ImportGraphDef() must be empty.");
1516 }
1517 }
1518
1519 ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1520 if (refiner == nullptr) {
1521 refiner = &default_refiner;
1522 } else {
1523 // Log a warning if we are importing a GraphDef at an older
1524 // producer version after already having added non-source/sink
1525 // nodes to the graph in the past.
1526 if (gdef.versions().producer() > 0 &&
1527 gdef.versions().producer() < refiner->graph_def_version() &&
1528 g->num_nodes() > 2) {
1529 LOG(WARNING) << "Importing a graph with a lower producer version "
1530 << gdef.versions().producer()
1531 << " into an existing graph with producer version "
1532 << refiner->graph_def_version() << ". Shape inference will "
1533 << "have run different parts of the graph with different "
1534 << "producer versions.";
1535 }
1536 }
1537
1538 // Set the graph def version of the refiner as the min of the
1539 // current value and the version from the graph we are about to
1540 // import.
1541 //
1542 // Note: to match Run() semantics, we should re-run shape inference
1543 // on the entire graph if the producer version has changed. For now
1544 // we log the warning above.
1545 refiner->set_graph_def_version(
1546 std::min(refiner->graph_def_version(), gdef.versions().producer()));
1547
1548 if (results == nullptr) {
1549 return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1550 &gdef.library(), g, refiner, nullptr,
1551 nullptr, nullptr);
1552 } else {
1553 return GraphConstructor::Construct(
1554 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1555 &results->return_tensors, &results->return_nodes,
1556 &results->missing_unused_input_map_keys);
1557 }
1558 }
1559
CopyGraph(const Graph & src,Graph * dest)1560 void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
1561
1562 } // namespace tensorflow
1563