1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <vector>
19
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph_node_util.h"
27 #include "tensorflow/core/graph/while_context.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/public/version.h"
35
36 namespace tensorflow {
37
38 const int Graph::kControlSlot = -1;
39
40 struct NodeProperties {
41 public:
NodePropertiestensorflow::NodeProperties42 NodeProperties(const OpDef* op_def, NodeDef node_def,
43 const DataTypeSlice inputs, const DataTypeSlice outputs)
44 : op_def(op_def),
45 node_def(std::move(node_def)),
46 input_types(inputs.begin(), inputs.end()),
47 output_types(outputs.begin(), outputs.end()) {}
48
49 const OpDef* op_def; // not owned
50 NodeDef node_def;
51 const DataTypeVector input_types;
52 const DataTypeVector output_types;
53 };
54
55 // Node
56
57 #define REF_CLASS(key, value) \
58 {key, value}, { "Ref" key, value }
59
60 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
61 *new std::unordered_map<string, Node::NodeClass>({
62 // Keep in same order as NodeClass values
63 REF_CLASS("Switch", NC_SWITCH),
64 REF_CLASS("_SwitchN", NC_SWITCH),
65 REF_CLASS("Merge", NC_MERGE),
66 REF_CLASS("Enter", NC_ENTER),
67 REF_CLASS("Exit", NC_EXIT),
68 REF_CLASS("NextIteration", NC_NEXT_ITERATION),
69 {"LoopCond", NC_LOOP_COND},
70 {"ControlTrigger", NC_CONTROL_TRIGGER},
71 {"_Send", NC_SEND},
72 {"_HostSend", NC_HOST_SEND},
73 {"_Recv", NC_RECV},
74 {"_HostRecv", NC_HOST_RECV},
75 {"Const", NC_CONSTANT},
76 {"HostConst", NC_CONSTANT},
77 {"Variable", NC_VARIABLE},
78 {"VariableV2", NC_VARIABLE},
79 REF_CLASS("Identity", NC_IDENTITY),
80 {"GetSessionHandle", NC_GET_SESSION_HANDLE},
81 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
82 {"GetSessionTensor", NC_GET_SESSION_TENSOR},
83 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
84 {"Size", NC_METADATA},
85 {"Shape", NC_METADATA},
86 {"Rank", NC_METADATA},
87 {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
88 {"CollectiveReduce", NC_COLLECTIVE},
89 {"CollectiveBcastSend", NC_COLLECTIVE},
90 {"CollectiveBcastRecv", NC_COLLECTIVE},
91 {"CollectiveGather", NC_COLLECTIVE},
92 {"FakeParam", NC_FAKE_PARAM},
93 {"PartitionedCall", NC_PARTITIONED_CALL},
94 {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
95 {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
96 {"If", NC_IF},
97 {"StatelessIf", NC_IF},
98 {"While", NC_WHILE},
99 {"StatelessWhile", NC_WHILE},
100 // Not using the constants defined in FunctionLibraryDefinition for the
101 // 4 ops below because android inference library does not link
102 // tf.function related files.
103 {"_Arg", NC_ARG},
104 {"_DeviceArg", NC_ARG},
105 {"_Retval", NC_RETVAL},
106 {"_DeviceRetval", NC_RETVAL},
107 {"_XlaMerge", NC_MERGE},
108 });
109
110 #undef REF_CLASS
111
GetNodeClassForOp(const string & ts)112 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
113 auto it = kNodeClassTable.find(ts);
114 if (it != kNodeClassTable.end()) {
115 return it->second;
116 } else {
117 return NC_OTHER;
118 }
119 }
120
DebugString() const121 string Node::DebugString() const {
122 string ret = strings::StrCat("{name:'", name(), "' id:", id_);
123 if (IsSource()) {
124 strings::StrAppend(&ret, " source}");
125 } else if (IsSink()) {
126 strings::StrAppend(&ret, " sink}");
127 } else {
128 strings::StrAppend(&ret, " op device:");
129 strings::StrAppend(&ret, "{", assigned_device_name(), "}");
130 strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
131 }
132 return ret;
133 }
134
Node()135 Node::Node()
136 : id_(-1),
137 cost_id_(-1),
138 class_(NC_UNINITIALIZED),
139 props_(nullptr),
140 assigned_device_name_index_(0),
141 while_ctx_(nullptr) {}
142
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,bool is_function_op)143 void Node::Initialize(int id, int cost_id,
144 std::shared_ptr<NodeProperties> props,
145 bool is_function_op) {
146 DCHECK_EQ(id_, -1);
147 DCHECK(in_edges_.empty());
148 DCHECK(out_edges_.empty());
149 id_ = id;
150 cost_id_ = cost_id;
151
152 props_ = std::move(props);
153 // Initialize the class_ based on the type string
154 if (is_function_op) {
155 class_ = NC_FUNCTION_OP;
156 } else {
157 class_ = GetNodeClassForOp(props_->node_def.op());
158 }
159 }
160
Clear()161 void Node::Clear() {
162 in_edges_.clear();
163 out_edges_.clear();
164 id_ = -1;
165 cost_id_ = -1;
166 class_ = NC_UNINITIALIZED;
167 props_.reset();
168 assigned_device_name_index_ = 0;
169 }
170
UpdateProperties()171 void Node::UpdateProperties() {
172 DataTypeVector inputs;
173 DataTypeVector outputs;
174 Status status =
175 InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
176 if (!status.ok()) {
177 LOG(ERROR) << "Failed at updating node: " << status;
178 return;
179 }
180 props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
181 inputs, outputs);
182 }
183
name() const184 const string& Node::name() const { return props_->node_def.name(); }
type_string() const185 const string& Node::type_string() const { return props_->node_def.op(); }
def() const186 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const187 const OpDef& Node::op_def() const { return *props_->op_def; }
188
num_inputs() const189 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const190 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const191 const DataTypeVector& Node::input_types() const { return props_->input_types; }
192
num_outputs() const193 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const194 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const195 const DataTypeVector& Node::output_types() const {
196 return props_->output_types;
197 }
198
attrs() const199 AttrSlice Node::attrs() const { return AttrSlice(def()); }
200
requested_inputs() const201 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
202 return def().input();
203 }
204
requested_device() const205 const string& Node::requested_device() const { return def().device(); }
206
out_nodes() const207 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
208 return gtl::make_range(NeighborIter(out_edges_.begin(), false),
209 NeighborIter(out_edges_.end(), false));
210 }
211
in_nodes() const212 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
213 return gtl::make_range(NeighborIter(in_edges_.begin(), true),
214 NeighborIter(in_edges_.end(), true));
215 }
216
MaybeCopyOnWrite()217 void Node::MaybeCopyOnWrite() {
218 // NodeProperties may be shared between Nodes. Make a copy if so.
219 if (!props_.unique()) {
220 props_ = std::make_shared<NodeProperties>(*props_);
221 }
222 }
223
AddAttrHelper(const string & name)224 AttrValue* Node::AddAttrHelper(const string& name) {
225 MaybeCopyOnWrite();
226 return &((*props_->node_def.mutable_attr())[name]);
227 }
228
ClearAttr(const string & name)229 void Node::ClearAttr(const string& name) {
230 MaybeCopyOnWrite();
231 (*props_->node_def.mutable_attr()).erase(name);
232 }
233
set_name(string name)234 void Node::set_name(string name) {
235 MaybeCopyOnWrite();
236 props_->node_def.set_name(std::move(name));
237 }
238
set_requested_device(const string & device)239 void Node::set_requested_device(const string& device) {
240 MaybeCopyOnWrite();
241 props_->node_def.set_device(device);
242 }
243
set_original_node_names(const std::vector<string> & names)244 void Node::set_original_node_names(const std::vector<string>& names) {
245 MaybeCopyOnWrite();
246 props_->node_def.mutable_experimental_debug_info()
247 ->clear_original_node_names();
248 if (!names.empty()) {
249 *props_->node_def.mutable_experimental_debug_info()
250 ->mutable_original_node_names() = {names.begin(), names.end()};
251 }
252 }
253
input_edge(int idx,const Edge ** e) const254 Status Node::input_edge(int idx, const Edge** e) const {
255 if (idx < 0 || idx >= num_inputs()) {
256 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
257 name(), " only has ", num_inputs(),
258 " inputs.");
259 }
260
261 // This does a linear search over the edges. In the common case,
262 // the number of elements is small enough that this search isn't
263 // expensive. Should it become a bottleneck, one can make an
264 // optimization where, if the number of edges is small, we use
265 // linear iteration, and if the number of edges is large, we perform
266 // an indexing step during construction that keeps an array of Edges
267 // indexed by pointer. This would keep the size of each Node small
268 // in the common case but make this function faster when the number
269 // of edges is large.
270 for (const Edge* edge : in_edges()) {
271 if (edge->dst_input() == idx) {
272 *e = edge;
273 return Status::OK();
274 }
275 }
276
277 return errors::NotFound("Could not find input edge ", idx, " for ", name());
278 }
279
280 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const281 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
282 input_edges->clear();
283 input_edges->resize(num_inputs(), nullptr);
284
285 for (const Edge* edge : in_edges()) {
286 if (edge->IsControlEdge()) continue;
287 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
288 return errors::Internal("Invalid edge input number ", edge->dst_input());
289 }
290 if ((*input_edges)[edge->dst_input()] != nullptr) {
291 return errors::Internal("Duplicate edge input number: ",
292 edge->dst_input());
293 }
294 (*input_edges)[edge->dst_input()] = edge;
295 }
296
297 for (int i = 0; i < num_inputs(); ++i) {
298 if ((*input_edges)[i] == nullptr) {
299 return errors::InvalidArgument("Missing edge input number: ", i);
300 }
301 }
302 return Status::OK();
303 }
304
input_node(int idx,Node ** n) const305 Status Node::input_node(int idx, Node** n) const {
306 const Edge* e;
307 TF_RETURN_IF_ERROR(input_edge(idx, &e));
308 if (e == nullptr) {
309 *n = nullptr;
310 } else {
311 *n = e->src();
312 }
313 return Status::OK();
314 }
315
input_node(int idx,const Node ** const_n) const316 Status Node::input_node(int idx, const Node** const_n) const {
317 Node* n;
318 TF_RETURN_IF_ERROR(input_node(idx, &n));
319 *const_n = n;
320 return Status::OK();
321 }
322
input_tensor(int idx,OutputTensor * t) const323 Status Node::input_tensor(int idx, OutputTensor* t) const {
324 const Edge* e;
325 TF_RETURN_IF_ERROR(input_edge(idx, &e));
326 DCHECK(e != nullptr);
327 *t = OutputTensor(e->src(), e->src_output());
328 return Status::OK();
329 }
330
331 // NodeDebugInfo
332
NodeDebugInfo(const Node & n)333 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)334 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
335 : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
336 ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)337 NodeDebugInfo::NodeDebugInfo(
338 StringPiece node_name, bool has_experimental_debug_info,
339 const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
340 : name(node_name) {
341 if (has_experimental_debug_info) {
342 const auto& names = experimental_debug_info.original_node_names();
343 original_node_names.assign(names.begin(), names.end());
344 }
345 }
346
347 // InputTensor
348
operator ==(const InputTensor & other) const349 bool InputTensor::operator==(const InputTensor& other) const {
350 return node == other.node && index == other.index;
351 }
352
operator ()(InputTensor const & s) const353 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
354 return Hash64Combine(std::hash<const Node*>()(s.node),
355 std::hash<int>()(s.index));
356 }
357
358 // OutputTensor
359
operator ==(const OutputTensor & other) const360 bool OutputTensor::operator==(const OutputTensor& other) const {
361 return node == other.node && index == other.index;
362 }
363
operator ()(OutputTensor const & s) const364 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
365 return Hash64Combine(std::hash<const Node*>()(s.node),
366 std::hash<int>()(s.index));
367 }
368
369 // Graph
370
Graph(const OpRegistryInterface * ops)371 Graph::Graph(const OpRegistryInterface* ops)
372 : ops_(ops, FunctionDefLibrary()),
373 versions_(new VersionDef),
374 arena_(8 << 10 /* 8kB */) {
375 versions_->set_producer(TF_GRAPH_DEF_VERSION);
376 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
377
378 // Initialize the name interning table for assigned_device_name.
379 device_names_.push_back("");
380 DCHECK_EQ(0, InternDeviceName(""));
381
382 // Source and sink have no endpoints, just control edges.
383 NodeDef def;
384 def.set_name("_SOURCE");
385 def.set_op("NoOp");
386 Status status;
387 Node* source = AddNode(def, &status);
388 TF_CHECK_OK(status);
389 CHECK_EQ(source->id(), kSourceId);
390
391 def.set_name("_SINK");
392 Node* sink = AddNode(def, &status);
393 TF_CHECK_OK(status);
394 CHECK_EQ(sink->id(), kSinkId);
395
396 AddControlEdge(source, sink);
397 }
398
Graph(const FunctionLibraryDefinition & flib_def)399 Graph::Graph(const FunctionLibraryDefinition& flib_def)
400 : Graph(flib_def.default_registry()) {
401 // Need a new-enough consumer to support the functions we add to the graph.
402 if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
403 versions_->set_min_consumer(12);
404 }
405 Status s = ops_.AddLibrary(flib_def);
406 CHECK(s.ok()) << s.error_message();
407 }
408
~Graph()409 Graph::~Graph() {
410 // Manually call the destructors for all the Nodes we constructed using
411 // placement new.
412 for (Node* node : nodes_) {
413 if (node != nullptr) {
414 node->~Node();
415 }
416 }
417 for (Node* node : free_nodes_) {
418 node->~Node();
419 }
420 // Edges have no destructor, and we arena-allocated them, so no need to
421 // destroy them.
422 }
423
versions() const424 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)425 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
426
AddNode(NodeDef node_def,Status * status)427 Node* Graph::AddNode(NodeDef node_def, Status* status) {
428 const OpRegistrationData* op_reg_data;
429 status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
430 if (!status->ok()) return nullptr;
431
432 DataTypeVector inputs;
433 DataTypeVector outputs;
434 status->Update(
435 InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
436 if (!status->ok()) {
437 *status = AttachDef(*status, node_def);
438 return nullptr;
439 }
440
441 Node* node = AllocateNode(
442 std::make_shared<NodeProperties>(&op_reg_data->op_def,
443 std::move(node_def), inputs, outputs),
444 nullptr, op_reg_data->is_function_op);
445 return node;
446 }
447
CopyNode(const Node * node)448 Node* Graph::CopyNode(const Node* node) {
449 DCHECK(!node->IsSource());
450 DCHECK(!node->IsSink());
451 Node* copy =
452 AllocateNode(node->props_, node, node->class_ == Node::NC_FUNCTION_OP);
453 copy->set_assigned_device_name(node->assigned_device_name());
454
455 // Since the OpDef of a function may be owned by the Graph that owns 'node',
456 // relookup the OpDef in the target graph. If it differs, then clone the
457 // node properties with the updated OpDef.
458 const OpDef* op_def;
459 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
460 if (op_def != node->props_->op_def) {
461 copy->MaybeCopyOnWrite();
462 copy->props_->op_def = op_def;
463 }
464
465 return copy;
466 }
467
RemoveNode(Node * node)468 void Graph::RemoveNode(Node* node) {
469 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
470 DCHECK(!node->IsSource());
471 DCHECK(!node->IsSink());
472
473 // Remove any edges involving this node.
474 for (const Edge* e : node->in_edges_) {
475 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
476 edges_[e->id_] = nullptr;
477 RecycleEdge(e);
478 --num_edges_;
479 }
480 node->in_edges_.clear();
481 for (const Edge* e : node->out_edges_) {
482 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
483 edges_[e->id_] = nullptr;
484 RecycleEdge(e);
485 --num_edges_;
486 }
487 node->out_edges_.clear();
488 ReleaseNode(node);
489 }
490
AddEdge(Node * source,int x,Node * dest,int y)491 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
492 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
493 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
494
495 // source/sink must only be linked via control slots, and
496 // control slots must only be linked to control slots.
497 if (source == source_node() || dest == sink_node() || x == kControlSlot ||
498 y == kControlSlot) {
499 DCHECK_EQ(x, kControlSlot) << source->DebugString();
500 DCHECK_EQ(y, kControlSlot) << dest->DebugString();
501 }
502
503 Edge* e = nullptr;
504 if (free_edges_.empty()) {
505 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
506 } else {
507 e = free_edges_.back();
508 free_edges_.pop_back();
509 }
510 e->id_ = edges_.size();
511 e->src_ = source;
512 e->dst_ = dest;
513 e->src_output_ = x;
514 e->dst_input_ = y;
515 CHECK(source->out_edges_.insert(e).second);
516 CHECK(dest->in_edges_.insert(e).second);
517 edges_.push_back(e);
518 ++num_edges_;
519 return e;
520 }
521
RemoveEdge(const Edge * e)522 void Graph::RemoveEdge(const Edge* e) {
523 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
524 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
525 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
526 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
527 CHECK_EQ(e, edges_[e->id_]);
528 CHECK_GT(num_edges_, 0);
529
530 edges_[e->id_] = nullptr;
531 RecycleEdge(e);
532 --num_edges_;
533 }
534
RecycleEdge(const Edge * e)535 void Graph::RecycleEdge(const Edge* e) {
536 free_edges_.push_back(const_cast<Edge*>(e));
537 }
538
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)539 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
540 bool allow_duplicates) {
541 if (!allow_duplicates) {
542 for (const Edge* edge : dest->in_edges()) {
543 if (edge->IsControlEdge() && edge->src() == source) {
544 // The requested edge already exists.
545 return nullptr;
546 }
547 }
548 }
549 // Modify dest's NodeDef if necessary.
550 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
551 // Check if this input is already in dest's NodeDef.
552 const string new_input = strings::StrCat("^", source->name());
553 bool input_exists = false;
554 for (const string& input : dest->props_->node_def.input()) {
555 if (input == new_input) {
556 input_exists = true;
557 break;
558 }
559 }
560 if (!input_exists) {
561 dest->MaybeCopyOnWrite();
562 dest->props_->node_def.add_input(new_input);
563 }
564 }
565 return AddEdge(source, kControlSlot, dest, kControlSlot);
566 }
567
RemoveControlEdge(const Edge * e)568 void Graph::RemoveControlEdge(const Edge* e) {
569 if (!e->src_->IsSource() && !e->dst_->IsSink()) {
570 e->dst_->MaybeCopyOnWrite();
571 string e_src_name = strings::StrCat("^", e->src_->name());
572 auto* inputs = e->dst_->props_->node_def.mutable_input();
573 for (auto it = inputs->begin(); it != inputs->end(); ++it) {
574 if (*it == e_src_name) {
575 inputs->erase(it);
576 break;
577 }
578 }
579 }
580 RemoveEdge(e);
581 }
582
583 namespace {
FindEdge(const Node * dst,int index)584 const Edge* FindEdge(const Node* dst, int index) {
585 for (const Edge* e : dst->in_edges()) {
586 if (e->dst_input() == index) return e;
587 }
588 return nullptr;
589 }
590 } // namespace
591
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)592 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
593 int dst_index) {
594 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
595 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
596 const Edge* e = FindEdge(dst, dst_index);
597 if (e == nullptr) {
598 return errors::InvalidArgument("Couldn't find edge to ",
599 FormatNodeForError(*dst));
600 }
601 RemoveEdge(e);
602 AddEdge(new_src, new_src_index, dst, dst_index);
603 dst->MaybeCopyOnWrite();
604 (*dst->props_->node_def.mutable_input())[dst_index] =
605 strings::StrCat(new_src->name(), ":", new_src_index);
606 return Status::OK();
607 }
608
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)609 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
610 if (!dst->IsWhileNode()) {
611 return errors::Internal(
612 "dst argument to AddWhileEdgeHack should be a While op, got: ",
613 dst->DebugString());
614 }
615 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
616 // Find the current number of data inputs. We'll add the new edge to the next
617 // missing data input.
618 int dst_index = 0;
619 for (const Edge* edge : dst->in_edges()) {
620 if (edge->IsControlEdge()) continue;
621 ++dst_index;
622 }
623 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
624 AddEdge(new_src, new_src_index, dst, dst_index);
625 dst->MaybeCopyOnWrite();
626 dst->props_->node_def.add_input(
627 strings::StrCat(new_src->name(), ":", new_src_index));
628 return Status::OK();
629 }
630
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)631 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
632 // Need a new-enough consumer to support the functions we add to the graph.
633 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
634 versions_->set_min_consumer(12);
635 }
636 return ops_.AddLibrary(fdef_lib);
637 }
638
639 namespace {
640
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)641 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
642 if (src_slot == Graph::kControlSlot) {
643 dst->add_input(strings::StrCat("^", src_name));
644 } else if (src_slot == 0) {
645 dst->add_input(src_name.data(), src_name.size());
646 } else {
647 dst->add_input(strings::StrCat(src_name, ":", src_slot));
648 }
649 }
650
651 } // namespace
652
ToGraphDef(GraphDef * graph_def) const653 void Graph::ToGraphDef(GraphDef* graph_def) const {
654 ToGraphDefSubRange(graph_def, 0);
655 }
656
ToGraphDefDebug() const657 GraphDef Graph::ToGraphDefDebug() const {
658 GraphDef ret;
659 ToGraphDef(&ret);
660 return ret;
661 }
662
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const663 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
664 graph_def->Clear();
665 *graph_def->mutable_versions() = versions();
666 *graph_def->mutable_library() = ops_.ToProto();
667
668 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
669
670 std::vector<const Edge*>
671 inputs; // Construct this outside the loop for speed.
672 for (auto id = from_node_id; id < num_node_ids(); ++id) {
673 const Node* node = FindNodeId(id);
674 if (node == nullptr || !node->IsOp()) continue;
675 NodeDef* node_def = graph_def->add_node();
676 *node_def = node->def();
677
678 // Use the node's assigned device, if any, instead of the device requested
679 // in the NodeDef.
680 if (!node->assigned_device_name().empty()) {
681 node_def->set_device(node->assigned_device_name());
682 }
683
684 // Get the inputs for this Node. We make sure control inputs are
685 // after data inputs, as required by GraphDef.
686 inputs.clear();
687 inputs.resize(node->num_inputs(), nullptr);
688 for (const Edge* edge : node->in_edges()) {
689 if (edge->IsControlEdge()) {
690 inputs.push_back(edge);
691 } else {
692 DCHECK(edge->dst_input() < inputs.size())
693 << "Edge " << edge->DebugString()
694 << " is overflowing the expected number of inputs ("
695 << node->num_inputs() << ") for node " << node->DebugString();
696 CHECK(inputs[edge->dst_input()] == nullptr)
697 << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
698 << " conflicts with pre-existing input edge "
699 << inputs[edge->dst_input()]->src()->name() << "->"
700 << inputs[edge->dst_input()]->dst()->name();
701
702 inputs[edge->dst_input()] = edge;
703 }
704 }
705 // Sort the control inputs for more predictable serialization.
706 std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
707 [](const Edge* a, const Edge* b) -> bool {
708 return a->src()->name() < b->src()->name();
709 });
710 node_def->clear_input();
711 node_def->mutable_input()->Reserve(inputs.size());
712
713 for (size_t i = 0; i < inputs.size(); ++i) {
714 const Edge* edge = inputs[i];
715 if (edge == nullptr) {
716 if (i < node->requested_inputs().size()) {
717 node_def->add_input(node->requested_inputs()[i]);
718 } else {
719 node_def->add_input("");
720 }
721 } else {
722 const Node* src = edge->src();
723 if (!src->IsOp()) continue;
724 AddInput(node_def, src->name(), edge->src_output());
725 }
726 }
727 }
728 }
729
NewName(StringPiece prefix)730 string Graph::NewName(StringPiece prefix) {
731 return strings::StrCat(prefix, "/_", name_counter_++);
732 }
733
IsValidNode(const Node * node) const734 Status Graph::IsValidNode(const Node* node) const {
735 if (node == nullptr) {
736 return errors::InvalidArgument("Node is null");
737 }
738 const int id = node->id();
739 if (id < 0) {
740 return errors::InvalidArgument("node id ", id, " is less than zero");
741 }
742 if (static_cast<size_t>(id) >= nodes_.size()) {
743 return errors::InvalidArgument(
744 "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
745 }
746 if (nodes_[id] != node) {
747 return errors::InvalidArgument("Node with id ", id,
748 " is different from the passed in node. "
749 "Does it belong to a different graph?");
750 }
751 return Status::OK();
752 }
753
IsValidOutputTensor(const Node * node,int idx) const754 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
755 TF_RETURN_IF_ERROR(IsValidNode(node));
756 if (idx >= node->num_outputs() || idx < 0) {
757 return errors::OutOfRange("Node '", node->name(), "' (type: '",
758 node->op_def().name(),
759 "', num of outputs: ", node->num_outputs(),
760 ") does not have ", "output ", idx);
761 }
762 return Status::OK();
763 }
764
IsValidInputTensor(const Node * node,int idx) const765 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
766 TF_RETURN_IF_ERROR(IsValidNode(node));
767 if (idx >= node->num_inputs() || idx < 0) {
768 return errors::OutOfRange("Node '", node->name(), "' (type: '",
769 node->op_def().name(),
770 "', num of inputs: ", node->num_inputs(),
771 ") does not have ", "input ", idx);
772 }
773 return Status::OK();
774 }
775
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,bool is_function_op)776 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
777 const Node* cost_node, bool is_function_op) {
778 Node* node = nullptr;
779 if (free_nodes_.empty()) {
780 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
781 } else {
782 node = free_nodes_.back();
783 free_nodes_.pop_back();
784 }
785 node->graph_ = this;
786 const int id = nodes_.size();
787 int cost_id = cost_node ? cost_node->cost_id() : id;
788 node->Initialize(id, cost_id, std::move(props), is_function_op);
789 nodes_.push_back(node);
790 ++num_nodes_;
791 return node;
792 }
793
ReleaseNode(Node * node)794 void Graph::ReleaseNode(Node* node) {
795 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
796 nodes_[node->id()] = nullptr;
797 free_nodes_.push_back(node);
798 --num_nodes_;
799 node->Clear();
800 }
801
802 // Ensures that 'device_name' is present in the device name table, and returns
803 // the index of that device name. The index is stable, and can be used in
804 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)805 int Graph::InternDeviceName(const string& device_name) {
806 // Special case, very common. Also, this allows us to use a single map
807 // lookup below, instead of two. The 'if (index_cell > 0)' test below
808 // relies on this check.
809 if (device_name.empty()) {
810 return 0;
811 }
812
813 int& index_cell = device_names_map_[device_name];
814 if (index_cell > 0) {
815 return index_cell;
816 }
817
818 const int index = device_names_map_.size();
819 index_cell = index;
820 device_names_.push_back(device_name);
821 return index;
822 }
823
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)824 Status Graph::AddWhileContext(StringPiece frame_name,
825 std::vector<Node*> enter_nodes,
826 std::vector<Node*> exit_nodes,
827 OutputTensor cond_output,
828 std::vector<OutputTensor> body_inputs,
829 std::vector<OutputTensor> body_outputs,
830 WhileContext** result) {
831 auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
832 string(frame_name),
833 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
834 cond_output, std::move(body_inputs),
835 std::move(body_outputs))));
836 if (!pair.second) {
837 *result = nullptr;
838 return errors::InvalidArgument("WhileContext with frame name '", frame_name,
839 "' already exists");
840 }
841 *result = &pair.first->second;
842 return Status::OK();
843 }
844
BuildNodeNameIndex() const845 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
846 std::unordered_map<string, Node*> result;
847 for (Node* n : nodes()) {
848 result[n->name()] = n;
849 }
850 return result;
851 }
852
DebugString() const853 string Edge::DebugString() const {
854 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
855 src_output_, dst_->name().c_str(), dst_input_);
856 }
857
858 } // namespace tensorflow
859