1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/graph_constructor.h"
17
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/core/common_runtime/shape_refiner.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/versions.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/tensor_id.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/flatmap.h"
43 #include "tensorflow/core/lib/gtl/flatset.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/scanner.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/public/version.h"
51
52 namespace tensorflow {
53
54 namespace {
55
56 // We remove duplicate control inputs before adding edges to the Graph, so we
57 // can skip expensive duplicates check in 'AddControlEdge'.
58 static constexpr const bool kDoNotCheckDuplicates = true;
59
IsMerge(const NodeDef & node_def)60 inline bool IsMerge(const NodeDef& node_def) {
61 return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
62 node_def.op() == "_XlaMerge";
63 }
64
IsNextIteration(const NodeDef & node_def)65 inline bool IsNextIteration(const NodeDef& node_def) {
66 return node_def.op() == "NextIteration" ||
67 node_def.op() == "RefNextIteration";
68 }
69
IsValidNodeName(StringPiece s,bool allow_internal_ops)70 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
71 using ::tensorflow::strings::Scanner;
72 Scanner scanner(s);
73 scanner
74 .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
75 : Scanner::LETTER_DIGIT_DOT)
76 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
77
78 while (true) {
79 if (!scanner.GetResult()) // Some error in previous iteration.
80 return false;
81 if (scanner.empty()) // No error, but nothing left, good.
82 return true;
83
84 // Absorb another piece, starting with a '>'
85 scanner.One(Scanner::RANGLE)
86 .One(Scanner::LETTER_DIGIT_DOT)
87 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
88 }
89 }
90
91 class GraphConstructor {
92 public:
93 struct Options {
Optionstensorflow::__anon7b82f73f0111::GraphConstructor::Options94 Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
95 : allow_internal_ops(in.allow_internal_ops),
96 expect_device_spec(in.expect_device_spec),
97 importing(false),
98 validate_nodes(in.validate_nodes),
99 validate_colocation_constraints(false),
100 add_default_attributes(in.add_default_attributes) {}
Optionstensorflow::__anon7b82f73f0111::GraphConstructor::Options101 Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
102 : allow_internal_ops(false),
103 expect_device_spec(false),
104 prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
105 ? in.prefix
106 : in.prefix + "/"),
107 uniquify_names(in.uniquify_names),
108 uniquify_prefix(in.uniquify_prefix),
109 input_map(in.input_map.begin(), in.input_map.end()),
110 skip_mapped_nodes(in.skip_mapped_nodes),
111 control_dependencies(in.control_dependencies),
112 return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
113 return_nodes(in.return_nodes),
114 importing(true),
115 validate_nodes(true),
116 validate_colocation_constraints(in.validate_colocation_constraints),
117 validate_shape(in.validate_shape),
118 default_device(in.default_device) {}
119
120 bool allow_internal_ops;
121 bool expect_device_spec;
122
123 string prefix;
124 bool uniquify_names;
125 bool uniquify_prefix;
126 std::map<TensorId, TensorId> input_map;
127 bool skip_mapped_nodes;
128 std::vector<string> control_dependencies;
129 std::vector<TensorId> return_tensors;
130 std::vector<string> return_nodes;
131
132 // TODO(ashankar): This bool exists to separate out functionality required
133 // to make ImportGraphDef a close equivalent of Python's import_graph_def
134 // without affecting the behavior of ConvertGraphDefToGraph at the time
135 // ImportGraphDef was added.
136 //
137 // That said, the functionality here (shape and op validation) seems
138 // applicable to ConvertGraphDefToGraph as well, so make an attempt to
139 // remove this.
140 bool importing;
141 // If true, validates that nodes being converted have all expected attrs
142 // set and no unknown attrs set by calling ValidateNodeDef().
143 // `validate_nodes` is always true when `importing` is set.
144 bool validate_nodes;
145 bool validate_colocation_constraints;
146 bool validate_shape = true;
147
148 // If true, GraphConstructor will add attributes with their default
149 // value to the Node when they are missing from the NodeDef.
150 bool add_default_attributes = true;
151
152 string default_device;
153 };
154
155 typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
156
157 // versions and library may be nullptr
158 static Status Construct(
159 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
160 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
161 std::vector<std::pair<Node*, int>>* return_tensors,
162 std::vector<Node*>* return_nodes,
163 std::vector<SafeTensorId>* missing_unused_input_map_keys);
164
165 static Status Construct(
166 const Options& opts, GraphDef&& graph_def, Graph* g,
167 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
168 std::vector<Node*>* return_nodes,
169 std::vector<SafeTensorId>* missing_unused_input_map_keys);
170
171 protected:
GraphConstructor(const Options & opts,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)172 GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
173 std::vector<std::pair<Node*, int>>* return_tensors,
174 std::vector<Node*>* return_nodes,
175 std::vector<SafeTensorId>* missing_unused_input_map_keys)
176 : opts_(opts),
177 g_(g),
178 original_versions_(g->versions()),
179 prefix_(opts.prefix),
180 refiner_(refiner),
181 return_tensors_(return_tensors),
182 return_nodes_(return_nodes),
183 missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
184
~GraphConstructor()185 virtual ~GraphConstructor() {}
186
TryImport()187 Status TryImport() {
188 TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
189 TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
190 TF_RETURN_IF_ERROR(BuildNodeIndex());
191 TF_RETURN_IF_ERROR(InitFromEdges());
192
193 // NOTE: Convert() invokes `consume_node_def()` on each node in the input
194 // graph, so `get_node_def()` is no longer usable once it is called.
195 TF_RETURN_IF_ERROR(Convert());
196
197 TF_RETURN_IF_ERROR(AddBackEdges());
198 TF_RETURN_IF_ERROR(UpdateVersionDef());
199 TF_RETURN_IF_ERROR(PopulateReturnTensors());
200 TF_RETURN_IF_ERROR(PopulateReturnNodes());
201 TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
202 UpdateUniquifiedColocationNames();
203 FixupSourceAndSinkEdges(g_);
204 return Status::OK();
205 }
206
207 private:
208 Status EnsureNoNameCollisions();
209 Status ValidateInputMapAndControlDependencies();
210 Status BuildNodeIndex();
211 Status InitFromEdges();
212 Status Convert();
213 Status AddBackEdges();
214 Status UpdateVersionDef();
215 Status PopulateReturnTensors();
216 Status PopulateReturnNodes();
217 Status PopulateMissingUnusedInputMapKeys();
218
219 void Undo();
220
221 // Prints cycles in the graph.
222 void PrintCycles();
223 // Performs DFS starting at `cur_node` and prints any cycles found.
224 void DFS(int cur_node, std::vector<int>* cur_branch,
225 std::vector<bool>* is_on_cur_branch,
226 absl::flat_hash_set<int>* unvisited);
227 Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
228 Status ValidateColocationConstraints(const NodeDef& node_def);
229 Status MakeNode(NodeDef&& node_def, Node** node);
230 Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
231 Status ValidateShape(Node* node);
232 Status ModifyNodeDefForImport(NodeDef* node_def);
233 // Modifies node_def's inputs according to opts_.input_map.
234 // input_already_exists is a pre-initialized vector of length
235 // node_def->input_size(). This function will mark inputs that are remapped to
236 // true.
237 void RemapNodeDefInputs(NodeDef* node_def,
238 std::vector<bool>* input_already_exists);
239 // input_already_exists is a pre-initialized vector of length
240 // node_def->input_size(). This function will add and mark control inputs as
241 // true.
242 void AddControlDependencies(NodeDef* node_def,
243 std::vector<bool>* input_already_exists);
244 void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
245 NodeDef* node_def);
246
247 // Modifies `node_def` if its name isn't unique, or if any of its inputs'
248 // names have been uniquified. This must be called in topological order on all
249 // nodes.
250 void UniquifyNames(const std::vector<bool>& input_already_exists,
251 NodeDef* node_def);
252
253 // Updates any constructed nodes' colocation group names if the name has been
254 // updated by UniquifyNames. This is called after all the nodes have been
255 // constructed so all the names have been uniquified if necessary.
256 void UpdateUniquifiedColocationNames();
257
258 // Returns true if `name` already exists in `g_` (either as a node name or
259 // prefix).
260 bool NameExistsInGraph(StringPiece name);
261
262 // Returns true if `name` already exists in the GraphDef being imported
263 // (either as a node name or prefix).
264 bool NameExistsInGraphDef(StringPiece name);
265
266 // Returns a unique version of `original_name`, or `original_name` if it's
267 // already unique in the graph.
268 string FindUniqueName(StringPiece original_name);
269
270 // Decrement pending count for users of `processed` and add the ones that now
271 // have all of their pending inputs satisfied to `ready_`.
272 void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
273
274 // Subclasses override the following virtual methods to provide efficient
275 // access to the original protocol buffer-based graph.
276
277 // Returns the number of nodes in the graph.
278 virtual size_t node_def_count() const = 0;
279 // Returns the i^th node in the graph. Must not be called after
280 // consume_node_def(i).
281 virtual const NodeDef& get_node_def(int i) const = 0;
282 // Destructively reads the i^th node in the graph, avoiding a copy if
283 // possible. After calling this method, the result of get_node_def(i) is
284 // undefined.
285 virtual NodeDef consume_node_def(int i) = 0;
286 // Returns the version information for the graph, or nullptr if none is
287 // available.
288 virtual const VersionDef* versions() const = 0;
289 // Returns the function information for the graph, or nullptr if none is
290 // available.
291 virtual const FunctionDefLibrary* library() const = 0;
292
293 // From constructor
294 const Options opts_;
295 Graph* g_;
296 const VersionDef original_versions_;
297
298 // A copy of opts_.prefix, possibly uniquified.
299 string prefix_;
300
301 ShapeRefiner* refiner_;
302
303 // May be null. Not owned.
304 std::vector<std::pair<Node*, int>>* return_tensors_;
305
306 // May be null. Not owned.
307 std::vector<Node*>* return_nodes_;
308
309 // May be null. Not owned.
310 std::vector<SafeTensorId>* missing_unused_input_map_keys_;
311
312 // Intermediate datastructure used to populate
313 // `missing_unused_input_map_keys_`.
314 std::set<TensorId> used_input_map_keys_;
315
316 // Intermediate datastructure used to track the destinations of back edges.
317 absl::flat_hash_set<int> merge_node_indices_;
318
319 // Mapping from node name to the index within node_defs_.
320 struct NodeInfo {
NodeInfotensorflow::__anon7b82f73f0111::GraphConstructor::NodeInfo321 explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
322 // Containers require that we have a default constructor.
NodeInfotensorflow::__anon7b82f73f0111::GraphConstructor::NodeInfo323 NodeInfo() : NodeInfo(-1) {}
324 int gdef_index;
325 Node* node; // nullptr until the NodeDef is converted to a Node.
326 };
327 absl::flat_hash_map<std::string, NodeInfo> gdef_nodes_;
328
329 // Prefixes already used in the GraphDef being imported.
330 absl::flat_hash_set<StringPiece> gdef_prefixes_;
331
332 // Mapping from node name to the existing node in g_.
333 absl::flat_hash_map<StringPiece, Node*> existing_nodes_;
334
335 // Prefixes already used in the graph.
336 absl::flat_hash_set<StringPiece> existing_prefixes_;
337
338 // Imported node names that have been uniquified. The key is the original
339 // name, the value is the new unique name.
340 gtl::FlatMap<string, string> uniquified_names_;
341
342 // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
343 // (sorted) set so nodes are created in the order defined in the GraphDef.
344 std::set<int> ready_;
345
346 // Mapping between index within node_defs_ and the number of inputs that
347 // still need to be converted.
348 std::vector<int> pending_count_;
349
350 // Mapping between index within node_defs_ and the index within node_defs_ of
351 // all nodes it outputs to.
352 std::vector<gtl::InlinedVector<int, 4>> outputs_;
353
354 // Used in the conversion from node_defs_ to g_ to represent the ith input
355 // of a node.
356 struct InputInfo {
InputInfotensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo357 explicit InputInfo(const string& node_name, Node* n, int i)
358 : name(node_name), node(n), index(i) {}
359 // Use string instead of StringPiece so we don't have to manage lifetime
360 string name;
361 Node* node;
362 int index;
363
IsControlInputtensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo364 static bool IsControlInput(const InputInfo& input) {
365 return input.index == Graph::kControlSlot;
366 }
CompareNametensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo367 static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
368 return lhs.name < rhs.name;
369 }
IsSameNametensorflow::__anon7b82f73f0111::GraphConstructor::InputInfo370 static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
371 return lhs.name == rhs.name;
372 }
373 };
374
375 // Used in the conversion from node_defs_ to g_ to represent an edge from
376 // the node named 'name' to node 'n'.
377 struct EdgeInfo {
EdgeInfotensorflow::__anon7b82f73f0111::GraphConstructor::EdgeInfo378 explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
379 : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
380 // Use string instead of StringPiece so we don't have to manage lifetime
381 string src_name;
382 int src_index;
383 Node* dst_node;
384 int dst_index;
385 };
386 std::vector<EdgeInfo> back_edges_;
387
388 TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
389 };
390
391 // Implementation of GraphConstructor that does not take ownership of the
392 // input NodeDef messages and thus copies the nodes into the constructed Graph*.
393 //
394 // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
395 // avoids copying each NodeDef into the constructed Graph*.
396 class NodeDefCopyingGraphConstructor : public GraphConstructor {
397 public:
NodeDefCopyingGraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)398 NodeDefCopyingGraphConstructor(
399 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
400 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
401 std::vector<std::pair<Node*, int>>* return_tensors,
402 std::vector<Node*>* return_nodes,
403 std::vector<SafeTensorId>* missing_unused_input_map_keys)
404 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
405 missing_unused_input_map_keys),
406 node_defs_(node_defs),
407 versions_(versions),
408 library_(library) {}
409
410 private:
node_def_count() const411 size_t node_def_count() const override { return node_defs_.size(); }
get_node_def(int i) const412 const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
consume_node_def(int i)413 NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
versions() const414 const VersionDef* versions() const override { return versions_; }
library() const415 const FunctionDefLibrary* library() const override { return library_; }
416
417 const NodeDefSlice node_defs_;
418 const VersionDef* const versions_;
419 const FunctionDefLibrary* const library_;
420 };
421
422 // Implementation of GraphConstructor that takes ownership of the input
423 // GraphDef, and can perform destructive reads.
424 class NodeDefMovingGraphConstructor : public GraphConstructor {
425 public:
NodeDefMovingGraphConstructor(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)426 NodeDefMovingGraphConstructor(
427 const Options& opts, GraphDef&& graph_def, Graph* g,
428 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
429 std::vector<Node*>* return_nodes,
430 std::vector<SafeTensorId>* missing_unused_input_map_keys)
431 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
432 missing_unused_input_map_keys),
433 graph_def_(std::move(graph_def)),
434 is_consumed_(graph_def_.node_size(), false) {}
435
436 private:
node_def_count() const437 size_t node_def_count() const override { return graph_def_.node().size(); }
get_node_def(int i) const438 const NodeDef& get_node_def(int i) const override {
439 CHECK(!is_consumed_[i])
440 << "NodeDef " << i << " accessed after it was consumed.";
441 return graph_def_.node(i);
442 }
consume_node_def(int i)443 NodeDef consume_node_def(int i) override {
444 CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
445 is_consumed_[i] = true;
446 return std::move(*graph_def_.mutable_node(i));
447 }
versions() const448 const VersionDef* versions() const override { return &graph_def_.versions(); }
library() const449 const FunctionDefLibrary* library() const override {
450 return &graph_def_.library();
451 }
452
453 GraphDef graph_def_;
454 std::vector<bool> is_consumed_;
455 };
456
ForwardCompatibilityWindowPassed(const VersionDef & versions)457 bool ForwardCompatibilityWindowPassed(const VersionDef& versions) {
458 // TF_GRAPH_DEF_VERSION is incremented daily.
459 // TF has a 3 week forward compatibility guarantee.
460 return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21;
461 }
462
MaybeAppendVersionWarning(const VersionDef * versions,const Status & import_status)463 Status MaybeAppendVersionWarning(const VersionDef* versions,
464 const Status& import_status) {
465 if (versions && ForwardCompatibilityWindowPassed(*versions)) {
466 return Status(
467 import_status.code(),
468 absl::StrCat(
469 "Converting GraphDef to Graph has failed with an error: '",
470 import_status.error_message(),
471 "' The binary trying to import the GraphDef was built when "
472 "GraphDef version was ",
473 TF_GRAPH_DEF_VERSION,
474 ". The GraphDef was produced by a binary built when GraphDef "
475 "version was ",
476 versions->producer(),
477 ". The difference between these versions is larger than "
478 "TensorFlow's forward compatibility guarantee, and might be the "
479 "root cause for failing to import the GraphDef."));
480 }
481 return import_status;
482 }
483
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)484 /* static */ Status GraphConstructor::Construct(
485 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
486 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
487 std::vector<std::pair<Node*, int>>* return_tensors,
488 std::vector<Node*>* return_nodes,
489 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
490 if (versions) {
491 TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
492 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
493 "GraphDef", "graph"));
494 }
495 NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
496 refiner, return_tensors, return_nodes,
497 missing_unused_input_map_keys);
498 Status s = c.TryImport();
499 if (!s.ok()) {
500 c.Undo();
501 s = MaybeAppendVersionWarning(versions, s);
502 }
503 return s;
504 }
505
Construct(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)506 /* static */ Status GraphConstructor::Construct(
507 const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
508 std::vector<std::pair<Node*, int>>* return_tensors,
509 std::vector<Node*>* return_nodes,
510 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
511 TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
512 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
513 "GraphDef", "graph"));
514 VersionDef version_def = graph_def.versions();
515 NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
516 return_tensors, return_nodes,
517 missing_unused_input_map_keys);
518 Status s = c.TryImport();
519 if (!s.ok()) {
520 c.Undo();
521 s = MaybeAppendVersionWarning(&version_def, s);
522 }
523 return s;
524 }
525
UpdatePendingCountAndReady(int processed,bool is_next_iteration)526 void GraphConstructor::UpdatePendingCountAndReady(int processed,
527 bool is_next_iteration) {
528 for (size_t i = 0; i < outputs_[processed].size(); ++i) {
529 const int output = outputs_[processed][i];
530 // We didn't consider NextIteration->Merge edges when computing
531 // pending_counts_ so we should not have to consider it here either.
532 bool is_next_iteration_to_merge_edge =
533 is_next_iteration && merge_node_indices_.count(output) == 1;
534 if (!is_next_iteration_to_merge_edge) {
535 int* current_pending_count = &pending_count_[output];
536 CHECK_GT(*current_pending_count, 0);
537 (*current_pending_count)--;
538 if (*current_pending_count == 0) {
539 ready_.insert(output);
540 }
541 }
542 }
543 }
544
545 // This could be expensive but we don't expect to call it often, if at all (only
546 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)547 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
548 const StringPiece& node_name) {
549 for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
550 if (iter->second.first == node_name) return true;
551 }
552 return false;
553 }
554
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)555 bool NodeNameInValues(const std::vector<string>& control_dependencies,
556 const StringPiece& node_name) {
557 return std::find(control_dependencies.begin(), control_dependencies.end(),
558 node_name) != control_dependencies.end();
559 }
560
561 // Adds any prefixes of `node_name` (not including the full name itself) to
562 // `prefixes`.
AddPrefixes(StringPiece node_name,absl::flat_hash_set<StringPiece> * prefixes)563 void AddPrefixes(StringPiece node_name,
564 absl::flat_hash_set<StringPiece>* prefixes) {
565 size_t idx = -1;
566 while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
567 prefixes->insert(node_name.substr(0, idx));
568 }
569 }
570
EnsureNoNameCollisions()571 Status GraphConstructor::EnsureNoNameCollisions() {
572 existing_nodes_.reserve(g_->num_nodes());
573 // Populate existing_nodes_ and existing_prefixes_.
574 for (Node* n : g_->nodes()) {
575 bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
576 if (already_exists) {
577 if (NodeNameInValues(opts_.input_map, n->name())) {
578 return errors::InvalidArgument(
579 "cannot resolve input_map because multiple nodes exist with name '",
580 n->name(), "'");
581 }
582 if (NodeNameInValues(opts_.control_dependencies, n->name())) {
583 return errors::InvalidArgument(
584 "cannot resolve control_dependencies because multiple nodes exist "
585 "with name '",
586 n->name(), "'");
587 }
588 }
589 AddPrefixes(n->name(), &existing_prefixes_);
590 }
591 if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
592 for (size_t i = 0; i < node_def_count(); ++i) {
593 const string& name = get_node_def(i).name();
594 if (NameExistsInGraph(name)) {
595 return errors::InvalidArgument("Node name '", name,
596 "' already exists in the Graph");
597 }
598 }
599 } else if (!prefix_.empty()) {
600 StringPiece prefix_no_slash(prefix_);
601 prefix_no_slash.remove_suffix(1);
602 if (!IsValidNodeName(prefix_no_slash, false)) {
603 return errors::InvalidArgument("Imported node name prefix '", prefix_,
604 "' would lead to invalid node names");
605 }
606 if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
607 prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
608 }
609 }
610 return Status::OK();
611 }
612
ValidateInputMapAndControlDependencies()613 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
614 for (const auto& mapping : opts_.input_map) {
615 TensorId src = mapping.first;
616 TensorId dst = mapping.second;
617 if (existing_nodes_.count(dst.first) == 0) {
618 return errors::InvalidArgument(
619 "node '", dst.first, "' in input_map does not exist in graph ",
620 "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
621 }
622 if ((src.second == Graph::kControlSlot) !=
623 (dst.second == Graph::kControlSlot)) {
624 return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
625 dst.ToString(), " between ",
626 "control edge and non-control edge");
627 }
628 }
629 for (const string& node : opts_.control_dependencies) {
630 if (existing_nodes_.count(node) == 0) {
631 return errors::InvalidArgument(
632 "node '", node,
633 "' in control_dependencies does not exist in "
634 "graph");
635 }
636 }
637 return Status::OK();
638 }
639
BuildNodeIndex()640 Status GraphConstructor::BuildNodeIndex() {
641 // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
642 for (int n = 0; n < node_def_count(); ++n) {
643 const NodeDef& node_def = get_node_def(n);
644 if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
645 return errors::InvalidArgument(
646 "Node '", node_def.name(),
647 "': Node name contains invalid characters");
648 }
649 if (!gdef_nodes_.insert(std::make_pair(node_def.name(), NodeInfo(n)))
650 .second) {
651 return errors::InvalidArgument("Node '", node_def.name(),
652 "' is not unique");
653 }
654 // Validate the operation's type.
655 if (node_def.op().empty()) {
656 return errors::InvalidArgument("Node '", node_def.name(),
657 "' does not specify an operation");
658 }
659 if (opts_.expect_device_spec && node_def.device().empty()) {
660 return errors::InvalidArgument("Node '", node_def.name(),
661 "' is missing a device specification");
662 }
663 if (IsMerge(node_def)) {
664 merge_node_indices_.insert(n);
665 }
666 // Validate control edges at end
667 bool in_control_dependence = false;
668 for (int i = 0; i < node_def.input_size(); ++i) {
669 StringPiece input_name = node_def.input(i);
670 if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
671 in_control_dependence = true;
672 } else if (in_control_dependence) {
673 return errors::InvalidArgument(
674 "Node '", node_def.name(),
675 "': Control dependencies must come after regular dependencies");
676 }
677 }
678 // Update gdef_prefixes_.
679 AddPrefixes(node_def.name(), &gdef_prefixes_);
680 }
681 return Status::OK();
682 }
683
InitFromEdges()684 Status GraphConstructor::InitFromEdges() {
685 const int num_nodes = node_def_count();
686 pending_count_.reserve(num_nodes);
687 outputs_.resize(num_nodes);
688 gtl::FlatSet<string> next_iteration_nodes;
689 for (int n = 0; n < node_def_count(); ++n) {
690 const NodeDef& node_def = get_node_def(n);
691 if (IsNextIteration(node_def)) {
692 next_iteration_nodes.insert(node_def.name());
693 }
694 }
695
696 // Parse the inputs for each node.
697 for (int n = 0; n < num_nodes; ++n) {
698 const NodeDef& node_def = get_node_def(n);
699 int pending_count = node_def.input_size();
700 if (IsMerge(node_def)) {
701 // Cycles in the graph are only allowed for while loops. A while loop is
702 // identified by an edge from a NextIteration node to a Merge node. For
703 // such Merge nodes, only wait for one non-control input before
704 // considering the node ready to process in Convert().
705 int32_t num_control_edges = 0;
706 bool has_loop_back_edge = false;
707 for (int i = 0; i < node_def.input_size(); ++i) {
708 StringPiece input_name(node_def.input(i));
709 if (absl::StartsWith(input_name, "^")) {
710 num_control_edges++;
711 } else {
712 TensorId id(ParseTensorName(input_name));
713 if (next_iteration_nodes.find(string(id.first)) !=
714 next_iteration_nodes.end()) {
715 has_loop_back_edge = true;
716 }
717 }
718 }
719 if (has_loop_back_edge) {
720 pending_count = num_control_edges + 1;
721 }
722 }
723 for (int i = 0; i < node_def.input_size(); ++i) {
724 StringPiece input_name = node_def.input(i);
725 TensorId id(ParseTensorName(input_name));
726 if (opts_.input_map.count(id) == 0) {
727 // If an input is not mapped, then the input should appear in the graph
728 // being imported.
729 auto iter = gdef_nodes_.find(id.first);
730 if (iter == gdef_nodes_.end()) {
731 return errors::InvalidArgument("Node '", node_def.name(),
732 "': Unknown input node '",
733 node_def.input(i), "'");
734 }
735 outputs_[iter->second.gdef_index].push_back(n);
736 } else {
737 // This input is mapped to an existing edge. Therefore this input is
738 // as good as being already processed.
739 --pending_count;
740 DCHECK_GE(pending_count, 0);
741 }
742 }
743 if (pending_count == 0) {
744 ready_.insert(n);
745 }
746 pending_count_.push_back(pending_count);
747 }
748 return Status::OK();
749 }
750
ValidateColocationConstraints(const NodeDef & node_def)751 Status GraphConstructor::ValidateColocationConstraints(
752 const NodeDef& node_def) {
753 if (!opts_.validate_colocation_constraints || !opts_.importing)
754 return Status::OK();
755 const auto iter = node_def.attr().find(kColocationAttrName);
756 if (iter == node_def.attr().end()) return Status::OK();
757 for (const string& c : iter->second.list().s()) {
758 StringPiece s(c);
759 if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
760 gdef_nodes_.find(s) == gdef_nodes_.end()) {
761 return errors::InvalidArgument(
762 "Node '", node_def.name(),
763 "' expects to be colocated with unknown node '", s, "'");
764 }
765 }
766 return Status::OK();
767 }
768
MakeNode(NodeDef && node_def,Node ** node)769 Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
770 // Add the node to the graph.
771 Status status;
772 *node = g_->AddNode(std::move(node_def), &status);
773 if (!status.ok()) return status;
774 if (opts_.expect_device_spec) {
775 (*node)->set_assigned_device_name((*node)->def().device());
776 }
777 return Status::OK();
778 }
779
ValidateShape(Node * node)780 Status GraphConstructor::ValidateShape(Node* node) {
781 if (!opts_.importing || !opts_.validate_shape) return Status::OK();
782 TF_RETURN_IF_ERROR(refiner_->AddNode(node));
783 // For nodes with the _output_shapes attribute, override the shape.
784 std::vector<const TensorShapeProto*> shape_attrs;
785 const char* kAttrName = "_output_shapes";
786 if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
787 // No _output_shapes attribute, the AddNode call above was sufficient.
788 return Status::OK();
789 }
790 auto* ic = refiner_->GetContext(node);
791 DCHECK(ic != nullptr)
792 << "ShapeRefiner::AddNode() should have created the InferenceContext";
793 if (shape_attrs.size() < node->num_outputs()) {
794 return errors::InvalidArgument(
795 "Node '", node->name(), "' has ", node->num_outputs(),
796 " outputs but the ", kAttrName, " attribute specifies shapes for ",
797 shape_attrs.size(), " outputs");
798 }
799 // NOTE(skyewm): we don't raise an error here because some users depend on
800 // this behavior, even though it's unsafe.
801 // TODO(b/74619486): raise an error.
802 if (shape_attrs.size() > node->num_outputs()) {
803 LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
804 << " outputs but the " << kAttrName
805 << " attribute specifies shapes for " << shape_attrs.size()
806 << " outputs. Output shapes may be inaccurate.";
807 }
808 for (int i = 0; i < node->num_outputs(); ++i) {
809 const TensorShapeProto& p = *shape_attrs[i];
810 shape_inference::ShapeHandle h;
811 Status s = ic->MakeShapeFromShapeProto(p, &h);
812 if (!s.ok()) {
813 return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
814 kAttrName, " attribute (shape #", i,
815 " error:'", s.error_message(), "'");
816 }
817 s = refiner_->SetShape(node, i, h);
818 if (!s.ok()) {
819 return errors::InvalidArgument(
820 "Node '", node->name(), "' has an ", kAttrName,
821 " attribute inconsistent with the GraphDef for output #", i, ": ",
822 s.error_message());
823 }
824 }
825 node->ClearAttr(kAttrName);
826 return Status::OK();
827 }
828
ModifyNodeDefForImport(NodeDef * node_def)829 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
830 const OpDef* op_def;
831 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
832 AddDefaultsToNodeDef(*op_def, node_def);
833 TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
834 if (versions()) {
835 TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
836 }
837 return Status::OK();
838 }
839
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)840 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
841 std::vector<bool>* input_already_exists) {
842 // Remove 'inputs_to_remove' from 'node_def'
843 NodeDef copy;
844 copy.mutable_input()->Reserve(node_def->input_size() -
845 inputs_to_remove.size());
846 for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
847 if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
848 ++j;
849 } else {
850 copy.add_input()->swap(*node_def->mutable_input(i));
851 }
852 }
853 node_def->mutable_input()->Swap(copy.mutable_input());
854 // Remove 'inputs_to_remove' from 'input_already_exists'
855 for (int idx : inputs_to_remove) {
856 input_already_exists->erase(input_already_exists->begin() + idx);
857 }
858 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
859 }
860
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)861 void GraphConstructor::RemapNodeDefInputs(
862 NodeDef* node_def, std::vector<bool>* input_already_exists) {
863 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
864 std::set<TensorId> control_inputs;
865 std::vector<int> inputs_to_remove;
866
867 for (int i = 0; i < node_def->input_size(); ++i) {
868 auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
869 if (iter == opts_.input_map.end()) continue;
870 used_input_map_keys_.insert(iter->first);
871
872 TensorId new_input = iter->second;
873 if (new_input.second == Graph::kControlSlot) {
874 // Check if we've already remapped a different input to new_input, and if
875 // so remove this input.
876 if (control_inputs.count(new_input) > 0) {
877 inputs_to_remove.push_back(i);
878 continue;
879 }
880 control_inputs.insert(new_input);
881 }
882 node_def->set_input(i, new_input.ToString());
883 (*input_already_exists)[i] = true;
884 }
885 if (!inputs_to_remove.empty()) {
886 RemoveInputs(inputs_to_remove, node_def, input_already_exists);
887 }
888 }
889
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)890 void GraphConstructor::AddControlDependencies(
891 NodeDef* node_def, std::vector<bool>* input_already_exists) {
892 // To avoid adding redundant control dependencies to every imported node, skip
893 // nodes that will inherit the dependencies from another imported node.
894 bool inherits_deps = false;
895 for (int i = 0; i < node_def->input_size(); ++i) {
896 // Assume we won't inherit dependencies from remapped inputs that already
897 // exist in the graph. Even if we're wrong, we'll only add redundant
898 // dependencies.
899 if ((*input_already_exists)[i]) continue;
900
901 // If this input is a backedge, assume we won't inherit the dependencies.
902 // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
903 // worth optimizing these.
904 TensorId id(ParseTensorName(node_def->input(i)));
905 auto iter = gdef_nodes_.find(id.first);
906 DCHECK(iter != gdef_nodes_.end()) << id.first;
907 if (iter->second.node == nullptr) {
908 // Input hasn't been created yet, indicating it's a backedge.
909 continue;
910 }
911 inherits_deps = true;
912 }
913 if (inherits_deps) return;
914
915 // node_def either has no inputs or all remapped inputs, add the control
916 // dependencies
917 for (const string& control_dep : opts_.control_dependencies) {
918 string input = TensorId(control_dep, Graph::kControlSlot).ToString();
919 bool found = false;
920 for (int i = node_def->input_size() - 1; i >= 0; --i) {
921 const string& node_input = node_def->input(i);
922 if (node_input[0] != '^') {
923 // Control inputs are at the end. Break when we reach the non-control
924 // inputs.
925 break;
926 }
927 if (node_input == input) {
928 // Control dependency already exists
929 found = true;
930 break;
931 }
932 }
933 if (found) {
934 continue;
935 }
936 node_def->add_input(input);
937 input_already_exists->push_back(true);
938 }
939 }
940
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)941 void GraphConstructor::AddPrefixToNodeDef(
942 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
943 if (prefix_.empty()) return;
944 node_def->set_name(strings::StrCat(prefix_, node_def->name()));
945 // Update names of input nodes
946 for (int i = 0; i < node_def->input_size(); ++i) {
947 // Skip remapped inputs (which already exist in g_ and are not being
948 // imported).
949 if (input_already_exists[i]) continue;
950 StringPiece input(node_def->input(i));
951 if (absl::ConsumePrefix(&input, "^")) {
952 node_def->set_input(i, strings::StrCat("^", prefix_, input));
953 } else {
954 node_def->set_input(i, strings::StrCat(prefix_, input));
955 }
956 }
957 // Update names of colocation groups
958 if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
959 auto* list =
960 node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
961 for (int i = 0; i < list->s_size(); ++i) {
962 StringPiece v(list->s(i));
963 if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
964 list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
965 }
966 }
967 }
968 }
969
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)970 void GraphConstructor::UniquifyNames(
971 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
972 if (NameExistsInGraph(node_def->name())) {
973 string old_name = node_def->name();
974 node_def->set_name(FindUniqueName(node_def->name()));
975 uniquified_names_[old_name] = node_def->name();
976 // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
977 // `name` because we guarantee the original NodeDef names are unique,
978 // meaning we won't generate this name again.
979 }
980 for (int i = 0; i < node_def->input_size(); ++i) {
981 // Skip remapped inputs (which already exist in g_ and are not being
982 // imported).
983 if (input_already_exists[i]) continue;
984 TensorId id = ParseTensorName(node_def->input(i));
985 // We require that UniquifyNames() is called on all NodeDefs in topological
986 // order. This guarantees that node_def's inputs will already be uniquified
987 // if necessary.
988 auto iter = uniquified_names_.find(string(id.first));
989 if (iter == uniquified_names_.end()) continue;
990 id.first = iter->second;
991 node_def->set_input(i, id.ToString());
992 }
993 }
994
UpdateUniquifiedColocationNames()995 void GraphConstructor::UpdateUniquifiedColocationNames() {
996 for (const auto& pair : gdef_nodes_) {
997 Node* node = pair.second.node;
998 if (node == nullptr) continue;
999 std::vector<string> coloc_values;
1000 if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1001 continue;
1002 bool updated = false;
1003 for (size_t i = 0; i < coloc_values.size(); ++i) {
1004 StringPiece val(coloc_values[i]);
1005 if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1006 auto name_pair = uniquified_names_.find(string(val));
1007 if (name_pair == uniquified_names_.end()) continue;
1008 updated = true;
1009 coloc_values[i] =
1010 strings::StrCat(kColocationGroupPrefix, name_pair->second);
1011 }
1012 }
1013 if (updated) {
1014 node->AddAttr(kColocationAttrName, std::move(coloc_values));
1015 }
1016 }
1017 }
1018
NameExistsInGraph(StringPiece name)1019 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1020 if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1021 if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1022 return false;
1023 }
1024
NameExistsInGraphDef(StringPiece name)1025 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1026 if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1027 if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1028 return false;
1029 }
1030
FindUniqueName(StringPiece original_name)1031 string GraphConstructor::FindUniqueName(StringPiece original_name) {
1032 string name(original_name);
1033 int count = 0;
1034 // Check that any generated names don't collide with imported NodeDefs (as
1035 // well as nodes in g_).
1036 while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1037 name = strings::StrCat(original_name, "_", ++count);
1038 }
1039 return name;
1040 }
1041
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)1042 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1043 bool* is_node_mapped) {
1044 const OpDef* op_def;
1045 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1046 for (int i = 0; i < op_def->output_arg_size(); ++i) {
1047 if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1048 *is_node_mapped = false;
1049 return Status::OK();
1050 }
1051 }
1052 *is_node_mapped = true;
1053 return Status::OK();
1054 }
1055
DFS(int cur_node,std::vector<int> * cur_branch,std::vector<bool> * is_on_cur_branch,absl::flat_hash_set<int> * unvisited)1056 void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1057 std::vector<bool>* is_on_cur_branch,
1058 absl::flat_hash_set<int>* unvisited) {
1059 cur_branch->push_back(cur_node);
1060 is_on_cur_branch->at(cur_node) = true;
1061 for (auto next_node : outputs_[cur_node]) {
1062 if (unvisited->find(next_node) != unvisited->end()) {
1063 if (is_on_cur_branch->at(next_node)) {
1064 auto iter =
1065 std::find(cur_branch->begin(), cur_branch->end(), next_node);
1066 LOG(WARNING) << "Cycle detected:";
1067 while (iter != cur_branch->end()) {
1068 LOG(WARNING) << SummarizeNodeDef(get_node_def(*iter));
1069 ++iter;
1070 }
1071 LOG(WARNING) << "End of cycle";
1072 } else {
1073 DFS(next_node, cur_branch, is_on_cur_branch, unvisited);
1074 }
1075 }
1076 }
1077 cur_branch->pop_back();
1078 is_on_cur_branch->at(cur_node) = false;
1079 unvisited->erase(cur_node);
1080 }
1081
PrintCycles()1082 void GraphConstructor::PrintCycles() {
1083 int num_nodes = outputs_.size();
1084 absl::flat_hash_set<int> unvisited;
1085 for (int i = 0; i < num_nodes; i++) {
1086 unvisited.insert(i);
1087 }
1088 while (!unvisited.empty()) {
1089 int cur_node = *unvisited.begin();
1090 // Nodes on the current branch of DFS in traversal order. This is used for
1091 // printing the nodes in the cycle.
1092 std::vector<int> cur_branch;
1093 // This is just to make lookups O(1).
1094 // is_on_cur_branch[i] ==
1095 // (std::find(cur_branch.start(),
1096 // cur_branch.end(), i) != cur_branch.end())
1097 std::vector<bool> is_on_cur_branch(num_nodes, false);
1098 DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited);
1099 }
1100 }
1101
Convert()1102 Status GraphConstructor::Convert() {
1103 // Import functions before adding nodes, since imported nodes may refer to
1104 // functions
1105 if (library()) {
1106 // TODO(b/135705010): Add rvalue overloads into the function library, to
1107 // avoid unnecessarily copying `*library()` here.
1108 TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1109 }
1110
1111 std::vector<InputInfo> inputs;
1112 int processed = 0;
1113
1114 std::vector<bool> input_already_exists;
1115
1116 // Process the NodeDefs in topological order.
1117 // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1118 // inputs, pending_counts_ with the number of inputs for each node and
1119 // outputs_ with the outputs of each node).
1120 while (!ready_.empty()) {
1121 int o = *ready_.begin();
1122 ready_.erase(ready_.begin());
1123 ++processed;
1124 inputs.clear();
1125 bool has_data_back_edge = false;
1126
1127 NodeDef node_def = consume_node_def(o);
1128
1129 // input_already_exists[i] is true iff the i-th input of the node we're
1130 // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1131 // to importing node_defs_). Conversely, input_already_exists[i] is false
1132 // iff the input refers to a node in node_defs_.
1133 input_already_exists.clear();
1134 input_already_exists.resize(node_def.input_size(), false);
1135
1136 std::string node_name = node_def.name();
1137
1138 if (opts_.importing) {
1139 if (opts_.skip_mapped_nodes) {
1140 bool is_node_mapped = false;
1141 TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1142 if (is_node_mapped) {
1143 // Skip this node after updating pending_count_ for outputs
1144 UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1145 continue;
1146 }
1147 }
1148
1149 if (!opts_.input_map.empty()) {
1150 // Note that input_already_exists can shrink here
1151 RemapNodeDefInputs(&node_def, &input_already_exists);
1152 }
1153 if (!opts_.control_dependencies.empty()) {
1154 // Note that input_already_exists can grow here
1155 AddControlDependencies(&node_def, &input_already_exists);
1156 }
1157 if (!opts_.default_device.empty() && node_def.device().empty()) {
1158 node_def.set_device(opts_.default_device);
1159 }
1160 }
1161
1162 DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1163 TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1164 for (int i = 0; i < node_def.input_size(); ++i) {
1165 TensorId tensor_id = ParseTensorName(node_def.input(i));
1166 Node* src_node;
1167 int src_index;
1168
1169 if (!input_already_exists[i]) {
1170 // Locate input in newly-imported nodes
1171 auto iter = gdef_nodes_.find(tensor_id.node());
1172 DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1173 src_node = iter->second.node;
1174 src_index = tensor_id.index();
1175 if (src_node == nullptr) has_data_back_edge = true;
1176 } else {
1177 // Input refers to preexistng node in graph
1178 auto iter = existing_nodes_.find(tensor_id.node());
1179 DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1180 src_node = iter->second;
1181 src_index = tensor_id.index();
1182 }
1183
1184 if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1185 std::ostringstream out;
1186 out << "Node '" << node_def.name() << "': Connecting to invalid output "
1187 << tensor_id.index() << " of source node " << tensor_id.node()
1188 << " which has " << src_node->num_outputs() << " outputs.";
1189
1190 if (src_node->type_string() == "If" ||
1191 src_node->type_string() == "StatelessIf" ||
1192 src_node->type_string() == "While" ||
1193 src_node->type_string() == "StatelessWhile") {
1194 out << " Try using "
1195 << "tf.compat.v1.experimental.output_all_intermediates(True).";
1196 }
1197 return errors::InvalidArgument(out.str());
1198 }
1199
1200 inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1201 }
1202
1203 if (has_data_back_edge && !IsMerge(node_def)) {
1204 return errors::InvalidArgument(
1205 "Node '", node_def.name(),
1206 "' had a back edge, but only Merge nodes can have back edges.");
1207 }
1208
1209 Node* node;
1210 if (opts_.importing) {
1211 if (!prefix_.empty()) {
1212 AddPrefixToNodeDef(input_already_exists, &node_def);
1213 }
1214 // Note: no need to uniquify names if the prefix already guarantees
1215 // uniqueness
1216 if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1217 UniquifyNames(input_already_exists, &node_def);
1218 }
1219 }
1220
1221 if (opts_.importing) {
1222 TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1223 } else {
1224 const OpDef* op_def;
1225 TF_RETURN_IF_ERROR(
1226 g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1227 if (opts_.add_default_attributes) {
1228 AddDefaultsToNodeDef(*op_def, &node_def);
1229 }
1230 if (opts_.validate_nodes) {
1231 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1232 }
1233 }
1234
1235 TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1236
1237 gdef_nodes_[node_name].node = node;
1238
1239 // Remove duplicate control inputs before adding edges to the graph. It
1240 // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1241 auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1242 auto first_control_copy = first_control;
1243 std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1244 inputs.erase(
1245 std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1246 inputs.end());
1247
1248 // Add edges from inputs to *node to the graph.
1249 for (size_t i = 0; i < inputs.size(); ++i) {
1250 if (inputs[i].node == nullptr) {
1251 // Record this back edge, which will be added after all nodes
1252 // are created.
1253 back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1254 } else if (inputs[i].index == Graph::kControlSlot) {
1255 g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1256 } else {
1257 TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1258 }
1259 }
1260
1261 TF_RETURN_IF_ERROR(ValidateShape(node));
1262
1263 // Update pending_count_ for outputs.
1264 UpdatePendingCountAndReady(o, node->IsNextIteration());
1265 }
1266
1267 if (processed < node_def_count()) {
1268 LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1269 << " NODES IN A CYCLE";
1270 for (int64_t i = 0; i < node_def_count(); i++) {
1271 if (pending_count_[i] != 0) {
1272 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1273 << " WITH PENDING COUNT = " << pending_count_[i];
1274 }
1275 }
1276 PrintCycles();
1277 return errors::InvalidArgument(node_def_count() - processed,
1278 " nodes in a cycle");
1279 }
1280
1281 return Status::OK();
1282 }
1283
AddBackEdges()1284 Status GraphConstructor::AddBackEdges() {
1285 // Add the back edges after all nodes are created.
1286 for (const auto& e : back_edges_) {
1287 Node* src_node = gdef_nodes_[e.src_name].node;
1288 if (e.src_index == Graph::kControlSlot) {
1289 g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1290 } else {
1291 TF_RETURN_IF_ERROR(
1292 MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1293 }
1294
1295 VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1296 << e.dst_node->name();
1297 }
1298 return Status::OK();
1299 }
1300
UpdateVersionDef()1301 Status GraphConstructor::UpdateVersionDef() {
1302 if (versions() == nullptr) return Status::OK();
1303
1304 if (!opts_.importing) {
1305 g_->set_versions(*versions());
1306 return Status::OK();
1307 }
1308 VersionDef g_versions = g_->versions();
1309 g_versions.set_producer(
1310 std::min(g_versions.producer(), versions()->producer()));
1311 g_versions.set_min_consumer(
1312 std::max(g_versions.min_consumer(), versions()->min_consumer()));
1313 if (versions()->bad_consumers_size() > 0) {
1314 std::set<int> bad(g_versions.bad_consumers().begin(),
1315 g_versions.bad_consumers().end());
1316 bad.insert(versions()->bad_consumers().begin(),
1317 versions()->bad_consumers().end());
1318 g_versions.clear_bad_consumers();
1319 for (int v : bad) {
1320 g_versions.add_bad_consumers(v);
1321 }
1322 }
1323 g_->set_versions(g_versions);
1324 return Status::OK();
1325 }
1326
PopulateReturnTensors()1327 Status GraphConstructor::PopulateReturnTensors() {
1328 if (opts_.return_tensors.empty()) return Status::OK();
1329 for (const TensorId& id : opts_.return_tensors) {
1330 auto iter = opts_.input_map.find(id);
1331 if (iter == opts_.input_map.end()) {
1332 // Locate id in imported nodes
1333 auto iter = gdef_nodes_.find(id.first);
1334 if (iter == gdef_nodes_.end()) {
1335 return errors::InvalidArgument("Requested return tensor '",
1336 id.ToString(),
1337 "' not found in graph def");
1338 }
1339 int num_outputs = iter->second.node->num_outputs();
1340 if ((id.second < 0 || id.second >= num_outputs) &&
1341 id.second != Graph::kControlSlot) {
1342 return errors::InvalidArgument("Invalid return output ", id.second,
1343 " of node '", id.first, "', which has ",
1344 num_outputs, " output(s)");
1345 }
1346 return_tensors_->push_back({iter->second.node, id.second});
1347 } else {
1348 // id was remapped to existing node
1349 TensorId remapped_id = iter->second;
1350 DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1351 Node* node = existing_nodes_[remapped_id.first];
1352 return_tensors_->push_back({node, remapped_id.second});
1353 }
1354 }
1355 return Status::OK();
1356 }
1357
PopulateReturnNodes()1358 Status GraphConstructor::PopulateReturnNodes() {
1359 if (opts_.return_nodes.empty()) return Status::OK();
1360 for (StringPiece name : opts_.return_nodes) {
1361 auto iter = gdef_nodes_.find(name);
1362 if (iter == gdef_nodes_.end()) {
1363 return errors::InvalidArgument("Requested return node '", name,
1364 "' not found in graph def");
1365 }
1366 return_nodes_->push_back(iter->second.node);
1367 }
1368 return Status::OK();
1369 }
1370
PopulateMissingUnusedInputMapKeys()1371 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1372 if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1373 for (const auto& input_map_pair : opts_.input_map) {
1374 TensorId key = input_map_pair.first;
1375 if (used_input_map_keys_.count(key) > 0) continue;
1376
1377 auto pair = gdef_nodes_.find(key.first);
1378 if (pair == gdef_nodes_.end()) {
1379 // key's node doesn't exist in GraphDef
1380 missing_unused_input_map_keys_->push_back(key);
1381 continue;
1382 }
1383
1384 // Check that key's index is in bounds. Get the number of outputs from the
1385 // NodeDef, rather than the imported Node, since the Node may not exist if
1386 // opts_.skip_mapped_nodes is true.
1387 const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1388 const OpDef* op_def;
1389 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1390 int num_outputs;
1391 TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1392 if (key.second >= num_outputs) {
1393 // key's index out of bounds
1394 missing_unused_input_map_keys_->push_back(key);
1395 }
1396 }
1397 return Status::OK();
1398 }
1399
Undo()1400 void GraphConstructor::Undo() {
1401 for (const auto& iter : gdef_nodes_) {
1402 if (iter.second.node != nullptr) {
1403 g_->RemoveNode(iter.second.node);
1404 }
1405 }
1406 g_->set_versions(original_versions_);
1407 }
1408
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1409 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1410 int input_index) {
1411 if (output_index >= src->num_outputs()) {
1412 return errors::InvalidArgument(
1413 "Output ", output_index, " of node ", src->name(),
1414 " does not exist. Node only has ", src->num_outputs(), " outputs.");
1415 }
1416 if (input_index >= dst->num_inputs()) {
1417 return errors::InvalidArgument(
1418 "Input ", input_index, " of node ", dst->name(),
1419 " does not exist. Node only has ", dst->num_inputs(), " inputs.");
1420 }
1421
1422 DataType src_out = src->output_type(output_index);
1423 DataType dst_in = dst->input_type(input_index);
1424 if (!TypesCompatible(dst_in, src_out)) {
1425 return errors::InvalidArgument(
1426 "Input ", input_index, " of node ", dst->name(), " was passed ",
1427 DataTypeString(src_out), " from ", src->name(), ":", output_index,
1428 " incompatible with expected ", DataTypeString(dst_in), ".");
1429 }
1430 g_->AddEdge(src, output_index, dst, input_index);
1431 return Status::OK();
1432 }
1433
1434 } // namespace
1435
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1436 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1437 const GraphDef& gdef, Graph* g) {
1438 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1439 return GraphConstructor::Construct(
1440 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1441 /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1442 /*missing_unused_input_map_keys=*/nullptr);
1443 }
1444
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,GraphDef && gdef,Graph * g)1445 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1446 GraphDef&& gdef, Graph* g) {
1447 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1448 return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1449 /*return_tensors=*/nullptr,
1450 /*return_nodes=*/nullptr,
1451 /*missing_unused_input_map_keys=*/nullptr);
1452 }
1453
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1454 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1455 gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1456 ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1457 // TODO(irving): Copy will go away once NodeInfo exists
1458 std::vector<const NodeDef*> node_defs;
1459 node_defs.reserve(nodes.size());
1460 for (const auto& n : nodes) {
1461 node_defs.push_back(&n);
1462 }
1463 return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1464 &refiner, /*return_tensors=*/nullptr,
1465 /*return_nodes=*/nullptr,
1466 /*missing_unused_input_map_keys=*/nullptr);
1467 }
1468
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1469 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1470 Graph* g, ShapeRefiner* refiner,
1471 ImportGraphDefResults* results) {
1472 if (!opts.return_tensors.empty()) {
1473 if (results == nullptr) {
1474 return errors::InvalidArgument(
1475 "results argument to ImportGraphDef() must be non-null if "
1476 "opts.return_tensors is non-empty");
1477 }
1478 }
1479
1480 if (!opts.return_nodes.empty()) {
1481 if (opts.skip_mapped_nodes) {
1482 return errors::InvalidArgument(
1483 "Requesting return_nodes with skip_mapped_nodes set is not currently "
1484 "supported");
1485 }
1486 if (results == nullptr) {
1487 return errors::InvalidArgument(
1488 "results argument to ImportGraphDef() must be non-null if "
1489 "opts.return_nodes is non-empty");
1490 }
1491 }
1492
1493 if (results != nullptr) {
1494 if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1495 !results->missing_unused_input_map_keys.empty()) {
1496 return errors::InvalidArgument(
1497 "All fields in results argument to ImportGraphDef() must be empty.");
1498 }
1499 }
1500
1501 ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1502 if (refiner == nullptr) {
1503 refiner = &default_refiner;
1504 } else {
1505 // Log a warning if we are importing a GraphDef at an older
1506 // producer version after already having added non-source/sink
1507 // nodes to the graph in the past.
1508 if (gdef.versions().producer() > 0 &&
1509 gdef.versions().producer() < refiner->graph_def_version() &&
1510 g->num_nodes() > 2) {
1511 LOG(WARNING) << "Importing a graph with a lower producer version "
1512 << gdef.versions().producer()
1513 << " into an existing graph with producer version "
1514 << refiner->graph_def_version() << ". Shape inference will "
1515 << "have run different parts of the graph with different "
1516 << "producer versions.";
1517 }
1518 }
1519
1520 // Set the graph def version of the refiner as the min of the
1521 // current value and the version from the graph we are about to
1522 // import.
1523 //
1524 // Note: to match Run() semantics, we should re-run shape inference
1525 // on the entire graph if the producer version has changed. For now
1526 // we log the warning above.
1527 refiner->set_graph_def_version(
1528 std::min(refiner->graph_def_version(), gdef.versions().producer()));
1529
1530 if (results == nullptr) {
1531 return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1532 &gdef.library(), g, refiner, nullptr,
1533 nullptr, nullptr);
1534 } else {
1535 return GraphConstructor::Construct(
1536 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1537 &results->return_tensors, &results->return_nodes,
1538 &results->missing_unused_input_map_keys);
1539 }
1540 }
1541
CopyGraph(const Graph & src,Graph * dest)1542 void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
1543
1544 } // namespace tensorflow
1545