1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph_constructor.h"
17
18 #include <algorithm>
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/core/common_runtime/shape_refiner.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function.pb.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/versions.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/tensor_id.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/gtl/flatmap.h"
42 #include "tensorflow/core/lib/gtl/flatset.h"
43 #include "tensorflow/core/lib/gtl/inlined_vector.h"
44 #include "tensorflow/core/lib/strings/scanner.h"
45 #include "tensorflow/core/lib/strings/str_util.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/macros.h"
48 #include "tensorflow/core/public/version.h"
49
50 namespace tensorflow {
51
52 namespace {
53
54 // We remove duplicate control inputs before adding edges to the Graph, so we
55 // can skip expensive duplicates check in 'AddControlEdge'.
56 static constexpr const bool kDoNotCheckDuplicates = true;
57
IsMerge(const NodeDef & node_def)58 inline bool IsMerge(const NodeDef& node_def) {
59 return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
60 node_def.op() == "_XlaMerge";
61 }
62
IsNextIteration(const NodeDef & node_def)63 inline bool IsNextIteration(const NodeDef& node_def) {
64 return node_def.op() == "NextIteration" ||
65 node_def.op() == "RefNextIteration";
66 }
67
IsValidNodeName(StringPiece s,bool allow_internal_ops)68 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
69 using ::tensorflow::strings::Scanner;
70 Scanner scanner(s);
71 scanner
72 .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
73 : Scanner::LETTER_DIGIT_DOT)
74 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
75
76 while (true) {
77 if (!scanner.GetResult()) // Some error in previous iteration.
78 return false;
79 if (scanner.empty()) // No error, but nothing left, good.
80 return true;
81
82 // Absorb another piece, starting with a '>'
83 scanner.One(Scanner::RANGLE)
84 .One(Scanner::LETTER_DIGIT_DOT)
85 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
86 }
87 }
88
89 class GraphConstructor {
90 public:
91 struct Options {
Optionstensorflow::__anon29e8b4a40111::GraphConstructor::Options92 Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
93 : allow_internal_ops(in.allow_internal_ops),
94 expect_device_spec(in.expect_device_spec),
95 importing(false),
96 validate_nodes(in.validate_nodes),
97 validate_colocation_constraints(false),
98 add_default_attributes(in.add_default_attributes) {}
Optionstensorflow::__anon29e8b4a40111::GraphConstructor::Options99 Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
100 : allow_internal_ops(false),
101 expect_device_spec(false),
102 prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
103 ? in.prefix
104 : in.prefix + "/"),
105 uniquify_names(in.uniquify_names),
106 uniquify_prefix(in.uniquify_prefix),
107 input_map(in.input_map.begin(), in.input_map.end()),
108 skip_mapped_nodes(in.skip_mapped_nodes),
109 control_dependencies(in.control_dependencies),
110 return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
111 return_nodes(in.return_nodes),
112 importing(true),
113 validate_nodes(true),
114 validate_colocation_constraints(in.validate_colocation_constraints),
115 validate_shape(in.validate_shape),
116 default_device(in.default_device) {}
117
118 bool allow_internal_ops;
119 bool expect_device_spec;
120
121 string prefix;
122 bool uniquify_names;
123 bool uniquify_prefix;
124 std::map<TensorId, TensorId> input_map;
125 bool skip_mapped_nodes;
126 std::vector<string> control_dependencies;
127 std::vector<TensorId> return_tensors;
128 std::vector<string> return_nodes;
129
130 // TODO(ashankar): This bool exists to separate out functionality required
131 // to make ImportGraphDef a close equivalent of Python's import_graph_def
132 // without affecting the behavior of ConvertGraphDefToGraph at the time
133 // ImportGraphDef was added.
134 //
135 // That said, the functionality here (shape and op validation) seems
136 // applicable to ConvertGraphDefToGraph as well, so make an attempt to
137 // remove this.
138 bool importing;
139 // If true, validates that nodes being converted have all expected attrs
140 // set and no unknonw attrs set by calling ValidateNodeDef().
141 // `validate_nodes` is always true when `importing` is set.
142 bool validate_nodes;
143 bool validate_colocation_constraints;
144 bool validate_shape = true;
145
146 // If true, GraphConstructor will add attributes with their default
147 // value to the Node when they are missing from the NodeDef.
148 bool add_default_attributes = true;
149
150 string default_device;
151 };
152
153 typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
154
155 // versions and library may be nullptr
156 static Status Construct(
157 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
158 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
159 std::vector<std::pair<Node*, int>>* return_tensors,
160 std::vector<Node*>* return_nodes,
161 std::vector<SafeTensorId>* missing_unused_input_map_keys);
162
163 static Status Construct(
164 const Options& opts, GraphDef&& graph_def, Graph* g,
165 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
166 std::vector<Node*>* return_nodes,
167 std::vector<SafeTensorId>* missing_unused_input_map_keys);
168
169 protected:
GraphConstructor(const Options & opts,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)170 GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
171 std::vector<std::pair<Node*, int>>* return_tensors,
172 std::vector<Node*>* return_nodes,
173 std::vector<SafeTensorId>* missing_unused_input_map_keys)
174 : opts_(opts),
175 g_(g),
176 original_versions_(g->versions()),
177 prefix_(opts.prefix),
178 refiner_(refiner),
179 return_tensors_(return_tensors),
180 return_nodes_(return_nodes),
181 missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
182
~GraphConstructor()183 virtual ~GraphConstructor() {}
184
TryImport()185 Status TryImport() {
186 TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
187 TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
188 TF_RETURN_IF_ERROR(BuildNodeIndex());
189 TF_RETURN_IF_ERROR(InitFromEdges());
190
191 // NOTE: Convert() invokes `consume_node_def()` on each node in the input
192 // graph, so `get_node_def()` is no longer usable once it is called.
193 TF_RETURN_IF_ERROR(Convert());
194
195 TF_RETURN_IF_ERROR(AddBackEdges());
196 TF_RETURN_IF_ERROR(UpdateVersionDef());
197 TF_RETURN_IF_ERROR(PopulateReturnTensors());
198 TF_RETURN_IF_ERROR(PopulateReturnNodes());
199 TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
200 UpdateUniquifiedColocationNames();
201 FixupSourceAndSinkEdges(g_);
202 return Status::OK();
203 }
204
205 private:
206 Status EnsureNoNameCollisions();
207 Status ValidateInputMapAndControlDependencies();
208 Status BuildNodeIndex();
209 Status InitFromEdges();
210 Status Convert();
211 Status AddBackEdges();
212 Status UpdateVersionDef();
213 Status PopulateReturnTensors();
214 Status PopulateReturnNodes();
215 Status PopulateMissingUnusedInputMapKeys();
216
217 void Undo();
218
219 // Prints cycles in the graph.
220 void PrintCycles();
221 // Performs DFS starting at `cur_node` and prints any cycles found.
222 void DFS(int cur_node, std::vector<int>* cur_branch,
223 std::vector<bool>* is_on_cur_branch,
224 absl::flat_hash_set<int>* unvisited);
225 Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
226 Status ValidateColocationConstraints(const NodeDef& node_def);
227 Status MakeNode(NodeDef&& node_def, Node** node);
228 Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
229 Status ValidateShape(Node* node);
230 Status ModifyNodeDefForImport(NodeDef* node_def);
231 // Modifies node_def's inputs according to opts_.input_map.
232 // input_already_exists is a pre-initialized vector of length
233 // node_def->input_size(). This function will mark inputs that are remapped to
234 // true.
235 void RemapNodeDefInputs(NodeDef* node_def,
236 std::vector<bool>* input_already_exists);
237 // input_already_exists is a pre-initialized vector of length
238 // node_def->input_size(). This function will add and mark control inputs as
239 // true.
240 void AddControlDependencies(NodeDef* node_def,
241 std::vector<bool>* input_already_exists);
242 void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
243 NodeDef* node_def);
244
245 // Modifies `node_def` if its name isn't unique, or if any of its inputs'
246 // names have been uniquified. This must be called in topological order on all
247 // nodes.
248 void UniquifyNames(const std::vector<bool>& input_already_exists,
249 NodeDef* node_def);
250
251 // Updates any constructed nodes' colocation group names if the name has been
252 // updated by UniquifyNames. This is called after all the nodes have been
253 // constructed so all the names have been uniquified if necessary.
254 void UpdateUniquifiedColocationNames();
255
256 // Returns true if `name` already exists in `g_` (either as a node name or
257 // prefix).
258 bool NameExistsInGraph(StringPiece name);
259
260 // Returns true if `name` already exists in the GraphDef being imported
261 // (either as a node name or prefix).
262 bool NameExistsInGraphDef(StringPiece name);
263
264 // Returns a unique version of `original_name`, or `original_name` if it's
265 // already unique in the graph.
266 string FindUniqueName(StringPiece original_name);
267
268 // Decrement pending count for users of `processed` and add the ones that now
269 // have all of their pending inputs satisfied to `ready_`.
270 void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
271
272 // Subclasses override the following virtual methods to provide efficient
273 // access to the original protocol buffer-based graph.
274
275 // Returns the number of nodes in the graph.
276 virtual size_t node_def_count() const = 0;
277 // Returns the i^th node in the graph. Must not be called after
278 // consume_node_def(i).
279 virtual const NodeDef& get_node_def(int i) const = 0;
280 // Destructively reads the i^th node in the graph, avoiding a copy if
281 // possible. After calling this method, the result of get_node_def(i) is
282 // undefined.
283 virtual NodeDef consume_node_def(int i) = 0;
284 // Returns the version information for the graph, or nullptr if none is
285 // available.
286 virtual const VersionDef* versions() const = 0;
287 // Returns the function information for the graph, or nullptr if none is
288 // available.
289 virtual const FunctionDefLibrary* library() const = 0;
290
291 // From constructor
292 const Options opts_;
293 Graph* g_;
294 const VersionDef original_versions_;
295
296 // A copy of opts_.prefix, possibly uniquified.
297 string prefix_;
298
299 ShapeRefiner* refiner_;
300
301 // May be null. Not owned.
302 std::vector<std::pair<Node*, int>>* return_tensors_;
303
304 // May be null. Not owned.
305 std::vector<Node*>* return_nodes_;
306
307 // May be null. Not owned.
308 std::vector<SafeTensorId>* missing_unused_input_map_keys_;
309
310 // Intermediate datastructure used to populate
311 // `missing_unused_input_map_keys_`.
312 std::set<TensorId> used_input_map_keys_;
313
314 // Intermediate datastructure used to track the destinations of back edges.
315 absl::flat_hash_set<int> merge_node_indices_;
316
317 // Mapping from node name to the index within node_defs_.
318 struct NodeInfo {
NodeInfotensorflow::__anon29e8b4a40111::GraphConstructor::NodeInfo319 explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
320 // Containers require that we have a default constructor.
NodeInfotensorflow::__anon29e8b4a40111::GraphConstructor::NodeInfo321 NodeInfo() : NodeInfo(-1) {}
322 int gdef_index;
323 Node* node; // nullptr until the NodeDef is converted to a Node.
324 };
325 gtl::FlatMap<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_;
326
327 // Storage for StringPiece keys in gdef_nodes_. Typically, the StringPiece key
328 // will refer to the string stored in `NodeDef::name()`. This intern table is
329 // only used when the original NodeDef's name is changed.
330 std::vector<string> string_intern_table_;
331
332 // Prefixes already used in the GraphDef being imported.
333 gtl::FlatSet<StringPiece, StringPieceHasher> gdef_prefixes_;
334
335 // Mapping from node name to the existing node in g_.
336 gtl::FlatMap<StringPiece, Node*, StringPieceHasher> existing_nodes_;
337
338 // Prefixes already used in the graph.
339 gtl::FlatSet<StringPiece, StringPieceHasher> existing_prefixes_;
340
341 // Imported node names that have been uniquified. The key is the original
342 // name, the value is the new unique name.
343 gtl::FlatMap<string, string> uniquified_names_;
344
345 // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
346 // (sorted) set so nodes are created in the order defined in the GraphDef.
347 std::set<int> ready_;
348
349 // Mapping between index within node_defs_ and the number of inputs that
350 // still need to be converted.
351 std::vector<int> pending_count_;
352
353 // Mapping between index within node_defs_ and the index within node_defs_ of
354 // all nodes it outputs to.
355 std::vector<gtl::InlinedVector<int, 4>> outputs_;
356
357 // Used in the conversion from node_defs_ to g_ to represent the ith input
358 // of a node.
359 struct InputInfo {
InputInfotensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo360 explicit InputInfo(const string& node_name, Node* n, int i)
361 : name(node_name), node(n), index(i) {}
362 // Use string instead of StringPiece so we don't have to manage lifetime
363 string name;
364 Node* node;
365 int index;
366
IsControlInputtensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo367 static bool IsControlInput(const InputInfo& input) {
368 return input.index == Graph::kControlSlot;
369 }
CompareNametensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo370 static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
371 return lhs.name < rhs.name;
372 }
IsSameNametensorflow::__anon29e8b4a40111::GraphConstructor::InputInfo373 static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
374 return lhs.name == rhs.name;
375 }
376 };
377
378 // Used in the conversion from node_defs_ to g_ to represent an edge from
379 // the node named 'name' to node 'n'.
380 struct EdgeInfo {
EdgeInfotensorflow::__anon29e8b4a40111::GraphConstructor::EdgeInfo381 explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
382 : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
383 // Use string instead of StringPiece so we don't have to manage lifetime
384 string src_name;
385 int src_index;
386 Node* dst_node;
387 int dst_index;
388 };
389 std::vector<EdgeInfo> back_edges_;
390
391 TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
392 };
393
394 // Implementation of GraphConstructor that does not take ownership of the
395 // input NodeDef messages and thus copies the nodes into the constructed Graph*.
396 //
397 // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
398 // avoids copying each NodeDef into the constructed Graph*.
399 class NodeDefCopyingGraphConstructor : public GraphConstructor {
400 public:
NodeDefCopyingGraphConstructor(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)401 NodeDefCopyingGraphConstructor(
402 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
403 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
404 std::vector<std::pair<Node*, int>>* return_tensors,
405 std::vector<Node*>* return_nodes,
406 std::vector<SafeTensorId>* missing_unused_input_map_keys)
407 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
408 missing_unused_input_map_keys),
409 node_defs_(node_defs),
410 versions_(versions),
411 library_(library) {}
412
413 private:
node_def_count() const414 size_t node_def_count() const override { return node_defs_.size(); }
get_node_def(int i) const415 const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
consume_node_def(int i)416 NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
versions() const417 const VersionDef* versions() const override { return versions_; }
library() const418 const FunctionDefLibrary* library() const override { return library_; }
419
420 const NodeDefSlice node_defs_;
421 const VersionDef* const versions_;
422 const FunctionDefLibrary* const library_;
423 };
424
425 // Implementation of GraphConstructor that takes ownership of the input
426 // GraphDef, and can perform destructive reads.
427 class NodeDefMovingGraphConstructor : public GraphConstructor {
428 public:
NodeDefMovingGraphConstructor(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)429 NodeDefMovingGraphConstructor(
430 const Options& opts, GraphDef&& graph_def, Graph* g,
431 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
432 std::vector<Node*>* return_nodes,
433 std::vector<SafeTensorId>* missing_unused_input_map_keys)
434 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
435 missing_unused_input_map_keys),
436 graph_def_(std::move(graph_def)),
437 is_consumed_(graph_def_.node_size(), false) {}
438
439 private:
node_def_count() const440 size_t node_def_count() const override { return graph_def_.node().size(); }
get_node_def(int i) const441 const NodeDef& get_node_def(int i) const override {
442 CHECK(!is_consumed_[i])
443 << "NodeDef " << i << " accessed after it was consumed.";
444 return graph_def_.node(i);
445 }
consume_node_def(int i)446 NodeDef consume_node_def(int i) override {
447 CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
448 is_consumed_[i] = true;
449 return std::move(*graph_def_.mutable_node(i));
450 }
versions() const451 const VersionDef* versions() const override { return &graph_def_.versions(); }
library() const452 const FunctionDefLibrary* library() const override {
453 return &graph_def_.library();
454 }
455
456 GraphDef graph_def_;
457 std::vector<bool> is_consumed_;
458 };
459
Construct(const Options & opts,NodeDefSlice node_defs,const VersionDef * versions,const FunctionDefLibrary * library,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)460 /* static */ Status GraphConstructor::Construct(
461 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
462 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
463 std::vector<std::pair<Node*, int>>* return_tensors,
464 std::vector<Node*>* return_nodes,
465 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
466 if (versions) {
467 TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
468 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
469 "GraphDef", "graph"));
470 }
471 NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
472 refiner, return_tensors, return_nodes,
473 missing_unused_input_map_keys);
474 const Status s = c.TryImport();
475 if (!s.ok()) c.Undo();
476 return s;
477 }
478
Construct(const Options & opts,GraphDef && graph_def,Graph * g,ShapeRefiner * refiner,std::vector<std::pair<Node *,int>> * return_tensors,std::vector<Node * > * return_nodes,std::vector<SafeTensorId> * missing_unused_input_map_keys)479 /* static */ Status GraphConstructor::Construct(
480 const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
481 std::vector<std::pair<Node*, int>>* return_tensors,
482 std::vector<Node*>* return_nodes,
483 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
484 TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
485 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
486 "GraphDef", "graph"));
487 NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
488 return_tensors, return_nodes,
489 missing_unused_input_map_keys);
490 const Status s = c.TryImport();
491 if (!s.ok()) c.Undo();
492 return s;
493 }
494
UpdatePendingCountAndReady(int processed,bool is_next_iteration)495 void GraphConstructor::UpdatePendingCountAndReady(int processed,
496 bool is_next_iteration) {
497 for (size_t i = 0; i < outputs_[processed].size(); ++i) {
498 const int output = outputs_[processed][i];
499 // We didn't consider NextIteration->Merge edges when computing
500 // pending_counts_ so we should not have to consider it here either.
501 bool is_next_iteration_to_merge_edge =
502 is_next_iteration && merge_node_indices_.count(output) == 1;
503 if (!is_next_iteration_to_merge_edge) {
504 int* current_pending_count = &pending_count_[output];
505 CHECK_GT(*current_pending_count, 0);
506 (*current_pending_count)--;
507 if (*current_pending_count == 0) {
508 ready_.insert(output);
509 }
510 }
511 }
512 }
513
514 // This could be expensive but we don't expect to call it often, if at all (only
515 // if there are multiple nodes in g_ with the same name)
NodeNameInValues(const std::map<TensorId,TensorId> & input_map,const StringPiece & node_name)516 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
517 const StringPiece& node_name) {
518 for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
519 if (iter->second.first == node_name) return true;
520 }
521 return false;
522 }
523
NodeNameInValues(const std::vector<string> & control_dependencies,const StringPiece & node_name)524 bool NodeNameInValues(const std::vector<string>& control_dependencies,
525 const StringPiece& node_name) {
526 return std::find(control_dependencies.begin(), control_dependencies.end(),
527 node_name) != control_dependencies.end();
528 }
529
530 // Adds any prefixes of `node_name` (not including the full name itself) to
531 // `prefixes`.
AddPrefixes(StringPiece node_name,gtl::FlatSet<StringPiece,StringPieceHasher> * prefixes)532 void AddPrefixes(StringPiece node_name,
533 gtl::FlatSet<StringPiece, StringPieceHasher>* prefixes) {
534 size_t idx = -1;
535 while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
536 prefixes->insert(node_name.substr(0, idx));
537 }
538 }
539
EnsureNoNameCollisions()540 Status GraphConstructor::EnsureNoNameCollisions() {
541 existing_nodes_.reserve(g_->num_nodes());
542 // Populate existing_nodes_ and existing_prefixes_.
543 for (Node* n : g_->nodes()) {
544 bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
545 if (already_exists) {
546 if (NodeNameInValues(opts_.input_map, n->name())) {
547 return errors::InvalidArgument(
548 "cannot resolve input_map because multiple nodes exist with name '",
549 n->name(), "'");
550 }
551 if (NodeNameInValues(opts_.control_dependencies, n->name())) {
552 return errors::InvalidArgument(
553 "cannot resolve control_dependencies because multiple nodes exist "
554 "with name '",
555 n->name(), "'");
556 }
557 }
558 AddPrefixes(n->name(), &existing_prefixes_);
559 }
560 if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
561 for (size_t i = 0; i < node_def_count(); ++i) {
562 const string& name = get_node_def(i).name();
563 if (NameExistsInGraph(name)) {
564 return errors::InvalidArgument("Node name '", name,
565 "' already exists in the Graph");
566 }
567 }
568 } else if (!prefix_.empty()) {
569 StringPiece prefix_no_slash(prefix_);
570 prefix_no_slash.remove_suffix(1);
571 if (!IsValidNodeName(prefix_no_slash, false)) {
572 return errors::InvalidArgument("Imported node name prefix '", prefix_,
573 "' would lead to invalid node names");
574 }
575 if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
576 prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
577 }
578 }
579 return Status::OK();
580 }
581
ValidateInputMapAndControlDependencies()582 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
583 for (const auto& mapping : opts_.input_map) {
584 TensorId src = mapping.first;
585 TensorId dst = mapping.second;
586 if (existing_nodes_.count(dst.first) == 0) {
587 return errors::InvalidArgument(
588 "node '", dst.first, "' in input_map does not exist in graph ",
589 "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
590 }
591 if ((src.second == Graph::kControlSlot) !=
592 (dst.second == Graph::kControlSlot)) {
593 return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
594 dst.ToString(), " between ",
595 "control edge and non-control edge");
596 }
597 }
598 for (const string& node : opts_.control_dependencies) {
599 if (existing_nodes_.count(node) == 0) {
600 return errors::InvalidArgument(
601 "node '", node,
602 "' in control_dependencies does not exist in "
603 "graph");
604 }
605 }
606 return Status::OK();
607 }
608
BuildNodeIndex()609 Status GraphConstructor::BuildNodeIndex() {
610 // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
611 for (int n = 0; n < node_def_count(); ++n) {
612 const NodeDef& node_def = get_node_def(n);
613 if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
614 return errors::InvalidArgument(
615 "Node '", node_def.name(),
616 "': Node name contains invalid characters");
617 }
618 if (!gdef_nodes_
619 .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
620 .second) {
621 return errors::InvalidArgument("Node '", node_def.name(),
622 "' is not unique");
623 }
624 // Validate the operation's type.
625 if (node_def.op().empty()) {
626 return errors::InvalidArgument("Node '", node_def.name(),
627 "' does not specify an operation");
628 }
629 if (opts_.expect_device_spec && node_def.device().empty()) {
630 return errors::InvalidArgument("Node '", node_def.name(),
631 "' is missing a device specification");
632 }
633 if (IsMerge(node_def)) {
634 merge_node_indices_.insert(n);
635 }
636 // Validate control edges at end
637 bool in_control_dependence = false;
638 for (int i = 0; i < node_def.input_size(); ++i) {
639 StringPiece input_name = node_def.input(i);
640 if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
641 in_control_dependence = true;
642 } else if (in_control_dependence) {
643 return errors::InvalidArgument(
644 "Node '", node_def.name(),
645 "': Control dependencies must come after regular dependencies");
646 }
647 }
648 // Update gdef_prefixes_.
649 AddPrefixes(node_def.name(), &gdef_prefixes_);
650 }
651 return Status::OK();
652 }
653
InitFromEdges()654 Status GraphConstructor::InitFromEdges() {
655 const int num_nodes = node_def_count();
656 pending_count_.reserve(num_nodes);
657 outputs_.resize(num_nodes);
658 gtl::FlatSet<string> next_iteration_nodes;
659 for (int n = 0; n < node_def_count(); ++n) {
660 const NodeDef& node_def = get_node_def(n);
661 if (IsNextIteration(node_def)) {
662 next_iteration_nodes.insert(node_def.name());
663 }
664 }
665
666 // Parse the inputs for each node.
667 for (int n = 0; n < num_nodes; ++n) {
668 const NodeDef& node_def = get_node_def(n);
669 int pending_count = node_def.input_size();
670 if (IsMerge(node_def)) {
671 // Cycles in the graph are only allowed for while loops. A while loop is
672 // identified by an edge from a NextIteration node to a Merge node. For
673 // such Merge nodes, only wait for one non-control input before
674 // considering the node ready to process in Convert().
675 int32 num_control_edges = 0;
676 bool has_loop_back_edge = false;
677 for (int i = 0; i < node_def.input_size(); ++i) {
678 StringPiece input_name(node_def.input(i));
679 if (absl::StartsWith(input_name, "^")) {
680 num_control_edges++;
681 } else {
682 TensorId id(ParseTensorName(input_name));
683 if (next_iteration_nodes.find(string(id.first)) !=
684 next_iteration_nodes.end()) {
685 has_loop_back_edge = true;
686 }
687 }
688 }
689 if (has_loop_back_edge) {
690 pending_count = num_control_edges + 1;
691 }
692 }
693 for (int i = 0; i < node_def.input_size(); ++i) {
694 StringPiece input_name = node_def.input(i);
695 TensorId id(ParseTensorName(input_name));
696 if (opts_.input_map.count(id) == 0) {
697 // If an input is not mapped, then the input should appear in the graph
698 // being imported.
699 auto iter = gdef_nodes_.find(id.first);
700 if (iter == gdef_nodes_.end()) {
701 return errors::InvalidArgument("Node '", node_def.name(),
702 "': Unknown input node '",
703 node_def.input(i), "'");
704 }
705 outputs_[iter->second.gdef_index].push_back(n);
706 } else {
707 // This input is mapped to an existing edge. Therefore this input is
708 // as good as being already processed.
709 --pending_count;
710 DCHECK_GE(pending_count, 0);
711 }
712 }
713 if (pending_count == 0) {
714 ready_.insert(n);
715 }
716 pending_count_.push_back(pending_count);
717 }
718 return Status::OK();
719 }
720
ValidateColocationConstraints(const NodeDef & node_def)721 Status GraphConstructor::ValidateColocationConstraints(
722 const NodeDef& node_def) {
723 if (!opts_.validate_colocation_constraints || !opts_.importing)
724 return Status::OK();
725 const auto iter = node_def.attr().find(kColocationAttrName);
726 if (iter == node_def.attr().end()) return Status::OK();
727 for (const string& c : iter->second.list().s()) {
728 StringPiece s(c);
729 if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
730 gdef_nodes_.find(s) == gdef_nodes_.end()) {
731 return errors::InvalidArgument(
732 "Node '", node_def.name(),
733 "' expects to be colocated with unknown node '", s, "'");
734 }
735 }
736 return Status::OK();
737 }
738
MakeNode(NodeDef && node_def,Node ** node)739 Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
740 // Add the node to the graph.
741 Status status;
742 *node = g_->AddNode(std::move(node_def), &status);
743 if (!status.ok()) return status;
744 if (opts_.expect_device_spec) {
745 (*node)->set_assigned_device_name((*node)->def().device());
746 }
747 return Status::OK();
748 }
749
ValidateShape(Node * node)750 Status GraphConstructor::ValidateShape(Node* node) {
751 if (!opts_.importing || !opts_.validate_shape) return Status::OK();
752 TF_RETURN_IF_ERROR(refiner_->AddNode(node));
753 // For nodes with the _output_shapes attribute, override the shape.
754 std::vector<const TensorShapeProto*> shape_attrs;
755 const char* kAttrName = "_output_shapes";
756 if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
757 // No _output_shapes attribute, the AddNode call above was sufficient.
758 return Status::OK();
759 }
760 auto* ic = refiner_->GetContext(node);
761 DCHECK(ic != nullptr)
762 << "ShapeRefiner::AddNode() should have created the InferenceContext";
763 if (shape_attrs.size() < node->num_outputs()) {
764 return errors::InvalidArgument(
765 "Node '", node->name(), "' has ", node->num_outputs(),
766 " outputs but the ", kAttrName, " attribute specifies shapes for ",
767 shape_attrs.size(), " outputs");
768 }
769 // NOTE(skyewm): we don't raise an error here because some users depend on
770 // this behavior, even though it's unsafe.
771 // TODO(b/74619486): raise an error.
772 if (shape_attrs.size() > node->num_outputs()) {
773 LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
774 << " outputs but the " << kAttrName
775 << " attribute specifies shapes for " << shape_attrs.size()
776 << " outputs. Output shapes may be inaccurate.";
777 }
778 for (int i = 0; i < node->num_outputs(); ++i) {
779 const TensorShapeProto& p = *shape_attrs[i];
780 shape_inference::ShapeHandle h;
781 Status s = ic->MakeShapeFromShapeProto(p, &h);
782 if (!s.ok()) {
783 return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
784 kAttrName, " attribute (shape #", i,
785 " error:'", s.error_message(), "'");
786 }
787 s = refiner_->SetShape(node, i, h);
788 if (!s.ok()) {
789 // If the output shape is incompatible with what is inferred
790 // by the graph for a very specific whitelist of ops, then we
791 // ignore this output shape. This can happen if there is a
792 // bug in the shape function for some operation, and the
793 // serialized graph def has the incorrect shape set when
794 // running on a newer binary with the fixed shape function.
795 // This is an escape hatch that allows us to correct shape
796 // functions that are not critical to correct execution but
797 // would cause graphs to fail if imported after correcting.
798 const string& op = node->type_string();
799 const std::vector<string> whitelist = {
800 // To be removed after 2017/03/08.
801 "RandomShuffleQueue",
802 "PaddingFIFOQueue",
803 "FIFOQueue",
804 "PriorityQueue",
805 "QueueSize",
806 "Stack",
807 "Barrier",
808 "BarrierReadySize",
809 "BarrierIncompleteSize",
810 "HashTable",
811 "MutableHashTable",
812 "MutableHashTableOfTensors",
813 "Mutex",
814 "CuckooTable",
815 "IndexTable",
816 "WholeFileReader",
817 "TextLineReader",
818 "FixedLengthRecordReader",
819 "TFRecordReader",
820 "IdentityReader",
821 "RefSwitch",
822 "RefEnter",
823 "RefNextIteration",
824 "RefMerge",
825 "RefIdentity",
826 "LMDBReader",
827 // To be removed after 2017/04/24.
828 "ConditionalAccumulator",
829 "SparseConditionalAccumulator",
830 "Table",
831 };
832 if (std::find(whitelist.begin(), whitelist.end(), op) ==
833 whitelist.end()) {
834 return errors::InvalidArgument(
835 "Node '", node->name(), "' has an ", kAttrName,
836 " attribute inconsistent with the GraphDef for output #", i, ": ",
837 s.error_message());
838 }
839 }
840 }
841 node->ClearAttr(kAttrName);
842 return Status::OK();
843 }
844
ModifyNodeDefForImport(NodeDef * node_def)845 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
846 const OpDef* op_def;
847 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
848 AddDefaultsToNodeDef(*op_def, node_def);
849 TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
850 if (versions()) {
851 TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
852 }
853 return Status::OK();
854 }
855
RemoveInputs(const std::vector<int> & inputs_to_remove,NodeDef * node_def,std::vector<bool> * input_already_exists)856 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
857 std::vector<bool>* input_already_exists) {
858 // Remove 'inputs_to_remove' from 'node_def'
859 NodeDef copy;
860 copy.mutable_input()->Reserve(node_def->input_size() -
861 inputs_to_remove.size());
862 for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
863 if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
864 ++j;
865 } else {
866 copy.add_input()->swap(*node_def->mutable_input(i));
867 }
868 }
869 node_def->mutable_input()->Swap(copy.mutable_input());
870 // Remove 'inputs_to_remove' from 'input_already_exists'
871 for (int idx : inputs_to_remove) {
872 input_already_exists->erase(input_already_exists->begin() + idx);
873 }
874 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
875 }
876
RemapNodeDefInputs(NodeDef * node_def,std::vector<bool> * input_already_exists)877 void GraphConstructor::RemapNodeDefInputs(
878 NodeDef* node_def, std::vector<bool>* input_already_exists) {
879 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
880 std::set<TensorId> control_inputs;
881 std::vector<int> inputs_to_remove;
882
883 for (int i = 0; i < node_def->input_size(); ++i) {
884 auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
885 if (iter == opts_.input_map.end()) continue;
886 used_input_map_keys_.insert(iter->first);
887
888 TensorId new_input = iter->second;
889 if (new_input.second == Graph::kControlSlot) {
890 // Check if we've already remapped a different input to new_input, and if
891 // so remove this input.
892 if (control_inputs.count(new_input) > 0) {
893 inputs_to_remove.push_back(i);
894 continue;
895 }
896 control_inputs.insert(new_input);
897 }
898 node_def->set_input(i, new_input.ToString());
899 (*input_already_exists)[i] = true;
900 }
901 if (!inputs_to_remove.empty()) {
902 RemoveInputs(inputs_to_remove, node_def, input_already_exists);
903 }
904 }
905
AddControlDependencies(NodeDef * node_def,std::vector<bool> * input_already_exists)906 void GraphConstructor::AddControlDependencies(
907 NodeDef* node_def, std::vector<bool>* input_already_exists) {
908 // To avoid adding redundant control dependencies to every imported node, skip
909 // nodes that will inherit the dependencies from another imported node.
910 bool inherits_deps = false;
911 for (int i = 0; i < node_def->input_size(); ++i) {
912 // Assume we won't inherit dependencies from remapped inputs that already
913 // exist in the graph. Even if we're wrong, we'll only add redundant
914 // dependencies.
915 if ((*input_already_exists)[i]) continue;
916
917 // If this input is a backedge, assume we won't inherit the dependencies.
918 // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
919 // worth optimizing these.
920 TensorId id(ParseTensorName(node_def->input(i)));
921 auto iter = gdef_nodes_.find(id.first);
922 DCHECK(iter != gdef_nodes_.end()) << id.first;
923 if (iter->second.node == nullptr) {
924 // Input hasn't been created yet, indicating it's a backedge.
925 continue;
926 }
927 inherits_deps = true;
928 }
929 if (inherits_deps) return;
930
931 // node_def either has no inputs or all remapped inputs, add the control
932 // dependencies
933 for (const string& control_dep : opts_.control_dependencies) {
934 string input = TensorId(control_dep, Graph::kControlSlot).ToString();
935 bool found = false;
936 for (int i = node_def->input_size() - 1; i >= 0; --i) {
937 const string& node_input = node_def->input(i);
938 if (node_input[0] != '^') {
939 // Control inputs are at the end. Break when we reach the non-control
940 // inputs.
941 break;
942 }
943 if (node_input == input) {
944 // Control dependency already exists
945 found = true;
946 break;
947 }
948 }
949 if (found) {
950 continue;
951 }
952 node_def->add_input(input);
953 input_already_exists->push_back(true);
954 }
955 }
956
AddPrefixToNodeDef(const std::vector<bool> & input_already_exists,NodeDef * node_def)957 void GraphConstructor::AddPrefixToNodeDef(
958 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
959 if (prefix_.empty()) return;
960 node_def->set_name(strings::StrCat(prefix_, node_def->name()));
961 // Update names of input nodes
962 for (int i = 0; i < node_def->input_size(); ++i) {
963 // Skip remapped inputs (which already exist in g_ and are not being
964 // imported).
965 if (input_already_exists[i]) continue;
966 StringPiece input(node_def->input(i));
967 if (absl::ConsumePrefix(&input, "^")) {
968 node_def->set_input(i, strings::StrCat("^", prefix_, input));
969 } else {
970 node_def->set_input(i, strings::StrCat(prefix_, input));
971 }
972 }
973 // Update names of colocation groups
974 if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
975 auto* list =
976 node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
977 for (int i = 0; i < list->s_size(); ++i) {
978 StringPiece v(list->s(i));
979 if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
980 list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
981 }
982 }
983 }
984 }
985
UniquifyNames(const std::vector<bool> & input_already_exists,NodeDef * node_def)986 void GraphConstructor::UniquifyNames(
987 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
988 if (NameExistsInGraph(node_def->name())) {
989 string old_name = node_def->name();
990 node_def->set_name(FindUniqueName(node_def->name()));
991 uniquified_names_[old_name] = node_def->name();
992 // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
993 // `name` because we guarantee the original NodeDef names are unique,
994 // meaning we won't generate this name again.
995 }
996 for (int i = 0; i < node_def->input_size(); ++i) {
997 // Skip remapped inputs (which already exist in g_ and are not being
998 // imported).
999 if (input_already_exists[i]) continue;
1000 TensorId id = ParseTensorName(node_def->input(i));
1001 // We require that UniquifyNames() is called on all NodeDefs in topological
1002 // order. This guarantees that node_def's inputs will already be uniquified
1003 // if necessary.
1004 auto iter = uniquified_names_.find(string(id.first));
1005 if (iter == uniquified_names_.end()) continue;
1006 id.first = iter->second;
1007 node_def->set_input(i, id.ToString());
1008 }
1009 }
1010
UpdateUniquifiedColocationNames()1011 void GraphConstructor::UpdateUniquifiedColocationNames() {
1012 for (const auto& pair : gdef_nodes_) {
1013 Node* node = pair.second.node;
1014 if (node == nullptr) continue;
1015 std::vector<string> coloc_values;
1016 if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1017 continue;
1018 bool updated = false;
1019 for (size_t i = 0; i < coloc_values.size(); ++i) {
1020 StringPiece val(coloc_values[i]);
1021 if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1022 auto name_pair = uniquified_names_.find(string(val));
1023 if (name_pair == uniquified_names_.end()) continue;
1024 updated = true;
1025 coloc_values[i] =
1026 strings::StrCat(kColocationGroupPrefix, name_pair->second);
1027 }
1028 }
1029 if (updated) {
1030 node->AddAttr(kColocationAttrName, std::move(coloc_values));
1031 }
1032 }
1033 }
1034
NameExistsInGraph(StringPiece name)1035 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1036 if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1037 if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1038 return false;
1039 }
1040
NameExistsInGraphDef(StringPiece name)1041 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1042 if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1043 if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1044 return false;
1045 }
1046
FindUniqueName(StringPiece original_name)1047 string GraphConstructor::FindUniqueName(StringPiece original_name) {
1048 string name(original_name);
1049 int count = 0;
1050 // Check that any generated names don't collide with imported NodeDefs (as
1051 // well as nodes in g_).
1052 while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1053 name = strings::StrCat(original_name, "_", ++count);
1054 }
1055 return name;
1056 }
1057
IsNodeFullyMapped(const NodeDef & node_def,bool * is_node_mapped)1058 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1059 bool* is_node_mapped) {
1060 const OpDef* op_def;
1061 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1062 for (int i = 0; i < op_def->output_arg_size(); ++i) {
1063 if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1064 *is_node_mapped = false;
1065 return Status::OK();
1066 }
1067 }
1068 *is_node_mapped = true;
1069 return Status::OK();
1070 }
1071
DFS(int cur_node,std::vector<int> * cur_branch,std::vector<bool> * is_on_cur_branch,absl::flat_hash_set<int> * unvisited)1072 void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1073 std::vector<bool>* is_on_cur_branch,
1074 absl::flat_hash_set<int>* unvisited) {
1075 cur_branch->push_back(cur_node);
1076 is_on_cur_branch->at(cur_node) = true;
1077 for (auto next_node : outputs_[cur_node]) {
1078 if (unvisited->find(next_node) != unvisited->end()) {
1079 if (is_on_cur_branch->at(next_node)) {
1080 auto iter =
1081 std::find(cur_branch->begin(), cur_branch->end(), next_node);
1082 LOG(WARNING) << "Cycle detected:";
1083 while (iter != cur_branch->end()) {
1084 LOG(WARNING) << SummarizeNodeDef(get_node_def(*iter));
1085 ++iter;
1086 }
1087 LOG(WARNING) << "End of cycle";
1088 } else {
1089 DFS(next_node, cur_branch, is_on_cur_branch, unvisited);
1090 }
1091 }
1092 }
1093 cur_branch->pop_back();
1094 is_on_cur_branch->at(cur_node) = false;
1095 unvisited->erase(cur_node);
1096 }
1097
PrintCycles()1098 void GraphConstructor::PrintCycles() {
1099 int num_nodes = outputs_.size();
1100 absl::flat_hash_set<int> unvisited;
1101 for (int i = 0; i < num_nodes; i++) {
1102 unvisited.insert(i);
1103 }
1104 while (!unvisited.empty()) {
1105 int cur_node = *unvisited.begin();
1106 // Nodes on the current branch of DFS in traversal order. This is used for
1107 // printing the nodes in the cycle.
1108 std::vector<int> cur_branch;
1109 // This is just to make lookups O(1).
1110 // is_on_cur_branch[i] ==
1111 // (std::find(cur_branch.start(),
1112 // cur_branch.end(), i) != cur_branch.end())
1113 std::vector<bool> is_on_cur_branch(num_nodes, false);
1114 DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited);
1115 }
1116 }
1117
Convert()1118 Status GraphConstructor::Convert() {
1119 // Import functions before adding nodes, since imported nodes may refer to
1120 // functions
1121 if (library()) {
1122 // TODO(b/135705010): Add rvalue overloads into the function library, to
1123 // avoid unnecessarily copying `*library()` here.
1124 TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1125 }
1126
1127 std::vector<InputInfo> inputs;
1128 int processed = 0;
1129
1130 std::vector<bool> input_already_exists;
1131
1132 // Process the NodeDefs in topological order.
1133 // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1134 // inputs, pending_counts_ with the number of inputs for each node and
1135 // outputs_ with the outputs of each node).
1136 while (!ready_.empty()) {
1137 int o = *ready_.begin();
1138 ready_.erase(ready_.begin());
1139 ++processed;
1140 inputs.clear();
1141 bool has_data_back_edge = false;
1142
1143 NodeDef node_def = consume_node_def(o);
1144
1145 // input_already_exists[i] is true iff the i-th input of the node we're
1146 // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1147 // to importing node_defs_). Conversely, input_already_exists[i] is false
1148 // iff the input refers to a node in node_defs_.
1149 input_already_exists.clear();
1150 input_already_exists.resize(node_def.input_size(), false);
1151
1152 ssize_t string_intern_table_index = -1;
1153
1154 if (opts_.importing) {
1155 // Intern the original node name, so that we can use a StringPiece of the
1156 // name to index gdef_nodes_.
1157 string_intern_table_index = string_intern_table_.size();
1158 string_intern_table_.push_back(node_def.name());
1159
1160 if (opts_.skip_mapped_nodes) {
1161 bool is_node_mapped = false;
1162 TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1163 if (is_node_mapped) {
1164 // Skip this node after updating pending_count_ for outputs
1165 UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1166 continue;
1167 }
1168 }
1169
1170 if (!opts_.input_map.empty()) {
1171 // Note that input_already_exists can shrink here
1172 RemapNodeDefInputs(&node_def, &input_already_exists);
1173 }
1174 if (!opts_.control_dependencies.empty()) {
1175 // Note that input_already_exists can grow here
1176 AddControlDependencies(&node_def, &input_already_exists);
1177 }
1178 if (!opts_.default_device.empty() && node_def.device().empty()) {
1179 node_def.set_device(opts_.default_device);
1180 }
1181 }
1182
1183 DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1184 TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1185 for (int i = 0; i < node_def.input_size(); ++i) {
1186 TensorId tensor_id = ParseTensorName(node_def.input(i));
1187 Node* src_node;
1188 int src_index;
1189
1190 if (!input_already_exists[i]) {
1191 // Locate input in newly-imported nodes
1192 auto iter = gdef_nodes_.find(tensor_id.node());
1193 DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1194 src_node = iter->second.node;
1195 src_index = tensor_id.index();
1196 if (src_node == nullptr) has_data_back_edge = true;
1197 } else {
1198 // Input refers to preexistng node in graph
1199 auto iter = existing_nodes_.find(tensor_id.node());
1200 DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1201 src_node = iter->second;
1202 src_index = tensor_id.index();
1203 }
1204
1205 if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1206 std::ostringstream out;
1207 out << "Node '" << node_def.name() << "': Connecting to invalid output "
1208 << tensor_id.index() << " of source node " << tensor_id.node()
1209 << " which has " << src_node->num_outputs() << " outputs.";
1210
1211 if (src_node->type_string() == "If" ||
1212 src_node->type_string() == "StatelessIf" ||
1213 src_node->type_string() == "While" ||
1214 src_node->type_string() == "StatelessWhile") {
1215 out << " Try using "
1216 << "tf.compat.v1.experimental.output_all_intermediates(True).";
1217 }
1218 return errors::InvalidArgument(out.str());
1219 }
1220
1221 inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1222 }
1223
1224 if (has_data_back_edge && !IsMerge(node_def)) {
1225 return errors::InvalidArgument(
1226 "Node '", node_def.name(),
1227 "' had a back edge, but only Merge nodes can have back edges.");
1228 }
1229
1230 Node* node;
1231 if (opts_.importing) {
1232 if (!prefix_.empty()) {
1233 AddPrefixToNodeDef(input_already_exists, &node_def);
1234 }
1235 // Note: no need to uniquify names if the prefix already guarantees
1236 // uniqueness
1237 if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1238 UniquifyNames(input_already_exists, &node_def);
1239 }
1240 }
1241
1242 if (opts_.importing) {
1243 TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1244 } else {
1245 const OpDef* op_def;
1246 TF_RETURN_IF_ERROR(
1247 g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1248 if (opts_.add_default_attributes) {
1249 AddDefaultsToNodeDef(*op_def, &node_def);
1250 }
1251 if (opts_.validate_nodes) {
1252 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1253 }
1254 }
1255
1256 TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1257
1258 if (opts_.importing) {
1259 // Use interned original node name so StringPiece remains valid.
1260 DCHECK_GE(string_intern_table_index, 0);
1261 gdef_nodes_[string_intern_table_[string_intern_table_index]].node = node;
1262 } else {
1263 DCHECK_EQ(string_intern_table_index, -1);
1264 gdef_nodes_[node->name()].node = node;
1265 }
1266
1267 // Remove duplicate control inputs before adding edges to the graph. It
1268 // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1269 auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1270 auto first_control_copy = first_control;
1271 std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1272 inputs.erase(
1273 std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1274 inputs.end());
1275
1276 // Add edges from inputs to *node to the graph.
1277 for (size_t i = 0; i < inputs.size(); ++i) {
1278 if (inputs[i].node == nullptr) {
1279 // Record this back edge, which will be added after all nodes
1280 // are created.
1281 back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1282 } else if (inputs[i].index == Graph::kControlSlot) {
1283 g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1284 } else {
1285 TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1286 }
1287 }
1288
1289 TF_RETURN_IF_ERROR(ValidateShape(node));
1290
1291 // Update pending_count_ for outputs.
1292 UpdatePendingCountAndReady(o, node->IsNextIteration());
1293 }
1294
1295 if (processed < node_def_count()) {
1296 LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1297 << " NODES IN A CYCLE";
1298 for (int64 i = 0; i < node_def_count(); i++) {
1299 if (pending_count_[i] != 0) {
1300 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1301 << " WITH PENDING COUNT = " << pending_count_[i];
1302 }
1303 }
1304 PrintCycles();
1305 return errors::InvalidArgument(node_def_count() - processed,
1306 " nodes in a cycle");
1307 }
1308
1309 return Status::OK();
1310 }
1311
AddBackEdges()1312 Status GraphConstructor::AddBackEdges() {
1313 // Add the back edges after all nodes are created.
1314 for (auto e : back_edges_) {
1315 Node* src_node = gdef_nodes_[e.src_name].node;
1316 if (e.src_index == Graph::kControlSlot) {
1317 g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1318 } else {
1319 TF_RETURN_IF_ERROR(
1320 MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1321 }
1322
1323 VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1324 << e.dst_node->name();
1325 }
1326 return Status::OK();
1327 }
1328
UpdateVersionDef()1329 Status GraphConstructor::UpdateVersionDef() {
1330 if (versions() == nullptr) return Status::OK();
1331
1332 if (!opts_.importing) {
1333 g_->set_versions(*versions());
1334 return Status::OK();
1335 }
1336 VersionDef g_versions = g_->versions();
1337 g_versions.set_producer(
1338 std::min(g_versions.producer(), versions()->producer()));
1339 g_versions.set_min_consumer(
1340 std::max(g_versions.min_consumer(), versions()->min_consumer()));
1341 if (versions()->bad_consumers_size() > 0) {
1342 std::set<int> bad(g_versions.bad_consumers().begin(),
1343 g_versions.bad_consumers().end());
1344 bad.insert(versions()->bad_consumers().begin(),
1345 versions()->bad_consumers().end());
1346 g_versions.clear_bad_consumers();
1347 for (int v : bad) {
1348 g_versions.add_bad_consumers(v);
1349 }
1350 }
1351 g_->set_versions(g_versions);
1352 return Status::OK();
1353 }
1354
PopulateReturnTensors()1355 Status GraphConstructor::PopulateReturnTensors() {
1356 if (opts_.return_tensors.empty()) return Status::OK();
1357 for (const TensorId& id : opts_.return_tensors) {
1358 auto iter = opts_.input_map.find(id);
1359 if (iter == opts_.input_map.end()) {
1360 // Locate id in imported nodes
1361 auto iter = gdef_nodes_.find(id.first);
1362 if (iter == gdef_nodes_.end()) {
1363 return errors::InvalidArgument("Requested return tensor '",
1364 id.ToString(),
1365 "' not found in graph def");
1366 }
1367 int num_outputs = iter->second.node->num_outputs();
1368 if ((id.second < 0 || id.second >= num_outputs) &&
1369 id.second != Graph::kControlSlot) {
1370 return errors::InvalidArgument("Invalid return output ", id.second,
1371 " of node '", id.first, "', which has ",
1372 num_outputs, " output(s)");
1373 }
1374 return_tensors_->push_back({iter->second.node, id.second});
1375 } else {
1376 // id was remapped to existing node
1377 TensorId remapped_id = iter->second;
1378 DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1379 Node* node = existing_nodes_[remapped_id.first];
1380 return_tensors_->push_back({node, remapped_id.second});
1381 }
1382 }
1383 return Status::OK();
1384 }
1385
PopulateReturnNodes()1386 Status GraphConstructor::PopulateReturnNodes() {
1387 if (opts_.return_nodes.empty()) return Status::OK();
1388 for (StringPiece name : opts_.return_nodes) {
1389 auto iter = gdef_nodes_.find(name);
1390 if (iter == gdef_nodes_.end()) {
1391 return errors::InvalidArgument("Requested return node '", name,
1392 "' not found in graph def");
1393 }
1394 return_nodes_->push_back(iter->second.node);
1395 }
1396 return Status::OK();
1397 }
1398
PopulateMissingUnusedInputMapKeys()1399 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1400 if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
1401 for (const auto& input_map_pair : opts_.input_map) {
1402 TensorId key = input_map_pair.first;
1403 if (used_input_map_keys_.count(key) > 0) continue;
1404
1405 auto pair = gdef_nodes_.find(key.first);
1406 if (pair == gdef_nodes_.end()) {
1407 // key's node doesn't exist in GraphDef
1408 missing_unused_input_map_keys_->push_back(key);
1409 continue;
1410 }
1411
1412 // Check that key's index is in bounds. Get the number of outputs from the
1413 // NodeDef, rather than the imported Node, since the Node may not exist if
1414 // opts_.skip_mapped_nodes is true.
1415 const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1416 const OpDef* op_def;
1417 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1418 int num_outputs;
1419 TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1420 if (key.second >= num_outputs) {
1421 // key's index out of bounds
1422 missing_unused_input_map_keys_->push_back(key);
1423 }
1424 }
1425 return Status::OK();
1426 }
1427
Undo()1428 void GraphConstructor::Undo() {
1429 for (const auto& iter : gdef_nodes_) {
1430 if (iter.second.node != nullptr) {
1431 g_->RemoveNode(iter.second.node);
1432 }
1433 }
1434 g_->set_versions(original_versions_);
1435 }
1436
MakeEdge(Node * src,int output_index,Node * dst,int input_index)1437 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1438 int input_index) {
1439 DataType src_out = src->output_type(output_index);
1440 DataType dst_in = dst->input_type(input_index);
1441 if (!TypesCompatible(dst_in, src_out)) {
1442 return errors::InvalidArgument(
1443 "Input ", input_index, " of node ", dst->name(), " was passed ",
1444 DataTypeString(src_out), " from ", src->name(), ":", output_index,
1445 " incompatible with expected ", DataTypeString(dst_in), ".");
1446 }
1447 g_->AddEdge(src, output_index, dst, input_index);
1448 return Status::OK();
1449 }
1450
1451 } // namespace
1452
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,const GraphDef & gdef,Graph * g)1453 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1454 const GraphDef& gdef, Graph* g) {
1455 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1456 return GraphConstructor::Construct(
1457 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1458 /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1459 /*missing_unused_input_map_keys=*/nullptr);
1460 }
1461
ConvertGraphDefToGraph(const GraphConstructorOptions & opts,GraphDef && gdef,Graph * g)1462 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1463 GraphDef&& gdef, Graph* g) {
1464 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1465 return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1466 /*return_tensors=*/nullptr,
1467 /*return_nodes=*/nullptr,
1468 /*missing_unused_input_map_keys=*/nullptr);
1469 }
1470
ConvertNodeDefsToGraph(const GraphConstructorOptions & opts,gtl::ArraySlice<NodeDef> nodes,Graph * g)1471 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1472 gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1473 ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1474 // TODO(irving): Copy will go away once NodeInfo exists
1475 std::vector<const NodeDef*> node_defs;
1476 node_defs.reserve(nodes.size());
1477 for (const auto& n : nodes) {
1478 node_defs.push_back(&n);
1479 }
1480 return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1481 &refiner, /*return_tensors=*/nullptr,
1482 /*return_nodes=*/nullptr,
1483 /*missing_unused_input_map_keys=*/nullptr);
1484 }
1485
ImportGraphDef(const ImportGraphDefOptions & opts,const GraphDef & gdef,Graph * g,ShapeRefiner * refiner,ImportGraphDefResults * results)1486 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1487 Graph* g, ShapeRefiner* refiner,
1488 ImportGraphDefResults* results) {
1489 if (!opts.return_tensors.empty()) {
1490 if (results == nullptr) {
1491 return errors::InvalidArgument(
1492 "results argument to ImportGraphDef() must be non-null if "
1493 "opts.return_tensors is non-empty");
1494 }
1495 }
1496
1497 if (!opts.return_nodes.empty()) {
1498 if (opts.skip_mapped_nodes) {
1499 return errors::InvalidArgument(
1500 "Requesting return_nodes with skip_mapped_nodes set is not currently "
1501 "supported");
1502 }
1503 if (results == nullptr) {
1504 return errors::InvalidArgument(
1505 "results argument to ImportGraphDef() must be non-null if "
1506 "opts.return_nodes is non-empty");
1507 }
1508 }
1509
1510 if (results != nullptr) {
1511 if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1512 !results->missing_unused_input_map_keys.empty()) {
1513 return errors::InvalidArgument(
1514 "All fields in results argument to ImportGraphDef() must be empty.");
1515 }
1516 }
1517
1518 ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1519 if (refiner == nullptr) {
1520 refiner = &default_refiner;
1521 } else {
1522 // Log a warning if we are importing a GraphDef at an older
1523 // producer version after already having added non-source/sink
1524 // nodes to the graph in the past.
1525 if (gdef.versions().producer() > 0 &&
1526 gdef.versions().producer() < refiner->graph_def_version() &&
1527 g->num_nodes() > 2) {
1528 LOG(WARNING) << "Importing a graph with a lower producer version "
1529 << gdef.versions().producer()
1530 << " into an existing graph with producer version "
1531 << refiner->graph_def_version() << ". Shape inference will "
1532 << "have run different parts of the graph with different "
1533 << "producer versions.";
1534 }
1535 }
1536
1537 // Set the graph def version of the refiner as the min of the
1538 // current value and the version from the graph we are about to
1539 // import.
1540 //
1541 // Note: to match Run() semantics, we should re-run shape inference
1542 // on the entire graph if the producer version has changed. For now
1543 // we log the warning above.
1544 refiner->set_graph_def_version(
1545 std::min(refiner->graph_def_version(), gdef.versions().producer()));
1546
1547 if (results == nullptr) {
1548 return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1549 &gdef.library(), g, refiner, nullptr,
1550 nullptr, nullptr);
1551 } else {
1552 return GraphConstructor::Construct(
1553 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1554 &results->return_tensors, &results->return_nodes,
1555 &results->missing_unused_input_map_keys);
1556 }
1557 }
1558
CopyGraph(const Graph & src,Graph * dest)1559 void CopyGraph(const Graph& src, Graph* dest) {
1560 for (Node* n : dest->nodes()) {
1561 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
1562 }
1563
1564 // Copy GraphDef versions
1565 dest->set_versions(src.versions());
1566
1567 // Copy the nodes.
1568 // "Node in src" -> "Node in *dest"
1569 gtl::FlatMap<const Node*, Node*> node_map;
1570 node_map[src.source_node()] = dest->source_node();
1571 node_map[src.sink_node()] = dest->sink_node();
1572 for (Node* n : src.op_nodes()) {
1573 node_map[n] = dest->CopyNode(n);
1574 }
1575
1576 // Copy the edges
1577 for (const Edge* e : src.edges()) {
1578 Node* src_copy = node_map[e->src()];
1579 Node* dst_copy = node_map[e->dst()];
1580 dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1581 }
1582 }
1583
1584 } // namespace tensorflow
1585