1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_properties.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/graph/graph_node_util.h"
29 #include "tensorflow/core/graph/while_context.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/lib/strings/stringprintf.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/public/version.h"
38
39 namespace tensorflow {
40
41 const int Graph::kControlSlot = -1;
42
43 // Node
GetNodeClassForOp(const std::string & ts)44 Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) {
45 static const absl::flat_hash_map<std::string, Node::NodeClass>*
46 kNodeClassTable =
47 #define REF_CLASS(key, value) \
48 {key, value}, { "Ref" key, value }
49 new absl::flat_hash_map<std::string, Node::NodeClass>({
50 // Keep in same order as NodeClass values
51 REF_CLASS("Switch", NC_SWITCH),
52 REF_CLASS("_SwitchN", NC_SWITCH),
53 REF_CLASS("Merge", NC_MERGE),
54 REF_CLASS("Enter", NC_ENTER),
55 REF_CLASS("Exit", NC_EXIT),
56 REF_CLASS("NextIteration", NC_NEXT_ITERATION),
57 {"LoopCond", NC_LOOP_COND},
58 {"ControlTrigger", NC_CONTROL_TRIGGER},
59 {"_Send", NC_SEND},
60 {"_HostSend", NC_HOST_SEND},
61 {"_Recv", NC_RECV},
62 {"_HostRecv", NC_HOST_RECV},
63 {"Const", NC_CONSTANT},
64 {"HostConst", NC_CONSTANT},
65 {"Variable", NC_VARIABLE},
66 {"VariableV2", NC_VARIABLE},
67 REF_CLASS("Identity", NC_IDENTITY),
68 {"GetSessionHandle", NC_GET_SESSION_HANDLE},
69 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
70 {"GetSessionTensor", NC_GET_SESSION_TENSOR},
71 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
72 {"Size", NC_METADATA},
73 {"Shape", NC_METADATA},
74 {"Rank", NC_METADATA},
75 {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
76 {"CollectiveReduce", NC_COLLECTIVE},
77 {"CollectiveBcastSend", NC_COLLECTIVE},
78 {"CollectiveBcastRecv", NC_COLLECTIVE},
79 {"CollectiveGather", NC_COLLECTIVE},
80 {"FakeParam", NC_FAKE_PARAM},
81 {"PartitionedCall", NC_PARTITIONED_CALL},
82 {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
83 {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
84 {"If", NC_IF},
85 {"StatelessIf", NC_IF},
86 {"While", NC_WHILE},
87 {"StatelessWhile", NC_WHILE},
88 {"Case", NC_CASE},
89 {"StatelessCase", NC_CASE},
90 // Not using the constants defined in FunctionLibraryDefinition
91 // for the
92 // 4 ops below because android inference library does not link
93 // tf.function related files.
94 {"_Arg", NC_ARG},
95 {"_DeviceArg", NC_ARG},
96 {"_Retval", NC_RETVAL},
97 {"_DeviceRetval", NC_RETVAL},
98 {"_XlaMerge", NC_MERGE},
99 });
100 #undef REF_CLASS
101
102 auto it = kNodeClassTable->find(ts);
103 if (it != kNodeClassTable->end()) {
104 return it->second;
105 } else {
106 return NC_OTHER;
107 }
108 }
109
DebugString() const110 std::string Node::DebugString() const {
111 std::string ret = strings::StrCat("{name:'", name(), "' id:", id_);
112 if (IsSource()) {
113 strings::StrAppend(&ret, " source}");
114 } else if (IsSink()) {
115 strings::StrAppend(&ret, " sink}");
116 } else {
117 strings::StrAppend(&ret, " op device:", "{requested: '", requested_device(),
118 "', assigned: '", assigned_device_name(), "'}", " def:{",
119 SummarizeNode(*this), "}}");
120 }
121 return ret;
122 }
123
Node()124 Node::Node()
125 : id_(-1),
126 cost_id_(-1),
127 class_(NC_UNINITIALIZED),
128 props_(nullptr),
129 assigned_device_name_index_(0),
130 while_ctx_(nullptr) {}
131
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,Node::NodeClass node_class)132 void Node::Initialize(int id, int cost_id,
133 std::shared_ptr<NodeProperties> props,
134 Node::NodeClass node_class) {
135 DCHECK_EQ(id_, -1);
136 DCHECK(in_edges_.empty());
137 DCHECK(out_edges_.empty());
138 id_ = id;
139 cost_id_ = cost_id;
140
141 props_ = std::move(props);
142 class_ = node_class;
143 }
144
Clear()145 void Node::Clear() {
146 in_edges_.clear();
147 out_edges_.clear();
148 id_ = -1;
149 cost_id_ = -1;
150 class_ = NC_UNINITIALIZED;
151 props_.reset();
152 assigned_device_name_index_ = 0;
153 }
154
UpdateProperties()155 void Node::UpdateProperties() {
156 DataTypeVector inputs;
157 DataTypeVector outputs;
158 Status status =
159 InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
160 if (!status.ok()) {
161 LOG(ERROR) << "Failed at updating node: " << status;
162 return;
163 }
164 if (props_->input_types != inputs || props_->output_types != outputs) {
165 if (TF_PREDICT_TRUE(props_.use_count() == 1)) {
166 props_->input_types = inputs;
167 props_->input_types_slice = props_->input_types;
168 props_->output_types = outputs;
169 props_->output_types_slice = props_->output_types;
170 } else {
171 props_ = std::make_shared<NodeProperties>(
172 props_->op_def, std::move(props_->node_def), inputs, outputs);
173 }
174 }
175 }
176
name() const177 const std::string& Node::name() const { return props_->node_def.name(); }
type_string() const178 const std::string& Node::type_string() const { return props_->node_def.op(); }
def() const179 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const180 const OpDef& Node::op_def() const { return *props_->op_def; }
181
mutable_def()182 NodeDef* Node::mutable_def() { return &props_->node_def; }
183
num_inputs() const184 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32_t i) const185 DataType Node::input_type(int32_t i) const { return props_->input_types[i]; }
input_types() const186 const DataTypeVector& Node::input_types() const { return props_->input_types; }
187
num_outputs() const188 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32_t o) const189 DataType Node::output_type(int32_t o) const { return props_->output_types[o]; }
output_types() const190 const DataTypeVector& Node::output_types() const {
191 return props_->output_types;
192 }
193
attrs() const194 AttrSlice Node::attrs() const { return AttrSlice(def()); }
195
requested_inputs() const196 const protobuf::RepeatedPtrField<std::string>& Node::requested_inputs() const {
197 return def().input();
198 }
199
requested_device() const200 const std::string& Node::requested_device() const { return def().device(); }
201
out_nodes() const202 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
203 return gtl::make_range(NeighborIter(out_edges_.begin(), false),
204 NeighborIter(out_edges_.end(), false));
205 }
206
in_nodes() const207 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
208 return gtl::make_range(NeighborIter(in_edges_.begin(), true),
209 NeighborIter(in_edges_.end(), true));
210 }
211
MaybeCopyOnWrite()212 void Node::MaybeCopyOnWrite() {
213 // NodeProperties may be shared between Nodes. Make a copy if so.
214 if (!props_.unique()) {
215 props_ = std::make_shared<NodeProperties>(*props_);
216 }
217 }
218
AddAttrHelper(const std::string & name)219 AttrValue* Node::AddAttrHelper(const std::string& name) {
220 MaybeCopyOnWrite();
221 return &((*props_->node_def.mutable_attr())[name]);
222 }
223
ClearAttr(const std::string & name)224 void Node::ClearAttr(const std::string& name) {
225 MaybeCopyOnWrite();
226 (*props_->node_def.mutable_attr()).erase(name);
227 }
228
set_name(std::string name)229 void Node::set_name(std::string name) {
230 MaybeCopyOnWrite();
231 props_->node_def.set_name(std::move(name));
232 }
233
set_requested_device(const std::string & device)234 void Node::set_requested_device(const std::string& device) {
235 MaybeCopyOnWrite();
236 props_->node_def.set_device(device);
237 }
238
set_original_node_names(const std::vector<std::string> & names)239 void Node::set_original_node_names(const std::vector<std::string>& names) {
240 MaybeCopyOnWrite();
241 props_->node_def.mutable_experimental_debug_info()
242 ->clear_original_node_names();
243 if (!names.empty()) {
244 *props_->node_def.mutable_experimental_debug_info()
245 ->mutable_original_node_names() = {names.begin(), names.end()};
246 }
247 }
248
input_edge(int idx,const Edge ** e) const249 Status Node::input_edge(int idx, const Edge** e) const {
250 if (idx < 0 || idx >= num_inputs()) {
251 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
252 name(), " only has ", num_inputs(),
253 " inputs.");
254 }
255
256 // This does a linear search over the edges. In the common case,
257 // the number of elements is small enough that this search isn't
258 // expensive. Should it become a bottleneck, one can make an
259 // optimization where, if the number of edges is small, we use
260 // linear iteration, and if the number of edges is large, we perform
261 // an indexing step during construction that keeps an array of Edges
262 // indexed by pointer. This would keep the size of each Node small
263 // in the common case but make this function faster when the number
264 // of edges is large.
265 for (const Edge* edge : in_edges()) {
266 if (edge->dst_input() == idx) {
267 *e = edge;
268 return Status::OK();
269 }
270 }
271
272 return errors::NotFound("Could not find input edge ", idx, " for ", name());
273 }
274
275 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const276 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
277 input_edges->clear();
278 input_edges->resize(num_inputs(), nullptr);
279
280 for (const Edge* edge : in_edges()) {
281 if (edge->IsControlEdge()) continue;
282 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
283 return errors::Internal("Invalid edge input number ", edge->dst_input());
284 }
285 if ((*input_edges)[edge->dst_input()] != nullptr) {
286 return errors::Internal("Duplicate edge input number: ",
287 edge->dst_input());
288 }
289 (*input_edges)[edge->dst_input()] = edge;
290 }
291
292 for (int i = 0; i < num_inputs(); ++i) {
293 if ((*input_edges)[i] == nullptr) {
294 return errors::InvalidArgument("Missing edge input number: ", i);
295 }
296 }
297 return Status::OK();
298 }
299
input_node(int idx,Node ** n) const300 Status Node::input_node(int idx, Node** n) const {
301 const Edge* e;
302 TF_RETURN_IF_ERROR(input_edge(idx, &e));
303 if (e == nullptr) {
304 *n = nullptr;
305 } else {
306 *n = e->src();
307 }
308 return Status::OK();
309 }
310
input_node(int idx,const Node ** const_n) const311 Status Node::input_node(int idx, const Node** const_n) const {
312 Node* n;
313 TF_RETURN_IF_ERROR(input_node(idx, &n));
314 *const_n = n;
315 return Status::OK();
316 }
317
input_tensor(int idx,OutputTensor * t) const318 Status Node::input_tensor(int idx, OutputTensor* t) const {
319 const Edge* e;
320 TF_RETURN_IF_ERROR(input_edge(idx, &e));
321 DCHECK(e != nullptr);
322 *t = OutputTensor(e->src(), e->src_output());
323 return Status::OK();
324 }
325
326 // NodeDebugInfo
327
NodeDebugInfo(const Node & n)328 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)329 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
330 : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
331 ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)332 NodeDebugInfo::NodeDebugInfo(
333 StringPiece node_name, bool has_experimental_debug_info,
334 const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
335 : name(node_name) {
336 if (has_experimental_debug_info) {
337 const auto& names = experimental_debug_info.original_node_names();
338 original_node_names.assign(names.begin(), names.end());
339 }
340 }
341
342 // InputTensor
343
operator ==(const InputTensor & other) const344 bool InputTensor::operator==(const InputTensor& other) const {
345 return node == other.node && index == other.index;
346 }
347
operator ()(InputTensor const & s) const348 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
349 return Hash64Combine(std::hash<const Node*>()(s.node),
350 std::hash<int>()(s.index));
351 }
352
353 // OutputTensor
354
operator ==(const OutputTensor & other) const355 bool OutputTensor::operator==(const OutputTensor& other) const {
356 return node == other.node && index == other.index;
357 }
358
operator ()(OutputTensor const & s) const359 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
360 return Hash64Combine(std::hash<const Node*>()(s.node),
361 std::hash<int>()(s.index));
362 }
363
364 // Graph
365
Graph(const OpRegistryInterface * ops)366 Graph::Graph(const OpRegistryInterface* ops)
367 : ops_(ops, FunctionDefLibrary()),
368 versions_(new VersionDef),
369 arena_(8 << 10 /* 8kB */) {
370 versions_->set_producer(TF_GRAPH_DEF_VERSION);
371 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
372
373 // Initialize the name interning table for assigned_device_name.
374 device_names_.push_back("");
375 DCHECK_EQ(0, InternDeviceName(""));
376
377 // Source and sink have no endpoints, just control edges.
378 NodeDef def;
379 def.set_name("_SOURCE");
380 def.set_op("NoOp");
381 Status status;
382 Node* source = AddNode(def, &status);
383 TF_CHECK_OK(status);
384 CHECK_EQ(source->id(), kSourceId);
385
386 def.set_name("_SINK");
387 Node* sink = AddNode(def, &status);
388 TF_CHECK_OK(status);
389 CHECK_EQ(sink->id(), kSinkId);
390
391 AddControlEdge(source, sink);
392 }
393
Graph(const FunctionLibraryDefinition & flib_def)394 Graph::Graph(const FunctionLibraryDefinition& flib_def)
395 : Graph(flib_def.default_registry()) {
396 // Need a new-enough consumer to support the functions we add to the graph.
397 if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
398 versions_->set_min_consumer(12);
399 }
400 Status s = ops_.AddLibrary(flib_def);
401 CHECK(s.ok()) << s.error_message();
402 }
403
~Graph()404 Graph::~Graph() {
405 // Manually call the destructors for all the Nodes we constructed using
406 // placement new.
407 for (Node* node : nodes_) {
408 if (node != nullptr) {
409 node->~Node();
410 }
411 }
412 for (Node* node : free_nodes_) {
413 node->~Node();
414 }
415 // Edges have no destructor, and we arena-allocated them, so no need to
416 // destroy them.
417 }
418
Clone()419 std::unique_ptr<Graph> Graph::Clone() {
420 std::unique_ptr<Graph> new_graph(new Graph(flib_def()));
421 new_graph->Copy(*this);
422 return new_graph;
423 }
424
versions() const425 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)426 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
427
Copy(const Graph & src)428 void Graph::Copy(const Graph& src) {
429 SetConstructionContext(src.GetConstructionContextInternal());
430 for (Node* n : nodes()) {
431 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
432 }
433
434 // Copy GraphDef versions
435 set_versions(src.versions());
436
437 // Copy the nodes.
438 // "Node in src" -> "Node in *dest"
439 gtl::FlatMap<const Node*, Node*> node_map;
440 node_map.reserve(src.num_nodes());
441 node_map[src.source_node()] = source_node();
442 node_map[src.sink_node()] = sink_node();
443 for (Node* n : src.op_nodes()) {
444 auto copy = CopyNode(n);
445 copy->in_edges_.reserve(n->in_edges().size());
446 copy->out_edges_.reserve(n->out_edges().size());
447 node_map[n] = copy;
448 }
449
450 // Copy the edges
451 edges_.reserve(src.num_edges());
452 for (const Edge* e : src.edges()) {
453 Node* src_copy = node_map[e->src()];
454 Node* dst_copy = node_map[e->dst()];
455 AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
456 }
457
458 types_ = src.types_;
459 node_name_to_out_type_ = src.node_name_to_out_type_;
460 }
461
AddNode(NodeDef node_def,Status * status)462 Node* Graph::AddNode(NodeDef node_def, Status* status) {
463 const OpRegistrationData* op_reg_data;
464 status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
465 if (!status->ok()) return nullptr;
466
467 DataTypeVector inputs;
468 DataTypeVector outputs;
469 status->Update(
470 InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
471 if (!status->ok()) {
472 *status = AttachDef(*status, node_def);
473 return nullptr;
474 }
475
476 Node::NodeClass node_class = op_reg_data->is_function_op
477 ? Node::NC_FUNCTION_OP
478 : Node::GetNodeClassForOp(node_def.op());
479
480 Node* node = AllocateNode(
481 std::make_shared<NodeProperties>(&op_reg_data->op_def,
482 std::move(node_def), inputs, outputs),
483 nullptr, node_class);
484 return node;
485 }
486
CopyNode(const Node * node)487 Node* Graph::CopyNode(const Node* node) {
488 DCHECK(!node->IsSource());
489 DCHECK(!node->IsSink());
490 Node* copy = AllocateNode(node->props_, node, node->class_);
491 copy->set_assigned_device_name(node->assigned_device_name());
492
493 // Since the OpDef of a function may be owned by the Graph that owns 'node',
494 // relookup the OpDef in the target graph. If it differs, then clone the
495 // node properties with the updated OpDef.
496 const OpDef* op_def;
497 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
498 if (op_def != node->props_->op_def) {
499 copy->MaybeCopyOnWrite();
500 copy->props_->op_def = op_def;
501 }
502 copy->SetStackTrace(node->GetStackTrace());
503
504 return copy;
505 }
506
RemoveNode(Node * node)507 void Graph::RemoveNode(Node* node) {
508 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
509 DCHECK(!node->IsSource());
510 DCHECK(!node->IsSink());
511
512 // Remove any edges involving this node.
513 for (const Edge* e : node->in_edges_) {
514 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
515 edges_[e->id_] = nullptr;
516 RecycleEdge(e);
517 --num_edges_;
518 }
519 node->in_edges_.clear();
520 for (const Edge* e : node->out_edges_) {
521 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
522 edges_[e->id_] = nullptr;
523 RecycleEdge(e);
524 --num_edges_;
525 }
526 node->out_edges_.clear();
527 ReleaseNode(node);
528 }
529
AddEdge(Node * source,int x,Node * dest,int y)530 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
531 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
532 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
533
534 // source/sink must only be linked via control slots, and
535 // control slots must only be linked to control slots.
536 if (source == source_node() || dest == sink_node() || x == kControlSlot ||
537 y == kControlSlot) {
538 DCHECK_EQ(x, kControlSlot) << source->DebugString();
539 DCHECK_EQ(y, kControlSlot) << dest->DebugString();
540 }
541
542 Edge* e = nullptr;
543 if (free_edges_.empty()) {
544 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
545 } else {
546 e = free_edges_.back();
547 free_edges_.pop_back();
548 }
549 e->id_ = edges_.size();
550 e->src_ = source;
551 e->dst_ = dest;
552 e->src_output_ = x;
553 e->dst_input_ = y;
554 CHECK(source->out_edges_.insert(e).second);
555 CHECK(dest->in_edges_.insert(e).second);
556 edges_.push_back(e);
557 ++num_edges_;
558 return e;
559 }
560
RemoveEdge(const Edge * e)561 void Graph::RemoveEdge(const Edge* e) {
562 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
563 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
564 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
565 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
566 CHECK_EQ(e, edges_[e->id_]);
567 CHECK_GT(num_edges_, 0);
568
569 edges_[e->id_] = nullptr;
570 RecycleEdge(e);
571 --num_edges_;
572 }
573
RecycleEdge(const Edge * e)574 void Graph::RecycleEdge(const Edge* e) {
575 free_edges_.push_back(const_cast<Edge*>(e));
576 }
577
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)578 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
579 bool allow_duplicates) {
580 if (!allow_duplicates) {
581 for (const Edge* edge : dest->in_edges()) {
582 if (edge->IsControlEdge() && edge->src() == source) {
583 // The requested edge already exists.
584 return nullptr;
585 }
586 }
587 }
588 // Modify dest's NodeDef if necessary.
589 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
590 // Check if this input is already in dest's NodeDef.
591 const std::string new_input = strings::StrCat("^", source->name());
592 bool input_exists = false;
593 for (const std::string& input : dest->props_->node_def.input()) {
594 if (input == new_input) {
595 input_exists = true;
596 break;
597 }
598 }
599 if (!input_exists) {
600 dest->MaybeCopyOnWrite();
601 dest->props_->node_def.add_input(new_input);
602 }
603 }
604 return AddEdge(source, kControlSlot, dest, kControlSlot);
605 }
606
RemoveControlEdge(const Edge * e)607 void Graph::RemoveControlEdge(const Edge* e) {
608 if (!e->src_->IsSource() && !e->dst_->IsSink()) {
609 e->dst_->MaybeCopyOnWrite();
610 std::string e_src_name = strings::StrCat("^", e->src_->name());
611 auto* inputs = e->dst_->props_->node_def.mutable_input();
612 for (auto it = inputs->begin(); it != inputs->end(); ++it) {
613 if (*it == e_src_name) {
614 inputs->erase(it);
615 break;
616 }
617 }
618 }
619 RemoveEdge(e);
620 }
621
622 namespace {
FindEdge(const Node * dst,int index)623 const Edge* FindEdge(const Node* dst, int index) {
624 for (const Edge* e : dst->in_edges()) {
625 if (e->dst_input() == index) return e;
626 }
627 return nullptr;
628 }
629 } // namespace
630
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)631 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
632 int dst_index) {
633 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
634 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
635 const Edge* e = FindEdge(dst, dst_index);
636 if (e == nullptr) {
637 return errors::InvalidArgument("Couldn't find edge to ",
638 FormatNodeForError(*dst));
639 }
640 RemoveEdge(e);
641 AddEdge(new_src, new_src_index, dst, dst_index);
642 dst->MaybeCopyOnWrite();
643 (*dst->props_->node_def.mutable_input())[dst_index] =
644 strings::StrCat(new_src->name(), ":", new_src_index);
645 return Status::OK();
646 }
647
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)648 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
649 if (!dst->IsWhileNode()) {
650 return errors::Internal(
651 "dst argument to AddWhileEdgeHack should be a While op, got: ",
652 dst->DebugString());
653 }
654 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
655 // Find the current number of data inputs. We'll add the new edge to the next
656 // missing data input.
657 int dst_index = 0;
658 for (const Edge* edge : dst->in_edges()) {
659 if (edge->IsControlEdge()) continue;
660 ++dst_index;
661 }
662 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
663 AddEdge(new_src, new_src_index, dst, dst_index);
664 dst->MaybeCopyOnWrite();
665 dst->props_->node_def.add_input(
666 strings::StrCat(new_src->name(), ":", new_src_index));
667 return Status::OK();
668 }
669
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)670 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
671 // Need a new-enough consumer to support the functions we add to the graph.
672 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
673 versions_->set_min_consumer(12);
674 }
675 return ops_.AddLibrary(fdef_lib);
676 }
677
678 namespace {
679
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)680 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
681 if (src_slot == Graph::kControlSlot) {
682 dst->add_input(strings::StrCat("^", src_name));
683 } else if (src_slot == 0) {
684 dst->add_input(src_name.data(), src_name.size());
685 } else {
686 dst->add_input(strings::StrCat(src_name, ":", src_slot));
687 }
688 }
689
690 } // namespace
691
ToGraphDef(GraphDef * graph_def) const692 void Graph::ToGraphDef(GraphDef* graph_def) const {
693 ToGraphDefSubRange(graph_def, 0);
694 }
695
ToGraphDefDebug() const696 GraphDef Graph::ToGraphDefDebug() const {
697 GraphDef ret;
698 ToGraphDef(&ret);
699 return ret;
700 }
701
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const702 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
703 graph_def->Clear();
704 *graph_def->mutable_versions() = versions();
705 *graph_def->mutable_library() = ops_.ToProto();
706
707 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
708
709 std::vector<const Edge*>
710 inputs; // Construct this outside the loop for speed.
711 for (auto id = from_node_id; id < num_node_ids(); ++id) {
712 const Node* node = FindNodeId(id);
713 if (node == nullptr || !node->IsOp()) continue;
714 NodeDef* node_def = graph_def->add_node();
715 *node_def = node->def();
716
717 // Use the node's assigned device, if any, instead of the device requested
718 // in the NodeDef.
719 if (!node->assigned_device_name().empty()) {
720 node_def->set_device(node->assigned_device_name());
721 }
722
723 // Get the inputs for this Node. We make sure control inputs are
724 // after data inputs, as required by GraphDef.
725 inputs.clear();
726 inputs.resize(node->num_inputs(), nullptr);
727 for (const Edge* edge : node->in_edges()) {
728 if (edge->IsControlEdge()) {
729 inputs.push_back(edge);
730 } else {
731 DCHECK(edge->dst_input() < inputs.size())
732 << "Edge " << edge->DebugString()
733 << " is overflowing the expected number of inputs ("
734 << node->num_inputs() << ") for node " << node->DebugString();
735 CHECK(inputs[edge->dst_input()] == nullptr)
736 << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
737 << " conflicts with pre-existing input edge "
738 << inputs[edge->dst_input()]->src()->name() << "->"
739 << inputs[edge->dst_input()]->dst()->name();
740
741 inputs[edge->dst_input()] = edge;
742 }
743 }
744 // Sort the control inputs for more predictable serialization.
745 std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
746 [](const Edge* a, const Edge* b) -> bool {
747 return a->src()->name() < b->src()->name();
748 });
749 node_def->clear_input();
750 node_def->mutable_input()->Reserve(inputs.size());
751
752 for (size_t i = 0; i < inputs.size(); ++i) {
753 const Edge* edge = inputs[i];
754 if (edge == nullptr) {
755 if (i < node->requested_inputs().size()) {
756 node_def->add_input(node->requested_inputs()[i]);
757 } else {
758 node_def->add_input("");
759 }
760 } else {
761 const Node* src = edge->src();
762 if (!src->IsOp()) continue;
763 AddInput(node_def, src->name(), edge->src_output());
764 }
765 }
766 }
767 }
768
NewName(StringPiece prefix)769 std::string Graph::NewName(StringPiece prefix) {
770 return strings::StrCat(prefix, "/_", name_counter_++);
771 }
772
IsValidNode(const Node * node) const773 Status Graph::IsValidNode(const Node* node) const {
774 if (node == nullptr) {
775 return errors::InvalidArgument("Node is null");
776 }
777 const int id = node->id();
778 if (id < 0) {
779 return errors::InvalidArgument("node id ", id, " is less than zero");
780 }
781 if (static_cast<size_t>(id) >= nodes_.size()) {
782 return errors::InvalidArgument(
783 "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
784 }
785 if (nodes_[id] != node) {
786 return errors::InvalidArgument("Node with id ", id,
787 " is different from the passed in node. "
788 "Does it belong to a different graph?");
789 }
790 return Status::OK();
791 }
792
IsValidOutputTensor(const Node * node,int idx) const793 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
794 TF_RETURN_IF_ERROR(IsValidNode(node));
795 if (idx >= node->num_outputs() || idx < 0) {
796 return errors::OutOfRange("Node '", node->name(), "' (type: '",
797 node->op_def().name(),
798 "', num of outputs: ", node->num_outputs(),
799 ") does not have ", "output ", idx);
800 }
801 return Status::OK();
802 }
803
IsValidInputTensor(const Node * node,int idx) const804 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
805 TF_RETURN_IF_ERROR(IsValidNode(node));
806 if (idx >= node->num_inputs() || idx < 0) {
807 return errors::OutOfRange("Node '", node->name(), "' (type: '",
808 node->op_def().name(),
809 "', num of inputs: ", node->num_inputs(),
810 ") does not have ", "input ", idx);
811 }
812 return Status::OK();
813 }
814
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,Node::NodeClass node_class)815 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
816 const Node* cost_node, Node::NodeClass node_class) {
817 Node* node = nullptr;
818 if (free_nodes_.empty()) {
819 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
820 } else {
821 node = free_nodes_.back();
822 free_nodes_.pop_back();
823 }
824 node->graph_ = this;
825 const int id = nodes_.size();
826 int cost_id = cost_node ? cost_node->cost_id() : id;
827 node->Initialize(id, cost_id, std::move(props), node_class);
828 nodes_.push_back(node);
829 ++num_nodes_;
830 return node;
831 }
832
ReleaseNode(Node * node)833 void Graph::ReleaseNode(Node* node) {
834 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
835 nodes_[node->id()] = nullptr;
836 free_nodes_.push_back(node);
837 --num_nodes_;
838 node->Clear();
839 }
840
841 // Ensures that 'device_name' is present in the device name table, and returns
842 // the index of that device name. The index is stable, and can be used in
843 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const std::string & device_name)844 int Graph::InternDeviceName(const std::string& device_name) {
845 // Special case, very common. Also, this allows us to use a single map
846 // lookup below, instead of two. The 'if (index_cell > 0)' test below
847 // relies on this check.
848 if (device_name.empty()) {
849 return 0;
850 }
851
852 int& index_cell = device_names_map_[device_name];
853 if (index_cell > 0) {
854 return index_cell;
855 }
856
857 const int index = device_names_map_.size();
858 index_cell = index;
859 device_names_.push_back(device_name);
860 return index;
861 }
862
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)863 Status Graph::AddWhileContext(StringPiece frame_name,
864 std::vector<Node*> enter_nodes,
865 std::vector<Node*> exit_nodes,
866 OutputTensor cond_output,
867 std::vector<OutputTensor> body_inputs,
868 std::vector<OutputTensor> body_outputs,
869 WhileContext** result) {
870 auto pair = while_ctxs_.insert(std::pair<std::string, WhileContext>(
871 std::string(frame_name),
872 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
873 cond_output, std::move(body_inputs),
874 std::move(body_outputs))));
875 if (!pair.second) {
876 *result = nullptr;
877 return errors::InvalidArgument("WhileContext with frame name '", frame_name,
878 "' already exists");
879 }
880 *result = &pair.first->second;
881 return Status::OK();
882 }
883
BuildNodeNameIndex() const884 std::unordered_map<std::string, Node*> Graph::BuildNodeNameIndex() const {
885 std::unordered_map<std::string, Node*> result;
886 for (Node* n : nodes()) {
887 result[n->name()] = n;
888 }
889 return result;
890 }
891
SetNodeType(StringPiece name,const FullTypeDef & ft)892 void Graph::SetNodeType(StringPiece name, const FullTypeDef& ft) {
893 TypeRef t = {std::make_shared<FullTypeDef>(ft)};
894 auto ret = types_.emplace(t);
895 if (ret.second == false) {
896 t = *ret.first;
897 }
898
899 node_name_to_out_type_.emplace(string(name), t);
900 }
901
NodeType(StringPiece name,FullTypeDef ** result)902 void Graph::NodeType(StringPiece name, FullTypeDef** result) {
903 *result = nullptr;
904 auto it = node_name_to_out_type_.find(string(name));
905 if (it == node_name_to_out_type_.end()) {
906 *result = nullptr;
907 return;
908 }
909 *result = it->second.full_type.get();
910 }
911
DebugString() const912 std::string Edge::DebugString() const {
913 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
914 src_output_, dst_->name().c_str(), dst_input_);
915 }
916
917 } // namespace tensorflow
918