1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/node_properties.h"
24 #include "tensorflow/core/framework/op_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/graph_node_util.h"
28 #include "tensorflow/core/graph/while_context.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/public/version.h"
36
37 namespace tensorflow {
38
39 const int Graph::kControlSlot = -1;
40
41 // Node
GetNodeClassForOp(const string & ts)42 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
43 static const absl::flat_hash_map<string, Node::NodeClass>* kNodeClassTable =
44 #define REF_CLASS(key, value) \
45 {key, value}, { "Ref" key, value }
46 new absl::flat_hash_map<string, Node::NodeClass>({
47 // Keep in same order as NodeClass values
48 REF_CLASS("Switch", NC_SWITCH),
49 REF_CLASS("_SwitchN", NC_SWITCH),
50 REF_CLASS("Merge", NC_MERGE),
51 REF_CLASS("Enter", NC_ENTER),
52 REF_CLASS("Exit", NC_EXIT),
53 REF_CLASS("NextIteration", NC_NEXT_ITERATION),
54 {"LoopCond", NC_LOOP_COND},
55 {"ControlTrigger", NC_CONTROL_TRIGGER},
56 {"_Send", NC_SEND},
57 {"_HostSend", NC_HOST_SEND},
58 {"_Recv", NC_RECV},
59 {"_HostRecv", NC_HOST_RECV},
60 {"Const", NC_CONSTANT},
61 {"HostConst", NC_CONSTANT},
62 {"Variable", NC_VARIABLE},
63 {"VariableV2", NC_VARIABLE},
64 REF_CLASS("Identity", NC_IDENTITY),
65 {"GetSessionHandle", NC_GET_SESSION_HANDLE},
66 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
67 {"GetSessionTensor", NC_GET_SESSION_TENSOR},
68 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
69 {"Size", NC_METADATA},
70 {"Shape", NC_METADATA},
71 {"Rank", NC_METADATA},
72 {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
73 {"CollectiveReduce", NC_COLLECTIVE},
74 {"CollectiveBcastSend", NC_COLLECTIVE},
75 {"CollectiveBcastRecv", NC_COLLECTIVE},
76 {"CollectiveGather", NC_COLLECTIVE},
77 {"FakeParam", NC_FAKE_PARAM},
78 {"PartitionedCall", NC_PARTITIONED_CALL},
79 {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
80 {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
81 {"If", NC_IF},
82 {"StatelessIf", NC_IF},
83 {"While", NC_WHILE},
84 {"StatelessWhile", NC_WHILE},
85 {"Case", NC_CASE},
86 {"StatelessCase", NC_CASE},
87 // Not using the constants defined in FunctionLibraryDefinition
88 // for the
89 // 4 ops below because android inference library does not link
90 // tf.function related files.
91 {"_Arg", NC_ARG},
92 {"_DeviceArg", NC_ARG},
93 {"_Retval", NC_RETVAL},
94 {"_DeviceRetval", NC_RETVAL},
95 {"_XlaMerge", NC_MERGE},
96 });
97 #undef REF_CLASS
98
99 auto it = kNodeClassTable->find(ts);
100 if (it != kNodeClassTable->end()) {
101 return it->second;
102 } else {
103 return NC_OTHER;
104 }
105 }
106
DebugString() const107 string Node::DebugString() const {
108 string ret = strings::StrCat("{name:'", name(), "' id:", id_);
109 if (IsSource()) {
110 strings::StrAppend(&ret, " source}");
111 } else if (IsSink()) {
112 strings::StrAppend(&ret, " sink}");
113 } else {
114 strings::StrAppend(&ret, " op device:");
115 strings::StrAppend(&ret, "{", assigned_device_name(), "}");
116 strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
117 }
118 return ret;
119 }
120
Node()121 Node::Node()
122 : id_(-1),
123 cost_id_(-1),
124 class_(NC_UNINITIALIZED),
125 props_(nullptr),
126 assigned_device_name_index_(0),
127 while_ctx_(nullptr) {}
128
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,Node::NodeClass node_class)129 void Node::Initialize(int id, int cost_id,
130 std::shared_ptr<NodeProperties> props,
131 Node::NodeClass node_class) {
132 DCHECK_EQ(id_, -1);
133 DCHECK(in_edges_.empty());
134 DCHECK(out_edges_.empty());
135 id_ = id;
136 cost_id_ = cost_id;
137
138 props_ = std::move(props);
139 class_ = node_class;
140 }
141
Clear()142 void Node::Clear() {
143 in_edges_.clear();
144 out_edges_.clear();
145 id_ = -1;
146 cost_id_ = -1;
147 class_ = NC_UNINITIALIZED;
148 props_.reset();
149 assigned_device_name_index_ = 0;
150 }
151
UpdateProperties()152 void Node::UpdateProperties() {
153 DataTypeVector inputs;
154 DataTypeVector outputs;
155 Status status =
156 InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
157 if (!status.ok()) {
158 LOG(ERROR) << "Failed at updating node: " << status;
159 return;
160 }
161 if (props_->input_types != inputs || props_->output_types != outputs) {
162 if (TF_PREDICT_TRUE(props_.use_count() == 1)) {
163 props_->input_types = inputs;
164 props_->input_types_slice = props_->input_types;
165 props_->output_types = outputs;
166 props_->output_types_slice = props_->output_types;
167 } else {
168 props_ = std::make_shared<NodeProperties>(
169 props_->op_def, std::move(props_->node_def), inputs, outputs);
170 }
171 }
172 }
173
name() const174 const string& Node::name() const { return props_->node_def.name(); }
type_string() const175 const string& Node::type_string() const { return props_->node_def.op(); }
def() const176 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const177 const OpDef& Node::op_def() const { return *props_->op_def; }
178
num_inputs() const179 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const180 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const181 const DataTypeVector& Node::input_types() const { return props_->input_types; }
182
num_outputs() const183 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const184 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const185 const DataTypeVector& Node::output_types() const {
186 return props_->output_types;
187 }
188
attrs() const189 AttrSlice Node::attrs() const { return AttrSlice(def()); }
190
requested_inputs() const191 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
192 return def().input();
193 }
194
requested_device() const195 const string& Node::requested_device() const { return def().device(); }
196
out_nodes() const197 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
198 return gtl::make_range(NeighborIter(out_edges_.begin(), false),
199 NeighborIter(out_edges_.end(), false));
200 }
201
in_nodes() const202 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
203 return gtl::make_range(NeighborIter(in_edges_.begin(), true),
204 NeighborIter(in_edges_.end(), true));
205 }
206
MaybeCopyOnWrite()207 void Node::MaybeCopyOnWrite() {
208 // NodeProperties may be shared between Nodes. Make a copy if so.
209 if (!props_.unique()) {
210 props_ = std::make_shared<NodeProperties>(*props_);
211 }
212 }
213
AddAttrHelper(const string & name)214 AttrValue* Node::AddAttrHelper(const string& name) {
215 MaybeCopyOnWrite();
216 return &((*props_->node_def.mutable_attr())[name]);
217 }
218
ClearAttr(const string & name)219 void Node::ClearAttr(const string& name) {
220 MaybeCopyOnWrite();
221 (*props_->node_def.mutable_attr()).erase(name);
222 }
223
set_name(string name)224 void Node::set_name(string name) {
225 MaybeCopyOnWrite();
226 props_->node_def.set_name(std::move(name));
227 }
228
set_requested_device(const string & device)229 void Node::set_requested_device(const string& device) {
230 MaybeCopyOnWrite();
231 props_->node_def.set_device(device);
232 }
233
set_original_node_names(const std::vector<string> & names)234 void Node::set_original_node_names(const std::vector<string>& names) {
235 MaybeCopyOnWrite();
236 props_->node_def.mutable_experimental_debug_info()
237 ->clear_original_node_names();
238 if (!names.empty()) {
239 *props_->node_def.mutable_experimental_debug_info()
240 ->mutable_original_node_names() = {names.begin(), names.end()};
241 }
242 }
243
input_edge(int idx,const Edge ** e) const244 Status Node::input_edge(int idx, const Edge** e) const {
245 if (idx < 0 || idx >= num_inputs()) {
246 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
247 name(), " only has ", num_inputs(),
248 " inputs.");
249 }
250
251 // This does a linear search over the edges. In the common case,
252 // the number of elements is small enough that this search isn't
253 // expensive. Should it become a bottleneck, one can make an
254 // optimization where, if the number of edges is small, we use
255 // linear iteration, and if the number of edges is large, we perform
256 // an indexing step during construction that keeps an array of Edges
257 // indexed by pointer. This would keep the size of each Node small
258 // in the common case but make this function faster when the number
259 // of edges is large.
260 for (const Edge* edge : in_edges()) {
261 if (edge->dst_input() == idx) {
262 *e = edge;
263 return Status::OK();
264 }
265 }
266
267 return errors::NotFound("Could not find input edge ", idx, " for ", name());
268 }
269
270 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const271 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
272 input_edges->clear();
273 input_edges->resize(num_inputs(), nullptr);
274
275 for (const Edge* edge : in_edges()) {
276 if (edge->IsControlEdge()) continue;
277 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
278 return errors::Internal("Invalid edge input number ", edge->dst_input());
279 }
280 if ((*input_edges)[edge->dst_input()] != nullptr) {
281 return errors::Internal("Duplicate edge input number: ",
282 edge->dst_input());
283 }
284 (*input_edges)[edge->dst_input()] = edge;
285 }
286
287 for (int i = 0; i < num_inputs(); ++i) {
288 if ((*input_edges)[i] == nullptr) {
289 return errors::InvalidArgument("Missing edge input number: ", i);
290 }
291 }
292 return Status::OK();
293 }
294
input_node(int idx,Node ** n) const295 Status Node::input_node(int idx, Node** n) const {
296 const Edge* e;
297 TF_RETURN_IF_ERROR(input_edge(idx, &e));
298 if (e == nullptr) {
299 *n = nullptr;
300 } else {
301 *n = e->src();
302 }
303 return Status::OK();
304 }
305
input_node(int idx,const Node ** const_n) const306 Status Node::input_node(int idx, const Node** const_n) const {
307 Node* n;
308 TF_RETURN_IF_ERROR(input_node(idx, &n));
309 *const_n = n;
310 return Status::OK();
311 }
312
input_tensor(int idx,OutputTensor * t) const313 Status Node::input_tensor(int idx, OutputTensor* t) const {
314 const Edge* e;
315 TF_RETURN_IF_ERROR(input_edge(idx, &e));
316 DCHECK(e != nullptr);
317 *t = OutputTensor(e->src(), e->src_output());
318 return Status::OK();
319 }
320
321 // NodeDebugInfo
322
NodeDebugInfo(const Node & n)323 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)324 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
325 : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
326 ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)327 NodeDebugInfo::NodeDebugInfo(
328 StringPiece node_name, bool has_experimental_debug_info,
329 const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
330 : name(node_name) {
331 if (has_experimental_debug_info) {
332 const auto& names = experimental_debug_info.original_node_names();
333 original_node_names.assign(names.begin(), names.end());
334 }
335 }
336
337 // InputTensor
338
operator ==(const InputTensor & other) const339 bool InputTensor::operator==(const InputTensor& other) const {
340 return node == other.node && index == other.index;
341 }
342
operator ()(InputTensor const & s) const343 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
344 return Hash64Combine(std::hash<const Node*>()(s.node),
345 std::hash<int>()(s.index));
346 }
347
348 // OutputTensor
349
operator ==(const OutputTensor & other) const350 bool OutputTensor::operator==(const OutputTensor& other) const {
351 return node == other.node && index == other.index;
352 }
353
operator ()(OutputTensor const & s) const354 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
355 return Hash64Combine(std::hash<const Node*>()(s.node),
356 std::hash<int>()(s.index));
357 }
358
359 // Graph
360
Graph(const OpRegistryInterface * ops)361 Graph::Graph(const OpRegistryInterface* ops)
362 : ops_(ops, FunctionDefLibrary()),
363 versions_(new VersionDef),
364 arena_(8 << 10 /* 8kB */) {
365 versions_->set_producer(TF_GRAPH_DEF_VERSION);
366 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
367
368 // Initialize the name interning table for assigned_device_name.
369 device_names_.push_back("");
370 DCHECK_EQ(0, InternDeviceName(""));
371
372 // Source and sink have no endpoints, just control edges.
373 NodeDef def;
374 def.set_name("_SOURCE");
375 def.set_op("NoOp");
376 Status status;
377 Node* source = AddNode(def, &status);
378 TF_CHECK_OK(status);
379 CHECK_EQ(source->id(), kSourceId);
380
381 def.set_name("_SINK");
382 Node* sink = AddNode(def, &status);
383 TF_CHECK_OK(status);
384 CHECK_EQ(sink->id(), kSinkId);
385
386 AddControlEdge(source, sink);
387 }
388
Graph(const FunctionLibraryDefinition & flib_def)389 Graph::Graph(const FunctionLibraryDefinition& flib_def)
390 : Graph(flib_def.default_registry()) {
391 // Need a new-enough consumer to support the functions we add to the graph.
392 if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
393 versions_->set_min_consumer(12);
394 }
395 Status s = ops_.AddLibrary(flib_def);
396 CHECK(s.ok()) << s.error_message();
397 }
398
~Graph()399 Graph::~Graph() {
400 // Manually call the destructors for all the Nodes we constructed using
401 // placement new.
402 for (Node* node : nodes_) {
403 if (node != nullptr) {
404 node->~Node();
405 }
406 }
407 for (Node* node : free_nodes_) {
408 node->~Node();
409 }
410 // Edges have no destructor, and we arena-allocated them, so no need to
411 // destroy them.
412 }
413
versions() const414 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)415 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
416
Copy(const Graph & src)417 void Graph::Copy(const Graph& src) {
418 SetConstructionContext(src.GetConstructionContextInternal());
419 for (Node* n : nodes()) {
420 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
421 }
422
423 // Copy GraphDef versions
424 set_versions(src.versions());
425
426 // Copy the nodes.
427 // "Node in src" -> "Node in *dest"
428 gtl::FlatMap<const Node*, Node*> node_map;
429 node_map.reserve(src.num_nodes());
430 node_map[src.source_node()] = source_node();
431 node_map[src.sink_node()] = sink_node();
432 for (Node* n : src.op_nodes()) {
433 auto copy = CopyNode(n);
434 copy->in_edges_.reserve(n->in_edges().size());
435 copy->out_edges_.reserve(n->out_edges().size());
436 node_map[n] = copy;
437 }
438
439 // Copy the edges
440 edges_.reserve(src.num_edges());
441 for (const Edge* e : src.edges()) {
442 Node* src_copy = node_map[e->src()];
443 Node* dst_copy = node_map[e->dst()];
444 AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
445 }
446 }
447
AddNode(NodeDef node_def,Status * status)448 Node* Graph::AddNode(NodeDef node_def, Status* status) {
449 const OpRegistrationData* op_reg_data;
450 status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
451 if (!status->ok()) return nullptr;
452
453 DataTypeVector inputs;
454 DataTypeVector outputs;
455 status->Update(
456 InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
457 if (!status->ok()) {
458 *status = AttachDef(*status, node_def);
459 return nullptr;
460 }
461
462 Node::NodeClass node_class = op_reg_data->is_function_op
463 ? Node::NC_FUNCTION_OP
464 : Node::GetNodeClassForOp(node_def.op());
465
466 Node* node = AllocateNode(
467 std::make_shared<NodeProperties>(&op_reg_data->op_def,
468 std::move(node_def), inputs, outputs),
469 nullptr, node_class);
470 return node;
471 }
472
CopyNode(const Node * node)473 Node* Graph::CopyNode(const Node* node) {
474 DCHECK(!node->IsSource());
475 DCHECK(!node->IsSink());
476 Node* copy = AllocateNode(node->props_, node, node->class_);
477 copy->set_assigned_device_name(node->assigned_device_name());
478
479 // Since the OpDef of a function may be owned by the Graph that owns 'node',
480 // relookup the OpDef in the target graph. If it differs, then clone the
481 // node properties with the updated OpDef.
482 const OpDef* op_def;
483 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
484 if (op_def != node->props_->op_def) {
485 copy->MaybeCopyOnWrite();
486 copy->props_->op_def = op_def;
487 }
488 copy->SetStackTrace(node->GetStackTrace());
489
490 return copy;
491 }
492
RemoveNode(Node * node)493 void Graph::RemoveNode(Node* node) {
494 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
495 DCHECK(!node->IsSource());
496 DCHECK(!node->IsSink());
497
498 // Remove any edges involving this node.
499 for (const Edge* e : node->in_edges_) {
500 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
501 edges_[e->id_] = nullptr;
502 RecycleEdge(e);
503 --num_edges_;
504 }
505 node->in_edges_.clear();
506 for (const Edge* e : node->out_edges_) {
507 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
508 edges_[e->id_] = nullptr;
509 RecycleEdge(e);
510 --num_edges_;
511 }
512 node->out_edges_.clear();
513 ReleaseNode(node);
514 }
515
AddEdge(Node * source,int x,Node * dest,int y)516 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
517 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
518 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
519
520 // source/sink must only be linked via control slots, and
521 // control slots must only be linked to control slots.
522 if (source == source_node() || dest == sink_node() || x == kControlSlot ||
523 y == kControlSlot) {
524 DCHECK_EQ(x, kControlSlot) << source->DebugString();
525 DCHECK_EQ(y, kControlSlot) << dest->DebugString();
526 }
527
528 Edge* e = nullptr;
529 if (free_edges_.empty()) {
530 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
531 } else {
532 e = free_edges_.back();
533 free_edges_.pop_back();
534 }
535 e->id_ = edges_.size();
536 e->src_ = source;
537 e->dst_ = dest;
538 e->src_output_ = x;
539 e->dst_input_ = y;
540 CHECK(source->out_edges_.insert(e).second);
541 CHECK(dest->in_edges_.insert(e).second);
542 edges_.push_back(e);
543 ++num_edges_;
544 return e;
545 }
546
RemoveEdge(const Edge * e)547 void Graph::RemoveEdge(const Edge* e) {
548 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
549 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
550 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
551 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
552 CHECK_EQ(e, edges_[e->id_]);
553 CHECK_GT(num_edges_, 0);
554
555 edges_[e->id_] = nullptr;
556 RecycleEdge(e);
557 --num_edges_;
558 }
559
RecycleEdge(const Edge * e)560 void Graph::RecycleEdge(const Edge* e) {
561 free_edges_.push_back(const_cast<Edge*>(e));
562 }
563
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)564 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
565 bool allow_duplicates) {
566 if (!allow_duplicates) {
567 for (const Edge* edge : dest->in_edges()) {
568 if (edge->IsControlEdge() && edge->src() == source) {
569 // The requested edge already exists.
570 return nullptr;
571 }
572 }
573 }
574 // Modify dest's NodeDef if necessary.
575 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
576 // Check if this input is already in dest's NodeDef.
577 const string new_input = strings::StrCat("^", source->name());
578 bool input_exists = false;
579 for (const string& input : dest->props_->node_def.input()) {
580 if (input == new_input) {
581 input_exists = true;
582 break;
583 }
584 }
585 if (!input_exists) {
586 dest->MaybeCopyOnWrite();
587 dest->props_->node_def.add_input(new_input);
588 }
589 }
590 return AddEdge(source, kControlSlot, dest, kControlSlot);
591 }
592
RemoveControlEdge(const Edge * e)593 void Graph::RemoveControlEdge(const Edge* e) {
594 if (!e->src_->IsSource() && !e->dst_->IsSink()) {
595 e->dst_->MaybeCopyOnWrite();
596 string e_src_name = strings::StrCat("^", e->src_->name());
597 auto* inputs = e->dst_->props_->node_def.mutable_input();
598 for (auto it = inputs->begin(); it != inputs->end(); ++it) {
599 if (*it == e_src_name) {
600 inputs->erase(it);
601 break;
602 }
603 }
604 }
605 RemoveEdge(e);
606 }
607
608 namespace {
FindEdge(const Node * dst,int index)609 const Edge* FindEdge(const Node* dst, int index) {
610 for (const Edge* e : dst->in_edges()) {
611 if (e->dst_input() == index) return e;
612 }
613 return nullptr;
614 }
615 } // namespace
616
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)617 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
618 int dst_index) {
619 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
620 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
621 const Edge* e = FindEdge(dst, dst_index);
622 if (e == nullptr) {
623 return errors::InvalidArgument("Couldn't find edge to ",
624 FormatNodeForError(*dst));
625 }
626 RemoveEdge(e);
627 AddEdge(new_src, new_src_index, dst, dst_index);
628 dst->MaybeCopyOnWrite();
629 (*dst->props_->node_def.mutable_input())[dst_index] =
630 strings::StrCat(new_src->name(), ":", new_src_index);
631 return Status::OK();
632 }
633
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)634 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
635 if (!dst->IsWhileNode()) {
636 return errors::Internal(
637 "dst argument to AddWhileEdgeHack should be a While op, got: ",
638 dst->DebugString());
639 }
640 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
641 // Find the current number of data inputs. We'll add the new edge to the next
642 // missing data input.
643 int dst_index = 0;
644 for (const Edge* edge : dst->in_edges()) {
645 if (edge->IsControlEdge()) continue;
646 ++dst_index;
647 }
648 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
649 AddEdge(new_src, new_src_index, dst, dst_index);
650 dst->MaybeCopyOnWrite();
651 dst->props_->node_def.add_input(
652 strings::StrCat(new_src->name(), ":", new_src_index));
653 return Status::OK();
654 }
655
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)656 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
657 // Need a new-enough consumer to support the functions we add to the graph.
658 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
659 versions_->set_min_consumer(12);
660 }
661 return ops_.AddLibrary(fdef_lib);
662 }
663
664 namespace {
665
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)666 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
667 if (src_slot == Graph::kControlSlot) {
668 dst->add_input(strings::StrCat("^", src_name));
669 } else if (src_slot == 0) {
670 dst->add_input(src_name.data(), src_name.size());
671 } else {
672 dst->add_input(strings::StrCat(src_name, ":", src_slot));
673 }
674 }
675
676 } // namespace
677
ToGraphDef(GraphDef * graph_def) const678 void Graph::ToGraphDef(GraphDef* graph_def) const {
679 ToGraphDefSubRange(graph_def, 0);
680 }
681
ToGraphDefDebug() const682 GraphDef Graph::ToGraphDefDebug() const {
683 GraphDef ret;
684 ToGraphDef(&ret);
685 return ret;
686 }
687
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const688 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
689 graph_def->Clear();
690 *graph_def->mutable_versions() = versions();
691 *graph_def->mutable_library() = ops_.ToProto();
692
693 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
694
695 std::vector<const Edge*>
696 inputs; // Construct this outside the loop for speed.
697 for (auto id = from_node_id; id < num_node_ids(); ++id) {
698 const Node* node = FindNodeId(id);
699 if (node == nullptr || !node->IsOp()) continue;
700 NodeDef* node_def = graph_def->add_node();
701 *node_def = node->def();
702
703 // Use the node's assigned device, if any, instead of the device requested
704 // in the NodeDef.
705 if (!node->assigned_device_name().empty()) {
706 node_def->set_device(node->assigned_device_name());
707 }
708
709 // Get the inputs for this Node. We make sure control inputs are
710 // after data inputs, as required by GraphDef.
711 inputs.clear();
712 inputs.resize(node->num_inputs(), nullptr);
713 for (const Edge* edge : node->in_edges()) {
714 if (edge->IsControlEdge()) {
715 inputs.push_back(edge);
716 } else {
717 DCHECK(edge->dst_input() < inputs.size())
718 << "Edge " << edge->DebugString()
719 << " is overflowing the expected number of inputs ("
720 << node->num_inputs() << ") for node " << node->DebugString();
721 CHECK(inputs[edge->dst_input()] == nullptr)
722 << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
723 << " conflicts with pre-existing input edge "
724 << inputs[edge->dst_input()]->src()->name() << "->"
725 << inputs[edge->dst_input()]->dst()->name();
726
727 inputs[edge->dst_input()] = edge;
728 }
729 }
730 // Sort the control inputs for more predictable serialization.
731 std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
732 [](const Edge* a, const Edge* b) -> bool {
733 return a->src()->name() < b->src()->name();
734 });
735 node_def->clear_input();
736 node_def->mutable_input()->Reserve(inputs.size());
737
738 for (size_t i = 0; i < inputs.size(); ++i) {
739 const Edge* edge = inputs[i];
740 if (edge == nullptr) {
741 if (i < node->requested_inputs().size()) {
742 node_def->add_input(node->requested_inputs()[i]);
743 } else {
744 node_def->add_input("");
745 }
746 } else {
747 const Node* src = edge->src();
748 if (!src->IsOp()) continue;
749 AddInput(node_def, src->name(), edge->src_output());
750 }
751 }
752 }
753 }
754
NewName(StringPiece prefix)755 string Graph::NewName(StringPiece prefix) {
756 return strings::StrCat(prefix, "/_", name_counter_++);
757 }
758
IsValidNode(const Node * node) const759 Status Graph::IsValidNode(const Node* node) const {
760 if (node == nullptr) {
761 return errors::InvalidArgument("Node is null");
762 }
763 const int id = node->id();
764 if (id < 0) {
765 return errors::InvalidArgument("node id ", id, " is less than zero");
766 }
767 if (static_cast<size_t>(id) >= nodes_.size()) {
768 return errors::InvalidArgument(
769 "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
770 }
771 if (nodes_[id] != node) {
772 return errors::InvalidArgument("Node with id ", id,
773 " is different from the passed in node. "
774 "Does it belong to a different graph?");
775 }
776 return Status::OK();
777 }
778
IsValidOutputTensor(const Node * node,int idx) const779 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
780 TF_RETURN_IF_ERROR(IsValidNode(node));
781 if (idx >= node->num_outputs() || idx < 0) {
782 return errors::OutOfRange("Node '", node->name(), "' (type: '",
783 node->op_def().name(),
784 "', num of outputs: ", node->num_outputs(),
785 ") does not have ", "output ", idx);
786 }
787 return Status::OK();
788 }
789
IsValidInputTensor(const Node * node,int idx) const790 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
791 TF_RETURN_IF_ERROR(IsValidNode(node));
792 if (idx >= node->num_inputs() || idx < 0) {
793 return errors::OutOfRange("Node '", node->name(), "' (type: '",
794 node->op_def().name(),
795 "', num of inputs: ", node->num_inputs(),
796 ") does not have ", "input ", idx);
797 }
798 return Status::OK();
799 }
800
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,Node::NodeClass node_class)801 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
802 const Node* cost_node, Node::NodeClass node_class) {
803 Node* node = nullptr;
804 if (free_nodes_.empty()) {
805 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
806 } else {
807 node = free_nodes_.back();
808 free_nodes_.pop_back();
809 }
810 node->graph_ = this;
811 const int id = nodes_.size();
812 int cost_id = cost_node ? cost_node->cost_id() : id;
813 node->Initialize(id, cost_id, std::move(props), node_class);
814 nodes_.push_back(node);
815 ++num_nodes_;
816 return node;
817 }
818
ReleaseNode(Node * node)819 void Graph::ReleaseNode(Node* node) {
820 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
821 nodes_[node->id()] = nullptr;
822 free_nodes_.push_back(node);
823 --num_nodes_;
824 node->Clear();
825 }
826
827 // Ensures that 'device_name' is present in the device name table, and returns
828 // the index of that device name. The index is stable, and can be used in
829 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)830 int Graph::InternDeviceName(const string& device_name) {
831 // Special case, very common. Also, this allows us to use a single map
832 // lookup below, instead of two. The 'if (index_cell > 0)' test below
833 // relies on this check.
834 if (device_name.empty()) {
835 return 0;
836 }
837
838 int& index_cell = device_names_map_[device_name];
839 if (index_cell > 0) {
840 return index_cell;
841 }
842
843 const int index = device_names_map_.size();
844 index_cell = index;
845 device_names_.push_back(device_name);
846 return index;
847 }
848
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)849 Status Graph::AddWhileContext(StringPiece frame_name,
850 std::vector<Node*> enter_nodes,
851 std::vector<Node*> exit_nodes,
852 OutputTensor cond_output,
853 std::vector<OutputTensor> body_inputs,
854 std::vector<OutputTensor> body_outputs,
855 WhileContext** result) {
856 auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
857 string(frame_name),
858 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
859 cond_output, std::move(body_inputs),
860 std::move(body_outputs))));
861 if (!pair.second) {
862 *result = nullptr;
863 return errors::InvalidArgument("WhileContext with frame name '", frame_name,
864 "' already exists");
865 }
866 *result = &pair.first->second;
867 return Status::OK();
868 }
869
BuildNodeNameIndex() const870 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
871 std::unordered_map<string, Node*> result;
872 for (Node* n : nodes()) {
873 result[n->name()] = n;
874 }
875 return result;
876 }
877
DebugString() const878 string Edge::DebugString() const {
879 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
880 src_output_, dst_->name().c_str(), dst_input_);
881 }
882
883 } // namespace tensorflow
884